diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2377edf9375cf..ede493dcf17d5 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -26,29 +26,23 @@ use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::SortOptions; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::kernels::cmp::eq as arrow_eq; -use arrow::compute::{SortOptions, take}; use arrow::datatypes::*; -use arrow::util::bit_iterator::BitIndexIterator; -use datafusion_common::hash_utils::with_hashes; + use datafusion_common::{ - DFSchema, HashSet, Result, ScalarValue, assert_or_internal_err, exec_datafusion_err, - exec_err, + DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err, }; use datafusion_expr::{ColumnarValue, expr_vec_fmt}; -use datafusion_common::HashMap; -use datafusion_common::hash_utils::RandomState; -use hashbrown::hash_map::RawEntryMut; - -/// Trait for InList static filters -trait StaticFilter { - fn null_count(&self) -> usize; +mod array_static_filter; +mod primitive_filter; +mod static_filter; +mod strategy; - /// Checks if values in `v` are contained in the filter - fn contains(&self, v: &dyn Array, negated: bool) -> Result; -} +use static_filter::StaticFilter; +use strategy::instantiate_static_filter; /// InList pub struct InListExpr { @@ -68,83 +62,6 @@ impl Debug for InListExpr { } } -/// Static filter for InList that stores the array and hash set for O(1) lookups -#[derive(Debug, Clone)] -struct ArrayStaticFilter { - in_array: ArrayRef, - state: RandomState, - /// Used to provide a lookup from value to in list index - /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, -} - -impl StaticFilter for ArrayStaticFilter { - fn null_count(&self) -> usize { - self.in_array.null_count() - } - - /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Null type comparisons always return null (SQL three-valued logic) - if v.data_type() == &DataType::Null - || self.in_array.data_type() == &DataType::Null - { - let nulls = NullBuffer::new_null(v.len()); - return Ok(BooleanArray::new( - BooleanBuffer::new_unset(v.len()), - Some(nulls), - )); - } - - // Unwrap dictionary-encoded needles when the value type matches - // in_array, evaluating against the dictionary values and mapping - // back via keys. - downcast_dictionary_array! { - v => { - // Only unwrap when the haystack (in_array) type matches - // the dictionary value type - if v.values().data_type() == self.in_array.data_type() { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())); - } - } - _ => {} - } - - let needle_nulls = v.logical_nulls(); - let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; - - with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; - Ok((0..v.len()) - .map(|i| { - // SQL three-valued logic: null IN (...) is always null - if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { - return None; - } - - let hash = hashes[i]; - let contains = self - .map - .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) - .is_some(); - - match contains { - true => Some(!negated), - false if haystack_has_nulls => None, - false => Some(negated), - } - }) - .collect()) - }) - } -} - /// Returns true if Arrow's vectorized `eq` kernel supports this data type. /// /// Supported: primitives, boolean, strings (Utf8/LargeUtf8/Utf8View), @@ -160,400 +77,6 @@ fn supports_arrow_eq(dt: &DataType) -> bool { } } -fn instantiate_static_filter( - in_array: ArrayRef, -) -> Result> { - match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - // Float primitive types (use ordered wrappers for Hash/Eq) - DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), - DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } - } -} - -impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave - fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set - if in_array.data_type() == &DataType::Null { - return Ok(ArrayStaticFilter { - in_array, - state: RandomState::default(), - map: HashMap::with_hasher(()), - }); - } - - let state = RandomState::default(); - let mut map: HashMap = HashMap::with_hasher(()); - - with_hashes([&in_array], &state, |hashes| -> Result<()> { - let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; - - match in_array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) - } - None => (0..in_array.len()).for_each(insert_value), - } - - Ok(()) - })?; - - Ok(Self { - in_array, - state, - map, - }) - } -} - -/// Wrapper for f32 that implements Hash and Eq using bit comparison. -/// This treats NaN values as equal to each other when they have the same bit pattern. -#[derive(Clone, Copy)] -struct OrderedFloat32(f32); - -impl Hash for OrderedFloat32 { - fn hash(&self, state: &mut H) { - self.0.to_ne_bytes().hash(state); - } -} - -impl PartialEq for OrderedFloat32 { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } -} - -impl Eq for OrderedFloat32 {} - -impl From for OrderedFloat32 { - fn from(v: f32) -> Self { - Self(v) - } -} - -/// Wrapper for f64 that implements Hash and Eq using bit comparison. -/// This treats NaN values as equal to each other when they have the same bit pattern. -#[derive(Clone, Copy)] -struct OrderedFloat64(f64); - -impl Hash for OrderedFloat64 { - fn hash(&self, state: &mut H) { - self.0.to_ne_bytes().hash(state); - } -} - -impl PartialEq for OrderedFloat64 { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } -} - -impl Eq for OrderedFloat64 {} - -impl From for OrderedFloat64 { - fn from(v: f64) -> Self { - Self(v) - } -} - -// Macro to generate specialized StaticFilter implementations for primitive types -macro_rules! primitive_static_filter { - ($Name:ident, $ArrowType:ty) => { - struct $Name { - null_count: usize, - values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; - - // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: - // ("-" means the value doesn't affect the result) - // - // | needle_null | haystack_null | negated | in set? | result | - // |-------------|---------------|---------|---------|--------| - // | true | - | false | - | null | - // | true | - | true | - | null | - // | false | true | false | yes | true | - // | false | true | false | no | null | - // | false | true | true | yes | false | - // | false | true | true | no | null | - // | false | false | false | yes | true | - // | false | false | false | no | false | - // | false | false | true | yes | false | - // | false | false | true | no | true | - - // Compute the "contains" result using collect_bool (fast batched approach) - // This ignores nulls - we handle them separately - let contains_buffer = if negated { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - !self.values.contains(&needle_values[i]) - }) - } else { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - self.values.contains(&needle_values[i]) - }) - }; - - // Compute the null mask - // Output is null when: - // 1. needle value is null, OR - // 2. needle value is not in set AND haystack has nulls - let result_nulls = match (needle_has_nulls, haystack_has_nulls) { - (false, false) => { - // No nulls anywhere - None - } - (true, false) => { - // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() - } - (false, true) => { - // Only haystack has nulls - result is null when value not in set - // Valid (not null) when original "in set" is true - // For NOT IN: contains_buffer = !original, so validity = !contains_buffer - let validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - Some(NullBuffer::new(validity)) - } - (true, true) => { - // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); - - // Valid when original "in set" is true (see above) - let haystack_validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - - // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; - Some(NullBuffer::new(combined_validity)) - } - }; - - Ok(BooleanArray::new(contains_buffer, result_nulls)) - } - } - }; -} - -// Generate specialized filters for all integer primitive types -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); -primitive_static_filter!(Int32StaticFilter, Int32Type); -primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); -primitive_static_filter!(UInt32StaticFilter, UInt32Type); -primitive_static_filter!(UInt64StaticFilter, UInt64Type); - -// Macro to generate specialized StaticFilter implementations for float types -// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics -macro_rules! float_static_filter { - ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { - struct $Name { - null_count: usize, - values: HashSet<$OrderedType>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(<$OrderedType>::from(v)); - } - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; - - // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: - // ("-" means the value doesn't affect the result) - // - // | needle_null | haystack_null | negated | in set? | result | - // |-------------|---------------|---------|---------|--------| - // | true | - | false | - | null | - // | true | - | true | - | null | - // | false | true | false | yes | true | - // | false | true | false | no | null | - // | false | true | true | yes | false | - // | false | true | true | no | null | - // | false | false | false | yes | true | - // | false | false | false | no | false | - // | false | false | true | yes | false | - // | false | false | true | no | true | - - // Compute the "contains" result using collect_bool (fast batched approach) - // This ignores nulls - we handle them separately - let contains_buffer = if negated { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - !self.values.contains(&<$OrderedType>::from(needle_values[i])) - }) - } else { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - self.values.contains(&<$OrderedType>::from(needle_values[i])) - }) - }; - - // Compute the null mask - // Output is null when: - // 1. needle value is null, OR - // 2. needle value is not in set AND haystack has nulls - let result_nulls = match (needle_has_nulls, haystack_has_nulls) { - (false, false) => { - // No nulls anywhere - None - } - (true, false) => { - // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() - } - (false, true) => { - // Only haystack has nulls - result is null when value not in set - // Valid (not null) when original "in set" is true - // For NOT IN: contains_buffer = !original, so validity = !contains_buffer - let validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - Some(NullBuffer::new(validity)) - } - (true, true) => { - // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); - - // Valid when original "in set" is true (see above) - let haystack_validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - - // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; - Some(NullBuffer::new(combined_validity)) - } - }; - - Ok(BooleanArray::new(contains_buffer, result_nulls)) - } - } - }; -} - -// Generate specialized filters for float types using ordered wrappers -float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); -float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); - /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], diff --git a/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs new file mode 100644 index 0000000000000..93bfcd49600d0 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, + make_comparator, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::{SortOptions, take}; +use arrow::datatypes::DataType; +use arrow::util::bit_iterator::BitIndexIterator; +use datafusion_common::HashMap; +use datafusion_common::Result; +use datafusion_common::hash_utils::{RandomState, with_hashes}; +use hashbrown::hash_map::RawEntryMut; + +use super::static_filter::StaticFilter; + +/// Static filter for InList that stores the array and hash set for O(1) lookups +#[derive(Debug, Clone)] +pub(super) struct ArrayStaticFilter { + in_array: ArrayRef, + state: RandomState, + /// Used to provide a lookup from value to in list index + /// + /// Note: usize::hash is not used, instead the raw entry + /// API is used to store entries w.r.t their value + map: HashMap, +} + +impl StaticFilter for ArrayStaticFilter { + fn null_count(&self) -> usize { + self.in_array.null_count() + } + + /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Null type comparisons always return null (SQL three-valued logic) + if v.data_type() == &DataType::Null + || self.in_array.data_type() == &DataType::Null + { + let nulls = NullBuffer::new_null(v.len()); + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(nulls), + )); + } + + // Unwrap dictionary-encoded needles when the value type matches + // in_array, evaluating against the dictionary values and mapping + // back via keys. + downcast_dictionary_array! { + v => { + // Only unwrap when the haystack (in_array) type matches + // the dictionary value type + if v.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())); + } + } + _ => {} + } + + let needle_nulls = v.logical_nulls(); + let needle_nulls = needle_nulls.as_ref(); + let haystack_has_nulls = self.in_array.null_count() != 0; + + with_hashes([v], &self.state, |hashes| { + let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; + Ok((0..v.len()) + .map(|i| { + // SQL three-valued logic: null IN (...) is always null + if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { + return None; + } + + let hash = hashes[i]; + let contains = self + .map + .raw_entry() + .from_hash(hash, |idx| cmp(i, *idx).is_eq()) + .is_some(); + + match contains { + true => Some(!negated), + false if haystack_has_nulls => None, + false => Some(negated), + } + }) + .collect()) + }) + } +} + +impl ArrayStaticFilter { + /// Computes a [`StaticFilter`] for the provided [`Array`] if there + /// are nulls present or there are more than the configured number of + /// elements. + /// + /// Note: This is split into a separate function as higher-rank trait bounds currently + /// cause type inference to misbehave + pub(super) fn try_new(in_array: ArrayRef) -> Result { + // Null type has no natural order - return empty hash set + if in_array.data_type() == &DataType::Null { + return Ok(ArrayStaticFilter { + in_array, + state: RandomState::default(), + map: HashMap::with_hasher(()), + }); + } + + let state = RandomState::default(); + let mut map: HashMap = HashMap::with_hasher(()); + + with_hashes([&in_array], &state, |hashes| -> Result<()> { + let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + if let RawEntryMut::Vacant(v) = map + .raw_entry_mut() + .from_hash(hash, |x| cmp(*x, idx).is_eq()) + { + v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); + } + }; + + match in_array.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..in_array.len()).for_each(insert_value), + } + + Ok(()) + })?; + + Ok(Self { + in_array, + state, + map, + }) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs new file mode 100644 index 0000000000000..2c184b8ea02e9 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, downcast_array, downcast_dictionary_array, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::take; +use arrow::datatypes::*; +use datafusion_common::{HashSet, Result, exec_datafusion_err}; +use std::hash::{Hash, Hasher}; + +use super::static_filter::StaticFilter; + +/// Wrapper for f32 that implements Hash and Eq using bit comparison. +/// This treats NaN values as equal to each other when they have the same bit pattern. +#[derive(Clone, Copy)] +struct OrderedFloat32(f32); + +impl Hash for OrderedFloat32 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat32 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for OrderedFloat32 {} + +impl From for OrderedFloat32 { + fn from(v: f32) -> Self { + Self(v) + } +} + +/// Wrapper for f64 that implements Hash and Eq using bit comparison. +/// This treats NaN values as equal to each other when they have the same bit pattern. +#[derive(Clone, Copy)] +struct OrderedFloat64(f64); + +impl Hash for OrderedFloat64 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for OrderedFloat64 {} + +impl From for OrderedFloat64 { + fn from(v: f64) -> Self { + Self(v) + } +} + +// Macro to generate specialized StaticFilter implementations for primitive types +macro_rules! primitive_static_filter { + ($Name:ident, $ArrowType:ty) => { + pub(super) struct $Name { + null_count: usize, + values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, + } + + impl $Name { + pub(super) fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(v); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let haystack_has_nulls = self.null_count > 0; + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: + // ("-" means the value doesn't affect the result) + // + // | needle_null | haystack_null | negated | in set? | result | + // |-------------|---------------|---------|---------|--------| + // | true | - | false | - | null | + // | true | - | true | - | null | + // | false | true | false | yes | true | + // | false | true | false | no | null | + // | false | true | true | yes | false | + // | false | true | true | no | null | + // | false | false | false | yes | true | + // | false | false | false | no | false | + // | false | false | true | yes | false | + // | false | false | true | no | true | + + // Compute the "contains" result using collect_bool (fast batched approach) + // This ignores nulls - we handle them separately + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&needle_values[i]) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&needle_values[i]) + }) + }; + + // Compute the null mask + // Output is null when: + // 1. needle value is null, OR + // 2. needle value is not in set AND haystack has nulls + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => { + // No nulls anywhere + None + } + (true, false) => { + // Only needle has nulls - just use needle's null mask + needle_nulls.cloned() + } + (false, true) => { + // Only haystack has nulls - result is null when value not in set + // Valid (not null) when original "in set" is true + // For NOT IN: contains_buffer = !original, so validity = !contains_buffer + let validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) + } + (true, true) => { + // Both have nulls - combine needle nulls with haystack-induced nulls + let needle_validity = needle_nulls.map(|n| n.inner().clone()) + .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + + // Valid when original "in set" is true (see above) + let haystack_validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + + // Combined validity: valid only where both are valid + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) + } + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) + } + } + }; +} + +// Generate specialized filters for all integer primitive types +primitive_static_filter!(Int8StaticFilter, Int8Type); +primitive_static_filter!(Int16StaticFilter, Int16Type); +primitive_static_filter!(Int32StaticFilter, Int32Type); +primitive_static_filter!(Int64StaticFilter, Int64Type); +primitive_static_filter!(UInt8StaticFilter, UInt8Type); +primitive_static_filter!(UInt16StaticFilter, UInt16Type); +primitive_static_filter!(UInt32StaticFilter, UInt32Type); +primitive_static_filter!(UInt64StaticFilter, UInt64Type); + +// Macro to generate specialized StaticFilter implementations for float types +// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics +macro_rules! float_static_filter { + ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { + pub(super) struct $Name { + null_count: usize, + values: HashSet<$OrderedType>, + } + + impl $Name { + pub(super) fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(<$OrderedType>::from(v)); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let haystack_has_nulls = self.null_count > 0; + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: + // ("-" means the value doesn't affect the result) + // + // | needle_null | haystack_null | negated | in set? | result | + // |-------------|---------------|---------|---------|--------| + // | true | - | false | - | null | + // | true | - | true | - | null | + // | false | true | false | yes | true | + // | false | true | false | no | null | + // | false | true | true | yes | false | + // | false | true | true | no | null | + // | false | false | false | yes | true | + // | false | false | false | no | false | + // | false | false | true | yes | false | + // | false | false | true | no | true | + + // Compute the "contains" result using collect_bool (fast batched approach) + // This ignores nulls - we handle them separately + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&<$OrderedType>::from(needle_values[i])) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&<$OrderedType>::from(needle_values[i])) + }) + }; + + // Compute the null mask + // Output is null when: + // 1. needle value is null, OR + // 2. needle value is not in set AND haystack has nulls + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => { + // No nulls anywhere + None + } + (true, false) => { + // Only needle has nulls - just use needle's null mask + needle_nulls.cloned() + } + (false, true) => { + // Only haystack has nulls - result is null when value not in set + // Valid (not null) when original "in set" is true + // For NOT IN: contains_buffer = !original, so validity = !contains_buffer + let validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) + } + (true, true) => { + // Both have nulls - combine needle nulls with haystack-induced nulls + let needle_validity = needle_nulls.map(|n| n.inner().clone()) + .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + + // Valid when original "in set" is true (see above) + let haystack_validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + + // Combined validity: valid only where both are valid + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) + } + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) + } + } + }; +} + +// Generate specialized filters for float types using ordered wrappers +float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); +float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs new file mode 100644 index 0000000000000..47bffb85ad8c1 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, BooleanArray}; +use datafusion_common::Result; + +/// Trait for InList static filters +pub(super) trait StaticFilter { + fn null_count(&self) -> usize; + + /// Checks if values in `v` are contained in the filter + fn contains(&self, v: &dyn Array, negated: bool) -> Result; +} diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs new file mode 100644 index 0000000000000..955ab5ad290a3 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::Result; + +use super::array_static_filter::ArrayStaticFilter; +use super::primitive_filter::*; +use super::static_filter::StaticFilter; + +pub(super) fn instantiate_static_filter( + in_array: ArrayRef, +) -> Result> { + match in_array.data_type() { + // Integer primitive types + DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + // Float primitive types (use ordered wrappers for Hash/Eq) + DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), + _ => { + /* fall through to generic implementation for unsupported types (Struct, etc.) */ + Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) + } + } +}