From 36b54a598c9891ad5ac3ebe6aa1f9547be457f5a Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Sun, 21 Apr 2024 19:08:07 +0000
Subject: [PATCH] feat(sql): support UNION

- update sbroad submodule to commit
with support of UNION operator
- add integration tests for global
tables (sharded tables were tested
in cartridge tests)
- Usage:

`select a from t union select b from t2`
---
 sbroad               |   2 +-
 src/sql/pgproto.rs   |   2 +
 test/int/test_sql.py | 168 +++++++++++++++++++++++++++++++++++++++++--
 3 files changed, 164 insertions(+), 8 deletions(-)

diff --git a/sbroad b/sbroad
index 0c57585af4..f5b5236d24 160000
--- a/sbroad
+++ b/sbroad
@@ -1 +1 @@
-Subproject commit 0c57585af4fbfe449b0c0fdcdd56a3d84d766169
+Subproject commit f5b5236d24a6974eac2efc238e754079c0274aba
diff --git a/src/sql/pgproto.rs b/src/sql/pgproto.rs
index 3d5af4b8bd..ce2b72d7b4 100644
--- a/src/sql/pgproto.rs
+++ b/src/sql/pgproto.rs
@@ -677,8 +677,10 @@ impl TryFrom<&Node> for CommandTag {
                 | Relational::GroupBy { .. }
                 | Relational::OrderBy { .. }
                 | Relational::Having { .. }
+                | Relational::Union { .. }
                 | Relational::UnionAll { .. }
                 | Relational::Values { .. }
+                | Relational::OrderBy { .. }
                 | Relational::ValuesRow { .. } => Ok(CommandTag::Select),
             },
             Node::Expression(_) | Node::Parameter => Err(SbroadError::Invalid(
diff --git a/test/int/test_sql.py b/test/int/test_sql.py
index b482bdd947..87d459befa 100644
--- a/test/int/test_sql.py
+++ b/test/int/test_sql.py
@@ -1034,13 +1034,12 @@ def test_union_all_on_global_tbls(cluster: Cluster):
     def check_multiple_union_all():
         data = i1.sql(
             """
-            select * from (
-                select a from g
-                where a = 2
-                union all
-                select d from s
-                group by d
-            ) union all
+            select a from g
+            where a = 2
+            union all
+            select d from s
+            group by d
+            union all
             select a from g
             where b = 1
             """,
@@ -1083,6 +1082,161 @@ def test_union_all_on_global_tbls(cluster: Cluster):
     Retriable(rps=5, timeout=6).call(check_complex_segment_child)
 
 
+def test_union_on_global_tbls(cluster: Cluster):
+    cluster.deploy(instance_count=1)
+    i1 = cluster.instances[0]
+
+    ddl = i1.sql(
+        """
+        create table g (a int not null, b int not null, primary key (a))
+        using memtx
+        distributed globally
+        option (timeout = 3)
+        """
+    )
+    assert ddl["row_count"] == 1
+
+    for i, j in [(1, 1), (2, 2), (3, 2)]:
+        index = i1.cas("insert", "G", [i, j])
+        i1.raft_wait_index(index, 3)
+
+    ddl = i1.sql(
+        """
+        create table s (c int not null, d int not null, primary key (c))
+        using memtx
+        distributed by (c)
+        option (timeout = 3)
+    """
+    )
+    assert ddl["row_count"] == 1
+    data = i1.sql("""insert into s values (1, 2), (2, 2), (3, 2);""")
+    assert data["row_count"] == 3
+
+    expected = [[1], [2]]
+
+    def check_global_vs_any():
+        data = i1.sql(
+            """
+            select b from g
+            union
+            select d from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_any)
+
+    def check_any_vs_global():
+        data = i1.sql(
+            """
+            select d from s
+            union
+            select b from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_any_vs_global)
+
+    def check_global_vs_global():
+        data = i1.sql(
+            """
+            select b from g
+            union
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [
+            [1],
+            [2],
+            [3],
+        ]
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_global)
+
+    expected = [[1], [2], [3]]
+
+    def check_global_vs_segment():
+        data = i1.sql(
+            """
+            select a from g
+            union
+            select c from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_segment)
+
+    def check_segment_vs_global():
+        data = i1.sql(
+            """
+            select c from s
+            union
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_segment_vs_global)
+
+    expected = [[1], [2], [3]]
+
+    def check_single_vs_global():
+        data = i1.sql(
+            """
+            select sum(c) - 3 from s
+            union
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_single_vs_global)
+
+    def check_global_vs_single():
+        data = i1.sql(
+            """
+            select a from g
+            union
+            select sum(c) - 3 from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_single)
+
+    def check_multiple_union():
+        data = i1.sql(
+            """
+            select a from g
+            where a = 2
+            union
+            select d from s
+            group by d
+            union
+            select a from g
+            where b = 1
+            except
+            select null from g
+            where false
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]]
+
+    Retriable(rps=5, timeout=6).call(check_multiple_union)
+
+def test_select_chaining():
+
+
 def test_trim(instance: Instance):
     instance.sql(
         """
-- 
GitLab