Skip to content
Snippets Groups Projects
Commit b2833f6a authored by Arseniy Volynets's avatar Arseniy Volynets :boy_tone5:
Browse files

fix: incorrect equivalence classes

- `propagate_equality` transformation did not
compute equality classes correctly, its 'merge'
function was completely wrong: it tried to add
intersection of classes to a another class,
instead of doing union
- to merge classes correctly we must do it
when we add a new pair of equal expressions:
otherwise later there will too many classes
that contain common elements, so 'merge'
function was removed and 'insert' now merges
two classes that contain common elements
- Also this logic is now covered by tests
parent 29b7bd56
No related branches found
No related tags found
1 merge request!1414sbroad import
......@@ -205,7 +205,7 @@ impl EqClassExpr {
}
/// A set of expressions that are equal to each other.
#[derive(Clone, PartialEq, Debug)]
#[derive(Clone, PartialEq, Debug, Default)]
struct EqClass {
set: HashSet<EqClassExpr, RepeatableState>,
}
......@@ -235,6 +235,7 @@ impl EqClass {
#[derive(Clone, PartialEq, Debug)]
struct EqClassChain {
/// Groups of equivalence classes of the "AND"-e chain (all element are equal to each other).
/// INVARIANT: this is REAL equivalence classes, meaning that they don't intersect.
list: Vec<EqClass>,
/// A set of equalities where both sides are references.
pairs: HashSet<EqClassExpr, RepeatableState>,
......@@ -250,28 +251,64 @@ impl EqClassChain {
/// Insert a new pair to the equality classes chain.
fn insert(&mut self, left: &EqClassExpr, right: &EqClassExpr) {
let mut ok = false;
// Insert a pair to all equality classes that have equal element.
// If one of the sides doesn't satisfy self equivalence, produce
// a new equality class.
for class in &mut self.list {
if (class.set.contains(left) || class.set.contains(right))
&& left.is_self_equivalent()
&& right.is_self_equivalent()
{
class.set.insert(left.clone());
class.set.insert(right.clone());
ok = true;
// a = null and b = null, such pairs can't be used to derive
// equality relation: `a = null and a = c`.
// TODO: `a = null` is always false, so is the whole and-chain,
// in the future we should land an optimization
// that simplifies boolean expressions and does NOT apply optimizations
// like `equality propagation` to such and-chains.
// https://git.picodata.io/picodata/picodata/sbroad/-/issues/855
if !left.is_self_equivalent() || !right.is_self_equivalent() {
return;
}
// Find indexes of classes in which these expressions appear,
// as we maintain an invariant that classes do not intersect
// there may be at most 2 of such classes.
let mut classes = [0, 0];
let mut found_classes = 0;
for (pos, class) in self.list.iter().enumerate() {
if class.set.contains(left) || class.set.contains(right) {
classes[found_classes] = pos;
found_classes += 1;
if found_classes == 2 {
break;
}
}
}
// No matches, so add a new equality class.
if !ok {
let mut class = EqClass::new();
class.set.insert(left.clone());
class.set.insert(right.clone());
self.list.push(class);
match found_classes {
0 => {
// No matches, so add a new equality class.
let mut class = EqClass::new();
class.set.insert(left.clone());
class.set.insert(right.clone());
self.list.push(class);
}
1 => {
// one expression is already bound to class
// another one is not.
self.list[classes[0]].set.insert(left.clone());
self.list[classes[0]].set.insert(right.clone());
}
2 => {
// two expressions appear in the different classes
// merge two classes into one
// TODO: this works in linear time, we could do
// better using Disjoint-set-union. Doing merges
// in O(1).
let mut i = classes[0];
let mut j = classes[1];
if self.list[i].set.len() < self.list[j].set.len() {
(i, j) = (j, i);
}
// Merge the smaller set (j) into bigger set (i)
let smaller = std::mem::take(self.list.get_mut(j).unwrap());
self.list.get_mut(i).unwrap().set.extend(smaller.set);
self.list.swap_remove(j);
}
_ => unreachable!("see break condition in loop above"),
}
// If both sides are references, add them to the pairs set.
......@@ -281,61 +318,6 @@ impl EqClassChain {
}
}
/// Merge equality classes in the chain if they contain common elements.
/// The only exception - we do not merge equality classes containing "NULL"
/// as `NULL != NULL` (the result is "NULL" itself).
fn merge(&self) -> Self {
let mut result = EqClassChain::new();
result.pairs.clone_from(&self.pairs);
// A set of indexes of the equality classes in the chain
// that contain common elements and should not be reinspected.
let mut matched: HashSet<usize> = HashSet::new();
for i in 0..self.list.len() {
if matched.contains(&i) {
continue;
}
if let Some(class) = self.list.get(i) {
let mut new_class = class.clone();
matched.insert(i);
for j in i..self.list.len() {
if matched.contains(&j) {
continue;
}
if let Some(item) = self.list.get(j) {
let mut buf = EqClass::new();
let mut is_self_equivalent = true;
for expr in item.set.intersection(&new_class.set) {
if expr.is_self_equivalent() {
is_self_equivalent = false;
buf = EqClass::new();
break;
}
buf.set.insert(expr.clone());
}
if !buf.set.is_empty() && is_self_equivalent {
matched.insert(j);
}
for expr in buf.set {
new_class.set.insert(expr);
}
}
}
result.list.push(new_class);
}
}
result
}
fn subtract_pairs(&self) -> Self {
let mut result = EqClassChain::new();
result.pairs.clone_from(&self.pairs);
......@@ -394,7 +376,7 @@ impl Chain {
eq_classes.insert(&left_eqe?, &right_eqe?);
}
let ecs = eq_classes.merge().subtract_pairs();
let ecs = eq_classes.subtract_pairs();
for ec in &ecs.list {
// Do not generate new equalities from a empty or single element lists.
......
use crate::backend::sql::ir::PatternWithParams;
use std::collections::HashMap;
use crate::collection;
use crate::ir::relation::Type;
use crate::ir::transformation::helpers::check_transformation;
use crate::ir::value::Value;
use crate::ir::Plan;
use crate::{backend::sql::ir::PatternWithParams, ir::node::NodeId};
use pretty_assertions::assert_eq;
use super::{EqClass, EqClassChain, EqClassConst, EqClassExpr, EqClassRef};
fn derive_equalities(plan: &mut Plan) {
plan.derive_equalities().unwrap();
}
......@@ -123,3 +129,108 @@ fn equality_propagation5() {
expected
);
}
#[derive(Default)]
struct ColumnBuilder {
next_pos: usize,
name_to_pos: HashMap<&'static str, usize>,
}
impl ColumnBuilder {
fn make_test_column(&mut self, name: &'static str) -> super::EqClassExpr {
// assuming all columns refer to the same relational node,
// different name means different position
let position = *self.name_to_pos.entry(name).or_insert_with(|| {
let p = self.next_pos;
self.next_pos += 1;
p
});
EqClassExpr::EqClassRef(EqClassRef {
targets: Some(vec![0]),
position,
parent: Some(NodeId {
offset: 0,
arena_type: crate::ir::node::ArenaType::Arena64,
}),
col_type: Type::Integer,
asterisk_source: None,
})
}
}
fn make_const(value: usize) -> EqClassExpr {
EqClassExpr::EqClassConst(EqClassConst {
value: Value::Unsigned(value as u64),
})
}
#[test]
fn equality_classes() {
let mut builder = ColumnBuilder::default();
let mut eqcs = EqClassChain::new();
let cola = builder.make_test_column("a");
let val1 = make_const(1);
let colb = builder.make_test_column("b");
let colc = builder.make_test_column("c");
let cold = builder.make_test_column("d");
// { a, b, 1}
eqcs.insert(&cola, &val1);
eqcs.insert(&colb, &val1);
assert_eq!(
eqcs.list,
vec![EqClass {
set: collection!(cola.clone(), colb.clone(), val1.clone())
}]
);
// { a, b, 1}, {c, d}
eqcs.insert(&colc, &cold);
assert_eq!(
eqcs.list,
vec![
EqClass {
set: collection!(cola.clone(), colb.clone(), val1.clone())
},
EqClass {
set: collection!(colc.clone(), cold.clone())
}
]
);
// { a, b, 1, c, d}
eqcs.insert(&colc, &val1);
let expected = vec![EqClass {
set: collection!(
cola.clone(),
colb.clone(),
val1.clone(),
colc.clone(),
cold.clone()
),
}];
assert_eq!(eqcs.list, expected);
// test we don't create equality classes with nulls
// as it's useless
let null = EqClassExpr::EqClassConst(EqClassConst { value: Value::Null });
eqcs.insert(&cola, &null);
assert_eq!(eqcs.list, expected);
// we used only c = d, so substruct pairs
// should return {a, b}
// note: we don't need {a, b, 1}, because
// a = 1 and b = 1
// is already present in expression
let substructed = eqcs.subtract_pairs();
assert_eq!(
substructed.list,
vec![EqClass {
set: collection!(cola.clone(), colb.clone())
},]
);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment