From ed9808c218dc9270b9673c606c2d0dee1900277d Mon Sep 17 00:00:00 2001
From: Denis Smirnov <sd@picodata.io>
Date: Mon, 7 Oct 2024 18:39:11 +0700
Subject: [PATCH] fix: bucket calculation for duplicated columns

The queries `select * from t where sk = 1 and sk = 2` discovered
the bucket for the constant 1, rather then an empty set. The reason
was that the tuple merge transformed `sk = 1 and sk = 2` to
`(sk, sk) = (1, 2)`, while the distribution took into account only
the first position (constant 1).

To compute all keys we now take a cartesian product between all
groups of columns of a tuple, where each group consists of columns
corresponding to single column of sharding key.

Suppose tuple is (a, b, a). (a, b) refer to sharding columns, then
we have two groups:
a -> {0, 2}
b -> {1}

And the distribution keys are:
{0, 2} x {1} = {(0, 1), (2, 1)}

Co-authored-by: Arseniy Volynets <a.volynets@picodata.io>
---
 sbroad-core/src/executor/bucket/tests.rs      | 26 +++++++++++
 .../src/frontend/sql/ir/tests/insert.rs       |  2 +-
 sbroad-core/src/ir/distribution.rs            | 44 ++++++++++++-------
 3 files changed, 56 insertions(+), 16 deletions(-)

diff --git a/sbroad-core/src/executor/bucket/tests.rs b/sbroad-core/src/executor/bucket/tests.rs
index 3c935089b..bcc6fdf83 100644
--- a/sbroad-core/src/executor/bucket/tests.rs
+++ b/sbroad-core/src/executor/bucket/tests.rs
@@ -486,3 +486,29 @@ fn delete_local() {
 
     assert_eq!(Buckets::Filtered(collection!(6691)), buckets);
 }
+
+#[test]
+fn same_multicolumn_sk_in_eq() {
+    let query = r#"select * from t where a = 1 and b = 1 and b = 2 and a = 2"#;
+
+    let coordinator = RouterRuntimeMock::new();
+    let mut query = Query::new(&coordinator, query, vec![]).unwrap();
+    let plan = query.exec_plan.get_ir_plan();
+    let top = plan.get_top().unwrap();
+    let buckets = query.bucket_discovery(top).unwrap();
+
+    assert_eq!(Buckets::Filtered(collection!()), buckets);
+}
+
+#[test]
+fn same_column_in_eq() {
+    let query = r#"select * from test_space where id = 1 and id = 2"#;
+
+    let coordinator = RouterRuntimeMock::new();
+    let mut query = Query::new(&coordinator, query, vec![]).unwrap();
+    let plan = query.exec_plan.get_ir_plan();
+    let top = plan.get_top().unwrap();
+    let buckets = query.bucket_discovery(top).unwrap();
+
+    assert_eq!(Buckets::Filtered(collection!()), buckets);
+}
diff --git a/sbroad-core/src/frontend/sql/ir/tests/insert.rs b/sbroad-core/src/frontend/sql/ir/tests/insert.rs
index 3b82d03b5..befdc6efa 100644
--- a/sbroad-core/src/frontend/sql/ir/tests/insert.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests/insert.rs
@@ -48,7 +48,7 @@ fn insert3() {
 
     let expected_explain = String::from(
         r#"insert "test_space" on conflict: fail
-    motion [policy: segment([ref("id")])]
+    motion [policy: local segment([ref("id")])]
         projection ("test_space"."id"::unsigned -> "id", "test_space"."id"::unsigned -> "id")
             scan "test_space"
 execution options:
diff --git a/sbroad-core/src/ir/distribution.rs b/sbroad-core/src/ir/distribution.rs
index 1e9aec139..e476f3bfd 100644
--- a/sbroad-core/src/ir/distribution.rs
+++ b/sbroad-core/src/ir/distribution.rs
@@ -1,6 +1,7 @@
 //! Tuple distribution module.
 
 use ahash::{AHashMap, RandomState};
+use itertools::Itertools;
 use smol_str::{format_smolstr, ToSmolStr};
 use std::collections::{HashMap, HashSet};
 
@@ -329,7 +330,7 @@ type ParentColumnPosition = usize;
 /// Set of the relational nodes referred by references under the row.
 struct ReferenceInfo {
     referred_children: ReferredNodes,
-    child_column_to_parent_col: AHashMap<ChildColumnReference, ParentColumnPosition>,
+    child_column_to_parent_col: AHashMap<ChildColumnReference, Vec<ParentColumnPosition>>,
 }
 
 impl ReferenceInfo {
@@ -339,7 +340,8 @@ impl ReferenceInfo {
         parent_children: &Children<'_>,
     ) -> Result<Self, SbroadError> {
         let mut ref_nodes = ReferredNodes::new();
-        let mut ref_map: AHashMap<ChildColumnReference, ParentColumnPosition> = AHashMap::new();
+        let mut ref_map: AHashMap<ChildColumnReference, Vec<ParentColumnPosition>> =
+            AHashMap::new();
         for (parent_column_pos, id) in ir.get_row_list(row_id)?.iter().enumerate() {
             let child_id = ir.get_child_under_alias(*id)?;
             if let Expression::Reference(Reference {
@@ -362,7 +364,10 @@ impl ReferenceInfo {
                         )
                     })?;
                     ref_nodes.append(*referred_id);
-                    ref_map.insert((*referred_id, *position).into(), parent_column_pos);
+                    ref_map
+                        .entry((*referred_id, *position).into())
+                        .or_default()
+                        .push(parent_column_pos);
                 }
             }
         }
@@ -615,7 +620,7 @@ impl Plan {
     fn dist_from_child(
         &self,
         child_rel_node: NodeId,
-        child_pos_map: &AHashMap<ChildColumnReference, ParentColumnPosition>,
+        child_pos_map: &AHashMap<ChildColumnReference, Vec<ParentColumnPosition>>,
     ) -> Result<Distribution, SbroadError> {
         let rel_node = self.get_relation_node(child_rel_node)?;
         let output_expr = self.get_expression_node(rel_node.output())?;
@@ -633,18 +638,27 @@ impl Plan {
                     let mut new_keys: HashSet<Key, RepeatableState> =
                         HashSet::with_hasher(RepeatableState);
                     for key in keys.iter() {
-                        let mut new_key: Key = Key::new(Vec::with_capacity(key.positions.len()));
-                        let all_found = key.positions.iter().all(|pos| {
-                            child_pos_map
-                                .get(&(child_rel_node, *pos).into())
-                                .map_or(false, |v| {
-                                    new_key.positions.push(*v);
-                                    true
-                                })
-                        });
+                        let all_found = key
+                            .positions
+                            .iter()
+                            .all(|pos| child_pos_map.contains_key(&(child_rel_node, *pos).into()));
 
                         if all_found {
-                            new_keys.insert(new_key);
+                            let product = key
+                                .positions
+                                .iter()
+                                .map(|pos| {
+                                    child_pos_map
+                                        .get(&(child_rel_node, *pos).into())
+                                        .unwrap()
+                                        .iter()
+                                        .copied()
+                                })
+                                .multi_cartesian_product();
+
+                            for positions in product {
+                                new_keys.insert(Key::new(positions));
+                            }
                         }
                     }
 
@@ -689,7 +703,7 @@ impl Plan {
 
     fn set_two_children_node_dist(
         &mut self,
-        child_pos_map: &AHashMap<ChildColumnReference, ParentColumnPosition>,
+        child_pos_map: &AHashMap<ChildColumnReference, Vec<ParentColumnPosition>>,
         left_id: NodeId,
         right_id: NodeId,
         parent_id: NodeId,
-- 
GitLab