From 46e47e9c5f3f80b2425c2e906fec95c0f7492cbf Mon Sep 17 00:00:00 2001
From: Igor Kuznetsov <kuznetsovin@gmail.com>
Date: Fri, 26 Aug 2022 13:51:24 +0300
Subject: [PATCH] feat: added support `inner join` to explain

---
 src/ir/explain.rs                      | 55 +++++++++++++++++++++--
 src/ir/explain/tests.rs                | 61 ++++++++++++++++++++++++++
 test_app/test/integration/api_test.lua | 37 ++++++++++++++++
 3 files changed, 150 insertions(+), 3 deletions(-)

diff --git a/src/ir/explain.rs b/src/ir/explain.rs
index 13e9ee5fdc..f20fa8f539 100644
--- a/src/ir/explain.rs
+++ b/src/ir/explain.rs
@@ -420,7 +420,7 @@ impl Display for MotionKey {
 }
 
 #[derive(Debug, Serialize)]
-pub enum Target {
+enum Target {
     Reference(String),
     Value(Value),
 }
@@ -434,10 +434,22 @@ impl Display for Target {
     }
 }
 
+#[derive(Debug, Serialize)]
+struct InnerJoin {
+    condition: Selection,
+}
+
+impl Display for InnerJoin {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(f, "join on {}", self.condition)
+    }
+}
+
 #[derive(Debug, Serialize)]
 #[allow(dead_code)]
 enum ExplainNode {
     Except,
+    InnerJoin(InnerJoin),
     Projection(Projection),
     Scan(Scan),
     Selection(Selection),
@@ -450,6 +462,7 @@ impl Display for ExplainNode {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         let s = match &self {
             ExplainNode::Except => "except".to_string(),
+            ExplainNode::InnerJoin(i) => i.to_string(),
             ExplainNode::Projection(e) => e.to_string(),
             ExplainNode::Scan(s) => s.to_string(),
             ExplainNode::Selection(s) => format!("selection {}", s),
@@ -682,8 +695,44 @@ impl FullExplain {
 
                     Some(ExplainNode::Motion(m))
                 }
-                Relational::InnerJoin { .. }
-                | Relational::Insert { .. }
+                Relational::InnerJoin {
+                    children,
+                    condition,
+                    ..
+                } => {
+                    if children.len() < 2 {
+                        return Err(QueryPlannerError::CustomError(
+                            "Join must have at least two children".into(),
+                        ));
+                    }
+                    let (_, subquery_ids) = children.split_at(2);
+                    let mut sq_ref_map: HashMap<usize, usize> =
+                        HashMap::with_capacity(children.len() - 2);
+
+                    for sq_id in subquery_ids.iter().rev() {
+                        let sq_node = stack.pop().ok_or_else(|| {
+                            QueryPlannerError::CustomError(
+                                "Join node failed to get a sub-query.".into(),
+                            )
+                        })?;
+                        result.subqueries.push(sq_node);
+                        let offset = result.subqueries.len() - 1;
+                        sq_ref_map.insert(*sq_id, offset);
+                    }
+
+                    if let (Some(right), Some(left)) = (stack.pop(), stack.pop()) {
+                        current_node.children.push(left);
+                        current_node.children.push(right);
+                    } else {
+                        return Err(QueryPlannerError::CustomError(
+                            "Join node must have exactly two children".into(),
+                        ));
+                    }
+
+                    let condition = Selection::new(ir, *condition, &sq_ref_map)?;
+                    Some(ExplainNode::InnerJoin(InnerJoin { condition }))
+                }
+                Relational::Insert { .. }
                 | Relational::Values { .. }
                 | Relational::ValuesRow { .. } => {
                     return Err(QueryPlannerError::CustomError(format!(
diff --git a/src/ir/explain/tests.rs b/src/ir/explain/tests.rs
index 685a975c7e..dea4ebe64f 100644
--- a/src/ir/explain/tests.rs
+++ b/src/ir/explain/tests.rs
@@ -221,3 +221,64 @@ motion [policy: segment([ref("identification_number")]), generation: none]
 
     assert_eq!(actual_explain, explain_tree.to_string())
 }
+
+#[test]
+fn motion_join_plan() {
+    let query = r#"SELECT "t1"."FIRST_NAME"
+FROM (SELECT "id", "FIRST_NAME" FROM "test_space" WHERE "id" = 3) as "t1"
+    JOIN (SELECT "identification_number", "product_code" FROM "hash_testing") as "t2" ON "t1"."id"="t2"."identification_number"
+WHERE "t2"."product_code" = '123'"#;
+
+    let plan = sql_to_optimized_ir(query, vec![]);
+
+    let top = &plan.get_top().unwrap();
+    let explain_tree = FullExplain::new(&plan, *top).unwrap();
+
+    let mut actual_explain = String::new();
+    actual_explain.push_str(r#"projection ("t1"."FIRST_NAME" -> "FIRST_NAME")
+    selection ROW("t2"."product_code") = ROW('123')
+        join on ROW("t1"."id") = ROW("t2"."identification_number")
+            scan "t1"
+                projection ("test_space"."id" -> "id", "test_space"."FIRST_NAME" -> "FIRST_NAME")
+                    selection ROW("test_space"."id") = ROW(3)
+                        scan "test_space"
+            motion [policy: segment([ref("identification_number")]), generation: none]
+                scan "t2"
+                    projection ("hash_testing"."identification_number" -> "identification_number", "hash_testing"."product_code" -> "product_code")
+                        scan "hash_testing"
+"#);
+
+    assert_eq!(actual_explain, explain_tree.to_string())
+}
+
+#[test]
+fn sq_join_plan() {
+    let query = r#"SELECT "t1"."FIRST_NAME"
+FROM (SELECT "id", "FIRST_NAME" FROM "test_space" WHERE "id" = 3) as "t1"
+    JOIN "hash_testing" ON "t1"."id"=(SELECT "identification_number" FROM "hash_testing")"#;
+
+    let plan = sql_to_optimized_ir(query, vec![]);
+
+    let top = &plan.get_top().unwrap();
+    let explain_tree = FullExplain::new(&plan, *top).unwrap();
+
+    let mut actual_explain = String::new();
+    actual_explain.push_str(r#"projection ("t1"."FIRST_NAME" -> "FIRST_NAME")
+    join on ROW("t1"."id") = ROW($0)
+        scan "t1"
+            projection ("test_space"."id" -> "id", "test_space"."FIRST_NAME" -> "FIRST_NAME")
+                selection ROW("test_space"."id") = ROW(3)
+                    scan "test_space"
+        motion [policy: full, generation: none]
+            scan "hash_testing"
+                projection ("hash_testing"."identification_number" -> "identification_number", "hash_testing"."product_code" -> "product_code", "hash_testing"."product_units" -> "product_units", "hash_testing"."sys_op" -> "sys_op")
+                    scan "hash_testing"
+subquery $0:
+motion [policy: segment([ref("identification_number")]), generation: none]
+            scan
+                projection ("hash_testing"."identification_number" -> "identification_number")
+                    scan "hash_testing"
+"#);
+
+    assert_eq!(actual_explain, explain_tree.to_string())
+}
diff --git a/test_app/test/integration/api_test.lua b/test_app/test/integration/api_test.lua
index 080fc95091..9a9ddaef1b 100644
--- a/test_app/test/integration/api_test.lua
+++ b/test_app/test/integration/api_test.lua
@@ -799,6 +799,43 @@ g.test_motion_explain = function()
     )
 end
 
+g.test_join_explain = function()
+    local api = cluster:server("api-1").net_box
+
+    local r, err = api:call("sbroad.explain", { [[SELECT *
+FROM
+    (SELECT "id", "name" FROM "space_simple_shard_key" WHERE "sysOp" < 1
+     UNION ALL
+     SELECT "id", "name" FROM "space_simple_shard_key_hist" WHERE "sysOp" > 0) AS "t3"
+INNER JOIN
+    (SELECT "id" as "tid"  FROM "testing_space" where "id" <> 1) AS "t8"
+    ON "t3"."id" = "t8"."tid"
+WHERE "t3"."name" = '123']] })
+    t.assert_equals(err, nil)
+    t.assert_equals(
+        r,
+        -- luacheck: max line length 210
+        {
+            "projection (\"t3\".\"id\" -> \"id\", \"t3\".\"name\" -> \"name\", \"t8\".\"tid\" -> \"tid\")",
+            "    selection ROW(\"t3\".\"name\") = ROW('123')",
+            "        join on ROW(\"t3\".\"id\") = ROW(\"t8\".\"tid\")",
+            "            scan \"t3\"",
+            "                union all",
+            "                    projection (\"space_simple_shard_key\".\"id\" -> \"id\", \"space_simple_shard_key\".\"name\" -> \"name\")",
+            "                        selection ROW(\"space_simple_shard_key\".\"sysOp\") < ROW(1)",
+            "                            scan \"space_simple_shard_key\"",
+            "                    projection (\"space_simple_shard_key_hist\".\"id\" -> \"id\", \"space_simple_shard_key_hist\".\"name\" -> \"name\")",
+            "                        selection ROW(\"space_simple_shard_key_hist\".\"sysOp\") > ROW(0)",
+            "                            scan \"space_simple_shard_key_hist\"",
+            "            motion [policy: segment([ref(\"tid\")]), generation: none]",
+            "                scan \"t8\"",
+            "                    projection (\"testing_space\".\"id\" -> \"tid\")",
+            "                        selection ROW(\"testing_space\".\"id\") <> ROW(1)",
+            "                            scan \"testing_space\"",
+        }
+    )
+end
+
 g.test_valid_explain = function()
     local api = cluster:server("api-1").net_box
 
-- 
GitLab