From 389fbff0c623c0f05afce43b4c7af7534f50d43c Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <vol0ncar@yandex.ru> Date: Mon, 10 Jun 2024 22:05:29 +0000 Subject: [PATCH] feat: infer not null constraint on primary key columns --- sbroad-core/src/frontend/sql.rs | 16 ++- sbroad-core/src/frontend/sql/ir/tests.rs | 16 +-- sbroad-core/src/frontend/sql/ir/tests/ddl.rs | 100 ++++++++++++++++++ .../src/frontend/sql/ir/tests/global.rs | 6 +- 4 files changed, 124 insertions(+), 14 deletions(-) create mode 100644 sbroad-core/src/frontend/sql/ir/tests/ddl.rs diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs index eb028387c..add50594b 100644 --- a/sbroad-core/src/frontend/sql.rs +++ b/sbroad-core/src/frontend/sql.rs @@ -3,7 +3,7 @@ //! Parses an SQL statement to the abstract syntax tree (AST) //! and builds the intermediate representation (IR). -use ahash::AHashMap; +use ahash::{AHashMap, AHashSet}; use core::panic; use itertools::Itertools; use pest::iterators::{Pair, Pairs}; @@ -504,6 +504,7 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl, let mut pk_keys: Vec<SmolStr> = Vec::new(); let mut shard_key: Vec<SmolStr> = Vec::new(); let mut engine_type: SpaceEngineType = SpaceEngineType::default(); + let mut explicit_null_columns: AHashSet<SmolStr> = AHashSet::new(); let mut timeout = get_default_timeout(); let mut tier = None; @@ -586,6 +587,7 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl, def_child_node.children.get(1), ) { (None, None) => { + explicit_null_columns.insert(column_def.name.clone()); column_def.is_nullable = true; } (Some(child_id), None) => { @@ -606,9 +608,13 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl, if !pk_keys.is_empty() { return primary_key_already_declared_error; } - if column_def.is_nullable { + if column_def.is_nullable + && explicit_null_columns.contains(&column_def.name) + { return nullable_primary_key_column_error; } + // Infer not null on primary key column + column_def.is_nullable = false; pk_keys.push(column_def.name.clone()); } _ => panic!("Unexpected rules met under ColumnDef."), @@ -627,12 +633,14 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl, for pk_col_id in pk_node.children.iter().skip(1) { let pk_col_name = parse_identifier(ast, *pk_col_id)?; let mut column_found = false; - for column in &columns { + for column in &mut columns { if column.name == pk_col_name { column_found = true; - if column.is_nullable { + if column.is_nullable && explicit_null_columns.contains(&column.name) { return nullable_primary_key_column_error; } + // Infer not null on primary key column + column.is_nullable = false; } } if !column_found { diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs index b99736b65..3bf492d1a 100644 --- a/sbroad-core/src/frontend/sql/ir/tests.rs +++ b/sbroad-core/src/frontend/sql/ir/tests.rs @@ -1803,10 +1803,10 @@ vtable_max_rows = 1000 #[test] fn front_sql_pg_style_params3() { - let input = r#"select "a" + $1 from "t" + let input = r#"select "a" + $1 from "t" where "a" = $1 group by "a" + $1 - having count("b") > $1 + having count("b") > $1 option(sql_vdbe_max_steps = $1, vtable_max_rows = $1)"#; let plan = sql_to_optimized_ir(input, vec![Value::Unsigned(42)]); @@ -2070,7 +2070,7 @@ vtable_max_rows = 5000 #[test] fn front_sql_except_single_right() { - let input = r#"SELECT "a", "b" from "t" + let input = r#"SELECT "a", "b" from "t" EXCEPT SELECT sum("a"), count("b") from "t" "#; @@ -2094,7 +2094,7 @@ vtable_max_rows = 5000 assert_eq!(expected_explain, plan.as_explain().unwrap()); - let input = r#"SELECT "b", "a" from "t" + let input = r#"SELECT "b", "a" from "t" EXCEPT SELECT sum("a"), count("b") from "t" "#; @@ -2122,7 +2122,7 @@ vtable_max_rows = 5000 #[test] fn front_sql_except_single_left() { - let input = r#"SELECT sum("a"), count("b") from "t" + let input = r#"SELECT sum("a"), count("b") from "t" EXCEPT SELECT "a", "b" from "t" "#; @@ -2149,7 +2149,7 @@ vtable_max_rows = 5000 #[test] fn front_sql_except_single_both() { - let input = r#"SELECT sum("a"), count("b") from "t" + let input = r#"SELECT sum("a"), count("b") from "t" EXCEPT SELECT sum("a"), sum("b") from "t" "#; @@ -2311,7 +2311,7 @@ vtable_max_rows = 5000 #[test] fn front_sql_left_join() { - let input = r#"SELECT * from (select "a" as a from "t") as o + let input = r#"SELECT * from (select "a" as a from "t") as o left outer join (select "b" as c, "d" as d from "t") as i on o.a = i.c "#; @@ -3558,6 +3558,8 @@ fn front_count_no_params() { #[cfg(test)] mod cte; #[cfg(test)] +mod ddl; +#[cfg(test)] mod global; #[cfg(test)] mod insert; diff --git a/sbroad-core/src/frontend/sql/ir/tests/ddl.rs b/sbroad-core/src/frontend/sql/ir/tests/ddl.rs new file mode 100644 index 000000000..394dd0756 --- /dev/null +++ b/sbroad-core/src/frontend/sql/ir/tests/ddl.rs @@ -0,0 +1,100 @@ +use crate::frontend::Ast; +use pretty_assertions::assert_eq; +use smol_str::SmolStr; + +use crate::{ + executor::engine::mock::RouterConfigurationMock, + frontend::sql::ast::AbstractSyntaxTree, + ir::{ + ddl::{ColumnDef, Ddl}, + relation::Type, + }, +}; + +#[test] +fn infer_not_null_on_pk1() { + let input = r#"create table t (a int primary key) distributed globally"#; + + let metadata = &RouterConfigurationMock::new(); + let plan = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap(); + let top_id = plan.get_top().unwrap(); + let top_node = plan.get_ddl_node(top_id).unwrap(); + + let Ddl::CreateTable { + format, + primary_key, + .. + } = top_node + else { + panic!("expected create table") + }; + + let def = ColumnDef { + name: "A".into(), + data_type: Type::Integer, + is_nullable: false, + }; + + assert_eq!(format, &vec![def]); + + let expected_pk: Vec<SmolStr> = vec!["A".into()]; + assert_eq!(primary_key, &expected_pk); +} + +#[test] +fn infer_not_null_on_pk2() { + let input = + r#"create table t (a int, b int not null, c int, primary key (a, b)) distributed globally"#; + + let metadata = &RouterConfigurationMock::new(); + let plan = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap(); + let top_id = plan.get_top().unwrap(); + let top_node = plan.get_ddl_node(top_id).unwrap(); + + let Ddl::CreateTable { + format, + primary_key, + .. + } = top_node + else { + panic!("expected create table") + }; + + let def_a = ColumnDef { + name: "A".into(), + data_type: Type::Integer, + is_nullable: false, + }; + + let def_b = ColumnDef { + name: "B".into(), + data_type: Type::Integer, + is_nullable: false, + }; + + let def_c = ColumnDef { + name: "C".into(), + data_type: Type::Integer, + is_nullable: true, + }; + + assert_eq!(format, &vec![def_a, def_b, def_c]); + + let expected_pk: Vec<SmolStr> = vec!["A".into(), "B".into()]; + assert_eq!(primary_key, &expected_pk); +} + +#[test] +fn infer_not_null_on_pk3() { + let input = r#"create table t (a int null, b int not null, c int, primary key (a, b)) distributed globally"#; + + let metadata = &RouterConfigurationMock::new(); + let err = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap_err(); + + dbg!(&err); + assert_eq!( + true, + err.to_string() + .contains("Primary key mustn't contain nullable columns.") + ); +} diff --git a/sbroad-core/src/frontend/sql/ir/tests/global.rs b/sbroad-core/src/frontend/sql/ir/tests/global.rs index a00ce844e..64fb630cf 100644 --- a/sbroad-core/src/frontend/sql/ir/tests/global.rs +++ b/sbroad-core/src/frontend/sql/ir/tests/global.rs @@ -944,7 +944,7 @@ vtable_max_rows = 5000 #[test] fn front_sql_global_left_join3() { let input = r#" - select "e", "b" from + select "e", "b" from (select "b" * "b" as "b" from "global_t") left join "t2" on true "#; @@ -975,9 +975,9 @@ vtable_max_rows = 5000 #[test] fn front_sql_global_left_join4() { let input = r#" - select "e", "b" from + select "e", "b" from (select "b" * "b" as "b" from "global_t") - left join + left join (select "e" + 1 as "e" from "t2") on true "#; -- GitLab