From 43e42c31ab4ed0f489c28c72ad69a8a2ddd4709e Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Wed, 13 Mar 2024 01:31:37 +0300
Subject: [PATCH] fix: except, groupby, join, sq on bucket_id

- Detect when except, groupby, join or eqality
with subqueries is done on bucket_id column
to avoid adding extra motions
---
 sbroad-core/src/executor/engine/helpers.rs    |   5 -
 sbroad-core/src/frontend/sql/ir/tests.rs      | 174 ++++++++++++++++
 sbroad-core/src/ir.rs                         |  93 +++++++++
 sbroad-core/src/ir/expression.rs              |  22 +-
 sbroad-core/src/ir/operator.rs                |  25 ++-
 sbroad-core/src/ir/relation.rs                |   2 -
 .../src/ir/transformation/redistribution.rs   | 195 ++++++++++++++----
 .../transformation/redistribution/groupby.rs  |  19 ++
 8 files changed, 477 insertions(+), 58 deletions(-)

diff --git a/sbroad-core/src/executor/engine/helpers.rs b/sbroad-core/src/executor/engine/helpers.rs
index 967cadf1d4..27654c65d0 100644
--- a/sbroad-core/src/executor/engine/helpers.rs
+++ b/sbroad-core/src/executor/engine/helpers.rs
@@ -81,11 +81,6 @@ pub fn normalize_name_for_space_api(s: &str) -> String {
     s.to_uppercase()
 }
 
-#[must_use]
-pub fn is_sharding_column_name(name: &str) -> bool {
-    name == "\"bucket_id\"" || name == "bucket_id"
-}
-
 /// A helper function to encode the execution plan into a pair of binary data (see `Message`):
 /// * required data (plan id, parameters, etc.)
 /// * optional data (execution plan, etc.)
diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs
index 204f4e8dd3..0a5f4c4f33 100644
--- a/sbroad-core/src/frontend/sql/ir/tests.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests.rs
@@ -2,7 +2,9 @@ use crate::errors::SbroadError;
 use crate::executor::engine::mock::RouterConfigurationMock;
 use crate::frontend::sql::ast::AbstractSyntaxTree;
 use crate::frontend::Ast;
+use crate::ir::operator::Relational;
 use crate::ir::transformation::helpers::sql_to_optimized_ir;
+use crate::ir::tree::traversal::PostOrder;
 use crate::ir::value::Value;
 use pretty_assertions::assert_eq;
 
@@ -517,6 +519,101 @@ vtable_max_rows = 5000
     assert_eq!(expected_explain, plan.as_explain().unwrap());
 }
 
+#[test]
+fn track_shard_col_pos() {
+    let input = r#"
+    select "e", "bucket_id", "f" 
+    from "t2"
+    where "e" + "f" = 3
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), 10);
+    for (_, node_id) in dfs.iter(top) {
+        let node = plan.get_relation_node(node_id).unwrap();
+        match node {
+            Relational::ScanRelation { .. } | Relational::Selection { .. } => {
+                assert_eq!(&vec![4_usize], map.get(&node_id).unwrap())
+            }
+            Relational::Projection { .. } => {
+                assert_eq!(&vec![1_usize], map.get(&node_id).unwrap())
+            }
+            _ => {}
+        }
+    }
+
+    let input = r#"select t_mv."bucket_id", "t2"."bucket_id" from "t2" join (
+        select "bucket_id" from "test_space" where "id" = 1
+    ) as t_mv
+    on t_mv."bucket_id" = "t2"."bucket_id";
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), 10);
+    for (_, node_id) in dfs.iter(top) {
+        let node = plan.get_relation_node(node_id).unwrap();
+        if let Relational::Join { .. } = node {
+            assert_eq!(&vec![4_usize, 5_usize], map.get(&node_id).unwrap());
+        }
+    }
+    assert_eq!(&vec![0_usize, 1_usize], map.get(&top).unwrap());
+
+    let input = r#"select t_mv."bucket_id", "t2"."bucket_id" from "t2" join (
+        select "bucket_id" from "test_space" where "id" = 1
+    ) as t_mv
+    on t_mv."bucket_id" < "t2"."bucket_id";
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), 10);
+    for (_, node_id) in dfs.iter(top) {
+        let node = plan.get_relation_node(node_id).unwrap();
+        if let Relational::Join { .. } = node {
+            assert_eq!(&vec![4_usize], map.get(&node_id).unwrap());
+        }
+    }
+    assert_eq!(&vec![1_usize], map.get(&top).unwrap());
+
+    let input = r#"
+    select "bucket_id", "e" from "t2"
+    union all
+    select "id", "bucket_id" from "test_space"
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    assert_eq!(None, map.get(&top));
+
+    let input = r#"
+    select "bucket_id", "e" from "t2"
+    union all
+    select "bucket_id", "id" from "test_space"
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    assert_eq!(&vec![0_usize], map.get(&top).unwrap());
+
+    let input = r#"
+    select "e" from (select "bucket_id" as "e" from "t2")
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    assert_eq!(&vec![0_usize], map.get(&top).unwrap());
+
+    let input = r#"
+    select "e" as "bucket_id" from "t2"
+    "#;
+    let plan = sql_to_optimized_ir(input, vec![]);
+    let top = plan.get_top().unwrap();
+    let map = plan.track_shard_column_pos(top).unwrap();
+    assert_eq!(None, map.get(&top));
+}
+
 #[test]
 fn front_sql_join_on_bucket_id1() {
     let input = r#"select * from "t2" join (
@@ -576,6 +673,83 @@ vtable_max_rows = 5000
     assert_eq!(expected_explain, plan.as_explain().unwrap());
 }
 
+#[test]
+fn front_sql_groupby_on_bucket_id() {
+    let input = r#"
+    select b, count(*) from (select "bucket_id" as b from "t2") as t 
+    group by b
+    "#;
+
+    let plan = sql_to_optimized_ir(input, vec![]);
+
+    let expected_explain = String::from(
+        r#"projection ("T"."B"::unsigned -> "B", count((*::integer))::integer -> "COL_1")
+    group by ("T"."B"::unsigned) output: ("T"."B"::unsigned -> "B")
+        scan "T"
+            projection ("t2"."bucket_id"::unsigned -> "B")
+                scan "t2"
+execution options:
+sql_vdbe_max_steps = 45000
+vtable_max_rows = 5000
+"#,
+    );
+
+    assert_eq!(expected_explain, plan.as_explain().unwrap());
+}
+
+#[test]
+fn front_sql_sq_on_bucket_id() {
+    let input = r#"
+    select b, e from (select "bucket_id" as b, "e" as e from "t2") as t 
+    where (b, e) in (select "bucket_id", "id" from "test_space")
+    "#;
+
+    let plan = sql_to_optimized_ir(input, vec![]);
+
+    let expected_explain = String::from(
+        r#"projection ("T"."B"::unsigned -> "B", "T"."E"::unsigned -> "E")
+    selection ROW("T"."B"::unsigned, "T"."E"::unsigned) in ROW($0, $0)
+        scan "T"
+            projection ("t2"."bucket_id"::unsigned -> "B", "t2"."e"::unsigned -> "E")
+                scan "t2"
+subquery $0:
+scan
+            projection ("test_space"."bucket_id"::unsigned -> "bucket_id", "test_space"."id"::unsigned -> "id")
+                scan "test_space"
+execution options:
+sql_vdbe_max_steps = 45000
+vtable_max_rows = 5000
+"#,
+    );
+
+    assert_eq!(expected_explain, plan.as_explain().unwrap());
+}
+
+#[test]
+fn front_sql_except_on_bucket_id() {
+    let input = r#"
+    select "e", "bucket_id" from "t2"
+    except
+    select "id", "bucket_id" from "test_space"
+    "#;
+
+    let plan = sql_to_optimized_ir(input, vec![]);
+
+    let expected_explain = String::from(
+        r#"except
+    projection ("t2"."e"::unsigned -> "e", "t2"."bucket_id"::unsigned -> "bucket_id")
+        scan "t2"
+    projection ("test_space"."id"::unsigned -> "id", "test_space"."bucket_id"::unsigned -> "bucket_id")
+        scan "test_space"
+execution options:
+sql_vdbe_max_steps = 45000
+vtable_max_rows = 5000
+"#,
+    );
+
+    assert_eq!(expected_explain, plan.as_explain().unwrap());
+}
+
 #[test]
 fn front_sql_exists_subquery_select_from_table() {
     let input = r#"SELECT "id" FROM "test_space" WHERE EXISTS (SELECT 0 FROM "hash_testing")"#;
diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs
index c07f673c69..fb61602130 100644
--- a/sbroad-core/src/ir.rs
+++ b/sbroad-core/src/ir.rs
@@ -32,8 +32,10 @@ use crate::ir::undo::TransformationLog;
 use crate::ir::value::Value;
 use crate::{collection, error, warn};
 
+use self::expression::Position;
 use self::parameters::Parameters;
 use self::relation::Relations;
+use self::transformation::redistribution::MotionPolicy;
 
 // TODO: remove when rust version in bumped in module
 #[allow(elided_lifetimes_in_associated_constant)]
@@ -1242,6 +1244,97 @@ impl Plan {
     }
 }
 
+/// Relational node id -> positions of columns in output that refer to sharding column.
+pub type ShardColInfo = ahash::AHashMap<NodeId, Vec<Position>>;
+
+impl Plan {
+    /// Helper function to track position of the sharding column
+    /// for any relational node in the subtree defined by `top_id`.
+    ///
+    /// # Errors
+    /// - invalid references in the plan subtree
+    pub fn track_shard_column_pos(&self, top_id: usize) -> Result<ShardColInfo, SbroadError> {
+        let mut memo = ShardColInfo::with_capacity(REL_CAPACITY);
+        let mut dfs = PostOrder::with_capacity(|x| self.nodes.rel_iter(x), REL_CAPACITY);
+
+        for (_, node_id) in dfs.iter(top_id) {
+            let node = self.get_relation_node(node_id)?;
+
+            match node {
+                Relational::ScanRelation { relation, .. } => {
+                    let table = self.get_relation_or_error(relation)?;
+                    if let Ok(Some(pos)) = table.get_bucket_id_position() {
+                        memo.insert(node_id, vec![pos]);
+                    }
+                    continue;
+                }
+                Relational::Motion { policy, .. } => {
+                    // Any motion node that moves data invalidates
+                    // bucket_id column selected from that space.
+                    // Even Segment policy is no help, because it only
+                    // creates index on virtual table but does not actually
+                    // add or update bucket_id column.
+                    if !matches!(policy, MotionPolicy::Local | MotionPolicy::LocalSegment(_)) {
+                        continue;
+                    }
+                }
+                _ => {}
+            }
+
+            let Some(children) = node.children() else {
+                continue;
+            };
+
+            let output = self.get_row_list(node.output())?;
+            for (pos, alias_id) in output.iter().enumerate() {
+                let ref_id = self.get_child_under_alias(*alias_id)?;
+                // If there is a parameter under alias
+                // and we haven't bound parameters yet,
+                // we will get an error.
+                let Ok(Expression::Reference {
+                    targets, position, ..
+                }) = self.get_expression_node(ref_id)
+                else {
+                    continue;
+                };
+                let Some(targets) = targets else {
+                    continue;
+                };
+
+                // For node with multiple targets (Union, Except, Intersect)
+                // we need that ALL targets would refer to the shard column.
+                let mut refers_to_shard_col = true;
+                for target in targets {
+                    let child_id = children.get(*target).ok_or_else(|| {
+                        SbroadError::Invalid(
+                            Entity::Plan,
+                            Some(format!(
+                                "invalid target ({target}) in reference with id: {ref_id}"
+                            )),
+                        )
+                    })?;
+                    let Some(candidates) = memo.get(child_id) else {
+                        refers_to_shard_col = false;
+                        break;
+                    };
+                    if !candidates.contains(position) {
+                        refers_to_shard_col = false;
+                        break;
+                    }
+                }
+
+                if refers_to_shard_col {
+                    memo.entry(node_id)
+                        .and_modify(|v| v.push(pos))
+                        .or_insert(vec![pos]);
+                }
+            }
+        }
+
+        Ok(memo)
+    }
+}
+
 pub mod api;
 mod explain;
 #[cfg(test)]
diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs
index f4a3e8ffd0..2333984137 100644
--- a/sbroad-core/src/ir/expression.rs
+++ b/sbroad-core/src/ir/expression.rs
@@ -14,7 +14,6 @@ use std::hash::{Hash, Hasher};
 use std::ops::Bound::Included;
 
 use crate::errors::{Entity, SbroadError};
-use crate::executor::engine::helpers::is_sharding_column_name;
 use crate::ir::aggregates::AggregateKind;
 use crate::ir::operator::{Bool, Relational};
 use crate::ir::relation::Type;
@@ -1067,6 +1066,17 @@ impl Plan {
         // Vec of (column position in child output, column plan id, new_targets).
         let mut filtered_children_row_list: Vec<(usize, usize, Vec<usize>)> = Vec::new();
 
+        // Helper lambda to retrieve column positions we need to exclude from child `rel_id`.
+        let column_positions_to_exclude = |rel_id| -> Result<Vec<Position>, SbroadError> {
+            let positions = if need_sharding_column {
+                vec![]
+            } else {
+                let mut info = self.track_shard_column_pos(rel_id)?;
+                info.remove(&rel_id).unwrap_or_default()
+            };
+            Ok(positions)
+        };
+
         if let Some(columns_spec) = source.get_columns_spec() {
             let (rel_child, _) = source
                 .iter()
@@ -1087,12 +1097,13 @@ impl Plan {
                 ColumnsRetrievalSpec::Indices(indices) => indices.clone(),
             };
 
+            let exclude_positions = column_positions_to_exclude(rel_child)?;
+
             for index in indices {
                 let col_id = *child_node_row_list
                     .get(index)
                     .expect("Column id not found under relational child output");
-                let alias_name = self.get_expression_node(col_id)?.get_alias_name()?;
-                if is_sharding_column_name(alias_name) {
+                if exclude_positions.contains(&index) {
                     continue;
                 }
                 filtered_children_row_list.push((index, col_id, source.targets()));
@@ -1112,9 +1123,10 @@ impl Plan {
                         filtered_children_row_list.push((pos, *id, new_targets.clone()));
                     });
                 } else {
+                    let exclude_positions = column_positions_to_exclude(child_node_id)?;
+
                     for (pos, expr_id) in child_row_list.iter().enumerate() {
-                        let alias_name = self.get_expression_node(*expr_id)?.get_alias_name()?;
-                        if is_sharding_column_name(alias_name) {
+                        if exclude_positions.contains(&pos) {
                             continue;
                         }
                         filtered_children_row_list.push((pos, *expr_id, new_targets.clone()));
diff --git a/sbroad-core/src/ir/operator.rs b/sbroad-core/src/ir/operator.rs
index e2ea841c5a..49e4953194 100644
--- a/sbroad-core/src/ir/operator.rs
+++ b/sbroad-core/src/ir/operator.rs
@@ -1263,15 +1263,27 @@ impl Plan {
                     EXPR_CAPACITY,
                     EXPR_CAPACITY,
                 );
+                // we should update ONLY references that refer to current child (left, right)
+                let current_target = match join_child {
+                    JoinChild::Inner => Some(vec![1_usize]),
+                    JoinChild::Outer => Some(vec![0_usize]),
+                };
                 let refs = condition_tree
                     .iter(condition)
                     .filter_map(|(_, id)| {
                         let expr = self.get_expression_node(id).ok();
-                        if let Some(Expression::Reference { position, .. }) = expr {
-                            if Some(*position) == sharding_column_pos {
-                                needs_bucket_id_column = true;
+                        if let Some(Expression::Reference {
+                            position, targets, ..
+                        }) = expr
+                        {
+                            if *targets == current_target {
+                                if Some(*position) == sharding_column_pos {
+                                    needs_bucket_id_column = true;
+                                }
+                                Some(id)
+                            } else {
+                                None
                             }
-                            Some(id)
                         } else {
                             None
                         }
@@ -1292,11 +1304,6 @@ impl Plan {
                     continue;
                 }
 
-                // we should update ONLY references that refer to current child (left, right)
-                let current_target = match join_child {
-                    JoinChild::Inner => Some(vec![1_usize]),
-                    JoinChild::Outer => Some(vec![0_usize]),
-                };
                 if let Some(sharding_column_pos) = sharding_column_pos {
                     for ref_id in refs {
                         let expr = self.get_mut_expression_node(ref_id)?;
diff --git a/sbroad-core/src/ir/relation.rs b/sbroad-core/src/ir/relation.rs
index a834b123e8..9ae2e0836c 100644
--- a/sbroad-core/src/ir/relation.rs
+++ b/sbroad-core/src/ir/relation.rs
@@ -27,8 +27,6 @@ use super::distribution::Key;
 
 const DEFAULT_VALUE: Value = Value::Null;
 
-pub const SHARD_COL_NAME: &str = "\"bucket_id\"";
-
 /// Supported column types, which is used in a schema only.
 /// This `Type` is derived from the result's metadata.
 #[derive(Serialize, Default, Deserialize, PartialEq, Hash, Debug, Eq, Clone)]
diff --git a/sbroad-core/src/ir/transformation/redistribution.rs b/sbroad-core/src/ir/transformation/redistribution.rs
index 15a63c3bdf..bd5961b871 100644
--- a/sbroad-core/src/ir/transformation/redistribution.rs
+++ b/sbroad-core/src/ir/transformation/redistribution.rs
@@ -1,6 +1,6 @@
 //! Resolve distribution conflicts and insert motion nodes to IR.
 
-use ahash::{AHashSet, RandomState};
+use ahash::{AHashMap, AHashSet, RandomState};
 use serde::{Deserialize, Serialize};
 use std::cmp::Ordering;
 use std::collections::{hash_map::Entry, HashMap, HashSet};
@@ -12,13 +12,13 @@ use crate::ir::expression::ColumnPositionMap;
 use crate::ir::expression::Expression;
 use crate::ir::operator::{Bool, JoinKind, Relational, Unary, UpdateStrategy};
 
-use crate::ir::relation::{TableKind, SHARD_COL_NAME};
+use crate::ir::relation::TableKind;
 use crate::ir::transformation::redistribution::eq_cols::EqualityCols;
 use crate::ir::tree::traversal::{
     BreadthFirst, LevelNode, PostOrder, PostOrderWithFilter, EXPR_CAPACITY, REL_CAPACITY,
 };
 use crate::ir::value::Value;
-use crate::ir::{Node, Plan};
+use crate::ir::{Node, Plan, ShardColInfo};
 use crate::otm::child_span;
 use sbroad_proc::otm_child_span;
 
@@ -432,6 +432,84 @@ impl Plan {
         }
     }
 
+    /// Check for join/sq equality on `bucket_id` column:
+    /// ```text
+    /// .. on (t1.a, t1.bucket_id) = (t2.b, t2.bucket_id)
+    ///
+    /// select * from t1 where bucket_id in (select bucket_id from t2)
+    /// ```
+    ///
+    /// In such case join/selection can be done locally.
+    fn has_eq_on_bucket_id(
+        &self,
+        left_row_id: usize,
+        right_row_id: usize,
+        rel_id: usize,
+        op: &Bool,
+        shard_col_info: &ShardColInfo,
+    ) -> Result<bool, SbroadError> {
+        if !(Bool::Eq == *op || Bool::In == *op) {
+            return Ok(false);
+        }
+        // It is possible that multiple columns in row refer to the shard column
+        // we need to find if there is a pair of such columns from different
+        // children for local join:
+        //
+        // select * from (select bucket_id as a from t1) as t1
+        // join (select bucket_id as b from t2) as t2
+        // on (1, a, b) = (2, b, 3)
+        //
+        // Equality pair `a = b` allows us to do local join.
+        //
+        // position in row that refers to shard column -> child id
+        let mut memo: AHashMap<usize, usize> = AHashMap::new();
+        let mut search_row = |row_id: usize| -> Result<bool, SbroadError> {
+            let refs = self.get_row_list(row_id)?;
+            for (pos_in_row, ref_id) in refs.iter().enumerate() {
+                let node @ Expression::Reference {
+                    targets,
+                    position: ref_pos,
+                    ..
+                } = self.get_expression_node(*ref_id)?
+                else {
+                    continue;
+                };
+                let targets = targets.as_ref().ok_or_else(|| {
+                    SbroadError::Invalid(
+                        Entity::Node,
+                        Some(format!(
+                            "ref ({ref_id}) in join condition with no targets: {node:?}"
+                        )),
+                    )
+                })?;
+                let child_idx = targets.first().ok_or_else(|| {
+                    SbroadError::Invalid(
+                        Entity::Node,
+                        Some(format!(
+                            "ref ({ref_id}) in join condition with empty targets: {node:?}"
+                        )),
+                    )
+                })?;
+                let child_id = self.get_relational_child(rel_id, *child_idx)?;
+                if let Some(candidates) = shard_col_info.get(&child_id) {
+                    if !candidates.contains(ref_pos) {
+                        continue;
+                    }
+                    if let Some(other_child_id) = memo.get(&pos_in_row) {
+                        if *other_child_id != child_id {
+                            return Ok(true);
+                        }
+                    } else {
+                        memo.insert(pos_in_row, child_id);
+                    }
+                }
+            }
+
+            Ok(false)
+        };
+        Ok(search_row(left_row_id)? || search_row(right_row_id)?)
+    }
+
     /// Choose a `MotionPolicy` strategy for the inner row.
     ///
     /// # Errors
@@ -551,11 +629,33 @@ impl Plan {
         &self,
         rel_id: usize,
         op_id: usize,
+        shard_col_info: &ShardColInfo,
     ) -> Result<Vec<(usize, MotionPolicy)>, SbroadError> {
         let mut strategies: Vec<(usize, MotionPolicy)> = Vec::new();
         let bool_op = BoolOp::from_expr(self, op_id)?;
         let left = self.get_additional_sq(rel_id, bool_op.left)?;
         let right = self.get_additional_sq(rel_id, bool_op.right)?;
+
+        // If we eq/in where both rows contain bucket_id in same position
+        // we don't need Motion nodes.
+        if (left.is_some() || right.is_some())
+            && self.has_eq_on_bucket_id(
+                bool_op.left,
+                bool_op.right,
+                rel_id,
+                &bool_op.op,
+                shard_col_info,
+            )?
+        {
+            if let Some(left_sq) = left {
+                strategies.push((left_sq, MotionPolicy::None));
+            }
+            if let Some(right_sq) = right {
+                strategies.push((right_sq, MotionPolicy::None));
+            }
+            return Ok(strategies);
+        }
+
         match left {
             Some(left_sq) => {
                 match right {
@@ -648,13 +748,15 @@ impl Plan {
         }
 
         let bool_nodes = self.get_bool_nodes_with_row_children(filter_id);
+        let shard_col_info = self.track_shard_column_pos(select_id)?;
         for (_, bool_node) in &bool_nodes {
             let bool_op = BoolOp::from_expr(self, *bool_node)?;
             self.set_distribution(bool_op.left)?;
             self.set_distribution(bool_op.right)?;
         }
         for (_, bool_node) in &bool_nodes {
-            let strategies = self.get_sq_node_strategies_for_bool_op(select_id, *bool_node)?;
+            let strategies =
+                self.get_sq_node_strategies_for_bool_op(select_id, *bool_node, &shard_col_info)?;
             for (id, policy) in strategies {
                 // In case we faced with `not ... in ...`, we
                 // have to change motion policy to Full.
@@ -912,36 +1014,16 @@ impl Plan {
         join_id: usize,
         left_row_id: usize,
         right_row_id: usize,
+        shard_col_info: &ShardColInfo,
     ) -> Result<MotionPolicy, SbroadError> {
-        {
-            // check for (a, t1.bucket_id, b) = (x, t2.bucket_id, y)
-            let get_shard_pos = |row_id: usize| -> Result<Option<usize>, SbroadError> {
-                let mut shard_pos = None;
-                let refs = self.get_row_list(row_id)?;
-                for (pos, ref_id) in refs.iter().enumerate() {
-                    let node @ Expression::Reference { .. } = self.get_expression_node(*ref_id)?
-                    else {
-                        continue;
-                    };
-
-                    // NB: This code assumes that user does not shoot himself in
-                    // the leg by renaming some column into `bucket_id` like here:
-                    // select * from (select "a" as "bucket_id", "bucket_id" as b from "t") join t2 on ...
-                    // If this happens, we will get wrong plan.
-                    // TODO: forbid renaming some column into `bucket_id` or renaming
-                    // `bucket_id` into something else.
-                    if SHARD_COL_NAME == self.get_alias_from_reference_node(node)? {
-                        shard_pos = Some(pos);
-                        break;
-                    }
-                }
-                Ok(shard_pos)
-            };
-            let left_shard_pos = get_shard_pos(left_row_id)?;
-            let right_shard_pos = get_shard_pos(right_row_id)?;
-            if left_shard_pos.is_some() && left_shard_pos == right_shard_pos {
-                return Ok(MotionPolicy::None);
-            }
+        if self.has_eq_on_bucket_id(
+            left_row_id,
+            right_row_id,
+            join_id,
+            &Bool::Eq,
+            shard_col_info,
+        )? {
+            return Ok(MotionPolicy::None);
         }
 
         let left_dist = self.get_distribution(left_row_id)?;
@@ -1141,6 +1223,7 @@ impl Plan {
             (inner, outer)
         };
 
+        let shard_col_info = self.track_shard_column_pos(rel_id)?;
         let mut inner_map: HashMap<usize, MotionPolicy> = HashMap::new();
         let mut new_inner_policy = MotionPolicy::Full;
         let filter = |node_id: usize| -> bool {
@@ -1190,7 +1273,8 @@ impl Plan {
             // Note, that we don't have to call `get_sq_node_strategy_for_unary_op` here, because
             // the only strategy it can return is `Motion::Full` for its child and all subqueries
             // are covered with `Motion::Full` by default.
-            let sq_strategies = self.get_sq_node_strategies_for_bool_op(rel_id, node_id)?;
+            let sq_strategies =
+                self.get_sq_node_strategies_for_bool_op(rel_id, node_id, &shard_col_info)?;
             let sq_strategies_len = sq_strategies.len();
             for (id, policy) in sq_strategies {
                 strategy.add_child(id, policy, Program::default());
@@ -1235,9 +1319,12 @@ impl Plan {
                         Bool::Between => {
                             unreachable!("Between in redistribution")
                         }
-                        Bool::Eq | Bool::In => {
-                            self.join_policy_for_eq(rel_id, bool_op.left, bool_op.right)?
-                        }
+                        Bool::Eq | Bool::In => self.join_policy_for_eq(
+                            rel_id,
+                            bool_op.left,
+                            bool_op.right,
+                            &shard_col_info,
+                        )?,
                         Bool::Gt | Bool::GtEq | Bool::Lt | Bool::LtEq | Bool::NotEq => {
                             MotionPolicy::Full
                         }
@@ -1763,6 +1850,36 @@ impl Plan {
         Ok(map)
     }
 
+    // Helper function to check whether except is done between
+    // sharded tables that both contain the bucket_id column
+    // at the same position in their outputs. In such case
+    // except can be done locally.
+    //
+    // Example:
+    // select "bucket_id" as a from t1
+    // except
+    // select "bucket_id" as b from t1
+    fn is_except_on_bucket_id(
+        &self,
+        rel_id: usize,
+        left_id: usize,
+        right_id: usize,
+    ) -> Result<bool, SbroadError> {
+        let shard_col_info = self.track_shard_column_pos(rel_id)?;
+        let Some(left_shard_positions) = shard_col_info.get(&left_id) else {
+            return Ok(false);
+        };
+        let Some(right_shard_positions) = shard_col_info.get(&right_id) else {
+            return Ok(false);
+        };
+        for l in left_shard_positions {
+            if right_shard_positions.contains(l) {
+                return Ok(true);
+            }
+        }
+        Ok(false)
+    }
+
     #[allow(clippy::too_many_lines)]
     fn resolve_except_conflicts(&mut self, rel_id: usize) -> Result<Strategy, SbroadError> {
         if !matches!(self.get_relation_node(rel_id)?, Relational::Except { .. }) {
@@ -1783,6 +1900,10 @@ impl Plan {
         let left_dist = self.get_rel_distribution(left_id)?;
         let right_dist = self.get_rel_distribution(right_id)?;
 
+        if self.is_except_on_bucket_id(rel_id, left_id, right_id)? {
+            return Ok(map);
+        }
+
         let (left_motion, right_motion) = match (left_dist, right_dist) {
             (
                 Distribution::Segment { keys: left_keys },
diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs
index b963c03c1d..4b2972f620 100644
--- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs
+++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs
@@ -1632,6 +1632,25 @@ impl Plan {
         if grouping_exprs.is_empty() && aggr_infos.is_empty() {
             return Ok(false);
         }
+
+        // Check for group by on bucket_id column
+        // in that case groupby can be done locally.
+        if !grouping_exprs.is_empty() {
+            let shard_col_info = self.track_shard_column_pos(final_proj_id)?;
+            for expr_id in &grouping_exprs {
+                let Expression::Reference { position, .. } = self.get_expression_node(*expr_id)?
+                else {
+                    continue;
+                };
+                let child_id = self.get_relational_from_reference_node(*expr_id)?;
+                if let Some(shard_positions) = shard_col_info.get(child_id) {
+                    if shard_positions.contains(position) {
+                        return Ok(false);
+                    }
+                }
+            }
+        }
+
         let (local_proj_id, grouping_positions, local_aliases_map) =
             self.add_local_projection(upper, &mut aggr_infos, &grouping_exprs)?;
         let sq_id = self.add_sub_query(local_proj_id, None)?;
-- 
GitLab