From c6da5b70d1c8a21e769d5ae0ad970db04ee82dc3 Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Fri, 24 Nov 2023 16:15:10 +0300
Subject: [PATCH] feat: suport join for global tables

---
 sbroad               |   2 +-
 test/int/test_sql.py | 195 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 196 insertions(+), 1 deletion(-)

diff --git a/sbroad b/sbroad
index e2b6207d15..46accea5f6 160000
--- a/sbroad
+++ b/sbroad
@@ -1 +1 @@
-Subproject commit e2b6207d15aa8ec6eacbdeb25b5c53dde000a1e1
+Subproject commit 46accea5f6bb77d77e8b0963d8d6f9a0b54b3a6d
diff --git a/test/int/test_sql.py b/test/int/test_sql.py
index a9e53a18ee..07e08fe5b3 100644
--- a/test/int/test_sql.py
+++ b/test/int/test_sql.py
@@ -315,6 +315,201 @@ def test_aggregates_on_global_tbl(cluster: Cluster):
     assert data["rows"] == [[1, 11]]
 
 
+def test_join_with_global_tbls(cluster: Cluster):
+    cluster.deploy(instance_count=1)
+    i1 = cluster.instances[0]
+
+    ddl = i1.sql(
+        """
+        create table g (a int not null, b int not null, primary key (a))
+        using memtx
+        distributed globally
+        option (timeout = 3)
+        """
+    )
+    assert ddl["row_count"] == 1
+
+    for i, j in [(1, 1), (2, 1), (3, 3)]:
+        index = i1.cas("insert", "G", [i, j])
+        i1.raft_wait_index(index, 3)
+
+    ddl = i1.sql(
+        """
+        create table s (c int not null, primary key (c))
+        using memtx
+        distributed by (c)
+        option (timeout = 3)
+        """
+    )
+    assert ddl["row_count"] == 1
+    data = i1.sql("""insert into s values (1), (2), (3), (4), (5);""")
+    assert data["row_count"] == 5
+
+    expected_rows = [[1], [1], [3]]
+
+    def check_inner_join_global_vs_segment():
+        data = i1.sql(
+            """
+            select b from g
+            join s on g.a = s.c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == expected_rows
+
+    Retriable(rps=5, timeout=6).call(check_inner_join_global_vs_segment)
+
+    def check_inner_join_segment_vs_global():
+        data = i1.sql(
+            """
+            select b from s
+            join g on g.a = s.c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == expected_rows
+
+    Retriable(rps=5, timeout=6).call(check_inner_join_segment_vs_global)
+
+    def check_inner_join_segment_vs_global_sq_in_cond():
+        data = i1.sql(
+            """
+            select c from s
+            join g on 1 = 1 and
+            c in (select a*a from g)
+            group by c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[1], [4]]
+
+    Retriable(rps=5, timeout=6).call(check_inner_join_segment_vs_global_sq_in_cond)
+
+    def check_left_join_segment_vs_global_sq_in_cond():
+        data = i1.sql(
+            """
+            select c, cast(sum(a) as int) from s
+            left join g on 1 = 1 and
+            c in (select a*a from g)
+            where c < 4
+            group by c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [
+            [1, 6],
+            [2, None],
+            [3, None],
+        ]
+
+    Retriable(rps=5, timeout=6).call(check_left_join_segment_vs_global_sq_in_cond)
+
+    def check_left_join_any_vs_global():
+        data = i1.sql(
+            """
+            select c, b from
+            (select c*c as c from s)
+            left join g on c = b
+            where c < 5
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[1, 1], [1, 1], [4, None]]
+
+    Retriable(rps=5, timeout=6).call(check_left_join_any_vs_global)
+
+    def check_inner_join_any_vs_global():
+        data = i1.sql(
+            """
+            select c, b from
+            (select c*c as c from s)
+            inner join g on c = b
+            where c < 5
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[1, 1], [1, 1]]
+
+    Retriable(rps=5, timeout=6).call(check_inner_join_any_vs_global)
+
+    def check_left_join_single_vs_global():
+        data = i1.sql(
+            """
+            select c, a from
+            (select count(*) as c from s)
+            left join g on c = a + 2
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[5, 3]]
+
+    Retriable(rps=5, timeout=6).call(check_left_join_single_vs_global)
+
+    def check_left_join_global_with_expr_in_proj_vs_segment():
+        data = i1.sql(
+            """
+            select b, c from (select b + 3 as b from g)
+            left join s on b = c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[4, 4], [4, 4], [6, None]]
+
+    Retriable(rps=5, timeout=6).call(
+        check_left_join_global_with_expr_in_proj_vs_segment
+    )
+
+    def check_left_join_global_vs_any_false_condition():
+        data = i1.sql(
+            """
+            select b, c from g
+            left join
+            (select c*c as c from s where c > 3)
+            on b = c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [
+            [1, None],
+            [1, None],
+            [3, None],
+        ]
+
+    Retriable(rps=5, timeout=6).call(check_left_join_global_vs_any_false_condition)
+
+    def check_left_join_global_vs_any():
+        data = i1.sql(
+            """
+            select b, c from g
+            left join
+            (select c*c as c from s where c < 3)
+            on b = c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[1, 1], [1, 1], [3, None]]
+
+    Retriable(rps=5, timeout=6).call(check_left_join_global_vs_any)
+
+    def check_left_join_complex_global_child_vs_any():
+        data = i1.sql(
+            """
+            select a, b, c from (
+                select a, b from g
+                inner join (select a + 2 as u from g)
+                on a = u
+            )
+            left join
+            (select c + 1 as c from s where c = 2)
+            on b = c
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda e: e[0]) == [[3, 3, 3]]
+
+    Retriable(rps=5, timeout=6).call(check_left_join_complex_global_child_vs_any)
+
+
 def test_hash(cluster: Cluster):
     cluster.deploy(instance_count=1)
     i1 = cluster.instances[0]
-- 
GitLab