From 67ae2d92d406ceb4b27836014dfed917c7934097 Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <vol0ncar@yandex.ru> Date: Tue, 4 Jun 2024 07:48:21 +0000 Subject: [PATCH] fix: throw error on avg/sum/total on number column --- .../test/integration/groupby_test.lua | 26 ++++++++ sbroad-core/src/executor/tests/exec_plan.rs | 8 +-- sbroad-core/src/frontend/sql/ir/tests.rs | 17 ++--- sbroad-core/src/ir/aggregates.rs | 63 +++++++++++++++++++ sbroad-core/src/ir/function.rs | 4 ++ 5 files changed, 106 insertions(+), 12 deletions(-) diff --git a/sbroad-cartridge/test_app/test/integration/groupby_test.lua b/sbroad-cartridge/test_app/test/integration/groupby_test.lua index 5d2df78e8..5d8fc3359 100644 --- a/sbroad-cartridge/test_app/test/integration/groupby_test.lua +++ b/sbroad-cartridge/test_app/test/integration/groupby_test.lua @@ -2525,3 +2525,29 @@ groupby_queries.test_having_inside_except1 = function() {2}, }) end + +groupby_queries.test_sum_on_decimal_col = function() + local api = cluster:server("api-1").net_box + + local _, err = api:call("sbroad.execute", { + [[ + select sum("number_col") as s from "arithmetic_space" + ]], {} + }) + t.assert_str_contains(tostring(err), "can't compute sum on argument with type number") + + _, err = api:call("sbroad.execute", { + [[ + select total("number_col") as s from "arithmetic_space" + ]], {} + }) + t.assert_str_contains(tostring(err), "can't compute total on argument with type number") + + _, err = api:call("sbroad.execute", { + [[ + select avg("number_col") as s from "arithmetic_space" + ]], {} + }) + t.assert_str_contains(tostring(err), "can't compute avg on argument with type number") +end + diff --git a/sbroad-core/src/executor/tests/exec_plan.rs b/sbroad-core/src/executor/tests/exec_plan.rs index 5a783aadc..603314b7f 100644 --- a/sbroad-core/src/executor/tests/exec_plan.rs +++ b/sbroad-core/src/executor/tests/exec_plan.rs @@ -240,7 +240,7 @@ fn exec_plan_subtree_two_stage_groupby_test_2() { fn exec_plan_subtree_aggregates() { let sql = r#"SELECT t1."sys_op" || t1."sys_op", t1."sys_op"*2 + count(t1."sysFrom"), sum(t1."id"), sum(distinct t1."id"*t1."sys_op") / count(distinct "id"), - group_concat(t1."id", 'o'), avg(t1."id"), total(t1."id"), min(t1."id"), max(t1."id") + group_concat(t1."FIRST_NAME", 'o'), avg(t1."id"), total(t1."id"), min(t1."id"), max(t1."id") FROM "test_space" as t1 group by t1."sys_op""#; let coordinator = RouterRuntimeMock::new(); @@ -298,9 +298,9 @@ fn exec_plan_subtree_aggregates() { format!( "{} {} {} {} {} {}", r#"SELECT "T1"."sys_op" as "column_12", ("T1"."id") * ("T1"."sys_op") as "column_49","#, - r#""T1"."id" as "column_46", total ("T1"."id") as "total_64","#, - r#"sum ("T1"."id") as "sum_42", count ("T1"."id") as "count_61", group_concat ("T1"."id", ?) as "group_concat_58","#, - r#"min ("T1"."id") as "min_67", max ("T1"."id") as "max_70", count ("T1"."sysFrom") as "count_37""#, + r#""T1"."id" as "column_46", count ("T1"."sysFrom") as "count_37","#, + r#"sum ("T1"."id") as "sum_42", count ("T1"."id") as "count_61", total ("T1"."id") as "total_64","#, + r#"min ("T1"."id") as "min_67", max ("T1"."id") as "max_70", group_concat ("T1"."FIRST_NAME", ?) as "group_concat_58""#, r#"FROM "test_space" as "T1""#, r#"GROUP BY "T1"."sys_op", ("T1"."id") * ("T1"."sys_op"), "T1"."id""#, ), diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs index 3bf492d1a..95b34689d 100644 --- a/sbroad-core/src/frontend/sql/ir/tests.rs +++ b/sbroad-core/src/frontend/sql/ir/tests.rs @@ -1504,7 +1504,7 @@ vtable_max_rows = 5000 #[test] fn front_sql_group_concat_aggregate() { - let input = r#"SELECT group_concat("b"), group_concat(distinct "b") FROM "t""#; + let input = r#"SELECT group_concat("FIRST_NAME"), group_concat(distinct "FIRST_NAME") FROM "test_space""#; let plan = sql_to_optimized_ir(input, vec![]); @@ -1512,9 +1512,9 @@ fn front_sql_group_concat_aggregate() { r#"projection (group_concat(("group_concat_13"::string))::string -> "COL_1", group_concat(distinct ("column_15"::string))::string -> "COL_2") motion [policy: full] scan - projection ("t"."b"::unsigned -> "column_15", group_concat(("t"."b"::unsigned))::string -> "group_concat_13") - group by ("t"."b"::unsigned) output: ("t"."a"::unsigned -> "a", "t"."b"::unsigned -> "b", "t"."c"::unsigned -> "c", "t"."d"::unsigned -> "d", "t"."bucket_id"::unsigned -> "bucket_id") - scan "t" + projection ("test_space"."FIRST_NAME"::string -> "column_15", group_concat(("test_space"."FIRST_NAME"::string))::string -> "group_concat_13") + group by ("test_space"."FIRST_NAME"::string) output: ("test_space"."id"::unsigned -> "id", "test_space"."sysFrom"::unsigned -> "sysFrom", "test_space"."FIRST_NAME"::string -> "FIRST_NAME", "test_space"."sys_op"::unsigned -> "sys_op", "test_space"."bucket_id"::unsigned -> "bucket_id") + scan "test_space" execution options: sql_vdbe_max_steps = 45000 vtable_max_rows = 5000 @@ -1526,17 +1526,18 @@ vtable_max_rows = 5000 #[test] fn front_sql_group_concat_aggregate2() { - let input = r#"SELECT group_concat("b", ' '), group_concat(distinct "b") FROM "t""#; + let input = r#"SELECT group_concat("FIRST_NAME", ' '), group_concat(distinct "FIRST_NAME") FROM "test_space""#; let plan = sql_to_optimized_ir(input, vec![]); + println!("{}", plan.as_explain().unwrap()); let expected_explain = String::from( r#"projection (group_concat(("group_concat_14"::string, ' '::string))::string -> "COL_1", group_concat(distinct ("column_16"::string))::string -> "COL_2") motion [policy: full] scan - projection ("t"."b"::unsigned -> "column_16", group_concat(("t"."b"::unsigned, ' '::string))::string -> "group_concat_14") - group by ("t"."b"::unsigned) output: ("t"."a"::unsigned -> "a", "t"."b"::unsigned -> "b", "t"."c"::unsigned -> "c", "t"."d"::unsigned -> "d", "t"."bucket_id"::unsigned -> "bucket_id") - scan "t" + projection ("test_space"."FIRST_NAME"::string -> "column_16", group_concat(("test_space"."FIRST_NAME"::string, ' '::string))::string -> "group_concat_14") + group by ("test_space"."FIRST_NAME"::string) output: ("test_space"."id"::unsigned -> "id", "test_space"."sysFrom"::unsigned -> "sysFrom", "test_space"."FIRST_NAME"::string -> "FIRST_NAME", "test_space"."sys_op"::unsigned -> "sys_op", "test_space"."bucket_id"::unsigned -> "bucket_id") + scan "test_space" execution options: sql_vdbe_max_steps = 45000 vtable_max_rows = 5000 diff --git a/sbroad-core/src/ir/aggregates.rs b/sbroad-core/src/ir/aggregates.rs index 4eff3fb95..5adeab519 100644 --- a/sbroad-core/src/ir/aggregates.rs +++ b/sbroad-core/src/ir/aggregates.rs @@ -82,6 +82,69 @@ impl AggregateKind { } } + /// Check agruments types of aggregate function + /// + /// # Errors + /// - Invlid plan/aggregate + /// - Invalid argument type + /// + /// # Panics + /// - Invalid argument count for aggregate + pub fn check_args_types(&self, plan: &Plan, args: &[usize]) -> Result<(), SbroadError> { + use crate::ir::relation::Type; + let get_arg_type = |idx: usize| -> Result<Type, SbroadError> { + let arg_id = *args.get(idx).expect("wrong agregate"); + let expr = plan.get_expression_node(arg_id)?; + expr.calculate_type(plan) + }; + let err = |arg_type: &Type| -> Result<(), SbroadError> { + Err(SbroadError::Invalid( + Entity::Query, + Some(format_smolstr!( + "can't compute {self} on argument with type {arg_type}. \ + Consider using explicit cast." + )), + )) + }; + match self { + AggregateKind::SUM | AggregateKind::AVG | AggregateKind::TOTAL => { + let arg_type = get_arg_type(0)?; + if !matches!( + arg_type, + Type::Decimal | Type::Double | Type::Unsigned | Type::Integer + ) { + err(&arg_type)?; + } + } + AggregateKind::MIN | AggregateKind::MAX => { + let arg_type = get_arg_type(0)?; + if !arg_type.is_scalar() { + err(&arg_type)?; + } + } + AggregateKind::GRCONCAT => { + let first_type = get_arg_type(0)?; + if args.len() == 2 { + let second_type = get_arg_type(1)?; + if first_type != second_type { + return Err(SbroadError::Invalid( + Entity::Query, + Some( + "group concat requires both arguments to be of the same type!" + .into(), + ), + )); + } + } + if !matches!(first_type, Type::String) { + err(&first_type)?; + } + } + AggregateKind::COUNT => {} + } + Ok(()) + } + /// Get final aggregate corresponding to given local aggregate /// /// # Errors diff --git a/sbroad-core/src/ir/function.rs b/sbroad-core/src/ir/function.rs index 289867c78..0ee2e10e5 100644 --- a/sbroad-core/src/ir/function.rs +++ b/sbroad-core/src/ir/function.rs @@ -78,6 +78,9 @@ impl Plan { /// /// # Errors /// - Invalid arguments for given aggregate function + /// + /// # Panics + /// - never pub fn add_aggregate_function( &mut self, function: &str, @@ -115,6 +118,7 @@ impl Plan { } } } + kind.check_args_types(self, &children)?; let feature = if is_distinct { Some(FunctionFeature::Distinct) } else { -- GitLab