diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 859ee6829e2..870ec57b35b 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -30,6 +30,142 @@ pub fn vortex_array::arrays::PrimitiveArray::with_iterator(&self, f: F) -> pub mod vortex_array::aggregate_fn +pub mod vortex_array::aggregate_fn::combined + +pub struct vortex_array::aggregate_fn::combined::Combined(pub T) + +impl vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::new(inner: T) -> Self + +impl core::clone::Clone for vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::clone(&self) -> vortex_array::aggregate_fn::combined::Combined + +impl core::fmt::Debug for vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::combined::Combined + +pub type vortex_array::aggregate_fn::combined::Combined::Options = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + +pub type vortex_array::aggregate_fn::combined::Combined::Partial = (<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Partial, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::Combined::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::combined::Combined::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::combined::Combined::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub struct vortex_array::aggregate_fn::combined::PairOptions(pub L, pub R) + +impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::combined::PairOptions + +impl core::clone::Clone for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::clone(&self) -> vortex_array::aggregate_fn::combined::PairOptions + +impl core::cmp::Eq for vortex_array::aggregate_fn::combined::PairOptions + +impl core::cmp::PartialEq for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::eq(&self, other: &vortex_array::aggregate_fn::combined::PairOptions) -> bool + +impl core::fmt::Debug for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +pub trait vortex_array::aggregate_fn::combined::BinaryCombined: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + +pub type vortex_array::aggregate_fn::combined::BinaryCombined::Left: vortex_array::aggregate_fn::AggregateFnVTable + +pub type vortex_array::aggregate_fn::combined::BinaryCombined::Right: vortex_array::aggregate_fn::AggregateFnVTable + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::coerce_args(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::finalize(&self, left: vortex_array::ArrayRef, right: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::finalize_scalar(&self, left_scalar: vortex_array::scalar::Scalar, right_scalar: vortex_array::scalar::Scalar) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::left(&self) -> Self::Left + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::right(&self) -> Self::Right + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +impl vortex_array::aggregate_fn::combined::BinaryCombined for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Left = vortex_array::aggregate_fn::fns::sum::Sum + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Right = vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, _options: &vortex_array::aggregate_fn::combined::PairOptions<::Options, ::Options>, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, sum: vortex_array::ArrayRef, count: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize_scalar(&self, left_scalar: vortex_array::scalar::Scalar, right_scalar: vortex_array::scalar::Scalar) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left(&self) -> vortex_array::aggregate_fn::fns::sum::Sum + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +pub type vortex_array::aggregate_fn::combined::CombinedOptions = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + pub mod vortex_array::aggregate_fn::fns pub mod vortex_array::aggregate_fn::fns::count @@ -344,6 +480,52 @@ pub struct vortex_array::aggregate_fn::fns::last::LastPartial pub fn vortex_array::aggregate_fn::fns::last::last(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub mod vortex_array::aggregate_fn::fns::mean + +pub struct vortex_array::aggregate_fn::fns::mean::Mean + +impl vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::combined() -> vortex_array::aggregate_fn::combined::Combined + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::clone(&self) -> vortex_array::aggregate_fn::fns::mean::Mean + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::combined::BinaryCombined for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Left = vortex_array::aggregate_fn::fns::sum::Sum + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Right = vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, _options: &vortex_array::aggregate_fn::combined::PairOptions<::Options, ::Options>, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, sum: vortex_array::ArrayRef, count: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize_scalar(&self, left_scalar: vortex_array::scalar::Scalar, right_scalar: vortex_array::scalar::Scalar) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left(&self) -> vortex_array::aggregate_fn::fns::sum::Sum + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::mean(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::min_max pub struct vortex_array::aggregate_fn::fns::min_max::MinMax @@ -1070,6 +1252,42 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Sel pub fn vortex_array::aggregate_fn::fns::sum::Sum::try_accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::combined::Combined + +pub type vortex_array::aggregate_fn::combined::Combined::Options = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + +pub type vortex_array::aggregate_fn::combined::Combined::Partial = (<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Partial, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::Combined::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::combined::Combined::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::combined::Combined::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 665dcec0747..a751a7c5749 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -311,7 +311,7 @@ impl GroupedAccumulator { if validity.value(i) { let group = elements.slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.finish()?)?; + states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } diff --git a/vortex-array/src/aggregate_fn/combined.rs b/vortex-array/src/aggregate_fn/combined.rs new file mode 100644 index 00000000000..fbd8706d6f7 --- /dev/null +++ b/vortex-array/src/aggregate_fn/combined.rs @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Generic adapter for aggregates whose result is computed from two child +//! aggregate functions, e.g. `Mean = Sum / Count`. + +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::fmt::{self}; +use std::hash::Hash; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_session::VortexSession; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::FieldName; +use crate::dtype::FieldNames; +use crate::dtype::Nullability; +use crate::dtype::StructFields; +use crate::scalar::Scalar; + +/// Pair of options for the two children of a [`BinaryCombined`] aggregate. +/// +/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound +/// requires `Display`, which tuples don't implement. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PairOptions(pub L, pub R); + +impl Display for PairOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({}, {})", self.0, self.1) + } +} + +// Convenience aliases so signatures stay readable. +type LeftOptions = <::Left as AggregateFnVTable>::Options; +type RightOptions = <::Right as AggregateFnVTable>::Options; +type LeftPartial = <::Left as AggregateFnVTable>::Partial; +type RightPartial = <::Right as AggregateFnVTable>::Partial; +/// Combined options for a [`BinaryCombined`] aggregate. +pub type CombinedOptions = PairOptions, RightOptions>; + +/// Declare an aggregate function in terms of two child aggregates. +pub trait BinaryCombined: 'static + Send + Sync + Clone { + /// The left child aggregate vtable. + type Left: AggregateFnVTable; + /// The right child aggregate vtable. + type Right: AggregateFnVTable; + + /// Stable identifier for the combined aggregate. + fn id(&self) -> AggregateFnId; + + /// Construct the left child vtable. + fn left(&self) -> Self::Left; + + /// Construct the right child vtable. + fn right(&self) -> Self::Right; + + /// Field name for the left child in the partial struct dtype. + fn left_name(&self) -> &'static str { + "left" + } + + /// Field name for the right child in the partial struct dtype. + fn right_name(&self) -> &'static str { + "right" + } + + /// Return type of the combined aggregate. + fn return_dtype(&self, input_dtype: &DType) -> Option; + + /// Combine the finalized left and right results into the final aggregate. + fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult; + + fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult; + + /// Serialize the options for this combined aggregate. Default: not serializable. + fn serialize(&self, options: &CombinedOptions) -> VortexResult>> { + let _ = options; + Ok(None) + } + + /// Deserialize the options for this combined aggregate. Default: bails. + fn deserialize( + &self, + metadata: &[u8], + session: &VortexSession, + ) -> VortexResult> { + let _ = (metadata, session); + vortex_bail!( + "Combined aggregate function {} is not deserializable", + BinaryCombined::id(self) + ); + } + + /// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`. + fn coerce_args( + &self, + options: &CombinedOptions, + input_dtype: &DType, + ) -> VortexResult { + let left_coerced = self.left().coerce_args(&options.0, input_dtype)?; + self.right().coerce_args(&options.1, &left_coerced) + } +} + +/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`]. +#[derive(Clone, Debug)] +pub struct Combined(pub T); + +impl Combined { + /// Construct a new combined aggregate vtable. + pub fn new(inner: T) -> Self { + Self(inner) + } +} + +impl AggregateFnVTable for Combined { + type Options = CombinedOptions; + type Partial = (LeftPartial, RightPartial); + + fn id(&self) -> AggregateFnId { + self.0.id() + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + BinaryCombined::serialize(&self.0, options) + } + + fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult { + BinaryCombined::deserialize(&self.0, metadata, session) + } + + fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { + BinaryCombined::coerce_args(&self.0, options, input_dtype) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + BinaryCombined::return_dtype(&self.0, input_dtype) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + let l = self.0.left().partial_dtype(&options.0, input_dtype)?; + let r = self.0.right().partial_dtype(&options.1, input_dtype)?; + Some(struct_dtype(self.0.left_name(), self.0.right_name(), l, r)) + } + + fn empty_partial( + &self, + options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult { + Ok(( + self.0.left().empty_partial(&options.0, input_dtype)?, + self.0.right().empty_partial(&options.1, input_dtype)?, + )) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + if other.is_null() { + return Ok(()); + } + let s = other.as_struct(); + let lname = self.0.left_name(); + let rname = self.0.right_name(); + let l_field = s + .field(lname) + .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?; + let r_field = s + .field(rname) + .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?; + self.0.left().combine_partials(&mut partial.0, l_field)?; + self.0.right().combine_partials(&mut partial.1, r_field)?; + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + let l_scalar = self.0.left().to_scalar(&partial.0)?; + let r_scalar = self.0.right().to_scalar(&partial.1)?; + let dtype = struct_dtype( + self.0.left_name(), + self.0.right_name(), + l_scalar.dtype().clone(), + r_scalar.dtype().clone(), + ); + Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar])) + } + + fn reset(&self, partial: &mut Self::Partial) { + self.0.left().reset(&mut partial.0); + self.0.right().reset(&mut partial.1); + } + + fn is_saturated(&self, partial: &Self::Partial) -> bool { + self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1) + } + + /// Fans out to each child's `try_accumulate`, falling back to `accumulate` + /// against a lazily-canonicalized batch. We always claim to handle the + /// batch ourselves so [`Self::accumulate`] is unreachable — this is the + /// same trick `Count` uses to opt out of the canonicalization path. + fn try_accumulate( + &self, + state: &mut Self::Partial, + batch: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let mut canonical: Option = None; + if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? { + let c = canonical.insert(batch.clone().execute::(ctx)?); + self.0.left().accumulate(&mut state.0, c, ctx)?; + } + if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? { + let c = match canonical.as_ref() { + Some(c) => c, + None => canonical.insert(batch.clone().execute::(ctx)?), + }; + self.0.right().accumulate(&mut state.1, c, ctx)?; + } + Ok(true) + } + + fn accumulate( + &self, + _state: &mut Self::Partial, + _batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + unreachable!("Combined::try_accumulate handles all batches") + } + + fn finalize(&self, states: ArrayRef) -> VortexResult { + let l_field = states.get_item(FieldName::from(self.0.left_name()))?; + let r_field = states.get_item(FieldName::from(self.0.right_name()))?; + let l_finalized = self.0.left().finalize(l_field)?; + let r_finalized = self.0.right().finalize(r_field)?; + BinaryCombined::finalize(&self.0, l_finalized, r_finalized) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + let l_scalar = self.0.left().finalize_scalar(&partial.0)?; + let r_scalar = self.0.right().finalize_scalar(&partial.1)?; + BinaryCombined::finalize_scalar(&self.0, l_scalar, r_scalar) + } +} + +fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType { + DType::Struct( + StructFields::new( + FieldNames::from_iter([FieldName::from(left_name), FieldName::from(right_name)]), + vec![left, right], + ), + Nullability::NonNullable, + ) +} diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs new file mode 100644 index 00000000000..f7d471023ee --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::combined::BinaryCombined; +use crate::aggregate_fn::combined::Combined; +use crate::aggregate_fn::combined::CombinedOptions; +use crate::aggregate_fn::combined::PairOptions; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::sum::Sum; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::Scalar; +use crate::scalar_fn::fns::operators::Operator; + +/// Compute the arithmetic mean of an array. +/// +/// See [`Mean`] for details. +pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + array.dtype().clone(), + )?; + acc.accumulate(array, ctx)?; + acc.finish() +} + +/// Compute the arithmetic mean of an array. +/// +/// Implemented as `Sum / Count` via [`BinaryCombined`]. +/// +/// Coercion / return type: +/// - Booleans and primitive numeric types are coerced to `f64` and the result +/// is a nullable `f64`. +/// - Decimals are kept as decimals but not implemented currently +#[derive(Clone, Debug)] +pub struct Mean; + +impl Mean { + pub fn combined() -> Combined { + Combined(Mean) + } +} + +impl BinaryCombined for Mean { + type Left = Sum; + type Right = Count; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new("vortex.mean") + } + + fn left(&self) -> Sum { + Sum + } + + fn right(&self) -> Count { + Count + } + + fn left_name(&self) -> &'static str { + "sum" + } + + fn right_name(&self) -> &'static str { + "count" + } + + fn return_dtype(&self, input_dtype: &DType) -> Option { + Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable)) + } + + fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult { + let target = match sum.dtype() { + DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable), + _ => DType::Primitive(PType::F64, Nullability::Nullable), + }; + let sum_cast = sum.cast(target.clone())?; + let count_cast = count.cast(target)?; + sum_cast.binary(count_cast, Operator::Div) + } + + fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult { + if let DType::Decimal(..) = left_scalar.dtype() { + vortex_bail!("mean::finalize_scalar not yet implemented for decimal inputs"); + } + + let target = DType::Primitive(PType::F64, Nullability::Nullable); + let sum_cast = left_scalar.cast(&target)?; + let count_cast = right_scalar.cast(&target)?; + + let sum = sum_cast.as_primitive().typed_value::(); + let count = count_cast.as_primitive().typed_value::(); + let value = match (sum, count) { + (None, _) | (_, None) | (_, Some(0.0)) => return Ok(Scalar::null(target)), // Sum overflowed + (Some(s), Some(c)) => s / c, + }; + Ok(Scalar::primitive(value, Nullability::Nullable)) + } + + fn serialize(&self, _options: &CombinedOptions) -> VortexResult>> { + unimplemented!("mean is not yet serializable"); + } + + fn coerce_args( + &self, + _options: &PairOptions< + ::Options, + ::Options, + >, + input_dtype: &DType, + ) -> VortexResult { + // Advisory hint for query planners: where possible, cast input to the + // type we're going to compute the mean in. + Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone())) + } +} + +/// Hint for callers: what to cast the input to before accumulation. +/// +/// - Bool stays as bool — `Sum` has a native bool path and bool → f64 isn't +/// currently a direct cast in vortex. +/// - Primitive numerics → `f64` so the sum and finalize work without overflow. +fn coerced_input_dtype(input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) => Some(input_dtype.clone()), + DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)), + DType::Decimal(..) => { + unimplemented!("mean is not implemented for decimals yet") + } + _ => None, + } +} + +fn mean_output_dtype(input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => { + Some(DType::Primitive(PType::F64, Nullability::Nullable)) + } + DType::Decimal(..) => { + unimplemented!("mean for decimals is not yet implemented"); + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use super::*; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::validity::Validity; + + #[test] + fn mean_all_valid() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_with_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_integers() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(20.0)); + Ok(()) + } + + #[test] + fn mean_bool() -> VortexResult<()> { + let array: BoolArray = [true, false, true, true].into_iter().collect(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(0.75)); + Ok(()) + } + + #[test] + fn mean_constant_non_null() -> VortexResult<()> { + let array = ConstantArray::new(5.0f64, 4); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(5.0)); + Ok(()) + } + + #[test] + fn mean_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]); + let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&chunked.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_all_null_returns_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), None); + Ok(()) + } + + #[test] + fn mean_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + dtype, + )?; + + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 38d5340cd1f..f1281c18544 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -6,6 +6,7 @@ pub mod first; pub mod is_constant; pub mod is_sorted; pub mod last; +pub mod mean; pub mod min_max; pub mod nan_count; pub mod sum; diff --git a/vortex-array/src/aggregate_fn/mod.rs b/vortex-array/src/aggregate_fn/mod.rs index b697265e62b..1bcd2cac030 100644 --- a/vortex-array/src/aggregate_fn/mod.rs +++ b/vortex-array/src/aggregate_fn/mod.rs @@ -32,6 +32,7 @@ pub use erased::*; mod options; pub use options::*; +pub mod combined; pub mod fns; pub mod kernels; pub mod proto;