From 67ae2d92d406ceb4b27836014dfed917c7934097 Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Tue, 4 Jun 2024 07:48:21 +0000
Subject: [PATCH] fix: throw error on avg/sum/total on number column

---
 .../test/integration/groupby_test.lua         | 26 ++++++++
 sbroad-core/src/executor/tests/exec_plan.rs   |  8 +--
 sbroad-core/src/frontend/sql/ir/tests.rs      | 17 ++---
 sbroad-core/src/ir/aggregates.rs              | 63 +++++++++++++++++++
 sbroad-core/src/ir/function.rs                |  4 ++
 5 files changed, 106 insertions(+), 12 deletions(-)

diff --git a/sbroad-cartridge/test_app/test/integration/groupby_test.lua b/sbroad-cartridge/test_app/test/integration/groupby_test.lua
index 5d2df78e8..5d8fc3359 100644
--- a/sbroad-cartridge/test_app/test/integration/groupby_test.lua
+++ b/sbroad-cartridge/test_app/test/integration/groupby_test.lua
@@ -2525,3 +2525,29 @@ groupby_queries.test_having_inside_except1 = function()
         {2},
     })
 end
+
+groupby_queries.test_sum_on_decimal_col = function()
+    local api = cluster:server("api-1").net_box
+
+    local _, err = api:call("sbroad.execute", {
+        [[
+            select sum("number_col") as s from "arithmetic_space"
+        ]], {}
+    })
+    t.assert_str_contains(tostring(err), "can't compute sum on argument with type number")
+
+    _, err = api:call("sbroad.execute", {
+        [[
+            select total("number_col") as s from "arithmetic_space"
+        ]], {}
+    })
+    t.assert_str_contains(tostring(err), "can't compute total on argument with type number")
+
+    _, err = api:call("sbroad.execute", {
+        [[
+            select avg("number_col") as s from "arithmetic_space"
+        ]], {}
+    })
+    t.assert_str_contains(tostring(err), "can't compute avg on argument with type number")
+end
+
diff --git a/sbroad-core/src/executor/tests/exec_plan.rs b/sbroad-core/src/executor/tests/exec_plan.rs
index 5a783aadc..603314b7f 100644
--- a/sbroad-core/src/executor/tests/exec_plan.rs
+++ b/sbroad-core/src/executor/tests/exec_plan.rs
@@ -240,7 +240,7 @@ fn exec_plan_subtree_two_stage_groupby_test_2() {
 fn exec_plan_subtree_aggregates() {
     let sql = r#"SELECT t1."sys_op" || t1."sys_op", t1."sys_op"*2 + count(t1."sysFrom"),
                       sum(t1."id"), sum(distinct t1."id"*t1."sys_op") / count(distinct "id"),
-                      group_concat(t1."id", 'o'), avg(t1."id"), total(t1."id"), min(t1."id"), max(t1."id")
+                      group_concat(t1."FIRST_NAME", 'o'), avg(t1."id"), total(t1."id"), min(t1."id"), max(t1."id")
                       FROM "test_space" as t1 group by t1."sys_op""#;
     let coordinator = RouterRuntimeMock::new();
 
@@ -298,9 +298,9 @@ fn exec_plan_subtree_aggregates() {
             format!(
                 "{} {} {} {} {} {}",
                 r#"SELECT "T1"."sys_op" as "column_12", ("T1"."id") * ("T1"."sys_op") as "column_49","#,
-                r#""T1"."id" as "column_46", total ("T1"."id") as "total_64","#,
-                r#"sum ("T1"."id") as "sum_42", count ("T1"."id") as "count_61", group_concat ("T1"."id", ?) as "group_concat_58","#,
-                r#"min ("T1"."id") as "min_67", max ("T1"."id") as "max_70", count ("T1"."sysFrom") as "count_37""#,
+                r#""T1"."id" as "column_46", count ("T1"."sysFrom") as "count_37","#,
+                r#"sum ("T1"."id") as "sum_42", count ("T1"."id") as "count_61", total ("T1"."id") as "total_64","#,
+                r#"min ("T1"."id") as "min_67", max ("T1"."id") as "max_70", group_concat ("T1"."FIRST_NAME", ?) as "group_concat_58""#,
                 r#"FROM "test_space" as "T1""#,
                 r#"GROUP BY "T1"."sys_op", ("T1"."id") * ("T1"."sys_op"), "T1"."id""#,
             ),
diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs
index 3bf492d1a..95b34689d 100644
--- a/sbroad-core/src/frontend/sql/ir/tests.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests.rs
@@ -1504,7 +1504,7 @@ vtable_max_rows = 5000
 
 #[test]
 fn front_sql_group_concat_aggregate() {
-    let input = r#"SELECT group_concat("b"), group_concat(distinct "b") FROM "t""#;
+    let input = r#"SELECT group_concat("FIRST_NAME"), group_concat(distinct "FIRST_NAME") FROM "test_space""#;
 
     let plan = sql_to_optimized_ir(input, vec![]);
 
@@ -1512,9 +1512,9 @@ fn front_sql_group_concat_aggregate() {
         r#"projection (group_concat(("group_concat_13"::string))::string -> "COL_1", group_concat(distinct ("column_15"::string))::string -> "COL_2")
     motion [policy: full]
         scan
-            projection ("t"."b"::unsigned -> "column_15", group_concat(("t"."b"::unsigned))::string -> "group_concat_13")
-                group by ("t"."b"::unsigned) output: ("t"."a"::unsigned -> "a", "t"."b"::unsigned -> "b", "t"."c"::unsigned -> "c", "t"."d"::unsigned -> "d", "t"."bucket_id"::unsigned -> "bucket_id")
-                    scan "t"
+            projection ("test_space"."FIRST_NAME"::string -> "column_15", group_concat(("test_space"."FIRST_NAME"::string))::string -> "group_concat_13")
+                group by ("test_space"."FIRST_NAME"::string) output: ("test_space"."id"::unsigned -> "id", "test_space"."sysFrom"::unsigned -> "sysFrom", "test_space"."FIRST_NAME"::string -> "FIRST_NAME", "test_space"."sys_op"::unsigned -> "sys_op", "test_space"."bucket_id"::unsigned -> "bucket_id")
+                    scan "test_space"
 execution options:
 sql_vdbe_max_steps = 45000
 vtable_max_rows = 5000
@@ -1526,17 +1526,18 @@ vtable_max_rows = 5000
 
 #[test]
 fn front_sql_group_concat_aggregate2() {
-    let input = r#"SELECT group_concat("b", ' '), group_concat(distinct "b") FROM "t""#;
+    let input = r#"SELECT group_concat("FIRST_NAME", ' '), group_concat(distinct "FIRST_NAME") FROM "test_space""#;
 
     let plan = sql_to_optimized_ir(input, vec![]);
 
+    println!("{}", plan.as_explain().unwrap());
     let expected_explain = String::from(
         r#"projection (group_concat(("group_concat_14"::string, ' '::string))::string -> "COL_1", group_concat(distinct ("column_16"::string))::string -> "COL_2")
     motion [policy: full]
         scan
-            projection ("t"."b"::unsigned -> "column_16", group_concat(("t"."b"::unsigned, ' '::string))::string -> "group_concat_14")
-                group by ("t"."b"::unsigned) output: ("t"."a"::unsigned -> "a", "t"."b"::unsigned -> "b", "t"."c"::unsigned -> "c", "t"."d"::unsigned -> "d", "t"."bucket_id"::unsigned -> "bucket_id")
-                    scan "t"
+            projection ("test_space"."FIRST_NAME"::string -> "column_16", group_concat(("test_space"."FIRST_NAME"::string, ' '::string))::string -> "group_concat_14")
+                group by ("test_space"."FIRST_NAME"::string) output: ("test_space"."id"::unsigned -> "id", "test_space"."sysFrom"::unsigned -> "sysFrom", "test_space"."FIRST_NAME"::string -> "FIRST_NAME", "test_space"."sys_op"::unsigned -> "sys_op", "test_space"."bucket_id"::unsigned -> "bucket_id")
+                    scan "test_space"
 execution options:
 sql_vdbe_max_steps = 45000
 vtable_max_rows = 5000
diff --git a/sbroad-core/src/ir/aggregates.rs b/sbroad-core/src/ir/aggregates.rs
index 4eff3fb95..5adeab519 100644
--- a/sbroad-core/src/ir/aggregates.rs
+++ b/sbroad-core/src/ir/aggregates.rs
@@ -82,6 +82,69 @@ impl AggregateKind {
         }
     }
 
+    /// Check agruments types of aggregate function
+    ///
+    /// # Errors
+    /// - Invlid plan/aggregate
+    /// - Invalid argument type
+    ///
+    /// # Panics
+    /// - Invalid argument count for aggregate
+    pub fn check_args_types(&self, plan: &Plan, args: &[usize]) -> Result<(), SbroadError> {
+        use crate::ir::relation::Type;
+        let get_arg_type = |idx: usize| -> Result<Type, SbroadError> {
+            let arg_id = *args.get(idx).expect("wrong agregate");
+            let expr = plan.get_expression_node(arg_id)?;
+            expr.calculate_type(plan)
+        };
+        let err = |arg_type: &Type| -> Result<(), SbroadError> {
+            Err(SbroadError::Invalid(
+                Entity::Query,
+                Some(format_smolstr!(
+                    "can't compute {self} on argument with type {arg_type}. \
+                    Consider using explicit cast."
+                )),
+            ))
+        };
+        match self {
+            AggregateKind::SUM | AggregateKind::AVG | AggregateKind::TOTAL => {
+                let arg_type = get_arg_type(0)?;
+                if !matches!(
+                    arg_type,
+                    Type::Decimal | Type::Double | Type::Unsigned | Type::Integer
+                ) {
+                    err(&arg_type)?;
+                }
+            }
+            AggregateKind::MIN | AggregateKind::MAX => {
+                let arg_type = get_arg_type(0)?;
+                if !arg_type.is_scalar() {
+                    err(&arg_type)?;
+                }
+            }
+            AggregateKind::GRCONCAT => {
+                let first_type = get_arg_type(0)?;
+                if args.len() == 2 {
+                    let second_type = get_arg_type(1)?;
+                    if first_type != second_type {
+                        return Err(SbroadError::Invalid(
+                            Entity::Query,
+                            Some(
+                                "group concat requires both arguments to be of the same type!"
+                                    .into(),
+                            ),
+                        ));
+                    }
+                }
+                if !matches!(first_type, Type::String) {
+                    err(&first_type)?;
+                }
+            }
+            AggregateKind::COUNT => {}
+        }
+        Ok(())
+    }
+
     /// Get final aggregate corresponding to given local aggregate
     ///
     /// # Errors
diff --git a/sbroad-core/src/ir/function.rs b/sbroad-core/src/ir/function.rs
index 289867c78..0ee2e10e5 100644
--- a/sbroad-core/src/ir/function.rs
+++ b/sbroad-core/src/ir/function.rs
@@ -78,6 +78,9 @@ impl Plan {
     ///
     /// # Errors
     /// - Invalid arguments for given aggregate function
+    ///
+    /// # Panics
+    /// - never
     pub fn add_aggregate_function(
         &mut self,
         function: &str,
@@ -115,6 +118,7 @@ impl Plan {
                 }
             }
         }
+        kind.check_args_types(self, &children)?;
         let feature = if is_distinct {
             Some(FunctionFeature::Distinct)
         } else {
-- 
GitLab