From cb251e1e152e8361577b661590049e604c06583b Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Mon, 4 Dec 2023 01:07:45 +0300
Subject: [PATCH] feat: support union all with global tbls

- update sbroad submodule to commit
supporting union all with global tbls
- add integration tests
---
 sbroad               |   2 +-
 src/sql/router.rs    |   4 +-
 test/int/test_sql.py | 196 +++++++++++++++++++++++++++++++++++++++++--
 3 files changed, 192 insertions(+), 10 deletions(-)

diff --git a/sbroad b/sbroad
index 46accea5f6..4cb04bfc0b 160000
--- a/sbroad
+++ b/sbroad
@@ -1 +1 @@
-Subproject commit 46accea5f6bb77d77e8b0963d8d6f9a0b54b3a6d
+Subproject commit 4cb04bfc0b957aee2d1529db276e3ead758ffe62
diff --git a/src/sql/router.rs b/src/sql/router.rs
index 8e5e6af4da..4fb164eabe 100644
--- a/src/sql/router.rs
+++ b/src/sql/router.rs
@@ -7,7 +7,7 @@ use sbroad::executor::bucket::Buckets;
 use sbroad::executor::engine::helpers::vshard::{
     exec_ir_on_all_buckets, exec_ir_on_some_buckets, get_random_bucket,
 };
-use sbroad::executor::engine::helpers::{dispatch, explain_format, materialize_motion};
+use sbroad::executor::engine::helpers::{dispatch_impl, explain_format, materialize_motion};
 use sbroad::executor::engine::helpers::{sharding_key_from_map, sharding_key_from_tuple};
 use sbroad::executor::engine::{QueryCache, Router, Vshard};
 use sbroad::executor::ir::{ConnectionType, ExecutionPlan, QueryType};
@@ -235,7 +235,7 @@ impl Router for RouterRuntime {
         top_id: usize,
         buckets: &sbroad::executor::bucket::Buckets,
     ) -> Result<Box<dyn std::any::Any>, SbroadError> {
-        dispatch(self, plan, top_id, buckets)
+        dispatch_impl(self, plan, top_id, buckets)
     }
 
     fn explain_format(&self, explain: String) -> Result<Box<dyn std::any::Any>, SbroadError> {
diff --git a/test/int/test_sql.py b/test/int/test_sql.py
index 07e08fe5b3..26dd30464e 100644
--- a/test/int/test_sql.py
+++ b/test/int/test_sql.py
@@ -1,13 +1,7 @@
 import pytest
 import re
 
-from conftest import (
-    Cluster,
-    KeyDef,
-    KeyPart,
-    Retriable,
-    ReturnError,
-)
+from conftest import Cluster, KeyDef, KeyPart, ReturnError, Retriable
 
 
 def test_pico_sql(cluster: Cluster):
@@ -510,6 +504,194 @@ def test_join_with_global_tbls(cluster: Cluster):
     Retriable(rps=5, timeout=6).call(check_left_join_complex_global_child_vs_any)
 
 
+def test_union_all_on_global_tbls(cluster: Cluster):
+    cluster.deploy(instance_count=1)
+    i1 = cluster.instances[0]
+
+    ddl = i1.sql(
+        """
+        create table g (a int not null, b int not null, primary key (a))
+        using memtx
+        distributed globally
+        option (timeout = 3)
+        """
+    )
+    assert ddl["row_count"] == 1
+
+    for i, j in [(1, 1), (2, 2), (3, 2)]:
+        index = i1.cas("insert", "G", [i, j])
+        i1.raft_wait_index(index, 3)
+
+    ddl = i1.sql(
+        """
+        create table s (c int not null, d int not null, primary key (c))
+        using memtx
+        distributed by (c)
+        option (timeout = 3)
+    """
+    )
+    assert ddl["row_count"] == 1
+    data = i1.sql("""insert into s values (1, 2), (2, 2), (3, 2);""")
+    assert data["row_count"] == 3
+
+    expected = [[1], [2], [2], [2], [2], [2]]
+
+    def check_global_vs_any():
+        data = i1.sql(
+            """
+            select b from g
+            union all
+            select d from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_any)
+
+    def check_any_vs_global():
+        data = i1.sql(
+            """
+            select d from s
+            union all
+            select b from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_any_vs_global)
+
+    def check_global_vs_global():
+        data = i1.sql(
+            """
+            select b from g
+            union all
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [
+            [1],
+            [1],
+            [2],
+            [2],
+            [2],
+            [3],
+        ]
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_global)
+
+    expected = [[1], [1], [2], [2], [3], [3]]
+
+    def check_global_vs_segment():
+        data = i1.sql(
+            """
+            select a from g
+            union all
+            select c from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_segment)
+
+    def check_segment_vs_global():
+        data = i1.sql(
+            """
+            select c from s
+            union all
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_segment_vs_global)
+
+    expected = [[1], [2], [3], [6]]
+
+    def check_single_vs_global():
+        data = i1.sql(
+            """
+            select sum(c) from s
+            union all
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_single_vs_global)
+
+    def check_global_vs_single():
+        data = i1.sql(
+            """
+            select a from g
+            union all
+            select sum(c) from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_single)
+
+    # some arbitrary queries
+
+    def check_multiple_union_all():
+        data = i1.sql(
+            """
+            select * from (
+                select a from g
+                where a = 2
+                union all
+                select d from s
+                group by d
+            ) union all
+            select a from g
+            where b = 1
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2], [2]]
+
+    Retriable(rps=5, timeout=6).call(check_multiple_union_all)
+
+    def check_union_with_where():
+        data = i1.sql(
+            """
+            select a from g
+            where a in (select d from s)
+            union all
+            select c from s
+            where c = 3
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [[2], [3]]
+
+    Retriable(rps=5, timeout=6).call(check_union_with_where)
+
+    def check_complex_segment_child():
+        data = i1.sql(
+            """
+            select a, b from g
+            where a in (select d from s)
+            union all
+            select d, sum(u) from s
+            inner join (select c as u from s)
+            on d = u or u = 1
+            group by d
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [[2, 2], [2, 9]]
+
+    Retriable(rps=5, timeout=6).call(check_complex_segment_child)
+
+
 def test_hash(cluster: Cluster):
     cluster.deploy(instance_count=1)
     i1 = cluster.instances[0]
-- 
GitLab