From 57138bd794fed1d5e244d910c882849324c17e98 Mon Sep 17 00:00:00 2001
From: Denis Smirnov <sd@picodata.io>
Date: Wed, 24 Aug 2022 11:02:11 +0700
Subject: [PATCH] fix: subtree iterator traversal order

Previously we had an error in the parameter binding order when
a query had sub-queries in filters or join conditions.
---
 .../engine/cartridge/backend/sql/ir.rs        |   2 +-
 .../engine/cartridge/backend/sql/tree.rs      |   2 +-
 src/executor/shard.rs                         |   2 +-
 src/frontend/sql.rs                           |   2 +-
 src/frontend/sql/ir/tests.rs                  |  74 +---------
 src/frontend/sql/ir/tests/params.rs           | 128 ++++++++++++++++++
 src/ir/explain.rs                             |   4 +-
 src/ir/expression.rs                          |  32 +++--
 src/ir/operator.rs                            |   2 +-
 src/ir/tree.rs                                |  67 +++++++--
 10 files changed, 216 insertions(+), 99 deletions(-)
 create mode 100644 src/frontend/sql/ir/tests/params.rs

diff --git a/src/executor/engine/cartridge/backend/sql/ir.rs b/src/executor/engine/cartridge/backend/sql/ir.rs
index 7262894e60..a79b248ea8 100644
--- a/src/executor/engine/cartridge/backend/sql/ir.rs
+++ b/src/executor/engine/cartridge/backend/sql/ir.rs
@@ -169,7 +169,7 @@ impl ExecutionPlan {
                             }
                             Expression::Reference { position, .. } => {
                                 let rel_id: usize =
-                                    ir_plan.get_relational_from_reference_node(*id)?;
+                                    *ir_plan.get_relational_from_reference_node(*id)?;
                                 let rel_node = ir_plan.get_relation_node(rel_id)?;
                                 let alias = &ir_plan.get_alias_from_reference_node(expr)?;
 
diff --git a/src/executor/engine/cartridge/backend/sql/tree.rs b/src/executor/engine/cartridge/backend/sql/tree.rs
index 937df293f3..8a312f847e 100644
--- a/src/executor/engine/cartridge/backend/sql/tree.rs
+++ b/src/executor/engine/cartridge/backend/sql/tree.rs
@@ -863,7 +863,7 @@ impl<'p> SyntaxPlan<'p> {
         let ir_plan = plan.get_ir_plan();
 
         // Wrap plan's nodes and preserve their ids.
-        let dft_post = DftPost::new(&top, |node| ir_plan.nodes.subtree_iter(node));
+        let dft_post = DftPost::new(&top, |node| ir_plan.subtree_iter(node));
         for (_, id) in dft_post {
             // it works only for post-order traversal
             let sn_id = sp.add_plan_node(*id)?;
diff --git a/src/executor/shard.rs b/src/executor/shard.rs
index d4f5da62b1..99d0f03e8d 100644
--- a/src/executor/shard.rs
+++ b/src/executor/shard.rs
@@ -15,7 +15,7 @@ impl<'e> ExecutionPlan<'e> {
         let mut nodes: Vec<usize> = Vec::new();
         let ir_plan = self.get_ir_plan();
 
-        let post_tree = DftPost::new(&top_node_id, |node| ir_plan.nodes.subtree_iter(node));
+        let post_tree = DftPost::new(&top_node_id, |node| ir_plan.subtree_iter(node));
         for (_, node_id) in post_tree {
             if ir_plan.is_bool_eq_with_rows(*node_id) {
                 nodes.push(*node_id);
diff --git a/src/frontend/sql.rs b/src/frontend/sql.rs
index f40efddb5e..b6f8ee3a8a 100644
--- a/src/frontend/sql.rs
+++ b/src/frontend/sql.rs
@@ -765,7 +765,7 @@ impl Plan {
         }
 
         let top_id = self.get_top()?;
-        let tree = DftPost::new(&top_id, |node| self.nodes.subtree_iter(node));
+        let tree = DftPost::new(&top_id, |node| self.subtree_iter(node));
         let nodes: Vec<usize> = tree.map(|(_, id)| *id).collect();
 
         // Transform parameters to values. The result values are stored in the
diff --git a/src/frontend/sql/ir/tests.rs b/src/frontend/sql/ir/tests.rs
index 82eacc5091..c64a708313 100644
--- a/src/frontend/sql/ir/tests.rs
+++ b/src/frontend/sql/ir/tests.rs
@@ -8,7 +8,7 @@ use crate::ir::value::Value;
 use crate::ir::Plan;
 use pretty_assertions::assert_eq;
 
-fn no_transform(_plan: &mut Plan) {}
+pub(super) fn no_transform(_plan: &mut Plan) {}
 
 #[test]
 fn front_sql1() {
@@ -420,73 +420,5 @@ fn front_sql19() {
     assert_eq!(sql_to_sql(input, vec![], &no_transform), expected);
 }
 
-#[test]
-fn front_params1() {
-    let pattern = r#"SELECT "id", "FIRST_NAME" FROM "test_space"
-        WHERE "sys_op" = ? AND "sysFrom" > ?"#;
-    let params = vec![Value::from(0_i64), Value::from(1_i64)];
-    let expected = PatternWithParams::new(
-        format!(
-            "{} {}",
-            r#"SELECT "test_space"."id", "test_space"."FIRST_NAME" FROM "test_space""#,
-            r#"WHERE ("test_space"."sys_op") = (?) and ("test_space"."sysFrom") > (?)"#,
-        ),
-        params.clone(),
-    );
-
-    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
-}
-
-#[test]
-fn front_params2() {
-    let pattern = r#"SELECT "id" FROM "test_space"
-        WHERE "sys_op" = ? AND "FIRST_NAME" = ?"#;
-    let params = vec![Value::Null, Value::from("hello")];
-    let expected = PatternWithParams::new(
-        format!(
-            "{} {}",
-            r#"SELECT "test_space"."id" FROM "test_space""#,
-            r#"WHERE ("test_space"."sys_op") = (?) and ("test_space"."FIRST_NAME") = (?)"#,
-        ),
-        params.clone(),
-    );
-
-    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
-}
-
-// check cyrillic params support
-#[test]
-fn front_params3() {
-    let pattern = r#"SELECT "id" FROM "test_space"
-        WHERE "sys_op" = ? AND "FIRST_NAME" = ?"#;
-    let params = vec![Value::Null, Value::from("кириллица")];
-    let expected = PatternWithParams::new(
-        format!(
-            "{} {}",
-            r#"SELECT "test_space"."id" FROM "test_space""#,
-            r#"WHERE ("test_space"."sys_op") = (?) and ("test_space"."FIRST_NAME") = (?)"#,
-        ),
-        params.clone(),
-    );
-
-    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
-}
-
-// check symbols in values (grammar)
-#[test]
-fn front_params4() {
-    let pattern = r#"SELECT "id" FROM "test_space"
-        WHERE "FIRST_NAME" = '''± !@#$%^&*()_+=-\/><";:,.`~'"#;
-
-    let params = vec![Value::from(r#"''± !@#$%^&*()_+=-\/><";:,.`~"#)];
-    let expected = PatternWithParams::new(
-        format!(
-            "{} {}",
-            r#"SELECT "test_space"."id" FROM "test_space""#,
-            r#"WHERE ("test_space"."FIRST_NAME") = (?)"#,
-        ),
-        params,
-    );
-
-    assert_eq!(sql_to_sql(pattern, vec![], &no_transform), expected);
-}
+#[cfg(test)]
+mod params;
diff --git a/src/frontend/sql/ir/tests/params.rs b/src/frontend/sql/ir/tests/params.rs
new file mode 100644
index 0000000000..ec2fd0233b
--- /dev/null
+++ b/src/frontend/sql/ir/tests/params.rs
@@ -0,0 +1,128 @@
+use super::*;
+use crate::executor::engine::cartridge::backend::sql::ir::PatternWithParams;
+use crate::ir::transformation::helpers::sql_to_sql;
+use crate::ir::value::Value;
+use pretty_assertions::assert_eq;
+
+#[test]
+fn front_params1() {
+    let pattern = r#"SELECT "id", "FIRST_NAME" FROM "test_space"
+        WHERE "sys_op" = ? AND "sysFrom" > ?"#;
+    let params = vec![Value::from(0_i64), Value::from(1_i64)];
+    let expected = PatternWithParams::new(
+        format!(
+            "{} {}",
+            r#"SELECT "test_space"."id", "test_space"."FIRST_NAME" FROM "test_space""#,
+            r#"WHERE ("test_space"."sys_op") = (?) and ("test_space"."sysFrom") > (?)"#,
+        ),
+        params.clone(),
+    );
+
+    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
+}
+
+#[test]
+fn front_params2() {
+    let pattern = r#"SELECT "id" FROM "test_space"
+        WHERE "sys_op" = ? AND "FIRST_NAME" = ?"#;
+    let params = vec![Value::Null, Value::from("hello")];
+    let expected = PatternWithParams::new(
+        format!(
+            "{} {}",
+            r#"SELECT "test_space"."id" FROM "test_space""#,
+            r#"WHERE ("test_space"."sys_op") = (?) and ("test_space"."FIRST_NAME") = (?)"#,
+        ),
+        params.clone(),
+    );
+
+    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
+}
+
+// check cyrillic params support
+#[test]
+fn front_params3() {
+    let pattern = r#"SELECT "id" FROM "test_space"
+        WHERE "sys_op" = ? AND "FIRST_NAME" = ?"#;
+    let params = vec![Value::Null, Value::from("кириллица")];
+    let expected = PatternWithParams::new(
+        format!(
+            "{} {}",
+            r#"SELECT "test_space"."id" FROM "test_space""#,
+            r#"WHERE ("test_space"."sys_op") = (?) and ("test_space"."FIRST_NAME") = (?)"#,
+        ),
+        params.clone(),
+    );
+
+    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
+}
+
+// check symbols in values (grammar)
+#[test]
+fn front_params4() {
+    let pattern = r#"SELECT "id" FROM "test_space"
+        WHERE "FIRST_NAME" = '''± !@#$%^&*()_+=-\/><";:,.`~'"#;
+
+    let params = vec![Value::from(r#"''± !@#$%^&*()_+=-\/><";:,.`~"#)];
+    let expected = PatternWithParams::new(
+        format!(
+            "{} {}",
+            r#"SELECT "test_space"."id" FROM "test_space""#,
+            r#"WHERE ("test_space"."FIRST_NAME") = (?)"#,
+        ),
+        params,
+    );
+
+    assert_eq!(sql_to_sql(pattern, vec![], &no_transform), expected);
+}
+
+// check parameter binding order, when selection has sub-queries
+#[test]
+fn front_params5() {
+    let pattern = r#"
+        SELECT "id" FROM "test_space"
+        WHERE "sys_op" = ? OR "id" IN (
+            SELECT "sysFrom" FROM "test_space_hist"
+            WHERE "sys_op" = ?
+        )
+    "#;
+    let params = vec![Value::from(0_i64), Value::from(1_i64)];
+    let expected = PatternWithParams::new(
+        format!(
+            "{} {} {}",
+            r#"SELECT "test_space"."id" FROM "test_space""#,
+            r#"WHERE (("test_space"."sys_op") = (?) or ("test_space"."id") in"#,
+            r#"(SELECT "test_space_hist"."sysFrom" FROM "test_space_hist" WHERE ("test_space_hist"."sys_op") = (?)))"#,
+        ),
+        params.clone(),
+    );
+
+    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
+}
+
+#[test]
+fn front_params6() {
+    let pattern = r#"
+        SELECT "id" FROM "test_space"
+        WHERE "sys_op" = ? OR "id" NOT IN (
+            SELECT "id" FROM "test_space"
+            WHERE "sys_op" = ?
+            UNION ALL
+            SELECT "id" FROM "test_space"
+            WHERE "sys_op" = ?
+        )
+    "#;
+    let params = vec![Value::from(0_i64), Value::from(1_i64), Value::from(2_i64)];
+    let expected = PatternWithParams::new(
+        format!(
+            "{} {} {} {} {}",
+            r#"SELECT "test_space"."id" FROM "test_space""#,
+            r#"WHERE (("test_space"."sys_op") = (?) or ("test_space"."id") not in"#,
+            r#"(SELECT "test_space"."id" FROM "test_space" WHERE ("test_space"."sys_op") = (?)"#,
+            r#"UNION ALL"#,
+            r#"SELECT "test_space"."id" FROM "test_space" WHERE ("test_space"."sys_op") = (?)))"#,
+        ),
+        params.clone(),
+    );
+
+    assert_eq!(sql_to_sql(pattern, params, &no_transform), expected);
+}
diff --git a/src/ir/explain.rs b/src/ir/explain.rs
index b25ab54856..f0a2f26e52 100644
--- a/src/ir/explain.rs
+++ b/src/ir/explain.rs
@@ -47,7 +47,7 @@ impl Col {
                     )));
                 }
                 Expression::Reference { position, .. } => {
-                    let rel_id: usize = plan.get_relational_from_reference_node(*id)?;
+                    let rel_id: usize = *plan.get_relational_from_reference_node(*id)?;
 
                     let rel_node = plan.get_relation_node(rel_id)?;
                     let alias = plan.get_alias_from_reference_node(current_node)?;
@@ -228,7 +228,7 @@ impl Row {
                     )));
                 }
                 Expression::Reference { position, .. } => {
-                    let rel_id: usize = plan.get_relational_from_reference_node(*child)?;
+                    let rel_id: usize = *plan.get_relational_from_reference_node(*child)?;
 
                     let rel_node = plan.get_relation_node(rel_id)?;
 
diff --git a/src/ir/expression.rs b/src/ir/expression.rs
index d56e2be3b5..2a693b49aa 100644
--- a/src/ir/expression.rs
+++ b/src/ir/expression.rs
@@ -189,6 +189,21 @@ impl Expression {
         ))
     }
 
+    /// Gets relational node id containing the reference.
+    ///
+    /// # Errors
+    /// - node isn't reference type
+    /// - reference doesn't have a parent
+    pub fn get_parent(&self) -> Result<usize, QueryPlannerError> {
+        if let Expression::Reference { parent, .. } = self {
+            return parent
+                .ok_or_else(|| QueryPlannerError::CustomError("Reference has no parent".into()));
+        }
+        Err(QueryPlannerError::CustomError(
+            "Node isn't reference type".into(),
+        ))
+    }
+
     /// The node is a row expression.
     #[must_use]
     pub fn is_row(&self) -> bool {
@@ -681,25 +696,24 @@ impl Plan {
         Ok(self.nodes.add_row(list, None))
     }
 
-    /// A list of relational nodes that makes up the reference.
+    /// A relational node pointed by the reference.
     ///
     /// # Errors
     /// - reference is invalid
-    /// - `relational_map` is not initialized
     pub fn get_relational_from_reference_node(
         &self,
         ref_id: usize,
-    ) -> Result<usize, QueryPlannerError> {
+    ) -> Result<&usize, QueryPlannerError> {
         if let Node::Expression(Expression::Reference {
             targets, parent, ..
         }) = self.get_node(ref_id)?
         {
-            let referred_rel_id = parent.ok_or(QueryPlannerError::CustomError(
-                "Reference node has no parent".into(),
-            ))?;
-            let rel = self.get_relation_node(referred_rel_id)?;
+            let referred_rel_id = parent.as_ref().ok_or_else(|| {
+                QueryPlannerError::CustomError("Reference node has no parent".into())
+            });
+            let rel = self.get_relation_node(*referred_rel_id.clone()?)?;
             if let Relational::Insert { .. } = rel {
-                return Ok(referred_rel_id);
+                return referred_rel_id;
             } else if let Some(children) = rel.children() {
                 match targets {
                     None => {
@@ -709,7 +723,7 @@ impl Plan {
                     }
                     Some(positions) => match (positions.get(0), positions.get(1)) {
                         (Some(first), None) => {
-                            let child_id = *children.get(*first).ok_or_else(|| {
+                            let child_id = children.get(*first).ok_or_else(|| {
                                 QueryPlannerError::CustomError(
                                     "Relational node has no child at first position".into(),
                                 )
diff --git a/src/ir/operator.rs b/src/ir/operator.rs
index 744ad97b5f..a7319a102a 100644
--- a/src/ir/operator.rs
+++ b/src/ir/operator.rs
@@ -410,7 +410,7 @@ impl Relational {
                 if let Expression::Alias { child, .. } = col_node {
                     let child_node = plan.get_expression_node(*child)?;
                     if let Expression::Reference { position: pos, .. } = child_node {
-                        let rel_id = plan.get_relational_from_reference_node(*child)?;
+                        let rel_id = *plan.get_relational_from_reference_node(*child)?;
                         let rel_node = plan.get_relation_node(rel_id)?;
                         return rel_node.scan_name(plan, *pos);
                     }
diff --git a/src/ir/tree.rs b/src/ir/tree.rs
index 4a07e8bac4..75778493b9 100644
--- a/src/ir/tree.rs
+++ b/src/ir/tree.rs
@@ -5,7 +5,7 @@ use std::cmp::Ordering;
 
 use super::expression::Expression;
 use super::operator::{Bool, Relational};
-use super::{Node, Nodes};
+use super::{Node, Nodes, Plan};
 
 /// Relational node's child iterator.
 ///
@@ -31,10 +31,10 @@ pub struct ExpressionIterator<'n> {
 
 /// Expression and relational nodes iterator.
 #[derive(Debug)]
-pub struct SubtreeIterator<'n> {
-    current: &'n usize,
+pub struct SubtreeIterator<'p> {
+    current: &'p usize,
     child: RefCell<usize>,
-    nodes: &'n Nodes,
+    plan: &'p Plan,
 }
 
 /// Children iterator for "and"-ed equivalent expressions.
@@ -95,13 +95,15 @@ impl<'n> Nodes {
             nodes: self,
         }
     }
+}
 
+impl<'p> Plan {
     #[must_use]
-    pub fn subtree_iter(&'n self, current: &'n usize) -> SubtreeIterator<'n> {
+    pub fn subtree_iter(&'p self, current: &'p usize) -> SubtreeIterator<'p> {
         SubtreeIterator {
             current,
             child: RefCell::new(0),
-            nodes: self,
+            plan: self,
         }
     }
 }
@@ -263,12 +265,12 @@ impl<'n> Iterator for RelationalIterator<'n> {
     }
 }
 
-impl<'n> Iterator for SubtreeIterator<'n> {
-    type Item = &'n usize;
+impl<'p> Iterator for SubtreeIterator<'p> {
+    type Item = &'p usize;
 
     #[allow(clippy::too_many_lines)]
     fn next(&mut self) -> Option<Self::Item> {
-        if let Some(child) = self.nodes.arena.get(*self.current) {
+        if let Some(child) = self.plan.nodes.arena.get(*self.current) {
             return match child {
                 Node::Parameter => None,
                 Node::Expression(exp) => match exp {
@@ -301,7 +303,48 @@ impl<'n> Iterator for SubtreeIterator<'n> {
                             }
                         };
                     }
-                    Expression::Constant { .. } | Expression::Reference { .. } => None,
+                    Expression::Constant { .. } => None,
+                    Expression::Reference { .. } => {
+                        let step = *self.child.borrow();
+                        if step == 0 {
+                            *self.child.borrow_mut() += 1;
+
+                            // At first we need to detect the place where the reference is used:
+                            // for selection filter or a join condition, we need to check whether
+                            // the reference points to an **additional** sub-query and then traverse
+                            // into it. Otherwise, stop traversal.
+                            let parent_id = match exp.get_parent() {
+                                Ok(parent_id) => parent_id,
+                                Err(_) => return None,
+                            };
+                            if let Ok(rel_id) =
+                                self.plan.get_relational_from_reference_node(*self.current)
+                            {
+                                match self.plan.get_relation_node(*rel_id) {
+                                    Ok(rel_node) if rel_node.is_subquery() => {
+                                        // Check if the sub-query is an additional one.
+                                        let parent = self.plan.get_relation_node(parent_id);
+                                        let mut is_additional = false;
+                                        if let Ok(Relational::Selection { children, .. }) = parent {
+                                            if children.iter().skip(1).any(|&c| c == *rel_id) {
+                                                is_additional = true;
+                                            }
+                                        }
+                                        if let Ok(Relational::InnerJoin { children, .. }) = parent {
+                                            if children.iter().skip(2).any(|&c| c == *rel_id) {
+                                                is_additional = true;
+                                            }
+                                        }
+                                        if is_additional {
+                                            return Some(rel_id);
+                                        }
+                                    }
+                                    _ => {}
+                                }
+                            }
+                        }
+                        None
+                    }
                 },
 
                 Node::Relational(r) => match r {
@@ -313,7 +356,7 @@ impl<'n> Iterator for SubtreeIterator<'n> {
                         let step = *self.child.borrow();
 
                         *self.child.borrow_mut() += 1;
-                        match step.cmp(&children.len()) {
+                        match step.cmp(&2) {
                             Ordering::Less => {
                                 return children.get(step);
                             }
@@ -358,7 +401,7 @@ impl<'n> Iterator for SubtreeIterator<'n> {
                         let step = *self.child.borrow();
 
                         *self.child.borrow_mut() += 1;
-                        match step.cmp(&children.len()) {
+                        match step.cmp(&1) {
                             Ordering::Less => {
                                 return children.get(step);
                             }
-- 
GitLab