From 10442f7d7b30090dcfd6149ecf29545dab9944b7 Mon Sep 17 00:00:00 2001
From: "ms.evilhat" <ms.evilhat@gmail.com>
Date: Wed, 8 Feb 2023 10:21:32 +0300
Subject: [PATCH] fix: fix usage alias with arithmetic expression

previously we suggested that operand of arithmetic expression is column.
it allows use alias (a as a1 + b as b2). ofcourse operand must be value
(a), while alias usage must be available for expression (a + b as sum)
---
 .../test/integration/arithmetic_test.lua      | 10 +++
 sbroad-core/src/frontend/sql.rs               | 51 ++++++++++-
 sbroad-core/src/frontend/sql/ast.rs           | 86 ++++++++++++++++++
 sbroad-core/src/frontend/sql/ast/tests.rs     | 90 +++++++++++++++----
 sbroad-core/src/frontend/sql/query.pest       |  5 +-
 .../sql/arithmetic_projection_ast.yaml        | 18 ++--
 .../sql/arithmetic_selection_ast.yaml         | 40 ++++-----
 7 files changed, 244 insertions(+), 56 deletions(-)

diff --git a/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua b/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua
index c4ab8f7c1e..ad3c89712f 100644
--- a/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua
+++ b/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua
@@ -218,6 +218,16 @@ end
 g2.test_arithmetic_invalid = function()
     local api = cluster:server("api-1").net_box
 
+    local _, err = api:call("sbroad.execute", {
+        [[select "id" as "alias1" + "a" as "alias2" from "arithmetic_space"]], {}
+    })
+    t.assert_str_contains(tostring(err), "rule parsing error")
+
+    local _, err = api:call("sbroad.execute", {
+        [[select ("id" + "a") as "alias1" + "b" as "alias2" from "arithmetic_space"]], {}
+    })
+    t.assert_str_contains(tostring(err), "rule parsing error")
+
     local _, err = api:call("sbroad.execute", { [[select "id" % 2 from "arithmetic_space"]], {} })
     t.assert_str_contains(tostring(err), "rule parsing error")
 
diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs
index f94346189e..1cb30778ab 100644
--- a/sbroad-core/src/frontend/sql.rs
+++ b/sbroad-core/src/frontend/sql.rs
@@ -797,21 +797,63 @@ impl Ast for AbstractSyntaxTree {
                                     ));
                                 }
                             }
-                            Type::Multiplication | Type::Addition => {
+                            Type::ArithmeticExprAlias => {
+                                // left child is Addition, Multiplication or ArithParentheses
+                                let ast_left_id = ast_column.children.first().ok_or_else(|| {
+                                    SbroadError::UnexpectedNumberOfValues(
+                                        "ArithmeticExprAlias has no children.".into(),
+                                    )
+                                })?;
+
+                                let arithmetic_parse_node = self.nodes.get_node(*ast_left_id)?;
+                                if arithmetic_parse_node.rule != Type::Multiplication
+                                    && arithmetic_parse_node.rule != Type::Addition
+                                {
+                                    return Err(SbroadError::Invalid(
+                                        Entity::Node,
+                                        Some(format!("expected Multiplication or Addition as the first child of ArithmeticExprAlias, got {}",
+                                        arithmetic_parse_node.rule)),
+                                    ));
+                                }
+
                                 let cond_id = get_arithmetic_cond_id(
                                     &mut plan,
-                                    ast_column,
+                                    arithmetic_parse_node,
                                     &map,
                                     &mut arithmetic_expression_ids,
                                     &mut rows,
                                 )?;
-                                columns.push(cond_id);
+
+                                // right child is AliasName if exists
+                                // else means that arithmetic expression does not have an alias
+                                match ast_column.children.get(1).ok_or_else(|| {
+                                    SbroadError::NotFound(
+                                        Entity::Node,
+                                        "that is right node with index 1 among ArithmeticExprAlias children"
+                                            .into(),
+                                    )
+                                }) {
+                                    Ok(ast_name_id) => {
+                                        let name = self
+                                            .nodes
+                                            .get_node(*ast_name_id)?
+                                            .value
+                                            .as_ref()
+                                            .ok_or_else(|| SbroadError::NotFound(Entity::Name, "of Alias".into()))?;
+
+                                        let plan_alias_id = plan
+                                            .nodes
+                                            .add_alias(&normalize_name_from_sql(name), cond_id)?;
+                                        columns.push(plan_alias_id);
+                                    },
+                                    Err(_) => { columns.push(cond_id); },
+                                }
                             }
                             _ => {
                                 return Err(SbroadError::Invalid(
                                     Entity::Type,
                                     Some(format!(
-                                        "expected a Column, Asterisk, Multiplication or Addition in projection, got {:?}.",
+                                        "expected a Column, Asterisk, ArithmeticExprAlias in projection, got {:?}.",
                                         ast_column.rule
                                     )),
                                 ));
@@ -942,6 +984,7 @@ impl Ast for AbstractSyntaxTree {
                 }
                 Type::AliasName
                 | Type::Add
+                | Type::ArithmeticExprAlias
                 | Type::ArithParentheses
                 | Type::ColumnName
                 | Type::Divide
diff --git a/sbroad-core/src/frontend/sql/ast.rs b/sbroad-core/src/frontend/sql/ast.rs
index 505c075a70..8c9afc0825 100644
--- a/sbroad-core/src/frontend/sql/ast.rs
+++ b/sbroad-core/src/frontend/sql/ast.rs
@@ -6,6 +6,7 @@
 extern crate pest;
 
 use std::collections::{hash_map::Entry, HashMap, HashSet};
+use std::fmt;
 use std::mem::swap;
 
 use pest::iterators::Pair;
@@ -30,6 +31,7 @@ pub enum Type {
     AliasName,
     And,
     ArithmeticExpr,
+    ArithmeticExprAlias,
     ArithParentheses,
     Asterisk,
     Between,
@@ -110,6 +112,7 @@ impl Type {
             Rule::AliasName => Ok(Type::AliasName),
             Rule::And => Ok(Type::And),
             Rule::ArithmeticExpr => Ok(Type::ArithmeticExpr),
+            Rule::ArithmeticExprAlias => Ok(Type::ArithmeticExprAlias),
             Rule::ArithParentheses => Ok(Type::ArithParentheses),
             Rule::Asterisk => Ok(Type::Asterisk),
             Rule::Between => Ok(Type::Between),
@@ -185,6 +188,89 @@ impl Type {
     }
 }
 
+impl fmt::Display for Type {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        let p = match self {
+            Type::Add => "Add".to_string(),
+            Type::Addition => "Addition".to_string(),
+            Type::Alias => "Alias".to_string(),
+            Type::AliasName => "AliasName".to_string(),
+            Type::And => "And".to_string(),
+            Type::ArithmeticExpr => "ArithmeticExpr".to_string(),
+            Type::ArithmeticExprAlias => "ArithmeticExprAlias".to_string(),
+            Type::ArithParentheses => "ArithParentheses".to_string(),
+            Type::Asterisk => "Asterisk".to_string(),
+            Type::Between => "Between".to_string(),
+            Type::Cast => "Cast".to_string(),
+            Type::Column => "Column".to_string(),
+            Type::ColumnName => "ColumnName".to_string(),
+            Type::Concat => "Concat".to_string(),
+            Type::Condition => "Condition".to_string(),
+            Type::Decimal => "Decimal".to_string(),
+            Type::Divide => "Divide".to_string(),
+            Type::Double => "Double".to_string(),
+            Type::Eq => "Eq".to_string(),
+            Type::Except => "Except".to_string(),
+            Type::Explain => "Explain".to_string(),
+            Type::False => "False".to_string(),
+            Type::Function => "Function".to_string(),
+            Type::FunctionName => "FunctionName".to_string(),
+            Type::Gt => "Gt".to_string(),
+            Type::GtEq => "GtEq".to_string(),
+            Type::In => "In".to_string(),
+            Type::InnerJoin => "InnerJoin".to_string(),
+            Type::Insert => "Insert".to_string(),
+            Type::Integer => "Integer".to_string(),
+            Type::IsNotNull => "IsNotNull".to_string(),
+            Type::IsNull => "IsNull".to_string(),
+            Type::Length => "Length".to_string(),
+            Type::Lt => "Lt".to_string(),
+            Type::LtEq => "LtEq".to_string(),
+            Type::Multiplication => "Multiplication".to_string(),
+            Type::Multiply => "Multiply".to_string(),
+            Type::Name => "Name".to_string(),
+            Type::NotEq => "NotEq".to_string(),
+            Type::NotIn => "NotIn".to_string(),
+            Type::Null => "Null".to_string(),
+            Type::Or => "Or".to_string(),
+            Type::Parameter => "Parameter".to_string(),
+            Type::Parentheses => "Parentheses".to_string(),
+            Type::Primary => "Primary".to_string(),
+            Type::Projection => "Projection".to_string(),
+            Type::Reference => "Reference".to_string(),
+            Type::Row => "Row".to_string(),
+            Type::Scan => "Scan".to_string(),
+            Type::ScanName => "ScanName".to_string(),
+            Type::Select => "Select".to_string(),
+            Type::Selection => "Selection".to_string(),
+            Type::String => "String".to_string(),
+            Type::SubQuery => "SubQuery".to_string(),
+            Type::SubQueryName => "SubQueryName".to_string(),
+            Type::Subtract => "Subtract".to_string(),
+            Type::Table => "Table".to_string(),
+            Type::TargetColumns => "TargetColumns".to_string(),
+            Type::True => "True".to_string(),
+            Type::TypeAny => "TypeAny".to_string(),
+            Type::TypeBool => "TypeBool".to_string(),
+            Type::TypeDecimal => "TypeDecimal".to_string(),
+            Type::TypeDouble => "TypeDouble".to_string(),
+            Type::TypeInt => "TypeInt".to_string(),
+            Type::TypeNumber => "TypeNumber".to_string(),
+            Type::TypeScalar => "TypeScalar".to_string(),
+            Type::TypeString => "TypeString".to_string(),
+            Type::TypeText => "TypeText".to_string(),
+            Type::TypeUnsigned => "TypeUnsigned".to_string(),
+            Type::TypeVarchar => "TypeVarchar".to_string(),
+            Type::UnionAll => "UnionAll".to_string(),
+            Type::Unsigned => "Unsigned".to_string(),
+            Type::Value => "Value".to_string(),
+            Type::Values => "Values".to_string(),
+            Type::ValuesRow => "ValuesRow".to_string(),
+        };
+        write!(f, "{p}")
+    }
+}
+
 /// Parse node is a wrapper over the pest pair.
 #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)]
 pub struct ParseNode {
diff --git a/sbroad-core/src/frontend/sql/ast/tests.rs b/sbroad-core/src/frontend/sql/ast/tests.rs
index f60ace70c3..f4f3c9fa2a 100644
--- a/sbroad-core/src/frontend/sql/ast/tests.rs
+++ b/sbroad-core/src/frontend/sql/ast/tests.rs
@@ -222,10 +222,6 @@ fn sql_arithmetic_selection_ast() {
     let node = ast.nodes.get_node(a_id).unwrap();
     assert_eq!(node.rule, Type::Reference);
 
-    let (_, col_id) = iter.next().unwrap();
-    let node = ast.nodes.get_node(col_id).unwrap();
-    assert_eq!(node.rule, Type::Column);
-
     let (_, add_id) = iter.next().unwrap();
     let node = ast.nodes.get_node(add_id).unwrap();
     assert_eq!(node.rule, Type::Add);
@@ -239,10 +235,6 @@ fn sql_arithmetic_selection_ast() {
     let node = ast.nodes.get_node(b_id).unwrap();
     assert_eq!(node.rule, Type::Reference);
 
-    let (_, col_id) = iter.next().unwrap();
-    let node = ast.nodes.get_node(col_id).unwrap();
-    assert_eq!(node.rule, Type::Column);
-
     let (_, addition_id) = iter.next().unwrap();
     let node = ast.nodes.get_node(addition_id).unwrap();
     assert_eq!(node.rule, Type::Addition);
@@ -327,9 +319,72 @@ fn sql_arithmetic_projection_ast() {
     let node = ast.nodes.get_node(a_id).unwrap();
     assert_eq!(node.rule, Type::Reference);
 
-    let (_, col_id) = iter.next().unwrap();
-    let node = ast.nodes.get_node(col_id).unwrap();
-    assert_eq!(node.rule, Type::Column);
+    let (_, add_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(add_id).unwrap();
+    assert_eq!(node.rule, Type::Add);
+
+    let (_, sel_name_b_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(sel_name_b_id).unwrap();
+    assert_eq!(node.rule, Type::ColumnName);
+    assert_eq!(node.value, Some("b".to_string()));
+
+    let (_, b_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(b_id).unwrap();
+    assert_eq!(node.rule, Type::Reference);
+
+    let (_, addition_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(addition_id).unwrap();
+    assert_eq!(node.rule, Type::Addition);
+
+    let (_, arithm_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(arithm_id).unwrap();
+    assert_eq!(node.rule, Type::ArithmeticExprAlias);
+
+    let (_, proj_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(proj_id).unwrap();
+    assert_eq!(node.rule, Type::Projection);
+
+    assert_eq!(None, iter.next());
+}
+
+#[test]
+fn sql_arithmetic_projection_alias_ast() {
+    let ast = AbstractSyntaxTree::new("select a as alias1 + b as alias2 from t").unwrap_err();
+    assert_eq!(
+        format!(
+            r#"rule parsing error:  --> 1:10
+  |
+1 | select a as alias1 + b as alias2 from t
+  |          ^---
+  |
+  = expected Multiply, Divide, Add, or Subtract"#,
+        ),
+        format!("{ast}"),
+    );
+
+    let ast = AbstractSyntaxTree::new("select a + b as sum from t").unwrap();
+
+    let top = ast.top.unwrap();
+    let mut dft_post = PostOrder::with_capacity(|node| ast.nodes.ast_iter(node), 64);
+    let mut iter = dft_post.iter(top);
+
+    let (_, table_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(table_id).unwrap();
+    assert_eq!(node.rule, Type::Table);
+    assert_eq!(node.value, Some("t".to_string()));
+
+    let (_, scan_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(scan_id).unwrap();
+    assert_eq!(node.rule, Type::Scan);
+
+    let (_, sel_name_a_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(sel_name_a_id).unwrap();
+    assert_eq!(node.rule, Type::ColumnName);
+    assert_eq!(node.value, Some("a".to_string()));
+
+    let (_, a_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(a_id).unwrap();
+    assert_eq!(node.rule, Type::Reference);
 
     let (_, add_id) = iter.next().unwrap();
     let node = ast.nodes.get_node(add_id).unwrap();
@@ -346,12 +401,17 @@ fn sql_arithmetic_projection_ast() {
 
     let (_, col_id) = iter.next().unwrap();
     let node = ast.nodes.get_node(col_id).unwrap();
-    assert_eq!(node.rule, Type::Column);
-
-    let (_, addition_id) = iter.next().unwrap();
-    let node = ast.nodes.get_node(addition_id).unwrap();
     assert_eq!(node.rule, Type::Addition);
 
+    let (_, alias_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(alias_id).unwrap();
+    assert_eq!(node.rule, Type::AliasName);
+    assert_eq!(node.value, Some("sum".into()));
+
+    let (_, arithm_id) = iter.next().unwrap();
+    let node = ast.nodes.get_node(arithm_id).unwrap();
+    assert_eq!(node.rule, Type::ArithmeticExprAlias);
+
     let (_, proj_id) = iter.next().unwrap();
     let node = ast.nodes.get_node(proj_id).unwrap();
     assert_eq!(node.rule, Type::Projection);
diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest
index 85e6b8d892..981e6ef906 100644
--- a/sbroad-core/src/frontend/sql/query.pest
+++ b/sbroad-core/src/frontend/sql/query.pest
@@ -9,9 +9,10 @@ Query = _{ Except | UnionAll | Select | Values | Insert }
         (((^"inner" ~ ^"join") | ^"join") ~ InnerJoin ~
         ^"on" ~ Condition)? ~ (^"where" ~ Selection)?
     }
-        Projection = { (Asterisk | ArithmeticExpr | Column) ~ ("," ~ (Asterisk | ArithmeticExpr | Column))*? }
+        Projection = { (Asterisk | ArithmeticExprAlias | Column) ~ ("," ~ (Asterisk | ArithmeticExprAlias | Column))*? }
             Column = { Alias | Value }
                 Alias = {Value ~ ^"as" ~ AliasName }
+                ArithmeticExprAlias = { ArithmeticExpr ~ (^"as" ~ AliasName)? }
                 AliasName = @{ Name }
                 Reference = { (ScanName ~ "." ~ ColumnName) | ColumnName }
                     ColumnName = @{ Name }
@@ -33,7 +34,7 @@ Query = _{ Except | UnionAll | Select | Values | Insert }
 
 ArithmeticExpr = _{ Addition | Multiplication | ArithParentheses }
     ArithParentheses = { "(" ~ ArithmeticExpr ~ ")" }
-    ArithELeft = _{ ArithParentheses | Column }
+    ArithELeft = _{ ArithParentheses | Value }
     Multiplication = { ArithELeft ~ (Multiply | Divide) ~ MultiplicationRight }
         MultiplicationRight = _{ Multiplication | ArithELeft }
     Addition = { AdditionLeft ~ (Add | Subtract) ~ AdditionRight }
diff --git a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml
index 9e0ea7e5e0..635c718464 100644
--- a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml
+++ b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_projection_ast.yaml
@@ -18,14 +18,14 @@ nodes:
       rule: Projection
       value: ~
     - children:
-        - 9
-        - 8
-        - 5
-      rule: Addition
+      - 5
+      rule: ArithmeticExprAlias
       value: ~
     - children:
+        - 9
+        - 8
         - 6
-      rule: Column
+      rule: Addition
       value: ~
     - children:
         - 7
@@ -39,10 +39,6 @@ nodes:
       value: "+"
     - children:
         - 10
-      rule: Column
-      value: ~
-    - children:
-        - 11
       rule: Reference
       value: ~
     - children: []
@@ -50,7 +46,7 @@ nodes:
       value: "a"
 top: 3
 map:
-  10:
-    - 1
   6:
+    - 1
+  9:
     - 1
\ No newline at end of file
diff --git a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml
index 07d0be3fa2..7d6d4d6349 100644
--- a/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml
+++ b/sbroad-core/tests/artifactory/frontend/sql/arithmetic_selection_ast.yaml
@@ -2,11 +2,11 @@
 nodes:
   arena:
     - children:
-        - 14
+        - 12
       rule: Select
       value: ~
     - children:
-        - 12
+        - 10
         - 2
       rule: Selection
       value: ~
@@ -19,17 +19,13 @@ nodes:
       rule: Unsigned
       value: 1
     - children:
-        - 9
         - 8
+        - 7
         - 5
       rule: Addition
       value: ~
     - children:
         - 6
-      rule: Column
-      value: ~
-    - children:
-        - 7
       rule: Reference
       value: ~
     - children: []
@@ -39,18 +35,14 @@ nodes:
       rule: Add
       value: "+"
     - children:
-        - 10
-      rule: Column
-      value: ~
-    - children:
-        - 11
+        - 9
       rule: Reference
       value: ~
     - children: []
       rule: ColumnName
       value: "a"
     - children:
-        - 13
+        - 11
       rule: Scan
       value: ~
     - children: []
@@ -58,15 +50,15 @@ nodes:
       value: "t"
     - children:
         - 1
-        - 15
+        - 13
       rule: Projection
       value: ~
     - children:
-        - 19
+        - 17
       rule: Column
       value: ~
     - children:
-        - 17
+        - 15
       rule: Reference
       value: ~
     - children: []
@@ -76,15 +68,15 @@ nodes:
       rule: AliasName
       value: "a"
     - children:
+        - 14
         - 16
-        - 18
       rule: Alias
       value: ~
-top: 14
+top: 12
 map:
-  16:
-    - 1
-  6:
-    - 12
-  10:
-    - 12
\ No newline at end of file
+  5:
+    - 10
+  8:
+    - 10
+  14:
+    - 1
\ No newline at end of file
-- 
GitLab