diff --git a/sbroad b/sbroad index 0c57585af4fbfe449b0c0fdcdd56a3d84d766169..f5b5236d24a6974eac2efc238e754079c0274aba 160000 --- a/sbroad +++ b/sbroad @@ -1 +1 @@ -Subproject commit 0c57585af4fbfe449b0c0fdcdd56a3d84d766169 +Subproject commit f5b5236d24a6974eac2efc238e754079c0274aba diff --git a/src/sql/pgproto.rs b/src/sql/pgproto.rs index 3d5af4b8bd1d43ee6d93b192f6575ba1ccd9dba5..ce2b72d7b47ba82114ce0bf91f6fe1a0b3923c43 100644 --- a/src/sql/pgproto.rs +++ b/src/sql/pgproto.rs @@ -677,8 +677,10 @@ impl TryFrom<&Node> for CommandTag { | Relational::GroupBy { .. } | Relational::OrderBy { .. } | Relational::Having { .. } + | Relational::Union { .. } | Relational::UnionAll { .. } | Relational::Values { .. } + | Relational::OrderBy { .. } | Relational::ValuesRow { .. } => Ok(CommandTag::Select), }, Node::Expression(_) | Node::Parameter => Err(SbroadError::Invalid( diff --git a/test/conftest.py b/test/conftest.py index ea100309b1b970eebc42259c5a28973ac039d46b..a0a345c48b18ba177afe862d6a5d311ba6a2bff8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1892,7 +1892,6 @@ class AuditServer: return None def server(queue: Queue, host: str, port: int) -> None: - class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): QUEUE = queue diff --git a/test/int/test_sql.py b/test/int/test_sql.py index b482bdd9472b0be3f2bae966245a9b4231e5def0..e3eb08e43d7a8d406de3e476c798174f50dc51c0 100644 --- a/test/int/test_sql.py +++ b/test/int/test_sql.py @@ -1034,13 +1034,12 @@ def test_union_all_on_global_tbls(cluster: Cluster): def check_multiple_union_all(): data = i1.sql( """ - select * from ( - select a from g - where a = 2 - union all - select d from s - group by d - ) union all + select a from g + where a = 2 + union all + select d from s + group by d + union all select a from g where b = 1 """, @@ -1083,6 +1082,159 @@ def test_union_all_on_global_tbls(cluster: Cluster): Retriable(rps=5, timeout=6).call(check_complex_segment_child) +def test_union_on_global_tbls(cluster: Cluster): + cluster.deploy(instance_count=1) + i1 = cluster.instances[0] + + ddl = i1.sql( + """ + create table g (a int not null, b int not null, primary key (a)) + using memtx + distributed globally + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + + for i, j in [(1, 1), (2, 2), (3, 2)]: + index = i1.cas("insert", "G", [i, j]) + i1.raft_wait_index(index, 3) + + ddl = i1.sql( + """ + create table s (c int not null, d int not null, primary key (c)) + using memtx + distributed by (c) + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + data = i1.sql("""insert into s values (1, 2), (2, 2), (3, 2);""") + assert data["row_count"] == 3 + + expected = [[1], [2]] + + def check_global_vs_any(): + data = i1.sql( + """ + select b from g + union + select d from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_any) + + def check_any_vs_global(): + data = i1.sql( + """ + select d from s + union + select b from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_any_vs_global) + + def check_global_vs_global(): + data = i1.sql( + """ + select b from g + union + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [ + [1], + [2], + [3], + ] + + Retriable(rps=5, timeout=6).call(check_global_vs_global) + + expected = [[1], [2], [3]] + + def check_global_vs_segment(): + data = i1.sql( + """ + select a from g + union + select c from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_segment) + + def check_segment_vs_global(): + data = i1.sql( + """ + select c from s + union + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_segment_vs_global) + + expected = [[1], [2], [3]] + + def check_single_vs_global(): + data = i1.sql( + """ + select sum(c) - 3 from s + union + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_single_vs_global) + + def check_global_vs_single(): + data = i1.sql( + """ + select a from g + union + select sum(c) - 3 from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_single) + + def check_multiple_union(): + data = i1.sql( + """ + select a from g + where a = 2 + union + select d from s + group by d + union + select a from g + where b = 1 + except + select null from g + where false + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]] + + Retriable(rps=5, timeout=6).call(check_multiple_union) + + def test_trim(instance: Instance): instance.sql( """