From 389fbff0c623c0f05afce43b4c7af7534f50d43c Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Mon, 10 Jun 2024 22:05:29 +0000
Subject: [PATCH] feat: infer not null constraint on primary key columns

---
 sbroad-core/src/frontend/sql.rs               |  16 ++-
 sbroad-core/src/frontend/sql/ir/tests.rs      |  16 +--
 sbroad-core/src/frontend/sql/ir/tests/ddl.rs  | 100 ++++++++++++++++++
 .../src/frontend/sql/ir/tests/global.rs       |   6 +-
 4 files changed, 124 insertions(+), 14 deletions(-)
 create mode 100644 sbroad-core/src/frontend/sql/ir/tests/ddl.rs

diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs
index eb028387c..add50594b 100644
--- a/sbroad-core/src/frontend/sql.rs
+++ b/sbroad-core/src/frontend/sql.rs
@@ -3,7 +3,7 @@
 //! Parses an SQL statement to the abstract syntax tree (AST)
 //! and builds the intermediate representation (IR).
 
-use ahash::AHashMap;
+use ahash::{AHashMap, AHashSet};
 use core::panic;
 use itertools::Itertools;
 use pest::iterators::{Pair, Pairs};
@@ -504,6 +504,7 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
     let mut pk_keys: Vec<SmolStr> = Vec::new();
     let mut shard_key: Vec<SmolStr> = Vec::new();
     let mut engine_type: SpaceEngineType = SpaceEngineType::default();
+    let mut explicit_null_columns: AHashSet<SmolStr> = AHashSet::new();
     let mut timeout = get_default_timeout();
     let mut tier = None;
 
@@ -586,6 +587,7 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
                                     def_child_node.children.get(1),
                                 ) {
                                     (None, None) => {
+                                        explicit_null_columns.insert(column_def.name.clone());
                                         column_def.is_nullable = true;
                                     }
                                     (Some(child_id), None) => {
@@ -606,9 +608,13 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
                                 if !pk_keys.is_empty() {
                                     return primary_key_already_declared_error;
                                 }
-                                if column_def.is_nullable {
+                                if column_def.is_nullable
+                                    && explicit_null_columns.contains(&column_def.name)
+                                {
                                     return nullable_primary_key_column_error;
                                 }
+                                // Infer not null on primary key column
+                                column_def.is_nullable = false;
                                 pk_keys.push(column_def.name.clone());
                             }
                             _ => panic!("Unexpected rules met under ColumnDef."),
@@ -627,12 +633,14 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
                 for pk_col_id in pk_node.children.iter().skip(1) {
                     let pk_col_name = parse_identifier(ast, *pk_col_id)?;
                     let mut column_found = false;
-                    for column in &columns {
+                    for column in &mut columns {
                         if column.name == pk_col_name {
                             column_found = true;
-                            if column.is_nullable {
+                            if column.is_nullable && explicit_null_columns.contains(&column.name) {
                                 return nullable_primary_key_column_error;
                             }
+                            // Infer not null on primary key column
+                            column.is_nullable = false;
                         }
                     }
                     if !column_found {
diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs
index b99736b65..3bf492d1a 100644
--- a/sbroad-core/src/frontend/sql/ir/tests.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests.rs
@@ -1803,10 +1803,10 @@ vtable_max_rows = 1000
 
 #[test]
 fn front_sql_pg_style_params3() {
-    let input = r#"select "a" + $1 from "t" 
+    let input = r#"select "a" + $1 from "t"
         where "a" = $1
         group by "a" + $1
-        having count("b") > $1 
+        having count("b") > $1
         option(sql_vdbe_max_steps = $1, vtable_max_rows = $1)"#;
 
     let plan = sql_to_optimized_ir(input, vec![Value::Unsigned(42)]);
@@ -2070,7 +2070,7 @@ vtable_max_rows = 5000
 
 #[test]
 fn front_sql_except_single_right() {
-    let input = r#"SELECT "a", "b" from "t" 
+    let input = r#"SELECT "a", "b" from "t"
         EXCEPT
         SELECT sum("a"), count("b") from "t"
     "#;
@@ -2094,7 +2094,7 @@ vtable_max_rows = 5000
 
     assert_eq!(expected_explain, plan.as_explain().unwrap());
 
-    let input = r#"SELECT "b", "a" from "t" 
+    let input = r#"SELECT "b", "a" from "t"
         EXCEPT
         SELECT sum("a"), count("b") from "t"
     "#;
@@ -2122,7 +2122,7 @@ vtable_max_rows = 5000
 
 #[test]
 fn front_sql_except_single_left() {
-    let input = r#"SELECT sum("a"), count("b") from "t" 
+    let input = r#"SELECT sum("a"), count("b") from "t"
         EXCEPT
         SELECT "a", "b" from "t"
     "#;
@@ -2149,7 +2149,7 @@ vtable_max_rows = 5000
 
 #[test]
 fn front_sql_except_single_both() {
-    let input = r#"SELECT sum("a"), count("b") from "t" 
+    let input = r#"SELECT sum("a"), count("b") from "t"
         EXCEPT
         SELECT sum("a"), sum("b") from "t"
     "#;
@@ -2311,7 +2311,7 @@ vtable_max_rows = 5000
 
 #[test]
 fn front_sql_left_join() {
-    let input = r#"SELECT * from (select "a" as a from "t") as o 
+    let input = r#"SELECT * from (select "a" as a from "t") as o
         left outer join (select "b" as c, "d" as d from "t") as i
         on o.a = i.c
         "#;
@@ -3558,6 +3558,8 @@ fn front_count_no_params() {
 #[cfg(test)]
 mod cte;
 #[cfg(test)]
+mod ddl;
+#[cfg(test)]
 mod global;
 #[cfg(test)]
 mod insert;
diff --git a/sbroad-core/src/frontend/sql/ir/tests/ddl.rs b/sbroad-core/src/frontend/sql/ir/tests/ddl.rs
new file mode 100644
index 000000000..394dd0756
--- /dev/null
+++ b/sbroad-core/src/frontend/sql/ir/tests/ddl.rs
@@ -0,0 +1,100 @@
+use crate::frontend::Ast;
+use pretty_assertions::assert_eq;
+use smol_str::SmolStr;
+
+use crate::{
+    executor::engine::mock::RouterConfigurationMock,
+    frontend::sql::ast::AbstractSyntaxTree,
+    ir::{
+        ddl::{ColumnDef, Ddl},
+        relation::Type,
+    },
+};
+
+#[test]
+fn infer_not_null_on_pk1() {
+    let input = r#"create table t (a int primary key) distributed globally"#;
+
+    let metadata = &RouterConfigurationMock::new();
+    let plan = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap();
+    let top_id = plan.get_top().unwrap();
+    let top_node = plan.get_ddl_node(top_id).unwrap();
+
+    let Ddl::CreateTable {
+        format,
+        primary_key,
+        ..
+    } = top_node
+    else {
+        panic!("expected create table")
+    };
+
+    let def = ColumnDef {
+        name: "A".into(),
+        data_type: Type::Integer,
+        is_nullable: false,
+    };
+
+    assert_eq!(format, &vec![def]);
+
+    let expected_pk: Vec<SmolStr> = vec!["A".into()];
+    assert_eq!(primary_key, &expected_pk);
+}
+
+#[test]
+fn infer_not_null_on_pk2() {
+    let input =
+        r#"create table t (a int, b int not null, c int, primary key (a, b)) distributed globally"#;
+
+    let metadata = &RouterConfigurationMock::new();
+    let plan = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap();
+    let top_id = plan.get_top().unwrap();
+    let top_node = plan.get_ddl_node(top_id).unwrap();
+
+    let Ddl::CreateTable {
+        format,
+        primary_key,
+        ..
+    } = top_node
+    else {
+        panic!("expected create table")
+    };
+
+    let def_a = ColumnDef {
+        name: "A".into(),
+        data_type: Type::Integer,
+        is_nullable: false,
+    };
+
+    let def_b = ColumnDef {
+        name: "B".into(),
+        data_type: Type::Integer,
+        is_nullable: false,
+    };
+
+    let def_c = ColumnDef {
+        name: "C".into(),
+        data_type: Type::Integer,
+        is_nullable: true,
+    };
+
+    assert_eq!(format, &vec![def_a, def_b, def_c]);
+
+    let expected_pk: Vec<SmolStr> = vec!["A".into(), "B".into()];
+    assert_eq!(primary_key, &expected_pk);
+}
+
+#[test]
+fn infer_not_null_on_pk3() {
+    let input = r#"create table t (a int null, b int not null, c int, primary key (a, b)) distributed globally"#;
+
+    let metadata = &RouterConfigurationMock::new();
+    let err = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap_err();
+
+    dbg!(&err);
+    assert_eq!(
+        true,
+        err.to_string()
+            .contains("Primary key mustn't contain nullable columns.")
+    );
+}
diff --git a/sbroad-core/src/frontend/sql/ir/tests/global.rs b/sbroad-core/src/frontend/sql/ir/tests/global.rs
index a00ce844e..64fb630cf 100644
--- a/sbroad-core/src/frontend/sql/ir/tests/global.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests/global.rs
@@ -944,7 +944,7 @@ vtable_max_rows = 5000
 #[test]
 fn front_sql_global_left_join3() {
     let input = r#"
-    select "e", "b" from 
+    select "e", "b" from
     (select "b" * "b" as "b" from "global_t")
     left join "t2" on true
     "#;
@@ -975,9 +975,9 @@ vtable_max_rows = 5000
 #[test]
 fn front_sql_global_left_join4() {
     let input = r#"
-    select "e", "b" from 
+    select "e", "b" from
     (select "b" * "b" as "b" from "global_t")
-    left join 
+    left join
     (select "e" + 1 as "e" from "t2")
     on true
     "#;
-- 
GitLab