From bfff1b4f5248dfb5c7944c6c3859f14667d3994a Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <vol0ncar@yandex.ru> Date: Mon, 18 Dec 2023 14:04:37 +0300 Subject: [PATCH] feat: support except for global tbls - update sbroad submodule to commit with support of except for global tables - add integration tests for except with global tbls --- sbroad | 2 +- src/sql.rs | 2 +- src/sql/pgproto.rs | 1 + src/sql/router.rs | 30 +++++---- test/int/test_sql.py | 143 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 163 insertions(+), 15 deletions(-) diff --git a/sbroad b/sbroad index 4cb04bfc0b..73576c4bd9 160000 --- a/sbroad +++ b/sbroad @@ -1 +1 @@ -Subproject commit 4cb04bfc0b957aee2d1529db276e3ead758ffe62 +Subproject commit 73576c4bd9dd047016fbcf287e83452aad8b3416 diff --git a/src/sql.rs b/src/sql.rs index 4781c09f5d..a271239772 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -75,7 +75,7 @@ fn check_table_privileges(plan: &IrPlan) -> traft::Result<()> { false }; let mut plan_traversal = PostOrderWithFilter::with_capacity( - |node| plan.subtree_iter(node), + |node| plan.subtree_iter(node, false), REL_CAPACITY, Box::new(filter), ); diff --git a/src/sql/pgproto.rs b/src/sql/pgproto.rs index 6a8b88b6ec..32d88e81c7 100644 --- a/src/sql/pgproto.rs +++ b/src/sql/pgproto.rs @@ -308,6 +308,7 @@ impl TryFrom<&Node> for CommandTag { | Relational::Projection { .. } | Relational::ScanRelation { .. } | Relational::ScanSubQuery { .. } + | Relational::Intersect { .. } | Relational::Selection { .. } | Relational::GroupBy { .. } | Relational::Having { .. } diff --git a/src/sql/router.rs b/src/sql/router.rs index 4fb164eabe..1c447e0de0 100644 --- a/src/sql/router.rs +++ b/src/sql/router.rs @@ -551,30 +551,34 @@ impl Metadata for RouterMetadata { columns.push(column); } + let normalized_name = normalize_name_from_sql(table_name); + let pk_cols = space_pk_columns(&name, &columns)?; + let pk_cols_str: &[&str] = &pk_cols.iter().map(String::as_str).collect::<Vec<_>>(); + // Try to find the sharding columns of the space in "_pico_table". // If nothing found then the space is local and we can't query it with // distributed SQL. let is_system_table = ClusterwideTable::values() .iter() .any(|sys_name| *sys_name == name.as_str()); - let shard_key_cols: Vec<String> = if is_system_table { - vec![] - } else { - Self::get_shard_cols(&name, &meta)? - }; - let sharding_key_arg: &[&str] = &shard_key_cols + + if is_system_table { + return Table::new_system(&normalized_name, columns, pk_cols_str); + } + let sharded_columns = Self::get_shard_cols(&name, &meta)?; + if sharded_columns.is_empty() { + return Table::new_global(&normalized_name, columns, pk_cols_str); + } + let sharding_columns_str: &[&str] = &sharded_columns .iter() .map(String::as_str) .collect::<Vec<_>>(); - let pk_cols = space_pk_columns(&name, &columns)?; - let pk_arg = &pk_cols.iter().map(String::as_str).collect::<Vec<_>>(); - Table::new( - &normalize_name_from_sql(table_name), + Table::new_sharded( + &normalized_name, columns, - sharding_key_arg, - pk_arg, + sharding_columns_str, + pk_cols_str, engine.into(), - is_system_table, ) } diff --git a/test/int/test_sql.py b/test/int/test_sql.py index 26dd30464e..38fe0db1bd 100644 --- a/test/int/test_sql.py +++ b/test/int/test_sql.py @@ -692,6 +692,149 @@ def test_union_all_on_global_tbls(cluster: Cluster): Retriable(rps=5, timeout=6).call(check_complex_segment_child) +def test_except_on_global_tbls(cluster: Cluster): + cluster.deploy(instance_count=1) + i1 = cluster.instances[0] + + ddl = i1.sql( + """ + create table g (a int not null, b int not null, primary key (a)) + using memtx + distributed globally + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + for i in range(1, 6): + index = i1.cas("insert", "G", [i, i]) + i1.raft_wait_index(index, 3) + + ddl = i1.sql( + """ + create table s (c int not null, d int not null, primary key (c)) + using memtx + distributed by (c) + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + + data = i1.sql("""insert into s values (3, 2), (4, 3), (5, 4), (6, 5), (7, 6);""") + assert data["row_count"] == 5 + + def check_global_vs_global(): + data = i1.sql( + """ + select a from g + except + select a - 1 from g + """, + timeout=2, + ) + assert data["rows"] == [[5]] + + Retriable(rps=5, timeout=6).call(check_global_vs_global) + + def check_global_vs_segment(): + data = i1.sql( + """ + select a from g + except + select c from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]] + + Retriable(rps=5, timeout=6).call(check_global_vs_segment) + + def check_global_vs_any(): + data = i1.sql( + """ + select b from g + except + select d from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1]] + + Retriable(rps=5, timeout=6).call(check_global_vs_any) + + def check_global_vs_single(): + data = i1.sql( + """ + select b from g + where b = 1 or b = 2 + except + select sum(d) from s + where d = 3 + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]] + + Retriable(rps=5, timeout=6).call(check_global_vs_single) + + def check_single_vs_global(): + data = i1.sql( + """ + select sum(d) from s + where d = 3 or d = 2 + except + select b from g + where b = 1 or b = 2 + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[5]] + + Retriable(rps=5, timeout=6).call(check_single_vs_global) + + def check_segment_vs_global(): + data = i1.sql( + """ + select c from s + except + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[6], [7]] + + Retriable(rps=5, timeout=6).call(check_segment_vs_global) + + def check_any_vs_global(): + data = i1.sql( + """ + select d from s + except + select b from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[6]] + + Retriable(rps=5, timeout=6).call(check_any_vs_global) + + def check_multiple_excepts(): + data = i1.sql( + """ + select a + 5 from g + where a = 1 or a = 2 + except select * from ( + select d from s + except + select b from g + ) + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[7]] + + Retriable(rps=5, timeout=6).call(check_multiple_excepts) + + def test_hash(cluster: Cluster): cluster.deploy(instance_count=1) i1 = cluster.instances[0] -- GitLab