Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::joins::hash_join::shared_bounds::{
PartitionBounds, PartitionBuildData, SharedBuildAccumulator,
};
use crate::joins::utils::{
OnceFut, equal_rows_arr, get_final_indices_from_shared_bitmap,
JoinKeyComparator, OnceFut, equal_rows_arr, get_final_indices_from_shared_bitmap,
};
use crate::{
RecordBatchStream, SendableRecordBatchStream, handle_state,
Expand All @@ -49,6 +49,7 @@ use crate::{
use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::SortOptions;
use datafusion_common::{
JoinSide, JoinType, NullEquality, Result, internal_datafusion_err, internal_err,
};
Expand Down Expand Up @@ -353,6 +354,89 @@ fn count_distinct_sorted_indices(indices: &UInt32Array) -> usize {
count
}

/// Optimized probe for RightSemi/RightAnti joins (HashMap path, no filter).
///
/// For each probe row, walks the build-side collision chain via
/// [`JoinHashMapType::get_first_match`] and stops at the first entry passing
/// equality (verified by [`JoinKeyComparator::is_equal`], zero-alloc per row).
///
/// - **RightSemi**: collects probe rows where at least one match was found.
/// - **RightAnti**: collects probe rows where no match was found.
///
/// # Example
///
/// ```text
/// Build side (customers):
/// Row 0: {id: 10} Row 1: {id: 20} Row 2: {id: 10}
/// Row 3: {id: 30} Row 4: {id: 10} Row 5: {id: 40}
///
/// Hash(10) chain: row 4 → row 2 → row 0 → end (3 entries)
///
/// Probe batch (orders):
/// Row 0: {customer_id: 10} Row 1: {customer_id: 99}
/// Row 2: {customer_id: 10} Row 3: {customer_id: 30}
///
/// RightSemi: output probe rows that have at least one match
/// Probe 0: get_first_match(Hash(10)) → row 4, eq check ✓ → matched
/// Probe 1: get_first_match(Hash(99)) → None → not matched
/// Probe 2: get_first_match(Hash(10)) → row 4, eq check ✓ → matched
/// Probe 3: get_first_match(Hash(30)) → row 3, eq check ✓ → matched
/// Output: probe rows [0, 2, 3]
///
/// RightAnti: output probe rows with NO match (inverted)
/// Output: probe row [1]
/// ```
#[expect(clippy::too_many_arguments)]
fn process_probe_batch_right_semi_anti(
probe_batch: &RecordBatch,
probe_values: &[ArrayRef],
hashes_buffer: &[u64],
build_values: &[ArrayRef],
build_batch: &RecordBatch,
map: &dyn JoinHashMapType,
join_type: JoinType,
null_equality: NullEquality,
output_schema: &Schema,
column_indices: &[ColumnIndex],
) -> Result<RecordBatch> {
let num_probe_rows = probe_batch.num_rows();
let is_semi = join_type == JoinType::RightSemi;

let sort_options = vec![SortOptions::default(); build_values.len()];
let comparator =
JoinKeyComparator::new(build_values, probe_values, &sort_options, null_equality)?;

let mut matched_probe_indices: Vec<u32> = Vec::new();

for (probe_row, &hash) in hashes_buffer.iter().enumerate().take(num_probe_rows) {
let found = map
.get_first_match(hash, &mut |build_row| {
comparator.is_equal(build_row as usize, probe_row)
})
.is_some();

if found == is_semi {
matched_probe_indices.push(probe_row as u32);
}
}

let probe_indices = UInt32Array::from(matched_probe_indices);
// Build-side indices unused for RightSemi/Anti output,
// but build_batch_from_indices requires them.
let build_indices = UInt64Array::from(vec![0u64; probe_indices.len()]);

build_batch_from_indices(
output_schema,
build_batch,
probe_batch,
&build_indices,
&probe_indices,
column_indices,
JoinSide::Left,
join_type,
)
}

impl HashJoinStream {
#[expect(clippy::too_many_arguments)]
pub(super) fn new(
Expand Down Expand Up @@ -683,6 +767,31 @@ impl HashJoinStream {
return Ok(StatefulStreamResult::Continue);
}

// Optimized path for RightSemi/RightAnti joins without filter on HashMap:
// Uses get_first_match to stop chain traversal after first verified match.
if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti)
&& self.filter.is_none()
&& !self.null_aware
&& let Map::HashMap(map) = build_side.left_data.map()
{
let result = process_probe_batch_right_semi_anti(
&state.batch,
&state.values,
&self.hashes_buffer,
build_side.left_data.values(),
build_side.left_data.batch(),
map.as_ref(),
self.join_type,
self.null_equality,
&self.schema,
&self.column_indices,
)?;
timer.done();
self.output_buffer.push_batch(result)?;
self.state = HashJoinStreamState::FetchProbeBatch;
return Ok(StatefulStreamResult::Continue);
}

// get the matched by join keys indices
let (left_indices, right_indices, next_offset) = match build_side.left_data.map()
{
Expand Down
70 changes: 70 additions & 0 deletions datafusion/physical-plan/src/joins/join_hash_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,31 @@ pub trait JoinHashMapType: Send + Sync {

/// Returns the number of entries in the join hash map.
fn len(&self) -> usize;

/// Returns the first build-side row index in the collision chain for `hash_value`
/// where `predicate` returns true. Returns `None` if no entry matches.
///
/// Walks the LIFO chain from head to tail, calling `predicate(build_row_index)`
/// on each entry. Stops and returns as soon as one returns true.
///
/// Used by RightSemi/RightAnti joins to find the first equality match per probe
/// row without enumerating the full chain.
///
/// # Example
///
/// ```text
/// Hash(10) chain: row 4 → row 2 → row 0 → end
///
/// get_first_match(Hash(10), |row| row == 2)
/// row 4: predicate(4) → false
/// row 2: predicate(2) → true → returns Some(2)
/// (row 0 never visited)
/// ```
fn get_first_match(
&self,
hash_value: u64,
predicate: &mut dyn FnMut(u64) -> bool,
) -> Option<u64>;
}

pub struct JoinHashMapU32 {
Expand Down Expand Up @@ -212,6 +237,14 @@ impl JoinHashMapType for JoinHashMapU32 {
fn len(&self) -> usize {
self.map.len()
}

fn get_first_match(
&self,
hash_value: u64,
predicate: &mut dyn FnMut(u64) -> bool,
) -> Option<u64> {
get_first_match_impl::<u32>(&self.map, &self.next, hash_value, predicate)
}
}

pub struct JoinHashMapU64 {
Expand Down Expand Up @@ -290,11 +323,48 @@ impl JoinHashMapType for JoinHashMapU64 {
fn len(&self) -> usize {
self.map.len()
}

fn get_first_match(
&self,
hash_value: u64,
predicate: &mut dyn FnMut(u64) -> bool,
) -> Option<u64> {
get_first_match_impl::<u64>(&self.map, &self.next, hash_value, predicate)
}
}

use crate::joins::MapOffset;
use crate::joins::chain::traverse_chain;

/// Returns the first build-side row in the collision chain where `predicate` is true.
pub fn get_first_match_impl<T>(
map: &HashTable<(u64, T)>,
next_chain: &[T],
hash_value: u64,
predicate: &mut dyn FnMut(u64) -> bool,
) -> Option<u64>
where
T: Copy + TryFrom<usize> + PartialOrd + Into<u64> + Sub<Output = T>,
<T as TryFrom<usize>>::Error: Debug,
{
let zero = T::try_from(0).unwrap();
let one = T::try_from(1).unwrap();

if let Some((_, idx)) = map.find(hash_value, |(h, _)| hash_value == *h) {
let mut i = *idx - one;
loop {
if predicate(i.into()) {
return Some(i.into());
}
let next = next_chain[i.into() as usize];
if next == zero {
return None;
}
i = next - one;
}
}
None
}
pub fn update_from_iter<'a, T>(
map: &mut HashTable<(u64, T)>,
next: &mut [T],
Expand Down
13 changes: 11 additions & 2 deletions datafusion/physical-plan/src/joins/stream_join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use std::sync::Arc;

use crate::joins::MapOffset;
use crate::joins::join_hash_map::{
contain_hashes, get_matched_indices, get_matched_indices_with_limit_offset,
update_from_iter,
contain_hashes, get_first_match_impl, get_matched_indices,
get_matched_indices_with_limit_offset, update_from_iter,
};
use crate::joins::utils::{JoinFilter, JoinHashMapType};
use crate::metrics::{
Expand Down Expand Up @@ -109,6 +109,15 @@ impl JoinHashMapType for PruningJoinHashMap {
fn len(&self) -> usize {
self.map.len()
}

fn get_first_match(
&self,
hash_value: u64,
predicate: &mut dyn FnMut(u64) -> bool,
) -> Option<u64> {
let next: Vec<u64> = self.next.iter().copied().collect();
get_first_match_impl::<u64>(&self.map, &next, hash_value, predicate)
}
}

/// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with
Expand Down
Loading