From a1ed892604774a3ad9cf3f080b7892019593e724 Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Sun, 4 Aug 2024 20:13:18 +0300
Subject: [PATCH] feat: infer shard key from primary key

---
 doc/sql/query.ebnf                           |  2 +-
 sbroad-core/src/frontend/sql.rs              | 10 ++++++++-
 sbroad-core/src/frontend/sql/ir/tests/ddl.rs | 22 ++++++++++++++++++--
 sbroad-core/src/frontend/sql/query.pest      |  2 +-
 4 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/doc/sql/query.ebnf b/doc/sql/query.ebnf
index 80cf51be14..c9f9535fa0 100644
--- a/doc/sql/query.ebnf
+++ b/doc/sql/query.ebnf
@@ -147,7 +147,7 @@ create_table   ::= 'CREATE' 'TABLE' table
                        (',' 'PRIMARY' 'KEY' '(' column (',' column)* ')')?
                    ')'
                    ('USING' ('MEMTX' | 'VINYL'))?
-                   ('DISTRIBUTED' (('BY' '(' column (',' column)*  ')' ('IN' 'TIER' tier)?) | 'GLOBALLY'))
+                   ('DISTRIBUTED' (('BY' '(' column (',' column)*  ')' ('IN' 'TIER' tier)?) | 'GLOBALLY'))?
 create_user    ::= 'CREATE' 'USER' user 'WITH'? 'PASSWORD' "'" password "'"
                    ('USING' ('CHAP-SHA1' | 'LDAP' | 'MD5'))?
 alter_user     ::= 'ALTER' 'USER' user
diff --git a/sbroad-core/src/frontend/sql.rs b/sbroad-core/src/frontend/sql.rs
index 55a72d8548..43557c9d62 100644
--- a/sbroad-core/src/frontend/sql.rs
+++ b/sbroad-core/src/frontend/sql.rs
@@ -507,6 +507,7 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
     let mut explicit_null_columns: AHashSet<SmolStr> = AHashSet::new();
     let mut timeout = get_default_timeout();
     let mut tier = None;
+    let mut is_global = false;
 
     let nullable_primary_key_column_error = Err(SbroadError::Invalid(
         Entity::Column,
@@ -680,7 +681,9 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
                 ) {
                     let distribution_type_node = ast.nodes.get_node(*distribution_type_id)?;
                     match distribution_type_node.rule {
-                        Rule::Global => {}
+                        Rule::Global => {
+                            is_global = true;
+                        }
                         Rule::Sharding => {
                             let sharding_node = ast.nodes.get_node(*distribution_type_id)?;
                             for sharding_node_child in &sharding_node.children {
@@ -759,6 +762,10 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
             Some(format_smolstr!("Primary key must be declared.")),
         ));
     }
+    // infer sharding key from primary key
+    if shard_key.is_empty() && !is_global {
+        shard_key = pk_keys.clone();
+    }
     let create_sharded_table = if shard_key.is_empty() {
         if engine_type != SpaceEngineType::Memtx {
             return Err(SbroadError::Unsupported(
@@ -766,6 +773,7 @@ fn parse_create_table(ast: &AbstractSyntaxTree, node: &ParseNode) -> Result<Ddl,
                 Some("global spaces can use only memtx engine".into()),
             ));
         }
+
         Ddl::CreateTable {
             name: table_name,
             format: columns,
diff --git a/sbroad-core/src/frontend/sql/ir/tests/ddl.rs b/sbroad-core/src/frontend/sql/ir/tests/ddl.rs
index 394dd07566..c0c58a7d70 100644
--- a/sbroad-core/src/frontend/sql/ir/tests/ddl.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests/ddl.rs
@@ -1,6 +1,6 @@
 use crate::frontend::Ast;
 use pretty_assertions::assert_eq;
-use smol_str::SmolStr;
+use smol_str::{SmolStr, ToSmolStr};
 
 use crate::{
     executor::engine::mock::RouterConfigurationMock,
@@ -91,10 +91,28 @@ fn infer_not_null_on_pk3() {
     let metadata = &RouterConfigurationMock::new();
     let err = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap_err();
 
-    dbg!(&err);
     assert_eq!(
         true,
         err.to_string()
             .contains("Primary key mustn't contain nullable columns.")
     );
 }
+
+#[test]
+fn infer_sk_from_pk() {
+    let input = r#"create table t ("a" int, "b" int, c int, primary key ("a", "b"))"#;
+
+    let metadata = &RouterConfigurationMock::new();
+    let plan = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap();
+    let top_id = plan.get_top().unwrap();
+    let top_node = plan.get_ddl_node(top_id).unwrap();
+
+    let Ddl::CreateTable { sharding_key, .. } = top_node else {
+        panic!("expected create table")
+    };
+
+    assert_eq!(
+        sharding_key.as_ref().unwrap(),
+        &vec!["a".to_smolstr(), "b".to_smolstr()]
+    );
+}
diff --git a/sbroad-core/src/frontend/sql/query.pest b/sbroad-core/src/frontend/sql/query.pest
index 39b13ccadd..eac66f5284 100644
--- a/sbroad-core/src/frontend/sql/query.pest
+++ b/sbroad-core/src/frontend/sql/query.pest
@@ -60,7 +60,7 @@ DDL = _{ CreateTable | DropTable | CreateIndex | DropIndex
     CreateTable = {
         ^"create" ~ ^"table" ~ NewTable ~
         "(" ~ Columns ~ ("," ~ PrimaryKey)? ~ ")" ~
-        Engine? ~ Distribution ~ TimeoutOption?
+        Engine? ~ Distribution? ~ TimeoutOption?
     }
         NewTable = @{Table}
         Columns = { ColumnDef ~ ("," ~ ColumnDef)* }
-- 
GitLab