From fad5a250db7e17e5dbb0d7301c12822606509d1b Mon Sep 17 00:00:00 2001 From: Kaitmazian Maksim <m.kaitmazian@picodata.io> Date: Wed, 12 Jun 2024 17:41:28 +0300 Subject: [PATCH] feat: infer parameters types from cast --- sbroad-core/src/backend/sql/ir.rs | 2 +- sbroad-core/src/backend/sql/tree.rs | 2 +- sbroad-core/src/errors.rs | 38 +++++ sbroad-core/src/executor/engine/helpers.rs | 4 +- sbroad-core/src/executor/ir.rs | 2 +- sbroad-core/src/frontend/sql.rs | 14 +- .../src/frontend/sql/ir/tests/params.rs | 141 +++++++++++++++++- sbroad-core/src/ir.rs | 120 ++++++++++++++- sbroad-core/src/ir/acl.rs | 2 +- sbroad-core/src/ir/api/constant.rs | 2 +- sbroad-core/src/ir/api/parameter.rs | 10 +- sbroad-core/src/ir/block.rs | 4 +- sbroad-core/src/ir/ddl.rs | 10 +- sbroad-core/src/ir/distribution.rs | 2 +- sbroad-core/src/ir/expression/types.rs | 2 +- sbroad-core/src/ir/helpers.rs | 2 +- sbroad-core/src/ir/transformation/helpers.rs | 8 +- sbroad-core/src/ir/tree/expression.rs | 2 +- sbroad-core/src/ir/tree/relation.rs | 2 +- sbroad-core/src/ir/tree/subtree.rs | 2 +- 20 files changed, 333 insertions(+), 38 deletions(-) diff --git a/sbroad-core/src/backend/sql/ir.rs b/sbroad-core/src/backend/sql/ir.rs index c74aaf2c2d..f8bea9cf41 100644 --- a/sbroad-core/src/backend/sql/ir.rs +++ b/sbroad-core/src/backend/sql/ir.rs @@ -329,7 +329,7 @@ impl ExecutionPlan { ), )); } - Node::Parameter => { + Node::Parameter(..) => { return Err(SbroadError::Unsupported( Entity::Node, Some( diff --git a/sbroad-core/src/backend/sql/tree.rs b/sbroad-core/src/backend/sql/tree.rs index 85f0c3dfe5..64df3339aa 100644 --- a/sbroad-core/src/backend/sql/tree.rs +++ b/sbroad-core/src/backend/sql/tree.rs @@ -598,7 +598,7 @@ impl<'p> SyntaxPlan<'p> { Node::Ddl(..) => panic!("DDL node {node:?} is not supported in the syntax plan"), Node::Acl(..) => panic!("ACL node {node:?} is not supported in the syntax plan"), Node::Block(..) => panic!("Block node {node:?} is not supported in the syntax plan"), - Node::Parameter => { + Node::Parameter(..) => { let sn = SyntaxNode::new_parameter(id); self.nodes.push_sn_plan(sn); } diff --git a/sbroad-core/src/errors.rs b/sbroad-core/src/errors.rs index 7b2ac481bd..95b8e75ec3 100644 --- a/sbroad-core/src/errors.rs +++ b/sbroad-core/src/errors.rs @@ -1,3 +1,4 @@ +use crate::ir::relation::Type; use serde::Serialize; use smol_str::{format_smolstr, SmolStr, ToSmolStr}; use std::fmt; @@ -278,6 +279,33 @@ impl fmt::Display for Action { } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +pub enum TypeError { + AmbiguousParameterType(usize, Type, Type), + CouldNotDetermineParameterType(usize), +} + +impl fmt::Display for TypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let p: SmolStr = match self { + TypeError::AmbiguousParameterType(param_idx, ty1, ty2) => { + let param_num = param_idx + 1; + format_smolstr!( + "parameter ${param_num} is ambiguous, it can be either {ty1} or {ty2}" + ) + } + TypeError::CouldNotDetermineParameterType(param_idx) => { + let param_num = param_idx + 1; + format_smolstr!("could not determine data type of parameter ${param_num}") + } + }; + + write!(f, "{p}") + } +} + +impl std::error::Error for TypeError {} + /// Types of error #[derive(Clone, Debug, PartialEq, Eq, Serialize)] pub enum SbroadError { @@ -307,6 +335,7 @@ pub enum SbroadError { /// Unexpected number of values (list length etc.). /// Second param is information what was expected and what got. UnexpectedNumberOfValues(SmolStr), + TypeError(TypeError), /// Object is not supported. /// Second param represents description or name that let to identify object. /// and can be empty (None). @@ -347,6 +376,9 @@ impl fmt::Display for SbroadError { SbroadError::OutdatedStorageSchema => { "storage schema version different from router".into() } + SbroadError::TypeError(err) => { + format_smolstr!("{err}") + } }; write!(f, "{p}") @@ -374,3 +406,9 @@ impl From<Error> for SbroadError { ) } } + +impl From<TypeError> for SbroadError { + fn from(error: TypeError) -> Self { + SbroadError::TypeError(error) + } +} diff --git a/sbroad-core/src/executor/engine/helpers.rs b/sbroad-core/src/executor/engine/helpers.rs index 405c21ee33..f58b40ef28 100644 --- a/sbroad-core/src/executor/engine/helpers.rs +++ b/sbroad-core/src/executor/engine/helpers.rs @@ -788,7 +788,7 @@ pub(crate) fn materialize_values( return Ok(None); } let child_node_ref = plan.get_mut_ir_plan().get_mut_node(child_id)?; - let child_node = std::mem::replace(child_node_ref, Node::Parameter); + let child_node = std::mem::replace(child_node_ref, Node::Parameter(None)); if let Node::Relational(Relational::Values { ref children, output, @@ -849,7 +849,7 @@ pub(crate) fn materialize_values( ) })?; let column_node_ref = plan.get_mut_ir_plan().get_mut_node(column_id)?; - let column_node = std::mem::replace(column_node_ref, Node::Parameter); + let column_node = std::mem::replace(column_node_ref, Node::Parameter(None)); if let Node::Expression(Expression::Constant { value, .. }) = column_node { if let Value::Null = value { nullable_column_indices.insert(idx); diff --git a/sbroad-core/src/executor/ir.rs b/sbroad-core/src/executor/ir.rs index cde6e9ad84..719e6e4c10 100644 --- a/sbroad-core/src/executor/ir.rs +++ b/sbroad-core/src/executor/ir.rs @@ -475,7 +475,7 @@ impl ExecutionPlan { let mut node: Node = if cte_ids.contains(&node_id) { dst_node.clone() } else { - std::mem::replace(dst_node, Node::Parameter) + std::mem::replace(dst_node, Node::Parameter(None)) }; let ir_plan = self.get_ir_plan(); match node { diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs index f111cbb86f..c76de1c43f 100644 --- a/sbroad-core/src/frontend/sql.rs +++ b/sbroad-core/src/frontend/sql.rs @@ -1244,6 +1244,14 @@ fn parse_cast_expr<M: Metadata>( assert!(!cast_types.is_empty(), "cast expression has no cast types"); + if let ParseExpression::PlanId { plan_id } = child_parse_expr { + let node = plan.get_mut_node(plan_id)?; + if let Node::Parameter(..) = node { + // Assign parameter type from the cast, just like Postgres. + *node = Node::Parameter(Some(cast_types[0].as_relation_type())); + } + } + Ok(ParseExpression::Cast { cast_types, child: Box::new(child_parse_expr), @@ -1477,14 +1485,14 @@ where } } -#[derive(Clone)] +#[derive(Clone, Debug)] enum ParseExpressionInfixOperator { InfixBool(Bool), InfixArithmetic(Arithmetic), Concat, } -#[derive(Clone)] +#[derive(Clone, Debug)] enum ParseExpression { PlanId { plan_id: usize, @@ -2654,7 +2662,7 @@ impl AbstractSyntaxTree { } } } - Node::Parameter => return Err(SbroadError::Invalid( + Node::Parameter(..) => return Err(SbroadError::Invalid( Entity::Expression, Some(SmolStr::from("Using parameter as a standalone ORDER BY expression doesn't influence sorting.")) )), diff --git a/sbroad-core/src/frontend/sql/ir/tests/params.rs b/sbroad-core/src/frontend/sql/ir/tests/params.rs index 368ea4ec53..11b3ae5be1 100644 --- a/sbroad-core/src/frontend/sql/ir/tests/params.rs +++ b/sbroad-core/src/frontend/sql/ir/tests/params.rs @@ -1,7 +1,146 @@ -use crate::ir::transformation::helpers::sql_to_optimized_ir; +use crate::errors::{SbroadError, TypeError}; +use crate::ir::relation::Type; +use crate::ir::transformation::helpers::{sql_to_ir_without_bind, sql_to_optimized_ir}; use crate::ir::value::Value; use pretty_assertions::assert_eq; +fn infer_pg_parameters_types( + query: &str, + client_types: &[Option<Type>], +) -> Result<Vec<Type>, SbroadError> { + let mut plan = sql_to_ir_without_bind(query); + plan.infer_pg_parameters_types(client_types) +} + +#[test] +fn param_type_inference() { + // no params + let types = infer_pg_parameters_types(r#"SELECT * FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), []); + + // can't infer without cast + let types = infer_pg_parameters_types(r#"SELECT $1 FROM "test_space""#, &[]); + assert!(matches!( + types, + Err(SbroadError::TypeError( + TypeError::CouldNotDetermineParameterType(0) + )) + )); + + let types = infer_pg_parameters_types(r#"SELECT $1 + $2 FROM "test_space""#, &[]); + assert!(matches!( + types, + Err(SbroadError::TypeError( + TypeError::CouldNotDetermineParameterType(..) + )) + )); + + // infer types from cast + let types = infer_pg_parameters_types(r#"SELECT CAST($1 AS INTEGER) FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), [Type::Integer]); + + let types = infer_pg_parameters_types(r#"SELECT $1::integer FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), [Type::Integer]); + + let types = infer_pg_parameters_types(r#"SELECT $1::integer + $1 FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), [Type::Integer]); + + let types = infer_pg_parameters_types(r#"SELECT $1 + $1::integer FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), [Type::Integer]); + + let types = + infer_pg_parameters_types(r#"SELECT $1::integer + $1::integer FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), [Type::Integer]); + + let types = infer_pg_parameters_types( + r#"SELECT $1::integer + $2::unsigned FROM "test_space""#, + &[], + ); + assert_eq!(types.unwrap(), [Type::Integer, Type::Unsigned]); + + let types = infer_pg_parameters_types( + r#"SELECT $1::integer + $3::unsigned FROM "test_space""#, + &[], + ); + assert!(matches!( + types, + Err(SbroadError::TypeError( + TypeError::CouldNotDetermineParameterType(1) + )) + )); + + // client provided a type + let types = infer_pg_parameters_types(r#"SELECT $1 FROM "test_space""#, &[Some(Type::String)]); + assert_eq!(types.unwrap(), [Type::String]); + + // client type has a higher priority + let types = infer_pg_parameters_types( + r#"SELECT $1::integer FROM "test_space""#, + &[Some(Type::String)], + ); + assert_eq!(types.unwrap(), [Type::String]); + + let types = infer_pg_parameters_types( + r#"SELECT $1::integer + $1 FROM "test_space""#, + &[Some(Type::Unsigned)], + ); + assert_eq!(types.unwrap(), [Type::Unsigned]); + + // infer one type and get another from the client + let types = infer_pg_parameters_types( + r#"SELECT $1 + $2::unsigned FROM "test_space""#, + &[Some(Type::Unsigned)], + ); + assert_eq!(types.unwrap(), [Type::Unsigned, Type::Unsigned]); + + let types = infer_pg_parameters_types( + r#"SELECT $1::unsigned + $2 FROM "test_space""#, + &[None, Some(Type::Unsigned)], + ); + assert_eq!(types.unwrap(), [Type::Unsigned, Type::Unsigned]); + + // ambiguous types + let types = infer_pg_parameters_types( + r#"SELECT $1::integer + $1::unsigned FROM "test_space""#, + &[], + ); + assert!(matches!( + types, + Err(SbroadError::TypeError(TypeError::AmbiguousParameterType( + .. + ))) + )); + + // too many client types + let types = infer_pg_parameters_types( + r#"SELECT $1 FROM "test_space""#, + &[Some(Type::String), Some(Type::Unsigned)], + ); + assert!(matches!( + types, + Err(SbroadError::UnexpectedNumberOfValues(..)) + )); + + let types = infer_pg_parameters_types(r#"SELECT $1::integer::text FROM "test_space""#, &[]); + assert_eq!(types.unwrap(), [Type::Integer]); + + let types = infer_pg_parameters_types(r#"SELECT ($1 * 1.0)::integer FROM "test_space""#, &[]); + assert!(matches!( + types, + Err(SbroadError::TypeError( + TypeError::CouldNotDetermineParameterType(0) + )) + )); + + let types = infer_pg_parameters_types(r#"SELECT $1 * 1::integer FROM "test_space""#, &[]); + assert!(matches!( + types, + Err(SbroadError::TypeError( + TypeError::CouldNotDetermineParameterType(0) + )) + )); +} + #[test] fn front_param_in_cast() { let pattern = r#"SELECT CAST(? AS INTEGER) FROM "test_space""#; diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs index 63da79eca1..c513e2ce05 100644 --- a/sbroad-core/src/ir.rs +++ b/sbroad-core/src/ir.rs @@ -18,10 +18,10 @@ use block::Block; use ddl::Ddl; use expression::Expression; use operator::{Arithmetic, Relational}; -use relation::Table; +use relation::{Table, Type}; use crate::errors::Entity::Query; -use crate::errors::{Action, Entity, SbroadError}; +use crate::errors::{Action, Entity, SbroadError, TypeError}; use crate::executor::engine::TableVersionMap; use crate::ir::expression::Expression::StableFunction; use crate::ir::helpers::RepeatableState; @@ -79,7 +79,11 @@ pub enum Node { Ddl(Ddl), Expression(Expression), Relational(Relational), - Parameter, + // The parameter type is inferred from the context. A typical value is None, i. e. any type. + // However, there is a special case where we can be more specific. According to Postgres, + // the parameter type can be specified by casting the parameter to a particular type. + // Thus, when we cast a parameter, we also assign it a type. + Parameter(Option<Type>), } /// Plan nodes storage. @@ -939,7 +943,7 @@ impl Plan { match node { Node::Relational(rel) => Ok(rel), Node::Expression(_) - | Node::Parameter + | Node::Parameter(..) | Node::Ddl(..) | Node::Acl(..) | Node::Block(..) => Err(SbroadError::Invalid( @@ -961,7 +965,7 @@ impl Plan { match self.get_mut_node(node_id)? { Node::Relational(rel) => Ok(rel), Node::Expression(_) - | Node::Parameter + | Node::Parameter(..) | Node::Ddl(..) | Node::Acl(..) | Node::Block(..) => Err(SbroadError::Invalid( @@ -979,7 +983,7 @@ impl Plan { pub fn get_expression_node(&self, node_id: usize) -> Result<&Expression, SbroadError> { match self.get_node(node_id)? { Node::Expression(exp) => Ok(exp), - Node::Parameter => { + Node::Parameter(..) => { let node = self.constants.get(node_id); if let Some(Node::Expression(exp)) = node { Ok(exp) @@ -1009,7 +1013,7 @@ impl Plan { match node { Node::Expression(exp) => Ok(exp), Node::Relational(_) - | Node::Parameter + | Node::Parameter(..) | Node::Ddl(..) | Node::Acl(..) | Node::Block(..) => Err(SbroadError::Invalid( @@ -1326,6 +1330,108 @@ impl Plan { } } +impl Plan { + fn get_param_type(&self, param_id: NodeId) -> Result<Option<Type>, SbroadError> { + let node = self.get_node(param_id)?; + if let Node::Parameter(ty) = node { + return Ok(ty.clone()); + } + Err(SbroadError::Invalid( + Entity::Node, + Some(format_smolstr!("node is not Parameter type: {node:?}")), + )) + } + + fn set_param_type(&mut self, param_id: NodeId, ty: &Type) -> Result<(), SbroadError> { + let node = self.get_mut_node(param_id)?; + if let Node::Parameter(..) = node { + *node = Node::Parameter(Some(ty.clone())); + Ok(()) + } else { + Err(SbroadError::Invalid( + Entity::Node, + Some(format_smolstr!("node is not Parameter type: {node:?}")), + )) + } + } + + fn count_pg_parameters(&self) -> usize { + self.pg_params_map + .values() + .fold(0, |p1, p2| std::cmp::max(p1, *p2 + 1)) // idx 0 stands for $1 + } + + /// Infer parameter types specified via cast. + /// + /// # Errors + /// - Parameter type is ambiguous. + /// + /// # Panics + /// - `self.pg_params_map` missed some parameters. + pub fn infer_pg_parameters_types( + &mut self, + client_types: &[Option<Type>], + ) -> Result<Vec<Type>, SbroadError> { + let params_count = self.count_pg_parameters(); + if params_count < client_types.len() { + return Err(SbroadError::UnexpectedNumberOfValues(format_smolstr!( + "client provided {} types for {} parameters", + client_types.len(), + params_count + ))); + } + let mut inferred_types = vec![None; params_count]; + + for (node_id, param_idx) in &self.pg_params_map { + let param_type = self.get_param_type(*node_id)?; + let inferred_type = inferred_types.get(*param_idx).unwrap_or_else(|| { + panic!("param idx {param_idx} exceeds params count {params_count}") + }); + let client_type = client_types.get(*param_idx).cloned().flatten(); + match (param_type, inferred_type, client_type) { + (_, _, Some(client_type)) => { + // Client provided an explicit type, no additional checks are required. + inferred_types[*param_idx] = Some(client_type.clone()); + } + (Some(param_type), Some(inferred_type), None) => { + if ¶m_type != inferred_type { + // We've inferred 2 different types for the same parameter. + return Err(TypeError::AmbiguousParameterType( + *param_idx, + param_type, + inferred_type.clone(), + ) + .into()); + } + } + (Some(param_type), None, None) => { + // We've inferred a more specific type from the context. + inferred_types[*param_idx] = Some(param_type); + } + _ => {} + } + } + + let types = inferred_types + .into_iter() + .enumerate() + .map(|(idx, ty)| ty.ok_or(TypeError::CouldNotDetermineParameterType(idx).into())) + .collect::<Result<Vec<_>, SbroadError>>()?; + + // Specify inferred types in all parameters nodes, allowing to calculate the result type + // for queries like `SELECT $1::int + $1`. Without this correction there will be an error + // like int and scalar are not supported for arithmetic expression, despite of the fact + // that the type of parameter was specified. + // + // TODO: Avoid cloning of self.pg_params_map. + for (node_id, param_idx) in &self.pg_params_map.clone() { + self.set_param_type(*node_id, &types[*param_idx])?; + } + + Ok(types) + } +} + /// Target positions in the reference. pub type Positions = [Option<Position>; 2]; diff --git a/sbroad-core/src/ir/acl.rs b/sbroad-core/src/ir/acl.rs index b2c2959956..adec244ce3 100644 --- a/sbroad-core/src/ir/acl.rs +++ b/sbroad-core/src/ir/acl.rs @@ -291,7 +291,7 @@ impl Plan { // Check that node is ACL type (before making any distructive operations). let _ = self.get_acl_node(node_id)?; // Replace ACL with parameter node. - let node = std::mem::replace(self.get_mut_node(node_id)?, Node::Parameter); + let node = std::mem::replace(self.get_mut_node(node_id)?, Node::Parameter(None)); match node { Node::Acl(acl) => Ok(acl), _ => Err(SbroadError::Invalid( diff --git a/sbroad-core/src/ir/api/constant.rs b/sbroad-core/src/ir/api/constant.rs index 65356e0f97..c9ce32ae4d 100644 --- a/sbroad-core/src/ir/api/constant.rs +++ b/sbroad-core/src/ir/api/constant.rs @@ -99,7 +99,7 @@ impl Plan { pub fn stash_constants(&mut self) -> Result<(), SbroadError> { let constants = self.get_const_list(); for const_id in constants { - let const_node = self.nodes.replace(const_id, Node::Parameter)?; + let const_node = self.nodes.replace(const_id, Node::Parameter(None))?; self.constants.insert(const_id, const_node); } Ok(()) diff --git a/sbroad-core/src/ir/api/parameter.rs b/sbroad-core/src/ir/api/parameter.rs index 441e724545..fa5cd4e320 100644 --- a/sbroad-core/src/ir/api/parameter.rs +++ b/sbroad-core/src/ir/api/parameter.rs @@ -109,7 +109,7 @@ impl<'binder> ParamsBinder<'binder> { // otherwise we may get different hashes for plans // with tnt and pg parameters. See `subtree_hash*` tests, for (_, param_id) in &self.nodes { - if !matches!(self.plan.get_node(*param_id)?, Node::Parameter) { + if !matches!(self.plan.get_node(*param_id)?, Node::Parameter(..)) { continue; } let value_idx = *self.pg_params_map.get(param_id).unwrap_or_else(|| { @@ -349,7 +349,7 @@ impl<'binder> ParamsBinder<'binder> { } } }, - Node::Parameter | Node::Ddl(..) | Node::Acl(..) => {} + Node::Parameter(..) | Node::Ddl(..) | Node::Acl(..) => {} } } @@ -521,7 +521,7 @@ impl<'binder> ParamsBinder<'binder> { } } }, - Node::Parameter | Node::Ddl(..) | Node::Acl(..) => {} + Node::Parameter(..) | Node::Ddl(..) | Node::Acl(..) => {} } } @@ -540,7 +540,7 @@ impl<'binder> ParamsBinder<'binder> { impl Plan { pub fn add_param(&mut self) -> usize { - self.nodes.push(Node::Parameter) + self.nodes.push(Node::Parameter(None)) } /// Bind params related to `Option` clause. @@ -587,7 +587,7 @@ impl Plan { .iter() .enumerate() .filter_map(|(id, node)| { - if let Node::Parameter = node { + if let Node::Parameter(..) = node { Some(id) } else { None diff --git a/sbroad-core/src/ir/block.rs b/sbroad-core/src/ir/block.rs index df483d3b0e..b7313321ce 100644 --- a/sbroad-core/src/ir/block.rs +++ b/sbroad-core/src/ir/block.rs @@ -38,7 +38,7 @@ impl Plan { | Node::Relational(_) | Node::Ddl(..) | Node::Acl(..) - | Node::Parameter => Err(SbroadError::Invalid( + | Node::Parameter(..) => Err(SbroadError::Invalid( Entity::Node, Some(format_smolstr!( "node {node:?} (id {node_id}) is not Block type" @@ -59,7 +59,7 @@ impl Plan { | Node::Relational(_) | Node::Ddl(..) | Node::Acl(..) - | Node::Parameter => Err(SbroadError::Invalid( + | Node::Parameter(..) => Err(SbroadError::Invalid( Entity::Node, Some(format_smolstr!( "node {node:?} (id {node_id}) is not Block type" diff --git a/sbroad-core/src/ir/ddl.rs b/sbroad-core/src/ir/ddl.rs index f692c28533..811d02ffa3 100644 --- a/sbroad-core/src/ir/ddl.rs +++ b/sbroad-core/src/ir/ddl.rs @@ -1,6 +1,6 @@ use crate::{ errors::{Entity, SbroadError}, - ir::{relation::Type, Node, Plan}, + ir::{relation::Type as RelationType, Node, Plan}, }; use serde::{Deserialize, Serialize}; use smol_str::{format_smolstr, SmolStr, ToSmolStr}; @@ -13,7 +13,7 @@ use tarantool::{ #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] pub struct ColumnDef { pub name: SmolStr, - pub data_type: Type, + pub data_type: RelationType, pub is_nullable: bool, } @@ -21,7 +21,7 @@ impl Default for ColumnDef { fn default() -> Self { Self { name: SmolStr::default(), - data_type: Type::default(), + data_type: RelationType::default(), is_nullable: true, } } @@ -29,7 +29,7 @@ impl Default for ColumnDef { #[derive(Clone, Debug, Default, Deserialize, PartialEq, Eq, Serialize)] pub struct ParamDef { - pub data_type: Type, + pub data_type: RelationType, } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Eq, Serialize)] @@ -201,7 +201,7 @@ impl Plan { // Check that node is DDL type (before making any distructive operations). let _ = self.get_ddl_node(node_id)?; // Replace DDL with parameter node. - let node = std::mem::replace(self.get_mut_node(node_id)?, Node::Parameter); + let node = std::mem::replace(self.get_mut_node(node_id)?, Node::Parameter(None)); match node { Node::Ddl(ddl) => Ok(ddl), _ => Err(SbroadError::Invalid( diff --git a/sbroad-core/src/ir/distribution.rs b/sbroad-core/src/ir/distribution.rs index 6a60f4a7bc..16bc3e4b69 100644 --- a/sbroad-core/src/ir/distribution.rs +++ b/sbroad-core/src/ir/distribution.rs @@ -795,7 +795,7 @@ impl Plan { .to_smolstr(), ), )), - Node::Parameter => Err(SbroadError::Invalid( + Node::Parameter(..) => Err(SbroadError::Invalid( Entity::Distribution, Some("Failed to get distribution for a parameter node.".to_smolstr()), )), diff --git a/sbroad-core/src/ir/expression/types.rs b/sbroad-core/src/ir/expression/types.rs index 099b677798..4cc8c02fb3 100644 --- a/sbroad-core/src/ir/expression/types.rs +++ b/sbroad-core/src/ir/expression/types.rs @@ -17,7 +17,7 @@ impl Plan { )), // Parameter nodes must recalculate their type during // binding (see `bind_params` function). - Node::Parameter => Ok(Type::Scalar), + Node::Parameter(ty) => Ok(ty.clone().unwrap_or(Type::Scalar)), Node::Ddl(_) => Err(SbroadError::Invalid( Entity::Node, Some("DDL node has no type".to_smolstr()), diff --git a/sbroad-core/src/ir/helpers.rs b/sbroad-core/src/ir/helpers.rs index 11626f1b3f..f553dd4b32 100644 --- a/sbroad-core/src/ir/helpers.rs +++ b/sbroad-core/src/ir/helpers.rs @@ -85,7 +85,7 @@ impl Plan { let child_node = self.get_node(*child).expect("Alias must have a child node"); let child = match child_node { Node::Expression(child_expr) => format!("{child_expr:?}"), - Node::Parameter => String::from("parameter"), + Node::Parameter(..) => String::from("parameter"), Node::Relational(rel) => format!("{rel:?}"), // TODO: fix `fix_betweens` logic to cover SubQueries with References. _ => unreachable!("unexpected Alias child node"), diff --git a/sbroad-core/src/ir/transformation/helpers.rs b/sbroad-core/src/ir/transformation/helpers.rs index edc01e1208..6eb3f40ea5 100644 --- a/sbroad-core/src/ir/transformation/helpers.rs +++ b/sbroad-core/src/ir/transformation/helpers.rs @@ -29,13 +29,17 @@ pub fn sql_to_optimized_ir(query: &str, params: Vec<Value>) -> Plan { /// if query is not correct #[must_use] pub fn sql_to_ir(query: &str, params: Vec<Value>) -> Plan { - let metadata = &RouterConfigurationMock::new(); - let mut plan = AbstractSyntaxTree::transform_into_plan(query, metadata).unwrap(); + let mut plan = sql_to_ir_without_bind(query); plan.bind_params(params).unwrap(); plan.apply_options().unwrap(); plan } +pub fn sql_to_ir_without_bind(query: &str) -> Plan { + let metadata = &RouterConfigurationMock::new(); + AbstractSyntaxTree::transform_into_plan(query, metadata).unwrap() +} + /// Compiles and transforms an SQL query to a new parameterized SQL. /// /// # Panics diff --git a/sbroad-core/src/ir/tree/expression.rs b/sbroad-core/src/ir/tree/expression.rs index b1621e2fe4..53737b4f52 100644 --- a/sbroad-core/src/ir/tree/expression.rs +++ b/sbroad-core/src/ir/tree/expression.rs @@ -170,7 +170,7 @@ fn expression_next<'nodes>( | Node::Block(_) | Node::Ddl(_) | Node::Relational(_) - | Node::Parameter => None, + | Node::Parameter(..) => None, } } None => None, diff --git a/sbroad-core/src/ir/tree/relation.rs b/sbroad-core/src/ir/tree/relation.rs index ebbbe300ce..0bb709fb5a 100644 --- a/sbroad-core/src/ir/tree/relation.rs +++ b/sbroad-core/src/ir/tree/relation.rs @@ -110,7 +110,7 @@ fn relational_next<'nodes>( Some( Node::Relational(Relational::ScanRelation { .. }) | Node::Expression(_) - | Node::Parameter + | Node::Parameter(_) | Node::Ddl(_) | Node::Acl(_) | Node::Block(_), diff --git a/sbroad-core/src/ir/tree/subtree.rs b/sbroad-core/src/ir/tree/subtree.rs index 8066acbc6c..9c69a32e1b 100644 --- a/sbroad-core/src/ir/tree/subtree.rs +++ b/sbroad-core/src/ir/tree/subtree.rs @@ -195,7 +195,7 @@ fn subtree_next<'plan>( ) -> Option<&'plan usize> { if let Some(child) = iter.get_nodes().arena.get(iter.get_current()) { return match child { - Node::Parameter | Node::Ddl(..) | Node::Acl(..) | Node::Block(..) => None, + Node::Parameter(..) | Node::Ddl(..) | Node::Acl(..) | Node::Block(..) => None, Node::Expression(expr) => match expr { Expression::Alias { .. } | Expression::ExprInParentheses { .. } -- GitLab