From cb251e1e152e8361577b661590049e604c06583b Mon Sep 17 00:00:00 2001 From: Arseniy Volynets <vol0ncar@yandex.ru> Date: Mon, 4 Dec 2023 01:07:45 +0300 Subject: [PATCH] feat: support union all with global tbls - update sbroad submodule to commit supporting union all with global tbls - add integration tests --- sbroad | 2 +- src/sql/router.rs | 4 +- test/int/test_sql.py | 196 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 192 insertions(+), 10 deletions(-) diff --git a/sbroad b/sbroad index 46accea5f6..4cb04bfc0b 160000 --- a/sbroad +++ b/sbroad @@ -1 +1 @@ -Subproject commit 46accea5f6bb77d77e8b0963d8d6f9a0b54b3a6d +Subproject commit 4cb04bfc0b957aee2d1529db276e3ead758ffe62 diff --git a/src/sql/router.rs b/src/sql/router.rs index 8e5e6af4da..4fb164eabe 100644 --- a/src/sql/router.rs +++ b/src/sql/router.rs @@ -7,7 +7,7 @@ use sbroad::executor::bucket::Buckets; use sbroad::executor::engine::helpers::vshard::{ exec_ir_on_all_buckets, exec_ir_on_some_buckets, get_random_bucket, }; -use sbroad::executor::engine::helpers::{dispatch, explain_format, materialize_motion}; +use sbroad::executor::engine::helpers::{dispatch_impl, explain_format, materialize_motion}; use sbroad::executor::engine::helpers::{sharding_key_from_map, sharding_key_from_tuple}; use sbroad::executor::engine::{QueryCache, Router, Vshard}; use sbroad::executor::ir::{ConnectionType, ExecutionPlan, QueryType}; @@ -235,7 +235,7 @@ impl Router for RouterRuntime { top_id: usize, buckets: &sbroad::executor::bucket::Buckets, ) -> Result<Box<dyn std::any::Any>, SbroadError> { - dispatch(self, plan, top_id, buckets) + dispatch_impl(self, plan, top_id, buckets) } fn explain_format(&self, explain: String) -> Result<Box<dyn std::any::Any>, SbroadError> { diff --git a/test/int/test_sql.py b/test/int/test_sql.py index 07e08fe5b3..26dd30464e 100644 --- a/test/int/test_sql.py +++ b/test/int/test_sql.py @@ -1,13 +1,7 @@ import pytest import re -from conftest import ( - Cluster, - KeyDef, - KeyPart, - Retriable, - ReturnError, -) +from conftest import Cluster, KeyDef, KeyPart, ReturnError, Retriable def test_pico_sql(cluster: Cluster): @@ -510,6 +504,194 @@ def test_join_with_global_tbls(cluster: Cluster): Retriable(rps=5, timeout=6).call(check_left_join_complex_global_child_vs_any) +def test_union_all_on_global_tbls(cluster: Cluster): + cluster.deploy(instance_count=1) + i1 = cluster.instances[0] + + ddl = i1.sql( + """ + create table g (a int not null, b int not null, primary key (a)) + using memtx + distributed globally + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + + for i, j in [(1, 1), (2, 2), (3, 2)]: + index = i1.cas("insert", "G", [i, j]) + i1.raft_wait_index(index, 3) + + ddl = i1.sql( + """ + create table s (c int not null, d int not null, primary key (c)) + using memtx + distributed by (c) + option (timeout = 3) + """ + ) + assert ddl["row_count"] == 1 + data = i1.sql("""insert into s values (1, 2), (2, 2), (3, 2);""") + assert data["row_count"] == 3 + + expected = [[1], [2], [2], [2], [2], [2]] + + def check_global_vs_any(): + data = i1.sql( + """ + select b from g + union all + select d from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_any) + + def check_any_vs_global(): + data = i1.sql( + """ + select d from s + union all + select b from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_any_vs_global) + + def check_global_vs_global(): + data = i1.sql( + """ + select b from g + union all + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [ + [1], + [1], + [2], + [2], + [2], + [3], + ] + + Retriable(rps=5, timeout=6).call(check_global_vs_global) + + expected = [[1], [1], [2], [2], [3], [3]] + + def check_global_vs_segment(): + data = i1.sql( + """ + select a from g + union all + select c from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_segment) + + def check_segment_vs_global(): + data = i1.sql( + """ + select c from s + union all + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_segment_vs_global) + + expected = [[1], [2], [3], [6]] + + def check_single_vs_global(): + data = i1.sql( + """ + select sum(c) from s + union all + select a from g + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_single_vs_global) + + def check_global_vs_single(): + data = i1.sql( + """ + select a from g + union all + select sum(c) from s + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == expected + + Retriable(rps=5, timeout=6).call(check_global_vs_single) + + # some arbitrary queries + + def check_multiple_union_all(): + data = i1.sql( + """ + select * from ( + select a from g + where a = 2 + union all + select d from s + group by d + ) union all + select a from g + where b = 1 + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2], [2]] + + Retriable(rps=5, timeout=6).call(check_multiple_union_all) + + def check_union_with_where(): + data = i1.sql( + """ + select a from g + where a in (select d from s) + union all + select c from s + where c = 3 + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[2], [3]] + + Retriable(rps=5, timeout=6).call(check_union_with_where) + + def check_complex_segment_child(): + data = i1.sql( + """ + select a, b from g + where a in (select d from s) + union all + select d, sum(u) from s + inner join (select c as u from s) + on d = u or u = 1 + group by d + """, + timeout=2, + ) + assert sorted(data["rows"], key=lambda x: x[0]) == [[2, 2], [2, 9]] + + Retriable(rps=5, timeout=6).call(check_complex_segment_child) + + def test_hash(cluster: Cluster): cluster.deploy(instance_count=1) i1 = cluster.instances[0] -- GitLab