From b36689c2770b522f88c697bf7d9b85cf7e1a23b9 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 7 Apr 2026 22:48:18 +0100 Subject: [PATCH 1/2] Support returning row count in prunning aggregate expressions This lets us effectively prune expressions like IsNotNull Signed-off-by: Robert Kruszewski --- Cargo.lock | 1 + java/testfiles/Cargo.lock | 1 + rust-toolchain.toml | 2 +- vortex-array/public-api.lock | 198 +++++++++++++++++ vortex-array/src/aggregate_fn/fns/mod.rs | 1 + .../src/aggregate_fn/fns/row_count/mod.rs | 143 +++++++++++++ vortex-array/src/aggregate_fn/session.rs | 2 + vortex-array/src/expr/exprs.rs | 22 ++ vortex-array/src/expr/pruning/pruning_expr.rs | 4 + vortex-array/src/scalar_fn/fns/is_not_null.rs | 47 ++-- vortex-array/src/scalar_fn/fns/mod.rs | 1 + .../src/scalar_fn/fns/stats_expression.rs | 202 ++++++++++++++++++ vortex-file/src/file.rs | 24 ++- vortex-file/src/v2/file_stats_reader.rs | 75 ++++++- vortex-layout/Cargo.toml | 1 + vortex-layout/public-api.lock | 2 +- vortex-layout/src/layouts/zoned/reader.rs | 30 ++- vortex-layout/src/layouts/zoned/zone_map.rs | 135 +++++++++++- 18 files changed, 820 insertions(+), 71 deletions(-) create mode 100644 vortex-array/src/aggregate_fn/fns/row_count/mod.rs create mode 100644 vortex-array/src/scalar_fn/fns/stats_expression.rs diff --git a/Cargo.lock b/Cargo.lock index c9b74f2ff69..ec8a9ceb7e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10635,6 +10635,7 @@ dependencies = [ "vortex-io", "vortex-mask", "vortex-metrics", + "vortex-runend", "vortex-scan", "vortex-sequence", "vortex-session", diff --git a/java/testfiles/Cargo.lock b/java/testfiles/Cargo.lock index e3fe6731a30..29e7e0d2f4e 100644 --- a/java/testfiles/Cargo.lock +++ b/java/testfiles/Cargo.lock @@ -2399,6 +2399,7 @@ dependencies = [ "vortex-io", "vortex-mask", "vortex-metrics", + "vortex-runend", "vortex-scan", "vortex-sequence", "vortex-session", diff --git a/rust-toolchain.toml b/rust-toolchain.toml index cb6bd203288..dbcc5a2f04e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ [toolchain] channel = "1.91.0" components = ["rust-src", "rustfmt", "clippy", "rust-analyzer"] -profile = "minimal" \ No newline at end of file +profile = "minimal" diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 8ee3c97f17b..ffe64490317 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -474,6 +474,54 @@ pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::try_accumulate(&sel pub fn vortex_array::aggregate_fn::fns::nan_count::nan_count(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub mod vortex_array::aggregate_fn::fns::row_count + +pub struct vortex_array::aggregate_fn::fns::row_count::RowCount + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::row_count::RowCount + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::clone(&self) -> vortex_array::aggregate_fn::fns::row_count::RowCount + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::row_count::RowCount + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::row_count::RowCount + +pub type vortex_array::aggregate_fn::fns::row_count::RowCount::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::row_count::RowCount::Partial = u64 + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::accumulate(&self, _partial: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::empty_partial(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::is_saturated(&self, _partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::return_dtype(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::sum pub enum vortex_array::aggregate_fn::fns::sum::SumState @@ -1034,6 +1082,42 @@ pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::to_scalar(&self, pa pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::try_accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::row_count::RowCount + +pub type vortex_array::aggregate_fn::fns::row_count::RowCount::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::row_count::RowCount::Partial = u64 + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::accumulate(&self, _partial: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::empty_partial(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::is_saturated(&self, _partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::return_dtype(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::row_count::RowCount::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::EmptyOptions @@ -12376,12 +12460,16 @@ pub fn vortex_array::expr::pack(elements: impl core::iter::traits::collect::Into pub fn vortex_array::expr::root() -> vortex_array::expr::Expression +pub fn vortex_array::expr::row_count() -> vortex_array::expr::Expression + pub fn vortex_array::expr::select(field_names: impl core::convert::Into, child: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub fn vortex_array::expr::select_exclude(fields: impl core::convert::Into, child: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub fn vortex_array::expr::split_conjunction(expr: &vortex_array::expr::Expression) -> alloc::vec::Vec +pub fn vortex_array::expr::stats_expression(agg: vortex_array::aggregate_fn::AggregateFnRef) -> vortex_array::expr::Expression + pub fn vortex_array::expr::zip_expr(mask: vortex_array::expr::Expression, if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression) -> vortex_array::expr::Expression pub type vortex_array::expr::Annotations<'a, A> = vortex_utils::aliases::hash_map::HashMap<&'a vortex_array::expr::Expression, vortex_utils::aliases::hash_set::HashSet> @@ -13102,6 +13190,14 @@ pub fn V::matches(array: &vortex_array::ArrayRef) -> bool pub fn V::try_match<'a>(array: &'a vortex_array::ArrayRef) -> core::option::Option> +impl vortex_array::matcher::Matcher for vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf + +pub type vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::Match<'a> = &'a vortex_array::ArrayRef + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::matches(array: &vortex_array::ArrayRef) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::try_match(array: &vortex_array::ArrayRef) -> core::option::Option + pub mod vortex_array::memory pub struct vortex_array::memory::DefaultHostAllocator @@ -17150,6 +17246,70 @@ pub fn vortex_array::scalar_fn::fns::select::Select::stat_falsification(&self, o pub fn vortex_array::scalar_fn::fns::select::Select::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> +pub mod vortex_array::scalar_fn::fns::stats_expression + +pub struct vortex_array::scalar_fn::fns::stats_expression::StatsExpression + +impl core::clone::Clone for vortex_array::scalar_fn::fns::stats_expression::StatsExpression + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::clone(&self) -> vortex_array::scalar_fn::fns::stats_expression::StatsExpression + +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::stats_expression::StatsExpression + +pub type vortex_array::scalar_fn::fns::stats_expression::StatsExpression::Options = vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::child_name(&self, _options: &Self::Options, _child_idx: usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::execute(&self, agg: &Self::Options, _args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::fmt_sql(&self, agg: &Self::Options, _expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::return_dtype(&self, agg: &Self::Options, _args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub struct vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf(_) + +impl core::fmt::Debug for vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::matcher::Matcher for vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf + +pub type vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::Match<'a> = &'a vortex_array::ArrayRef + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::matches(array: &vortex_array::ArrayRef) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpressionOf::try_match(array: &vortex_array::ArrayRef) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stats_expression::contains_stats_fn_array(array: &vortex_array::ArrayRef) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::substitute_stats_fn_array(array: vortex_array::ArrayRef, replacement: &vortex_array::ArrayRef) -> vortex_error::VortexResult + pub mod vortex_array::scalar_fn::fns::zip pub struct vortex_array::scalar_fn::fns::zip::Zip @@ -18358,6 +18518,44 @@ pub fn vortex_array::scalar_fn::fns::select::Select::stat_falsification(&self, o pub fn vortex_array::scalar_fn::fns::select::Select::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> +impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::stats_expression::StatsExpression + +pub type vortex_array::scalar_fn::fns::stats_expression::StatsExpression::Options = vortex_array::aggregate_fn::AggregateFnRef + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::Arity + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::child_name(&self, _options: &Self::Options, _child_idx: usize) -> vortex_array::scalar_fn::ChildName + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::execute(&self, agg: &Self::Options, _args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::fmt_sql(&self, agg: &Self::Options, _expr: &vortex_array::expr::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::return_dtype(&self, agg: &Self::Options, _args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::simplify(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, ctx: &dyn vortex_array::scalar_fn::SimplifyCtx) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::simplify_untyped(&self, options: &Self::Options, expr: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::stat_expression(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, stat: vortex_array::expr::stats::Stat, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::stat_falsification(&self, options: &Self::Options, expr: &vortex_array::expr::Expression, catalog: &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option + +pub fn vortex_array::scalar_fn::fns::stats_expression::StatsExpression::validity(&self, options: &Self::Options, expression: &vortex_array::expr::Expression) -> vortex_error::VortexResult> + impl vortex_array::scalar_fn::ScalarFnVTable for vortex_array::scalar_fn::fns::zip::Zip pub type vortex_array::scalar_fn::fns::zip::Zip::Options = vortex_array::scalar_fn::EmptyOptions diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 38d5340cd1f..b4d3a2f4b2b 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -8,4 +8,5 @@ pub mod is_sorted; pub mod last; pub mod min_max; pub mod nan_count; +pub mod row_count; pub mod sum; diff --git a/vortex-array/src/aggregate_fn/fns/row_count/mod.rs b/vortex-array/src/aggregate_fn/fns/row_count/mod.rs new file mode 100644 index 00000000000..9fff5edf67c --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/row_count/mod.rs @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::EmptyOptions; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::Scalar; + +/// Count the total number of elements in an array, including nulls. +/// +/// Applies to all types. Returns a `u64` count. +/// The identity value is zero. +/// +/// Unlike [`Count`][crate::aggregate_fn::fns::count::Count], this aggregate includes +/// null elements in the total. It is primarily used as a marker inside pruning +/// predicates that need to refer to the scope row count. +#[derive(Clone, Debug)] +pub struct RowCount; + +impl AggregateFnVTable for RowCount { + type Options = EmptyOptions; + type Partial = u64; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new("vortex.row_count") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + unimplemented!("RowCount is not yet serializable"); + } + + fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option { + Some(DType::Primitive(PType::U64, Nullability::NonNullable)) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + self.return_dtype(options, input_dtype) + } + + fn empty_partial( + &self, + _options: &Self::Options, + _input_dtype: &DType, + ) -> VortexResult { + Ok(0u64) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("row_count partial should not be null"); + *partial += val; + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + Ok(Scalar::primitive(*partial, Nullability::NonNullable)) + } + + fn reset(&self, partial: &mut Self::Partial) { + *partial = 0; + } + + #[inline] + fn is_saturated(&self, _partial: &Self::Partial) -> bool { + false + } + + fn try_accumulate( + &self, + state: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + *state += batch.len() as u64; + Ok(true) + } + + fn accumulate( + &self, + _partial: &mut Self::Partial, + _batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + unreachable!("RowCount::try_accumulate handles all arrays") + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + Ok(partials) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + self.to_scalar(partial) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::row_count::RowCount; + use crate::arrays::PrimitiveArray; + + #[test] + fn row_count_all_valid() -> VortexResult<()> { + let array = buffer![1i32, 2, 3, 4, 5].into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut acc = Accumulator::try_new(RowCount, EmptyOptions, array.dtype().clone())?; + acc.accumulate(&array, &mut ctx)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(5)); + Ok(()) + } + + #[test] + fn row_count_includes_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)]) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut acc = Accumulator::try_new(RowCount, EmptyOptions, array.dtype().clone())?; + acc.accumulate(&array, &mut ctx)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(5)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 886b1ea9c2d..9fe5ba78209 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,6 +18,7 @@ use crate::aggregate_fn::fns::is_sorted::IsSorted; use crate::aggregate_fn::fns::last::Last; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::aggregate_fn::fns::row_count::RowCount; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::kernels::DynAggregateKernel; use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; @@ -59,6 +60,7 @@ impl Default for AggregateFnSession { this.register(Last); this.register(MinMax); this.register(NanCount); + this.register(RowCount); this.register(Sum); // Register the built-in aggregate kernels. diff --git a/vortex-array/src/expr/exprs.rs b/vortex-array/src/expr/exprs.rs index 13821669034..f4ffe73da9a 100644 --- a/vortex-array/src/expr/exprs.rs +++ b/vortex-array/src/expr/exprs.rs @@ -9,6 +9,9 @@ use vortex_error::VortexExpect; use vortex_error::vortex_panic; use vortex_utils::iter::ReduceBalancedIterExt; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTableExt; +use crate::aggregate_fn::fns::row_count::RowCount; use crate::dtype::DType; use crate::dtype::FieldName; use crate::dtype::FieldNames; @@ -46,6 +49,7 @@ use crate::scalar_fn::fns::pack::PackOptions; use crate::scalar_fn::fns::root::Root; use crate::scalar_fn::fns::select::FieldSelection; use crate::scalar_fn::fns::select::Select; +use crate::scalar_fn::fns::stats_expression::StatsExpression; use crate::scalar_fn::fns::zip::Zip; // ---- Root ---- @@ -701,3 +705,21 @@ pub fn dynamic( pub fn list_contains(list: Expression, value: Expression) -> Expression { ListContains.new_expr(EmptyOptions, [list, value]) } + +// ---- StatsExpression ---- + +/// Creates a placeholder [`StatsExpression`] wrapping the given aggregate. +/// +/// The expression must be substituted before evaluation by the layer that owns the +/// evaluation scope — see [`StatsExpression`] for details. +pub fn stats_expression(agg: AggregateFnRef) -> Expression { + StatsExpression.new_expr(agg, []) +} + +/// Creates a [`StatsExpression`] wrapping the [`RowCount`] aggregate. +/// +/// This is the canonical way to refer to the row count of the current evaluation scope inside +/// a pruning predicate. +pub fn row_count() -> Expression { + stats_expression(RowCount.bind(crate::aggregate_fn::EmptyOptions)) +} diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index bf985cad861..6d4a86dd17b 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -86,6 +86,10 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa /// cannot hold, and false if it cannot be determined from stats alone whether the positions can /// be pruned. /// +/// Row-count-aware pruning (for example `is_not_null(...)`) emits +/// [`row_count`][crate::expr::row_count] placeholders that the evaluation layer must substitute +/// before executing the returned expression. +/// /// If the falsification logic attempts to access an unknown stat, /// this function will return `None`. pub fn checked_pruning_expr( diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index f5449a20cad..e1a15d12130 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -14,10 +14,8 @@ use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; use crate::expr::StatsCatalog; -use crate::expr::and; use crate::expr::eq; -use crate::expr::gt; -use crate::expr::lit; +use crate::expr::row_count; use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -106,20 +104,10 @@ impl ScalarFnVTable for IsNotNull { expr: &Expression, catalog: &dyn StatsCatalog, ) -> Option { - // is_not_null is falsified when ALL values are null, i.e. null_count == len. - // Since there is no len stat in the zone map, we approximate using IsConstant: - // if the zone is constant and has any nulls, then all values must be null. - // - // TODO(#7187): Add a len stat to enable the more general falsification: - // null_count == len => is_not_null is all false. - let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; - let is_constant_expr = expr.child(0).stat_expression(Stat::IsConstant, catalog)?; - // If the zone is constant (is_constant == true) and has nulls (null_count > 0), - // then all values must be null, so is_not_null is all false. - Some(and( - eq(is_constant_expr, lit(true)), - gt(null_count_expr, lit(0u64)), - )) + // is_not_null is falsified when ALL values are null, i.e. null_count == row_count. + let child = expr.child(0); + let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?; + Some(eq(null_count_expr, row_count())) } } @@ -267,38 +255,27 @@ mod tests { use crate::dtype::Field; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; - use crate::expr::and; use crate::expr::col; use crate::expr::eq; - use crate::expr::gt; - use crate::expr::lit; use crate::expr::pruning::checked_pruning_expr; + use crate::expr::row_count; use crate::expr::stats::Stat; let expr = is_not_null(col("a")); let (pruning_expr, st) = checked_pruning_expr( &expr, - &FieldPathSet::from_iter([ - FieldPath::from_iter([Field::Name("a".into()), Field::Name("null_count".into())]), - FieldPath::from_iter([Field::Name("a".into()), Field::Name("is_constant".into())]), - ]), + &FieldPathSet::from_iter([FieldPath::from_iter([ + Field::Name("a".into()), + Field::Name("null_count".into()), + ])]), ) .unwrap(); - assert_eq!( - &pruning_expr, - &and( - eq(col("a_is_constant"), lit(true)), - gt(col("a_null_count"), lit(0u64)), - ) - ); + assert_eq!(&pruning_expr, &eq(col("a_null_count"), row_count())); assert_eq!( st.map(), - &HashMap::from_iter([( - FieldPath::from_name("a"), - HashSet::from([Stat::NullCount, Stat::IsConstant]) - )]) + &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) ); } } diff --git a/vortex-array/src/scalar_fn/fns/mod.rs b/vortex-array/src/scalar_fn/fns/mod.rs index 8fa1b66532d..271f54b9988 100644 --- a/vortex-array/src/scalar_fn/fns/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mod.rs @@ -20,4 +20,5 @@ pub mod operators; pub mod pack; pub mod root; pub mod select; +pub mod stats_expression; pub mod zip; diff --git a/vortex-array/src/scalar_fn/fns/stats_expression.rs b/vortex-array/src/scalar_fn/fns/stats_expression.rs new file mode 100644 index 00000000000..519c44630f7 --- /dev/null +++ b/vortex-array/src/scalar_fn/fns/stats_expression.rs @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Formatter; +use std::marker::PhantomData; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnRef; +use crate::aggregate_fn::AggregateFnVTable; +use crate::arrays::ScalarFn; +use crate::arrays::scalar_fn::ExactScalarFn; +use crate::arrays::scalar_fn::ScalarFnArrayExt; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::matcher::Matcher; +use crate::scalar_fn::Arity; +use crate::scalar_fn::ChildName; +use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ScalarFnId; +use crate::scalar_fn::ScalarFnVTable; + +/// A placeholder expression wrapping an [`AggregateFnRef`] that must be substituted with a +/// concrete expression before the containing tree is evaluated. +/// +/// `StatsExpression` nodes are produced while building a pruning predicate to refer to a +/// scope-level statistic that is not derivable from a column of a zone map (for example, the +/// row count of the current scope). The layer that owns the evaluation scope is responsible +/// for walking the expression and substituting each `StatsExpression` with an expression that +/// produces the appropriate array — a literal, a reference to a column it has augmented onto +/// its struct, or any other valid expression. +/// +/// Calling [`ScalarFnVTable::execute`] directly returns an error. +#[derive(Clone)] +pub struct StatsExpression; + +impl ScalarFnVTable for StatsExpression { + type Options = AggregateFnRef; + + fn id(&self) -> ScalarFnId { + ScalarFnId::from("vortex.stats_expression") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(0) + } + + fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName { + unreachable!("StatsExpression has arity 0") + } + + fn fmt_sql( + &self, + agg: &Self::Options, + _expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "stats_expression({})", agg.id()) + } + + fn return_dtype(&self, agg: &Self::Options, _args: &[DType]) -> VortexResult { + // StatsExpression has no children, so we cannot derive a scope dtype. Aggregates whose + // return type is input-independent (e.g. RowCount) will still produce a valid dtype. + agg.return_dtype(&DType::Null).ok_or_else(|| { + vortex_err!( + "StatsExpression wraps aggregate {} whose return type depends on scope dtype", + agg.id() + ) + }) + } + + fn execute( + &self, + agg: &Self::Options, + _args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + vortex_bail!( + "StatsExpression({}) must be substituted before evaluation", + agg.id() + ) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Matcher for a [`ScalarFnArray`] whose scalar function is a [`StatsExpression`] wrapping the +/// aggregate `V`. +/// +/// [`ScalarFnArray`]: crate::arrays::ScalarFnArray +#[derive(Debug)] +pub struct StatsExpressionOf(PhantomData); + +impl Matcher for StatsExpressionOf { + type Match<'a> = &'a ArrayRef; + + fn matches(array: &ArrayRef) -> bool { + array + .as_opt::>() + .is_some_and(|view| view.options.is::()) + } + + fn try_match(array: &ArrayRef) -> Option> { + Self::matches(array).then_some(array) + } +} + +/// Returns `true` if the array tree rooted at `array` contains a [`ScalarFnArray`] matching +/// [`StatsExpressionOf`]. +/// +/// Traversal stops at the first non-[`ScalarFn`] array because those are evaluation leaves — any +/// [`StatsExpression`] placeholders live in the lazy [`ScalarFnArray`] tree produced by +/// [`ArrayRef::apply`][crate::ArrayRef::apply]. +/// +/// [`ScalarFnArray`]: crate::arrays::ScalarFnArray +pub fn contains_stats_fn_array(array: &ArrayRef) -> bool { + if array.is::>() { + return true; + } + match array.as_opt::() { + Some(view) => view.iter_children().any(contains_stats_fn_array::), + None => false, + } +} + +/// Walk the array tree rooted at `array` and replace every [`ScalarFnArray`] matching +/// [`StatsExpressionOf`] with `replacement`. +/// +/// [`ScalarFnArray`] ancestors of a replaced node are rewritten in place via slot take/put so +/// that unaffected children are not cloned. Non-[`ScalarFn`] arrays are returned unchanged — +/// they are evaluation leaves. +/// +/// [`ScalarFnArray`]: crate::arrays::ScalarFnArray +pub fn substitute_stats_fn_array( + array: ArrayRef, + replacement: &ArrayRef, +) -> VortexResult { + if array.is::>() { + vortex_ensure!( + replacement.len() == array.len(), + "StatsExpression replacement length {} does not match scope length {}", + replacement.len(), + array.len(), + ); + vortex_ensure!( + replacement.dtype() == array.dtype(), + "StatsExpression replacement dtype {} does not match scope dtype {}", + replacement.dtype(), + array.dtype(), + ); + return Ok(replacement.clone()); + } + + if !array.is::() { + return Ok(array); + } + + let nchildren = array.nchildren(); + let mut array = array; + for slot_idx in 0..nchildren { + // SAFETY: `substitute_stats_fn_array` always returns an array with the same dtype and + // length as its input — `StatsExpression` placeholders are replaced with a checked + // replacement (same dtype and length), and `ScalarFn` recursion preserves both by + // operating on each slot in place. + let (taken, child) = unsafe { array.take_slot_unchecked(slot_idx)? }; + let new_child = substitute_stats_fn_array::(child, replacement)?; + array = unsafe { taken.put_slot_unchecked(slot_idx, new_child)? }; + } + Ok(array) +} + +#[cfg(test)] +mod tests { + use crate::aggregate_fn::fns::row_count::RowCount; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::expr::row_count; + + #[test] + fn row_count_helper_is_stats_expression() { + let expr = row_count(); + let agg = expr.as_::(); + assert!(agg.is::()); + assert_eq!( + expr.return_dtype(&DType::Primitive(PType::I32, Nullability::Nullable)) + .unwrap(), + DType::Primitive(PType::U64, Nullability::NonNullable), + ); + } +} diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index 7257676b05c..fd2e57fcfa3 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -12,7 +12,10 @@ use std::sync::Arc; use itertools::Itertools; use vortex_array::ArrayRef; use vortex_array::Columnar; +use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::row_count::RowCount; +use vortex_array::arrays::ConstantArray; use vortex_array::dtype::DType; use vortex_array::dtype::Field; use vortex_array::dtype::FieldMask; @@ -20,6 +23,7 @@ use vortex_array::dtype::FieldPath; use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; use vortex_array::expr::pruning::checked_pruning_expr; +use vortex_array::scalar_fn::fns::stats_expression::substitute_stats_fn_array; use vortex_error::VortexResult; use vortex_layout::LayoutReader; use vortex_layout::scan::layout::LayoutReaderDataSource; @@ -162,16 +166,18 @@ impl VortexFile { return Ok(false); }; + // Apply the predicate, then substitute any row_count placeholders in the resulting array + // tree with a ConstantArray carrying the file-level row count. + let applied = file_stats.apply(&predicate)?; + let row_count_replacement = + ConstantArray::new(self.footer.row_count(), applied.len()).into_array(); + let applied = substitute_stats_fn_array::(applied, &row_count_replacement)?; + let mut ctx = self.session.create_execution_ctx(); - Ok( - match file_stats - .apply(&predicate)? - .execute::(&mut ctx)? - { - Columnar::Constant(s) => s.scalar().as_bool().value() == Some(true), - Columnar::Canonical(_) => false, - }, - ) + Ok(match applied.execute::(&mut ctx)? { + Columnar::Constant(s) => s.scalar().as_bool().value() == Some(true), + Columnar::Canonical(_) => false, + }) } pub fn splits(&self) -> VortexResult>> { diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index e5b504d1af4..0171a2ae684 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -15,6 +15,8 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::row_count::RowCount; +use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; @@ -26,6 +28,7 @@ use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::literal::Literal; +use vortex_array::scalar_fn::fns::stats_expression::substitute_stats_fn_array; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; @@ -84,9 +87,9 @@ impl FileStatsLayoutReader { return Ok(false); }; - // Given how we implemented the StatsCatalog, we know the expression must be all literals. - // We can therefore optimize with a null scope since there are no field references that - // need to be resolved. + // Given how we implemented the StatsCatalog, we know the expression must be all literals + // or row_count placeholders. We can therefore optimize with a null scope since there are + // no field references that need to be resolved. let simplified = pruning_expr.optimize_recursive(&DType::Null)?; if let Some(result) = simplified.as_opt::() { // Can prune if the result is non-nullable and true @@ -94,8 +97,12 @@ impl FileStatsLayoutReader { } // Sometimes expressions don't implement constant folding to literals... In this case, - // we just execute the expression over a null array. + // we apply the expression over a null array and substitute any row_count placeholders + // in the resulting array tree with the file's row count. let pruning = NullArray::new(1).into_array().apply(&pruning_expr)?; + let row_count_replacement = + ConstantArray::new(self.child.row_count(), pruning.len()).into_array(); + let pruning = substitute_stats_fn_array::(pruning, &row_count_replacement)?; let mut ctx = self.session.create_execution_ctx(); let result = pruning @@ -126,7 +133,8 @@ impl StatsCatalog for FileStatsLayoutReader { let stat_value = field_stats.get(stat)?.as_exact()?; let field_dtype = self.struct_fields.field_by_index(field_idx)?; - let stat_scalar = Scalar::try_new(field_dtype, Some(stat_value)).ok()?; + let stat_dtype = stat.dtype(&field_dtype)?; + let stat_scalar = Scalar::try_new(stat_dtype, Some(stat_value)).ok()?; Some(lit(stat_scalar)) } @@ -209,12 +217,14 @@ mod tests { use vortex_array::ArrayContext; use vortex_array::IntoArray as _; + use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::expr::get_item; use vortex_array::expr::gt; + use vortex_array::expr::is_not_null; use vortex_array::expr::lit; use vortex_array::expr::root; use vortex_array::expr::stats::Precision; @@ -259,6 +269,18 @@ mod tests { ) } + fn test_file_null_count_stats(null_count: u64) -> FileStatistics { + let mut stats = StatsSet::default(); + stats.set( + Stat::NullCount, + Precision::exact(ScalarValue::from(null_count)), + ); + FileStatistics::new( + Arc::from([stats]), + Arc::from([DType::Primitive(PType::I32, Nullability::Nullable)]), + ) + } + #[test] fn pruning_when_filter_out_of_range() -> VortexResult<()> { block_on(|handle| async { @@ -337,4 +359,47 @@ mod tests { Ok(()) }) } + + #[test] + fn pruning_is_not_null_when_file_is_all_null() -> VortexResult<()> { + block_on(|handle| async { + let session = SESSION.clone().with_handle(handle); + let ctx = ArrayContext::empty(); + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let struct_array = StructArray::from_fields( + [( + "col", + PrimitiveArray::from_option_iter([None::, None, None, None, None]) + .into_array(), + )] + .as_slice(), + )?; + let strategy = TableStrategy::new( + Arc::new(FlatLayoutStrategy::default()), + Arc::new(FlatLayoutStrategy::default()), + ); + let layout = strategy + .write_stream( + ctx, + Arc::::clone(&segments), + struct_array.into_array().to_array_stream().sequenced(ptr), + eof, + &session, + ) + .await?; + + let child = layout.new_reader("".into(), segments, &SESSION)?; + + let reader = + FileStatsLayoutReader::new(child, test_file_null_count_stats(5), SESSION.clone()); + + let expr = is_not_null(get_item("col", root())); + let mask = Mask::new_true(5); + let result = reader.pruning_evaluation(&(0..5), &expr, mask)?.await?; + assert_eq!(result, Mask::new_false(5)); + + Ok(()) + }) + } } diff --git a/vortex-layout/Cargo.toml b/vortex-layout/Cargo.toml index 30a0953444a..61b1253ef43 100644 --- a/vortex-layout/Cargo.toml +++ b/vortex-layout/Cargo.toml @@ -47,6 +47,7 @@ vortex-flatbuffers = { workspace = true, features = ["layout"] } vortex-io = { workspace = true } vortex-mask = { workspace = true } vortex-metrics = { workspace = true } +vortex-runend = { workspace = true } vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 6b029d9c50d..9adc658296d 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -818,7 +818,7 @@ pub unsafe fn vortex_layout::layouts::zoned::zone_map::ZoneMap::new_unchecked(ar pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::present_stats(&self) -> &alloc::sync::Arc<[vortex_array::expr::stats::Stat]> -pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::prune(&self, predicate: &vortex_array::expr::expression::Expression, session: &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::prune(&self, predicate: &vortex_array::expr::expression::Expression, zone_len: u64, row_count: u64, session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::to_stats_set(&self, stats: &[vortex_array::expr::stats::Stat], ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult diff --git a/vortex-layout/src/layouts/zoned/reader.rs b/vortex-layout/src/layouts/zoned/reader.rs index 10f2c1ff598..2a82dbb8b60 100644 --- a/vortex-layout/src/layouts/zoned/reader.rs +++ b/vortex-layout/src/layouts/zoned/reader.rs @@ -104,12 +104,12 @@ impl ZonedReader { .entry(expr.clone()) .or_default() .get_or_init(move || { - let field_path_set = FieldPathSet::from_iter( - self.layout - .present_stats - .iter() - .map(|s| FieldPath::from_name(s.name())), - ); + let field_path_set = self + .layout + .present_stats + .iter() + .map(|s| FieldPath::from_name(s.name())) + .collect::(); checked_pruning_expr(&expr, &field_path_set).map(|(expr, _)| expr) }) .clone() @@ -171,12 +171,15 @@ impl ZonedReader { let zone_map = self.zone_map(); let dynamic_updates = DynamicExprUpdates::new(&expr); let session = self.session.clone(); + let zone_len = self.layout.zone_len as u64; + let row_count = self.layout.row_count(); Some( async move { let zone_map = zone_map.await?; - let initial_mask = - zone_map.prune(&predicate, &session).map_err(|err| { + let initial_mask = zone_map + .prune(&predicate, zone_len, row_count, &session) + .map_err(|err| { err.with_context(format!( "While evaluating pruning predicate {} (derived from {})", predicate, expr @@ -188,6 +191,8 @@ impl ZonedReader { dynamic_updates, latest_result: RwLock::new((0, initial_mask)), session, + zone_len, + row_count, })) } .boxed() @@ -345,6 +350,8 @@ struct PruningResult { dynamic_updates: Option, latest_result: RwLock<(u64, Mask)>, session: VortexSession, + zone_len: u64, + row_count: u64, } impl PruningResult { @@ -385,7 +392,12 @@ impl PruningResult { let next_mask = self .zone_map - .prune(&self.predicate, &self.session) + .prune( + &self.predicate, + self.zone_len, + self.row_count, + &self.session, + ) .map_err(|err| { err.with_context(format!( "While evaluating pruning predicate {}", diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 00df7f72dd4..f7feba7be67 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,7 +8,10 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::row_count::RowCount; use vortex_array::aggregate_fn::fns::sum::sum; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::DType; @@ -19,12 +22,16 @@ use vortex_array::expr::Expression; use vortex_array::expr::stats::Precision; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; +use vortex_array::scalar_fn::fns::stats_expression::contains_stats_fn_array; +use vortex_array::scalar_fn::fns::stats_expression::substitute_stats_fn_array; use vortex_array::stats::StatsSet; use vortex_array::validity::Validity; +use vortex_buffer::buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_mask::Mask; +use vortex_runend::RunEnd; use vortex_session::VortexSession; use crate::layouts::zoned::builder::MAX_IS_TRUNCATED; @@ -160,18 +167,72 @@ impl ZoneMap { /// be pruned. /// /// The expression provided should be the result of converting an existing `VortexExpr` via - /// [`checked_pruning_expr`][vortex_array::expr::pruning::checked_pruning_expr] into a prunable - /// expression that can be evaluated on a zone map. + /// [`checked_pruning_expr`] into a prunable expression that can be evaluated on a zone map. + /// + /// Before evaluation, any + /// [`row_count`][vortex_array::expr::row_count] placeholders left in the lazy + /// [`ScalarFnArray`][vortex_array::arrays::ScalarFnArray] tree produced by + /// [`ArrayRef::apply`][vortex_array::ArrayRef::apply] are replaced with the per-zone row + /// count array — a [`ConstantArray`] when every zone has the same length, or a run-end + /// encoded array when the final zone is short. /// /// All zones where the predicate evaluates to `true` can be skipped entirely. - pub fn prune(&self, predicate: &Expression, session: &VortexSession) -> VortexResult { + /// + /// [`checked_pruning_expr`]: vortex_array::expr::pruning::checked_pruning_expr + pub fn prune( + &self, + predicate: &Expression, + zone_len: u64, + row_count: u64, + session: &VortexSession, + ) -> VortexResult { let mut ctx = session.create_execution_ctx(); - self.array - .clone() - .into_array() - .apply(predicate)? - .execute::(&mut ctx) + let num_zones = self.array.len(); + + let applied = self.array.clone().into_array().apply(predicate)?; + + if num_zones == 0 || !contains_stats_fn_array::(&applied) { + return applied.execute::(&mut ctx); + } + + let row_count_array = row_count_array(zone_len, row_count, num_zones, &mut ctx)?; + let substituted = substitute_stats_fn_array::(applied, &row_count_array)?; + substituted.execute::(&mut ctx) + } +} + +/// Build the per-zone row count array for a zone map of `num_zones` zones backed by a data array +/// of `row_count` total rows and nominal `zone_len` per zone. +/// +/// When every zone has the same length the result is a [`ConstantArray`]. When the final zone is +/// short, the result is a run-end encoded array whose trailing run carries the short last-zone +/// length. +fn row_count_array( + zone_len: u64, + row_count: u64, + num_zones: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let last_zone_len = row_count - zone_len.saturating_mul((num_zones as u64) - 1); + if num_zones == 1 || last_zone_len == zone_len { + return Ok(ConstantArray::new(last_zone_len, num_zones).into_array()); + } + + let ends = unsafe { + PrimitiveArray::new_unchecked( + buffer![num_zones as u64 - 1, num_zones as u64], + Validity::NonNullable, + ) + } + .into_array(); + let values = unsafe { + PrimitiveArray::new_unchecked(buffer![zone_len, last_zone_len], Validity::NonNullable) } + .into_array(); + + // SAFETY: `ends` are strictly increasing, terminate at `num_zones`, and align one-to-one + // with the non-null run values. + Ok(unsafe { RunEnd::new_unchecked(ends, values, 0, num_zones, ctx) }.into_array()) } // TODO(ngates): we should make it such that the zone map stores a mirror of the DType @@ -292,6 +353,7 @@ mod tests { use vortex_array::dtype::PType; use vortex_array::expr::gt; use vortex_array::expr::gt_eq; + use vortex_array::expr::is_not_null; use vortex_array::expr::lit; use vortex_array::expr::lt; use vortex_array::expr::pruning::checked_pruning_expr; @@ -445,7 +507,7 @@ mod tests { // => A.max < 6 let expr = gt_eq(root(), lit(6i32)); let (pruning_expr, _) = checked_pruning_expr(&expr, &stats).unwrap(); - let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + let mask = zone_map.prune(&pruning_expr, 1, 3, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), BoolArray::from_iter([true, false, false]) @@ -455,7 +517,7 @@ mod tests { // => A.max <= 5 let expr = gt(root(), lit(5i32)); let (pruning_expr, _) = checked_pruning_expr(&expr, &stats).unwrap(); - let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + let mask = zone_map.prune(&pruning_expr, 1, 3, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), BoolArray::from_iter([true, false, false]) @@ -465,7 +527,58 @@ mod tests { // => A.min >= 2 let expr = lt(root(), lit(2i32)); let (pruning_expr, _) = checked_pruning_expr(&expr, &stats).unwrap(); - let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + let mask = zone_map.prune(&pruning_expr, 1, 3, &SESSION).unwrap(); assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); } + + #[test] + fn row_count_prunes_short_trailing_zone() { + let zone_map = ZoneMap::try_new( + PType::U64.into(), + StructArray::from_fields(&[( + "null_count", + PrimitiveArray::new(buffer![0u64, 0, 2], Validity::AllValid).into_array(), + )]) + .unwrap(), + Arc::new([Stat::NullCount]), + ) + .unwrap(); + + let available_stats = + FieldPathSet::from_iter([FieldPath::from_iter([Stat::NullCount.name().into()])]); + let expr = is_not_null(root()); + let (pruning_expr, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + + let mask = zone_map.prune(&pruning_expr, 4, 10, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, true]) + ); + } + + #[test] + fn row_count_prunes_all_null_uniform_zones() { + let zone_map = ZoneMap::try_new( + PType::U64.into(), + StructArray::from_fields(&[( + "null_count", + PrimitiveArray::new(buffer![0u64, 4, 0], Validity::AllValid).into_array(), + )]) + .unwrap(), + Arc::new([Stat::NullCount]), + ) + .unwrap(); + + let available_stats = + FieldPathSet::from_iter([FieldPath::from_iter([Stat::NullCount.name().into()])]); + let expr = is_not_null(root()); + let (pruning_expr, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + + // All three zones have length 4 (total rows = 12). + let mask = zone_map.prune(&pruning_expr, 4, 12, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, true, false]) + ); + } } From fca96facd2003c165d29a566bdfa3c9f9137afe4 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 23 Apr 2026 17:23:36 +0100 Subject: [PATCH 2/2] Update vortex-array/src/expr/exprs.rs Signed-off-by: Robert Kruszewski --- vortex-array/src/expr/exprs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vortex-array/src/expr/exprs.rs b/vortex-array/src/expr/exprs.rs index f4ffe73da9a..232fb1809ea 100644 --- a/vortex-array/src/expr/exprs.rs +++ b/vortex-array/src/expr/exprs.rs @@ -721,5 +721,5 @@ pub fn stats_expression(agg: AggregateFnRef) -> Expression { /// This is the canonical way to refer to the row count of the current evaluation scope inside /// a pruning predicate. pub fn row_count() -> Expression { - stats_expression(RowCount.bind(crate::aggregate_fn::EmptyOptions)) + stats_expression(RowCount.bind(EmptyOptions)) }