From 2122fa2408ffad703773cf4cab9ef1b6ec508073 Mon Sep 17 00:00:00 2001 From: "ms.evilhat" <ms.evilhat@gmail.com> Date: Tue, 31 Jan 2023 10:31:30 +0300 Subject: [PATCH] feat: add arithmetic expressions to selection and joins --- .../test/integration/arithmetic_test.lua | 323 ++++++++++-------- sbroad-core/src/backend/sql/ir.rs | 1 + sbroad-core/src/backend/sql/tree.rs | 32 ++ sbroad-core/src/executor/bucket.rs | 13 +- sbroad-core/src/executor/ir.rs | 5 + sbroad-core/src/frontend/sql.rs | 74 +++- sbroad-core/src/frontend/sql/ast.rs | 22 +- sbroad-core/src/frontend/sql/ast/tests.rs | 2 +- sbroad-core/src/frontend/sql/ir.rs | 39 ++- sbroad-core/src/ir.rs | 17 +- sbroad-core/src/ir/api/parameter.rs | 10 + sbroad-core/src/ir/explain.rs | 35 +- sbroad-core/src/ir/expression.rs | 89 +++++ sbroad-core/src/ir/operator.rs | 42 +++ sbroad-core/src/ir/transformation.rs | 3 +- .../src/ir/transformation/merge_tuples.rs | 39 ++- .../src/ir/transformation/redistribution.rs | 4 + sbroad-core/src/ir/tree/expression.rs | 4 +- sbroad-core/src/ir/tree/subtree.rs | 4 +- 19 files changed, 589 insertions(+), 169 deletions(-) diff --git a/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua b/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua index 26b77b3d7f..0d127d0f50 100644 --- a/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua +++ b/sbroad-cartridge/test_app/test/integration/arithmetic_test.lua @@ -15,14 +15,14 @@ g.before_each( function() local api = cluster:server("api-1").net_box - for i = 1, 5 do + for i = 1, 10 do local r, err = api:call("sbroad.execute", { [[ insert into "arithmetic_space" ("id", "a", "b", "c", "d", "e", "f", "boolean_col", "string_col", "number_col") values (?,?,?,?,?,?,?,?,?,?) ]], - {i, i, i, i, i, i, i, true, "123", decimal.new('4.599999')}, + {i, i, i*2, i*3, i, i, i, true, "123", decimal.new('4.6')}, }) t.assert_equals(err, nil) t.assert_equals(r, {row_count = 1}) @@ -33,7 +33,7 @@ g.before_each( ("id", "a", "b", "c", "d", "e", "f", "boolean_col", "string_col", "number_col") values (?,?,?,?,?,?,?,?,?,?) ]], - {i, i, i, i, i, i, i, true, "123", decimal.new('4.599999')}, + {i, i, i, i, i, i, i, false, "123", decimal.new('4.599999')}, }) t.assert_equals(err, nil) t.assert_equals(r, {row_count = 1}) @@ -66,7 +66,7 @@ g1.before_each( function() local api = cluster:server("api-1").net_box - for k = 1,5 do + for k = 1,10 do local r, err = api:call("sbroad.execute", { [[ insert into "arithmetic_space" @@ -95,7 +95,6 @@ g1.after_all(function() helper.stop_test_cluster() end) - g.test_arithmetic_invalid = function() local api = cluster:server("api-1").net_box @@ -124,16 +123,31 @@ g.test_arithmetic_invalid = function() t.assert_str_contains(tostring(err), "rule parsing error") -- arithemic operation on boolean col - local _, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space" where "boolean_col" + "boolean_col" > 0]], {} }) - t.assert_str_contains(tostring(err), "Type mismatch: can not convert boolean(TRUE) to integer, decimal, double, datetime or interval") + local _, err = api:call("sbroad.execute", + { [[select "id" from "arithmetic_space" where "boolean_col" + "boolean_col" > 0]], {} } + ) + t.assert_str_contains( + tostring(err), + "Type mismatch: can not convert boolean(TRUE) to integer, decimal, double, datetime or interval" + ) -- arithemic operation on string col - local _, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space" where "string_col" + "string_col" > 0]], {} }) - t.assert_str_contains(tostring(err), "Type mismatch: can not convert string('123') to integer, decimal, double, datetime or interval") + local _, err = api:call("sbroad.execute", + { [[select "id" from "arithmetic_space" where "string_col" + "string_col" > 0]], {} } + ) + t.assert_str_contains( + tostring(err), + "Type mismatch: can not convert string('123') to integer, decimal, double, datetime or interval" + ) -- arithemic operation on number col - local _, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space" where "number_col" + "number_col" > 0]], {} }) - t.assert_str_contains(tostring(err), "Type mismatch: can not convert number(4.599999) to integer, decimal, double, datetime or interval") + local _, err = api:call("sbroad.execute", + { [[select "id" from "arithmetic_space" where "number_col" + "number_col" > 0]], {} } +) + t.assert_str_contains( + tostring(err), + "Type mismatch: can not convert number(4.6) to integer, decimal, double, datetime or interval" + ) end g.test_arithmetic_valid = function() @@ -168,52 +182,31 @@ g.test_arithmetic_valid = function() r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + res_all.rows + ) -- test several operators with different priority local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" where - "id" + "id" * "id" + "id" > 0 - and "id" - "id" * "id" - "id" > 0 - and "id" + "id" / "id" + "id" > 0 - and "id" - "id" / "id" - "id" > 0 + "id" + "id" * "id" + "id" >= 0 + and "id" - "id" * "id" - "id" <= 0 + and "id" + "id" / "id" + "id" >= 0 + and "id" - "id" / "id" - "id" <= 0 ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) - - -- test priority - local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where cast("a" as decimal) / cast("b" as decimal) * cast("c" as decimal) = (cast("a" as decimal) / cast("b" as decimal)) * cast("c" as decimal) - and cast("a" as decimal) / cast("b" as decimal) * cast("c" as decimal) = cast("a" as decimal) / (cast("b" as decimal) * cast("c" as decimal)) - and cast("a" as decimal) * cast("b" as decimal) / cast("c" as decimal) = (cast("a" as decimal) * cast("b" as decimal)) / cast("c" as decimal) - and cast("a" as decimal) * cast("b" as decimal) / cast("c" as decimal) = cast("a" as decimal) * (cast("b" as decimal) / cast("c" as decimal)) - and "id" + "id" * "id" / "id" = "id" + ("id" * "id" / "id") - and "id" * "id" / "id" + "id" = ("id" * "id" / "id") + "id" - ]], {} }) - t.assert_equals(err, nil) t.assert_items_equals( - res.metadata, - { {name = "id", type = "integer"} } + r.rows, + res_all.rows ) - -- t.assert_items_equals( - -- res.rows, - -- res_all.rows - -- ) end - g.test_arithmetic_with_bool = function() local api = cluster:server("api-1").net_box @@ -224,100 +217,104 @@ g.test_arithmetic_with_bool = function() -- test arithmetic_expr [comparison operator] number local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" - where "id" + cast("a" as decimal) >= 0 - and "id" + "c" <= 2 + where "id" + "a" >= 0 + and "id" + "b" <= 12 and "id" + "d" > 0 - and "id" + "e" < 0 - and "id" + cast("a" as decimal) = 2 + and "id" + "e" < 8 + and "id" + "d" = 2 + and "id" + "a" != 3 ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + { {1} } + ) -- test arithmetic_expr [comparison operator] arithmetic_expr local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" - where "id" + "a" >= "id" * 1 - and "id" + "c" <= "id" * 1 - and "id" + "d" > "id" * 1 - and "id" + "e" < "id" * 1 - and "id" + "a" = "id" * 1 + where "id" + "a" >= "id" * 2 + and "id" + "c" <= "id" * 4 + and "id" + "b" > "id" * "a" + and "id" + "a" < "id" + 3 + and "id" + "a" = 2 + and "id" + "a" != 4 ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + { {1} } + ) -- test arithmetic_expr [comparison operator] row local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" where "id" + "a" >= "id" - and "id" + "c" <= "id" - and "id" + "d" > "id" - and "id" + "e" < "id" - and "id" + "a" = "id" + and "id" + "b" <= "c" + and "id" + "d" > "e" + and "id" + "f" < "c" + and "id" + "a" = "b" + and "id" + "a" != "c" ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + res_all.rows + ) -- test number [comparison operator] arithmetic_expr local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" - where 0 >= "id" + "a" - and 2 <= "id" + "c" - and 0 > "id" + "d" - and 0 < "id" + "e" - and 2 = "id" + "a" + where 12 >= "id" + "a" + and 4 <= "id" + "d" + and 12 > "id" + "e" + and 4 < "id" + "f" + and 20 = "id" + "c" + and 9 != "id" + "b" ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + { {5} } + ) -- test row [comparison operator] arithmetic_expr local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" - where "id" >= "id" + "a" - and "id" <= "id" + "c" - and "id" > "id" + "d" - and "id" < "id" + "e" - and "id" = "id" + "a" + where "c" >= "id" + "b" + and "b" <= "id" + "c" + and "c" > "id" + "a" + and "id" < "a" + "e" + and "b" = "id" + "f" + and "c" != "id" + "a" ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + res_all.rows + ) end --- Type mismatch: can not convert scalar(1) to integer, decimal or double g.test_join_simple_arithmetic = function() local api = cluster:server("api-1").net_box @@ -327,21 +324,21 @@ g.test_join_simple_arithmetic = function() FROM (SELECT "id", "a" FROM "arithmetic_space" - WHERE "c" > 0 + WHERE "c" < 0 UNION ALL SELECT "id", "a" FROM "arithmetic_space" - WHERE "c" > 0) AS "t3" + WHERE "c" > 0) AS "t3" INNER JOIN (SELECT "id" as "id1", "b" - FROM "arithmetic_space2" - WHERE "b" < 0 + FROM "arithmetic_space2" + WHERE "b" < 0 UNION ALL SELECT "id" as "id1", "b" - FROM "arithmetic_space2" - WHERE "b" > 0) AS "t8" - ON "t3"."id" + 1 = "t8"."id1" + 2 - WHERE "t3"."id" = 1]], { } }) + FROM "arithmetic_space2" + WHERE "b" > 0) AS "t8" + ON "t3"."id" + "t3"."a" * 2 = "t8"."id1" + "t8"."b" + WHERE "t3"."id" = 2]], { } }) t.assert_equals(err, nil) t.assert_equals( @@ -352,70 +349,96 @@ g.test_join_simple_arithmetic = function() {name = "t8.b", type = "integer"}, } ) - -- t.assert_equals( - -- r.rows, - -- { - -- ... - -- } - -- ) + t.assert_equals( + r.rows, + { { 2, 2, 3 } } + ) + + -- check the same query with params + local r2, err = api:call("sbroad.execute", + { [[ + SELECT "t3"."id", "t3"."a", "t8"."b" + FROM + (SELECT "id", "a" + FROM "arithmetic_space" + WHERE "c" < ? + UNION ALL + SELECT "id", "a" + FROM "arithmetic_space" + WHERE "c" > ?) AS "t3" + INNER JOIN + (SELECT "id" as "id1", "b" + FROM "arithmetic_space2" + WHERE "b" < ? + UNION ALL + SELECT "id" as "id1", "b" + FROM "arithmetic_space2" + WHERE "b" > ?) AS "t8" + ON "t3"."id" + "t3"."a" * ? = "t8"."id1" + "t8"."b" + WHERE "t3"."id" = ?]], { 0, 0, 0, 0, 2, 2} }) + + t.assert_equals(err, nil) + t.assert_equals( + r.metadata, + { + {name = "t3.id", type = "integer"}, + {name = "t3.a", type = "integer"}, + {name = "t8.b", type = "integer"}, + } + ) + t.assert_equals( + r.rows, + r2.rows + ) end g.test_selection_simple_arithmetic = function() local api = cluster:server("api-1").net_box + local res_all, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space"]], {} }) + t.assert_equals(err, nil) + t.assert_not_equals(res_all.rows, {}) + -- check selection with arithmetic expr and `>` comparison operator - local r, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space" where "id" + 2 > 0]], {} }) + local r, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space" where "id" + 1 > 8]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + { {8}, {9}, {10} } + ) -- check selection with arithmetic expr and `between` comparison operator - local r, err = api:call("sbroad.execute", { [[select "id" from "arithmetic_space" where "id" between "id" - 1 and "id" * 4]], {} }) + local r, err = api:call("sbroad.execute", + { [[select "id" from "arithmetic_space" where "id" between "id" - 1 and "id" * 4]], {} } + ) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) - - -- check selection with arithmetic expr and `!=` comparison operator - local r, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where "id" * 2 != "a" * 2 - ]], {} }) - t.assert_equals(err, nil) t.assert_items_equals( - r.metadata, - { {name = "id", type = "integer"} } + r.rows, + res_all.rows ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) -- check selection with arithmetic expr and boolean operators local r, err = api:call("sbroad.execute", { [[ select "id" from "arithmetic_space" - where "id" * 2 > 2 or "id" * 2 > 0 and "id" * 2 = 0 + where ("id" > "a" * 2 or "id" * 2 > 10) and "id" - 6 != 0 ]], {} }) t.assert_equals(err, nil) t.assert_items_equals( r.metadata, { {name = "id", type = "integer"} } ) - -- t.assert_items_equals( - -- r.rows, - -- { ... } - -- ) + t.assert_items_equals( + r.rows, + { {7}, {8}, {9}, {10} } + ) end g1.test_associativity = function() @@ -427,45 +450,43 @@ g1.test_associativity = function() -- addition and multiplication are associative local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where "a" + ("b" + "c") = ("a" + "b") + "c" + select "id" from "arithmetic_space" where "a" + ("b" + "c") = ("a" + "b") + "c" ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where "a" * ("b" * "c") = ("a" * "b") * "c" + select "id" from "arithmetic_space" where "a" * ("b" * "c") = ("a" * "b") * "c" ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) -- subtraction is left-associative local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where ("a" - "b") - "c" = "a" - "b" - "c" + select "id" from "arithmetic_space" where ("a" - "b") - "c" = "a" - "b" - "c" ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where "a" - ("b" - "c" ) = "a" - "b" - "c" + select "id" from "arithmetic_space" where "a" - ("b" - "c" ) = "a" - "b" - "c" ]], {} }) t.assert_equals(err, nil) t.assert_not_equals(res.rows, res_all.rows) -- division is left-associative local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where (cast("a" as decimal) / cast("b" as decimal)) / cast("c" as decimal) = cast("a" as decimal) / cast("b" as decimal) / cast("c" as decimal) + select "id" from "arithmetic_space" where + (cast("a" as decimal) / cast("b" as decimal)) / cast("c" as decimal) = + cast("a" as decimal) / cast("b" as decimal) / cast("c" as decimal) ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where cast("a" as decimal) / (cast("b" as decimal) / cast("c" as decimal)) = (cast("a" as decimal) / cast("b" as decimal)) / cast("c" as decimal) + select "id" from "arithmetic_space" where + cast("a" as decimal) / (cast("b" as decimal) / cast("c" as decimal)) = + (cast("a" as decimal) / cast("b" as decimal)) / cast("c" as decimal) ]], {} }) t.assert_equals(err, nil) t.assert_not_equals(res.rows, res_all.rows) @@ -480,15 +501,13 @@ g1.test_commutativity = function() -- addition and multiplication are commutative local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where "a" + "b" = "b" + "a" + select "id" from "arithmetic_space" where "a" + "b" = "b" + "a" ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where "a" * "b" = "b" * "a" + select "id" from "arithmetic_space" where "a" * "b" = "b" * "a" ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) @@ -505,9 +524,11 @@ g1.test_commutativity = function() t.assert_equals(res.rows, {}) local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" where cast("b" as decimal) / cast("a" as decimal) = cast("a" as decimal) / cast("b" as decimal) + select "id" from "arithmetic_space" + where cast("b" as decimal) / cast("a" as decimal) = cast("a" as decimal) / cast("b" as decimal) except - select "id" from "arithmetic_space" where "a" = "b" or "a" = -1 * "b" + select "id" from "arithmetic_space" + where "a" = "b" or "a" = -1 * "b" ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, {}) @@ -537,15 +558,17 @@ g1.test_distributivity = function() -- division is right-distributive over addition|subtraction local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where (cast("a" as decimal) + cast("b" as decimal)) / cast("c" as decimal) = cast("a" as decimal) / cast("c" as decimal) + cast("b" as decimal) / cast("c" as decimal) + select "id" from "arithmetic_space" where + (cast("a" as decimal) + cast("b" as decimal)) / cast("c" as decimal) = + cast("a" as decimal) / cast("c" as decimal) + cast("b" as decimal) / cast("c" as decimal) ]], {} }) t.assert_equals(err, nil) t.assert_equals(res.rows, res_all.rows) local res, err = api:call("sbroad.execute", { [[ - select "id" from "arithmetic_space" - where cast("a" as decimal) / (cast("b" as decimal) + cast("c" as decimal)) = cast("a" as decimal) / cast("b" as decimal) + cast("a" as decimal) / cast("c" as decimal) + select "id" from "arithmetic_space" where + cast("a" as decimal) / (cast("b" as decimal) + cast("c" as decimal)) = + cast("a" as decimal) / cast("b" as decimal) + cast("a" as decimal) / cast("c" as decimal) ]], {} }) t.assert_equals(err, nil) t.assert_not_equals(res.rows, res_all.rows) diff --git a/sbroad-core/src/backend/sql/ir.rs b/sbroad-core/src/backend/sql/ir.rs index 8a05f4a99d..df80b80545 100644 --- a/sbroad-core/src/backend/sql/ir.rs +++ b/sbroad-core/src/backend/sql/ir.rs @@ -311,6 +311,7 @@ impl ExecutionPlan { match expr { Expression::Alias { .. } | Expression::Bool { .. } + | Expression::Arithmetic { .. } | Expression::Cast { .. } | Expression::Concat { .. } | Expression::Row { .. } diff --git a/sbroad-core/src/backend/sql/tree.rs b/sbroad-core/src/backend/sql/tree.rs index 73a2966fc6..0b945fc315 100644 --- a/sbroad-core/src/backend/sql/tree.rs +++ b/sbroad-core/src/backend/sql/tree.rs @@ -822,6 +822,38 @@ impl<'p> SyntaxPlan<'p> { }; Ok(self.nodes.push_syntax_node(sn)) } + Expression::Arithmetic { + left, + right, + op, + with_parentheses, + } => { + let sn = if *with_parentheses { + SyntaxNode::new_pointer( + id, + Some(self.nodes.push_syntax_node(SyntaxNode::new_open())), + vec![ + self.nodes.get_syntax_node_id(*left)?, + self.nodes + .push_syntax_node(SyntaxNode::new_operator(&format!("{op}"))), + self.nodes.get_syntax_node_id(*right)?, + self.nodes.push_syntax_node(SyntaxNode::new_close()), + ], + ) + } else { + SyntaxNode::new_pointer( + id, + Some(self.nodes.get_syntax_node_id(*left)?), + vec![ + self.nodes + .push_syntax_node(SyntaxNode::new_operator(&format!("{op}"))), + self.nodes.get_syntax_node_id(*right)?, + ], + ) + }; + + Ok(self.nodes.push_syntax_node(sn)) + } Expression::Unary { child, op, .. } => { let sn = SyntaxNode::new_pointer( id, diff --git a/sbroad-core/src/executor/bucket.rs b/sbroad-core/src/executor/bucket.rs index 30d07c9f59..85b93a474f 100644 --- a/sbroad-core/src/executor/bucket.rs +++ b/sbroad-core/src/executor/bucket.rs @@ -77,6 +77,7 @@ where let mut buckets: Vec<Buckets> = Vec::new(); let ir_plan = self.exec_plan.get_ir_plan(); let expr = ir_plan.get_expression_node(expr_id)?; + if let Expression::Bool { op: Bool::Eq | Bool::In, left, @@ -87,22 +88,30 @@ where let pairs = vec![(*left, *right), (*right, *left)]; for (left_id, right_id) in pairs { let left_expr = ir_plan.get_expression_node(left_id)?; + + if left_expr.is_arithmetic() { + return Ok(Buckets::new_all()); + } if !left_expr.is_row() { return Err(SbroadError::Invalid( Entity::Expression, Some(format!( - "left side of equality expression is not a row: {left_expr:?}" + "left side of equality expression is not a row or arithmetic: {left_expr:?}" )), )); } + let right_expr = ir_plan.get_expression_node(right_id)?; + if right_expr.is_arithmetic() { + return Ok(Buckets::new_all()); + } let right_columns = if let Expression::Row { list, .. } = right_expr { list.clone() } else { return Err(SbroadError::Invalid( Entity::Expression, Some(format!( - "right side of equality expression is not a row: {right_expr:?}" + "right side of equality expression is not a row or arithmetic: {right_expr:?}" )), )); }; diff --git a/sbroad-core/src/executor/ir.rs b/sbroad-core/src/executor/ir.rs index faefbe7f5d..a5c1dee353 100644 --- a/sbroad-core/src/executor/ir.rs +++ b/sbroad-core/src/executor/ir.rs @@ -386,6 +386,11 @@ impl ExecutionPlan { ref mut right, .. } + | Expression::Arithmetic { + ref mut left, + ref mut right, + .. + } | Expression::Concat { ref mut left, ref mut right, diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs index ec22742793..65efcf5001 100644 --- a/sbroad-core/src/frontend/sql.rs +++ b/sbroad-core/src/frontend/sql.rs @@ -15,7 +15,7 @@ use crate::frontend::sql::ir::Translation; use crate::frontend::Ast; use crate::ir::expression::cast::Type as CastType; use crate::ir::expression::Expression; -use crate::ir::operator::{Bool, Unary}; +use crate::ir::operator::{Arithmetic, Bool, Unary}; use crate::ir::tree::traversal::PostOrder; use crate::ir::value::Value; use crate::ir::{Node, Plan}; @@ -130,6 +130,7 @@ impl Ast for AbstractSyntaxTree { let mut col_idx: usize = 0; let mut betweens: Vec<Between> = Vec::new(); + let mut arithmetic_expression_ids: Vec<usize> = Vec::new(); for (_, id) in dft_post.iter(top) { let node = self.nodes.get_node(id)?; @@ -304,7 +305,7 @@ impl Ast for AbstractSyntaxTree { } else { return Err(SbroadError::Invalid( Entity::Plan, - Some(format!("left and right plan nodes do not match the AST scan name: {ast_scan_name:?}")), + Some(format!("left {left_name:?} and right {right_name:?} plan nodes do not match the AST scan name: {ast_scan_name:?}")), )); } } else { @@ -738,6 +739,68 @@ impl Ast for AbstractSyntaxTree { let projection_id = plan.add_proj_internal(plan_child_id, &columns)?; map.add(id, projection_id); } + Type::Multiplication | Type::Addition => { + let plan_left_id: usize; + let plan_right_id: usize; + + let ast_left_id = node.children.first().ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "Multiplication or Addition has no children.".into(), + ) + })?; + + // if left child of current multiplication or addition is `(expr)` then + // we need to get expr that is child of `()` and add it to the plan + // also we will mark this expr to add in the future `()` + let ar_left = self.nodes.get_node(*ast_left_id)?; + if ar_left.rule == Type::ArithParentheses { + let arithmetic_id = ar_left.children.first().ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "ArithParentheses has no children.".into(), + ) + })?; + plan_left_id = plan.as_row(map.get(*arithmetic_id)?, &mut rows)?; + arithmetic_expression_ids.push(plan_left_id); + } else { + plan_left_id = plan.as_row(map.get(*ast_left_id)?, &mut rows)?; + } + + let ast_right_id = node.children.get(2).ok_or_else(|| { + SbroadError::NotFound( + Entity::Node, + "that is right node with index 2 among Multiplication or Addition children".into(), + ) + })?; + + // if left child of current multiplication or addition is `(expr)` then + // we need to get expr that is child of `()` and add it to the plan + // also we will mark this expr to add in the future `()` + let ar_right = self.nodes.get_node(*ast_right_id)?; + if ar_right.rule == Type::ArithParentheses { + let arithmetic_id = ar_right.children.first().ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "ArithParentheses has no children.".into(), + ) + })?; + plan_right_id = plan.as_row(map.get(*arithmetic_id)?, &mut rows)?; + arithmetic_expression_ids.push(plan_right_id); + } else { + plan_right_id = plan.as_row(map.get(*ast_right_id)?, &mut rows)?; + } + + let ast_op_id = node.children.get(1).ok_or_else(|| { + SbroadError::NotFound( + Entity::Node, + "that is center node (operator) with index 1 among Multiplication or Addition children".into(), + ) + })?; + let op_node = self.nodes.get_node(*ast_op_id)?; + + let op = Arithmetic::from_node_type(&op_node.rule)?; + let cond_id = + plan.add_arithmetic_to_plan(plan_left_id, op, plan_right_id, false)?; + map.add(id, cond_id); + } Type::Except => { let ast_left_id = node.children.first().ok_or_else(|| { SbroadError::UnexpectedNumberOfValues("Except has no children.".into()) @@ -848,12 +911,17 @@ impl Ast for AbstractSyntaxTree { map.add(0, map.get(*ast_child_id)?); } Type::AliasName + | Type::Add + | Type::ArithParentheses | Type::ColumnName + | Type::Divide | Type::FunctionName | Type::Length + | Type::Multiply | Type::ScanName | Type::Select | Type::SubQueryName + | Type::Subtract | Type::TargetColumns | Type::TypeAny | Type::TypeBool @@ -883,7 +951,7 @@ impl Ast for AbstractSyntaxTree { plan.set_top(plan_top_id)?; let replaces = plan.replace_sq_with_references()?; plan.fix_betweens(&betweens, &replaces)?; - + plan.fix_arithmetic_parentheses(&arithmetic_expression_ids)?; Ok(plan) } } diff --git a/sbroad-core/src/frontend/sql/ast.rs b/sbroad-core/src/frontend/sql/ast.rs index 0427f9def2..505c075a70 100644 --- a/sbroad-core/src/frontend/sql/ast.rs +++ b/sbroad-core/src/frontend/sql/ast.rs @@ -24,9 +24,13 @@ pub(super) struct ParseTree; /// should be also added in the current list. #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] pub enum Type { + Add, + Addition, Alias, AliasName, And, + ArithmeticExpr, + ArithParentheses, Asterisk, Between, Cast, @@ -35,6 +39,7 @@ pub enum Type { Concat, Condition, Decimal, + Divide, Double, Eq, Except, @@ -48,11 +53,13 @@ pub enum Type { InnerJoin, Insert, Integer, - IsNull, IsNotNull, + IsNull, Length, Lt, LtEq, + Multiplication, + Multiply, Name, NotEq, NotIn, @@ -71,6 +78,7 @@ pub enum Type { String, SubQuery, SubQueryName, + Subtract, Table, TargetColumns, True, @@ -96,9 +104,13 @@ impl Type { #[allow(dead_code)] fn from_rule(rule: Rule) -> Result<Self, SbroadError> { match rule { + Rule::Add => Ok(Type::Add), + Rule::Addition => Ok(Type::Addition), Rule::Alias => Ok(Type::Alias), Rule::AliasName => Ok(Type::AliasName), Rule::And => Ok(Type::And), + Rule::ArithmeticExpr => Ok(Type::ArithmeticExpr), + Rule::ArithParentheses => Ok(Type::ArithParentheses), Rule::Asterisk => Ok(Type::Asterisk), Rule::Between => Ok(Type::Between), Rule::Cast => Ok(Type::Cast), @@ -107,6 +119,7 @@ impl Type { Rule::Concat => Ok(Type::Concat), Rule::Condition => Ok(Type::Condition), Rule::Decimal => Ok(Type::Decimal), + Rule::Divide => Ok(Type::Divide), Rule::Double => Ok(Type::Double), Rule::Eq => Ok(Type::Eq), Rule::Except => Ok(Type::Except), @@ -118,13 +131,15 @@ impl Type { Rule::GtEq => Ok(Type::GtEq), Rule::In => Ok(Type::In), Rule::InnerJoin => Ok(Type::InnerJoin), - Rule::Integer => Ok(Type::Integer), Rule::Insert => Ok(Type::Insert), - Rule::IsNull => Ok(Type::IsNull), + Rule::Integer => Ok(Type::Integer), Rule::IsNotNull => Ok(Type::IsNotNull), + Rule::IsNull => Ok(Type::IsNull), Rule::Length => Ok(Type::Length), Rule::Lt => Ok(Type::Lt), Rule::LtEq => Ok(Type::LtEq), + Rule::Multiplication => Ok(Type::Multiplication), + Rule::Multiply => Ok(Type::Multiply), Rule::Name => Ok(Type::Name), Rule::NotEq => Ok(Type::NotEq), Rule::NotIn => Ok(Type::NotIn), @@ -142,6 +157,7 @@ impl Type { Rule::Selection => Ok(Type::Selection), Rule::String => Ok(Type::String), Rule::SubQuery => Ok(Type::SubQuery), + Rule::Subtract => Ok(Type::Subtract), Rule::Table => Ok(Type::Table), Rule::TargetColumns => Ok(Type::TargetColumns), Rule::True => Ok(Type::True), diff --git a/sbroad-core/src/frontend/sql/ast/tests.rs b/sbroad-core/src/frontend/sql/ast/tests.rs index 7440a8469b..2d46560c64 100644 --- a/sbroad-core/src/frontend/sql/ast/tests.rs +++ b/sbroad-core/src/frontend/sql/ast/tests.rs @@ -177,7 +177,7 @@ fn invalid_condition() { 2 | "identification_number" = 1 "product_code" = 2 | ^--- | - = expected EOI"#, + = expected EOI, Multiply, Divide, Add, or Subtract"#, ), format!("{ast}"), ) diff --git a/sbroad-core/src/frontend/sql/ir.rs b/sbroad-core/src/frontend/sql/ir.rs index 022e1717f5..6d5069e989 100644 --- a/sbroad-core/src/frontend/sql/ir.rs +++ b/sbroad-core/src/frontend/sql/ir.rs @@ -7,7 +7,7 @@ use crate::errors::{Action, Entity, SbroadError}; use crate::frontend::sql::ast::{ParseNode, Type}; use crate::ir::expression::Expression; use crate::ir::helpers::RepeatableState; -use crate::ir::operator::{Bool, Relational, Unary}; +use crate::ir::operator::{Arithmetic, Bool, Relational, Unary}; use crate::ir::tree::traversal::{PostOrder, EXPR_CAPACITY, REL_CAPACITY}; use crate::ir::value::double::Double; use crate::ir::value::Value; @@ -41,6 +41,26 @@ impl Bool { } } +impl Arithmetic { + /// Creates `Arithmetic` from ast node type. + /// + /// # Errors + /// Returns `SbroadError` when the operator is invalid. + #[allow(dead_code)] + pub(super) fn from_node_type(s: &Type) -> Result<Self, SbroadError> { + match s { + Type::Multiply => Ok(Arithmetic::Multiply), + Type::Add => Ok(Arithmetic::Add), + Type::Divide => Ok(Arithmetic::Divide), + Type::Subtract => Ok(Arithmetic::Subtract), + _ => Err(SbroadError::Invalid( + Entity::Operator, + Some(format!("Arithmetic: {s:?}")), + )), + } + } +} + impl Unary { /// Creates `Unary` from ast node type. /// @@ -320,6 +340,18 @@ impl Plan { Ok(()) } + // It is necessary to keep parentheses from original sql query + // We mark this arithmetic expressions to add in the future `()` as a part of SyntaxData + pub(super) fn fix_arithmetic_parentheses( + &mut self, + arithmetic_expression_ids: &[usize], + ) -> Result<(), SbroadError> { + for id in arithmetic_expression_ids { + self.nodes.set_arithmetic_node_parentheses(*id, true)?; + } + Ok(()) + } + fn clone_expr_subtree(&mut self, top_id: usize) -> Result<usize, SbroadError> { let mut map = HashMap::new(); let mut subtree = @@ -343,6 +375,11 @@ impl Plan { ref mut right, .. } + | Expression::Arithmetic { + ref mut left, + ref mut right, + .. + } | Expression::Concat { ref mut left, ref mut right, diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs index fc44e9081f..b1967797f7 100644 --- a/sbroad-core/src/ir.rs +++ b/sbroad-core/src/ir.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use std::slice::Iter; use expression::Expression; -use operator::Relational; +use operator::{Arithmetic, Relational}; use relation::Table; use crate::errors::{Action, Entity, SbroadError}; @@ -391,6 +391,21 @@ impl Plan { self.nodes.add_bool(left, op, right) } + /// Add rithmetic node to the plan. + /// + /// # Errors + /// Returns `SbroadError` when the condition node can't append'. + pub fn add_arithmetic_to_plan( + &mut self, + left: usize, + op: Arithmetic, + right: usize, + with_parentheses: bool, + ) -> Result<usize, SbroadError> { + self.nodes + .add_arithmetic_node(left, op, right, with_parentheses) + } + /// Add unary operator node to the plan. /// /// # Errors diff --git a/sbroad-core/src/ir/api/parameter.rs b/sbroad-core/src/ir/api/parameter.rs index 67b107b2cd..95b16d4092 100644 --- a/sbroad-core/src/ir/api/parameter.rs +++ b/sbroad-core/src/ir/api/parameter.rs @@ -126,6 +126,11 @@ impl Plan { ref right, .. } + | Expression::Arithmetic { + ref left, + ref right, + .. + } | Expression::Concat { ref left, ref right, @@ -211,6 +216,11 @@ impl Plan { ref mut right, .. } + | Expression::Arithmetic { + ref mut left, + ref mut right, + .. + } | Expression::Concat { ref mut left, ref mut right, diff --git a/sbroad-core/src/ir/explain.rs b/sbroad-core/src/ir/explain.rs index 3dc0bdd2f0..6dda5e42dd 100644 --- a/sbroad-core/src/ir/explain.rs +++ b/sbroad-core/src/ir/explain.rs @@ -13,7 +13,7 @@ use crate::ir::transformation::redistribution::{ }; use crate::ir::Plan; -use super::operator::{Bool, Unary}; +use super::operator::{Arithmetic, Bool, Unary}; use super::tree::traversal::{PostOrder, EXPR_CAPACITY, REL_CAPACITY}; use super::value::Value; @@ -133,7 +133,10 @@ impl ColExpr { let row_expr = ColExpr::Row(row); stack.push(row_expr); } - Expression::Alias { .. } | Expression::Bool { .. } | Expression::Unary { .. } => { + Expression::Alias { .. } + | Expression::Bool { .. } + | Expression::Arithmetic { .. } + | Expression::Unary { .. } => { return Err(SbroadError::Unsupported( Entity::Expression, Some(format!( @@ -351,6 +354,7 @@ impl Row { } } Expression::Bool { .. } + | Expression::Arithmetic { .. } | Expression::Cast { .. } | Expression::Concat { .. } | Expression::StableFunction { .. } @@ -380,13 +384,18 @@ impl Display for Row { } } +#[derive(Debug, Serialize)] +enum BinaryOp { + ArithOp(Arithmetic), + BoolOp(Bool), +} /// Recursive type which describe `WHERE` cause in explain #[derive(Debug, Serialize)] enum Selection { Row(Row), BinaryOp { left: Box<Selection>, - op: Bool, + op: BinaryOp, right: Box<Selection>, }, UnaryOp { @@ -407,7 +416,14 @@ impl Selection { let result = match current_node { Expression::Bool { left, op, right } => Selection::BinaryOp { left: Box::new(Selection::new(plan, *left, ref_map)?), - op: op.clone(), + op: BinaryOp::BoolOp(op.clone()), + right: Box::new(Selection::new(plan, *right, ref_map)?), + }, + Expression::Arithmetic { + left, op, right, .. + } => Selection::BinaryOp { + left: Box::new(Selection::new(plan, *left, ref_map)?), + op: BinaryOp::ArithOp(op.clone()), right: Box::new(Selection::new(plan, *right, ref_map)?), }, Expression::Row { list, .. } => { @@ -433,6 +449,17 @@ impl Selection { } } +impl Display for BinaryOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match &self { + BinaryOp::ArithOp(a) => a.to_string(), + BinaryOp::BoolOp(b) => b.to_string(), + }; + + write!(f, "{s}") + } +} + impl Display for Selection { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let s = match &self { diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs index 242cb8f801..ad40406977 100644 --- a/sbroad-core/src/ir/expression.rs +++ b/sbroad-core/src/ir/expression.rs @@ -54,6 +54,23 @@ pub enum Expression { /// Right branch expression node index in the plan node arena. right: usize, }, + /// Binary expression returning row result. + /// + /// Example: `a + b > 42`, `a + b < c + 1`, `1 + 2 != 2 * 2`. + Arithmetic { + /// Left branch expression node index in the plan node arena. + left: usize, + /// Arithmetic operator. + op: operator::Arithmetic, + /// Right branch expression node index in the plan node arena. + right: usize, + /// Has expr parentheses or not. Important to keep this information + /// because we can not add parentheses for all exprs: we parse query + /// from the depth and from left to the right and not all arithmetic + /// operations are associative, example: + /// `(6 - 2) - 1 != 6 - (2 - 1)`, `(8 / 4) / 2 != 8 / (4 / 2))`. + with_parentheses: bool, + }, /// Type cast expression. /// /// Example: `cast(a as text)`. @@ -228,6 +245,10 @@ impl Expression { pub fn is_row(&self) -> bool { matches!(self, Expression::Row { .. }) } + #[must_use] + pub fn is_arithmetic(&self) -> bool { + matches!(self, Expression::Arithmetic { .. }) + } /// Replaces parent in the reference node with the new one. pub fn replace_parent_in_reference(&mut self, from_id: Option<usize>, to_id: Option<usize>) { @@ -293,6 +314,73 @@ impl Nodes { Ok(self.push(Node::Expression(Expression::Bool { left, op, right }))) } + /// Adds arithmetic node. + /// + /// # Errors + /// - when left or right nodes are invalid + pub fn add_arithmetic_node( + &mut self, + left: usize, + op: operator::Arithmetic, + right: usize, + with_parentheses: bool, + ) -> Result<usize, SbroadError> { + self.arena.get(left).ok_or_else(|| { + SbroadError::NotFound( + Entity::Node, + format!( + "(left child of Arithmetic node) from arena with index {}", + left + ), + ) + })?; + self.arena.get(right).ok_or_else(|| { + SbroadError::NotFound( + Entity::Node, + format!( + "(right child of Arithmetic node) from arena with index {}", + right + ), + ) + })?; + Ok(self.push(Node::Expression(Expression::Arithmetic { + left, + op, + right, + with_parentheses, + }))) + } + + /// Set `with_parentheses` for arithmetic node. + /// + /// # Errors + /// - when left or right nodes are invalid + pub fn set_arithmetic_node_parentheses( + &mut self, + node_id: usize, + parentheses_to_set: bool, + ) -> Result<(), SbroadError> { + let arith_node = self.arena.get_mut(node_id).ok_or_else(|| { + SbroadError::NotFound( + Entity::Node, + format!("(Arithmetic node) from arena with index {node_id}"), + ) + })?; + + if let Node::Expression(Expression::Arithmetic { + with_parentheses, .. + }) = arith_node + { + *with_parentheses = parentheses_to_set; + Ok(()) + } else { + Err(SbroadError::Invalid( + Entity::Node, + Some(format!("expected Arithmetic with index {node_id}")), + )) + } + } + /// Adds reference node. pub fn add_ref( &mut self, @@ -881,6 +969,7 @@ impl Plan { let expr = self.get_expression_node(expr_id)?; match expr { Expression::Bool { .. } + | Expression::Arithmetic { .. } | Expression::Unary { .. } | Expression::Constant { value: Value::Boolean(_) | Value::Null, diff --git a/sbroad-core/src/ir/operator.rs b/sbroad-core/src/ir/operator.rs index d5b43d94b6..7523741cf8 100644 --- a/sbroad-core/src/ir/operator.rs +++ b/sbroad-core/src/ir/operator.rs @@ -84,6 +84,48 @@ impl Display for Bool { } } +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Hash, Clone)] +#[serde(rename_all = "lowercase")] +pub enum Arithmetic { + /// `*` + Multiply, + /// `/` + Divide, + /// `+` + Add, + /// `-` + Subtract, +} + +impl Arithmetic { + /// Creates `Arithmetic` from the operator string. + /// + /// # Errors + /// Returns `SbroadError` when the operator is invalid. + pub fn from(s: &str) -> Result<Self, SbroadError> { + match s.to_lowercase().as_str() { + "*" => Ok(Arithmetic::Multiply), + "/" => Ok(Arithmetic::Divide), + "+" => Ok(Arithmetic::Add), + "-" => Ok(Arithmetic::Subtract), + _ => Err(SbroadError::Unsupported(Entity::Operator, None)), + } + } +} + +impl Display for Arithmetic { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let op = match &self { + Arithmetic::Multiply => "*", + Arithmetic::Divide => "/", + Arithmetic::Add => "+", + Arithmetic::Subtract => "-", + }; + + write!(f, "{op}") + } +} + /// Unary operator returning Bool expression. #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Hash, Clone)] #[serde(rename_all = "lowercase")] diff --git a/sbroad-core/src/ir/transformation.rs b/sbroad-core/src/ir/transformation.rs index 875ea634f2..19d8974e55 100644 --- a/sbroad-core/src/ir/transformation.rs +++ b/sbroad-core/src/ir/transformation.rs @@ -166,7 +166,8 @@ impl Plan { *child = *new_id; } } - Expression::Bool { left, right, .. } => { + Expression::Bool { left, right, .. } + | Expression::Arithmetic { left, right, .. } => { if let Some(new_id) = map.get(left) { *left = *new_id; } diff --git a/sbroad-core/src/ir/transformation/merge_tuples.rs b/sbroad-core/src/ir/transformation/merge_tuples.rs index b4371fde04..45476ddf7c 100644 --- a/sbroad-core/src/ir/transformation/merge_tuples.rs +++ b/sbroad-core/src/ir/transformation/merge_tuples.rs @@ -84,6 +84,16 @@ impl Chain { _ => (*left, *right, op.clone()), }; + if let Ok(Expression::Arithmetic { .. }) = plan.get_expression_node(left_id) { + self.other.push(expr_id); + return Ok(()); + } + + if let Ok(Expression::Arithmetic { .. }) = plan.get_expression_node(right_id) { + self.other.push(expr_id); + return Ok(()); + } + // If boolean expression contains a reference to an additional // sub-query, it should be added to the "other" list. let left_sq = plan.get_sub_query_from_row_node(left_id)?; @@ -267,7 +277,7 @@ impl Plan { /// # Errors /// - Failed to build an expression subtree for some chain. /// - The plan is invalid (some bugs). - #[allow(clippy::type_complexity)] + #[allow(clippy::type_complexity, clippy::too_many_lines)] pub fn expr_tree_modify_and_chains( &mut self, expr_id: usize, @@ -336,6 +346,33 @@ impl Plan { } } } + Expression::Arithmetic { left, right, .. } => { + let children = vec![*left, *right]; + for (pos, child) in children.iter().enumerate() { + let chain = chains.get(child); + if let Some(chain) = chain { + let new_child_id = f_to_plan(chain, self)?; + let expr_mut = self.get_mut_expression_node(id)?; + if let Expression::Arithmetic { + left: ref mut left_id, + right: ref mut right_id, + .. + } = expr_mut + { + if pos == 0 { + *left_id = new_child_id; + } else { + *right_id = new_child_id; + } + } else { + return Err(SbroadError::Invalid( + Entity::Expression, + Some(format!("expected Arithmetic expression: {expr_mut:?}")), + )); + } + } + } + } Expression::Row { list, .. } => { let children = list.clone(); for (pos, child) in children.iter().enumerate() { diff --git a/sbroad-core/src/ir/transformation/redistribution.rs b/sbroad-core/src/ir/transformation/redistribution.rs index 5246b61224..e344bda275 100644 --- a/sbroad-core/src/ir/transformation/redistribution.rs +++ b/sbroad-core/src/ir/transformation/redistribution.rs @@ -657,6 +657,7 @@ impl Plan { /// /// # Errors /// - Failed to set row distribution in the join condition tree. + #[allow(clippy::too_many_lines)] fn resolve_join_conflicts( &mut self, rel_id: usize, @@ -719,6 +720,9 @@ impl Plan { let left_expr = self.get_expression_node(bool_op.left)?; let right_expr = self.get_expression_node(bool_op.right)?; new_inner_policy = match (left_expr, right_expr) { + (Expression::Arithmetic { .. }, _) | (_, Expression::Arithmetic { .. }) => { + MotionPolicy::Full + } (Expression::Bool { .. }, Expression::Bool { .. }) => { let left_policy = inner_map .get(&bool_op.left) diff --git a/sbroad-core/src/ir/tree/expression.rs b/sbroad-core/src/ir/tree/expression.rs index e37b87155e..2bdab47715 100644 --- a/sbroad-core/src/ir/tree/expression.rs +++ b/sbroad-core/src/ir/tree/expression.rs @@ -78,7 +78,9 @@ fn expression_next<'nodes>( None } Some(Node::Expression( - Expression::Bool { left, right, .. } | Expression::Concat { left, right }, + Expression::Bool { left, right, .. } + | Expression::Arithmetic { left, right, .. } + | Expression::Concat { left, right }, )) => { let child_step = *iter.get_child().borrow(); if child_step == 0 { diff --git a/sbroad-core/src/ir/tree/subtree.rs b/sbroad-core/src/ir/tree/subtree.rs index 672bf05be0..fbc0370890 100644 --- a/sbroad-core/src/ir/tree/subtree.rs +++ b/sbroad-core/src/ir/tree/subtree.rs @@ -205,7 +205,9 @@ fn subtree_next<'plan>( } None } - Expression::Bool { left, right, .. } | Expression::Concat { left, right } => { + Expression::Bool { left, right, .. } + | Expression::Arithmetic { left, right, .. } + | Expression::Concat { left, right } => { let child_step = *iter.get_child().borrow(); if child_step == 0 { *iter.get_child().borrow_mut() += 1; -- GitLab