diff --git a/sbroad-core/src/backend/sql/tree/tests.rs b/sbroad-core/src/backend/sql/tree/tests.rs index 630548633f91c6161c28ea4dc50409f169b31977..8a2ac991e26db2e2a4dc0b58e628031beacd2255 100644 --- a/sbroad-core/src/backend/sql/tree/tests.rs +++ b/sbroad-core/src/backend/sql/tree/tests.rs @@ -50,6 +50,8 @@ fn sql_order_selection() { .join("sql_order_selection.yaml"); let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(expected_plan, plan); let exec_plan = ExecutionPlan::from(plan.clone()); let top_id = exec_plan.get_ir_plan().get_top().unwrap(); @@ -140,6 +142,8 @@ fn sql_arbitrary_projection_plan() { let proj_id = plan.add_proj_internal(scan_id, &[alias_id], false).unwrap(); plan.set_top(proj_id).unwrap(); + // this field is not serialized, do not check it + plan.context = None; // check the plan let path = Path::new("") diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs index 69453eeb7b8ba08e10c0d54ff55721898315cf4a..b99736b656af1a0d3096847a2bddec03f3ec9d86 100644 --- a/sbroad-core/src/frontend/sql/ir/tests.rs +++ b/sbroad-core/src/frontend/sql/ir/tests.rs @@ -1,11 +1,12 @@ use crate::errors::SbroadError; -use crate::executor::engine::mock::RouterConfigurationMock; use crate::frontend::sql::ast::AbstractSyntaxTree; use crate::frontend::Ast; use crate::ir::operator::Relational; use crate::ir::transformation::helpers::sql_to_optimized_ir; use crate::ir::tree::traversal::PostOrder; use crate::ir::value::Value; +use crate::ir::Plan; +use crate::{executor::engine::mock::RouterConfigurationMock, ir::Positions}; use pretty_assertions::assert_eq; use time::{format_description, OffsetDateTime, Time}; @@ -798,6 +799,16 @@ vtable_max_rows = 5000 assert_eq!(expected_explain, plan.as_explain().unwrap()); } +impl Plan { + fn get_positions(&self, node_id: usize) -> Option<Positions> { + let mut context = self.context_mut(); + context + .get_shard_columns_positions(node_id, self) + .unwrap() + .copied() + } +} + #[test] fn track_shard_col_pos() { let input = r#" @@ -807,16 +818,15 @@ fn track_shard_col_pos() { "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), 10); for (_, node_id) in dfs.iter(top) { let node = plan.get_relation_node(node_id).unwrap(); match node { Relational::ScanRelation { .. } | Relational::Selection { .. } => { - assert_eq!([Some(4_usize), None], *map.get(&node_id).unwrap()) + assert_eq!([Some(4_usize), None], plan.get_positions(node_id).unwrap()) } Relational::Projection { .. } => { - assert_eq!([Some(1_usize), None], *map.get(&node_id).unwrap()) + assert_eq!([Some(1_usize), None], plan.get_positions(node_id).unwrap()) } _ => {} } @@ -829,15 +839,20 @@ fn track_shard_col_pos() { "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), 10); for (_, node_id) in dfs.iter(top) { let node = plan.get_relation_node(node_id).unwrap(); if let Relational::Join { .. } = node { - assert_eq!([Some(4_usize), Some(5_usize)], *map.get(&node_id).unwrap()); + assert_eq!( + [Some(4_usize), Some(5_usize)], + plan.get_positions(node_id).unwrap() + ); } } - assert_eq!([Some(0_usize), Some(1_usize)], *map.get(&top).unwrap()); + assert_eq!( + [Some(0_usize), Some(1_usize)], + plan.get_positions(top).unwrap() + ); let input = r#"select t_mv."bucket_id", "t2"."bucket_id" from "t2" join ( select "bucket_id" from "test_space" where "id" = 1 @@ -846,15 +861,14 @@ fn track_shard_col_pos() { "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), 10); for (_, node_id) in dfs.iter(top) { let node = plan.get_relation_node(node_id).unwrap(); if let Relational::Join { .. } = node { - assert_eq!([Some(4_usize), None], *map.get(&node_id).unwrap()); + assert_eq!([Some(4_usize), None], plan.get_positions(node_id).unwrap()); } } - assert_eq!([Some(1_usize), None], *map.get(&top).unwrap()); + assert_eq!([Some(1_usize), None], plan.get_positions(top).unwrap()); let input = r#" select "bucket_id", "e" from "t2" @@ -863,8 +877,7 @@ fn track_shard_col_pos() { "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); - assert_eq!(None, map.get(&top)); + assert_eq!(None, plan.get_positions(top)); let input = r#" select "bucket_id", "e" from "t2" @@ -873,24 +886,21 @@ fn track_shard_col_pos() { "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); - assert_eq!([Some(0_usize), None], *map.get(&top).unwrap()); + assert_eq!([Some(0_usize), None], plan.get_positions(top).unwrap()); let input = r#" select "e" from (select "bucket_id" as "e" from "t2") "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); - assert_eq!([Some(0_usize), None], *map.get(&top).unwrap()); + assert_eq!([Some(0_usize), None], plan.get_positions(top).unwrap()); let input = r#" select "e" as "bucket_id" from "t2" "#; let plan = sql_to_optimized_ir(input, vec![]); let top = plan.get_top().unwrap(); - let map = plan.track_shard_column_pos(top).unwrap(); - assert_eq!(None, map.get(&top)); + assert_eq!(None, plan.get_positions(top)); } #[test] diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs index 5c65dd49f7b5f957c29639b16979c766576ff8ef..aead78825415eea0912c8f1afad71da800d6b123 100644 --- a/sbroad-core/src/ir.rs +++ b/sbroad-core/src/ir.rs @@ -5,7 +5,8 @@ use base64ct::{Base64, Encoding}; use serde::{Deserialize, Serialize}; use smol_str::{format_smolstr, SmolStr, ToSmolStr}; -use std::collections::hash_map::{Entry, IntoIter}; +use std::cell::{RefCell, RefMut}; +use std::collections::hash_map::IntoIter; use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; @@ -422,6 +423,36 @@ pub struct Plan { /// See `apply_options`. pub options: Options, pub version_map: TableVersionMap, + /// Exists only on the router during plan build. + /// RefCell is used because context can be mutated + /// independently of the plan. It is just stored + /// in the plan for convenience: otherwise we'd + /// have to explictly pass context to every method + /// of the pipeline. + #[serde(skip)] + pub context: Option<RefCell<BuildContext>>, +} + +/// Helper structures used to build the plan +/// on the router. +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct BuildContext { + shard_col_info: ShardColumnsMap, +} + +impl BuildContext { + /// Returns positions in node's output + /// referring to the shard column. + /// + /// # Errors + /// - Invalid plan + pub fn get_shard_columns_positions( + &mut self, + node_id: NodeId, + plan: &Plan, + ) -> Result<Option<&Positions>, SbroadError> { + self.shard_col_info.get(node_id, plan) + } } impl Default for Plan { @@ -432,6 +463,17 @@ impl Default for Plan { #[allow(dead_code)] impl Plan { + /// Get mut reference to build context + /// + /// # Panics + /// - There are other mut refs + pub fn context_mut(&self) -> RefMut<'_, BuildContext> { + self.context + .as_ref() + .expect("context always exists during plan build") + .borrow_mut() + } + /// Add relation to the plan. /// /// If relation already exists, do nothing. @@ -476,6 +518,7 @@ impl Plan { options: Options::default(), version_map: TableVersionMap::new(), pg_params_map: HashMap::new(), + context: Some(RefCell::new(BuildContext::default())), } } @@ -1277,104 +1320,176 @@ pub type Positions = [Option<Position>; 2]; /// Relational node id -> positions of columns in output that refer to sharding column. pub type ShardColInfo = ahash::AHashMap<NodeId, Positions>; -impl Plan { - /// Helper function to track position of the sharding column - /// for any relational node in the subtree defined by `top_id`. +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct ShardColumnsMap { + /// Maps node id to positions of bucket_id column in + /// the node output. Currently we track only two + /// bucket_id columns appearences for perf reasons. + pub memo: ahash::AHashMap<NodeId, Positions>, + /// ids of nodes which were inserted into the middle + /// of the plan and changed the bucket_id columns + /// positions and thus invalidated all the nodes + /// in the memo which are located above this node. + pub invalid_ids: ahash::AHashSet<NodeId>, +} + +impl ShardColumnsMap { + /// Update information about node's sharding column positions + /// assuming that node's children positions were already computed. /// /// # Errors - /// - invalid references in the plan subtree + /// - invalid plan /// /// # Panics - /// - plan contains invalid references - pub fn track_shard_column_pos(&self, top_id: usize) -> Result<ShardColInfo, SbroadError> { - let mut memo = ShardColInfo::with_capacity(REL_CAPACITY); - let mut dfs = PostOrder::with_capacity(|x| self.nodes.rel_iter(x), REL_CAPACITY); - - for (_, node_id) in dfs.iter(top_id) { - let node = self.get_relation_node(node_id)?; - - match node { - Relational::ScanRelation { relation, .. } => { - let table = self.get_relation_or_error(relation)?; - if let Ok(Some(pos)) = table.get_bucket_id_position() { - memo.insert(node_id, [Some(pos), None]); - } - continue; + /// - invalid plan + pub fn update_node(&mut self, node_id: NodeId, plan: &Plan) -> Result<(), SbroadError> { + let node = plan.get_relation_node(node_id)?; + match node { + Relational::ScanRelation { relation, .. } => { + let table = plan.get_relation_or_error(relation)?; + if let Ok(Some(pos)) = table.get_bucket_id_position() { + self.memo.insert(node_id, [Some(pos), None]); } - Relational::Motion { policy, .. } => { - // Any motion node that moves data invalidates - // bucket_id column selected from that space. - // Even Segment policy is no help, because it only - // creates index on virtual table but does not actually - // add or update bucket_id column. - if !matches!(policy, MotionPolicy::Local | MotionPolicy::LocalSegment(_)) { - continue; - } + return Ok(()); + } + Relational::Motion { policy, .. } => { + // Any motion node that moves data invalidates + // bucket_id column selected from that space. + // Even Segment policy is no help, because it only + // creates index on virtual table but does not actually + // add or update bucket_id column. + if !matches!(policy, MotionPolicy::Local | MotionPolicy::LocalSegment(_)) { + return Ok(()); } - _ => {} } + _ => {} + } - let children = node.children(); - if children.is_empty() { + let children = node.children(); + if children.is_empty() { + return Ok(()); + }; + let children_contain_shard_positions = children.iter().any(|c| self.memo.contains_key(c)); + if !children_contain_shard_positions { + // The children do not contain any shard columns, no need to check + // the output. + return Ok(()); + } + + let output_id = node.output(); + let output_len = plan.get_row_list(output_id)?.len(); + let mut new_positions = [None, None]; + for pos in 0..output_len { + let output = plan.get_row_list(output_id)?; + let alias_id = output.get(pos).expect("can't fail"); + let ref_id = plan.get_child_under_alias(*alias_id)?; + // If there is a parameter under alias + // and we haven't bound parameters yet, + // we will get an error. + let Ok(Expression::Reference { + targets, position, .. + }) = plan.get_expression_node(ref_id) + else { + continue; + }; + let Some(targets) = targets else { continue; }; - let output = self.get_row_list(node.output())?; - for (pos, alias_id) in output.iter().enumerate() { - let ref_id = self.get_child_under_alias(*alias_id)?; - // If there is a parameter under alias - // and we haven't bound parameters yet, - // we will get an error. - let Ok(Expression::Reference { - targets, position, .. - }) = self.get_expression_node(ref_id) - else { - continue; - }; - let Some(targets) = targets else { - continue; + let children = plan.get_relational_children(node_id)?; + // For node with multiple targets (Union, Except, Intersect) + // we need that ALL targets would refer to the shard column. + let mut refers_to_shard_col = true; + for target in targets { + let child_id = children.get(*target).expect("invalid reference"); + let Some(positions) = self.memo.get(child_id) else { + refers_to_shard_col = false; + break; }; + if positions[0] != Some(*position) && positions[1] != Some(*position) { + refers_to_shard_col = false; + break; + } + } - // For node with multiple targets (Union, Except, Intersect) - // we need that ALL targets would refer to the shard column. - let mut refers_to_shard_col = true; - for target in targets { - let child_id = children.get(*target).expect("invalid reference"); - let Some(positions) = memo.get(child_id) else { - refers_to_shard_col = false; - break; - }; - if positions[0] != Some(*position) && positions[1] != Some(*position) { - refers_to_shard_col = false; - break; - } + if refers_to_shard_col { + if new_positions[0].is_none() { + new_positions[0] = Some(pos); + } else if new_positions[0] == Some(pos) { + // Do nothing, we already have this position. + } else { + new_positions[1] = Some(pos); + + // We already tracked two positions, + // the node may have more, but we assume + // that's really rare case and just don't + // want to allocate more memory to track them. + break; } + } + } + if new_positions[0].is_some() { + self.memo.insert(node_id, new_positions); + } + Ok(()) + } - if refers_to_shard_col { - match memo.entry(node_id) { - Entry::Occupied(mut entry) => { - let positions = entry.get_mut(); - if positions[0].is_none() { - positions[0] = Some(pos); - } else if positions[0] == Some(pos) { - // Do nothing, we already have this position. - } else if positions[1].is_none() { - positions[1] = Some(pos); - } else if positions[1] == Some(pos) { - // Do nothing, we already have this position. - } else { - unreachable!("more than 2 pointers in the reference"); - } - } - Entry::Vacant(entry) => { - entry.insert([Some(pos), None]); - } - } + /// Handle node insertion into the middle of the plan. + /// Node insertion may invalidate already computed positions + /// for all the nodes located above it (on the path from root to + /// the inserted node). Currently only node that invalidates already + /// computed positions is Motion (non-local). + /// + /// # Errors + /// - Invalid plan + /// + /// # Panics + /// - invalid plan + pub fn handle_node_insertion( + &mut self, + node_id: NodeId, + plan: &Plan, + ) -> Result<(), SbroadError> { + let node = plan.get_relation_node(node_id)?; + if let Relational::Motion { + policy, children, .. + } = node + { + if matches!(policy, MotionPolicy::Local | MotionPolicy::LocalSegment(_)) { + return Ok(()); + } + let child_id = children.first().expect("invalid plan"); + if let Some(positions) = self.memo.get(child_id) { + if positions[0].is_some() || positions[1].is_some() { + self.invalid_ids.insert(node_id); } } } + Ok(()) + } - Ok(memo) + /// Get positions in the node's output which refer + /// to the sharding columns. + /// + /// # Errors + /// - Invalid plan + pub fn get(&mut self, id: NodeId, plan: &Plan) -> Result<Option<&Positions>, SbroadError> { + if !self.invalid_ids.is_empty() { + self.update_subtree(id, plan)?; + } + Ok(self.memo.get(&id)) + } + + fn update_subtree(&mut self, node_id: NodeId, plan: &Plan) -> Result<(), SbroadError> { + let mut dfs = PostOrder::with_capacity(|x| plan.nodes.rel_iter(x), REL_CAPACITY); + for (_, id) in dfs.iter(node_id) { + self.update_node(id, plan)?; + self.invalid_ids.remove(&id); + } + if plan.get_top()? != node_id { + self.invalid_ids.insert(node_id); + } + Ok(()) } } diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs index 1a42867f2d64176abc022b8f7e49678c2ed0eb37..1442e4bd5bba875c3496c4b35be40fde62f50963 100644 --- a/sbroad-core/src/ir/expression.rs +++ b/sbroad-core/src/ir/expression.rs @@ -1223,8 +1223,11 @@ impl Plan { let positions = if need_sharding_column { [None, None] } else { - let mut info = self.track_shard_column_pos(rel_id)?; - info.remove(&rel_id).unwrap_or_default() + let mut context = self.context_mut(); + context + .get_shard_columns_positions(rel_id, self)? + .copied() + .unwrap_or_default() }; Ok(positions) }; diff --git a/sbroad-core/src/ir/operator.rs b/sbroad-core/src/ir/operator.rs index b6594c194b7c6deebc4818b9ed53dd768c682767..b986325fd15f7d7a51808204494380d385d750ef 100644 --- a/sbroad-core/src/ir/operator.rs +++ b/sbroad-core/src/ir/operator.rs @@ -8,6 +8,7 @@ use ahash::RandomState; use crate::collection; use serde::{Deserialize, Serialize}; use smol_str::{format_smolstr, SmolStr, ToSmolStr}; +use std::borrow::BorrowMut; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; @@ -882,6 +883,17 @@ impl Relational { } impl Plan { + /// Add relational node to plan arena and update shard columns info. + /// + /// # Errors + /// - failed to oupdate shard columns info due to invalid plan subtree + pub fn add_relational(&mut self, node: Relational) -> Result<usize, SbroadError> { + let rel_id = self.nodes.push(Node::Relational(node)); + let mut context = self.context_mut(); + context.shard_col_info.update_node(rel_id, self)?; + Ok(rel_id) + } + /// Adds delete node. /// /// # Errors @@ -893,7 +905,7 @@ impl Plan { children: vec![child_id], output, }; - let delete_id = self.nodes.push(Node::Relational(delete)); + let delete_id = self.add_relational(delete)?; self.replace_parent_in_subtree(output, None, Some(delete_id))?; Ok(delete_id) } @@ -928,7 +940,7 @@ impl Plan { output, }; - let except_id = self.nodes.push(Node::Relational(except)); + let except_id = self.add_relational(except)?; self.replace_parent_in_subtree(output, None, Some(except_id))?; Ok(except_id) } @@ -1178,7 +1190,7 @@ impl Plan { output: proj_output, is_distinct: false, }; - let proj_id = self.nodes.push(Node::Relational(proj_node)); + let proj_id = self.add_relational(proj_node)?; self.replace_parent_in_subtree(proj_output, None, Some(proj_id))?; let upd_output = self.add_row_for_output(proj_id, &[], false)?; let update_node = Relational::Update { @@ -1189,7 +1201,7 @@ impl Plan { output: upd_output, strategy: update_kind, }; - let update_id = self.nodes.push(Node::Relational(update_node)); + let update_id = self.add_relational(update_node)?; self.replace_parent_in_subtree(upd_output, None, Some(update_id))?; Ok(update_id) @@ -1307,7 +1319,7 @@ impl Plan { alias: alias.map(SmolStr::from), }; - let scan_id = nodes.push(Node::Relational(scan)); + let scan_id = self.add_relational(scan)?; self.replace_parent_in_subtree(output_id, None, Some(scan_id))?; return Ok(scan_id); } @@ -1433,7 +1445,7 @@ impl Plan { kind, }; - let join_id = self.nodes.push(Node::Relational(join)); + let join_id = self.add_relational(join)?; self.replace_parent_in_subtree(condition, None, Some(join_id))?; self.replace_parent_in_subtree(output, None, Some(join_id))?; return Ok(join_id); @@ -1497,8 +1509,13 @@ impl Plan { output, is_child_subquery, }; - let motion_id = self.nodes.push(Node::Relational(motion)); + let motion_id = self.add_relational(motion)?; self.replace_parent_in_subtree(output, None, Some(motion_id))?; + let mut context = self.context_mut(); + context + .shard_col_info + .borrow_mut() + .handle_node_insertion(motion_id, self)?; Ok(motion_id) } @@ -1524,7 +1541,7 @@ impl Plan { is_distinct, }; - let proj_id = self.nodes.push(Node::Relational(proj)); + let proj_id = self.add_relational(proj)?; self.replace_parent_in_subtree(output, None, Some(proj_id))?; Ok(proj_id) } @@ -1548,7 +1565,7 @@ impl Plan { is_distinct, }; - let proj_id = self.nodes.push(Node::Relational(proj)); + let proj_id = self.add_relational(proj)?; self.replace_parent_in_subtree(output, None, Some(proj_id))?; Ok(proj_id) } @@ -1582,7 +1599,7 @@ impl Plan { output, }; - let select_id = self.nodes.push(Node::Relational(select)); + let select_id = self.add_relational(select)?; self.replace_parent_in_subtree(filter, None, Some(select_id))?; self.replace_parent_in_subtree(output, None, Some(select_id))?; Ok(select_id) @@ -1626,7 +1643,7 @@ impl Plan { output, }; - let having_id = self.nodes.push(Node::Relational(having)); + let having_id = self.add_relational(having)?; self.replace_parent_in_subtree(filter, None, Some(having_id))?; self.replace_parent_in_subtree(output, None, Some(having_id))?; Ok(having_id) @@ -1652,7 +1669,7 @@ impl Plan { order_by_elements: order_by_elements.clone(), }; - let plan_order_by_id = self.nodes.push(Node::Relational(order_by)); + let plan_order_by_id = self.add_relational(order_by)?; for order_by_element in order_by_elements { if let OrderByElement { entity: OrderByEntity::Expression { expr_id }, @@ -1686,7 +1703,7 @@ impl Plan { output, }; - let sq_id = self.nodes.push(Node::Relational(sq)); + let sq_id = self.add_relational(sq)?; self.replace_parent_in_subtree(output, None, Some(sq_id))?; Ok(sq_id) } @@ -1757,7 +1774,7 @@ impl Plan { child: child_id, output, }; - let cte_id = self.nodes.push(Node::Relational(cte)); + let cte_id = self.add_relational(cte)?; Ok(cte_id) } @@ -1804,7 +1821,7 @@ impl Plan { } }; - let union_id = self.nodes.push(Node::Relational(union_all)); + let union_id = self.add_relational(union_all)?; self.replace_parent_in_subtree(output, None, Some(union_id))?; Ok(union_id) } @@ -1836,7 +1853,7 @@ impl Plan { data: row_id, children: vec![], }; - let values_row_id = self.nodes.push(Node::Relational(values_row)); + let values_row_id = self.add_relational(values_row)?; self.replace_parent_in_subtree(row_id, None, Some(values_row_id))?; Ok(values_row_id) } @@ -1916,7 +1933,7 @@ impl Plan { output, children: value_rows, }; - let values_id = self.nodes.push(Node::Relational(values)); + let values_id = self.add_relational(values)?; self.replace_parent_in_subtree(output, None, Some(values_id))?; Ok(values_id) } diff --git a/sbroad-core/src/ir/operator/tests.rs b/sbroad-core/src/ir/operator/tests.rs index 5816c9eac706c4de98dc51249c4b68f26f9dc752..faf8643e315ff85507c8eadea0b45494de679563 100644 --- a/sbroad-core/src/ir/operator/tests.rs +++ b/sbroad-core/src/ir/operator/tests.rs @@ -86,6 +86,8 @@ fn scan_rel_serialized() { .join("operator") .join("scan_rel.yaml"); let s = fs::read_to_string(path).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(plan, Plan::from_yaml(&s).unwrap()); } @@ -491,6 +493,8 @@ fn selection_with_sub_query() { let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(expected_plan, plan); } diff --git a/sbroad-core/src/ir/transformation/redistribution.rs b/sbroad-core/src/ir/transformation/redistribution.rs index eb2a31d7c9821c748b57fcb12e703870b8761f9b..4aec1fc8dafb665ee02b64cc9ed3836ee810753a 100644 --- a/sbroad-core/src/ir/transformation/redistribution.rs +++ b/sbroad-core/src/ir/transformation/redistribution.rs @@ -20,7 +20,7 @@ use crate::ir::tree::traversal::{ BreadthFirst, LevelNode, PostOrder, PostOrderWithFilter, EXPR_CAPACITY, REL_CAPACITY, }; use crate::ir::value::Value; -use crate::ir::{Node, Plan, ShardColInfo}; +use crate::ir::{Node, Plan}; use crate::otm::child_span; use sbroad_proc::otm_child_span; @@ -447,7 +447,6 @@ impl Plan { right_row_id: usize, rel_id: usize, op: &Bool, - shard_col_info: &ShardColInfo, ) -> Result<bool, SbroadError> { if !(Bool::Eq == *op || Bool::In == *op) { return Ok(false); @@ -492,7 +491,8 @@ impl Plan { ) })?; let child_id = self.get_relational_child(rel_id, *child_idx)?; - if let Some(positions) = shard_col_info.get(&child_id) { + let mut context = self.context_mut(); + if let Some(positions) = context.get_shard_columns_positions(child_id, self)? { if positions[0] != Some(*ref_pos) && positions[1] != Some(*ref_pos) { continue; } @@ -629,7 +629,6 @@ impl Plan { &self, rel_id: usize, op_id: usize, - shard_col_info: &ShardColInfo, ) -> Result<Vec<(usize, MotionPolicy)>, SbroadError> { let mut strategies: Vec<(usize, MotionPolicy)> = Vec::new(); let bool_op = BoolOp::from_expr(self, op_id)?; @@ -639,13 +638,7 @@ impl Plan { // If we eq/in where both rows contain bucket_id in same position // we don't need Motion nodes. if (left.is_some() || right.is_some()) - && self.has_eq_on_bucket_id( - bool_op.left, - bool_op.right, - rel_id, - &bool_op.op, - shard_col_info, - )? + && self.has_eq_on_bucket_id(bool_op.left, bool_op.right, rel_id, &bool_op.op)? { if let Some(left_sq) = left { strategies.push((left_sq, MotionPolicy::None)); @@ -756,28 +749,8 @@ impl Plan { self.set_distribution(bool_op.right)?; } - // Check that we actually need to get sharding column positions (it is expensive). - let mut need_shard_col_info = false; for (_, bool_node) in &bool_nodes { - let bool_op = BoolOp::from_expr(self, *bool_node)?; - if need_shard_col_info { - continue; - } - let left = self.get_additional_sq(select_id, bool_op.left)?; - let right = self.get_additional_sq(select_id, bool_op.right)?; - if left.is_some() || right.is_some() { - need_shard_col_info = true; - } - } - - let shard_col_info = if need_shard_col_info { - self.track_shard_column_pos(select_id)? - } else { - ShardColInfo::new() - }; - for (_, bool_node) in &bool_nodes { - let strategies = - self.get_sq_node_strategies_for_bool_op(select_id, *bool_node, &shard_col_info)?; + let strategies = self.get_sq_node_strategies_for_bool_op(select_id, *bool_node)?; for (id, policy) in strategies { // In case we faced with `not ... in ...`, we // have to change motion policy to Full. @@ -1035,15 +1008,8 @@ impl Plan { join_id: usize, left_row_id: usize, right_row_id: usize, - shard_col_info: &ShardColInfo, ) -> Result<MotionPolicy, SbroadError> { - if self.has_eq_on_bucket_id( - left_row_id, - right_row_id, - join_id, - &Bool::Eq, - shard_col_info, - )? { + if self.has_eq_on_bucket_id(left_row_id, right_row_id, join_id, &Bool::Eq)? { return Ok(MotionPolicy::None); } @@ -1241,7 +1207,6 @@ impl Plan { (inner, outer) }; - let shard_col_info = self.track_shard_column_pos(rel_id)?; let mut inner_map: HashMap<usize, MotionPolicy> = HashMap::new(); let mut new_inner_policy = MotionPolicy::Full; let filter = |node_id: usize| -> bool { @@ -1291,8 +1256,7 @@ impl Plan { // Note, that we don't have to call `get_sq_node_strategy_for_unary_op` here, because // the only strategy it can return is `Motion::Full` for its child and all subqueries // are covered with `Motion::Full` by default. - let sq_strategies = - self.get_sq_node_strategies_for_bool_op(rel_id, node_id, &shard_col_info)?; + let sq_strategies = self.get_sq_node_strategies_for_bool_op(rel_id, node_id)?; let sq_strategies_len = sq_strategies.len(); for (id, policy) in sq_strategies { strategy.add_child(id, policy, Program::default()); @@ -1337,12 +1301,9 @@ impl Plan { Bool::Between => { unreachable!("Between in redistribution") } - Bool::Eq | Bool::In => self.join_policy_for_eq( - rel_id, - bool_op.left, - bool_op.right, - &shard_col_info, - )?, + Bool::Eq | Bool::In => { + self.join_policy_for_eq(rel_id, bool_op.left, bool_op.right)? + } Bool::Gt | Bool::GtEq | Bool::Lt | Bool::LtEq | Bool::NotEq => { MotionPolicy::Full } @@ -1914,21 +1875,21 @@ impl Plan { // select "bucket_id" as a from t1 // except // select "bucket_id" as b from t1 - fn is_except_on_bucket_id( - &self, - rel_id: usize, - left_id: usize, - right_id: usize, - ) -> Result<bool, SbroadError> { - let shard_col_info = self.track_shard_column_pos(rel_id)?; - let Some(left_shard_positions) = shard_col_info.get(&left_id) else { + fn is_except_on_bucket_id(&self, left_id: usize, right_id: usize) -> Result<bool, SbroadError> { + let mut context = self.context_mut(); + let Some(left_shard_positions) = + context.get_shard_columns_positions(left_id, self)?.copied() + else { return Ok(false); }; - let Some(right_shard_positions) = shard_col_info.get(&right_id) else { + let Some(right_shard_positions) = context + .get_shard_columns_positions(right_id, self)? + .copied() + else { return Ok(false); }; - for l in left_shard_positions { - if right_shard_positions.contains(l) { + for l in &left_shard_positions { + if l.is_some() && right_shard_positions.contains(l) { return Ok(true); } } @@ -1955,7 +1916,7 @@ impl Plan { let left_dist = self.get_rel_distribution(left_id)?; let right_dist = self.get_rel_distribution(right_id)?; - if self.is_except_on_bucket_id(rel_id, left_id, right_id)? { + if self.is_except_on_bucket_id(left_id, right_id)? { return Ok(map); } @@ -2083,7 +2044,7 @@ impl Plan { right: cloned_left_id, output: intersect_output_id, }; - let intersect_id = self.nodes.push(Node::Relational(intersect)); + let intersect_id = self.add_relational(intersect)?; self.change_child(except_id, right_id, intersect_id)?; diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs index 9c891a705f71f85c1ee45c742b6460cb2a3a6265..28c9bb557be0581ac9b9da6981f808b336310ed4 100644 --- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs +++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs @@ -683,7 +683,7 @@ impl Plan { is_final, }; - let groupby_id = self.nodes.push(Node::Relational(groupby)); + let groupby_id = self.add_relational(groupby)?; self.replace_parent_in_subtree(final_output, None, Some(groupby_id))?; for expr in grouping_exprs { @@ -1055,7 +1055,7 @@ impl Plan { children: vec![child_id], is_distinct: false, }; - let proj_id = self.nodes.push(Node::Relational(proj)); + let proj_id = self.add_relational(proj)?; for info in aggr_infos { // We take expressions inside aggregate functions from Final projection, // so we need to update parent @@ -1397,7 +1397,7 @@ impl Plan { output, }; self.replace_parent_in_subtree(output, None, Some(final_id))?; - self.nodes.push(Node::Relational(final_groupby)); + self.add_relational(final_groupby)?; Ok(final_id) } @@ -1718,14 +1718,17 @@ impl Plan { // Check for group by on bucket_id column // in that case groupby can be done locally. if !grouping_exprs.is_empty() { - let shard_col_info = self.track_shard_column_pos(final_proj_id)?; + // let shard_col_info = self.track_shard_column_pos(final_proj_id)?; for expr_id in &grouping_exprs { let Expression::Reference { position, .. } = self.get_expression_node(*expr_id)? else { continue; }; let child_id = self.get_relational_from_reference_node(*expr_id)?; - if let Some(shard_positions) = shard_col_info.get(child_id) { + let mut context = self.context_mut(); + if let Some(shard_positions) = + context.get_shard_columns_positions(*child_id, self)? + { if shard_positions[0] == Some(*position) || shard_positions[1] == Some(*position) { diff --git a/sbroad-core/src/ir/transformation/redistribution/tests.rs b/sbroad-core/src/ir/transformation/redistribution/tests.rs index 7b215e52a05d98dd53cdee5f491f548dae8de4b8..2771326d07073104976f75bb7de935397c63ec5b 100644 --- a/sbroad-core/src/ir/transformation/redistribution/tests.rs +++ b/sbroad-core/src/ir/transformation/redistribution/tests.rs @@ -69,6 +69,8 @@ fn full_motion_less_for_sub_query() { .join("full_motion_less_for_sub_query.yaml"); let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(plan, expected_plan); } @@ -132,6 +134,8 @@ fn full_motion_non_segment_outer_for_sub_query() { .join("full_motion_non_segment_outer_for_sub_query.yaml"); let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(plan, expected_plan); } @@ -192,6 +196,8 @@ fn local_sub_query() { .join("local_sub_query.yaml"); let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(plan, expected_plan); } @@ -273,6 +279,8 @@ fn multiple_sub_queries() { .join("multiple_sub_queries.yaml"); let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(plan, expected_plan); } diff --git a/sbroad-core/src/ir/transformation/redistribution/tests/segment.rs b/sbroad-core/src/ir/transformation/redistribution/tests/segment.rs index 0d3695357d4fce41ac60745cd1d778d67a74153d..fe89c1f62f7a96310ea47dcfbbd649d8bd413425 100644 --- a/sbroad-core/src/ir/transformation/redistribution/tests/segment.rs +++ b/sbroad-core/src/ir/transformation/redistribution/tests/segment.rs @@ -93,6 +93,8 @@ fn sub_query1() { .join("segment_motion_for_sub_query.yaml"); let s = fs::read_to_string(path).unwrap(); let expected_plan = Plan::from_yaml(&s).unwrap(); + // This field is not serialized, do not check it + plan.context = None; assert_eq!(plan, expected_plan); }