From ff5d3a447570a6b76d0b7e9069f0e65fd1a8caa8 Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <a.volynets@picodata.io> Date: Tue, 13 Jun 2023 09:26:54 +0000 Subject: [PATCH] feat: support count(*) --- .../test/integration/groupby_test.lua | 69 ++++++++++++++++++ sbroad-core/src/backend/sql/ir.rs | 3 + sbroad-core/src/backend/sql/tree.rs | 2 +- sbroad-core/src/executor/ir.rs | 2 +- sbroad-core/src/executor/tests/subtree.rs | 71 +++++++++++++++++++ sbroad-core/src/frontend/sql.rs | 46 ++++++++---- sbroad-core/src/frontend/sql/ast.rs | 3 + sbroad-core/src/frontend/sql/ast/tests.rs | 18 +++++ sbroad-core/src/frontend/sql/ir.rs | 4 +- sbroad-core/src/frontend/sql/ir/tests.rs | 54 ++++++++++++++ sbroad-core/src/frontend/sql/query.pest | 5 +- sbroad-core/src/ir.rs | 4 +- sbroad-core/src/ir/api/parameter.rs | 8 ++- sbroad-core/src/ir/explain.rs | 15 ++++ sbroad-core/src/ir/expression.rs | 2 + sbroad-core/src/ir/helpers.rs | 3 + .../transformation/redistribution/eq_cols.rs | 2 +- .../transformation/redistribution/groupby.rs | 6 ++ sbroad-core/src/ir/tree/expression.rs | 6 +- sbroad-core/src/ir/tree/subtree.rs | 2 +- 20 files changed, 302 insertions(+), 23 deletions(-) diff --git a/sbroad-cartridge/test_app/test/integration/groupby_test.lua b/sbroad-cartridge/test_app/test/integration/groupby_test.lua index a172a16a91..16a8c826f9 100644 --- a/sbroad-cartridge/test_app/test/integration/groupby_test.lua +++ b/sbroad-cartridge/test_app/test/integration/groupby_test.lua @@ -68,6 +68,24 @@ groupby_queries.before_all( t.assert_equals(err, nil) t.assert_equals(r, {row_count = 4}) + r, err = api:call("sbroad.execute", { + [[ + INSERT INTO "null_t" + ("na", "nb", "nc") + VALUES (?,?,?),(?,?,?), + (?,?,?),(?,?,?),(?,?,?) + ]], + { + 1, nil, 1, + 2, nil, nil, + 3, nil, 3, + 4, 1, 2, + 5, nil, 1, + } + }) + + t.assert_equals(err, nil) + t.assert_equals(r, {row_count = 5}) end ) @@ -1375,3 +1393,54 @@ groupby_queries.test_aggr_distinct_without_groupby = function() {3, 4, 2, 5, 6}, }) end + +groupby_queries.test_count_asterisk = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT count(*) from "arithmetic_space" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "decimal" }, + }) + t.assert_items_equals(r.rows, { + {4} + }) + local api = cluster:server("api-1").net_box + + -- check on table with nulls + r, err = api:call("sbroad.execute", { + [[ + SELECT count(*) from "null_t" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "decimal" }, + }) + t.assert_items_equals(r.rows, { + {5} + }) +end + +groupby_queries.test_count_asterisk_with_groupby = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT count(*), "nb" from "null_t" group by "nb" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "decimal" }, + { name = "nb", type = "integer" }, + }) + t.assert_items_equals(r.rows, { + {4, nil}, + {1, 1} + }) +end diff --git a/sbroad-core/src/backend/sql/ir.rs b/sbroad-core/src/backend/sql/ir.rs index d5d7968774..97e478cbf7 100644 --- a/sbroad-core/src/backend/sql/ir.rs +++ b/sbroad-core/src/backend/sql/ir.rs @@ -386,6 +386,9 @@ impl ExecutionPlan { Expression::StableFunction { name, .. } => { sql.push_str(name.as_str()); } + Expression::CountAsterisk => { + sql.push('*'); + } } } } diff --git a/sbroad-core/src/backend/sql/tree.rs b/sbroad-core/src/backend/sql/tree.rs index bfecec0927..2bd6dd79c6 100644 --- a/sbroad-core/src/backend/sql/tree.rs +++ b/sbroad-core/src/backend/sql/tree.rs @@ -868,7 +868,7 @@ impl<'p> SyntaxPlan<'p> { let sn = SyntaxNode::new_parameter(id); Ok(self.nodes.push_syntax_node(sn)) } - Expression::Reference { .. } => { + Expression::Reference { .. } | Expression::CountAsterisk => { let sn = SyntaxNode::new_pointer(id, None, vec![]); Ok(self.nodes.push_syntax_node(sn)) } diff --git a/sbroad-core/src/executor/ir.rs b/sbroad-core/src/executor/ir.rs index 100689c718..f2aa1f104f 100644 --- a/sbroad-core/src/executor/ir.rs +++ b/sbroad-core/src/executor/ir.rs @@ -450,7 +450,7 @@ impl ExecutionPlan { })?; } } - Expression::Constant { .. } => {} + Expression::Constant { .. } | Expression::CountAsterisk => {} }, Node::Parameter { .. } => {} } diff --git a/sbroad-core/src/executor/tests/subtree.rs b/sbroad-core/src/executor/tests/subtree.rs index f8edde6097..88ae99639a 100644 --- a/sbroad-core/src/executor/tests/subtree.rs +++ b/sbroad-core/src/executor/tests/subtree.rs @@ -403,3 +403,74 @@ fn exec_plan_subtree_aggregates_no_groupby() { vec![] )); } + +#[test] +fn exec_plan_subtree_count_asterisk() { + let sql = r#"SELECT count(*) FROM "test_space""#; + let coordinator = RouterRuntimeMock::new(); + + let mut query = Query::new(&coordinator, sql, vec![]).unwrap(); + let motion_id = *query + .exec_plan + .get_ir_plan() + .clone_slices() + .slice(0) + .unwrap() + .position(0) + .unwrap(); + let mut virtual_table = VirtualTable::new(); + virtual_table.add_column(Column { + name: "count_13".into(), + r#type: Type::Integer, + role: ColumnRole::User, + }); + virtual_table.set_alias("").unwrap(); + if let MotionPolicy::Segment(key) = get_motion_policy(query.exec_plan.get_ir_plan(), motion_id) + { + query.reshard_vtable(&mut virtual_table, key).unwrap(); + } + + let mut vtables: HashMap<usize, Rc<VirtualTable>> = HashMap::new(); + vtables.insert(motion_id, Rc::new(virtual_table)); + + let exec_plan = query.get_mut_exec_plan(); + exec_plan.set_vtables(vtables); + let top_id = exec_plan.get_ir_plan().get_top().unwrap(); + let motion_child_id = exec_plan.get_motion_subtree_root(motion_id).unwrap(); + + // Check groupby local stage + let subplan1 = exec_plan.take_subtree(motion_child_id).unwrap(); + let subplan1_top_id = subplan1.get_ir_plan().get_top().unwrap(); + let sp = SyntaxPlan::new(&subplan1, subplan1_top_id, Snapshot::Oldest).unwrap(); + let ordered = OrderedSyntaxNodes::try_from(sp).unwrap(); + let nodes = ordered.to_syntax_data().unwrap(); + let (sql, _) = subplan1.to_sql(&nodes, &Buckets::All, "test").unwrap(); + if let MotionPolicy::Full = exec_plan.get_motion_policy(motion_id).unwrap() { + } else { + panic!("Expected MotionPolicy::Full for local aggregation stage"); + }; + + assert_eq!( + sql, + PatternWithParams::new( + r#"SELECT count (*) as "count_13" FROM "test_space""#.to_string(), + vec![] + ) + ); + + // Check main query + let subplan2 = exec_plan.take_subtree(top_id).unwrap(); + let subplan2_top_id = subplan2.get_ir_plan().get_top().unwrap(); + let sp = SyntaxPlan::new(&subplan2, subplan2_top_id, Snapshot::Oldest).unwrap(); + let ordered = OrderedSyntaxNodes::try_from(sp).unwrap(); + let nodes = ordered.to_syntax_data().unwrap(); + let (sql, _) = subplan2.to_sql(&nodes, &Buckets::All, "test").unwrap(); + assert_eq!( + sql, + PatternWithParams::new( + r#"SELECT sum ("count_13") as "COL_1" FROM (SELECT "count_13" FROM "TMP_test_7")"# + .to_string(), + vec![] + ) + ); +} diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs index f30abf7eaf..9ea584f2e7 100644 --- a/sbroad-core/src/frontend/sql.rs +++ b/sbroad-core/src/frontend/sql.rs @@ -700,15 +700,37 @@ impl Ast for AbstractSyntaxTree { Type::Function => { if let Some((first, mut other)) = node.children.split_first() { let mut is_distinct = false; + let function_name = + self.nodes.get_node(*first)?.value.as_ref().ok_or_else(|| { + SbroadError::NotFound(Entity::Name, "of sql function".into()) + })?; if let Some(first_id) = other.first() { - if let Type::Distinct = self.nodes.get_node(*first_id)?.rule { - is_distinct = true; - let Some((_, args)) = other.split_first() else { - return Err(SbroadError::Invalid( - Entity::AST, - Some("function ast has no arguments".into()))) - }; - other = args; + let rule = &self.nodes.get_node(*first_id)?.rule; + match rule { + Type::Distinct => { + is_distinct = true; + let Some((_, args)) = other.split_first() else { + return Err(SbroadError::Invalid( + Entity::AST, + Some("function ast has no arguments".into()))) + }; + other = args; + } + Type::CountAsterisk => { + if other.len() > 1 { + return Err(SbroadError::UnexpectedNumberOfValues( + "function ast with Asterisk has extra children".into(), + )); + } + let normalized_name = function_name.to_lowercase(); + if "count" != normalized_name.as_str() { + return Err(SbroadError::Invalid( + Entity::Query, + Some(format!("\"*\" is allowed only inside \"count\" aggregate function. Got: {normalized_name}")) + )); + } + } + _ => {} } } let mut plan_arg_list = Vec::new(); @@ -716,10 +738,6 @@ impl Ast for AbstractSyntaxTree { let plan_child_id = map.get(*ast_child_id)?; plan_arg_list.push(plan_child_id); } - let function_name = - self.nodes.get_node(*first)?.value.as_ref().ok_or_else(|| { - SbroadError::NotFound(Entity::Name, "of sql function".into()) - })?; if Expression::is_aggregate_name(function_name) { if plan_arg_list.len() != 1 { @@ -1022,6 +1040,10 @@ impl Ast for AbstractSyntaxTree { })?; map.add(0, map.get(*ast_child_id)?); } + Type::CountAsterisk => { + let plan_id = plan.nodes.push(Node::Expression(Expression::CountAsterisk)); + map.add(id, plan_id); + } Type::AliasName | Type::Add | Type::ColumnName diff --git a/sbroad-core/src/frontend/sql/ast.rs b/sbroad-core/src/frontend/sql/ast.rs index f783f9e907..5f9619a63b 100644 --- a/sbroad-core/src/frontend/sql/ast.rs +++ b/sbroad-core/src/frontend/sql/ast.rs @@ -39,6 +39,7 @@ pub enum Type { ColumnName, Concat, Condition, + CountAsterisk, Distinct, Decimal, Divide, @@ -124,6 +125,7 @@ impl Type { Rule::Column => Ok(Type::Column), Rule::ColumnName => Ok(Type::ColumnName), Rule::Concat => Ok(Type::Concat), + Rule::CountAsterisk => Ok(Type::CountAsterisk), Rule::Condition => Ok(Type::Condition), Rule::Decimal => Ok(Type::Decimal), Rule::Divide => Ok(Type::Divide), @@ -215,6 +217,7 @@ impl fmt::Display for Type { Type::Column => "Column".to_string(), Type::ColumnName => "ColumnName".to_string(), Type::Concat => "Concat".to_string(), + Type::CountAsterisk => "CountAsterisk".to_string(), Type::Condition => "Condition".to_string(), Type::Decimal => "Decimal".to_string(), Type::Distinct => "Distinct".to_string(), diff --git a/sbroad-core/src/frontend/sql/ast/tests.rs b/sbroad-core/src/frontend/sql/ast/tests.rs index b8b0255d28..aad6ea0e01 100644 --- a/sbroad-core/src/frontend/sql/ast/tests.rs +++ b/sbroad-core/src/frontend/sql/ast/tests.rs @@ -225,6 +225,24 @@ fn invalid_query() { ); } +#[test] +fn front_sql_invalid_count_asterisk2() { + let input = r#"SELECT count(distinct *) FROM "t" group by "b""#; + + let ast = AbstractSyntaxTree::new(input); + + assert_eq!(true, ast.is_err()); +} + +#[test] +fn front_sql_invalid_count_asterisk3() { + let input = r#"SELECT count(*, a) FROM "t" group by "b""#; + + let ast = AbstractSyntaxTree::new(input); + + assert_eq!(true, ast.is_err()); +} + #[test] fn invalid_condition() { let query = r#"SELECT "identification_number", "product_code" FROM "test_space" WHERE diff --git a/sbroad-core/src/frontend/sql/ir.rs b/sbroad-core/src/frontend/sql/ir.rs index 192102b124..533be17e09 100644 --- a/sbroad-core/src/frontend/sql/ir.rs +++ b/sbroad-core/src/frontend/sql/ir.rs @@ -377,7 +377,9 @@ impl Plan { let next_id = self.nodes.next_id(); let mut expr = self.get_expression_node(id)?.clone(); match expr { - Expression::Constant { .. } | Expression::Reference { .. } => {} + Expression::Constant { .. } + | Expression::Reference { .. } + | Expression::CountAsterisk => {} Expression::Alias { ref mut child, .. } | Expression::Cast { ref mut child, .. } | Expression::Unary { ref mut child, .. } => { diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs index f06a48ae18..0313672519 100644 --- a/sbroad-core/src/frontend/sql/ir/tests.rs +++ b/sbroad-core/src/frontend/sql/ir/tests.rs @@ -713,6 +713,60 @@ fn front_sql_aggregates() { assert_eq!(expected_explain, plan.as_explain().unwrap()); } +#[test] +fn front_sql_count_asterisk1() { + let input = r#"SELECT count(*), count(*) FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (sum(("count_13")) -> "COL_1", sum(("count_13")) -> "COL_2") + motion [policy: full] + scan + projection (count((*)) -> "count_13") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_count_asterisk2() { + let input = r#"SELECT cOuNt(*), "b" FROM "t" group by "b""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (sum(("count_26")) -> "COL_1", "column_12" -> "b") + group by ("column_12") output: ("column_12" -> "column_12", "count_26" -> "count_26") + motion [policy: segment([ref("column_12")])] + scan + projection ("t"."b" -> "column_12", count((*)) -> "count_26") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_invalid_count_asterisk1() { + let input = r#"SELECT sum(*) FROM "t" group by "b""#; + + let metadata = &RouterConfigurationMock::new(); + let ast = AbstractSyntaxTree::new(input).unwrap(); + let plan = ast.resolve_metadata(metadata); + let err = plan.unwrap_err(); + + assert_eq!( + true, + err.to_string() + .contains("\"*\" is allowed only inside \"count\" aggregate function.") + ); +} + #[test] fn front_sql_aggregates_with_subexpressions() { let input = r#"SELECT "b", count("a" * "b" + 1), count(bucket_id("a")) FROM "t" diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest index 4758297127..8d5aa5273a 100644 --- a/sbroad-core/src/frontend/sql/query.pest +++ b/sbroad-core/src/frontend/sql/query.pest @@ -18,7 +18,7 @@ Query = _{ Except | UnionAll | Select | Values | Insert } ColumnName = @{ Name } ScanName = @{ Name } Name = @{ NameString | ("\"" ~ NameString ~ "\"") } - Asterisk = @{ "*" } + Asterisk = { "*" } Selection = { Expr } Scan = { (SubQuery | Table) ~ (^"as" ~ ScanName)? } Table = @{ Name } @@ -92,10 +92,11 @@ Expr = _{ Or | And | Unary | Between | Cmp | Primary | Parentheses } OrLeft = _{ AndRight } OrRight = _{ Or | OrLeft } -Function = { FunctionName ~ ("(" ~ Distinct? ~ FunctionArgs ~ ")") } +Function = { FunctionName ~ "(" ~ (CountAsterisk | (Distinct? ~ FunctionArgs)) ~ ")" } FunctionName = @{ Name } FunctionArgs = _{ (FunctionExpr ~ ("," ~ FunctionExpr)*)? } FunctionExpr = _{ Parentheses | Primary } + CountAsterisk = { "*" } Distinct = { ^"distinct" } diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs index 5545e9a193..3034de2fab 100644 --- a/sbroad-core/src/ir.rs +++ b/sbroad-core/src/ir.rs @@ -585,7 +585,9 @@ impl Plan { } } } - Expression::Constant { .. } | Expression::Reference { .. } => {} + Expression::Constant { .. } + | Expression::Reference { .. } + | Expression::CountAsterisk => {} } return Err(SbroadError::FailedTo( Action::Replace, diff --git a/sbroad-core/src/ir/api/parameter.rs b/sbroad-core/src/ir/api/parameter.rs index bf59affdda..247fe37519 100644 --- a/sbroad-core/src/ir/api/parameter.rs +++ b/sbroad-core/src/ir/api/parameter.rs @@ -153,7 +153,9 @@ impl Plan { } } } - Expression::Constant { .. } | Expression::Reference { .. } => {} + Expression::Constant { .. } + | Expression::Reference { .. } + | Expression::CountAsterisk => {} }, Node::Parameter => {} } @@ -246,7 +248,9 @@ impl Plan { } } } - Expression::Constant { .. } | Expression::Reference { .. } => {} + Expression::Constant { .. } + | Expression::Reference { .. } + | Expression::CountAsterisk => {} }, Node::Parameter => {} } diff --git a/sbroad-core/src/ir/explain.rs b/sbroad-core/src/ir/explain.rs index 741d331627..549ed9f1db 100644 --- a/sbroad-core/src/ir/explain.rs +++ b/sbroad-core/src/ir/explain.rs @@ -84,6 +84,9 @@ impl ColExpr { let cast_expr = ColExpr::Cast(Box::new(expr), to.clone()); stack.push(cast_expr); } + Expression::CountAsterisk => { + stack.push(ColExpr::Column("*".to_string())); + } Expression::Reference { position, .. } => { let mut col_name = String::new(); @@ -470,6 +473,12 @@ impl Row { let col = Col::new(plan, *child)?; row.add_col(RowVal::Column(col)); } + Expression::CountAsterisk => { + return Err(SbroadError::Invalid( + Entity::Plan, + Some("CountAsterisk can't be present among Row children!".into()), + )) + } } } @@ -549,6 +558,12 @@ impl Selection { let row = Row::from_ir_nodes(plan, &[subtree_node_id], ref_map)?; Selection::Row(row) } + Expression::CountAsterisk => { + return Err(SbroadError::Invalid( + Entity::Plan, + Some("CountAsterisk can't be present in Selection filter!".into()), + )) + } }; Ok(result) diff --git a/sbroad-core/src/ir/expression.rs b/sbroad-core/src/ir/expression.rs index e47cd832f9..6a8c533418 100644 --- a/sbroad-core/src/ir/expression.rs +++ b/sbroad-core/src/ir/expression.rs @@ -150,6 +150,8 @@ pub enum Expression { /// Child expression node index in the plan node arena. child: usize, }, + /// Argument of `count` aggregate in `count(*)` expression + CountAsterisk, } #[allow(dead_code)] diff --git a/sbroad-core/src/ir/helpers.rs b/sbroad-core/src/ir/helpers.rs index 74b6d22bf3..801c61dcb6 100644 --- a/sbroad-core/src/ir/helpers.rs +++ b/sbroad-core/src/ir/helpers.rs @@ -85,6 +85,9 @@ impl Plan { Expression::Constant { value } => { writeln!(buf, "Constant [value = {value}]")?; } + Expression::CountAsterisk => { + writeln!(buf, "CountAsterisk")?; + } Expression::Reference { targets, position, diff --git a/sbroad-core/src/ir/transformation/redistribution/eq_cols.rs b/sbroad-core/src/ir/transformation/redistribution/eq_cols.rs index aea4d6bcd6..86bc512f87 100644 --- a/sbroad-core/src/ir/transformation/redistribution/eq_cols.rs +++ b/sbroad-core/src/ir/transformation/redistribution/eq_cols.rs @@ -83,7 +83,7 @@ impl ReferredMap { .add(referred.get_or_none(*right)); referred.insert(node_id, res); } - Expression::Constant { .. } => { + Expression::Constant { .. } | Expression::CountAsterisk => { referred.insert(node_id, Referred::None); } Expression::Reference { diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs index 296a1bce84..d7ebfa907e 100644 --- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs +++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs @@ -198,6 +198,9 @@ impl Plan { op.hash(state); self.hash_for_expr(*child, state, depth - 1); } + Expression::CountAsterisk => { + "CountAsterisk".hash(state); + } } } } @@ -423,6 +426,9 @@ impl Plan { if let Node::Expression(right) = r { match left { Expression::Alias { .. } => {} + Expression::CountAsterisk => { + return Ok(matches!(right, Expression::CountAsterisk)) + } Expression::Bool { left: left_left, op: op_left, diff --git a/sbroad-core/src/ir/tree/expression.rs b/sbroad-core/src/ir/tree/expression.rs index 3dca6f2d90..40b4583824 100644 --- a/sbroad-core/src/ir/tree/expression.rs +++ b/sbroad-core/src/ir/tree/expression.rs @@ -173,7 +173,11 @@ fn expression_next<'nodes>( } } Some( - Node::Expression(Expression::Constant { .. } | Expression::Reference { .. }) + Node::Expression( + Expression::Constant { .. } + | Expression::Reference { .. } + | Expression::CountAsterisk, + ) | Node::Relational(_) | Node::Parameter, ) diff --git a/sbroad-core/src/ir/tree/subtree.rs b/sbroad-core/src/ir/tree/subtree.rs index b1abbdf63f..5c2ae10b94 100644 --- a/sbroad-core/src/ir/tree/subtree.rs +++ b/sbroad-core/src/ir/tree/subtree.rs @@ -229,7 +229,7 @@ fn subtree_next<'plan>( } }; } - Expression::Constant { .. } => None, + Expression::Constant { .. } | Expression::CountAsterisk => None, Expression::Reference { .. } => { let step = *iter.get_child().borrow(); if step == 0 { -- GitLab