From 96891c12148401dd85b05bd2dc70e28d016324fe Mon Sep 17 00:00:00 2001
From: Denis Smirnov <sd@picodata.io>
Date: Wed, 15 May 2024 19:45:54 +0700
Subject: [PATCH] perf: speed-up track_shard_column_pos()

---
 sbroad-core/src/frontend/sql/ir/tests.rs      | 16 +++----
 sbroad-core/src/ir.rs                         | 47 ++++++++++++-------
 sbroad-core/src/ir/aggregates.rs              |  4 +-
 sbroad-core/src/ir/expression.rs              |  9 ++--
 .../src/ir/transformation/redistribution.rs   | 25 ++++++++--
 .../transformation/redistribution/groupby.rs  |  4 +-
 .../redistribution/left_join.rs               | 38 +--------------
 7 files changed, 73 insertions(+), 70 deletions(-)

diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs
index efd742c78..90405d333 100644
--- a/sbroad-core/src/frontend/sql/ir/tests.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests.rs
@@ -813,10 +813,10 @@ fn track_shard_col_pos() {
         let node = plan.get_relation_node(node_id).unwrap();
         match node {
             Relational::ScanRelation { .. } | Relational::Selection { .. } => {
-                assert_eq!(&vec![4_usize], map.get(&node_id).unwrap())
+                assert_eq!([Some(4_usize), None], *map.get(&node_id).unwrap())
             }
             Relational::Projection { .. } => {
-                assert_eq!(&vec![1_usize], map.get(&node_id).unwrap())
+                assert_eq!([Some(1_usize), None], *map.get(&node_id).unwrap())
             }
             _ => {}
         }
@@ -834,10 +834,10 @@ fn track_shard_col_pos() {
     for (_, node_id) in dfs.iter(top) {
         let node = plan.get_relation_node(node_id).unwrap();
         if let Relational::Join { .. } = node {
-            assert_eq!(&vec![4_usize, 5_usize], map.get(&node_id).unwrap());
+            assert_eq!([Some(4_usize), Some(5_usize)], *map.get(&node_id).unwrap());
         }
     }
-    assert_eq!(&vec![0_usize, 1_usize], map.get(&top).unwrap());
+    assert_eq!([Some(0_usize), Some(1_usize)], *map.get(&top).unwrap());
 
     let input = r#"select t_mv."bucket_id", "t2"."bucket_id" from "t2" join (
         select "bucket_id" from "test_space" where "id" = 1
@@ -851,10 +851,10 @@ fn track_shard_col_pos() {
     for (_, node_id) in dfs.iter(top) {
         let node = plan.get_relation_node(node_id).unwrap();
         if let Relational::Join { .. } = node {
-            assert_eq!(&vec![4_usize], map.get(&node_id).unwrap());
+            assert_eq!([Some(4_usize), None], *map.get(&node_id).unwrap());
         }
     }
-    assert_eq!(&vec![1_usize], map.get(&top).unwrap());
+    assert_eq!([Some(1_usize), None], *map.get(&top).unwrap());
 
     let input = r#"
     select "bucket_id", "e" from "t2"
@@ -874,7 +874,7 @@ fn track_shard_col_pos() {
     let plan = sql_to_optimized_ir(input, vec![]);
     let top = plan.get_top().unwrap();
     let map = plan.track_shard_column_pos(top).unwrap();
-    assert_eq!(&vec![0_usize], map.get(&top).unwrap());
+    assert_eq!([Some(0_usize), None], *map.get(&top).unwrap());
 
     let input = r#"
     select "e" from (select "bucket_id" as "e" from "t2")
@@ -882,7 +882,7 @@ fn track_shard_col_pos() {
     let plan = sql_to_optimized_ir(input, vec![]);
     let top = plan.get_top().unwrap();
     let map = plan.track_shard_column_pos(top).unwrap();
-    assert_eq!(&vec![0_usize], map.get(&top).unwrap());
+    assert_eq!([Some(0_usize), None], *map.get(&top).unwrap());
 
     let input = r#"
     select "e" as "bucket_id" from "t2"
diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs
index dfb6ede1f..5c65dd49f 100644
--- a/sbroad-core/src/ir.rs
+++ b/sbroad-core/src/ir.rs
@@ -5,7 +5,7 @@
 use base64ct::{Base64, Encoding};
 use serde::{Deserialize, Serialize};
 use smol_str::{format_smolstr, SmolStr, ToSmolStr};
-use std::collections::hash_map::IntoIter;
+use std::collections::hash_map::{Entry, IntoIter};
 use std::collections::{HashMap, HashSet};
 use std::fmt::{Display, Formatter};
 
@@ -1271,8 +1271,11 @@ impl Plan {
     }
 }
 
+/// Target positions in the reference.
+pub type Positions = [Option<Position>; 2];
+
 /// Relational node id -> positions of columns in output that refer to sharding column.
-pub type ShardColInfo = ahash::AHashMap<NodeId, Vec<Position>>;
+pub type ShardColInfo = ahash::AHashMap<NodeId, Positions>;
 
 impl Plan {
     /// Helper function to track position of the sharding column
@@ -1280,6 +1283,9 @@ impl Plan {
     ///
     /// # Errors
     /// - invalid references in the plan subtree
+    ///
+    /// # Panics
+    /// - plan contains invalid references
     pub fn track_shard_column_pos(&self, top_id: usize) -> Result<ShardColInfo, SbroadError> {
         let mut memo = ShardColInfo::with_capacity(REL_CAPACITY);
         let mut dfs = PostOrder::with_capacity(|x| self.nodes.rel_iter(x), REL_CAPACITY);
@@ -1291,7 +1297,7 @@ impl Plan {
                 Relational::ScanRelation { relation, .. } => {
                     let table = self.get_relation_or_error(relation)?;
                     if let Ok(Some(pos)) = table.get_bucket_id_position() {
-                        memo.insert(node_id, vec![pos]);
+                        memo.insert(node_id, [Some(pos), None]);
                     }
                     continue;
                 }
@@ -1333,28 +1339,37 @@ impl Plan {
                 // we need that ALL targets would refer to the shard column.
                 let mut refers_to_shard_col = true;
                 for target in targets {
-                    let child_id = children.get(*target).ok_or_else(|| {
-                        SbroadError::Invalid(
-                            Entity::Plan,
-                            Some(format_smolstr!(
-                                "invalid target ({target}) in reference with id: {ref_id}"
-                            )),
-                        )
-                    })?;
-                    let Some(candidates) = memo.get(child_id) else {
+                    let child_id = children.get(*target).expect("invalid reference");
+                    let Some(positions) = memo.get(child_id) else {
                         refers_to_shard_col = false;
                         break;
                     };
-                    if !candidates.contains(position) {
+                    if positions[0] != Some(*position) && positions[1] != Some(*position) {
                         refers_to_shard_col = false;
                         break;
                     }
                 }
 
                 if refers_to_shard_col {
-                    memo.entry(node_id)
-                        .and_modify(|v| v.push(pos))
-                        .or_insert(vec![pos]);
+                    match memo.entry(node_id) {
+                        Entry::Occupied(mut entry) => {
+                            let positions = entry.get_mut();
+                            if positions[0].is_none() {
+                                positions[0] = Some(pos);
+                            } else if positions[0] == Some(pos) {
+                                // Do nothing, we already have this position.
+                            } else if positions[1].is_none() {
+                                positions[1] = Some(pos);
+                            } else if positions[1] == Some(pos) {
+                                // Do nothing, we already have this position.
+                            } else {
+                                unreachable!("more than 2 pointers in the reference");
+                            }
+                        }
+                        Entry::Vacant(entry) => {
+                            entry.insert([Some(pos), None]);
+                        }
+                    }
                 }
             }
         }
diff --git a/sbroad-core/src/ir/aggregates.rs b/sbroad-core/src/ir/aggregates.rs
index 4694764a7..f55403626 100644
--- a/sbroad-core/src/ir/aggregates.rs
+++ b/sbroad-core/src/ir/aggregates.rs
@@ -5,12 +5,12 @@ use crate::ir::expression::cast::Type;
 use crate::ir::expression::Expression;
 use crate::ir::operator::Arithmetic;
 use crate::ir::relation::Type as RelType;
-use crate::ir::{Node, Plan};
+use crate::ir::{Node, Plan, Position};
 use std::collections::HashMap;
 use std::fmt::{Display, Formatter};
 use std::rc::Rc;
 
-use super::expression::{ColumnPositionMap, FunctionFeature, Position};
+use super::expression::{ColumnPositionMap, FunctionFeature};
 
 /// The kind of aggregate function
 ///
diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs
index 2da84617e..1a42867f2 100644
--- a/sbroad-core/src/ir/expression.rs
+++ b/sbroad-core/src/ir/expression.rs
@@ -17,6 +17,7 @@ use crate::errors::{Entity, SbroadError};
 use crate::ir::aggregates::AggregateKind;
 use crate::ir::operator::{Bool, Relational};
 use crate::ir::relation::Type;
+use crate::ir::Positions as Targets;
 
 use super::distribution::Distribution;
 use super::tree::traversal::{PostOrderWithFilter, EXPR_CAPACITY};
@@ -1218,9 +1219,9 @@ impl Plan {
         let mut filtered_children_row_list: Vec<(usize, usize, Vec<usize>)> = Vec::new();
 
         // Helper lambda to retrieve column positions we need to exclude from child `rel_id`.
-        let column_positions_to_exclude = |rel_id| -> Result<Vec<Position>, SbroadError> {
+        let column_positions_to_exclude = |rel_id| -> Result<Targets, SbroadError> {
             let positions = if need_sharding_column {
-                vec![]
+                [None, None]
             } else {
                 let mut info = self.track_shard_column_pos(rel_id)?;
                 info.remove(&rel_id).unwrap_or_default()
@@ -1261,7 +1262,7 @@ impl Plan {
                 let col_id = *child_node_row_list
                     .get(index)
                     .expect("Column id not found under relational child output");
-                if exclude_positions.contains(&index) {
+                if exclude_positions[0] == Some(index) || exclude_positions[1] == Some(index) {
                     continue;
                 }
                 filtered_children_row_list.push((index, col_id, source.targets()));
@@ -1284,7 +1285,7 @@ impl Plan {
                     let exclude_positions = column_positions_to_exclude(child_node_id)?;
 
                     for (pos, expr_id) in child_row_list.iter().enumerate() {
-                        if exclude_positions.contains(&pos) {
+                        if exclude_positions[0] == Some(pos) || exclude_positions[1] == Some(pos) {
                             continue;
                         }
                         filtered_children_row_list.push((pos, *expr_id, new_targets.clone()));
diff --git a/sbroad-core/src/ir/transformation/redistribution.rs b/sbroad-core/src/ir/transformation/redistribution.rs
index d713f6d2a..c6ed46c46 100644
--- a/sbroad-core/src/ir/transformation/redistribution.rs
+++ b/sbroad-core/src/ir/transformation/redistribution.rs
@@ -496,8 +496,8 @@ impl Plan {
                     )
                 })?;
                 let child_id = self.get_relational_child(rel_id, *child_idx)?;
-                if let Some(candidates) = shard_col_info.get(&child_id) {
-                    if !candidates.contains(ref_pos) {
+                if let Some(positions) = shard_col_info.get(&child_id) {
+                    if positions[0] != Some(*ref_pos) && positions[1] != Some(*ref_pos) {
                         continue;
                     }
                     if let Some(other_child_id) = memo.get(&pos_in_row) {
@@ -754,12 +754,31 @@ impl Plan {
         }
 
         let bool_nodes = self.get_bool_nodes_with_row_children(filter_id);
-        let shard_col_info = self.track_shard_column_pos(select_id)?;
         for (_, bool_node) in &bool_nodes {
             let bool_op = BoolOp::from_expr(self, *bool_node)?;
             self.set_distribution(bool_op.left)?;
             self.set_distribution(bool_op.right)?;
         }
+
+        // Check that we actually need to get sharding column positions (it is expensive).
+        let mut need_shard_col_info = false;
+        for (_, bool_node) in &bool_nodes {
+            let bool_op = BoolOp::from_expr(self, *bool_node)?;
+            if need_shard_col_info {
+                continue;
+            }
+            let left = self.get_additional_sq(select_id, bool_op.left)?;
+            let right = self.get_additional_sq(select_id, bool_op.right)?;
+            if left.is_some() || right.is_some() {
+                need_shard_col_info = true;
+            }
+        }
+
+        let shard_col_info = if need_shard_col_info {
+            self.track_shard_column_pos(select_id)?
+        } else {
+            ShardColInfo::new()
+        };
         for (_, bool_node) in &bool_nodes {
             let strategies =
                 self.get_sq_node_strategies_for_bool_op(select_id, *bool_node, &shard_col_info)?;
diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs
index 9ddeecf00..08e68daf9 100644
--- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs
+++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs
@@ -1723,7 +1723,9 @@ impl Plan {
                 };
                 let child_id = self.get_relational_from_reference_node(*expr_id)?;
                 if let Some(shard_positions) = shard_col_info.get(child_id) {
-                    if shard_positions.contains(position) {
+                    if shard_positions[0] == Some(*position)
+                        || shard_positions[1] == Some(*position)
+                    {
                         return Ok(false);
                     }
                 }
diff --git a/sbroad-core/src/ir/transformation/redistribution/left_join.rs b/sbroad-core/src/ir/transformation/redistribution/left_join.rs
index 064d3bb7f..ea0d83d0f 100644
--- a/sbroad-core/src/ir/transformation/redistribution/left_join.rs
+++ b/sbroad-core/src/ir/transformation/redistribution/left_join.rs
@@ -1,13 +1,12 @@
 //! Left Join trasformation logic when outer child has Global distribution
 //! and inner child has Segment or Any distribution.
 
-use smol_str::{format_smolstr, SmolStr};
+use smol_str::format_smolstr;
 
 use crate::{
     errors::{Entity, SbroadError},
     ir::{
         distribution::Distribution,
-        expression::Expression,
         operator::{JoinKind, Relational},
         Plan,
     },
@@ -91,42 +90,9 @@ impl Plan {
 }
 
 fn create_projection(plan: &mut Plan, join_id: usize) -> Result<usize, SbroadError> {
-    let proj_columns_names = collect_projection_columns(plan, join_id)?;
-    let proj_columns_refs: Vec<&str> = proj_columns_names.iter().map(SmolStr::as_str).collect();
-    let proj_id = plan.add_proj(join_id, &proj_columns_refs, false, false)?;
+    let proj_id = plan.add_proj(join_id, &[], false, false)?;
     let output_id = plan.get_relational_output(proj_id)?;
     plan.replace_parent_in_subtree(output_id, Some(join_id), Some(proj_id))?;
     plan.set_distribution(output_id)?;
     Ok(proj_id)
 }
-
-// Returns a list of column aliases from join node output.
-fn collect_projection_columns(
-    plan: &mut Plan,
-    join_id: usize,
-) -> Result<Vec<SmolStr>, SbroadError> {
-    // TODO: currently we use all columns from joined tables,
-    // but it is possible that a lot of columns are not used
-    // above in the plan, we can remove unused columns to
-    // reduce amount of data sent through the network.
-    // https://git.picodata.io/picodata/picodata/sbroad/-/issues/36
-    let output_id = plan.get_relational_output(join_id)?;
-    let columns_len = plan.get_row_list(output_id)?.len();
-    let mut projection_columns: Vec<SmolStr> = Vec::with_capacity(columns_len);
-    for idx in 0..columns_len {
-        let expr_id = *plan.get_row_list(output_id)?.get(idx).ok_or_else(|| {
-            SbroadError::UnexpectedNumberOfValues("output row size changed".into())
-        })?;
-        if let Expression::Alias { name, .. } = plan.get_expression_node(expr_id)? {
-            projection_columns.push(name.clone());
-        } else {
-            return Err(SbroadError::Invalid(
-                Entity::Node,
-                Some(format_smolstr!(
-                    "node ({join_id}) output columns is not alias"
-                )),
-            ));
-        }
-    }
-    Ok(projection_columns)
-}
-- 
GitLab