diff --git a/Cargo.lock b/Cargo.lock index 02da8661eedea..0601f2ec16be7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2277,6 +2277,7 @@ dependencies = [ "datafusion-expr", "datafusion-expr-common", "datafusion-macros", + "datafusion-physical-expr-common", "env_logger", "hex", "itertools 0.14.0", diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 967b35d2eb985..fdc31476b530f 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -78,6 +78,7 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-macros = { workspace = true } +datafusion-physical-expr-common = { workspace = true } hex = { workspace = true, optional = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 89c4d4eb0fc81..3014901fe2b83 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -18,11 +18,11 @@ use arrow::datatypes::DataType; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; -use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; use datafusion_common::{Result, ScalarValue, utils::take_function_args}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; +use datafusion_physical_expr_common::datum::compare_with_eq; #[user_doc( doc_section(label = "Conditional Functions"), @@ -111,25 +111,29 @@ impl ScalarUDFImpl for NullIfFunc { /// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. fn nullif_func(args: &[ColumnarValue]) -> Result { let [lhs, rhs] = take_function_args("nullif", args)?; + let is_nested = lhs.data_type().is_nested(); match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { let rhs = rhs.to_scalar()?; - let array = nullif(lhs, &eq(&lhs, &rhs)?)?; + let eq_array = compare_with_eq(lhs, &rhs, is_nested)?; + let array = nullif(lhs, &eq_array)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - let array = nullif(lhs, &eq(&lhs, &rhs)?)?; + let eq_array = compare_with_eq(lhs, rhs, is_nested)?; + let array = nullif(lhs, &eq_array)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { let lhs_s = lhs.to_scalar()?; let lhs_a = lhs.to_array_of_size(rhs.len())?; + let eq_array = compare_with_eq(&lhs_s, rhs, is_nested)?; let array = nullif( // nullif in arrow-select does not support Datum, so we need to convert to array lhs_a.as_ref(), - &eq(&lhs_s, &rhs)?, + &eq_array, )?; Ok(ColumnarValue::Array(array)) } @@ -148,7 +152,12 @@ fn nullif_func(args: &[ColumnarValue]) -> Result { mod tests { use std::sync::Arc; - use arrow::array::*; + use arrow::{ + array::*, + buffer::NullBuffer, + datatypes::{Field, Fields, Int64Type}, + }; + use datafusion_common::assert_batches_eq; use super::*; @@ -251,6 +260,88 @@ mod tests { Ok(()) } + #[test] + fn nullif_struct() -> Result<()> { + let fields = Fields::from(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + ]); + + let lhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + let lhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None])); + let lhs_nulls = Some(NullBuffer::from(vec![true, true, false])); + let lhs = ColumnarValue::Array(Arc::new(StructArray::new( + fields.clone(), + vec![lhs_a, lhs_b], + lhs_nulls, + ))); + + let rhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(9), None])); + let rhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None])); + let rhs_nulls = Some(NullBuffer::from(vec![true, true, false])); + let rhs = ColumnarValue::Array(Arc::new(StructArray::new( + fields, + vec![rhs_a, rhs_b], + rhs_nulls, + ))); + + let result = nullif_func(&[lhs, rhs])?; + let result = result.into_array(0).expect("Failed to convert to array"); + let batch = RecordBatch::try_from_iter([("result", result)])?; + + let expected = [ + "+--------------+", + "| result |", + "+--------------+", + "| |", + "| {a: 2, b: 2} |", + "| |", + "+--------------+", + ]; + + assert_batches_eq!(expected, &[batch]); + + Ok(()) + } + + #[test] + fn nullif_list() -> Result<()> { + let lhs = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3)]), + Some(vec![]), + Some(vec![Some(5), Some(6), Some(7)]), + None, + ])); + let lhs = ColumnarValue::Array(lhs); + + let rhs = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + ])); + let rhs = ColumnarValue::Scalar(ScalarValue::List(rhs)); + + let result = nullif_func(&[lhs, rhs])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let batch = RecordBatch::try_from_iter([("result", result)])?; + + let expected = [ + "+-----------+", + "| result |", + "+-----------+", + "| |", + "| [3] |", + "| [] |", + "| [5, 6, 7] |", + "| |", + "+-----------+", + ]; + + assert_batches_eq!(expected, &[batch]); + + Ok(()) + } + #[test] fn nullif_literal_first() -> Result<()> { let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);