From 44efca3afe695d2e910a0f7aa3fa1298861fb332 Mon Sep 17 00:00:00 2001
From: EmirVildanov <reddog201030@gmail.com>
Date: Wed, 4 Sep 2024 14:05:38 +0300
Subject: [PATCH] feat: support asterisk translation into local SQL without
 "bucket_id" selection

---
 .../src/backend/sql/ir/tests/inner_join.rs    |  4 +-
 sbroad-core/src/backend/sql/tree.rs           | 61 +++++++++----------
 sbroad-core/src/executor/ir.rs                | 21 +++++--
 sbroad-core/src/executor/tests.rs             |  9 +--
 sbroad-core/src/ir.rs                         | 51 ++++++++++++++++
 5 files changed, 103 insertions(+), 43 deletions(-)

diff --git a/sbroad-core/src/backend/sql/ir/tests/inner_join.rs b/sbroad-core/src/backend/sql/ir/tests/inner_join.rs
index 43558fd3e..2780d7986 100644
--- a/sbroad-core/src/backend/sql/ir/tests/inner_join.rs
+++ b/sbroad-core/src/backend/sql/ir/tests/inner_join.rs
@@ -64,7 +64,7 @@ fn inner_join2_latest() {
             r#""hash_testing"."product_units","#,
             r#""hash_testing"."sys_op" FROM "hash_testing") as "hash_testing""#,
             r#"INNER JOIN"#,
-            r#"(SELECT * FROM "history" WHERE ("history"."id") = (?)) as "t""#,
+            r#"(SELECT "history"."id" FROM "history" WHERE ("history"."id") = (?)) as "t""#,
             r#"ON ("hash_testing"."identification_number") = ("t"."id")"#,
             r#"WHERE ("hash_testing"."product_code") = (?)"#,
         ),
@@ -89,7 +89,7 @@ fn inner_join2_oldest() {
             r#""hash_testing"."product_units","#,
             r#""hash_testing"."sys_op" FROM "hash_testing") as "hash_testing""#,
             r#"INNER JOIN"#,
-            r#"(SELECT * FROM "history" WHERE ("history"."id") = (?)) as "t""#,
+            r#"(SELECT "history"."id" FROM "history" WHERE ("history"."id") = (?)) as "t""#,
             r#"ON ("hash_testing"."identification_number") = ("t"."id")"#,
             r#"WHERE ("hash_testing"."product_code") = (?)"#,
         ),
diff --git a/sbroad-core/src/backend/sql/tree.rs b/sbroad-core/src/backend/sql/tree.rs
index 1f3423335..5e7b3d7e7 100644
--- a/sbroad-core/src/backend/sql/tree.rs
+++ b/sbroad-core/src/backend/sql/tree.rs
@@ -498,7 +498,9 @@ impl Select {
         id: usize,
     ) -> Result<Option<Select>, SbroadError> {
         let sn = sp.nodes.get_sn(id);
-        if let Some(Node::Relational(Relational::Projection(_))) = sp.get_plan_node(&sn.data)? {
+        if let Some((Node::Relational(Relational::Projection { .. }), _)) =
+            sp.get_plan_node(&sn.data)?
+        {
             let mut select = Select {
                 parent,
                 branch,
@@ -512,7 +514,7 @@ impl Select {
             loop {
                 let left_id = node.left_id_or_err()?;
                 let sn_left = sp.nodes.get_sn(left_id);
-                let plan_node_left = sp.plan_node_or_err(&sn_left.data)?;
+                let (plan_node_left, _) = sp.plan_node_or_err(&sn_left.data)?;
                 if let Node::Relational(
                     Relational::ScanRelation(_)
                     | Relational::ScanCte(_)
@@ -1485,6 +1487,7 @@ impl<'p> SyntaxPlan<'p> {
         let mut children = Vec::with_capacity(list.len() * 2 + 2);
         children.push(self.nodes.push_sn_non_plan(SyntaxNode::new_open()));
 
+        // Helper enum used for solving asterisk to local SQL transformation.
         enum NodeToAdd {
             SnId(usize),
             Asterisk(SyntaxNode),
@@ -1492,8 +1495,12 @@ impl<'p> SyntaxPlan<'p> {
         }
 
         let mut handle_reference =
-            |sn_id: usize, need_comma: bool, expr_node: &Expression| -> Vec<NodeToAdd> {
-                let mut non_reference_nodes = || -> Vec<NodeToAdd> {
+            |sn_id: usize, need_comma: bool, expr_id: NodeId| -> Vec<NodeToAdd> {
+                let ir_plan = self.plan.get_ir_plan();
+                let expr_node = ir_plan
+                    .get_expression_node(expr_id)
+                    .expect("Expression node must exist");
+                let mut non_asterisk_nodes = || -> Vec<NodeToAdd> {
                     let mut nodes_to_add = Vec::new();
                     if last_handled_asterisk_id.is_some() {
                         nodes_to_add.push(NodeToAdd::Comma);
@@ -1513,29 +1520,20 @@ impl<'p> SyntaxPlan<'p> {
                             relation_name,
                             asterisk_id,
                         }),
-                    parent,
-                    targets,
                     ..
                 } = expr_node
                 {
                     // If we reference ScanNode, we don't want to transform asterisks
                     // in order not to select "bucket_id". That's why we save them as a
                     // sequence of references.
-                    if let Some(parent) = parent {
-                        let ir_plan = self.plan.get_ir_plan();
-                        let targets = targets
-                            .clone()
-                            .expect("Reference with parent must have targets.");
-                        let first_target = targets.first().expect("Targets must not be empry");
-                        let child_id = ir_plan
-                            .get_relational_child(*parent, *first_target)
-                            .expect("Rel child must exist");
-                        let target_rel_node = ir_plan
-                            .get_relation_node(child_id)
-                            .expect("target node must exist");
-                        if let Relational::ScanRelation { .. } = target_rel_node {
-                            return non_reference_nodes();
-                        }
+                    let ref_source_node_id = ir_plan
+                        .get_reference_source_relation(expr_id)
+                        .expect("Reference must have a source relation");
+                    let ref_source_node = ir_plan
+                        .get_relation_node(ref_source_node_id)
+                        .expect("Node must be a relational");
+                    if let Relational::ScanRelation { .. } = ref_source_node {
+                        return non_asterisk_nodes();
                     }
 
                     let mut need_comma = false;
@@ -1574,7 +1572,7 @@ impl<'p> SyntaxPlan<'p> {
                         nodes_to_add.push(NodeToAdd::Asterisk(asterisk_node_to_add));
                     }
                 } else {
-                    return non_reference_nodes();
+                    return non_asterisk_nodes();
                 }
                 nodes_to_add
             };
@@ -1583,15 +1581,14 @@ impl<'p> SyntaxPlan<'p> {
                                             need_comma: bool|
          -> Result<(), SbroadError> {
             let sn_node = self.nodes.get_sn(sn_id);
-            let sn_plan_node = self.get_plan_node(&sn_node.data)?;
+            let sn_plan_node_pair = self.get_plan_node(&sn_node.data)?;
 
-            let nodes_to_add = if let Some(Node::Expression(node_expr)) = sn_plan_node {
+            let nodes_to_add = if let Some((Node::Expression(node_expr), sn_plan_node_id)) =
+                sn_plan_node_pair
+            {
                 match node_expr {
-                    Expression::Alias { child, .. } => {
-                        let alias_child = self.plan.get_ir_plan().get_expression_node(*child)?;
-                        handle_reference(sn_id, need_comma, alias_child)
-                    }
-                    _ => handle_reference(sn_id, need_comma, node_expr),
+                    Expression::Alias { child, .. } => handle_reference(sn_id, need_comma, *child),
+                    _ => handle_reference(sn_id, need_comma, sn_plan_node_id),
                 }
             } else {
                 // As it's not ad Alias under Projection output, we don't have to
@@ -1751,9 +1748,9 @@ impl<'p> SyntaxPlan<'p> {
     ///
     /// # Errors
     /// - syntax node wraps an invalid plan node
-    pub fn get_plan_node(&self, data: &SyntaxData) -> Result<Option<Node>, SbroadError> {
+    pub fn get_plan_node(&self, data: &SyntaxData) -> Result<Option<(&Node, NodeId)>, SbroadError> {
         if let SyntaxData::PlanId(id) = data {
-            Ok(Some(self.plan.get_ir_plan().get_node(*id)?))
+            Ok(Some((self.plan.get_ir_plan().get_node(*id)?, *id)))
         } else {
             Ok(None)
         }
@@ -1764,7 +1761,7 @@ impl<'p> SyntaxPlan<'p> {
     /// # Errors
     /// - plan node is invalid
     /// - syntax tree node doesn't have a plan node
-    pub fn plan_node_or_err(&self, data: &SyntaxData) -> Result<Node, SbroadError> {
+    pub fn plan_node_or_err(&self, data: &SyntaxData) -> Result<(&Node, NodeId), SbroadError> {
         self.get_plan_node(data)?.ok_or_else(|| {
             SbroadError::Invalid(
                 Entity::SyntaxPlan,
diff --git a/sbroad-core/src/executor/ir.rs b/sbroad-core/src/executor/ir.rs
index f1c1e4f0f..959c4ad8c 100644
--- a/sbroad-core/src/executor/ir.rs
+++ b/sbroad-core/src/executor/ir.rs
@@ -640,9 +640,12 @@ impl ExecutionPlan {
                     *rel.mut_output() = subtree_map.get_id(*output);
                     relational_output_id = Some(*rel.mut_output());
 
-                    if !rel_renamed_output_lists.is_empty()
-                        && !matches!(rel, Relational::Projection { .. })
-                    {
+                    // If we deal with Projection we have to fix
+                    // only References that have an Asterisk source.
+                    // References without asterisks would be covered with aliases like
+                    // "COL_1 as <alias>".
+                    let is_projection = matches!(rel, Relational::Projection { .. });
+                    if !rel_renamed_output_lists.is_empty() {
                         let rel_output_list: Vec<usize> =
                             new_plan.get_row_list(rel.output())?.to_vec();
 
@@ -650,11 +653,19 @@ impl ExecutionPlan {
                             let ref_under_alias = new_plan.get_child_under_alias(*output_id)?;
                             let ref_expr = new_plan.get_expression_node(ref_under_alias)?;
                             let Expression::Reference {
-                                position, targets, ..
+                                position,
+                                targets,
+                                asterisk_source,
+                                ..
                             } = ref_expr
                             else {
-                                panic!("Expected reference, got {ref_expr:?}");
+                                continue;
                             };
+
+                            if is_projection && asterisk_source.is_none() {
+                                continue;
+                            }
+
                             let mut ref_rel_node = None;
 
                             let target = if let Some(targets) = targets {
diff --git a/sbroad-core/src/executor/tests.rs b/sbroad-core/src/executor/tests.rs
index 9b407d9aa..859cb4777 100644
--- a/sbroad-core/src/executor/tests.rs
+++ b/sbroad-core/src/executor/tests.rs
@@ -787,7 +787,7 @@ fn anonymous_col_index_test() {
             LuaValue::String(String::from(PatternWithParams::new(
                 format!(
                     "{} {} {} {} {} {}",
-                    "SELECT *",
+                    r#"SELECT "test_space"."id", "test_space"."sysFrom", "test_space"."FIRST_NAME", "test_space"."sys_op""#,
                     r#"FROM "test_space""#,
                     r#"WHERE ("test_space"."id") in"#,
                     r#"(SELECT "COL_1" FROM "TMP_test_1136")"#,
@@ -803,7 +803,7 @@ fn anonymous_col_index_test() {
                 format!(
                     "{} {} {} {} {} {}",
                     "SELECT",
-                    r#"* FROM "test_space""#,
+                    r#""test_space"."id", "test_space"."sysFrom", "test_space"."FIRST_NAME", "test_space"."sys_op" FROM "test_space""#,
                     r#"WHERE ("test_space"."id") in"#,
                     r#"(SELECT "COL_1" FROM "TMP_test_1136")"#,
                     r#"or ("test_space"."id") in"#,
@@ -837,7 +837,8 @@ fn sharding_column1_test() {
         LuaValue::String(String::from(PatternWithParams::new(
             format!(
                 "{} {}",
-                r#"SELECT *"#, r#"FROM "test_space" WHERE ("test_space"."id") = (?)"#,
+                r#"SELECT "test_space"."id", "test_space"."sysFrom", "test_space"."FIRST_NAME", "test_space"."sys_op""#,
+                r#"FROM "test_space" WHERE ("test_space"."id") = (?)"#,
             ),
             vec![Value::from(1_u64)],
         ))),
@@ -866,7 +867,7 @@ fn sharding_column2_test() {
         LuaValue::String(String::from(PatternWithParams::new(
             format!(
                 "{} {}",
-                r#"SELECT *,"#,
+                r#"SELECT "test_space"."id", "test_space"."sysFrom", "test_space"."FIRST_NAME", "test_space"."sys_op","#,
                 r#""test_space"."bucket_id" FROM "test_space" WHERE ("test_space"."id") = (?)"#,
             ),
             vec![Value::from(1_u64)],
diff --git a/sbroad-core/src/ir.rs b/sbroad-core/src/ir.rs
index 743e97d76..75472cbf5 100644
--- a/sbroad-core/src/ir.rs
+++ b/sbroad-core/src/ir.rs
@@ -1023,6 +1023,57 @@ impl Plan {
         self.slices.clone()
     }
 
+    /// Find source relational node from which given reference came.
+    /// E.g. in case we're working with a reference under Projection with the following
+    /// children struct:
+    /// Projection <- given reference in the output
+    ///     Selection
+    ///         Scan
+    /// the source would be the Scan. It could also be a Join or Union nodes.
+    pub fn get_reference_source_relation(&self, ref_id: usize) -> Result<usize, SbroadError> {
+        let mut ref_id = ref_id;
+        let mut ref_node = self.get_expression_node(ref_id)?;
+        if let Expression::Alias { child, .. } = ref_node {
+            ref_node = self.get_expression_node(*child)?;
+            ref_id = *child;
+        }
+        let Expression::Reference { position, .. } = ref_node else {
+            panic!("Expected reference")
+        };
+        let ref_parent_node_id = *self.get_relational_from_reference_node(ref_id)?;
+        let ref_source_node = self.get_relation_node(ref_parent_node_id)?;
+        match ref_source_node {
+            Relational::Delete { .. } | Relational::Insert { .. } | Relational::Update { .. } => {
+                panic!("Reference source search shouldn't reach DML node.")
+            }
+            Relational::Selection { output, .. }
+            | Relational::Having { output, .. }
+            | Relational::OrderBy { output, .. }
+            | Relational::Limit { output, .. } => {
+                let source_output_list = self.get_row_list(*output)?;
+                let source_ref_id = source_output_list[*position];
+                self.get_reference_source_relation(source_ref_id)
+            }
+            Relational::ScanRelation { .. }
+            | Relational::Projection { .. }
+            | Relational::ScanCte { .. }
+            | Relational::GroupBy { .. }
+            | Relational::Motion { .. }
+            | Relational::ScanSubQuery { .. }
+            | Relational::Join { .. }
+            | Relational::Except { .. }
+            | Relational::Intersect { .. }
+            | Relational::UnionAll { .. }
+            | Relational::Union { .. }
+            | Relational::Values { .. } => Ok(ref_parent_node_id),
+            Relational::ValuesRow { .. } => {
+                panic!(
+                    "Reference source search shouldn't reach unsupported node {ref_source_node:?}."
+                )
+            }
+        }
+    }
+
     /// Get relation in the plan by its name or returns error.
     ///
     /// # Errors
-- 
GitLab