From 6370f968a650bc6f4ca152de6084076505d937eb Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <a.volynets@picodata.io> Date: Mon, 19 Jun 2023 10:02:59 +0000 Subject: [PATCH] feat: add avg, min, max, total, group_concat --- .../test/integration/groupby_test.lua | 314 +++++++++++++++++- sbroad-core/src/executor/tests/exec_plan.rs | 51 ++- sbroad-core/src/frontend/sql.rs | 32 +- sbroad-core/src/frontend/sql/ast.rs | 3 + sbroad-core/src/frontend/sql/ast/tests.rs | 16 + sbroad-core/src/frontend/sql/ir/tests.rs | 114 +++++++ sbroad-core/src/frontend/sql/query.pest | 2 +- sbroad-core/src/ir/aggregates.rs | 139 ++++++-- sbroad-core/src/ir/explain.rs | 1 + sbroad-core/src/ir/function.rs | 41 ++- .../transformation/redistribution/groupby.rs | 12 +- .../frontend/sql/single_quoted_str_ast.yaml | 40 +++ 12 files changed, 683 insertions(+), 82 deletions(-) create mode 100644 sbroad-core/tests/artifactory/frontend/sql/single_quoted_str_ast.yaml diff --git a/sbroad-cartridge/test_app/test/integration/groupby_test.lua b/sbroad-cartridge/test_app/test/integration/groupby_test.lua index 16a8c826f9..04c9608d9d 100644 --- a/sbroad-cartridge/test_app/test/integration/groupby_test.lua +++ b/sbroad-cartridge/test_app/test/integration/groupby_test.lua @@ -1274,7 +1274,7 @@ groupby_queries.test_join_single7 = function() local api = cluster:server("api-1").net_box local r, err = api:call("sbroad.execute", { - [[ select o.a, i.d from (select "c" + 3 as c, "d" + 4 as d from "arithmetic_space") as o + [[ select i.a, o.d from (select "c" + 3 as c, "d" + 4 as d from "arithmetic_space") as o inner join (select sum("a") as a, count("b") as b from "arithmetic_space") as i on i.a = cast(o.d as number) ]], {} @@ -1290,12 +1290,13 @@ groupby_queries.test_join_single7 = function() }) end -groupby_queries.test_join_single7 = function() +groupby_queries.test_join_single8 = function() local api = cluster:server("api-1").net_box + local r, err = api:call("sbroad.execute", { - [[ select i.a, o.d from (select "c" as c, "d" as d from "arithmetic_space") as o + [[ select i.a, o.d from (select "c" + 3 as c, "d" + 4 as d from "arithmetic_space") as o inner join (select sum("a") as a, count("b") as b from "arithmetic_space") as i - on i.a = o.d + 4 and i.b = o.c + 3 + on i.a < 10 ]], {} }) t.assert_equals(err, nil) @@ -1304,18 +1305,19 @@ groupby_queries.test_join_single7 = function() { name = "O.D", type = "integer" }, }) t.assert_items_equals(r.rows, { - { 6, 2 }, - { 6, 2 }, + { 6, 6 }, + { 6, 6 }, + { 6, 5 }, + { 6, 5 }, }) end -groupby_queries.test_join_single8 = function() +groupby_queries.test_join_single9 = function() local api = cluster:server("api-1").net_box - local r, err = api:call("sbroad.execute", { - [[ select i.a, o.d from (select "c" + 3 as c, "d" + 4 as d from "arithmetic_space") as o + [[ select i.a, o.d from (select "c" as c, "d" as d from "arithmetic_space" group by "c", "d") as o inner join (select sum("a") as a, count("b") as b from "arithmetic_space") as i - on i.a < 10 + on i.a < o.d + 5 ]], {} }) t.assert_equals(err, nil) @@ -1324,19 +1326,16 @@ groupby_queries.test_join_single8 = function() { name = "O.D", type = "integer" }, }) t.assert_items_equals(r.rows, { - { 6, 6 }, - { 6, 6 }, - { 6, 5 }, - { 6, 5 }, + { 6, 2 }, }) end -groupby_queries.test_join_single9 = function() +groupby_queries.test_join_single10 = function() local api = cluster:server("api-1").net_box local r, err = api:call("sbroad.execute", { - [[ select i.a, o.d from (select "c" as c, "d" as d from "arithmetic_space" group by "c", "d") as o + [[ select i.a, o.d from (select "c" as c, "d" as d from "arithmetic_space") as o inner join (select sum("a") as a, count("b") as b from "arithmetic_space") as i - on i.a < o.d + 5 + on i.a = o.d + 4 and i.b = o.c + 3 ]], {} }) t.assert_equals(err, nil) @@ -1346,8 +1345,10 @@ groupby_queries.test_join_single9 = function() }) t.assert_items_equals(r.rows, { { 6, 2 }, + { 6, 2 }, }) end + groupby_queries.test_aggr_distinct = function() local api = cluster:server("api-1").net_box @@ -1444,3 +1445,282 @@ groupby_queries.test_count_asterisk_with_groupby = function() {1, 1} }) end + +groupby_queries.test_avg = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT avg("c"), avg(distinct "c"), avg("b"), avg(distinct "b") from "arithmetic_space" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "double" }, + { name = "COL_2", type = "double" }, + { name = "COL_3", type = "double" }, + { name = "COL_4", type = "double" }, + }) + t.assert_items_equals(r.rows, { + {1, 1, 2.25, 2} + }) +end + +groupby_queries.test_avg_with_groupby = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT "a", avg("b"), avg(distinct "b") FROM "arithmetic_space" + GROUP BY "a" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "a", type = "integer" }, + { name = "COL_1", type = "double" }, + { name = "COL_2", type = "double" }, + }) + t.assert_items_equals(r.rows, { + {1, 1.5, 1.5}, + {2, 3, 3}, + }) +end + +groupby_queries.test_group_concat = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT group_concat(cast("c" as string), ' '), group_concat(distinct cast("c" as string)) + from "arithmetic_space" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "string" }, + { name = "COL_2", type = "string" }, + }) + t.assert_items_equals(r.rows, { + {"1 1 1 1", "1"} + }) +end + +groupby_queries.test_group_concat_with_groupby = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT "a", group_concat(cast("e" as string), '|'), group_concat(distinct cast("e" as string)) + FROM "arithmetic_space" + GROUP BY "a" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "a", type = "integer" }, + { name = "COL_1", type = "string" }, + { name = "COL_2", type = "string" }, + }) + t.assert_items_equals(r.rows, { + {1, "2|2", "2"}, + {2, "2|2", "2"}, + }) +end + +groupby_queries.test_min = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT min("id"), min(distinct "d" / 2) from "arithmetic_space" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "scalar" }, + { name = "COL_2", type = "scalar" }, + }) + t.assert_items_equals(r.rows, { + {1, 0} + }) +end + +groupby_queries.test_min_with_groupby = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT "a", min("b"), min(distinct "b") FROM "arithmetic_space" + GROUP BY "a" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "a", type = "integer" }, + { name = "COL_1", type = "scalar" }, + { name = "COL_2", type = "scalar" }, + }) + t.assert_items_equals(r.rows, { + {1, 1, 1}, + {2, 3, 3}, + }) +end + +groupby_queries.test_max = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT max("id"), max(distinct "d" / 2) from "arithmetic_space" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "scalar" }, + { name = "COL_2", type = "scalar" }, + }) + t.assert_items_equals(r.rows, { + {4, 1} + }) +end + +groupby_queries.test_max_with_groupby = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT "a", max("b"), max(distinct "b") FROM "arithmetic_space" + GROUP BY "a" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "a", type = "integer" }, + { name = "COL_1", type = "scalar" }, + { name = "COL_2", type = "scalar" }, + }) + t.assert_items_equals(r.rows, { + {1, 2, 2}, + {2, 3, 3}, + }) +end + +groupby_queries.test_total = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT total("id"), total(distinct "d" / 2) from "arithmetic_space" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "double" }, + { name = "COL_2", type = "double" }, + }) + t.assert_items_equals(r.rows, { + {10, 1} + }) +end + +groupby_queries.test_total_no_rows = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT total("id") from ( + select * from "arithmetic_space" inner join + "null_t" on false + ) + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "double" }, + }) + t.assert_items_equals(r.rows, { + { 0 } + }) +end + +groupby_queries.test_total_null_rows = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT total("nb") from ( + select * from "arithmetic_space" left join + "null_t" on false + ) + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "double" }, + }) + t.assert_items_equals(r.rows, { + { 0 } + }) +end + +groupby_queries.test_sum_no_rows = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT sum("id") from ( + select * from "arithmetic_space" inner join + "null_t" on false + ) + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "decimal" }, + }) + t.assert_items_equals(r.rows, { + { nil } + }) +end + +groupby_queries.test_sum_null_rows = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT sum("nb") from ( + select * from "arithmetic_space" left join + "null_t" on false + ) + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "COL_1", type = "decimal" }, + }) + t.assert_items_equals(r.rows, { + { nil } + }) +end + +groupby_queries.test_total_with_groupby = function() + local api = cluster:server("api-1").net_box + + local r, err = api:call("sbroad.execute", { + [[ + SELECT "a", total("b"), total(distinct "b") FROM "arithmetic_space" + GROUP BY "a" + ]], {} + }) + t.assert_equals(err, nil) + t.assert_equals(r.metadata, { + { name = "a", type = "integer" }, + { name = "COL_1", type = "double" }, + { name = "COL_2", type = "double" }, + }) + t.assert_items_equals(r.rows, { + {1, 3, 3}, + {2, 6, 3}, + }) +end diff --git a/sbroad-core/src/executor/tests/exec_plan.rs b/sbroad-core/src/executor/tests/exec_plan.rs index 240c7006ed..2cd9343e2f 100644 --- a/sbroad-core/src/executor/tests/exec_plan.rs +++ b/sbroad-core/src/executor/tests/exec_plan.rs @@ -235,9 +235,10 @@ fn exec_plan_subtree_two_stage_groupby_test_2() { #[test] fn exec_plan_subtree_aggregates() { let sql = format!( - "{} {} {}", + "{} {} {} {}", r#"SELECT t1."sys_op" || t1."sys_op", t1."sys_op"*2 + count(t1."sysFrom"),"#, - r#"sum(t1."id"), sum(distinct t1."id"*t1."sys_op") / count(distinct "id")"#, + r#"sum(t1."id"), sum(distinct t1."id"*t1."sys_op") / count(distinct "id"),"#, + r#"group_concat(t1."id", 'o'), avg(t1."id"), total(t1."id"), min(t1."id"), max(t1."id")"#, r#"FROM "test_space" as t1 group by t1."sys_op""# ); let coordinator = RouterRuntimeMock::new(); @@ -277,6 +278,31 @@ fn exec_plan_subtree_aggregates() { r#type: Type::Integer, role: ColumnRole::User, }); + virtual_table.add_column(Column { + name: "group_concat_58".into(), + r#type: Type::String, + role: ColumnRole::User, + }); + virtual_table.add_column(Column { + name: "count_61".into(), + r#type: Type::Integer, + role: ColumnRole::User, + }); + virtual_table.add_column(Column { + name: "total_64".into(), + r#type: Type::Integer, + role: ColumnRole::User, + }); + virtual_table.add_column(Column { + name: "min_67".into(), + r#type: Type::Integer, + role: ColumnRole::User, + }); + virtual_table.add_column(Column { + name: "max_70".into(), + r#type: Type::Integer, + role: ColumnRole::User, + }); if let MotionPolicy::Segment(key) = get_motion_policy(query.exec_plan.get_ir_plan(), motion_id) { virtual_table.reshard(key, &query.coordinator).unwrap(); @@ -306,13 +332,15 @@ fn exec_plan_subtree_aggregates() { sql, PatternWithParams::new( format!( - "{} {} {} {}", + "{} {} {} {} {} {}", r#"SELECT "T1"."sys_op" as "column_12", ("T1"."id") * ("T1"."sys_op") as "column_48","#, - r#""T1"."id" as "column_50", sum ("T1"."id") as "sum_42", count ("T1"."sysFrom") as"#, - r#""count_37" FROM "test_space" as "T1""#, + r#""T1"."id" as "column_50", group_concat ("T1"."id", ?) as "group_concat_58","#, + r#"sum ("T1"."id") as "sum_42", count ("T1"."id") as "count_61", total ("T1"."id") as "total_64","#, + r#"min ("T1"."id") as "min_67", max ("T1"."id") as "max_70", count ("T1"."sysFrom") as "count_37""#, + r#"FROM "test_space" as "T1""#, r#"GROUP BY "T1"."sys_op", ("T1"."id") * ("T1"."sys_op"), "T1"."id""#, ), - vec![] + vec![Value::from("o")] ) ); @@ -323,14 +351,17 @@ fn exec_plan_subtree_aggregates() { sql, PatternWithParams::new( format!( - "{} {} {} {} {}", + "{} {} {} {} {} {} {} {}", r#"SELECT ("column_12") || ("column_12") as "COL_1","#, r#"("column_12") * (?) + (sum ("count_37")) as "COL_2", sum ("sum_42") as "COL_3","#, - r#"(sum (DISTINCT "column_48")) / (count (DISTINCT "column_50")) as "COL_4""#, - r#"FROM (SELECT "sys_op","sum_42","count_37","sum_49","count_51" FROM "TMP_test_39")"#, + r#"(sum (DISTINCT "column_48")) / (count (DISTINCT "column_50")) as "COL_4","#, + r#"group_concat ("group_concat_58", ?) as "COL_5","#, + r#"(sum (CAST ("sum_42" as double)) / sum (CAST ("count_61" as double))) as "COL_6","#, + r#"total ("total_64") as "COL_7", min ("min_67") as "COL_8", max ("max_70") as "COL_9""#, + r#"FROM (SELECT "sys_op","sum_42","count_37","sum_49","count_51","group_concat_58","count_61","total_64","min_67","max_70" FROM "TMP_test_70")"#, r#"GROUP BY "column_12""# ), - vec![Value::Unsigned(2)] + vec![Value::Unsigned(2), Value::from("o")] ) ); } diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs index b7e094ec30..36b5e37440 100644 --- a/sbroad-core/src/frontend/sql.rs +++ b/sbroad-core/src/frontend/sql.rs @@ -21,6 +21,7 @@ use crate::ir::value::Value; use crate::ir::{Node, Plan}; use crate::otm::child_span; +use crate::ir::aggregates::AggregateKind; use sbroad_proc::otm_child_span; /// Helper structure to fix the double linking @@ -738,28 +739,13 @@ impl Ast for AbstractSyntaxTree { plan_arg_list.push(plan_child_id); } - if Expression::is_aggregate_name(function_name) { - if plan_arg_list.len() != 1 { - return Err(SbroadError::Invalid( - Entity::Query, - Some(format!( - "Expected one argument for aggregate: {function_name}." - )), - )); - } - let argument = *plan_arg_list.first().ok_or_else(|| { - SbroadError::Invalid( - Entity::Query, - Some(format!( - "aggregate function {function_name} has no arguments!" - )), - ) - })?; + if let Some(kind) = AggregateKind::new(function_name) { let plan_id = plan.add_aggregate_function( &function_name.to_string(), - argument, + kind, + plan_arg_list.clone(), is_distinct, - ); + )?; map.add(id, plan_id); continue; } else if is_distinct { @@ -1039,6 +1025,14 @@ impl Ast for AbstractSyntaxTree { })?; map.add(0, map.get(*ast_child_id)?); } + Type::SingleQuotedString => { + let ast_child_id = node.children.first().ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "SingleQuotedString has no children.".into(), + ) + })?; + map.add(id, map.get(*ast_child_id)?); + } Type::CountAsterisk => { let plan_id = plan.nodes.push(Node::Expression(Expression::CountAsterisk)); map.add(id, plan_id); diff --git a/sbroad-core/src/frontend/sql/ast.rs b/sbroad-core/src/frontend/sql/ast.rs index 5f9619a63b..c559a831b2 100644 --- a/sbroad-core/src/frontend/sql/ast.rs +++ b/sbroad-core/src/frontend/sql/ast.rs @@ -85,6 +85,7 @@ pub enum Type { Select, Selection, String, + SingleQuotedString, SubQuery, Subtract, Table, @@ -172,6 +173,7 @@ impl Type { Rule::Select => Ok(Type::Select), Rule::Selection => Ok(Type::Selection), Rule::String => Ok(Type::String), + Rule::SingleQuotedString => Ok(Type::SingleQuotedString), Rule::SubQuery => Ok(Type::SubQuery), Rule::Subtract => Ok(Type::Subtract), Rule::Table => Ok(Type::Table), @@ -262,6 +264,7 @@ impl fmt::Display for Type { Type::Select => "Select".to_string(), Type::Selection => "Selection".to_string(), Type::String => "String".to_string(), + Type::SingleQuotedString => "SingleQuotedString".to_string(), Type::SubQuery => "SubQuery".to_string(), Type::Subtract => "Subtract".to_string(), Type::Table => "Table".to_string(), diff --git a/sbroad-core/src/frontend/sql/ast/tests.rs b/sbroad-core/src/frontend/sql/ast/tests.rs index aad6ea0e01..ced600ae83 100644 --- a/sbroad-core/src/frontend/sql/ast/tests.rs +++ b/sbroad-core/src/frontend/sql/ast/tests.rs @@ -145,6 +145,22 @@ fn transform_select_6() { assert_eq!(expected, ast); } +#[test] +fn single_quoted_str_ast() { + let query = r#"SELECT ' ' FROM "test_space""#; + let ast = AbstractSyntaxTree::new(query).unwrap(); + + let path = Path::new("") + .join("tests") + .join("artifactory") + .join("frontend") + .join("sql") + .join("single_quoted_str_ast.yaml"); + let s = fs::read_to_string(path).unwrap(); + let expected: AbstractSyntaxTree = AbstractSyntaxTree::from_yaml(&s).unwrap(); + assert_eq!(expected, ast); +} + #[test] fn traversal() { let query = r#"select a from t where a = 1"#; diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs index b9b567892d..2606e50f84 100644 --- a/sbroad-core/src/frontend/sql/ir/tests.rs +++ b/sbroad-core/src/frontend/sql/ir/tests.rs @@ -697,6 +697,120 @@ fn front_sql_aggregates() { assert_eq!(expected_explain, plan.as_explain().unwrap()); } +#[test] +fn front_sql_avg_aggregate() { + let input = r#"SELECT avg("b"), avg(distinct "b"), avg("b") * avg("b") FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection ((sum(("sum_13"::double)) / sum(("count_13"::double))) -> "COL_1", avg(distinct ("column_15"::double)) -> "COL_2", ((sum(("sum_13"::double)) / sum(("count_13"::double)))) * ((sum(("sum_13"::double)) / sum(("count_13"::double)))) -> "COL_3") + motion [policy: full] + scan + projection ("t"."b" -> "column_15", sum(("t"."b")) -> "sum_13", count(("t"."b")) -> "count_13") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_total_aggregate() { + let input = r#"SELECT total("b"), total(distinct "b") FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (total(("total_13")) -> "COL_1", total(distinct ("column_15")) -> "COL_2") + motion [policy: full] + scan + projection ("t"."b" -> "column_15", total(("t"."b")) -> "total_13") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_min_aggregate() { + let input = r#"SELECT min("b"), min(distinct "b") FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (min(("min_13")) -> "COL_1", min(distinct ("column_15")) -> "COL_2") + motion [policy: full] + scan + projection ("t"."b" -> "column_15", min(("t"."b")) -> "min_13") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_max_aggregate() { + let input = r#"SELECT max("b"), max(distinct "b") FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (max(("max_13")) -> "COL_1", max(distinct ("column_15")) -> "COL_2") + motion [policy: full] + scan + projection ("t"."b" -> "column_15", max(("t"."b")) -> "max_13") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_group_concat_aggregate() { + let input = r#"SELECT group_concat("b"), group_concat(distinct "b") FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (group_concat(("group_concat_13")) -> "COL_1", group_concat(distinct ("column_15")) -> "COL_2") + motion [policy: full] + scan + projection ("t"."b" -> "column_15", group_concat(("t"."b")) -> "group_concat_13") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + +#[test] +fn front_sql_group_concat_aggregate2() { + let input = r#"SELECT group_concat("b", ' '), group_concat(distinct "b") FROM "t""#; + + let plan = sql_to_optimized_ir(input, vec![]); + + let expected_explain = String::from( + r#"projection (group_concat(("group_concat_14", ' ')) -> "COL_1", group_concat(distinct ("column_16")) -> "COL_2") + motion [policy: full] + scan + projection ("t"."b" -> "column_16", group_concat(("t"."b", ' ')) -> "group_concat_14") + group by ("t"."b") output: ("t"."a" -> "a", "t"."b" -> "b", "t"."c" -> "c", "t"."d" -> "d", "t"."bucket_id" -> "bucket_id") + scan "t" +"#, + ); + + assert_eq!(expected_explain, plan.as_explain().unwrap()); +} + #[test] fn front_sql_count_asterisk1() { let input = r#"SELECT count(*), count(*) FROM "t""#; diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest index ab6015aac7..976ad55961 100644 --- a/sbroad-core/src/frontend/sql/query.pest +++ b/sbroad-core/src/frontend/sql/query.pest @@ -147,7 +147,7 @@ Value = _{ Parameter | Row | True | False | Null | Decimal | Double | Unsigned | Double = @{ Integer ~ ("." ~ ASCII_DIGIT*)? ~ (^"e" ~ Integer) } Integer = @{ ("+" | "-")? ~ ASCII_DIGIT+ } Unsigned = @{ ASCII_DIGIT+ } - SingleQuotedString = _{ "'" ~ String ~ "'" } + SingleQuotedString = ${ "'" ~ String ~ "'" } Row = { ("(" ~ Value ~ ("," ~ Value)* ~ ")") | (^"row" ~ "(" ~ Value ~ ("," ~ Value)* ~ ")") diff --git a/sbroad-core/src/ir/aggregates.rs b/sbroad-core/src/ir/aggregates.rs index 80a599124b..5fb60d5004 100644 --- a/sbroad-core/src/ir/aggregates.rs +++ b/sbroad-core/src/ir/aggregates.rs @@ -1,6 +1,7 @@ use crate::errors::{Entity, SbroadError}; +use crate::ir::expression::cast::Type; use crate::ir::expression::Expression; -use crate::ir::expression::Expression::StableFunction; +use crate::ir::operator::Arithmetic; use crate::ir::{Node, Plan}; use std::collections::HashMap; use std::fmt::{Display, Formatter}; @@ -13,6 +14,11 @@ use std::rc::Rc; pub enum AggregateKind { COUNT, SUM, + AVG, + TOTAL, + MIN, + MAX, + GRCONCAT, } impl Display for AggregateKind { @@ -20,6 +26,11 @@ impl Display for AggregateKind { let name = match self { AggregateKind::COUNT => "count", AggregateKind::SUM => "sum", + AggregateKind::AVG => "avg", + AggregateKind::TOTAL => "total", + AggregateKind::MIN => "min", + AggregateKind::MAX => "max", + AggregateKind::GRCONCAT => "group_concat", }; write!(f, "{name}") } @@ -32,6 +43,11 @@ impl AggregateKind { match normalized.as_str() { "count" => Some(AggregateKind::COUNT), "sum" => Some(AggregateKind::SUM), + "avg" => Some(AggregateKind::AVG), + "total" => Some(AggregateKind::TOTAL), + "min" => Some(AggregateKind::MIN), + "max" => Some(AggregateKind::MAX), + "group_concat" => Some(AggregateKind::GRCONCAT), _ => None, } } @@ -41,6 +57,11 @@ impl AggregateKind { match self { AggregateKind::COUNT => vec![AggregateKind::COUNT], AggregateKind::SUM => vec![AggregateKind::SUM], + AggregateKind::AVG => vec![AggregateKind::SUM, AggregateKind::COUNT], + AggregateKind::TOTAL => vec![AggregateKind::TOTAL], + AggregateKind::MIN => vec![AggregateKind::MIN], + AggregateKind::MAX => vec![AggregateKind::MAX], + AggregateKind::GRCONCAT => vec![AggregateKind::GRCONCAT], } } @@ -53,8 +74,12 @@ impl AggregateKind { local_aggregate: &AggregateKind, ) -> Result<AggregateKind, SbroadError> { let res = match (self, local_aggregate) { - (AggregateKind::COUNT, AggregateKind::COUNT) - | (AggregateKind::SUM, AggregateKind::SUM) => AggregateKind::SUM, + (AggregateKind::COUNT | AggregateKind::AVG, AggregateKind::COUNT) + | (AggregateKind::SUM | AggregateKind::AVG, AggregateKind::SUM) => AggregateKind::SUM, + (AggregateKind::TOTAL, AggregateKind::TOTAL) => AggregateKind::TOTAL, + (AggregateKind::MIN, AggregateKind::MIN) => AggregateKind::MIN, + (AggregateKind::MAX, AggregateKind::MAX) => AggregateKind::MAX, + (AggregateKind::GRCONCAT, AggregateKind::GRCONCAT) => AggregateKind::GRCONCAT, (_, _) => { return Err(SbroadError::Invalid( Entity::Aggregate, @@ -130,12 +155,31 @@ impl SimpleAggregate { } impl SimpleAggregate { - /// Create columns with final aggregates in final `Projection` + /// Create final aggregate expression and return its id + /// + /// # Examples + /// Suppose this aggregate is non-distinct `AVG` and at local stage + /// `SUM` and `COUNT` were computed with corresponding local + /// aliases `sum_1` and `count_1`, then this function + /// will create the following expression: + /// + /// ```txt + /// sum(sum_1) / sum(count_1) + /// ``` + /// + /// If we had `AVG(distinct a)` in user query, then at local stage + /// we must have used `a` as `group by` expression and assign it + /// a local alias. Let's say local alias is `column_1`, then this + /// function will create the following expression: + /// + /// ```txt + /// avg(column_1) + /// ``` /// /// # Errors /// - Invalid aggregate /// - Could not find local alias position in child output - /// + #[allow(clippy::too_many_lines)] pub fn create_column_for_final_projection( &self, parent: usize, @@ -143,15 +187,17 @@ impl SimpleAggregate { alias_to_pos: &HashMap<String, usize>, is_distinct: bool, ) -> Result<usize, SbroadError> { - let mut final_aggregates: Vec<usize> = vec![]; + // map local AggregateKind to finalised expression of that aggregate + let mut final_aggregates: HashMap<AggregateKind, usize> = HashMap::new(); let mut create_final_aggr = |local_alias: &str, + local_kind: AggregateKind, final_func: &str| -> Result<(), SbroadError> { let Some(position) = alias_to_pos.get(local_alias) else { let parent_node = plan.get_relation_node(parent)?; return Err(SbroadError::Invalid( Entity::Node, - Some(format!("could find aggregate column in final {parent_node:?} child by local alias: {local_alias}. Aliases: {alias_to_pos:?}")))) + Some(format!("could not find aggregate column in final {parent_node:?} child by local alias: {local_alias}. Aliases: {alias_to_pos:?}")))) }; let ref_node = Expression::Reference { parent: Some(parent), @@ -160,13 +206,42 @@ impl SimpleAggregate { position: *position, }; let ref_id = plan.nodes.push(Node::Expression(ref_node)); - let final_aggr = StableFunction { + let children = match self.kind { + AggregateKind::AVG => vec![plan.add_cast(ref_id, Type::Double)?], + AggregateKind::GRCONCAT => { + if let Expression::StableFunction { children, .. } = + plan.get_expression_node(self.fun_id)? + { + if children.len() > 1 { + let second_arg = { + let a = *children + .get(1) + .ok_or(SbroadError::Invalid(Entity::Aggregate, None))?; + plan.clone_expr_subtree(a)? + }; + vec![ref_id, second_arg] + } else { + vec![ref_id] + } + } else { + return Err(SbroadError::Invalid( + Entity::Aggregate, + Some(format!( + "fun_id ({}) points to other expression node", + self.fun_id + )), + )); + } + } + _ => vec![ref_id], + }; + let final_aggr = Expression::StableFunction { name: final_func.to_string(), - children: vec![ref_id], + children, is_distinct, }; let aggr_id = plan.nodes.push(Node::Expression(final_aggr)); - final_aggregates.push(aggr_id); + final_aggregates.insert(local_kind, aggr_id); Ok(()) }; if is_distinct { @@ -179,7 +254,7 @@ impl SimpleAggregate { ) })?; let final_aggregate_name = self.kind.to_string(); - create_final_aggr(local_alias, final_aggregate_name.as_str())?; + create_final_aggr(local_alias, self.kind, final_aggregate_name.as_str())?; } else { for aggr_kind in self.kind.get_local_aggregates_kinds() { let local_alias = self.lagg_alias.get(&aggr_kind).ok_or_else(|| { @@ -192,20 +267,42 @@ impl SimpleAggregate { })?; let final_aggregate_name = self.kind.get_final_aggregate_kind(&aggr_kind)?.to_string(); - create_final_aggr(local_alias, final_aggregate_name.as_str())?; + create_final_aggr(local_alias, aggr_kind, final_aggregate_name.as_str())?; } } let final_expr_id = if final_aggregates.len() == 1 { - *final_aggregates.first().ok_or_else(|| { - SbroadError::UnexpectedNumberOfValues("final_aggregates is empty".into()) - })? + *final_aggregates + .values() + .into_iter() + .next() + .ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues("final_aggregates is empty".into()) + })? } else { - return Err(SbroadError::Unsupported( - Entity::Aggregate, - Some(format!( - "aggregate with multiple final aggregates: {self:?}" - )), - )); + match self.kind { + AggregateKind::AVG => { + let sum_aggr = *final_aggregates.get(&AggregateKind::SUM).ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "final_aggregates: missing final aggregate for SUM".into(), + ) + })?; + let count_aggr = + *final_aggregates.get(&AggregateKind::COUNT).ok_or_else(|| { + SbroadError::UnexpectedNumberOfValues( + "final_aggregates: missing final aggregate for COUNT".into(), + ) + })?; + plan.add_arithmetic_to_plan(sum_aggr, Arithmetic::Divide, count_aggr, true)? + } + _ => { + return Err(SbroadError::Unsupported( + Entity::Aggregate, + Some(format!( + "aggregate with multiple final aggregates: {self:?}" + )), + )) + } + } }; Ok(final_expr_id) } diff --git a/sbroad-core/src/ir/explain.rs b/sbroad-core/src/ir/explain.rs index 27f81b9937..9c8da5f2d1 100644 --- a/sbroad-core/src/ir/explain.rs +++ b/sbroad-core/src/ir/explain.rs @@ -137,6 +137,7 @@ impl ColExpr { args.push(arg); len -= 1; } + args.reverse(); let args_expr = ColExpr::Row(args); let func_expr = ColExpr::StableFunction(name.clone(), Box::new(args_expr), *is_distinct); diff --git a/sbroad-core/src/ir/function.rs b/sbroad-core/src/ir/function.rs index 22e50e6851..9e5c563fcf 100644 --- a/sbroad-core/src/ir/function.rs +++ b/sbroad-core/src/ir/function.rs @@ -1,4 +1,5 @@ use crate::errors::{Entity, SbroadError}; +use crate::ir::aggregates::AggregateKind; use crate::ir::expression::Expression; use crate::ir::{Node, Plan}; use serde::{Deserialize, Serialize}; @@ -65,18 +66,48 @@ impl Plan { /// Add aggregate function to plan /// /// # Errors - /// - + /// - Invalid arguments for given aggregate function pub fn add_aggregate_function( &mut self, function: &str, - child: usize, + kind: AggregateKind, + children: Vec<usize>, is_distinct: bool, - ) -> usize { + ) -> Result<usize, SbroadError> { + match kind { + AggregateKind::GRCONCAT => { + if children.len() > 2 || children.is_empty() { + return Err(SbroadError::Invalid( + Entity::Query, + Some(format!( + "GROUP_CONCAT aggregate function can have one or two arguments at most. Got: {} arguments", children.len() + )), + )); + } + if is_distinct && children.len() == 2 { + return Err(SbroadError::Invalid( + Entity::Query, + Some(format!( + "distinct GROUP_CONCAT aggregate function has only one argument. Got: {} arguments", children.len() + )), + )); + } + } + _ => { + if children.len() != 1 { + return Err(SbroadError::Invalid( + Entity::Query, + Some(format!("Expected one argument for aggregate: {function}.")), + )); + } + } + } let func_expr = Expression::StableFunction { name: function.to_lowercase(), - children: vec![child], + children, is_distinct, }; - self.nodes.push(Node::Expression(func_expr)) + let id = self.nodes.push(Node::Expression(func_expr)); + Ok(id) } } diff --git a/sbroad-core/src/ir/transformation/redistribution/groupby.rs b/sbroad-core/src/ir/transformation/redistribution/groupby.rs index cab162fb2e..6d9cbbc5b5 100644 --- a/sbroad-core/src/ir/transformation/redistribution/groupby.rs +++ b/sbroad-core/src/ir/transformation/redistribution/groupby.rs @@ -952,20 +952,14 @@ impl Plan { arguments: &[usize], local_alias: &str, ) -> Result<usize, SbroadError> { - // Currently all supported aggregate functions take only one argument - let aggregate_expression = *arguments.first().ok_or_else(|| { - SbroadError::UnexpectedNumberOfValues(format!( - "create_local_aggregate: Aggregate function has no children: {kind}, {local_alias}" - )) - })?; let fun = Function { name: kind.to_string(), behavior: Behavior::Stable, }; - // We can reuse `aggregate_expression` between local aggregates, because + // We can reuse aggregate expression between local aggregates, because // all local aggregates are located inside the same motion subtree and we // assume that each local aggregate does not need to modify its expression - let local_fun_id = self.add_stable_function(&fun, vec![aggregate_expression])?; + let local_fun_id = self.add_stable_function(&fun, arguments.to_vec())?; let alias_id = self.nodes.add_alias(local_alias, local_fun_id)?; Ok(alias_id) } @@ -1046,7 +1040,7 @@ impl Plan { .nodes .expr_iter(info.aggr.fun_id, false) .collect::<Vec<usize>>(); - if args.len() > 1 { + if args.len() > 1 && !matches!(info.aggr.kind, AggregateKind::GRCONCAT) { return Err(SbroadError::UnexpectedNumberOfValues(format!( "aggregate ({info:?}) have more than one argument" ))); diff --git a/sbroad-core/tests/artifactory/frontend/sql/single_quoted_str_ast.yaml b/sbroad-core/tests/artifactory/frontend/sql/single_quoted_str_ast.yaml new file mode 100644 index 0000000000..c7abe3ff9a --- /dev/null +++ b/sbroad-core/tests/artifactory/frontend/sql/single_quoted_str_ast.yaml @@ -0,0 +1,40 @@ +--- +nodes: + arena: + - children: #0 + - 3 + rule: Select + value: ~ + - children: #1 + - 2 + rule: Scan + value: ~ + - children: [] #2 + rule: Table + value: "\"test_space\"" + - children: #3 + - 1 + - 4 + rule: Projection + value: ~ + - children: #4 + - 8 + rule: Column + value: ~ + - children: #5 + - 6 + rule: SingleQuotedString + value: ~ + - children: [] #6 + rule: String + value: " " + - children: [] #7 + rule: AliasName + value: COL_1 + - children: #8 + - 5 + - 7 + rule: Alias + value: ~ +top: 3 +map: {} -- GitLab