From 56310721aa8f48b8333057cdf624ec3e830403ab Mon Sep 17 00:00:00 2001 From: EmirVildanov <reddog201030@gmail.com> Date: Sun, 12 May 2024 19:20:59 +0300 Subject: [PATCH] feat: code re-usage refactoring --- sbroad-core/src/executor/ir.rs | 102 ++- sbroad-core/src/frontend/sql/ir.rs | 75 +- sbroad-core/src/ir/api/parameter.rs | 646 +++++++++++------- sbroad-core/src/ir/expression.rs | 66 +- sbroad-core/src/ir/helpers.rs | 202 +++--- sbroad-core/src/ir/operator.rs | 38 -- sbroad-core/src/ir/transformation.rs | 84 ++- .../src/ir/transformation/not_push_down.rs | 19 +- .../transformation/redistribution/groupby.rs | 10 +- 9 files changed, 680 insertions(+), 562 deletions(-) diff --git a/sbroad-core/src/executor/ir.rs b/sbroad-core/src/executor/ir.rs index 99511fd7e..cde6e9ad8 100644 --- a/sbroad-core/src/executor/ir.rs +++ b/sbroad-core/src/executor/ir.rs @@ -61,6 +61,34 @@ impl From<Plan> for ExecutionPlan { } } +/// Translates the original plan's node id to the new sub-plan one. +struct SubtreeMap { + inner: AHashMap<usize, usize>, +} + +impl SubtreeMap { + fn with_capacity(capacity: usize) -> Self { + SubtreeMap { + inner: AHashMap::with_capacity(capacity), + } + } + + fn get_id(&self, expr_id: usize) -> usize { + *self + .inner + .get(&expr_id) + .unwrap_or_else(|| panic!("Could not find expr with id {expr_id} in subtree map")) + } + + fn contains_key(&self, expr_id: usize) -> bool { + self.inner.contains_key(&expr_id) + } + + fn insert(&mut self, old_id: usize, new_id: usize) { + self.inner.insert(old_id, new_id); + } +} + impl ExecutionPlan { #[must_use] pub fn get_ir_plan(&self) -> &Plan { @@ -116,7 +144,7 @@ impl ExecutionPlan { )) } - /// Add materialize motion result to translation map of virtual tables + /// Add materialize motion result to map of virtual tables. /// /// # Errors /// - invalid motion node @@ -423,8 +451,7 @@ impl ExecutionPlan { } } - // Translates the original plan's node id to the new sub-plan one. - let mut translation: AHashMap<usize, usize> = AHashMap::with_capacity(nodes.len()); + let mut subtree_map = SubtreeMap::with_capacity(nodes.len()); let vtables_capacity = self.get_vtables().map_or_else(|| 1, HashMap::len); // Map of { plan node_id -> virtual table }. let mut new_vtables: HashMap<usize, Rc<VirtualTable>> = @@ -435,7 +462,7 @@ impl ExecutionPlan { for (_, node_id) in nodes { // We have already processed this node (sub-queries in BETWEEN // and CTEs can be referred twice). - if translation.contains_key(&node_id) { + if subtree_map.contains_key(node_id) { continue; } @@ -473,7 +500,7 @@ impl ExecutionPlan { // XXX: UNDO operation can cause problems if we introduce more complicated transformations // for filter/condition (but then the UNDO logic should be changed as well). let undo_expr_id = ir_plan.undo.get_oldest(expr_id).unwrap_or(expr_id); - *expr_id = *translation.get(undo_expr_id).unwrap_or_else(|| panic!("Could not find filter/condition node id {undo_expr_id} in the map.")); + *expr_id = subtree_map.get_id(*undo_expr_id); new_plan.replace_parent_in_subtree(*expr_id, None, Some(next_id))?; } Relational::ScanRelation { relation, .. } @@ -508,9 +535,7 @@ impl ExecutionPlan { Relational::GroupBy { gr_cols, .. } => { let mut new_cols: Vec<usize> = Vec::with_capacity(gr_cols.len()); for col_id in gr_cols.iter() { - let new_col_id = *translation.get(col_id).unwrap_or_else(|| { - panic!("Grouping column {col_id} in translation map.") - }); + let new_col_id = subtree_map.get_id(*col_id); new_plan.replace_parent_in_subtree( new_col_id, None, @@ -528,10 +553,7 @@ impl ExecutionPlan { for element in order_by_elements.iter() { let new_entity = match element.entity { OrderByEntity::Expression { expr_id } => { - let new_element_id = - *translation.get(&expr_id).unwrap_or_else(|| { - panic!("ORDER BY element {element:?} not found in translation map.") - }); + let new_element_id = subtree_map.get_id(expr_id); new_plan.replace_parent_in_subtree( new_element_id, None, @@ -553,9 +575,7 @@ impl ExecutionPlan { *order_by_elements = new_elements; } Relational::ValuesRow { data, .. } => { - *data = *translation.get(data).unwrap_or_else(|| { - panic!("Could not find data node id {data} in the map.") - }); + *data = subtree_map.get_id(*data); } Relational::Except { .. } | Relational::Intersect { .. } @@ -568,15 +588,11 @@ impl ExecutionPlan { } for child_id in rel.mut_children() { - *child_id = *translation.get(child_id).unwrap_or_else(|| { - panic!("Could not find child node id {child_id} in the map.") - }); + *child_id = subtree_map.get_id(*child_id); } let output = rel.output(); - *rel.mut_output() = *translation.get(&output).unwrap_or_else(|| { - panic!("Node not found as output node {output} in relational node {rel:?}.") - }); + *rel.mut_output() = subtree_map.get_id(output); new_plan.replace_parent_in_subtree(rel.output(), None, Some(next_id))?; } Node::Expression(ref mut expr) => match expr { @@ -584,9 +600,7 @@ impl ExecutionPlan { | Expression::ExprInParentheses { ref mut child } | Expression::Cast { ref mut child, .. } | Expression::Unary { ref mut child, .. } => { - *child = *translation.get(child).unwrap_or_else(|| { - panic!("Could not find child node id {child} in the map.") - }); + *child = subtree_map.get_id(*child); } Expression::Bool { ref mut left, @@ -603,12 +617,8 @@ impl ExecutionPlan { ref mut right, .. } => { - *left = *translation.get(left).unwrap_or_else(|| { - panic!("Could not find left child node id {left} in the map.") - }); - *right = *translation.get(right).unwrap_or_else(|| { - panic!("Could not find right child node id {right} in the map.") - }); + *left = subtree_map.get_id(*left); + *right = subtree_map.get_id(*right); } Expression::Trim { ref mut pattern, @@ -616,13 +626,9 @@ impl ExecutionPlan { .. } => { if let Some(pattern) = pattern { - *pattern = *translation.get(pattern).unwrap_or_else(|| { - panic!("Could not find pattern node id {pattern} in the map.") - }); + *pattern = subtree_map.get_id(*pattern); } - *target = *translation.get(target).unwrap_or_else(|| { - panic!("Could not find target node id {target} in the map.") - }); + *target = subtree_map.get_id(*target); } Expression::Reference { ref mut parent, .. } => { // The new parent node id MUST be set while processing the relational nodes. @@ -636,9 +642,7 @@ impl ExecutionPlan { ref mut children, .. } => { for child in children { - *child = *translation.get(child).unwrap_or_else(|| { - panic!("Could not find child node id {child} in the map.") - }); + *child = subtree_map.get_id(*child); } } Expression::Constant { .. } | Expression::CountAsterisk => {} @@ -648,24 +652,14 @@ impl ExecutionPlan { else_expr, } => { if let Some(search_expr) = search_expr { - *search_expr = *translation.get(search_expr).unwrap_or_else(|| { - panic!("Could not find search expression {search_expr} in the map.") - }); + *search_expr = subtree_map.get_id(*search_expr); } for (cond_expr, res_expr) in when_blocks { - *cond_expr = *translation.get(cond_expr).unwrap_or_else(|| { - panic!( - "Could not find cond WHEN expression {cond_expr} in the map." - ) - }); - *res_expr = *translation.get(res_expr).unwrap_or_else(|| { - panic!("Could not find res THEN expression {res_expr} in the map.") - }); + *cond_expr = subtree_map.get_id(*cond_expr); + *res_expr = subtree_map.get_id(*res_expr); } if let Some(else_expr) = else_expr { - *else_expr = *translation.get(else_expr).unwrap_or_else(|| { - panic!("Could not find else expression {else_expr} in the map.") - }); + *else_expr = subtree_map.get_id(*else_expr); } } }, @@ -675,7 +669,7 @@ impl ExecutionPlan { } } new_plan.nodes.push(node); - translation.insert(node_id, next_id); + subtree_map.insert(node_id, next_id); if top_id == node_id { new_plan.set_top(next_id)?; } diff --git a/sbroad-core/src/frontend/sql/ir.rs b/sbroad-core/src/frontend/sql/ir.rs index 24ce1df60..33b9e01bb 100644 --- a/sbroad-core/src/frontend/sql/ir.rs +++ b/sbroad-core/src/frontend/sql/ir.rs @@ -14,7 +14,7 @@ use crate::ir::transformation::redistribution::MotionOpcode; use crate::ir::tree::traversal::{PostOrder, EXPR_CAPACITY}; use crate::ir::value::double::Double; use crate::ir::value::Value; -use crate::ir::{Node, Plan}; +use crate::ir::{Node, NodeId, Plan}; use super::Between; @@ -130,6 +130,35 @@ impl SubQuery { } } +struct CloneExprSubtreeMap { + // Map of { old_node_id -> new_node_id } for cloning nodes. + inner: AHashMap<NodeId, NodeId>, +} + +impl CloneExprSubtreeMap { + fn with_capacity(capacity: usize) -> Self { + CloneExprSubtreeMap { + inner: AHashMap::with_capacity(capacity), + } + } + + fn insert(&mut self, old_id: usize, new_id: usize) { + self.inner.insert(old_id, new_id); + } + + fn replace(&self, id: &mut usize) { + let new_id = self.get(*id); + *id = new_id; + } + + fn get(&self, id: usize) -> usize { + *self + .inner + .get(&id) + .unwrap_or_else(|| panic!("Node with id {id} not found in the cloning subtree map.")) + } +} + impl Plan { fn gather_sq_for_replacement(&self) -> Result<HashSet<SubQuery, RepeatableState>, SbroadError> { let mut set: HashSet<SubQuery, RepeatableState> = HashSet::with_hasher(RepeatableState); @@ -291,11 +320,11 @@ impl Plan { } pub(crate) fn clone_expr_subtree(&mut self, top_id: usize) -> Result<usize, SbroadError> { - let mut map = HashMap::new(); let mut subtree = PostOrder::with_capacity(|node| self.nodes.expr_iter(node, false), EXPR_CAPACITY); subtree.populate_nodes(top_id); let nodes = subtree.take_nodes(); + let mut map = CloneExprSubtreeMap::with_capacity(nodes.len()); for (_, id) in nodes { let next_id = self.nodes.next_id(); let mut expr = self.get_expression_node(id)?.clone(); @@ -306,11 +335,7 @@ impl Plan { Expression::Alias { ref mut child, .. } | Expression::ExprInParentheses { ref mut child } | Expression::Cast { ref mut child, .. } - | Expression::Unary { ref mut child, .. } => { - *child = *map.get(child).ok_or_else(|| { - SbroadError::NotFound(Entity::SubTree, format_smolstr!("(id {id})")) - })?; - } + | Expression::Unary { ref mut child, .. } => map.replace(child), Expression::Bool { ref mut left, ref mut right, @@ -326,12 +351,8 @@ impl Plan { ref mut right, .. } => { - *left = *map.get(left).ok_or_else(|| { - SbroadError::NotFound(Entity::SubTree, format_smolstr!("(id {id})")) - })?; - *right = *map.get(right).ok_or_else(|| { - SbroadError::NotFound(Entity::SubTree, format_smolstr!("(id {id})")) - })?; + map.replace(left); + map.replace(right); } Expression::Trim { ref mut pattern, @@ -339,13 +360,9 @@ impl Plan { .. } => { if let Some(pattern) = pattern { - *pattern = *map.get(pattern).ok_or_else(|| { - SbroadError::NotFound(Entity::SubTree, format_smolstr!("(id {id})")) - })?; + map.replace(pattern); } - *target = *map.get(target).ok_or_else(|| { - SbroadError::NotFound(Entity::SubTree, format_smolstr!("(id {id})")) - })?; + map.replace(target); } Expression::Row { list: ref mut children, @@ -355,9 +372,7 @@ impl Plan { ref mut children, .. } => { for child in children { - *child = *map.get(child).ok_or_else(|| { - SbroadError::NotFound(Entity::SubTree, format_smolstr!("(id {id})")) - })?; + map.replace(child); } } Expression::Case { @@ -366,22 +381,14 @@ impl Plan { ref mut else_expr, } => { if let Some(search_expr) = search_expr { - *search_expr = *map.get(search_expr).unwrap_or_else(|| { - panic!("Search expression not found for subtree cloning.") - }); + map.replace(search_expr); } for (cond_expr, res_expr) in when_blocks { - *cond_expr = *map.get(cond_expr).unwrap_or_else(|| { - panic!("Condition expression not found for subtree cloning.") - }); - *res_expr = *map.get(res_expr).unwrap_or_else(|| { - panic!("Result expression not found for subtree cloning.") - }); + map.replace(cond_expr); + map.replace(res_expr); } if let Some(else_expr) = else_expr { - *else_expr = *map.get(else_expr).unwrap_or_else(|| { - panic!("Else expression not found for subtree cloning.") - }); + map.replace(else_expr); } } } diff --git a/sbroad-core/src/ir/api/parameter.rs b/sbroad-core/src/ir/api/parameter.rs index 61e9be554..5b5e50e71 100644 --- a/sbroad-core/src/ir/api/parameter.rs +++ b/sbroad-core/src/ir/api/parameter.rs @@ -1,173 +1,202 @@ -use crate::errors::{Entity, SbroadError}; +use crate::errors::SbroadError; use crate::ir::block::Block; use crate::ir::expression::Expression; use crate::ir::operator::Relational; -use crate::ir::tree::traversal::PostOrder; +use crate::ir::tree::traversal::{LevelNode, PostOrder}; use crate::ir::value::Value; -use crate::ir::{Node, OptionParamValue, Plan}; +use crate::ir::{Node, NodeId, OptionParamValue, Plan, ValueIdx}; use crate::otm::child_span; use sbroad_proc::otm_child_span; -use ahash::RandomState; -use smol_str::format_smolstr; -use std::collections::{HashMap, HashSet}; +use crate::ir::relation::Type; +use ahash::{AHashMap, AHashSet, RandomState}; +use std::collections::HashMap; -impl Plan { - pub fn add_param(&mut self) -> usize { - self.nodes.push(Node::Parameter) - } +struct ParamsBinder<'binder> { + plan: &'binder mut Plan, + /// Plan nodes to traverse during binding. + nodes: Vec<LevelNode>, + /// Number of parameters met in the OPTIONs. + binded_options_counter: usize, + /// Flag indicating whether we use Tarantool parameters notation. + tnt_params_style: bool, + /// Map of { plan param_id -> corresponding value }. + pg_params_map: HashMap<NodeId, ValueIdx>, + /// Plan nodes that correspond to Parameters. + param_node_ids: AHashSet<NodeId>, + /// Params transformed into constant Values. + value_ids: Vec<NodeId>, + /// Values that should be bind. + values: Vec<Value>, + /// We need to use rows instead of values in some cases (AST can solve + /// this problem for non-parameterized queries, but for parameterized + /// queries it is IR responsibility). + /// + /// Map of { param_id -> corresponding row }. + row_map: AHashMap<usize, usize, RandomState>, +} - // Gather all parameter nodes from the tree to a hash set. - #[must_use] - pub fn get_param_set(&self) -> HashSet<usize> { - let param_set: HashSet<usize> = self - .nodes - .arena - .iter() - .enumerate() - .filter_map(|(id, node)| { - if let Node::Parameter = node { - Some(id) - } else { - None - } +fn get_param_value( + tnt_params_style: bool, + param_id: usize, + param_index: usize, + value_ids: &[usize], + pg_params_map: &HashMap<NodeId, ValueIdx>, +) -> usize { + let value_index = if tnt_params_style { + // In case non-pg params are used, index is the correct position + param_index + } else { + value_ids.len() + - 1 + - *pg_params_map.get(¶m_id).unwrap_or_else(|| { + panic!("Value index not found for parameter with id: {param_id}.") }) - .collect(); - param_set - } - - /// Substitute parameters to the plan. - /// The purpose of this function is to find every `Parameter` node and replace it - /// with `Expression::Constant` (under the row). - /// - /// # Errors - /// - Invalid amount of parameters. - /// - Internal errors. - #[allow(clippy::too_many_lines)] - #[otm_child_span("plan.bind")] - pub fn bind_params(&mut self, mut values: Vec<Value>) -> Result<(), SbroadError> { - // Nothing to do here. - if values.is_empty() { - return Ok(()); - } + }; + let val_id = value_ids + .get(value_index) + .unwrap_or_else(|| panic!("Parameter not found in position {value_index}.")); + *val_id +} - let capacity = self.next_id(); - let mut tree = PostOrder::with_capacity(|node| self.subtree_iter(node, false), capacity); - let top_id = self.get_top()?; +impl<'binder> ParamsBinder<'binder> { + fn new(plan: &'binder mut Plan, mut values: Vec<Value>) -> Result<Self, SbroadError> { + let capacity = plan.next_id(); + let mut tree = PostOrder::with_capacity(|node| plan.subtree_iter(node, false), capacity); + let top_id = plan.get_top()?; tree.populate_nodes(top_id); let nodes = tree.take_nodes(); - let mut binded_params_counter = 0; - if !self.raw_options.is_empty() { - binded_params_counter = self.bind_option_params(&mut values)?; + let mut binded_options_counter = 0; + if !plan.raw_options.is_empty() { + binded_options_counter = plan.bind_option_params(&mut values); } - // Gather all parameter nodes from the tree to a hash set. - // `param_node_ids` is used during first plan traversal (`row_ids` populating). - // `param_node_ids_cloned` is used during second plan traversal (nodes transformation). - let mut param_node_ids = self.get_param_set(); - let mut param_node_ids_cloned = param_node_ids.clone(); + let param_node_ids = plan.get_param_set(); + let tnt_params_style = plan.pg_params_map.is_empty(); + let pg_params_map = std::mem::take(&mut plan.pg_params_map); - let tnt_params_style = self.pg_params_map.is_empty(); - - let mut pg_params_map = std::mem::take(&mut self.pg_params_map); + let binder = ParamsBinder { + plan, + nodes, + binded_options_counter, + tnt_params_style, + pg_params_map, + param_node_ids, + value_ids: Vec::new(), + values, + row_map: AHashMap::new(), + }; + Ok(binder) + } - if !tnt_params_style { + /// Copy values to bind for Postgres-style parameters. + fn handle_pg_parameters(&mut self) -> Result<(), SbroadError> { + if !self.tnt_params_style { // Due to how we calculate hash for plan subtree and the // fact that pg parameters can refer to same value multiple // times we currently copy params that are referred more // than once in order to get the same hash. // See https://git.picodata.io/picodata/picodata/sbroad/-/issues/583 - let mut used_values = vec![false; values.len()]; + let mut used_values = vec![false; self.values.len()]; let invalid_idx = |param_id: usize, value_idx: usize| { - SbroadError::Invalid( - Entity::Plan, - Some(format_smolstr!( - "out of bounds value index {value_idx} for pg parameter {param_id}" - )), - ) + panic!("Out of bounds value index {value_idx} for pg parameter {param_id}."); }; // NB: we can't use `param_node_ids`, we need to traverse // parameters in the same order they will be bound, // otherwise we may get different hashes for plans // with tnt and pg parameters. See `subtree_hash*` tests, - for (_, param_id) in &nodes { - if !matches!(self.get_node(*param_id)?, Node::Parameter) { + for (_, param_id) in &self.nodes { + if !matches!(self.plan.get_node(*param_id)?, Node::Parameter) { continue; } - let value_idx = *pg_params_map.get(param_id).ok_or(SbroadError::Invalid( - Entity::Plan, - Some(format_smolstr!( - "value index not found for parameter with id: {param_id}", - )), - ))?; + let value_idx = *self.pg_params_map.get(param_id).unwrap_or_else(|| { + panic!("Value index not found for parameter with id: {param_id}."); + }); if used_values.get(value_idx).copied().unwrap_or(true) { - let Some(value) = values.get(value_idx) else { - return Err(invalid_idx(*param_id, value_idx)); + let Some(value) = self.values.get(value_idx) else { + invalid_idx(*param_id, value_idx) }; - values.push(value.clone()); - pg_params_map + self.values.push(value.clone()); + self.pg_params_map .entry(*param_id) - .and_modify(|value_idx| *value_idx = values.len() - 1); + .and_modify(|value_idx| *value_idx = self.values.len() - 1); } else if let Some(used) = used_values.get_mut(value_idx) { *used = true; } else { - return Err(invalid_idx(*param_id, value_idx)); + invalid_idx(*param_id, value_idx) } } } + Ok(()) + } - // Transform parameters to values (plan constants). The result values are stored in the - // opposite to parameters order. - let mut value_ids: Vec<usize> = Vec::with_capacity(values.len()); - while let Some(param) = values.pop() { - value_ids.push(self.add_const(param)); + /// Transform parameters (passed by user) to values (plan constants). + /// The result values are stored in the opposite to parameters order. + /// + /// In case some redundant params were passed, they'll + /// be ignored (just not popped from the `value_ids` stack later). + fn create_parameter_constants(&mut self) { + self.value_ids = Vec::with_capacity(self.values.len()); + while let Some(param) = self.values.pop() { + self.value_ids.push(self.plan.add_const(param)); } + } + + /// Check that number of user passed params equal to the params nodes we have to bind. + fn check_params_count(&self) { + let non_binded_params_len = self.param_node_ids.len() - self.binded_options_counter; + assert!( + !(self.tnt_params_style && non_binded_params_len > self.value_ids.len()), + "Expected at least {} values for parameters. Got {}.", + non_binded_params_len, + self.value_ids.len() + ); + } - // We need to use rows instead of values in some cases (AST can solve - // this problem for non-parameterized queries, but for parameterized - // queries it is IR responsibility). - let mut row_ids: HashMap<usize, usize, RandomState> = - HashMap::with_hasher(RandomState::new()); - - let non_binded_params_len = param_node_ids.len() - binded_params_counter; - if tnt_params_style && non_binded_params_len > value_ids.len() { - return Err(SbroadError::Invalid( - Entity::Value, - Some(format_smolstr!( - "Expected at least {} values for parameters. Got {}.", - non_binded_params_len, - value_ids.len() - )), - )); + /// Retrieve a corresponding value (plan constant node) for a parameter node. + fn get_param_value(&self, param_id: usize, param_index: usize) -> usize { + get_param_value( + self.tnt_params_style, + param_id, + param_index, + &self.value_ids, + &self.pg_params_map, + ) + } + + /// 1.) Increase binding param index. + /// 2.) In case `cover_with_row` is set to true, cover the param node with a row. + fn cover_param_with_row( + &self, + param_id: usize, + cover_with_row: bool, + param_index: &mut usize, + row_ids: &mut HashMap<usize, usize, RandomState>, + ) { + if self.param_node_ids.contains(¶m_id) { + if row_ids.contains_key(¶m_id) { + return; + } + *param_index = param_index.saturating_sub(1); + if cover_with_row { + let val_id = self.get_param_value(param_id, *param_index); + row_ids.insert(param_id, val_id); + } } + } - // Populate rows. - // Number of parameters - `idx` - 1 = index in params we are currently binding. - // Initially pointing to nowhere. - let mut idx = value_ids.len(); + /// Traverse the plan nodes tree and cover parameter nodes with rows if needed. + #[allow(clippy::too_many_lines)] + fn cover_params_with_rows(&mut self) -> Result<(), SbroadError> { + // Len of `value_ids` - `param_index` = param index we are currently binding. + let mut param_index = self.value_ids.len(); - let get_value = |param_id: usize, idx: usize| -> usize { - let value_idx = if tnt_params_style { - // in case non-pg params are used, - // idx is the correct position - idx - } else { - value_ids.len() - - 1 - - *pg_params_map.get(¶m_id).unwrap_or_else(|| { - panic!("Value index not found for parameter with id: {param_id}.") - }) - }; - let val_id = value_ids - .get(value_idx) - .unwrap_or_else(|| panic!("Parameter not found in position {value_idx}.")); - *val_id - }; + let mut row_ids = HashMap::with_hasher(RandomState::new()); - for (_, id) in &nodes { - let node = self.get_node(*id)?; + for (_, id) in &self.nodes { + let node = self.plan.get_node(*id)?; match node { // Note: Parameter may not be met at the top of relational operators' expression // trees such as OrderBy and GroupBy, because it won't influence ordering and @@ -185,11 +214,7 @@ impl Plan { condition: ref param_id, .. } => { - if param_node_ids.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(*param_id, idx); - row_ids.insert(*param_id, self.nodes.add_row(vec![val_id], None)); - } + self.cover_param_with_row(*param_id, true, &mut param_index, &mut row_ids); } _ => {} }, @@ -209,9 +234,7 @@ impl Plan { child: ref param_id, .. } => { - if param_node_ids.take(param_id).is_some() { - idx = idx.saturating_sub(1); - } + self.cover_param_with_row(*param_id, false, &mut param_index, &mut row_ids); } Expression::Bool { ref left, @@ -228,11 +251,12 @@ impl Plan { ref right, } => { for param_id in &[*left, *right] { - if param_node_ids.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(*param_id, idx); - row_ids.insert(*param_id, self.nodes.add_row(vec![val_id], None)); - } + self.cover_param_with_row( + *param_id, + true, + &mut param_index, + &mut row_ids, + ); } } Expression::Trim { @@ -245,11 +269,12 @@ impl Plan { None => [None, Some(*target)], }; for param_id in params.into_iter().flatten() { - if param_node_ids.take(¶m_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(param_id, idx); - row_ids.insert(param_id, self.nodes.add_row(vec![val_id], None)); - } + self.cover_param_with_row( + param_id, + true, + &mut param_index, + &mut row_ids, + ); } } Expression::Row { ref list, .. } @@ -257,11 +282,14 @@ impl Plan { children: ref list, .. } => { for param_id in list { - if param_node_ids.take(param_id).is_some() { - // Parameter is already under row/function so that we don't - // have to cover it with `add_row` call. - idx = idx.saturating_sub(1); - } + // Parameter is already under row/function so that we don't + // have to cover it with `add_row` call. + self.cover_param_with_row( + *param_id, + false, + &mut param_index, + &mut row_ids, + ); } } Expression::Case { @@ -270,24 +298,34 @@ impl Plan { ref else_expr, } => { if let Some(search_expr) = search_expr { - if param_node_ids.take(search_expr).is_some() { - idx = idx.saturating_sub(1); - } + self.cover_param_with_row( + *search_expr, + false, + &mut param_index, + &mut row_ids, + ); } - for (cond_expr, res_expr) in when_blocks { - if param_node_ids.take(cond_expr).is_some() { - idx = idx.saturating_sub(1); - } - if param_node_ids.take(res_expr).is_some() { - idx = idx.saturating_sub(1); - } + self.cover_param_with_row( + *cond_expr, + false, + &mut param_index, + &mut row_ids, + ); + self.cover_param_with_row( + *res_expr, + false, + &mut param_index, + &mut row_ids, + ); } - if let Some(else_expr) = else_expr { - if param_node_ids.take(else_expr).is_some() { - idx = idx.saturating_sub(1); - } + self.cover_param_with_row( + *else_expr, + false, + &mut param_index, + &mut row_ids, + ); } } Expression::Reference { .. } @@ -297,11 +335,14 @@ impl Plan { Node::Block(block) => match block { Block::Procedure { ref values, .. } => { for param_id in values { - if param_node_ids.take(param_id).is_some() { - // We don't need to wrap arguments, passed into the - // procedure call, into the rows. - idx = idx.saturating_sub(1); - } + // We don't need to wrap arguments, passed into the + // procedure call, into the rows. + self.cover_param_with_row( + *param_id, + false, + &mut param_index, + &mut row_ids, + ); } } }, @@ -309,38 +350,71 @@ impl Plan { } } - // Closure to retrieve a corresponding row for a parameter node. - let get_row = |param_id: usize| -> Result<usize, SbroadError> { - let row_id = row_ids.get(¶m_id).ok_or_else(|| { - SbroadError::NotFound( - Entity::Node, - format_smolstr!("(Row) at position {param_id}"), - ) - })?; - Ok(*row_id) - }; + let fixed_row_ids: AHashMap<usize, usize, RandomState> = row_ids + .iter() + .map(|(param_id, val_id)| { + let row_cover = self.plan.nodes.add_row(vec![*val_id], None); + (*param_id, row_cover) + }) + .collect(); + self.row_map = fixed_row_ids; + + Ok(()) + } - // Replace parameters in the plan. - idx = value_ids.len(); - for (_, id) in &nodes { + /// Replace parameters in the plan. + #[allow(clippy::too_many_lines)] + fn bind_params(&mut self) -> Result<(), SbroadError> { + let mut exprs_to_set_ref_type: HashMap<usize, Type> = HashMap::new(); + + for (_, id) in &self.nodes { // Before binding, references that referred to // parameters had scalar type (by default), // but in fact they may refer to different stuff. - { - let mut new_type = None; - if let Node::Expression(expr) = self.get_node(*id)? { - if let Expression::Reference { .. } = expr { - new_type = Some(expr.recalculate_type(self)?); - } - } - if let Some(new_type) = new_type { - let expr = self.get_mut_expression_node(*id)?; - expr.set_ref_type(new_type); + if let Node::Expression(expr) = self.plan.get_node(*id)? { + if let Expression::Reference { .. } = expr { + exprs_to_set_ref_type.insert(*id, expr.recalculate_type(self.plan)?); continue; } } + } + for (id, new_type) in exprs_to_set_ref_type { + let expr = self.plan.get_mut_expression_node(id)?; + expr.set_ref_type(new_type); + } + + // Len of `value_ids` - `param_index` = param index we are currently binding. + let mut param_index = self.value_ids.len(); - let node = self.get_mut_node(*id)?; + let tnt_params_style = self.tnt_params_style; + let row_ids = std::mem::take(&mut self.row_map); + let value_ids = std::mem::take(&mut self.value_ids); + let pg_params_map = std::mem::take(&mut self.pg_params_map); + + let bind_param = |param_id: &mut usize, is_row: bool, param_index: &mut usize| { + *param_id = if self.param_node_ids.contains(param_id) { + *param_index = param_index.saturating_sub(1); + let binding_node_id = if is_row { + *row_ids + .get(param_id) + .unwrap_or_else(|| panic!("Row not found at position {param_id}")) + } else { + get_param_value( + tnt_params_style, + *param_id, + *param_index, + &value_ids, + &pg_params_map, + ) + }; + binding_node_id + } else { + *param_id + } + }; + + for (_, id) in &self.nodes { + let node = self.plan.get_mut_node(*id)?; match node { Node::Relational(rel) => match rel { Relational::Having { @@ -355,11 +429,7 @@ impl Plan { condition: ref mut param_id, .. } => { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let row_id = get_row(*param_id)?; - *param_id = row_id; - } + bind_param(param_id, true, &mut param_index); } _ => {} }, @@ -379,11 +449,7 @@ impl Plan { child: ref mut param_id, .. } => { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(*param_id, idx); - *param_id = val_id; - } + bind_param(param_id, false, &mut param_index); } Expression::Bool { ref mut left, @@ -399,12 +465,8 @@ impl Plan { ref mut left, ref mut right, } => { - for param_id in &mut [left, right].iter_mut() { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let row_id = get_row(**param_id)?; - **param_id = row_id; - } + for param_id in [left, right] { + bind_param(param_id, true, &mut param_index); } } Expression::Trim { @@ -417,11 +479,7 @@ impl Plan { None => [None, Some(target)], }; for param_id in params.into_iter().flatten() { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let row_id = get_row(*param_id)?; - *param_id = row_id; - } + bind_param(param_id, true, &mut param_index); } } Expression::Row { ref mut list, .. } @@ -430,11 +488,7 @@ impl Plan { .. } => { for param_id in list { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(*param_id, idx); - *param_id = val_id; - } + bind_param(param_id, false, &mut param_index); } } Expression::Case { @@ -442,24 +496,15 @@ impl Plan { ref mut when_blocks, ref mut else_expr, } => { - let mut do_the_work = |param_id: &mut usize| { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(*param_id, idx); - *param_id = val_id; - } - }; - if let Some(search_expr) = search_expr { - do_the_work(search_expr); + if let Some(param_id) = search_expr { + bind_param(param_id, false, &mut param_index); } - - for (cond_expr, res_expr) in when_blocks { - do_the_work(cond_expr); - do_the_work(res_expr); + for (param_id_1, param_id_2) in when_blocks { + bind_param(param_id_1, false, &mut param_index); + bind_param(param_id_2, false, &mut param_index); } - - if let Some(else_expr) = else_expr { - do_the_work(else_expr); + if let Some(param_id) = else_expr { + bind_param(param_id, false, &mut param_index); } } Expression::Reference { .. } @@ -469,11 +514,7 @@ impl Plan { Node::Block(block) => match block { Block::Procedure { ref mut values, .. } => { for param_id in values { - if param_node_ids_cloned.take(param_id).is_some() { - idx = idx.saturating_sub(1); - let val_id = get_value(*param_id, idx); - *param_id = val_id; - } + bind_param(param_id, false, &mut param_index); } } }, @@ -481,22 +522,33 @@ impl Plan { } } - // Update values row output. - for (_, id) in nodes { - if let Ok(Node::Relational(Relational::ValuesRow { .. })) = self.get_node(id) { - self.update_values_row(id)?; + Ok(()) + } + + fn update_value_rows(&mut self) -> Result<(), SbroadError> { + for (_, id) in &self.nodes { + if let Ok(Node::Relational(Relational::ValuesRow { .. })) = self.plan.get_node(*id) { + self.plan.update_values_row(*id)?; } } - Ok(()) } +} + +impl Plan { + pub fn add_param(&mut self) -> usize { + self.nodes.push(Node::Parameter) + } /// Bind params related to `Option` clause. /// Returns the number of params binded to options. /// /// # Errors /// - User didn't provide parameter value for corresponding option parameter - pub fn bind_option_params(&mut self, values: &mut Vec<Value>) -> Result<usize, SbroadError> { + /// + /// # Panics + /// - Plan is inconsistent state + pub fn bind_option_params(&mut self, values: &mut Vec<Value>) -> usize { // Bind parameters in options to values. // Because the Option clause is the last clause in the // query the parameters are located in the end of params list. @@ -505,37 +557,101 @@ impl Plan { if let OptionParamValue::Parameter { plan_id: param_id } = opt.val { if !self.pg_params_map.is_empty() { // PG-like params syntax - let value_idx = *self.pg_params_map.get(¶m_id).ok_or_else(|| { - SbroadError::Invalid( - Entity::Plan, - Some(format_smolstr!( - "no value idx in map for option parameter: {opt:?}" - )), - ) - })?; - let value = values.get(value_idx).ok_or_else(|| { - SbroadError::Invalid( - Entity::Plan, - Some(format_smolstr!( - "invalid value idx {value_idx}, for option: {opt:?}" - )), - ) - })?; + let value_idx = *self.pg_params_map.get(¶m_id).unwrap_or_else(|| { + panic!("No value idx in map for option parameter: {opt:?}."); + }); + let value = values.get(value_idx).unwrap_or_else(|| { + panic!("Invalid value idx {value_idx}, for option: {opt:?}."); + }); opt.val = OptionParamValue::Value { val: value.clone() }; } else if let Some(v) = values.pop() { binded_params_counter += 1; opt.val = OptionParamValue::Value { val: v }; } else { - return Err(SbroadError::Invalid( - Entity::Query, - Some(format_smolstr!( - "no parameter value specified for option: {}", - opt.kind - )), - )); + panic!("No parameter value specified for option: {}", opt.kind); } } } - Ok(binded_params_counter) + binded_params_counter + } + + // Gather all parameter nodes from the tree to a hash set. + #[must_use] + pub fn get_param_set(&self) -> AHashSet<usize> { + let param_set: AHashSet<usize> = self + .nodes + .arena + .iter() + .enumerate() + .filter_map(|(id, node)| { + if let Node::Parameter = node { + Some(id) + } else { + None + } + }) + .collect(); + param_set + } + + /// Synchronize values row output with the data tuple after parameter binding. + /// + /// # Errors + /// - Node is not values row + /// - Output and data tuples have different number of columns + /// - Output is not a row of aliases + /// + /// # Panics + /// - Plan is inconsistent state + pub fn update_values_row(&mut self, id: usize) -> Result<(), SbroadError> { + let values_row = self.get_node(id)?; + let (output_id, data_id) = + if let Node::Relational(Relational::ValuesRow { output, data, .. }) = values_row { + (*output, *data) + } else { + panic!("Expected a values row: {values_row:?}") + }; + let data = self.get_expression_node(data_id)?; + let data_list = data.clone_row_list()?; + let output = self.get_expression_node(output_id)?; + let output_list = output.clone_row_list()?; + for (pos, alias_id) in output_list.iter().enumerate() { + let new_child_id = *data_list + .get(pos) + .unwrap_or_else(|| panic!("Node not found at position {pos}")); + let alias = self.get_mut_expression_node(*alias_id)?; + if let Expression::Alias { ref mut child, .. } = alias { + *child = new_child_id; + } else { + panic!("Expected an alias: {alias:?}") + } + } + Ok(()) + } + + /// Substitute parameters to the plan. + /// The purpose of this function is to find every `Parameter` node and replace it + /// with `Expression::Constant` (under the row). + /// + /// # Errors + /// - Invalid amount of parameters. + /// - Internal errors. + #[allow(clippy::too_many_lines)] + #[otm_child_span("plan.bind")] + pub fn bind_params(&mut self, values: Vec<Value>) -> Result<(), SbroadError> { + // Nothing to do here. + if values.is_empty() { + return Ok(()); + } + + let mut binder = ParamsBinder::new(self, values)?; + binder.handle_pg_parameters()?; + binder.create_parameter_constants(); + binder.check_params_count(); + binder.cover_params_with_rows()?; + binder.bind_params()?; + binder.update_value_rows()?; + + Ok(()) } } diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs index 1442e4bd5..8dfd78f19 100644 --- a/sbroad-core/src/ir/expression.rs +++ b/sbroad-core/src/ir/expression.rs @@ -494,8 +494,9 @@ impl<'plan> PlanExpr<'plan> { impl<'plan> Hash for PlanExpr<'plan> { fn hash<H: Hasher>(&self, state: &mut H) { - let comp = Comparator::new(ReferencePolicy::ByFields, self.plan); - comp.hash_for_expr(self.id, state, EXPR_HASH_DEPTH); + let mut comp = Comparator::new(ReferencePolicy::ByFields, self.plan); + comp.set_hasher(state); + comp.hash_for_expr(self.id, EXPR_HASH_DEPTH); } } @@ -531,6 +532,7 @@ pub enum ReferencePolicy { pub struct Comparator<'plan> { policy: ReferencePolicy, plan: &'plan Plan, + state: Option<&'plan mut dyn Hasher>, } pub const EXPR_HASH_DEPTH: usize = 5; @@ -538,7 +540,15 @@ pub const EXPR_HASH_DEPTH: usize = 5; impl<'plan> Comparator<'plan> { #[must_use] pub fn new(policy: ReferencePolicy, plan: &'plan Plan) -> Self { - Comparator { policy, plan } + Comparator { + policy, + plan, + state: None, + } + } + + pub fn set_hasher<H: Hasher>(&mut self, state: &'plan mut H) { + self.state = Some(state); } /// Checks whether expression subtrees `lhs` and `rhs` are equal. @@ -757,21 +767,33 @@ impl<'plan> Comparator<'plan> { Ok(false) } + pub fn hash_for_child_expr(&mut self, child: usize, depth: usize) { + self.hash_for_expr(child, depth - 1); + } + + /// TODO: See strange [behaviour](https://users.rust-lang.org/t/unintuitive-behaviour-with-passing-a-reference-to-trait-object-to-function/35937) + /// about `&mut dyn Hasher` and why we use `ref mut state`. + /// + /// # Panics + /// - Comparator hasher wasn't set. #[allow(clippy::too_many_lines)] - pub fn hash_for_expr<H: Hasher>(&self, top: usize, state: &mut H, depth: usize) { + pub fn hash_for_expr(&mut self, top: usize, depth: usize) { if depth == 0 { return; } let Ok(node) = self.plan.get_expression_node(top) else { return; }; + let Some(ref mut state) = self.state else { + panic!("Hasher should have been set previously"); + }; match node { Expression::ExprInParentheses { child } => { - self.hash_for_expr(*child, state, depth - 1); + self.hash_for_child_expr(*child, depth); } Expression::Alias { child, name } => { name.hash(state); - self.hash_for_expr(*child, state, depth - 1); + self.hash_for_child_expr(*child, depth); } Expression::Case { search_expr, @@ -779,33 +801,33 @@ impl<'plan> Comparator<'plan> { else_expr, } => { if let Some(search_expr) = search_expr { - self.hash_for_expr(*search_expr, state, depth - 1); + self.hash_for_child_expr(*search_expr, depth); } for (cond_expr, res_expr) in when_blocks { - self.hash_for_expr(*cond_expr, state, depth - 1); - self.hash_for_expr(*res_expr, state, depth - 1); + self.hash_for_child_expr(*cond_expr, depth); + self.hash_for_child_expr(*res_expr, depth); } if let Some(else_expr) = else_expr { - self.hash_for_expr(*else_expr, state, depth - 1); + self.hash_for_child_expr(*else_expr, depth); } } Expression::Bool { op, left, right } => { op.hash(state); - self.hash_for_expr(*left, state, depth - 1); - self.hash_for_expr(*right, state, depth - 1); + self.hash_for_child_expr(*left, depth); + self.hash_for_child_expr(*right, depth); } Expression::Arithmetic { op, left, right } => { op.hash(state); - self.hash_for_expr(*left, state, depth - 1); - self.hash_for_expr(*right, state, depth - 1); + self.hash_for_child_expr(*left, depth); + self.hash_for_child_expr(*right, depth); } Expression::Cast { child, to } => { to.hash(state); - self.hash_for_expr(*child, state, depth - 1); + self.hash_for_child_expr(*child, depth); } Expression::Concat { left, right } => { - self.hash_for_expr(*left, state, depth - 1); - self.hash_for_expr(*right, state, depth - 1); + self.hash_for_child_expr(*left, depth); + self.hash_for_child_expr(*right, depth); } Expression::Trim { kind, @@ -814,9 +836,9 @@ impl<'plan> Comparator<'plan> { } => { kind.hash(state); if let Some(pattern) = pattern { - self.hash_for_expr(*pattern, state, depth - 1); + self.hash_for_child_expr(*pattern, depth); } - self.hash_for_expr(*target, state, depth - 1); + self.hash_for_child_expr(*target, depth); } Expression::Constant { value } => { value.hash(state); @@ -842,7 +864,7 @@ impl<'plan> Comparator<'plan> { }, Expression::Row { list, .. } => { for child in list { - self.hash_for_expr(*child, state, depth - 1); + self.hash_for_child_expr(*child, depth); } } Expression::StableFunction { @@ -855,12 +877,12 @@ impl<'plan> Comparator<'plan> { func_type.hash(state); name.hash(state); for child in children { - self.hash_for_expr(*child, state, depth - 1); + self.hash_for_child_expr(*child, depth); } } Expression::Unary { child, op } => { op.hash(state); - self.hash_for_expr(*child, state, depth - 1); + self.hash_for_child_expr(*child, depth); } Expression::CountAsterisk => { "CountAsterisk".hash(state); diff --git a/sbroad-core/src/ir/helpers.rs b/sbroad-core/src/ir/helpers.rs index 5cb60600e..11626f1b3 100644 --- a/sbroad-core/src/ir/helpers.rs +++ b/sbroad-core/src/ir/helpers.rs @@ -55,6 +55,18 @@ fn formatted_tabulate(buf: &mut String, n: i32) -> Result<(), std::fmt::Error> { Ok(()) } +/// Helper formatting function for writing with tabulation. +fn write_with_tabulation(buf: &mut String, n: i32, text: &str) -> Result<(), std::fmt::Error> { + formatted_tabulate(buf, n)?; + write!(buf, "{text}") +} + +/// Helper formatting function for writing with tabulation and new line. +fn writeln_with_tabulation(buf: &mut String, n: i32, text: &str) -> Result<(), std::fmt::Error> { + formatted_tabulate(buf, n)?; + writeln!(buf, "{text}") +} + /// Formatting helper debug functions impl Plan { /// Helper function for printing Expression node. @@ -71,24 +83,18 @@ impl Plan { match expr { Expression::Alias { name, child } => { let child_node = self.get_node(*child).expect("Alias must have a child node"); - match child_node { - Node::Expression(child_expr) => { - writeln!(buf, "Alias [name = {name}, child = {child_expr:?}]")?; - } - Node::Parameter => { - writeln!(buf, "Alias [name = {name}, child = parameter]")?; - } - Node::Relational(rel) => { - writeln!(buf, "Alias [name = {name}, child = {rel:?}]")?; - } + let child = match child_node { + Node::Expression(child_expr) => format!("{child_expr:?}"), + Node::Parameter => String::from("parameter"), + Node::Relational(rel) => format!("{rel:?}"), // TODO: fix `fix_betweens` logic to cover SubQueries with References. _ => unreachable!("unexpected Alias child node"), - } + }; + writeln!(buf, "Alias [name = {name}, child = {child}]")?; } Expression::ExprInParentheses { child } => { writeln!(buf, "Parentheses")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Child")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Child")?; self.formatted_arena_node(buf, tabulation_number + 1, *child)?; } Expression::Case { @@ -98,39 +104,31 @@ impl Plan { } => { writeln!(buf, "Case")?; if let Some(search_expr) = search_expr { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Search expr")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Search_expr")?; self.formatted_arena_node(buf, tabulation_number + 1, *search_expr)?; } for (cond_expr, res_expr) in when_blocks { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "WHEN")?; + writeln_with_tabulation(buf, tabulation_number + 1, "WHEN")?; self.formatted_arena_node(buf, tabulation_number + 1, *cond_expr)?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "THEN")?; + writeln_with_tabulation(buf, tabulation_number + 1, "THEN")?; self.formatted_arena_node(buf, tabulation_number + 1, *res_expr)?; } if let Some(else_expr) = else_expr { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Else expr")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Else expr")?; self.formatted_arena_node(buf, tabulation_number + 1, *else_expr)?; } } Expression::Bool { op, left, right } => { writeln!(buf, "Bool [op: {op}]")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Left child")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Left child")?; self.formatted_arena_node(buf, tabulation_number + 1, *left)?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Right child")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Right child")?; self.formatted_arena_node(buf, tabulation_number + 1, *right)?; } Expression::Constant { value } => { writeln!(buf, "Constant [value = {value}]")?; } - Expression::CountAsterisk => { - writeln!(buf, "CountAsterisk")?; - } + Expression::CountAsterisk => writeln!(buf, "CountAsterisk")?, Expression::Reference { targets, position, @@ -140,8 +138,11 @@ impl Plan { let alias_name = self.get_alias_from_reference_node(expr).unwrap(); writeln!(buf, "Reference")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Alias: {alias_name}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Alias: {alias_name}").as_str(), + )?; // See explain logic for Reference node let rel_id = self.get_relational_from_reference_node(node_id); @@ -149,31 +150,42 @@ impl Plan { let rel_node = self.get_relation_node(*rel_id); if let Ok(rel_node) = rel_node { if let Ok(Some(name)) = rel_node.scan_name(self, *position) { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Referenced table name (or alias): {name}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Referenced table name (or alias): {name}").as_str(), + )?; } } } - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Parent: {parent:?}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Parent: {parent:?}").as_str(), + )?; if let Some(targets) = targets { for target_id in targets { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "target_id: {target_id}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("target_id: {target_id}").as_str(), + )?; } } else { writeln!(buf, "NO TARGETS")?; } - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Column type: {col_type}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Column type: {col_type}").as_str(), + )?; } Expression::Row { list, distribution } => { writeln!(buf, "Row [distribution = {distribution:?}]")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "List:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "List:")?; for value in list { self.formatted_arena_node(buf, tabulation_number + 2, *value)?; } @@ -184,8 +196,7 @@ impl Plan { Expression::StableFunction { .. } => writeln!(buf, "StableFunction")?, Expression::Unary { op, child } => { writeln!(buf, "Unary [op: {op}]")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Child")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Child")?; self.formatted_arena_node(buf, tabulation_number + 1, *child)?; } Expression::Arithmetic { .. } => writeln!(buf, "Arithmetic")?, @@ -205,8 +216,7 @@ impl Plan { if tabulation_number == 0 { writeln!(buf, "---------------------------------------------")?; } - formatted_tabulate(buf, tabulation_number)?; - write!(buf, "[id: {node_id}] ")?; + write_with_tabulation(buf, tabulation_number, format!("[id: {node_id}] ").as_str())?; let relation = self.get_relation_node(node_id); if let Ok(relation) = relation { write!(buf, "relation: ")?; @@ -216,17 +226,22 @@ impl Plan { alias, relation, .. } => { writeln!(buf, "ScanRelation")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Relation: {relation}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Relation: {relation}").as_str(), + )?; if let Some(alias) = alias { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Alias: {alias}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Alias: {alias}").as_str(), + )?; } } Relational::Join { condition, .. } => { writeln!(buf, "InnerJoin")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Condition:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Condition:")?; self.formatted_arena_node(buf, tabulation_number + 2, *condition)?; } Relational::Projection { .. } => { @@ -235,16 +250,22 @@ impl Plan { Relational::ScanCte { alias, .. } => { writeln!(buf, "ScanCte")?; if !alias.is_empty() { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Alias: {alias}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Alias: {alias}").as_str(), + )?; } } Relational::ScanSubQuery { alias, .. } => { writeln!(buf, "ScanSubQuery")?; if let Some(alias) = alias { if !alias.is_empty() { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Alias: {alias}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Alias: {alias}").as_str(), + )?; } } } @@ -254,40 +275,35 @@ impl Plan { output: _, } => { writeln!(buf, "Selection")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Filter")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Filter")?; self.formatted_arena_node(buf, tabulation_number + 1, *filter)?; } Relational::Having { filter, .. } => { writeln!(buf, "Having")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Filter")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Filter")?; self.formatted_arena_node(buf, tabulation_number + 1, *filter)?; } Relational::GroupBy { gr_cols, is_final, .. } => { writeln!(buf, "GroupBy [is_final = {is_final}]")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Gr_cols:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Gr_cols:")?; for gr_col in gr_cols { - formatted_tabulate(buf, tabulation_number + 2)?; let gl_col_expr = self.get_expression_node(*gr_col); - if let Ok(gl_col_expr) = gl_col_expr { - writeln!(buf, "Gr_col: {gl_col_expr:?}")?; + let text = if let Ok(gl_col_expr) = gl_col_expr { + format!("Gr_col: {gl_col_expr:?}") } else { - writeln!(buf, "Gr_col: {gr_col}")?; - } + format!("Gr_col: {gr_col}") + }; + writeln_with_tabulation(buf, tabulation_number + 2, text.as_str())?; } } Relational::OrderBy { order_by_elements, .. } => { writeln!(buf, "OrderBy")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Order_by_elements:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Order_by_elements:")?; for element in order_by_elements { - formatted_tabulate(buf, tabulation_number + 2)?; let order_by_entity_str = match element.entity { OrderByEntity::Expression { expr_id } => { let order_by_expr = self.get_expression_node(expr_id); @@ -300,7 +316,7 @@ impl Plan { OrderByEntity::Index { value } => format!("{value}"), }; let order_by_type = element.order_type.clone(); - writeln!(buf, "Order_by_element: {order_by_entity_str} [order_type = {order_by_type:?}]")?; + writeln_with_tabulation(buf, tabulation_number + 2, format!("Order_by_element: {order_by_entity_str} [order_type = {order_by_type:?}]").as_str())?; } } Relational::Values { .. } => writeln!(buf, "Values")?, @@ -316,14 +332,9 @@ impl Plan { .. } => { writeln!(buf, "Update")?; - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Update columns map:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Update columns map:")?; for (rel_pos, proj_pos) in update_columns_map { - formatted_tabulate(buf, tabulation_number + 2)?; - writeln!( - buf, - "Update {relation} column on pos {rel_pos} to child projection column on pos {proj_pos}" - )?; + writeln_with_tabulation(buf, tabulation_number + 2, format!("Update {relation} column on pos {rel_pos} to child projection column on pos {proj_pos}").as_str())?; } } Relational::Delete { .. } => writeln!(buf, "Delete")?, @@ -349,22 +360,25 @@ impl Plan { | Relational::Having { .. } | Relational::GroupBy { .. } | Relational::ValuesRow { .. }) => { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Children:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Children:")?; for child in &node.children() { - formatted_tabulate(buf, tabulation_number + 2)?; - writeln!(buf, "Child_id = {child}")?; + writeln_with_tabulation( + buf, + tabulation_number + 2, + format!("Child_id = {child}").as_str(), + )?; } } Relational::OrderBy { child, .. } | Relational::ScanCte { child, .. } => { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Children:")?; - formatted_tabulate(buf, tabulation_number + 2)?; - writeln!(buf, "Child_id = {child}")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Children:")?; + writeln_with_tabulation( + buf, + tabulation_number + 2, + format!("Child_id = {child}").as_str(), + )?; } Relational::ScanRelation { .. } => { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "[No children]")?; + writeln_with_tabulation(buf, tabulation_number + 1, "[No children]")?; } } // Print output. @@ -388,8 +402,11 @@ impl Plan { | Relational::UnionAll { output, .. } | Relational::Update { output, .. } | Relational::ValuesRow { output, .. } => { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Output_id: {output}")?; + writeln_with_tabulation( + buf, + tabulation_number + 1, + format!("Output_id: {output}").as_str(), + )?; self.formatted_arena_node(buf, tabulation_number + 2, *output)?; } } @@ -480,20 +497,17 @@ impl SyntaxPlan<'_> { } if let Some(left_id) = node.left { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Left:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Left:")?; self.formatted_inner(plan, buf, tabulation_number + 2, left_id)?; } if !node.right.is_empty() { - formatted_tabulate(buf, tabulation_number + 1)?; - writeln!(buf, "Right:")?; + writeln_with_tabulation(buf, tabulation_number + 1, "Right:")?; } for right in &node.right { self.formatted_inner(plan, buf, tabulation_number + 2, *right)?; } } else { - formatted_tabulate(buf, tabulation_number)?; - writeln!(buf, "MISSING")?; + writeln_with_tabulation(buf, tabulation_number, "MISSING")?; } Ok(()) } diff --git a/sbroad-core/src/ir/operator.rs b/sbroad-core/src/ir/operator.rs index b986325fd..678003dcd 100644 --- a/sbroad-core/src/ir/operator.rs +++ b/sbroad-core/src/ir/operator.rs @@ -2115,44 +2115,6 @@ impl Plan { Ok(map) } - /// Synchronize values row output with the data tuple after parameter binding. - /// - /// # Errors - /// - Node is not values row - /// - Output and data tuples have different number of columns - /// - Output is not a row of aliases - pub fn update_values_row(&mut self, id: usize) -> Result<(), SbroadError> { - let values_row = self.get_node(id)?; - let (output_id, data_id) = - if let Node::Relational(Relational::ValuesRow { output, data, .. }) = values_row { - (*output, *data) - } else { - return Err(SbroadError::Invalid( - Entity::Expression, - Some(format_smolstr!("Expected a values row: {values_row:?}")), - )); - }; - let data = self.get_expression_node(data_id)?; - let data_list = data.clone_row_list()?; - let output = self.get_expression_node(output_id)?; - let output_list = output.clone_row_list()?; - for (pos, alias_id) in output_list.iter().enumerate() { - let new_child_id = *data_list.get(pos).ok_or_else(|| { - SbroadError::NotFound(Entity::Node, format_smolstr!("at position {pos}")) - })?; - let alias = self.get_mut_expression_node(*alias_id)?; - if let Expression::Alias { ref mut child, .. } = alias { - *child = new_child_id; - } else { - return Err(SbroadError::Invalid( - Entity::Expression, - Some(format_smolstr!("expected an alias: {alias:?}")), - )); - } - } - Ok(()) - } - /// Sets children for relational node /// /// # Errors diff --git a/sbroad-core/src/ir/transformation.rs b/sbroad-core/src/ir/transformation.rs index 2c3162c75..c85d27d55 100644 --- a/sbroad-core/src/ir/transformation.rs +++ b/sbroad-core/src/ir/transformation.rs @@ -20,8 +20,40 @@ use crate::ir::{Node, Plan}; use std::collections::HashMap; pub type ExprId = usize; -/// Helper type representing map of (`old_expr_id` -> `changed_expr_id`). -pub type OldNewExpressionMap = HashMap<ExprId, ExprId>; +/// Helper struct representing map of (`old_expr_id` -> `changed_expr_id`). +struct OldNewExpressionMap { + inner: HashMap<ExprId, ExprId>, +} + +impl OldNewExpressionMap { + fn new() -> Self { + OldNewExpressionMap { + inner: HashMap::new(), + } + } + + fn insert(&mut self, old_id: usize, new_id: usize) { + self.inner.insert(old_id, new_id); + } + + fn replace(&self, child: &mut usize) { + if let Some(new_id) = self.inner.get(child) { + *child = *new_id; + } + } + + fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + fn get(&self, key: usize) -> Option<&ExprId> { + self.inner.get(&key) + } + + fn len(&self) -> usize { + self.inner.len() + } +} /// Pair of (old tree id, transformed tree id). pub type OldNewTopIdPair = (ExprId, ExprId); @@ -147,7 +179,7 @@ impl Plan { f: TransformFunction, ops: &[Bool], ) -> Result<OldNewTopIdPair, SbroadError> { - let mut map: OldNewExpressionMap = HashMap::new(); + let mut map = OldNewExpressionMap::new(); // Note, that filter accepts nodes: // * On which we'd like to apply transformation // * That will contain transformed nodes as children @@ -190,7 +222,7 @@ impl Plan { let (old_top_id, new_top_id) = if map.is_empty() { (top_id, top_id) } else { - let old_top_id = if map.get(&top_id).is_some() && map.len() == 1 { + let old_top_id = if map.get(top_id).is_some() && map.len() == 1 { top_id } else { self.clone_expr_subtree(top_id)? @@ -208,9 +240,7 @@ impl Plan { | Expression::ExprInParentheses { child, .. } | Expression::Cast { child, .. } | Expression::Unary { child, .. } => { - if let Some(new_id) = map.get(child) { - *child = *new_id; - } + map.replace(child); } Expression::Case { search_expr, @@ -218,53 +248,33 @@ impl Plan { else_expr, } => { if let Some(search_expr) = search_expr { - if let Some(new_id) = map.get(search_expr) { - *search_expr = *new_id; - } + map.replace(search_expr); } - for (cond_expr, res_expr) in when_blocks { - if let Some(new_id) = map.get(cond_expr) { - *cond_expr = *new_id; - } - if let Some(new_id) = map.get(res_expr) { - *res_expr = *new_id; - } + map.replace(cond_expr); + map.replace(res_expr); } - if let Some(else_expr) = else_expr { - if let Some(new_id) = map.get(else_expr) { - *else_expr = *new_id; - } + map.replace(else_expr); } } Expression::Bool { left, right, .. } | Expression::Arithmetic { left, right, .. } => { - if let Some(new_id) = map.get(left) { - *left = *new_id; - } - if let Some(new_id) = map.get(right) { - *right = *new_id; - } + map.replace(left); + map.replace(right); } Expression::Trim { pattern, target, .. } => { if let Some(pattern) = pattern { - if let Some(new_id) = map.get(pattern) { - *pattern = *new_id; - } - } - if let Some(new_id) = map.get(target) { - *target = *new_id; + map.replace(pattern); } + map.replace(target); } Expression::Row { list, .. } | Expression::StableFunction { children: list, .. } => { for id in list { - if let Some(new_id) = map.get(id) { - *id = *new_id; - } + map.replace(id); } } Expression::Concat { .. } @@ -274,7 +284,7 @@ impl Plan { } } // Checks if the top node is a new node. - if let Some(new_id) = map.get(&top_id) { + if let Some(new_id) = map.get(top_id) { new_top_id = *new_id; } (old_top_id, new_top_id) diff --git a/sbroad-core/src/ir/transformation/not_push_down.rs b/sbroad-core/src/ir/transformation/not_push_down.rs index 920474232..867fb970e 100644 --- a/sbroad-core/src/ir/transformation/not_push_down.rs +++ b/sbroad-core/src/ir/transformation/not_push_down.rs @@ -14,7 +14,6 @@ use crate::ir::{Node, Plan}; use crate::otm::child_span; use sbroad_proc::otm_child_span; use smol_str::{format_smolstr, SmolStr}; -use std::collections::HashMap; /// Enum representing status of Not push down traversal. /// It may be in two states: @@ -53,7 +52,7 @@ fn call_expr_tree_not_push_down( ) -> Result<OldNewTopIdPair, SbroadError> { // Because of the borrow checker we can't change `Bool` and `Row` children during recursive // traversal and have to do it using this map after transformation. - let mut old_new_expression_map = HashMap::new(); + let mut old_new_expression_map = OldNewExpressionMap::new(); let new_top_id = plan.push_down_not_for_expression(top_id, NotState::Off, &mut old_new_expression_map)?; @@ -83,23 +82,15 @@ fn call_expr_tree_not_push_down( let expr = plan.get_mut_expression_node(*id)?; match expr { Expression::ExprInParentheses { child } => { - if let Some(new_id) = old_new_expression_map.get(child) { - *child = *new_id; - } + old_new_expression_map.replace(child); } Expression::Bool { left, right, .. } => { - if let Some(new_id) = old_new_expression_map.get(left) { - *left = *new_id; - } - if let Some(new_id) = old_new_expression_map.get(right) { - *right = *new_id; - } + old_new_expression_map.replace(left); + old_new_expression_map.replace(right); } Expression::Row { list, .. } => { for id in list { - if let Some(new_id) = old_new_expression_map.get(id) { - *id = *new_id; - } + old_new_expression_map.replace(id); } } _ => {} diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs index 28c9bb557..a534ca881 100644 --- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs +++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs @@ -104,9 +104,10 @@ impl<'plan, 'args> AggregateSignature<'plan, 'args> { impl<'plan, 'args> Hash for AggregateSignature<'plan, 'args> { fn hash<H: Hasher>(&self, state: &mut H) { self.kind.hash(state); - let comp = Comparator::new(ReferencePolicy::ByAliases, self.plan); + let mut comp = Comparator::new(ReferencePolicy::ByAliases, self.plan); + comp.set_hasher(state); for arg in self.arguments { - comp.hash_for_expr(*arg, state, EXPR_HASH_DEPTH); + comp.hash_for_expr(*arg, EXPR_HASH_DEPTH); } } } @@ -141,8 +142,9 @@ impl<'plan> GroupingExpression<'plan> { impl<'plan> Hash for GroupingExpression<'plan> { fn hash<H: Hasher>(&self, state: &mut H) { - let comp = Comparator::new(ReferencePolicy::ByAliases, self.plan); - comp.hash_for_expr(self.id, state, EXPR_HASH_DEPTH); + let mut comp = Comparator::new(ReferencePolicy::ByAliases, self.plan); + comp.set_hasher(state); + comp.hash_for_expr(self.id, EXPR_HASH_DEPTH); } } -- GitLab