From 10442f7d7b30090dcfd6149ecf29545dab9944b7 Mon Sep 17 00:00:00 2001 From: "ms.evilhat" <ms.evilhat@gmail.com> Date: Wed, 8 Feb 2023 10:21:32 +0300 Subject: [PATCH] fix: fix usage alias with arithmetic expression previously we suggested that operand of arithmetic expression is column. it allows use alias (a as a1 + b as b2). ofcourse operand must be value (a), while alias usage must be available for expression (a + b as sum) --- .../test/integration/arithmetic_test.lua | 10 +++ sbroad-core/src/frontend/sql.rs | 51 ++++++++++- sbroad-core/src/frontend/sql/ast.rs | 86 ++++++++++++++++++ sbroad-core/src/frontend/sql/ast/tests.rs | 90 +++++++++++++++---- sbroad-core/src/frontend/sql/query.pest | 5 +- .../sql/arithmetic_projection_ast.yaml | 18 ++-- .../sql/arithmetic_selection_ast.yaml | 40 ++++----- 7 files changed, 244 insertions(+), 56 deletions(-) diff --git a/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua b/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua index c4ab8f7c1e..ad3c89712f 100644 --- a/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua +++ b/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua @@ -218,6 +218,16 @@ end g2.test_arithmetic_invalid = function() local api = cluster:server("api-1").net_box + local _, err = api:call("sbroad.execute", { + [[select "id" as "alias1" + "a" as "alias2" from "arithmetic_space"]], {} + }) + t.assert_str_contains(tostring(err), "rule parsing error") + + local _, err = api:call("sbroad.execute", { + [[select ("id" + "a") as "alias1" + "b" as "alias2" from "arithmetic_space"]], {} + }) + t.assert_str_contains(tostring(err), "rule parsing error") + local _, err = api:call("sbroad.execute", { [[select "id" % 2 from "arithmetic_space"]], {} }) t.assert_str_contains(tostring(err), "rule parsing error") diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs index f94346189e..1cb30778ab 100644 --- a/sbroad-core/src/frontend/sql.rs +++ b/sbroad-core/src/frontend/sql.rs @@ -797,21 +797,63 @@ impl Ast for AbstractSyntaxTree { )); } } - Type::Multiplication | Type::Addition => { + Type::ArithmeticExprAlias => { + // left child is Addition, Multiplication or ArithParentheses + let ast_left_id = ast_column.children.first().ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "ArithmeticExprAlias has no children.".into(), + ) + })?; + + let arithmetic_parse_node = self.nodes.get_node(*ast_left_id)?; + if arithmetic_parse_node.rule != Type::Multiplication + && arithmetic_parse_node.rule != Type::Addition + { + return Err(SbroadError::Invalid( + Entity::Node, + Some(format!("expected Multiplication or Addition as the first child of ArithmeticExprAlias, got {}", + arithmetic_parse_node.rule)), + )); + } + let cond_id = get_arithmetic_cond_id( &mut plan, - ast_column, + arithmetic_parse_node, &map, &mut arithmetic_expression_ids, &mut rows, )?; - columns.push(cond_id); + + // right child is AliasName if exists + // else means that arithmetic expression does not have an alias + match ast_column.children.get(1).ok_or_else(|| { + SbroadError::NotFound( + Entity::Node, + "that is right node with index 1 among ArithmeticExprAlias children" + .into(), + ) + }) { + Ok(ast_name_id) => { + let name = self + .nodes + .get_node(*ast_name_id)? + .value + .as_ref() + .ok_or_else(|| SbroadError::NotFound(Entity::Name, "of Alias".into()))?; + + let plan_alias_id = plan + .nodes + .add_alias(&normalize_name_from_sql(name), cond_id)?; + columns.push(plan_alias_id); + }, + Err(_) => { columns.push(cond_id); }, + } } _ => { return Err(SbroadError::Invalid( Entity::Type, Some(format!( - "expected a Column, Asterisk, Multiplication or Addition in projection, got {:?}.", + "expected a Column, Asterisk, ArithmeticExprAlias in projection, got {:?}.", ast_column.rule )), )); @@ -942,6 +984,7 @@ impl Ast for AbstractSyntaxTree { } Type::AliasName | Type::Add + | Type::ArithmeticExprAlias | Type::ArithParentheses | Type::ColumnName | Type::Divide diff --git a/sbroad-core/src/frontend/sql/ast.rs b/sbroad-core/src/frontend/sql/ast.rs index 505c075a70..8c9afc0825 100644 --- a/sbroad-core/src/frontend/sql/ast.rs +++ b/sbroad-core/src/frontend/sql/ast.rs @@ -6,6 +6,7 @@ extern crate pest; use std::collections::{hash_map::Entry, HashMap, HashSet}; +use std::fmt; use std::mem::swap; use pest::iterators::Pair; @@ -30,6 +31,7 @@ pub enum Type { AliasName, And, ArithmeticExpr, + ArithmeticExprAlias, ArithParentheses, Asterisk, Between, @@ -110,6 +112,7 @@ impl Type { Rule::AliasName => Ok(Type::AliasName), Rule::And => Ok(Type::And), Rule::ArithmeticExpr => Ok(Type::ArithmeticExpr), + Rule::ArithmeticExprAlias => Ok(Type::ArithmeticExprAlias), Rule::ArithParentheses => Ok(Type::ArithParentheses), Rule::Asterisk => Ok(Type::Asterisk), Rule::Between => Ok(Type::Between), @@ -185,6 +188,89 @@ impl Type { } } +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let p = match self { + Type::Add => "Add".to_string(), + Type::Addition => "Addition".to_string(), + Type::Alias => "Alias".to_string(), + Type::AliasName => "AliasName".to_string(), + Type::And => "And".to_string(), + Type::ArithmeticExpr => "ArithmeticExpr".to_string(), + Type::ArithmeticExprAlias => "ArithmeticExprAlias".to_string(), + Type::ArithParentheses => "ArithParentheses".to_string(), + Type::Asterisk => "Asterisk".to_string(), + Type::Between => "Between".to_string(), + Type::Cast => "Cast".to_string(), + Type::Column => "Column".to_string(), + Type::ColumnName => "ColumnName".to_string(), + Type::Concat => "Concat".to_string(), + Type::Condition => "Condition".to_string(), + Type::Decimal => "Decimal".to_string(), + Type::Divide => "Divide".to_string(), + Type::Double => "Double".to_string(), + Type::Eq => "Eq".to_string(), + Type::Except => "Except".to_string(), + Type::Explain => "Explain".to_string(), + Type::False => "False".to_string(), + Type::Function => "Function".to_string(), + Type::FunctionName => "FunctionName".to_string(), + Type::Gt => "Gt".to_string(), + Type::GtEq => "GtEq".to_string(), + Type::In => "In".to_string(), + Type::InnerJoin => "InnerJoin".to_string(), + Type::Insert => "Insert".to_string(), + Type::Integer => "Integer".to_string(), + Type::IsNotNull => "IsNotNull".to_string(), + Type::IsNull => "IsNull".to_string(), + Type::Length => "Length".to_string(), + Type::Lt => "Lt".to_string(), + Type::LtEq => "LtEq".to_string(), + Type::Multiplication => "Multiplication".to_string(), + Type::Multiply => "Multiply".to_string(), + Type::Name => "Name".to_string(), + Type::NotEq => "NotEq".to_string(), + Type::NotIn => "NotIn".to_string(), + Type::Null => "Null".to_string(), + Type::Or => "Or".to_string(), + Type::Parameter => "Parameter".to_string(), + Type::Parentheses => "Parentheses".to_string(), + Type::Primary => "Primary".to_string(), + Type::Projection => "Projection".to_string(), + Type::Reference => "Reference".to_string(), + Type::Row => "Row".to_string(), + Type::Scan => "Scan".to_string(), + Type::ScanName => "ScanName".to_string(), + Type::Select => "Select".to_string(), + Type::Selection => "Selection".to_string(), + Type::String => "String".to_string(), + Type::SubQuery => "SubQuery".to_string(), + Type::SubQueryName => "SubQueryName".to_string(), + Type::Subtract => "Subtract".to_string(), + Type::Table => "Table".to_string(), + Type::TargetColumns => "TargetColumns".to_string(), + Type::True => "True".to_string(), + Type::TypeAny => "TypeAny".to_string(), + Type::TypeBool => "TypeBool".to_string(), + Type::TypeDecimal => "TypeDecimal".to_string(), + Type::TypeDouble => "TypeDouble".to_string(), + Type::TypeInt => "TypeInt".to_string(), + Type::TypeNumber => "TypeNumber".to_string(), + Type::TypeScalar => "TypeScalar".to_string(), + Type::TypeString => "TypeString".to_string(), + Type::TypeText => "TypeText".to_string(), + Type::TypeUnsigned => "TypeUnsigned".to_string(), + Type::TypeVarchar => "TypeVarchar".to_string(), + Type::UnionAll => "UnionAll".to_string(), + Type::Unsigned => "Unsigned".to_string(), + Type::Value => "Value".to_string(), + Type::Values => "Values".to_string(), + Type::ValuesRow => "ValuesRow".to_string(), + }; + write!(f, "{p}") + } +} + /// Parse node is a wrapper over the pest pair. #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] pub struct ParseNode { diff --git a/sbroad-core/src/frontend/sql/ast/tests.rs b/sbroad-core/src/frontend/sql/ast/tests.rs index f60ace70c3..f4f3c9fa2a 100644 --- a/sbroad-core/src/frontend/sql/ast/tests.rs +++ b/sbroad-core/src/frontend/sql/ast/tests.rs @@ -222,10 +222,6 @@ fn sql_arithmetic_selection_ast() { let node = ast.nodes.get_node(a_id).unwrap(); assert_eq!(node.rule, Type::Reference); - let (_, col_id) = iter.next().unwrap(); - let node = ast.nodes.get_node(col_id).unwrap(); - assert_eq!(node.rule, Type::Column); - let (_, add_id) = iter.next().unwrap(); let node = ast.nodes.get_node(add_id).unwrap(); assert_eq!(node.rule, Type::Add); @@ -239,10 +235,6 @@ fn sql_arithmetic_selection_ast() { let node = ast.nodes.get_node(b_id).unwrap(); assert_eq!(node.rule, Type::Reference); - let (_, col_id) = iter.next().unwrap(); - let node = ast.nodes.get_node(col_id).unwrap(); - assert_eq!(node.rule, Type::Column); - let (_, addition_id) = iter.next().unwrap(); let node = ast.nodes.get_node(addition_id).unwrap(); assert_eq!(node.rule, Type::Addition); @@ -327,9 +319,72 @@ fn sql_arithmetic_projection_ast() { let node = ast.nodes.get_node(a_id).unwrap(); assert_eq!(node.rule, Type::Reference); - let (_, col_id) = iter.next().unwrap(); - let node = ast.nodes.get_node(col_id).unwrap(); - assert_eq!(node.rule, Type::Column); + let (_, add_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(add_id).unwrap(); + assert_eq!(node.rule, Type::Add); + + let (_, sel_name_b_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(sel_name_b_id).unwrap(); + assert_eq!(node.rule, Type::ColumnName); + assert_eq!(node.value, Some("b".to_string())); + + let (_, b_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(b_id).unwrap(); + assert_eq!(node.rule, Type::Reference); + + let (_, addition_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(addition_id).unwrap(); + assert_eq!(node.rule, Type::Addition); + + let (_, arithm_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(arithm_id).unwrap(); + assert_eq!(node.rule, Type::ArithmeticExprAlias); + + let (_, proj_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(proj_id).unwrap(); + assert_eq!(node.rule, Type::Projection); + + assert_eq!(None, iter.next()); +} + +#[test] +fn sql_arithmetic_projection_alias_ast() { + let ast = AbstractSyntaxTree::new("select a as alias1 + b as alias2 from t").unwrap_err(); + assert_eq!( + format!( + r#"rule parsing error: --> 1:10 + | +1 | select a as alias1 + b as alias2 from t + | ^--- + | + = expected Multiply, Divide, Add, or Subtract"#, + ), + format!("{ast}"), + ); + + let ast = AbstractSyntaxTree::new("select a + b as sum from t").unwrap(); + + let top = ast.top.unwrap(); + let mut dft_post = PostOrder::with_capacity(|node| ast.nodes.ast_iter(node), 64); + let mut iter = dft_post.iter(top); + + let (_, table_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(table_id).unwrap(); + assert_eq!(node.rule, Type::Table); + assert_eq!(node.value, Some("t".to_string())); + + let (_, scan_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(scan_id).unwrap(); + assert_eq!(node.rule, Type::Scan); + + let (_, sel_name_a_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(sel_name_a_id).unwrap(); + assert_eq!(node.rule, Type::ColumnName); + assert_eq!(node.value, Some("a".to_string())); + + let (_, a_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(a_id).unwrap(); + assert_eq!(node.rule, Type::Reference); let (_, add_id) = iter.next().unwrap(); let node = ast.nodes.get_node(add_id).unwrap(); @@ -346,12 +401,17 @@ fn sql_arithmetic_projection_ast() { let (_, col_id) = iter.next().unwrap(); let node = ast.nodes.get_node(col_id).unwrap(); - assert_eq!(node.rule, Type::Column); - - let (_, addition_id) = iter.next().unwrap(); - let node = ast.nodes.get_node(addition_id).unwrap(); assert_eq!(node.rule, Type::Addition); + let (_, alias_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(alias_id).unwrap(); + assert_eq!(node.rule, Type::AliasName); + assert_eq!(node.value, Some("sum".into())); + + let (_, arithm_id) = iter.next().unwrap(); + let node = ast.nodes.get_node(arithm_id).unwrap(); + assert_eq!(node.rule, Type::ArithmeticExprAlias); + let (_, proj_id) = iter.next().unwrap(); let node = ast.nodes.get_node(proj_id).unwrap(); assert_eq!(node.rule, Type::Projection); diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest index 85e6b8d892..981e6ef906 100644 --- a/sbroad-core/src/frontend/sql/query.pest +++ b/sbroad-core/src/frontend/sql/query.pest @@ -9,9 +9,10 @@ Query = _{ Except | UnionAll | Select | Values | Insert } (((^"inner" ~ ^"join") | ^"join") ~ InnerJoin ~ ^"on" ~ Condition)? ~ (^"where" ~ Selection)? } - Projection = { (Asterisk | ArithmeticExpr | Column) ~ ("," ~ (Asterisk | ArithmeticExpr | Column))*? } + Projection = { (Asterisk | ArithmeticExprAlias | Column) ~ ("," ~ (Asterisk | ArithmeticExprAlias | Column))*? } Column = { Alias | Value } Alias = {Value ~ ^"as" ~ AliasName } + ArithmeticExprAlias = { ArithmeticExpr ~ (^"as" ~ AliasName)? } AliasName = @{ Name } Reference = { (ScanName ~ "." ~ ColumnName) | ColumnName } ColumnName = @{ Name } @@ -33,7 +34,7 @@ Query = _{ Except | UnionAll | Select | Values | Insert } ArithmeticExpr = _{ Addition | Multiplication | ArithParentheses } ArithParentheses = { "(" ~ ArithmeticExpr ~ ")" } - ArithELeft = _{ ArithParentheses | Column } + ArithELeft = _{ ArithParentheses | Value } Multiplication = { ArithELeft ~ (Multiply | Divide) ~ MultiplicationRight } MultiplicationRight = _{ Multiplication | ArithELeft } Addition = { AdditionLeft ~ (Add | Subtract) ~ AdditionRight } diff --git a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml index 9e0ea7e5e0..635c718464 100644 --- a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml +++ b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml @@ -18,14 +18,14 @@ nodes: rule: Projection value: ~ - children: - - 9 - - 8 - - 5 - rule: Addition + - 5 + rule: ArithmeticExprAlias value: ~ - children: + - 9 + - 8 - 6 - rule: Column + rule: Addition value: ~ - children: - 7 @@ -39,10 +39,6 @@ nodes: value: "+" - children: - 10 - rule: Column - value: ~ - - children: - - 11 rule: Reference value: ~ - children: [] @@ -50,7 +46,7 @@ nodes: value: "a" top: 3 map: - 10: - - 1 6: + - 1 + 9: - 1 \ No newline at end of file diff --git a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml index 07d0be3fa2..7d6d4d6349 100644 --- a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml +++ b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml @@ -2,11 +2,11 @@ nodes: arena: - children: - - 14 + - 12 rule: Select value: ~ - children: - - 12 + - 10 - 2 rule: Selection value: ~ @@ -19,17 +19,13 @@ nodes: rule: Unsigned value: 1 - children: - - 9 - 8 + - 7 - 5 rule: Addition value: ~ - children: - 6 - rule: Column - value: ~ - - children: - - 7 rule: Reference value: ~ - children: [] @@ -39,18 +35,14 @@ nodes: rule: Add value: "+" - children: - - 10 - rule: Column - value: ~ - - children: - - 11 + - 9 rule: Reference value: ~ - children: [] rule: ColumnName value: "a" - children: - - 13 + - 11 rule: Scan value: ~ - children: [] @@ -58,15 +50,15 @@ nodes: value: "t" - children: - 1 - - 15 + - 13 rule: Projection value: ~ - children: - - 19 + - 17 rule: Column value: ~ - children: - - 17 + - 15 rule: Reference value: ~ - children: [] @@ -76,15 +68,15 @@ nodes: rule: AliasName value: "a" - children: + - 14 - 16 - - 18 rule: Alias value: ~ -top: 14 +top: 12 map: - 16: - - 1 - 6: - - 12 - 10: - - 12 \ No newline at end of file + 5: + - 10 + 8: + - 10 + 14: + - 1 \ No newline at end of file -- GitLab