From fe226dd055c47a3f8ec2085930ff5579b12bdc25 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 16:17:06 +0800 Subject: [PATCH 1/8] Update min_max.rs to support dictionary scalars Return ScalarValue::Dictionary(...) in dictionary batches instead of unwrapping to inner scalars. Enhance min_max! logic to safely handle dictionary-vs-dictionary and dictionary-vs-non-dictionary comparisons. Add regression tests for raw-dictionary covering no-coercion, null-containing, and multi-batch scenarios. --- .../functions-aggregate-common/src/min_max.rs | 36 +++++-- datafusion/functions-aggregate/src/min_max.rs | 95 +++++++++++++++++++ 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 27620221cf23..df950323d80d 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -413,6 +413,28 @@ macro_rules! min_max { min_max_generic!(lhs, rhs, $OP) } + ( + ScalarValue::Dictionary(key_type, lhs_inner), + ScalarValue::Dictionary(_, rhs_inner), + ) => { + let winner = min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP); + ScalarValue::Dictionary(key_type.clone(), Box::new(winner)) + } + + ( + ScalarValue::Dictionary(_, lhs_inner), + rhs, + ) => { + min_max_generic!(lhs_inner.as_ref(), rhs, $OP) + } + + ( + lhs, + ScalarValue::Dictionary(_, rhs_inner), + ) => { + min_max_generic!(lhs, rhs_inner.as_ref(), $OP) + } + e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", @@ -766,9 +788,10 @@ pub fn min_batch(values: &ArrayRef) -> Result { DataType::FixedSizeList(_, _) => { min_max_batch_generic(values, Ordering::Greater)? } - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - min_batch(values)? + DataType::Dictionary(key_type, _) => { + let dict_values = values.as_any_dictionary().values(); + let inner = min_batch(dict_values)?; + ScalarValue::Dictionary(key_type.clone(), Box::new(inner)) } _ => min_max_batch!(values, min), }) @@ -847,9 +870,10 @@ pub fn max_batch(values: &ArrayRef) -> Result { DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - max_batch(values)? + DataType::Dictionary(key_type, _) => { + let dict_values = values.as_any_dictionary().values(); + let inner = max_batch(dict_values)?; + ScalarValue::Dictionary(key_type.clone(), Box::new(inner)) } _ => min_max_batch!(values, max), }) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 9d05c57b02e9..d7aa9bfd1b5e 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1262,4 +1262,99 @@ mod tests { assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); Ok(()) } + + fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) + } + + #[test] + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a", "d"]); + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + let dict_type = dict_array_ref.data_type().clone(); + + let mut min_acc = MinAccumulator::try_new(&dict_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); + + let mut max_acc = MaxAccumulator::try_new(&dict_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) + ); + Ok(()) + } + + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a"]); + let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + let dict_type = dict_array_ref.data_type().clone(); + + let mut min_acc = MinAccumulator::try_new(&dict_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); + + let mut max_acc = MaxAccumulator::try_new(&dict_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("c".to_string()))) + ); + Ok(()) + } + + #[test] + fn test_min_max_dictionary_multi_batch() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + + let values1 = StringArray::from(vec!["b", "c"]); + let keys1 = Int32Array::from(vec![Some(0), Some(1)]); + let batch1 = Arc::new( + DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(), + ) as ArrayRef; + + let values2 = StringArray::from(vec!["a", "d"]); + let keys2 = Int32Array::from(vec![Some(0), Some(1)]); + let batch2 = Arc::new( + DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(), + ) as ArrayRef; + + let mut min_acc = MinAccumulator::try_new(&dict_type)?; + min_acc.update_batch(&[Arc::clone(&batch1)])?; + min_acc.update_batch(&[Arc::clone(&batch2)])?; + let min_result = min_acc.evaluate()?; + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); + + let mut max_acc = MaxAccumulator::try_new(&dict_type)?; + max_acc.update_batch(&[Arc::clone(&batch1)])?; + max_acc.update_batch(&[Arc::clone(&batch2)])?; + let max_result = max_acc.evaluate()?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) + ); + Ok(()) + } } From caafe1ce6aee4467f929adbbaf5d6e5718567f07 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 16:24:31 +0800 Subject: [PATCH 2/8] Refactor dictionary min/max logic and tests Centralize dictionary batch handling for min/max operations. Streamline min_max_batch_generic to initialize from the first non-null element. Implement shared setup/assert helpers in dictionary tests to reduce repetition while preserving test coverage. --- .../functions-aggregate-common/src/min_max.rs | 65 +++++---- datafusion/functions-aggregate/src/min_max.rs | 125 ++++++++---------- 2 files changed, 93 insertions(+), 97 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index df950323d80d..4d984fc4b0c7 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -141,6 +141,16 @@ macro_rules! min_max_generic { }}; } +macro_rules! min_max_dictionary { + ($VALUE:expr, $DELTA:expr, wrap $KEY_TYPE:expr, $OP:ident) => {{ + let winner = min_max_generic!($VALUE, $DELTA, $OP); + ScalarValue::Dictionary($KEY_TYPE.clone(), Box::new(winner)) + }}; + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + min_max_generic!($VALUE, $DELTA, $OP) + }}; +} + // min/max of two scalar values of the same type macro_rules! min_max { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ @@ -417,22 +427,26 @@ macro_rules! min_max { ScalarValue::Dictionary(key_type, lhs_inner), ScalarValue::Dictionary(_, rhs_inner), ) => { - let winner = min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP); - ScalarValue::Dictionary(key_type.clone(), Box::new(winner)) + min_max_dictionary!( + lhs_inner.as_ref(), + rhs_inner.as_ref(), + wrap key_type, + $OP + ) } ( ScalarValue::Dictionary(_, lhs_inner), rhs, ) => { - min_max_generic!(lhs_inner.as_ref(), rhs, $OP) + min_max_dictionary!(lhs_inner.as_ref(), rhs, $OP) } ( lhs, ScalarValue::Dictionary(_, rhs_inner), ) => { - min_max_generic!(lhs, rhs_inner.as_ref(), $OP) + min_max_dictionary!(lhs, rhs_inner.as_ref(), $OP) } e => { @@ -445,6 +459,17 @@ macro_rules! min_max { }}; } +fn dictionary_batch_extreme( + values: &ArrayRef, + extreme_fn: fn(&ArrayRef) -> Result, +) -> Result { + let DataType::Dictionary(key_type, _) = values.data_type() else { + unreachable!("dictionary_batch_extreme requires dictionary arrays") + }; + let inner = extreme_fn(values.as_any_dictionary().values())?; + Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(inner))) +} + /// An accumulator to compute the maximum value #[derive(Debug, Clone)] pub struct MaxAccumulator { @@ -788,32 +813,22 @@ pub fn min_batch(values: &ArrayRef) -> Result { DataType::FixedSizeList(_, _) => { min_max_batch_generic(values, Ordering::Greater)? } - DataType::Dictionary(key_type, _) => { - let dict_values = values.as_any_dictionary().values(); - let inner = min_batch(dict_values)?; - ScalarValue::Dictionary(key_type.clone(), Box::new(inner)) - } + DataType::Dictionary(_, _) => dictionary_batch_extreme(values, min_batch)?, _ => min_max_batch!(values, min), }) } /// Generic min/max implementation for complex types fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { - if array.len() == array.null_count() { + let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i)); + let Some(first_idx) = non_null_indices.next() else { return ScalarValue::try_from(array.data_type()); - } - let mut extreme = ScalarValue::try_from_array(array, 0)?; - for i in 1..array.len() { + }; + + let mut extreme = ScalarValue::try_from_array(array, first_idx)?; + for i in non_null_indices { let current = ScalarValue::try_from_array(array, i)?; - if current.is_null() { - continue; - } - if extreme.is_null() { - extreme = current; - continue; - } - let cmp = extreme.try_cmp(¤t)?; - if cmp == ordering { + if extreme.try_cmp(¤t)? == ordering { extreme = current; } } @@ -870,11 +885,7 @@ pub fn max_batch(values: &ArrayRef) -> Result { DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(key_type, _) => { - let dict_values = values.as_any_dictionary().values(); - let inner = max_batch(dict_values)?; - ScalarValue::Dictionary(key_type.clone(), Box::new(inner)) - } + DataType::Dictionary(_, _) => dictionary_batch_extreme(values, max_batch)?, _ => min_max_batch!(values, max), }) } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index d7aa9bfd1b5e..5734f2854dd8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1267,94 +1267,79 @@ mod tests { ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) } - #[test] - fn test_min_max_dictionary_without_coercion() -> Result<()> { - let values = StringArray::from(vec!["b", "c", "a", "d"]); - let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]); - let dict_array = - DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); - let dict_array_ref = Arc::new(dict_array) as ArrayRef; - let dict_type = dict_array_ref.data_type().clone(); + fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue { + dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string()))) + } + + fn string_dictionary_batch( + values: Vec<&str>, + keys: Vec>, + ) -> ArrayRef { + let values = Arc::new(StringArray::from(values)) as ArrayRef; + Arc::new(DictionaryArray::try_new(Int32Array::from(keys), values).unwrap()) + as ArrayRef + } + + fn assert_dictionary_min_max( + dict_type: &DataType, + batches: &[ArrayRef], + expected_min: &str, + expected_max: &str, + ) -> Result<()> { + let key_type = match dict_type { + DataType::Dictionary(key_type, _) => key_type.as_ref().clone(), + other => panic!("expected dictionary type, got {other:?}"), + }; - let mut min_acc = MinAccumulator::try_new(&dict_type)?; - min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; - let min_result = min_acc.evaluate()?; + let mut min_acc = MinAccumulator::try_new(dict_type)?; + for batch in batches { + min_acc.update_batch(&[Arc::clone(batch)])?; + } assert_eq!( - min_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + min_acc.evaluate()?, + utf8_dict_scalar(key_type.clone(), expected_min) ); - let mut max_acc = MaxAccumulator::try_new(&dict_type)?; - max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; - let max_result = max_acc.evaluate()?; - assert_eq!( - max_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) - ); + let mut max_acc = MaxAccumulator::try_new(dict_type)?; + for batch in batches { + max_acc.update_batch(&[Arc::clone(batch)])?; + } + assert_eq!(max_acc.evaluate()?, utf8_dict_scalar(key_type, expected_max)); + Ok(()) } #[test] - fn test_min_max_dictionary_with_nulls() -> Result<()> { - let values = StringArray::from(vec!["b", "c", "a"]); - let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]); - let dict_array = - DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); - let dict_array_ref = Arc::new(dict_array) as ArrayRef; + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + vec!["b", "c", "a", "d"], + vec![Some(0), Some(1), Some(2), Some(3)], + ); let dict_type = dict_array_ref.data_type().clone(); - let mut min_acc = MinAccumulator::try_new(&dict_type)?; - min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; - let min_result = min_acc.evaluate()?; - assert_eq!( - min_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) - ); + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } - let mut max_acc = MaxAccumulator::try_new(&dict_type)?; - max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; - let max_result = max_acc.evaluate()?; - assert_eq!( - max_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("c".to_string()))) + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + vec!["b", "c", "a"], + vec![None, Some(0), None, Some(1), Some(2)], ); - Ok(()) + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c") } #[test] fn test_min_max_dictionary_multi_batch() -> Result<()> { let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let batch1 = + string_dictionary_batch(vec!["b", "c"], vec![Some(0), Some(1)]); + let batch2 = + string_dictionary_batch(vec!["a", "d"], vec![Some(0), Some(1)]); - let values1 = StringArray::from(vec!["b", "c"]); - let keys1 = Int32Array::from(vec![Some(0), Some(1)]); - let batch1 = Arc::new( - DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(), - ) as ArrayRef; - - let values2 = StringArray::from(vec!["a", "d"]); - let keys2 = Int32Array::from(vec![Some(0), Some(1)]); - let batch2 = Arc::new( - DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(), - ) as ArrayRef; - - let mut min_acc = MinAccumulator::try_new(&dict_type)?; - min_acc.update_batch(&[Arc::clone(&batch1)])?; - min_acc.update_batch(&[Arc::clone(&batch2)])?; - let min_result = min_acc.evaluate()?; - assert_eq!( - min_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) - ); - - let mut max_acc = MaxAccumulator::try_new(&dict_type)?; - max_acc.update_batch(&[Arc::clone(&batch1)])?; - max_acc.update_batch(&[Arc::clone(&batch2)])?; - let max_result = max_acc.evaluate()?; - assert_eq!( - max_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) - ); - Ok(()) + assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d") } } From 0bbc56e23a68f7ed7500934a939b708f4cbb644e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 16:30:44 +0800 Subject: [PATCH 3/8] Simplify min/max flow in dictionary handling Refactor dictionary min/max flow by removing the wrap macro arm, making re-wrapping explicit through a private helper. This separates the "choose inner winner" from the "wrap as dictionary" step for easier auditing. In `datafusion/functions-aggregate/src/min_max.rs`, update `string_dictionary_batch` to accept slices instead of owned Vecs, and introduce a small `evaluate_dictionary_accumulator` helper to streamline min/max assertions with a shared accumulator execution path, reducing repeated setup. --- .../functions-aggregate-common/src/min_max.rs | 20 +++---- datafusion/functions-aggregate/src/min_max.rs | 59 ++++++++++--------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 4d984fc4b0c7..aa802d4003f4 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -142,13 +142,7 @@ macro_rules! min_max_generic { } macro_rules! min_max_dictionary { - ($VALUE:expr, $DELTA:expr, wrap $KEY_TYPE:expr, $OP:ident) => {{ - let winner = min_max_generic!($VALUE, $DELTA, $OP); - ScalarValue::Dictionary($KEY_TYPE.clone(), Box::new(winner)) - }}; - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - min_max_generic!($VALUE, $DELTA, $OP) - }}; + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ min_max_generic!($VALUE, $DELTA, $OP) }}; } // min/max of two scalar values of the same type @@ -427,11 +421,13 @@ macro_rules! min_max { ScalarValue::Dictionary(key_type, lhs_inner), ScalarValue::Dictionary(_, rhs_inner), ) => { - min_max_dictionary!( + wrap_dictionary_scalar( + key_type.as_ref(), + min_max_dictionary!( lhs_inner.as_ref(), rhs_inner.as_ref(), - wrap key_type, $OP + ), ) } @@ -467,7 +463,11 @@ fn dictionary_batch_extreme( unreachable!("dictionary_batch_extreme requires dictionary arrays") }; let inner = extreme_fn(values.as_any_dictionary().values())?; - Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(inner))) + Ok(wrap_dictionary_scalar(key_type.as_ref(), inner)) +} + +fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(value)) } /// An accumulator to compute the maximum value diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 5734f2854dd8..6be7341ed10c 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1271,13 +1271,21 @@ mod tests { dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string()))) } - fn string_dictionary_batch( - values: Vec<&str>, - keys: Vec>, - ) -> ArrayRef { - let values = Arc::new(StringArray::from(values)) as ArrayRef; - Arc::new(DictionaryArray::try_new(Int32Array::from(keys), values).unwrap()) - as ArrayRef + fn string_dictionary_batch(values: &[&str], keys: &[Option]) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn evaluate_dictionary_accumulator( + mut acc: impl Accumulator, + batches: &[ArrayRef], + ) -> Result { + for batch in batches { + acc.update_batch(&[Arc::clone(batch)])?; + } + acc.evaluate() } fn assert_dictionary_min_max( @@ -1291,20 +1299,17 @@ mod tests { other => panic!("expected dictionary type, got {other:?}"), }; - let mut min_acc = MinAccumulator::try_new(dict_type)?; - for batch in batches { - min_acc.update_batch(&[Arc::clone(batch)])?; - } - assert_eq!( - min_acc.evaluate()?, - utf8_dict_scalar(key_type.clone(), expected_min) - ); + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min)); - let mut max_acc = MaxAccumulator::try_new(dict_type)?; - for batch in batches { - max_acc.update_batch(&[Arc::clone(batch)])?; - } - assert_eq!(max_acc.evaluate()?, utf8_dict_scalar(key_type, expected_max)); + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max)); Ok(()) } @@ -1312,8 +1317,8 @@ mod tests { #[test] fn test_min_max_dictionary_without_coercion() -> Result<()> { let dict_array_ref = string_dictionary_batch( - vec!["b", "c", "a", "d"], - vec![Some(0), Some(1), Some(2), Some(3)], + &["b", "c", "a", "d"], + &[Some(0), Some(1), Some(2), Some(3)], ); let dict_type = dict_array_ref.data_type().clone(); @@ -1323,8 +1328,8 @@ mod tests { #[test] fn test_min_max_dictionary_with_nulls() -> Result<()> { let dict_array_ref = string_dictionary_batch( - vec!["b", "c", "a"], - vec![None, Some(0), None, Some(1), Some(2)], + &["b", "c", "a"], + &[None, Some(0), None, Some(1), Some(2)], ); let dict_type = dict_array_ref.data_type().clone(); @@ -1335,10 +1340,8 @@ mod tests { fn test_min_max_dictionary_multi_batch() -> Result<()> { let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); - let batch1 = - string_dictionary_batch(vec!["b", "c"], vec![Some(0), Some(1)]); - let batch2 = - string_dictionary_batch(vec!["a", "d"], vec![Some(0), Some(1)]); + let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]); + let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]); assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d") } From 9240400ba56efcdb96c0622480dad16130cfeebb Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 16:50:23 +0800 Subject: [PATCH 4/8] Fix dictionary min/max behavior in DataFusion Update min_max.rs to ensure dictionary batches iterate actual array rows, comparing referenced scalar values. Unreferenced dictionary entries no longer affect MIN/MAX, and referenced null values are correctly skipped. Expanded tests to cover these changes and updated expectations Added regression tests for unreferenced and referenced null dictionary values. --- .../functions-aggregate-common/src/min_max.rs | 37 +++++++++++++------ datafusion/functions-aggregate/src/min_max.rs | 34 ++++++++++++++++- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index aa802d4003f4..21e67cf44207 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -18,8 +18,8 @@ //! Basic min/max functionality shared across DataFusion aggregate functions use arrow::array::{ - ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, @@ -457,13 +457,23 @@ macro_rules! min_max { fn dictionary_batch_extreme( values: &ArrayRef, - extreme_fn: fn(&ArrayRef) -> Result, + ordering: Ordering, ) -> Result { - let DataType::Dictionary(key_type, _) = values.data_type() else { - unreachable!("dictionary_batch_extreme requires dictionary arrays") - }; - let inner = extreme_fn(values.as_any_dictionary().values())?; - Ok(wrap_dictionary_scalar(key_type.as_ref(), inner)) + let mut extreme: Option = None; + + for i in 0..values.len() { + let current = ScalarValue::try_from_array(values, i)?; + if current.is_null() { + continue; + } + + match &extreme { + Some(existing) if existing.try_cmp(¤t)? != ordering => {} + _ => extreme = Some(current), + } + } + + extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok) } fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue { @@ -813,7 +823,9 @@ pub fn min_batch(values: &ArrayRef) -> Result { DataType::FixedSizeList(_, _) => { min_max_batch_generic(values, Ordering::Greater)? } - DataType::Dictionary(_, _) => dictionary_batch_extreme(values, min_batch)?, + DataType::Dictionary(_, _) => { + dictionary_batch_extreme(values, Ordering::Greater)? + } _ => min_max_batch!(values, min), }) } @@ -828,7 +840,10 @@ fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result Result { DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(_, _) => dictionary_batch_extreme(values, max_batch)?, + DataType::Dictionary(_, _) => dictionary_batch_extreme(values, Ordering::Less)?, _ => min_max_batch!(values, max), }) } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 6be7341ed10c..09142a4858b5 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1259,7 +1259,7 @@ mod tests { let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); + assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); Ok(()) } @@ -1278,6 +1278,16 @@ mod tests { ) as ArrayRef } + fn optional_string_dictionary_batch( + values: &[Option<&str>], + keys: &[Option], + ) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + fn evaluate_dictionary_accumulator( mut acc: impl Accumulator, batches: &[ArrayRef], @@ -1336,6 +1346,28 @@ mod tests { assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c") } + #[test] + fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["a", "z", "zz_unused"], + &[Some(1), Some(1), None], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z") + } + + #[test] + fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> { + let dict_array_ref = optional_string_dictionary_batch( + &[Some("b"), None, Some("a"), Some("d")], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + #[test] fn test_min_max_dictionary_multi_batch() -> Result<()> { let dict_type = From 7127c20c7df88db4ee827d495b2fe3c96f1ae72d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 17:00:38 +0800 Subject: [PATCH 5/8] Fix MIN/MAX to preserve dictionary types in execution Update get_min_max_result_type to maintain Dictionary instead of unwrapping to V, allowing planned MIN/MAX execution to utilize the dictionary-aware accumulator. Add end-to-end SQL regression test to ensure MIN/MAX properly ignores unreferenced dictionary values and preserves dictionary-typed output schema. Adjust unit expectations for dictionary coercion tests to reflect new planned-path behavior. --- datafusion/core/tests/sql/aggregates/basic.rs | 38 ++++++++++++++++++ datafusion/functions-aggregate/src/min_max.rs | 40 ++++++++++++------- 2 files changed, 63 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index 3e5dc6a0b187..1ee98e44f89d 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -442,6 +442,44 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> { Ok(()) } +#[tokio::test] +async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> { + let ctx = SessionContext::new(); + + let dict_values = StringArray::from(vec!["a", "z", "zz_unused"]); + let dict_indices = Int32Array::from(vec![Some(1), Some(1), None]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new("dict", dict_type.clone(), true)])); + + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)])?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let df = ctx + .sql("SELECT min(dict) AS min_dict, max(dict) AS max_dict FROM t") + .await?; + let results = df.collect().await?; + + assert_eq!(results[0].schema().field(0).data_type(), &dict_type); + assert_eq!(results[0].schema().field(1).data_type(), &dict_type); + + assert_snapshot!( + batches_to_string(&results), + @r" + +----------+----------+ + | min_dict | max_dict | + +----------+----------+ + | z | z | + +----------+----------+ + " + ); + + Ok(()) +} + #[tokio::test] async fn group_by_ree_dict_column() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 09142a4858b5..fbca5522517c 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -53,7 +53,6 @@ use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use datafusion_macros::user_doc; use half::f16; use std::mem::size_of_val; -use std::ops::Deref; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. @@ -63,17 +62,12 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { input_types.len() ); } - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } + // Preserve dictionary inputs so planned MIN/MAX execution uses the same + // dictionary-aware accumulator/state path as direct accumulator tests. + // + // TODO add checker for datatype which min and max supported. + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function. + Ok(input_types.to_vec()) } #[user_doc( @@ -1223,6 +1217,10 @@ mod tests { vec![DataType::Decimal128(10, 2)], vec![DataType::Decimal256(1, 1)], vec![DataType::Utf8], + vec![DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )], ]; for fun in funs { for input_type in &input_types { @@ -1237,7 +1235,13 @@ mod tests { let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let result = get_min_max_result_type(&[data_type])?; - assert_eq!(result, vec![DataType::Utf8]); + assert_eq!( + result, + vec![DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )] + ); Ok(()) } @@ -1254,12 +1258,18 @@ mod tests { let mut min_acc = MinAccumulator::try_new(&rt_type)?; min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let min_result = min_acc.evaluate()?; - assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string()))); + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) + ); Ok(()) } From a077a273c795635c156ed06d7bde7c96d2568cd9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 17:12:58 +0800 Subject: [PATCH 6/8] Enhance MIN/MAX integration test and clean up macros Update the MIN/MAX integration test in basic.rs to use two MemTable partitions, ensuring the physical plan includes both partial and final aggregate stages. Retain checks for dictionary-typed output schema and results. In min_max.rs, remove the no-op min_max_dictionary! macro and inline the existing generic comparison helper for improved clarity and efficiency. --- datafusion/core/tests/sql/aggregates/basic.rs | 34 ++++++++++++++----- .../functions-aggregate-common/src/min_max.rs | 14 ++------ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index 1ee98e44f89d..8ebdba5e2d5a 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -19,6 +19,7 @@ use super::*; use datafusion::common::test_util::batches_to_string; use datafusion_catalog::MemTable; use datafusion_common::ScalarValue; +use datafusion_physical_plan::displayable; use insta::assert_snapshot; #[tokio::test] @@ -444,23 +445,40 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> { #[tokio::test] async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> { - let ctx = SessionContext::new(); - - let dict_values = StringArray::from(vec!["a", "z", "zz_unused"]); - let dict_indices = Int32Array::from(vec![Some(1), Some(1), None]); - let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + let ctx = SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let schema = Arc::new(Schema::new(vec![Field::new("dict", dict_type.clone(), true)])); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(dict)])?; - let provider = MemTable::try_new(schema, vec![vec![batch]])?; + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(DictionaryArray::new( + Int32Array::from(vec![Some(1), Some(1), None]), + Arc::new(StringArray::from(vec!["a", "z", "zz_unused"])), + ))], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(DictionaryArray::new( + Int32Array::from(vec![Some(0), Some(1)]), + Arc::new(StringArray::from(vec!["a", "d"])), + ))], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; let df = ctx .sql("SELECT min(dict) AS min_dict, max(dict) AS max_dict FROM t") .await?; + let physical_plan = df.clone().create_physical_plan().await?; + let formatted_plan = format!("{}", displayable(physical_plan.as_ref()).indent(true)); + assert!(formatted_plan.contains("AggregateExec: mode=Partial, gby=[]")); + assert!( + formatted_plan.contains("AggregateExec: mode=Final, gby=[]") + || formatted_plan.contains("AggregateExec: mode=FinalPartitioned, gby=[]") + ); + let results = df.collect().await?; assert_eq!(results[0].schema().field(0).data_type(), &dict_type); @@ -472,7 +490,7 @@ async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> { +----------+----------+ | min_dict | max_dict | +----------+----------+ - | z | z | + | a | z | +----------+----------+ " ); diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 21e67cf44207..ba96a2a9211c 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -141,10 +141,6 @@ macro_rules! min_max_generic { }}; } -macro_rules! min_max_dictionary { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ min_max_generic!($VALUE, $DELTA, $OP) }}; -} - // min/max of two scalar values of the same type macro_rules! min_max { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ @@ -423,11 +419,7 @@ macro_rules! min_max { ) => { wrap_dictionary_scalar( key_type.as_ref(), - min_max_dictionary!( - lhs_inner.as_ref(), - rhs_inner.as_ref(), - $OP - ), + min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP), ) } @@ -435,14 +427,14 @@ macro_rules! min_max { ScalarValue::Dictionary(_, lhs_inner), rhs, ) => { - min_max_dictionary!(lhs_inner.as_ref(), rhs, $OP) + min_max_generic!(lhs_inner.as_ref(), rhs, $OP) } ( lhs, ScalarValue::Dictionary(_, rhs_inner), ) => { - min_max_dictionary!(lhs, rhs_inner.as_ref(), $OP) + min_max_generic!(lhs, rhs_inner.as_ref(), $OP) } e => { From 4bc0ac35e224bd3b2a9628567f95becd8976fc9c Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 17:16:20 +0800 Subject: [PATCH 7/8] feat(tests): improve formatting and readability in aggregate tests - Refactored the formatting in `basic.rs` to enhance readability by breaking long lines into shorter segments. - Updated `min_max.rs` for consistent formatting in the `test_min_max_dictionary_ignores_unreferenced_values` function. --- datafusion/core/tests/sql/aggregates/basic.rs | 9 +++++++-- datafusion/functions-aggregate/src/min_max.rs | 6 ++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index 8ebdba5e2d5a..b7b8de428378 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -445,11 +445,16 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> { #[tokio::test] async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> { - let ctx = SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); - let schema = Arc::new(Schema::new(vec![Field::new("dict", dict_type.clone(), true)])); + let schema = Arc::new(Schema::new(vec![Field::new( + "dict", + dict_type.clone(), + true, + )])); let batch1 = RecordBatch::try_new( schema.clone(), diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index fbca5522517c..bb074bc4af18 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1358,10 +1358,8 @@ mod tests { #[test] fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> { - let dict_array_ref = string_dictionary_batch( - &["a", "z", "zz_unused"], - &[Some(1), Some(1), None], - ); + let dict_array_ref = + string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]); let dict_type = dict_array_ref.data_type().clone(); assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z") From a85804b27d260fd6cbdff5fd7e039fd5c48ae3a1 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 2 Apr 2026 17:44:59 +0800 Subject: [PATCH 8/8] Restore dictionary coercion in min_max.rs Ensure MIN/MAX(Dictionary(..., T)) returns T at SQL boundary while retaining the new dictionary comparison logic. Update regression tests to verify that dictionary inputs are accepted and that the result is the underlying scalar type. Adjust planner-level regression test in basic.rs to expect final output schema to be Utf8 instead of dictionary-typed. --- datafusion/core/tests/sql/aggregates/basic.rs | 4 +- datafusion/functions-aggregate/src/min_max.rs | 70 ++++++++++--------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index b7b8de428378..57bc9a64397d 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -486,8 +486,8 @@ async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> { let results = df.collect().await?; - assert_eq!(results[0].schema().field(0).data_type(), &dict_type); - assert_eq!(results[0].schema().field(1).data_type(), &dict_type); + assert_eq!(results[0].schema().field(0).data_type(), &DataType::Utf8); + assert_eq!(results[0].schema().field(1).data_type(), &DataType::Utf8); assert_snapshot!( batches_to_string(&results), diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index bb074bc4af18..9d4389e0a587 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -53,6 +53,7 @@ use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use datafusion_macros::user_doc; use half::f16; use std::mem::size_of_val; +use std::ops::Deref; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. @@ -62,12 +63,17 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { input_types.len() ); } - // Preserve dictionary inputs so planned MIN/MAX execution uses the same - // dictionary-aware accumulator/state path as direct accumulator tests. - // - // TODO add checker for datatype which min and max supported. - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function. - Ok(input_types.to_vec()) + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported. + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function. + _ => Ok(input_types.to_vec()), + } } #[user_doc( @@ -1209,23 +1215,31 @@ mod tests { #[test] fn test_min_max_coerce_types() { - // the coerced types is same with input types let funs: Vec> = vec![Box::new(Min::new()), Box::new(Max::new())]; - let input_types = vec![ - vec![DataType::Int32], - vec![DataType::Decimal128(10, 2)], - vec![DataType::Decimal256(1, 1)], - vec![DataType::Utf8], - vec![DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - )], + let cases = vec![ + (vec![DataType::Int32], vec![DataType::Int32]), + ( + vec![DataType::Decimal128(10, 2)], + vec![DataType::Decimal128(10, 2)], + ), + ( + vec![DataType::Decimal256(1, 1)], + vec![DataType::Decimal256(1, 1)], + ), + (vec![DataType::Utf8], vec![DataType::Utf8]), + ( + vec![DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )], + vec![DataType::Utf8], + ), ]; for fun in funs { - for input_type in &input_types { + for (input_type, expected_type) in &cases { let result = fun.coerce_types(input_type); - assert_eq!(*input_type, result.unwrap()); + assert_eq!(*expected_type, result.unwrap()); } } } @@ -1235,18 +1249,12 @@ mod tests { let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let result = get_min_max_result_type(&[data_type])?; - assert_eq!( - result, - vec![DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - )] - ); + assert_eq!(result, vec![DataType::Utf8]); Ok(()) } #[test] - fn test_min_max_dictionary() -> Result<()> { + fn test_min_max_dictionary_after_coercion() -> Result<()> { let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]); let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]); let dict_array = @@ -1258,18 +1266,12 @@ mod tests { let mut min_acc = MinAccumulator::try_new(&rt_type)?; min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let min_result = min_acc.evaluate()?; - assert_eq!( - min_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) - ); + assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string()))); let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!( - max_result, - dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) - ); + assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); Ok(()) }