Skip to content
Snippets Groups Projects
Commit 7fc3ff62 authored by Arseniy Volynets's avatar Arseniy Volynets :boy_tone5:
Browse files

fix: merge tuple transformation didn't group cols

- merge tuple transformation that merges several
and-ed equalities into equlities of rows didn't
group columns by child they refer to. This led
to rows where we couldn't find sharding keys,
because they were scattered across the different
rows:

```
sk(t1) = (a, b), sk(t2) = (e, f)
... on (t1.a, t2.f) = (t2.e, t1.b)
```

But now correct rows are generated:

```
... on (t1.a, t1.b) = (t2.e, t2.f)
```
parent f84401ec
No related branches found
No related tags found
1 merge request!1414sbroad import
......@@ -2577,7 +2577,7 @@ fn front_sql_groupby_with_aggregates() {
let plan = sql_to_optimized_ir(input, vec![]);
let expected_explain = String::from(
r#"projection ("t1"."a"::unsigned -> "a", "t1"."b"::unsigned -> "b", "t1"."c"::decimal -> "c", "t2"."g"::unsigned -> "g", "t2"."e"::unsigned -> "e", "t2"."f"::decimal -> "f")
join on ROW("t1"."a"::unsigned, "t2"."g"::unsigned) = ROW("t2"."e"::unsigned, "t1"."b"::unsigned)
join on ROW("t1"."a"::unsigned, "t1"."b"::unsigned) = ROW("t2"."e"::unsigned, "t2"."g"::unsigned)
scan "t1"
projection ("column_596"::unsigned -> "a", "column_696"::unsigned -> "b", sum(("sum_1596"::decimal))::decimal -> "c")
group by ("column_596"::unsigned, "column_696"::unsigned) output: ("column_596"::unsigned -> "column_596", "column_696"::unsigned -> "column_696", "sum_1596"::decimal -> "sum_1596")
......@@ -2585,7 +2585,7 @@ fn front_sql_groupby_with_aggregates() {
projection ("t"."a"::unsigned -> "column_596", "t"."b"::unsigned -> "column_696", sum(("t"."c"::unsigned))::decimal -> "sum_1596")
group by ("t"."a"::unsigned, "t"."b"::unsigned) output: ("t"."a"::unsigned -> "a", "t"."b"::unsigned -> "b", "t"."c"::unsigned -> "c", "t"."d"::unsigned -> "d", "t"."bucket_id"::unsigned -> "bucket_id")
scan "t"
motion [policy: full]
motion [policy: segment([ref("e"), ref("g")])]
scan "t2"
projection ("column_2496"::unsigned -> "g", "column_2596"::unsigned -> "e", sum(("sum_3496"::decimal))::decimal -> "f")
group by ("column_2496"::unsigned, "column_2596"::unsigned) output: ("column_2496"::unsigned -> "column_2496", "column_2596"::unsigned -> "column_2596", "sum_3496"::decimal -> "sum_3496")
......@@ -2600,7 +2600,6 @@ execution options:
);
assert_eq!(expected_explain, plan.as_explain().unwrap());
println!("{}", plan.as_explain().unwrap());
}
#[test]
......@@ -4091,7 +4090,6 @@ fn front_select_without_scan_3() {
let metadata = &RouterConfigurationMock::new();
let err = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap_err();
dbg!(&err);
assert_eq!(
"invalid type: expected a Column in SelectWithoutScan, got Asterisk.",
err.to_string()
......@@ -4105,7 +4103,6 @@ fn front_select_without_scan_4() {
let metadata = &RouterConfigurationMock::new();
let err = AbstractSyntaxTree::transform_into_plan(input, metadata).unwrap_err();
dbg!(&err);
assert_eq!(
"invalid type: expected a Column in SelectWithoutScan, got Distinct.",
err.to_string()
......
......@@ -1388,6 +1388,15 @@ impl Plan {
}
}
/// Return Reference if this `node_id` refers to it,
/// otherwise return `None`.
pub fn get_reference(&self, node_id: NodeId) -> Option<&Reference> {
if let Expression::Reference(r) = self.get_expression_node(node_id).ok()? {
return Some(r);
}
None
}
/// Get mutable expression type node
///
/// # Errors
......
......@@ -18,7 +18,7 @@ use std::hash::BuildHasher;
use super::node::expression::Expression;
use super::node::relational::Relational;
use super::node::{Like, Limit, SelectWithoutScan};
use super::node::{ArithmeticExpr, Like, Limit, SelectWithoutScan};
/// Helper macros to build a hash map or set
/// from the list of arguments.
......@@ -221,7 +221,14 @@ impl Plan {
writeln_with_tabulation(buf, tabulation_number + 1, "Child")?;
self.formatted_arena_node(buf, tabulation_number + 1, *child)?;
}
Expression::Arithmetic(_) => writeln!(buf, "Arithmetic")?,
Expression::Arithmetic(ArithmeticExpr { left, right, op }) => {
writeln!(buf, "Arithmetic: [op: {op}]")?;
writeln_with_tabulation(buf, tabulation_number + 1, "Child")?;
writeln_with_tabulation(buf, tabulation_number + 1, "Left child")?;
self.formatted_arena_node(buf, tabulation_number + 1, *left)?;
writeln_with_tabulation(buf, tabulation_number + 1, "Right child")?;
self.formatted_arena_node(buf, tabulation_number + 1, *right)?;
}
};
}
Ok(())
......
......@@ -13,7 +13,8 @@
use crate::errors::{Entity, SbroadError};
use crate::ir::helpers::RepeatableState;
use crate::ir::node::expression::{Expression, MutExpression};
use crate::ir::node::{Alias, ArithmeticExpr, BoolExpr, NodeId, Row};
use crate::ir::node::relational::Relational;
use crate::ir::node::{Alias, ArithmeticExpr, BoolExpr, NodeId, Reference, Row};
use crate::ir::operator::Bool;
use crate::ir::transformation::OldNewTopIdPair;
use crate::ir::tree::traversal::BreadthFirst;
......@@ -162,17 +163,24 @@ impl Chain {
let mut grouped_top_id: Option<NodeId> = None;
let ordered_ops = &[Bool::Eq, Bool::NotEq];
for op in ordered_ops {
if let Some((left, right)) = self.grouped.get(op) {
let left_row_id = plan.nodes.add_row(left.clone(), None);
let right_row_id = plan.nodes.add_row(right.clone(), None);
let cond_id = plan.add_cond(left_row_id, op.clone(), right_row_id)?;
match grouped_top_id {
None => {
grouped_top_id = Some(cond_id);
}
Some(top_id) => {
grouped_top_id = Some(plan.add_cond(top_id, Bool::And, cond_id)?);
}
let Some((left, right)) = self.grouped.get(op) else {
continue;
};
let cond_id = if *op == Bool::Eq {
if let Some(grouped) = plan.split_join_references(left, right) {
grouped.add_rows_to_plan(plan)?
} else {
add_rows_and_cond(plan, left.clone(), right.clone(), op)?
}
} else {
add_rows_and_cond(plan, left.clone(), right.clone(), op)?
};
match grouped_top_id {
None => {
grouped_top_id = Some(cond_id);
}
Some(top_id) => {
grouped_top_id = Some(plan.add_cond(top_id, Bool::And, cond_id)?);
}
}
}
......@@ -205,7 +213,155 @@ impl Chain {
}
}
fn add_rows_and_cond(
plan: &mut Plan,
left: impl Into<Vec<NodeId>>,
right: impl Into<Vec<NodeId>>,
op: &Bool,
) -> Result<NodeId, SbroadError> {
let left_row_id = plan.nodes.add_row(left.into(), None);
let right_row_id = plan.nodes.add_row(right.into(), None);
plan.add_cond(left_row_id, op.clone(), right_row_id)
}
struct GroupedRows {
// Left row that contains references to first join child
join_refs_left: Vec<NodeId>,
// Right row that contains references to second join child
join_refs_right: Vec<NodeId>,
other_left: Vec<NodeId>,
other_right: Vec<NodeId>,
}
impl GroupedRows {
fn add_rows_to_plan(self, plan: &mut Plan) -> Result<NodeId, SbroadError> {
fn add_eq(
left: Vec<NodeId>,
right: Vec<NodeId>,
plan: &mut Plan,
) -> Result<Option<NodeId>, SbroadError> {
debug_assert!(left.len() == right.len());
let res = if !left.is_empty() {
add_rows_and_cond(plan, left, right, &Bool::Eq)?.into()
} else {
None
};
Ok(res)
}
let eq1: Option<NodeId> = add_eq(self.join_refs_left, self.join_refs_right, plan)?;
let eq2: Option<NodeId> = add_eq(self.other_left, self.other_right, plan)?;
let res = match (eq1, eq2) {
(Some(id1), Some(id2)) => plan.add_bool(id1, Bool::And, id2)?,
(None, Some(id)) | (Some(id), None) => id,
(None, None) => panic!("at least some row must be non-empty"),
};
Ok(res)
}
}
impl Plan {
fn split_join_references(&self, left: &[NodeId], right: &[NodeId]) -> Option<GroupedRows> {
// First check that we are in join
let contains_join_refs = |row: &[NodeId]| -> bool {
row.iter().any(|id| {
self.get_expression_node(*id).is_ok_and(|expr| {
if let Expression::Reference(Reference {
parent: Some(p), ..
}) = expr
{
if self
.get_relation_node(*p)
.is_ok_and(|rel| matches!(rel, Relational::Join(_)))
{
return true;
}
}
false
})
})
};
if !contains_join_refs(left) && !contains_join_refs(right) {
return None;
}
// Split (left) = (right) into
// (a1) = (b1) and (a2) = (b2)
// a1 - contains references from one child
// a2 - contains references from some other child
// a2, b2 - contain all other expressions
// This is done for join conflict resolution:
// we calculate the distribution of rows in the `on`
// condition to find equality on sharding keys.
// In case of `(t1.a, t2.d) = (t2.c, t1.b)` where
// sk(t1) = (a, b), sk(t2) = (c, d) we will fail
// to find matching keys and insert Motion(Full).
// So we need to group references by their table (child).
// a1 - will store references of the first child
// b1 - of the second child
let mut join_refs_left = Vec::new();
let mut join_refs_right = Vec::new();
let mut other_left = Vec::new();
let mut other_right = Vec::new();
let first_child_target = Some(vec![0]);
let second_child_target = Some(vec![1]);
left.iter()
.zip(right.iter())
.map(|(left_id, right_id)| {
// Map each pair of equal expressions into
// (left, right, flag), where flag=true indicates
// that this pair is of form Reference1 = Reference2
// where Reference1 refers to first join child
// and Reference2 refers to second join child
let other_pair = (left_id, right_id, false);
let Some(Reference {
targets: target_l,
parent: parent_l,
..
}) = self.get_reference(*left_id)
else {
return other_pair;
};
let Some(Reference {
targets: target_r,
parent: parent_r,
..
}) = self.get_reference(*right_id)
else {
return other_pair;
};
debug_assert!(parent_r == parent_l);
if target_l == &first_child_target && target_r == &second_child_target {
return (left_id, right_id, true);
} else if target_l == &second_child_target && target_r == &first_child_target {
return (right_id, left_id, true);
}
other_pair
})
.for_each(|(left_id, right_id, is_join_refs)| {
if is_join_refs {
join_refs_left.push(*left_id);
join_refs_right.push(*right_id);
} else {
other_left.push(*left_id);
other_right.push(*right_id);
}
});
Some(GroupedRows {
join_refs_left,
join_refs_right,
other_left,
other_right,
})
}
fn get_columns_or_self(&self, expr_id: NodeId) -> Result<Vec<NodeId>, SbroadError> {
let expr = self.get_expression_node(expr_id)?;
match expr {
......
......@@ -106,3 +106,47 @@ fn merge_tuples6() {
assert_eq!(check_transformation(input, vec![], &merge_tuples), expected);
}
#[test]
fn merge_tuples7() {
let input = r#"
select "a", "f" from "t" inner join "t2"
on "t"."a" = "t2"."e" and "t2"."f" = "t"."b"
"#;
// merge_tuples must group rows of the same table on the same
// side of the equality for join conflict resultion to work
// correctly, otherwise we will get Motion(Full) instead
// local join here
let expected = PatternWithParams::new(
format!(
"{} {}",
r#"SELECT "t"."a", "t2"."f" FROM (SELECT "t"."a", "t"."b", "t"."c", "t"."d" FROM "t")"#,
r#"as "t" INNER JOIN (SELECT "t2"."e", "t2"."f", "t2"."g", "t2"."h" FROM "t2") as "t2" ON ("t"."a", "t"."b") = ("t2"."e", "t2"."f")"#,
),
vec![],
);
assert_eq!(check_transformation(input, vec![], &merge_tuples), expected);
}
#[test]
fn merge_tuples8() {
let input = r#"
select "a", "f" from "t" inner join "t2"
on "t"."a" = "t"."b" and "t"."a" = "t2"."e" and "t2"."f" = "t"."b" and "t2"."f" = "t2"."e"
"#;
// check merge tuple will create two groupes:
// one with grouped columns and other group with all other equalities
let expected = PatternWithParams::new(
format!(
"{} {} {}",
r#"SELECT "t"."a", "t2"."f" FROM (SELECT "t"."a", "t"."b", "t"."c", "t"."d" FROM "t")"#,
r#"as "t" INNER JOIN (SELECT "t2"."e", "t2"."f", "t2"."g", "t2"."h" FROM "t2") as "t2" ON"#,
r#"("t"."b", "t"."a") = ("t2"."f", "t2"."e") and ("t2"."f", "t"."a") = ("t2"."e", "t"."b")"#,
),
vec![],
);
assert_eq!(check_transformation(input, vec![], &merge_tuples), expected);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment