From 7857a936f977e93bf1c4ffc1908b652c2c941647 Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Tue, 31 Jan 2023 15:20:40 +0300
Subject: [PATCH] fix: sbroad fails to plan query with nested sq

When we gathered SQs we traversed only a subset of relational
nodes in `rel_iter`, because in `rel_iter` we didn't go into
children that are located in `filter` or `condition`. Now we
traverse all relational nodes in the plan.
---
 sbroad-core/src/frontend/sql/ir.rs       | 12 +++++-----
 sbroad-core/src/frontend/sql/ir/tests.rs | 28 ++++++++++++++++++++++++
 sbroad-core/src/ir.rs                    |  5 +++++
 3 files changed, 38 insertions(+), 7 deletions(-)

diff --git a/sbroad-core/src/frontend/sql/ir.rs b/sbroad-core/src/frontend/sql/ir.rs
index 3da97977b4..022e1717f5 100644
--- a/sbroad-core/src/frontend/sql/ir.rs
+++ b/sbroad-core/src/frontend/sql/ir.rs
@@ -155,12 +155,10 @@ impl SubQuery {
 impl Plan {
     fn gather_sq_for_replacement(&self) -> Result<HashSet<SubQuery, RepeatableState>, SbroadError> {
         let mut set: HashSet<SubQuery, RepeatableState> = HashSet::with_hasher(RepeatableState);
-        let top = self.get_top()?;
-        let mut rel_post = PostOrder::with_capacity(|node| self.nodes.rel_iter(node), REL_CAPACITY);
         // Traverse expression trees of the selection and join nodes.
         // Gather all sub-queries in the boolean expressions there.
-        for (_, rel_id) in rel_post.iter(top) {
-            match self.get_node(rel_id)? {
+        for (id, node) in self.nodes.iter().enumerate() {
+            match node {
                 Node::Relational(
                     Relational::Selection { filter: tree, .. }
                     | Relational::InnerJoin {
@@ -172,16 +170,16 @@ impl Plan {
                         |node| self.nodes.expr_iter(node, false),
                         capacity,
                     );
-                    for (_, id) in expr_post.iter(*tree) {
+                    for (_, op_id) in expr_post.iter(*tree) {
                         if let Node::Expression(Expression::Bool { left, right, .. }) =
-                            self.get_node(id)?
+                            self.get_node(op_id)?
                         {
                             let children = &[*left, *right];
                             for child in children {
                                 if let Node::Relational(Relational::ScanSubQuery { .. }) =
                                     self.get_node(*child)?
                                 {
-                                    set.insert(SubQuery::new(rel_id, id, *child));
+                                    set.insert(SubQuery::new(id, op_id, *child));
                                 }
                             }
                         }
diff --git a/sbroad-core/src/frontend/sql/ir/tests.rs b/sbroad-core/src/frontend/sql/ir/tests.rs
index 3009a16575..3c22ff3909 100644
--- a/sbroad-core/src/frontend/sql/ir/tests.rs
+++ b/sbroad-core/src/frontend/sql/ir/tests.rs
@@ -359,5 +359,33 @@ fn front_sql20() {
     assert_eq!(expected_explain, plan.as_explain().unwrap());
 }
 
+#[test]
+fn front_sql_nested_subqueries() {
+    let input = r#"SELECT "a" FROM "t"
+        WHERE "a" in (SELECT "a" FROM "t1" WHERE "a" in (SELECT "b" FROM "t1"))"#;
+
+    let plan = sql_to_optimized_ir(input, vec![]);
+
+    let expected_explain = String::from(
+        r#"projection ("t"."a" -> "a")
+    selection ROW("t"."a") in ROW($1)
+        scan "t"
+subquery $0:
+motion [policy: full, generation: none]
+                            scan
+                                projection ("t1"."b" -> "b")
+                                    scan "t1"
+subquery $1:
+motion [policy: full, generation: none]
+            scan
+                projection ("t1"."a" -> "a")
+                    selection ROW("t1"."a") in ROW($0)
+                        scan "t1"
+"#,
+    );
+
+    assert_eq!(expected_explain, plan.as_explain().unwrap());
+}
+
 #[cfg(test)]
 mod params;
diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs
index 6c92043236..fc44e9081f 100644
--- a/sbroad-core/src/ir.rs
+++ b/sbroad-core/src/ir.rs
@@ -4,6 +4,7 @@
 
 use base64ct::{Base64, Encoding};
 use serde::{Deserialize, Serialize};
+use std::slice::Iter;
 
 use expression::Expression;
 use operator::Relational;
@@ -85,6 +86,10 @@ impl Nodes {
         self.arena.len()
     }
 
+    pub fn iter(&self) -> Iter<'_, Node> {
+        self.arena.iter()
+    }
+
     /// Add new node to arena.
     ///
     /// Inserts a new node to the arena and returns its position,
-- 
GitLab