diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index 3e5dc6a0b187..57bc9a64397d 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -19,6 +19,7 @@ use super::*; use datafusion::common::test_util::batches_to_string; use datafusion_catalog::MemTable; use datafusion_common::ScalarValue; +use datafusion_physical_plan::displayable; use insta::assert_snapshot; #[tokio::test] @@ -442,6 +443,66 @@ async fn count_distinct_dictionary_mixed_values() -> Result<()> { Ok(()) } +#[tokio::test] +async fn min_max_dictionary_uses_planned_dictionary_path() -> Result<()> { + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); + + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new( + "dict", + dict_type.clone(), + true, + )])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(DictionaryArray::new( + Int32Array::from(vec![Some(1), Some(1), None]), + Arc::new(StringArray::from(vec!["a", "z", "zz_unused"])), + ))], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(DictionaryArray::new( + Int32Array::from(vec![Some(0), Some(1)]), + Arc::new(StringArray::from(vec!["a", "d"])), + ))], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + + let df = ctx + .sql("SELECT min(dict) AS min_dict, max(dict) AS max_dict FROM t") + .await?; + let physical_plan = df.clone().create_physical_plan().await?; + let formatted_plan = format!("{}", displayable(physical_plan.as_ref()).indent(true)); + assert!(formatted_plan.contains("AggregateExec: mode=Partial, gby=[]")); + assert!( + formatted_plan.contains("AggregateExec: mode=Final, gby=[]") + || formatted_plan.contains("AggregateExec: mode=FinalPartitioned, gby=[]") + ); + + let results = df.collect().await?; + + assert_eq!(results[0].schema().field(0).data_type(), &DataType::Utf8); + assert_eq!(results[0].schema().field(1).data_type(), &DataType::Utf8); + + assert_snapshot!( + batches_to_string(&results), + @r" + +----------+----------+ + | min_dict | max_dict | + +----------+----------+ + | a | z | + +----------+----------+ + " + ); + + Ok(()) +} + #[tokio::test] async fn group_by_ree_dict_column() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 27620221cf23..ba96a2a9211c 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -18,8 +18,8 @@ //! Basic min/max functionality shared across DataFusion aggregate functions use arrow::array::{ - ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, - Date64Array, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray, @@ -413,6 +413,30 @@ macro_rules! min_max { min_max_generic!(lhs, rhs, $OP) } + ( + ScalarValue::Dictionary(key_type, lhs_inner), + ScalarValue::Dictionary(_, rhs_inner), + ) => { + wrap_dictionary_scalar( + key_type.as_ref(), + min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP), + ) + } + + ( + ScalarValue::Dictionary(_, lhs_inner), + rhs, + ) => { + min_max_generic!(lhs_inner.as_ref(), rhs, $OP) + } + + ( + lhs, + ScalarValue::Dictionary(_, rhs_inner), + ) => { + min_max_generic!(lhs, rhs_inner.as_ref(), $OP) + } + e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", @@ -423,6 +447,31 @@ macro_rules! min_max { }}; } +fn dictionary_batch_extreme( + values: &ArrayRef, + ordering: Ordering, +) -> Result { + let mut extreme: Option = None; + + for i in 0..values.len() { + let current = ScalarValue::try_from_array(values, i)?; + if current.is_null() { + continue; + } + + match &extreme { + Some(existing) if existing.try_cmp(¤t)? != ordering => {} + _ => extreme = Some(current), + } + } + + extreme.map_or_else(|| ScalarValue::try_from(values.data_type()), Ok) +} + +fn wrap_dictionary_scalar(key_type: &DataType, value: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type.clone()), Box::new(value)) +} + /// An accumulator to compute the maximum value #[derive(Debug, Clone)] pub struct MaxAccumulator { @@ -767,8 +816,7 @@ pub fn min_batch(values: &ArrayRef) -> Result { min_max_batch_generic(values, Ordering::Greater)? } DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - min_batch(values)? + dictionary_batch_extreme(values, Ordering::Greater)? } _ => min_max_batch!(values, min), }) @@ -776,21 +824,18 @@ pub fn min_batch(values: &ArrayRef) -> Result { /// Generic min/max implementation for complex types fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { - if array.len() == array.null_count() { + let mut non_null_indices = (0..array.len()).filter(|&i| !array.is_null(i)); + let Some(first_idx) = non_null_indices.next() else { return ScalarValue::try_from(array.data_type()); - } - let mut extreme = ScalarValue::try_from_array(array, 0)?; - for i in 1..array.len() { + }; + + let mut extreme = ScalarValue::try_from_array(array, first_idx)?; + for i in non_null_indices { let current = ScalarValue::try_from_array(array, i)?; if current.is_null() { continue; } - if extreme.is_null() { - extreme = current; - continue; - } - let cmp = extreme.try_cmp(¤t)?; - if cmp == ordering { + if extreme.is_null() || extreme.try_cmp(¤t)? == ordering { extreme = current; } } @@ -847,10 +892,7 @@ pub fn max_batch(values: &ArrayRef) -> Result { DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - max_batch(values)? - } + DataType::Dictionary(_, _) => dictionary_batch_extreme(values, Ordering::Less)?, _ => min_max_batch!(values, max), }) } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 9d05c57b02e9..9d4389e0a587 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -70,8 +70,8 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // TODO add checker, if the value type is complex data type Ok(vec![dict_value_type.deref().clone()]) } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + // TODO add checker for datatype which min and max supported. + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function. _ => Ok(input_types.to_vec()), } } @@ -1215,19 +1215,31 @@ mod tests { #[test] fn test_min_max_coerce_types() { - // the coerced types is same with input types let funs: Vec> = vec![Box::new(Min::new()), Box::new(Max::new())]; - let input_types = vec![ - vec![DataType::Int32], - vec![DataType::Decimal128(10, 2)], - vec![DataType::Decimal256(1, 1)], - vec![DataType::Utf8], + let cases = vec![ + (vec![DataType::Int32], vec![DataType::Int32]), + ( + vec![DataType::Decimal128(10, 2)], + vec![DataType::Decimal128(10, 2)], + ), + ( + vec![DataType::Decimal256(1, 1)], + vec![DataType::Decimal256(1, 1)], + ), + (vec![DataType::Utf8], vec![DataType::Utf8]), + ( + vec![DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )], + vec![DataType::Utf8], + ), ]; for fun in funs { - for input_type in &input_types { + for (input_type, expected_type) in &cases { let result = fun.coerce_types(input_type); - assert_eq!(*input_type, result.unwrap()); + assert_eq!(*expected_type, result.unwrap()); } } } @@ -1242,7 +1254,7 @@ mod tests { } #[test] - fn test_min_max_dictionary() -> Result<()> { + fn test_min_max_dictionary_after_coercion() -> Result<()> { let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]); let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]); let dict_array = @@ -1259,7 +1271,120 @@ mod tests { let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); + assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); Ok(()) } + + fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) + } + + fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue { + dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string()))) + } + + fn string_dictionary_batch(values: &[&str], keys: &[Option]) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn optional_string_dictionary_batch( + values: &[Option<&str>], + keys: &[Option], + ) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn evaluate_dictionary_accumulator( + mut acc: impl Accumulator, + batches: &[ArrayRef], + ) -> Result { + for batch in batches { + acc.update_batch(&[Arc::clone(batch)])?; + } + acc.evaluate() + } + + fn assert_dictionary_min_max( + dict_type: &DataType, + batches: &[ArrayRef], + expected_min: &str, + expected_max: &str, + ) -> Result<()> { + let key_type = match dict_type { + DataType::Dictionary(key_type, _) => key_type.as_ref().clone(), + other => panic!("expected dictionary type, got {other:?}"), + }; + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min)); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max)); + + Ok(()) + } + + #[test] + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a", "d"], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a"], + &[None, Some(0), None, Some(1), Some(2)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c") + } + + #[test] + fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> { + let dict_array_ref = + string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z") + } + + #[test] + fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> { + let dict_array_ref = optional_string_dictionary_batch( + &[Some("b"), None, Some("a"), Some("d")], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_multi_batch() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]); + let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]); + + assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d") + } }