From 96891c12148401dd85b05bd2dc70e28d016324fe Mon Sep 17 00:00:00 2001 From: Denis Smirnov <sd@picodata.io> Date: Wed, 15 May 2024 19:45:54 +0700 Subject: [PATCH] perf: speed-up track_shard_column_pos() --- sbroad-core/src/frontend/sql/ir/tests.rs | 16 +++---- sbroad-core/src/ir.rs | 47 ++++++++++++------- sbroad-core/src/ir/aggregates.rs | 4 +- sbroad-core/src/ir/expression.rs | 9 ++-- .../src/ir/transformation/redistribution.rs | 25 ++++++++-- .../transformation/redistribution/groupby.rs | 4 +- .../redistribution/left_join.rs | 38 +-------------- 7 files changed, 73 insertions(+), 70 deletions(-) diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs index efd742c78..90405d333 100644 --- a/sbroad-core/src/frontend/sql/ir/tests.rs +++ b/sbroad-core/src/frontend/sql/ir/tests.rs @@ -813,10 +813,10 @@ fn track_shard_col_pos() { let node = plan.get_relation_node(node_id).unwrap(); match node { Relational::ScanRelation { .. } | Relational::Selection { .. } => { - assert_eq!(&vec![4_usize], map.get(&node_id).unwrap()) + assert_eq!([Some(4_usize), None], *map.get(&node_id).unwrap()) } Relational::Projection { .. } => { - assert_eq!(&vec![1_usize], map.get(&node_id).unwrap()) + assert_eq!([Some(1_usize), None], *map.get(&node_id).unwrap()) } _ => {} } @@ -834,10 +834,10 @@ fn track_shard_col_pos() { for (_, node_id) in dfs.iter(top) { let node = plan.get_relation_node(node_id).unwrap(); if let Relational::Join { .. } = node { - assert_eq!(&vec![4_usize, 5_usize], map.get(&node_id).unwrap()); + assert_eq!([Some(4_usize), Some(5_usize)], *map.get(&node_id).unwrap()); } } - assert_eq!(&vec![0_usize, 1_usize], map.get(&top).unwrap()); + assert_eq!([Some(0_usize), Some(1_usize)], *map.get(&top).unwrap()); let input = r#"select t_mv."bucket_id", "t2"."bucket_id" from "t2" join ( select "bucket_id" from "test_space" where "id" = 1 @@ -851,10 +851,10 @@ fn track_shard_col_pos() { for (_, node_id) in dfs.iter(top) { let node = plan.get_relation_node(node_id).unwrap(); if let Relational::Join { .. } = node { - assert_eq!(&vec![4_usize], map.get(&node_id).unwrap()); + assert_eq!([Some(4_usize), None], *map.get(&node_id).unwrap()); } } - assert_eq!(&vec![1_usize], map.get(&top).unwrap()); + assert_eq!([Some(1_usize), None], *map.get(&top).unwrap()); let input = r#" select "bucket_id", "e" from "t2" @@ -874,7 +874,7 @@ fn track_shard_col_pos() { let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); let map = plan.track_shard_column_pos(top).unwrap(); - assert_eq!(&vec![0_usize], map.get(&top).unwrap()); + assert_eq!([Some(0_usize), None], *map.get(&top).unwrap()); let input = r#" select "e" from (select "bucket_id" as "e" from "t2") @@ -882,7 +882,7 @@ fn track_shard_col_pos() { let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); let map = plan.track_shard_column_pos(top).unwrap(); - assert_eq!(&vec![0_usize], map.get(&top).unwrap()); + assert_eq!([Some(0_usize), None], *map.get(&top).unwrap()); let input = r#" select "e" as "bucket_id" from "t2" diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs index dfb6ede1f..5c65dd49f 100644 --- a/sbroad-core/src/ir.rs +++ b/sbroad-core/src/ir.rs @@ -5,7 +5,7 @@ use base64ct::{Base64, Encoding}; use serde::{Deserialize, Serialize}; use smol_str::{format_smolstr, SmolStr, ToSmolStr}; -use std::collections::hash_map::IntoIter; +use std::collections::hash_map::{Entry, IntoIter}; use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; @@ -1271,8 +1271,11 @@ impl Plan { } } +/// Target positions in the reference. +pub type Positions = [Option<Position>; 2]; + /// Relational node id -> positions of columns in output that refer to sharding column. -pub type ShardColInfo = ahash::AHashMap<NodeId, Vec<Position>>; +pub type ShardColInfo = ahash::AHashMap<NodeId, Positions>; impl Plan { /// Helper function to track position of the sharding column @@ -1280,6 +1283,9 @@ impl Plan { /// /// # Errors /// - invalid references in the plan subtree + /// + /// # Panics + /// - plan contains invalid references pub fn track_shard_column_pos(&self, top_id: usize) -> Result<ShardColInfo, SbroadError> { let mut memo = ShardColInfo::with_capacity(REL_CAPACITY); let mut dfs = PostOrder::with_capacity(|x| self.nodes.rel_iter(x), REL_CAPACITY); @@ -1291,7 +1297,7 @@ impl Plan { Relational::ScanRelation { relation, .. } => { let table = self.get_relation_or_error(relation)?; if let Ok(Some(pos)) = table.get_bucket_id_position() { - memo.insert(node_id, vec![pos]); + memo.insert(node_id, [Some(pos), None]); } continue; } @@ -1333,28 +1339,37 @@ impl Plan { // we need that ALL targets would refer to the shard column. let mut refers_to_shard_col = true; for target in targets { - let child_id = children.get(*target).ok_or_else(|| { - SbroadError::Invalid( - Entity::Plan, - Some(format_smolstr!( - "invalid target ({target}) in reference with id: {ref_id}" - )), - ) - })?; - let Some(candidates) = memo.get(child_id) else { + let child_id = children.get(*target).expect("invalid reference"); + let Some(positions) = memo.get(child_id) else { refers_to_shard_col = false; break; }; - if !candidates.contains(position) { + if positions[0] != Some(*position) && positions[1] != Some(*position) { refers_to_shard_col = false; break; } } if refers_to_shard_col { - memo.entry(node_id) - .and_modify(|v| v.push(pos)) - .or_insert(vec![pos]); + match memo.entry(node_id) { + Entry::Occupied(mut entry) => { + let positions = entry.get_mut(); + if positions[0].is_none() { + positions[0] = Some(pos); + } else if positions[0] == Some(pos) { + // Do nothing, we already have this position. + } else if positions[1].is_none() { + positions[1] = Some(pos); + } else if positions[1] == Some(pos) { + // Do nothing, we already have this position. + } else { + unreachable!("more than 2 pointers in the reference"); + } + } + Entry::Vacant(entry) => { + entry.insert([Some(pos), None]); + } + } } } } diff --git a/sbroad-core/src/ir/aggregates.rs b/sbroad-core/src/ir/aggregates.rs index 4694764a7..f55403626 100644 --- a/sbroad-core/src/ir/aggregates.rs +++ b/sbroad-core/src/ir/aggregates.rs @@ -5,12 +5,12 @@ use crate::ir::expression::cast::Type; use crate::ir::expression::Expression; use crate::ir::operator::Arithmetic; use crate::ir::relation::Type as RelType; -use crate::ir::{Node, Plan}; +use crate::ir::{Node, Plan, Position}; use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::rc::Rc; -use super::expression::{ColumnPositionMap, FunctionFeature, Position}; +use super::expression::{ColumnPositionMap, FunctionFeature}; /// The kind of aggregate function /// diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs index 2da84617e..1a42867f2 100644 --- a/sbroad-core/src/ir/expression.rs +++ b/sbroad-core/src/ir/expression.rs @@ -17,6 +17,7 @@ use crate::errors::{Entity, SbroadError}; use crate::ir::aggregates::AggregateKind; use crate::ir::operator::{Bool, Relational}; use crate::ir::relation::Type; +use crate::ir::Positions as Targets; use super::distribution::Distribution; use super::tree::traversal::{PostOrderWithFilter, EXPR_CAPACITY}; @@ -1218,9 +1219,9 @@ impl Plan { let mut filtered_children_row_list: Vec<(usize, usize, Vec<usize>)> = Vec::new(); // Helper lambda to retrieve column positions we need to exclude from child `rel_id`. - let column_positions_to_exclude = |rel_id| -> Result<Vec<Position>, SbroadError> { + let column_positions_to_exclude = |rel_id| -> Result<Targets, SbroadError> { let positions = if need_sharding_column { - vec![] + [None, None] } else { let mut info = self.track_shard_column_pos(rel_id)?; info.remove(&rel_id).unwrap_or_default() @@ -1261,7 +1262,7 @@ impl Plan { let col_id = *child_node_row_list .get(index) .expect("Column id not found under relational child output"); - if exclude_positions.contains(&index) { + if exclude_positions[0] == Some(index) || exclude_positions[1] == Some(index) { continue; } filtered_children_row_list.push((index, col_id, source.targets())); @@ -1284,7 +1285,7 @@ impl Plan { let exclude_positions = column_positions_to_exclude(child_node_id)?; for (pos, expr_id) in child_row_list.iter().enumerate() { - if exclude_positions.contains(&pos) { + if exclude_positions[0] == Some(pos) || exclude_positions[1] == Some(pos) { continue; } filtered_children_row_list.push((pos, *expr_id, new_targets.clone())); diff --git a/sbroad-core/src/ir/transformation/redistribution.rs b/sbroad-core/src/ir/transformation/redistribution.rs index d713f6d2a..c6ed46c46 100644 --- a/sbroad-core/src/ir/transformation/redistribution.rs +++ b/sbroad-core/src/ir/transformation/redistribution.rs @@ -496,8 +496,8 @@ impl Plan { ) })?; let child_id = self.get_relational_child(rel_id, *child_idx)?; - if let Some(candidates) = shard_col_info.get(&child_id) { - if !candidates.contains(ref_pos) { + if let Some(positions) = shard_col_info.get(&child_id) { + if positions[0] != Some(*ref_pos) && positions[1] != Some(*ref_pos) { continue; } if let Some(other_child_id) = memo.get(&pos_in_row) { @@ -754,12 +754,31 @@ impl Plan { } let bool_nodes = self.get_bool_nodes_with_row_children(filter_id); - let shard_col_info = self.track_shard_column_pos(select_id)?; for (_, bool_node) in &bool_nodes { let bool_op = BoolOp::from_expr(self, *bool_node)?; self.set_distribution(bool_op.left)?; self.set_distribution(bool_op.right)?; } + + // Check that we actually need to get sharding column positions (it is expensive). + let mut need_shard_col_info = false; + for (_, bool_node) in &bool_nodes { + let bool_op = BoolOp::from_expr(self, *bool_node)?; + if need_shard_col_info { + continue; + } + let left = self.get_additional_sq(select_id, bool_op.left)?; + let right = self.get_additional_sq(select_id, bool_op.right)?; + if left.is_some() || right.is_some() { + need_shard_col_info = true; + } + } + + let shard_col_info = if need_shard_col_info { + self.track_shard_column_pos(select_id)? + } else { + ShardColInfo::new() + }; for (_, bool_node) in &bool_nodes { let strategies = self.get_sq_node_strategies_for_bool_op(select_id, *bool_node, &shard_col_info)?; diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs index 9ddeecf00..08e68daf9 100644 --- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs +++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs @@ -1723,7 +1723,9 @@ impl Plan { }; let child_id = self.get_relational_from_reference_node(*expr_id)?; if let Some(shard_positions) = shard_col_info.get(child_id) { - if shard_positions.contains(position) { + if shard_positions[0] == Some(*position) + || shard_positions[1] == Some(*position) + { return Ok(false); } } diff --git a/sbroad-core/src/ir/transformation/redistribution/left_join.rs b/sbroad-core/src/ir/transformation/redistribution/left_join.rs index 064d3bb7f..ea0d83d0f 100644 --- a/sbroad-core/src/ir/transformation/redistribution/left_join.rs +++ b/sbroad-core/src/ir/transformation/redistribution/left_join.rs @@ -1,13 +1,12 @@ //! Left Join trasformation logic when outer child has Global distribution //! and inner child has Segment or Any distribution. -use smol_str::{format_smolstr, SmolStr}; +use smol_str::format_smolstr; use crate::{ errors::{Entity, SbroadError}, ir::{ distribution::Distribution, - expression::Expression, operator::{JoinKind, Relational}, Plan, }, @@ -91,42 +90,9 @@ impl Plan { } fn create_projection(plan: &mut Plan, join_id: usize) -> Result<usize, SbroadError> { - let proj_columns_names = collect_projection_columns(plan, join_id)?; - let proj_columns_refs: Vec<&str> = proj_columns_names.iter().map(SmolStr::as_str).collect(); - let proj_id = plan.add_proj(join_id, &proj_columns_refs, false, false)?; + let proj_id = plan.add_proj(join_id, &[], false, false)?; let output_id = plan.get_relational_output(proj_id)?; plan.replace_parent_in_subtree(output_id, Some(join_id), Some(proj_id))?; plan.set_distribution(output_id)?; Ok(proj_id) } - -// Returns a list of column aliases from join node output. -fn collect_projection_columns( - plan: &mut Plan, - join_id: usize, -) -> Result<Vec<SmolStr>, SbroadError> { - // TODO: currently we use all columns from joined tables, - // but it is possible that a lot of columns are not used - // above in the plan, we can remove unused columns to - // reduce amount of data sent through the network. - // https://git.picodata.io/picodata/picodata/sbroad/-/issues/36 - let output_id = plan.get_relational_output(join_id)?; - let columns_len = plan.get_row_list(output_id)?.len(); - let mut projection_columns: Vec<SmolStr> = Vec::with_capacity(columns_len); - for idx in 0..columns_len { - let expr_id = *plan.get_row_list(output_id)?.get(idx).ok_or_else(|| { - SbroadError::UnexpectedNumberOfValues("output row size changed".into()) - })?; - if let Expression::Alias { name, .. } = plan.get_expression_node(expr_id)? { - projection_columns.push(name.clone()); - } else { - return Err(SbroadError::Invalid( - Entity::Node, - Some(format_smolstr!( - "node ({join_id}) output columns is not alias" - )), - )); - } - } - Ok(projection_columns) -} -- GitLab