Skip to content

fix: Fix EquivalenceClass calculation for Union queries #16185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use crate::{
PhysicalSortExpr, PhysicalSortRequirement,
};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{JoinType, ScalarValue};
use datafusion_common::{HashMap, JoinType, ScalarValue};
use datafusion_physical_expr_common::physical_expr::format_physical_expr_list;
use itertools::Itertools;
use std::fmt::Display;
use std::sync::Arc;
use std::vec::IntoIter;
Expand Down Expand Up @@ -422,6 +423,110 @@ impl EquivalenceGroup {
self.bridge_classes()
}

/// Creates a mapping from expressions to their positions in the equivalence group.
///
/// This function builds a HashMap that maps each expression to a tuple containing:
/// 1. A unique index for the expression (based on insertion order)
/// 2. The ID of the equivalence class it belongs to
///
/// # Returns
///
/// A HashMap where:
/// - Key: Reference to a physical expression
/// - Value: A tuple (expr_index, class_id) where:
/// - expr_index: A unique index for the expression (0-based, based on insertion order)
/// - class_id: The index of the equivalence class containing this expression
///
/// # Example
///
/// For an equivalence group with classes:
/// - Class 0: [a, b]
/// - Class 1: [c, d]
///
/// The function returns:
/// {
/// a -> (0, 0),
/// b -> (1, 0),
/// c -> (2, 1),
/// d -> (3, 1)
/// }
fn expressions(&self) -> HashMap<&Arc<dyn PhysicalExpr>, (usize, usize)> {
let mut map = HashMap::new();
for (cls_id, cls) in self.classes.iter().enumerate() {
for expr in cls.exprs.iter() {
map.insert(expr, (map.len(), cls_id));
}
}
map
}

/// Computes the intersection of two equivalence groups.
///
/// This function finds all expressions that are equivalent in both groups by:
/// 1. Creating a mapping of expressions to their positions in the first group
/// 2. For each equivalence class in the second group:
/// - Find all expressions that exist in both groups
/// - Sort them by their position in the first group
/// - Group consecutive expressions that belong to the same equivalence class in the first group
///
/// Computational Complexity: O(NlogN) where N is the total number of expressions in both groups.
/// # Arguments
///
/// * `other` - The other equivalence group to intersect with
///
/// # Returns
///
/// A new equivalence group containing only the expressions that are equivalent
/// in both input groups.
///
/// # Example
///
///
/// Group1: [a, b, c] // a=b=c
/// Group2: [b, c, d] // b=c=d
///
/// Result: [b, c] // b=c in both groups
///
pub fn intersect(&self, other: &Self) -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here, can we please add documentation

let self_exprs = self.expressions();

let mut new_classes = Vec::new();
for cls in other.classes.iter() {
let exprs = cls
.exprs
.iter()
.flat_map(|expr| {
self_exprs
.get(expr)
.map(|(expr_id, cls_id)| (expr, *expr_id, *cls_id))
})
.sorted_by_key(|(_, expr_id, _)| *expr_id)
.collect::<Vec<_>>();
let mut start = 0;
if exprs.len() <= 1 {
continue;
}
for i in 0..exprs.len() - 1 {
let cls_id = exprs[i].2;
let next_cls_id = exprs[i + 1].2;
if cls_id != next_cls_id && i > start {
new_classes.push(EquivalenceClass::new(
(start..=i).map(|idx| Arc::clone(exprs[idx].0)).collect(),
));
start = i + 1;
}
}
if exprs.len() > start + 1 {
new_classes.push(EquivalenceClass::new(
(start..exprs.len())
.map(|idx| Arc::clone(exprs[idx].0))
.collect(),
));
}
}
Self::new(new_classes)
}

/// This utility function unifies/bridges classes that have common expressions.
/// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`.
/// Since both classes contain `b`, columns `a`, `b` and `c` are actually all
Expand Down Expand Up @@ -1098,4 +1203,27 @@ mod tests {

Ok(())
}

#[test]
fn test_eq_group_intersect() -> Result<()> {
let eq_group1 = EquivalenceGroup::new(vec![
EquivalenceClass::new(vec![lit(1), lit(2), lit(3)]),
EquivalenceClass::new(vec![lit(5), lit(6), lit(7)]),
]);
let eq_group2 = EquivalenceGroup::new(vec![
EquivalenceClass::new(vec![lit(2), lit(3), lit(4)]),
EquivalenceClass::new(vec![lit(6), lit(7), lit(8)]),
]);
let intersect = eq_group1.intersect(&eq_group2);

assert_eq!(intersect.len(), 2);
for cls in intersect.classes.iter() {
assert_eq!(cls.exprs.len(), 2);
assert!(
(cls.exprs.contains(&lit(2)) && cls.exprs.contains(&lit(3)))
|| (cls.exprs.contains(&lit(6)) && cls.exprs.contains(&lit(7)))
);
}
Ok(())
}
}
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/equivalence/properties/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ fn calculate_union_binary(
orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs);
let orderings = orderings.build();

let eq_group = lhs.eq_group().intersect(rhs.eq_group());

let mut eq_properties =
EquivalenceProperties::new(lhs.schema).with_constants(constants);

eq_properties.add_equivalence_group(eq_group);
eq_properties.add_new_orderings(orderings);
Ok(eq_properties)
}
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ CREATE TABLE t2(
(3, 'John')
;

# union with equivalence class
query TT rowsort
(SELECT name, name as __name FROM t1) UNION ALL (SELECT name, name as __name FROM t1) ORDER BY __name;
----
Alex Alex
Alex Alex
Alice Alice
Alice Alice
Bob Bob
Bob Bob

# union with EXCEPT(JOIN)
query T rowsort
(
Expand Down