From 8f8838f200844416cd1e48cabe0383b5f8d3c6d8 Mon Sep 17 00:00:00 2001
From: Arseniy Volynets <vol0ncar@yandex.ru>
Date: Fri, 31 May 2024 12:51:16 +0000
Subject: [PATCH] fix: use motion ref counter for SAE opcode

---
 .../test_app/test/integration/cte_test.lua    | 25 +++++++
 .../src/executor/engine/helpers/vshard.rs     | 65 +++++++++++++++----
 sbroad-core/src/executor/ir.rs                |  6 +-
 .../src/ir/transformation/redistribution.rs   | 18 ++---
 4 files changed, 88 insertions(+), 26 deletions(-)

diff --git a/sbroad-cartridge/test_app/test/integration/cte_test.lua b/sbroad-cartridge/test_app/test/integration/cte_test.lua
index d41ffafaf..e6102d5f1 100644
--- a/sbroad-cartridge/test_app/test/integration/cte_test.lua
+++ b/sbroad-cartridge/test_app/test/integration/cte_test.lua
@@ -196,5 +196,30 @@ g.test_cte = function ()
     t.assert_equals(err, nil)
     t.assert_items_equals(r["metadata"], { {name = "T.C", type = "integer"} })
     t.assert_items_equals(r["rows"], { {1} })
+
+    -- cte with "serialize as empty table" opcode in motion
+    r, err = api:call("sbroad.execute", { [[
+        WITH cte1(a) as (VALUES(1)),
+        cte2(a) as (SELECT a1.a FROM cte1 a1 JOIN "t" ON true UNION SELECT * FROM cte1 a2)
+        SELECT * FROM cte2
+    ]], })
+    t.assert_equals(err, nil)
+    t.assert_items_equals(r["metadata"], { {name = "CTE2.A", type = "any"} })
+    t.assert_items_equals(r["rows"], { {1} })
+
+    r, err = api:call("sbroad.execute", { [[
+        WITH cte1(a) as (VALUES(1)),
+        cte2(a) as (
+            SELECT a1.a FROM cte1 a1 JOIN "t" ON a1.a = "id"
+            UNION ALL
+            SELECT * FROM cte1 a2
+            UNION ALL
+            SELECT * FROM cte1 a3
+        )
+        SELECT * FROM cte2
+    ]], })
+    t.assert_equals(err, nil)
+    t.assert_items_equals(r["metadata"], { {name = "CTE2.A", type = "any"} })
+    t.assert_items_equals(r["rows"], { {1}, {1}, {1} })
 end
 
diff --git a/sbroad-core/src/executor/engine/helpers/vshard.rs b/sbroad-core/src/executor/engine/helpers/vshard.rs
index 1a5f97664..7eb1499f6 100644
--- a/sbroad-core/src/executor/engine/helpers/vshard.rs
+++ b/sbroad-core/src/executor/engine/helpers/vshard.rs
@@ -14,10 +14,11 @@ use crate::{
             relation::RelationalIterator,
             traversal::{PostOrderWithFilter, REL_CAPACITY},
         },
-        Node, Plan,
+        Node, NodeId, Plan,
     },
     otm::child_span,
 };
+use ahash::AHashMap;
 use rand::{thread_rng, Rng};
 use sbroad_proc::otm_child_span;
 use smol_str::format_smolstr;
@@ -254,14 +255,15 @@ pub fn exec_ir_on_some_buckets(
 struct SerializeAsEmptyInfo {
     // ids of topmost motion nodes which have this opcode
     // with `true` value
-    top_motion_ids: Vec<usize>,
+    top_motion_ids: Vec<NodeId>,
     // ids of motions which have this opcode
-    target_motion_ids: Vec<usize>,
+    target_motion_ids: Vec<NodeId>,
+    unused_motions: Vec<NodeId>,
     // ids of motions that are located below
     // top_motion_id, vtables corresponding
     // to those motions must be deleted from
     // replicaset message.
-    unused_motions: Vec<usize>,
+    motions_ref_count: AHashMap<NodeId, usize>,
 }
 
 impl Plan {
@@ -300,6 +302,23 @@ impl Plan {
 
     fn serialize_as_empty_info(&self) -> Result<Option<SerializeAsEmptyInfo>, SbroadError> {
         let top_ids = self.collect_top_ids()?;
+
+        let mut motions_ref_count: AHashMap<NodeId, usize> = AHashMap::new();
+        let filter = |node_id: usize| -> bool {
+            matches!(
+                self.get_node(node_id),
+                Ok(Node::Relational(Relational::Motion { .. }))
+            )
+        };
+        let mut dfs =
+            PostOrderWithFilter::with_capacity(|x| self.nodes.rel_iter(x), 0, Box::new(filter));
+        for (_, motion_id) in dfs.iter(self.get_top()?) {
+            motions_ref_count
+                .entry(motion_id)
+                .and_modify(|cnt| *cnt += 1)
+                .or_insert(1);
+        }
+
         if top_ids.is_empty() {
             return Ok(None);
         }
@@ -338,6 +357,7 @@ impl Plan {
             top_motion_ids: top_ids,
             target_motion_ids: target_motions,
             unused_motions,
+            motions_ref_count,
         }))
     }
 }
@@ -354,14 +374,21 @@ pub fn prepare_rs_to_ir_map(
     let mut rs_ir: HashMap<String, ExecutionPlan> = HashMap::new();
     rs_ir.reserve(rs_bucket_vec.len());
     if let Some((last, other)) = rs_bucket_vec.split_last() {
-        let sae_info = sub_plan.get_ir_plan().serialize_as_empty_info()?;
-        for (rs, bucket_ids) in other {
-            let mut rs_plan = sub_plan.clone();
-            if let Some(ref info) = sae_info {
-                apply_serialize_as_empty_opcode(&mut rs_plan, info)?;
+        let mut sae_info = sub_plan.get_ir_plan().serialize_as_empty_info()?;
+        let mut other_plan = sub_plan.clone();
+
+        if let Some(info) = sae_info.as_mut() {
+            apply_serialize_as_empty_opcode(&mut other_plan, info)?;
+        }
+        if let Some((other_last, other_other)) = other.split_last() {
+            for (rs, bucket_ids) in other_other {
+                let mut rs_plan = other_plan.clone();
+                filter_vtable(&mut rs_plan, bucket_ids)?;
+                rs_ir.insert(rs.clone(), rs_plan);
             }
-            filter_vtable(&mut rs_plan, bucket_ids)?;
-            rs_ir.insert(rs.clone(), rs_plan);
+            let (rs, bucket_ids) = other_last;
+            filter_vtable(&mut other_plan, bucket_ids)?;
+            rs_ir.insert(rs.clone(), other_plan);
         }
 
         if let Some(ref info) = sae_info {
@@ -377,11 +404,21 @@ pub fn prepare_rs_to_ir_map(
 
 fn apply_serialize_as_empty_opcode(
     sub_plan: &mut ExecutionPlan,
-    info: &SerializeAsEmptyInfo,
+    info: &mut SerializeAsEmptyInfo,
 ) -> Result<(), SbroadError> {
     if let Some(vtables_map) = sub_plan.get_mut_vtables() {
-        for motion_id in &info.unused_motions {
-            vtables_map.remove(motion_id);
+        let unused_motions = std::mem::take(&mut info.unused_motions);
+        for motion_id in &unused_motions {
+            let Some(use_count) = info.motions_ref_count.get_mut(motion_id) else {
+                return Err(SbroadError::UnexpectedNumberOfValues(format_smolstr!(
+                    "no ref count for motion={motion_id}"
+                )));
+            };
+            if *use_count > 1 {
+                *use_count -= 1;
+            } else {
+                vtables_map.remove(motion_id);
+            }
         }
     }
 
diff --git a/sbroad-core/src/executor/ir.rs b/sbroad-core/src/executor/ir.rs
index 57c904741..99511fd7e 100644
--- a/sbroad-core/src/executor/ir.rs
+++ b/sbroad-core/src/executor/ir.rs
@@ -105,10 +105,14 @@ impl ExecutionPlan {
                 return Ok(Rc::clone(result));
             }
         }
+        let motion_node = self.get_ir_plan().get_relation_node(motion_id)?;
 
         Err(SbroadError::NotFound(
             Entity::VirtualTable,
-            format_smolstr!("for Motion node ({motion_id})"),
+            format_smolstr!(
+                "for Motion node ({motion_id}): {motion_node:?}. Plan: {:?}",
+                self
+            ),
         ))
     }
 
diff --git a/sbroad-core/src/ir/transformation/redistribution.rs b/sbroad-core/src/ir/transformation/redistribution.rs
index f7ba50d70..eb2a31d7c 100644
--- a/sbroad-core/src/ir/transformation/redistribution.rs
+++ b/sbroad-core/src/ir/transformation/redistribution.rs
@@ -139,20 +139,16 @@ pub enum MotionOpcode {
     AddMissingRowsForLeftJoin {
         motion_id: usize,
     },
-    /// When set to `true` this opcode serializes
-    /// motion subtree to sql that produces
+    /// When set to `true` this opcode serializes motion subtree to sql that produces
     /// empty table.
     ///
-    /// Relevant only for Local motion policy.
-    /// Must be initialized to `true` by planner,
-    /// executor garuantees to mark only one replicaset
-    /// which will have `false` value in this opcode.
-    /// For all replicasets that will have `true` value,
-    /// executor will remove all virtual tables used in
-    /// below subtree and unlink the given motion node.
+    /// Relevant only for Local motion policy. Must be initialized to `true` by planner,
+    /// executor guarantees to mark only one replicaset which will have `false` value
+    /// in this opcode. For all replicasets that will have `true` value, executor unlinks
+    /// the sub-trees below the given motion nodes.
     ///
-    /// Note: currently this opcode is only used for
-    /// execution of union all having global child and sharded child.
+    /// Note: currently this opcode is only used for execution of union all having global
+    /// child and sharded child.
     SerializeAsEmptyTable(bool),
     RemoveDuplicates,
 }
-- 
GitLab