diff --git a/sbroad b/sbroad
index 0c57585af4fbfe449b0c0fdcdd56a3d84d766169..f5b5236d24a6974eac2efc238e754079c0274aba 160000
--- a/sbroad
+++ b/sbroad
@@ -1 +1 @@
-Subproject commit 0c57585af4fbfe449b0c0fdcdd56a3d84d766169
+Subproject commit f5b5236d24a6974eac2efc238e754079c0274aba
diff --git a/src/sql/pgproto.rs b/src/sql/pgproto.rs
index 3d5af4b8bd1d43ee6d93b192f6575ba1ccd9dba5..ce2b72d7b47ba82114ce0bf91f6fe1a0b3923c43 100644
--- a/src/sql/pgproto.rs
+++ b/src/sql/pgproto.rs
@@ -677,8 +677,10 @@ impl TryFrom<&Node> for CommandTag {
                 | Relational::GroupBy { .. }
                 | Relational::OrderBy { .. }
                 | Relational::Having { .. }
+                | Relational::Union { .. }
                 | Relational::UnionAll { .. }
                 | Relational::Values { .. }
+                | Relational::OrderBy { .. }
                 | Relational::ValuesRow { .. } => Ok(CommandTag::Select),
             },
             Node::Expression(_) | Node::Parameter => Err(SbroadError::Invalid(
diff --git a/test/conftest.py b/test/conftest.py
index ea100309b1b970eebc42259c5a28973ac039d46b..a0a345c48b18ba177afe862d6a5d311ba6a2bff8 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1892,7 +1892,6 @@ class AuditServer:
             return None
 
         def server(queue: Queue, host: str, port: int) -> None:
-
             class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
                 QUEUE = queue
 
diff --git a/test/int/test_sql.py b/test/int/test_sql.py
index b482bdd9472b0be3f2bae966245a9b4231e5def0..e3eb08e43d7a8d406de3e476c798174f50dc51c0 100644
--- a/test/int/test_sql.py
+++ b/test/int/test_sql.py
@@ -1034,13 +1034,12 @@ def test_union_all_on_global_tbls(cluster: Cluster):
     def check_multiple_union_all():
         data = i1.sql(
             """
-            select * from (
-                select a from g
-                where a = 2
-                union all
-                select d from s
-                group by d
-            ) union all
+            select a from g
+            where a = 2
+            union all
+            select d from s
+            group by d
+            union all
             select a from g
             where b = 1
             """,
@@ -1083,6 +1082,159 @@ def test_union_all_on_global_tbls(cluster: Cluster):
     Retriable(rps=5, timeout=6).call(check_complex_segment_child)
 
 
+def test_union_on_global_tbls(cluster: Cluster):
+    cluster.deploy(instance_count=1)
+    i1 = cluster.instances[0]
+
+    ddl = i1.sql(
+        """
+        create table g (a int not null, b int not null, primary key (a))
+        using memtx
+        distributed globally
+        option (timeout = 3)
+        """
+    )
+    assert ddl["row_count"] == 1
+
+    for i, j in [(1, 1), (2, 2), (3, 2)]:
+        index = i1.cas("insert", "G", [i, j])
+        i1.raft_wait_index(index, 3)
+
+    ddl = i1.sql(
+        """
+        create table s (c int not null, d int not null, primary key (c))
+        using memtx
+        distributed by (c)
+        option (timeout = 3)
+    """
+    )
+    assert ddl["row_count"] == 1
+    data = i1.sql("""insert into s values (1, 2), (2, 2), (3, 2);""")
+    assert data["row_count"] == 3
+
+    expected = [[1], [2]]
+
+    def check_global_vs_any():
+        data = i1.sql(
+            """
+            select b from g
+            union
+            select d from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_any)
+
+    def check_any_vs_global():
+        data = i1.sql(
+            """
+            select d from s
+            union
+            select b from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_any_vs_global)
+
+    def check_global_vs_global():
+        data = i1.sql(
+            """
+            select b from g
+            union
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [
+            [1],
+            [2],
+            [3],
+        ]
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_global)
+
+    expected = [[1], [2], [3]]
+
+    def check_global_vs_segment():
+        data = i1.sql(
+            """
+            select a from g
+            union
+            select c from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_segment)
+
+    def check_segment_vs_global():
+        data = i1.sql(
+            """
+            select c from s
+            union
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_segment_vs_global)
+
+    expected = [[1], [2], [3]]
+
+    def check_single_vs_global():
+        data = i1.sql(
+            """
+            select sum(c) - 3 from s
+            union
+            select a from g
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_single_vs_global)
+
+    def check_global_vs_single():
+        data = i1.sql(
+            """
+            select a from g
+            union
+            select sum(c) - 3 from s
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == expected
+
+    Retriable(rps=5, timeout=6).call(check_global_vs_single)
+
+    def check_multiple_union():
+        data = i1.sql(
+            """
+            select a from g
+            where a = 2
+            union
+            select d from s
+            group by d
+            union
+            select a from g
+            where b = 1
+            except
+            select null from g
+            where false
+            """,
+            timeout=2,
+        )
+        assert sorted(data["rows"], key=lambda x: x[0]) == [[1], [2]]
+
+    Retriable(rps=5, timeout=6).call(check_multiple_union)
+
+
 def test_trim(instance: Instance):
     instance.sql(
         """