From 36b54a598c9891ad5ac3ebe6aa1f9547be457f5a Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <vol0ncar@yandex.ru> Date: Sun, 21 Apr 2024 19:08:07 +0000 Subject: [PATCH] feat(sql): support UNION - update sbroad submodule to commit with support of UNION operator - add integration tests for global tables (sharded tables were tested in cartridge tests) - Usage: `select a from t union select b from t2` --- sbroad | 2 +- src/sql/pgproto.rs | 2 + test/int/test_sql.py | 168 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 164 insertions(+), 8 deletions(-) diff --git a/sbroad b/sbroad index 0c57585af4..f5b5236d24 160000 --- a/sbroad +++ b/sbroad @@ -1 +1 @@ -Subproject commit 0c57585af4fbfe449b0c0fdcdd56a3d84d766169 +Subproject commit f5b5236d24a6974eac2efc238e754079c0274aba diff --git a/src/sql/pgproto.rs b/src/sql/pgproto.rs index 3d5af4b8bd..ce2b72d7b4 100644 --- a/src/sql/pgproto.rs +++ b/src/sql/pgproto.rs @@ -677,8 +677,10 @@ impl TryFrom<&Node> for CommandTag { | Relational::GroupBy { .. } | Relational::OrderBy { .. } | Relational::Having { .. } + | Relational::Union { .. } | Relational::UnionAll { .. } | Relational::Values { .. } + | Relational::OrderBy { .. } | Relational::ValuesRow { .. } => Ok(CommandTag::Select), }, Node::Expression(_) | Node::Parameter => Err(SbroadError::Invalid( diff --git a/test/int/test_sql.py b/test/int/test_sql.py index b482bdd947..87d459befa 100644 --- a/test/int/test_sql.py +++ b/test/int/test_sql.py @@ -1034,13 +1034,12 @@ def test_union_all_on_global_tbls(cluster: Cluster): def check_multiple_union_all(): data = i1.sql( """ - select * from ( - select a from g - where a = 2 - union all - select d from s - group by d - ) union all + select a from g + where a = 2 + union all + select d from s + group by d + union all select a from g where b = 1 """, @@ -1083,6 +1082,161 @@ def test_union_all_on_global_tbls(cluster: Cluster): Retriable(rps=5, timeout=6).call(check_complex_segment_child) +def test_union_on_global_tbls(cluster: Cluster): + cluster.deploy(instance_count=1) + i1 = cluster.instances[0] + + ddl = i1.sql( + """ + create table g (a int not null, b int not null, primary key (a)) + using memtx + distributed globally + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + + for i, j in [(1, 1), (2, 2), (3, 2)]: + index = i1.cas("insert", "G", [i, j]) + i1.raft_wait_index(index, 3) + + ddl = i1.sql( + """ + create table s (c int not null, d int not null, primary key (c)) + using memtx + distributed by (c) + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + data = i1.sql("""insert into s values (1, 2), (2, 2), (3, 2);""") + assert data["row_count"] == 3 + + expected = [[1], [2]] + + def check_global_vs_any(): + data = i1.sql( + """ + select b from g + union + select d from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_any) + + def check_any_vs_global(): + data = i1.sql( + """ + select d from s + union + select b from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_any_vs_global) + + def check_global_vs_global(): + data = i1.sql( + """ + select b from g + union + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [ + [1], + [2], + [3], + ] + + Retriable(rps=5, timeout=6).call(check_global_vs_global) + + expected = [[1], [2], [3]] + + def check_global_vs_segment(): + data = i1.sql( + """ + select a from g + union + select c from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_segment) + + def check_segment_vs_global(): + data = i1.sql( + """ + select c from s + union + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_segment_vs_global) + + expected = [[1], [2], [3]] + + def check_single_vs_global(): + data = i1.sql( + """ + select sum(c) - 3 from s + union + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_single_vs_global) + + def check_global_vs_single(): + data = i1.sql( + """ + select a from g + union + select sum(c) - 3 from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_single) + + def check_multiple_union(): + data = i1.sql( + """ + select a from g + where a = 2 + union + select d from s + group by d + union + select a from g + where b = 1 + except + select null from g + where false + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]] + + Retriable(rps=5, timeout=6).call(check_multiple_union) + +def test_select_chaining(): + + def test_trim(instance: Instance): instance.sql( """ -- GitLab