diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 1004fba3d4f45..2d5320b4a5e23 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -33,7 +33,7 @@ use crate::joins::hash_join::shared_bounds::{ PartitionBounds, PartitionBuildData, SharedBuildAccumulator, }; use crate::joins::utils::{ - OnceFut, equal_rows_arr, get_final_indices_from_shared_bitmap, + JoinKeyComparator, OnceFut, equal_rows_arr, get_final_indices_from_shared_bitmap, }; use crate::{ RecordBatchStream, SendableRecordBatchStream, handle_state, @@ -49,6 +49,7 @@ use crate::{ use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::SortOptions; use datafusion_common::{ JoinSide, JoinType, NullEquality, Result, internal_datafusion_err, internal_err, }; @@ -353,6 +354,89 @@ fn count_distinct_sorted_indices(indices: &UInt32Array) -> usize { count } +/// Optimized probe for RightSemi/RightAnti joins (HashMap path, no filter). +/// +/// For each probe row, walks the build-side collision chain via +/// [`JoinHashMapType::get_first_match`] and stops at the first entry passing +/// equality (verified by [`JoinKeyComparator::is_equal`], zero-alloc per row). +/// +/// - **RightSemi**: collects probe rows where at least one match was found. +/// - **RightAnti**: collects probe rows where no match was found. +/// +/// # Example +/// +/// ```text +/// Build side (customers): +/// Row 0: {id: 10} Row 1: {id: 20} Row 2: {id: 10} +/// Row 3: {id: 30} Row 4: {id: 10} Row 5: {id: 40} +/// +/// Hash(10) chain: row 4 → row 2 → row 0 → end (3 entries) +/// +/// Probe batch (orders): +/// Row 0: {customer_id: 10} Row 1: {customer_id: 99} +/// Row 2: {customer_id: 10} Row 3: {customer_id: 30} +/// +/// RightSemi: output probe rows that have at least one match +/// Probe 0: get_first_match(Hash(10)) → row 4, eq check ✓ → matched +/// Probe 1: get_first_match(Hash(99)) → None → not matched +/// Probe 2: get_first_match(Hash(10)) → row 4, eq check ✓ → matched +/// Probe 3: get_first_match(Hash(30)) → row 3, eq check ✓ → matched +/// Output: probe rows [0, 2, 3] +/// +/// RightAnti: output probe rows with NO match (inverted) +/// Output: probe row [1] +/// ``` +#[expect(clippy::too_many_arguments)] +fn process_probe_batch_right_semi_anti( + probe_batch: &RecordBatch, + probe_values: &[ArrayRef], + hashes_buffer: &[u64], + build_values: &[ArrayRef], + build_batch: &RecordBatch, + map: &dyn JoinHashMapType, + join_type: JoinType, + null_equality: NullEquality, + output_schema: &Schema, + column_indices: &[ColumnIndex], +) -> Result { + let num_probe_rows = probe_batch.num_rows(); + let is_semi = join_type == JoinType::RightSemi; + + let sort_options = vec![SortOptions::default(); build_values.len()]; + let comparator = + JoinKeyComparator::new(build_values, probe_values, &sort_options, null_equality)?; + + let mut matched_probe_indices: Vec = Vec::new(); + + for (probe_row, &hash) in hashes_buffer.iter().enumerate().take(num_probe_rows) { + let found = map + .get_first_match(hash, &mut |build_row| { + comparator.is_equal(build_row as usize, probe_row) + }) + .is_some(); + + if found == is_semi { + matched_probe_indices.push(probe_row as u32); + } + } + + let probe_indices = UInt32Array::from(matched_probe_indices); + // Build-side indices unused for RightSemi/Anti output, + // but build_batch_from_indices requires them. + let build_indices = UInt64Array::from(vec![0u64; probe_indices.len()]); + + build_batch_from_indices( + output_schema, + build_batch, + probe_batch, + &build_indices, + &probe_indices, + column_indices, + JoinSide::Left, + join_type, + ) +} + impl HashJoinStream { #[expect(clippy::too_many_arguments)] pub(super) fn new( @@ -683,6 +767,31 @@ impl HashJoinStream { return Ok(StatefulStreamResult::Continue); } + // Optimized path for RightSemi/RightAnti joins without filter on HashMap: + // Uses get_first_match to stop chain traversal after first verified match. + if matches!(self.join_type, JoinType::RightSemi | JoinType::RightAnti) + && self.filter.is_none() + && !self.null_aware + && let Map::HashMap(map) = build_side.left_data.map() + { + let result = process_probe_batch_right_semi_anti( + &state.batch, + &state.values, + &self.hashes_buffer, + build_side.left_data.values(), + build_side.left_data.batch(), + map.as_ref(), + self.join_type, + self.null_equality, + &self.schema, + &self.column_indices, + )?; + timer.done(); + self.output_buffer.push_batch(result)?; + self.state = HashJoinStreamState::FetchProbeBatch; + return Ok(StatefulStreamResult::Continue); + } + // get the matched by join keys indices let (left_indices, right_indices, next_offset) = match build_side.left_data.map() { diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index 8f0fb66b64fbf..c1e608fef65b5 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -134,6 +134,31 @@ pub trait JoinHashMapType: Send + Sync { /// Returns the number of entries in the join hash map. fn len(&self) -> usize; + + /// Returns the first build-side row index in the collision chain for `hash_value` + /// where `predicate` returns true. Returns `None` if no entry matches. + /// + /// Walks the LIFO chain from head to tail, calling `predicate(build_row_index)` + /// on each entry. Stops and returns as soon as one returns true. + /// + /// Used by RightSemi/RightAnti joins to find the first equality match per probe + /// row without enumerating the full chain. + /// + /// # Example + /// + /// ```text + /// Hash(10) chain: row 4 → row 2 → row 0 → end + /// + /// get_first_match(Hash(10), |row| row == 2) + /// row 4: predicate(4) → false + /// row 2: predicate(2) → true → returns Some(2) + /// (row 0 never visited) + /// ``` + fn get_first_match( + &self, + hash_value: u64, + predicate: &mut dyn FnMut(u64) -> bool, + ) -> Option; } pub struct JoinHashMapU32 { @@ -212,6 +237,14 @@ impl JoinHashMapType for JoinHashMapU32 { fn len(&self) -> usize { self.map.len() } + + fn get_first_match( + &self, + hash_value: u64, + predicate: &mut dyn FnMut(u64) -> bool, + ) -> Option { + get_first_match_impl::(&self.map, &self.next, hash_value, predicate) + } } pub struct JoinHashMapU64 { @@ -290,11 +323,48 @@ impl JoinHashMapType for JoinHashMapU64 { fn len(&self) -> usize { self.map.len() } + + fn get_first_match( + &self, + hash_value: u64, + predicate: &mut dyn FnMut(u64) -> bool, + ) -> Option { + get_first_match_impl::(&self.map, &self.next, hash_value, predicate) + } } use crate::joins::MapOffset; use crate::joins::chain::traverse_chain; +/// Returns the first build-side row in the collision chain where `predicate` is true. +pub fn get_first_match_impl( + map: &HashTable<(u64, T)>, + next_chain: &[T], + hash_value: u64, + predicate: &mut dyn FnMut(u64) -> bool, +) -> Option +where + T: Copy + TryFrom + PartialOrd + Into + Sub, + >::Error: Debug, +{ + let zero = T::try_from(0).unwrap(); + let one = T::try_from(1).unwrap(); + + if let Some((_, idx)) = map.find(hash_value, |(h, _)| hash_value == *h) { + let mut i = *idx - one; + loop { + if predicate(i.into()) { + return Some(i.into()); + } + let next = next_chain[i.into() as usize]; + if next == zero { + return None; + } + i = next - one; + } + } + None +} pub fn update_from_iter<'a, T>( map: &mut HashTable<(u64, T)>, next: &mut [T], diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 571c199abb448..4304df6106480 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -24,8 +24,8 @@ use std::sync::Arc; use crate::joins::MapOffset; use crate::joins::join_hash_map::{ - contain_hashes, get_matched_indices, get_matched_indices_with_limit_offset, - update_from_iter, + contain_hashes, get_first_match_impl, get_matched_indices, + get_matched_indices_with_limit_offset, update_from_iter, }; use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{ @@ -109,6 +109,15 @@ impl JoinHashMapType for PruningJoinHashMap { fn len(&self) -> usize { self.map.len() } + + fn get_first_match( + &self, + hash_value: u64, + predicate: &mut dyn FnMut(u64) -> bool, + ) -> Option { + let next: Vec = self.next.iter().copied().collect(); + get_first_match_impl::(&self.map, &next, hash_value, predicate) + } } /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with