diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 27620221cf23c..ca2939cf5aa16 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -18,8 +18,8 @@ //! Basic min/max functionality shared across DataFusion aggregate functions use arrow::array::{ - ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, @@ -141,10 +141,25 @@ macro_rules! min_max_generic { }}; } -// min/max of two scalar values of the same type macro_rules! min_max { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - Ok(match ($VALUE, $DELTA) { + match choose_min_max!($OP) { + Ordering::Greater => Ok(min_max_scalar_impl!($VALUE, $DELTA, min)), + Ordering::Less => Ok(min_max_scalar_impl!($VALUE, $DELTA, max)), + Ordering::Equal => { + unreachable!("min/max comparisons do not use equal ordering") + } + } + }}; +} + +// min/max of two logically compatible scalar values. +// Dictionary scalars participate by comparing their inner logical values. +// When both inputs are dictionaries, matching key types are preserved in the +// result; differing key types remain an unexpected invariant violation. +macro_rules! min_max_scalar_impl { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + match ($VALUE, $DELTA) { (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, ( lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss), @@ -413,16 +428,54 @@ macro_rules! min_max { min_max_generic!(lhs, rhs, $OP) } + ( + ScalarValue::Dictionary(lhs_dict_key_type, lhs_dict_value), + ScalarValue::Dictionary(rhs_dict_key_type, rhs_dict_value), + ) => { + if lhs_dict_key_type != rhs_dict_key_type { + return internal_err!( + "MIN/MAX is not expected to receive dictionary scalars with different key types ({:?} vs {:?})", + lhs_dict_key_type, + rhs_dict_key_type + ); + } + + let result = min_max_scalar( + lhs_dict_value.as_ref(), + rhs_dict_value.as_ref(), + choose_min_max!($OP), + )?; + ScalarValue::Dictionary(lhs_dict_key_type.clone(), Box::new(result)) + } + (ScalarValue::Dictionary(_, lhs_dict_value), rhs_scalar) => { + min_max_scalar(lhs_dict_value.as_ref(), rhs_scalar, choose_min_max!($OP))? + } + (lhs_scalar, ScalarValue::Dictionary(_, rhs_dict_value)) => { + min_max_scalar(lhs_scalar, rhs_dict_value.as_ref(), choose_min_max!($OP))? + } + e => { return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + "MIN/MAX is not expected to receive logically incompatible scalar values {:?}", e ) } - }) + } }}; } +fn min_max_scalar( + lhs: &ScalarValue, + rhs: &ScalarValue, + ordering: Ordering, +) -> Result { + match ordering { + Ordering::Greater => Ok(min_max_scalar_impl!(lhs, rhs, min)), + Ordering::Less => Ok(min_max_scalar_impl!(lhs, rhs, max)), + Ordering::Equal => unreachable!("min/max comparisons do not use equal ordering"), + } +} + /// An accumulator to compute the maximum value #[derive(Debug, Clone)] pub struct MaxAccumulator { @@ -760,37 +813,40 @@ pub fn min_batch(values: &ArrayRef) -> Result { min_binary_view ) } - DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?, - DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?, - DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?, - DataType::FixedSizeList(_, _) => { - min_max_batch_generic(values, Ordering::Greater)? - } - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - min_batch(values)? - } + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Dictionary(_, _) => min_max_batch_generic(values, Ordering::Greater)?, _ => min_max_batch!(values, min), }) } -/// Generic min/max implementation for complex types -fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { - if array.len() == array.null_count() { - return ScalarValue::try_from(array.data_type()); - } - let mut extreme = ScalarValue::try_from_array(array, 0)?; - for i in 1..array.len() { - let current = ScalarValue::try_from_array(array, i)?; - if current.is_null() { - continue; +/// Finds the min/max by scanning logical rows via `ScalarValue::try_from_array`. +/// +/// Callers are responsible for routing dictionary arrays to this helper. +/// Passing `dictionary.values()` is semantically incorrect because it can +/// include unreferenced dictionary entries and ignore null key positions. +fn min_max_batch_generic(values: &ArrayRef, ordering: Ordering) -> Result { + let mut index = 0; + let mut extreme = loop { + if index == values.len() { + return ScalarValue::try_from(values.data_type()); } - if extreme.is_null() { - extreme = current; - continue; + + let current = ScalarValue::try_from_array(values, index)?; + index += 1; + + if !current.is_null() { + break current; } - let cmp = extreme.try_cmp(¤t)?; - if cmp == ordering { + }; + + while index < values.len() { + let current = ScalarValue::try_from_array(values, index)?; + index += 1; + + if !current.is_null() && extreme.try_cmp(¤t)? == ordering { extreme = current; } } @@ -843,14 +899,122 @@ pub fn max_batch(values: &ArrayRef) -> Result { let value = value.map(|e| e.to_vec()); ScalarValue::FixedSizeBinary(*size, value) } - DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?, - DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, - DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, - DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - max_batch(values)? - } + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Dictionary(_, _) => min_max_batch_generic(values, Ordering::Less)?, _ => min_max_batch!(values, max), }) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, DictionaryArray}; + use std::sync::Arc; + + #[test] + fn min_max_dictionary_and_scalar_compare_by_inner_value() -> Result<()> { + let dictionary = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let scalar = ScalarValue::Float32(Some(2.0)); + + let result = min_max_scalar(&dictionary, &scalar, Ordering::Less)?; + + assert_eq!(result, ScalarValue::Float32(Some(2.0))); + Ok(()) + } + + #[test] + fn min_max_dictionary_same_key_type_rewraps_result() -> Result<()> { + let lhs = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let rhs = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Float32(Some(2.0))), + ); + + let result = min_max_scalar(&lhs, &rhs, Ordering::Less)?; + + assert_eq!( + result, + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Float32(Some(2.0))), + ) + ); + Ok(()) + } + + #[test] + fn min_max_dictionary_different_key_types_error() -> Result<()> { + let lhs = ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let rhs = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Float32(Some(2.0))), + ); + + let error: DataFusionError = + min_max_scalar(&lhs, &rhs, Ordering::Less).unwrap_err(); + + assert!( + error + .to_string() + .contains("dictionary scalars with different key types") + ); + Ok(()) + } + + #[test] + fn min_max_dictionary_and_incompatible_scalar_error() -> Result<()> { + let dictionary = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Float32(Some(1.0))), + ); + let scalar = ScalarValue::Int32(Some(2)); + + let error: DataFusionError = + min_max_scalar(&dictionary, &scalar, Ordering::Less).unwrap_err(); + + assert!( + error + .to_string() + .contains("logically incompatible scalar values") + ); + Ok(()) + } + + #[test] + fn min_max_batch_dictionary_uses_logical_rows() -> Result<()> { + let keys = Int8Array::from(vec![Some(1), None, Some(1), Some(1)]); + let values = Arc::new(StringArray::from(vec!["zzz", "bbb", "aaa"])); + let array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef; + let raw_values = array.as_any_dictionary().values(); + let raw_min = min_batch(raw_values)?; + + let min = min_batch(&array)?; + let max = max_batch(&array)?; + + let expected = ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(Some("bbb".to_string()))), + ); + + // raw_min is "aaa" because it is the min of the values, but min/max of the dictionary should be "bbb" + // because the null key is ignored and all non-null keys point to "bbb". + assert_ne!(raw_min, expected); + + assert_eq!(min, expected); + assert_eq!(max, expected); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 9d05c57b02e93..78ceba9779310 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1004,12 +1004,13 @@ mod tests { use super::*; use arrow::{ array::{ - DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray, + Array, DictionaryArray, Float32Array, Int8Array, Int32Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + PrimitiveArray, StringArray, }, datatypes::{ - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, + ArrowDictionaryKeyType, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, }, }; use std::sync::Arc; @@ -1259,7 +1260,178 @@ mod tests { let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); + assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); + Ok(()) + } + + fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) + } + + fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue { + dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string()))) + } + + fn string_dictionary_batch(values: &[&str], keys: &[Option]) -> ArrayRef { + string_dictionary_batch_with_keys(Int32Array::from(keys.to_vec()), values) + } + + fn string_dictionary_batch_with_keys( + keys: PrimitiveArray, + values: &[&str], + ) -> ArrayRef + where + K: ArrowDictionaryKeyType, + { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef + } + + fn optional_string_dictionary_batch( + values: &[Option<&str>], + keys: &[Option], + ) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn float_dictionary_batch(values: &[f32], keys: &[Option]) -> ArrayRef { + let values = Arc::new(Float32Array::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn evaluate_dictionary_accumulator( + mut acc: impl Accumulator, + batches: &[ArrayRef], + ) -> Result { + for batch in batches { + acc.update_batch(&[Arc::clone(batch)])?; + } + acc.evaluate() + } + + fn assert_dictionary_min_max( + dict_type: &DataType, + batches: &[ArrayRef], + expected_min: &str, + expected_max: &str, + ) -> Result<()> { + let key_type = match dict_type { + DataType::Dictionary(key_type, _) => key_type.as_ref().clone(), + other => panic!("expected dictionary type, got {other:?}"), + }; + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min)); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max)); + + Ok(()) + } + + #[test] + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a", "d"], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a"], + &[None, Some(0), None, Some(1), Some(2)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c") + } + + #[test] + fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> { + let dict_array_ref = + string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z") + } + + #[test] + fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> { + let dict_array_ref = optional_string_dictionary_batch( + &[Some("b"), None, Some("a"), Some("d")], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_multi_batch() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]); + let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]); + + assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d") + } + + #[test] + fn test_min_max_dictionary_int8_keys() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let dict_array_ref = string_dictionary_batch_with_keys( + Int8Array::from(vec![Some(0), Some(1), Some(2), Some(3)]), + &["b", "c", "a", "d"], + ); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_float_with_nans() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Float32)); + let batch1 = float_dictionary_batch(&[0.0, f32::NAN], &[Some(0), Some(1)]); + let batch2 = float_dictionary_batch(&[f32::NEG_INFINITY], &[Some(0)]); + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(&dict_type)?, + &[Arc::clone(&batch1), Arc::clone(&batch2)], + )?; + assert_eq!( + min_result, + dict_scalar( + DataType::Int32, + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ) + ); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(&dict_type)?, + &[batch1, batch2], + )?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Float32(Some(f32::NAN))) + ); + Ok(()) } }