From d8d171c52ff34cf83208adbe74cbb50170c437a0 Mon Sep 17 00:00:00 2001 From: Sergio Esteves Date: Mon, 16 Mar 2026 17:16:58 +0000 Subject: [PATCH 1/3] feat: add approx_top_k aggregate function Add a new approx_top_k(expression, k) aggregate function that returns the approximate top-k most frequent values with their estimated counts, using the Filtered Space-Saving algorithm. The implementation uses a capacity multiplier of 3 (matching ClickHouse's default) and includes an alpha map for improved accuracy by filtering low-frequency noise before it enters the main summary. Return type is List(Struct({value: T, count: UInt64})) ordered by count descending, where T matches the input column type. Closes #20967 --- .../tests/dataframe/dataframe_functions.rs | 26 +- .../functions-aggregate/src/approx_top_k.rs | 1421 +++++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 3 + .../tests/cases/roundtrip_logical_plan.rs | 7 +- .../sqllogictest/test_files/approx_top_k.slt | 87 + docs/source/user-guide/expressions.md | 1 + .../user-guide/sql/aggregate_functions.md | 25 + 7 files changed, 1566 insertions(+), 4 deletions(-) create mode 100644 datafusion/functions-aggregate/src/approx_top_k.rs create mode 100644 datafusion/sqllogictest/test_files/approx_top_k.slt diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 2ada0411f4f8c..bc83670664c6a 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -32,7 +32,9 @@ use datafusion_common::test_util::batches_to_string; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::{ExprSchemable, LogicalPlanBuilder, table_scan}; -use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; +use datafusion_functions_aggregate::expr_fn::{ + approx_median, approx_percentile_cont, approx_top_k, +}; use datafusion_functions_nested::map::map; use insta::assert_snapshot; @@ -409,6 +411,28 @@ async fn test_fn_approx_median() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_approx_top_k() -> Result<()> { + // Column b has values [1, 10, 10, 100] -- 10 appears twice, others once. + // Use k=1 to avoid non-deterministic ordering among tied items. + let expr = approx_top_k(vec![col("b"), lit(1)]); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +-------------------------------+ + | approx_top_k(test.b,Int32(1)) | + +-------------------------------+ + | [{value: 10, count: 2}] | + +-------------------------------+ + "); + + Ok(()) +} + #[tokio::test] async fn test_fn_approx_percentile_cont() -> Result<()> { let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), None); diff --git a/datafusion/functions-aggregate/src/approx_top_k.rs b/datafusion/functions-aggregate/src/approx_top_k.rs new file mode 100644 index 0000000000000..fdca78235ec34 --- /dev/null +++ b/datafusion/functions-aggregate/src/approx_top_k.rs @@ -0,0 +1,1421 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Approximate top-k aggregate function using the Filtered Space-Saving algorithm. +//! +//! This implements a distributed-friendly approximate top-k aggregation using +//! the Filtered Space-Saving algorithm. The algorithm maintains a fixed-size summary +//! of counters plus an alpha map (filter) that remembers evicted items' frequencies. +//! +//! Usage: `approx_top_k(column, k)` +//! - `column`: The column to find the most frequent values from +//! - `k`: The number of top elements to track (required, literal integer) +//! +//! Returns: `List, count: UInt64 }>` ordered by count descending. +//! +//! Algorithm references: +//! - Filtered Space-Saving: +//! - Parallel Space Saving: +//! - Space-Saving: Metwally, Agrawal, El Abbadi. "Efficient Computation of Frequent +//! and Top-k Elements in Data Streams" (ICDT 2005) + +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, Float32Array, Float64Array, + Int8Array, Int16Array, Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, + ListArray, StringArray, StructArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + UInt8Array, UInt16Array, UInt32Array, UInt64Array, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, +}; +use datafusion_macros::user_doc; + +make_udaf_expr_and_func!( + ApproxTopK, + approx_top_k, + "Returns the approximate most frequent (top-k) values and their counts using the Filtered Space-Saving algorithm.", + approx_top_k_udaf +); + +// --------------------------------------------------------------------------- +// Algorithm constants +// --------------------------------------------------------------------------- + +/// Suggested constant from the paper "Finding top-k elements in data streams", +/// chap 6, equation (24). Determines the size of the alpha map relative to the capacity. +const ALPHA_MAP_ELEMENTS_PER_COUNTER: usize = 6; + +/// Limit the max alpha value to avoid overflow with merges or weighted additions. +const MAX_ALPHA_VALUE: u64 = u32::MAX as u64; + +/// Maximum allowed value for k in `approx_top_k(column, k)`. +const APPROX_TOP_K_MAX_K: usize = 10_000; + +/// Capacity multiplier for internal tracking (matches ClickHouse's default). +/// +/// We track more items internally than k to improve accuracy. +/// If user asks for top-5, we internally track top `5 * 3 = 15` items. +/// Memory impact: ~100 bytes per counter; target_capacity is 2x the tracked +/// counter count, so top-100 uses ~60 KB per accumulator. +const CAPACITY_MULTIPLIER: usize = 3; + +// --------------------------------------------------------------------------- +// SpaceSavingSummary (core algorithm) +// --------------------------------------------------------------------------- + +/// Counter entry in the Filtered Space-Saving summary. +/// +/// Each entry tracks an item, its estimated count, and the error bound. +/// The algorithm guarantees that the true count lies within `[count - error, count]`. +#[derive(Debug, Clone)] +struct Counter { + /// The serialized bytes representing the tracked item. + item: Vec, + /// FNV-1a hash of the item (cached to avoid recomputation). + hash: u64, + /// The estimated frequency count (may overestimate due to eviction handling). + count: u64, + /// The maximum possible overestimation (error bound). + error: u64, +} + +impl Counter { + /// Compare counters for sorting: higher `(count - error)` wins, + /// then higher `count` breaks ties. + fn is_greater_than(&self, other: &Counter) -> bool { + let self_lb = self.count.saturating_sub(self.error); + let other_lb = other.count.saturating_sub(other.error); + (self_lb > other_lb) || (self_lb == other_lb && self.count > other.count) + } + + /// Ordering for top-k selection: highest-ranked counters sort first. + fn cmp_by_rank(&self, other: &Counter) -> std::cmp::Ordering { + if other.is_greater_than(self) { + std::cmp::Ordering::Greater + } else if self.is_greater_than(other) { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Equal + } + } +} + +/// Filtered Space-Saving algorithm summary for approximate top-k / heavy hitters. +/// +/// Uses a hash map for O(1) counter lookups and maintains an alpha map (filter) +/// to remember evicted items' frequencies. +#[derive(Debug, Clone)] +struct SpaceSavingSummary { + counters: Vec, + counter_map: HashMap, usize>, + alpha_map: Vec, + requested_capacity: usize, + /// Internal target capacity to avoid frequent truncations. + /// Set to `max(64, requested_capacity * 2)`. + target_capacity: usize, +} + +impl SpaceSavingSummary { + fn compute_alpha_map_size(capacity: usize) -> usize { + (capacity * ALPHA_MAP_ELEMENTS_PER_COUNTER).next_power_of_two() + } + + /// FNV-1a hash for item bytes. + fn hash_item(item: &[u8]) -> u64 { + let mut hash: u64 = 0xcbf29ce484222325; + for &byte in item { + hash ^= byte as u64; + hash = hash.wrapping_mul(0x100000001b3); + } + hash + } + + fn new(capacity: usize) -> Self { + Self { + counters: Vec::new(), + counter_map: HashMap::new(), + alpha_map: Vec::new(), + requested_capacity: 0, + target_capacity: 0, + } + .resized(capacity) + } + + fn resized(mut self, new_capacity: usize) -> Self { + if self.requested_capacity != new_capacity { + debug_assert!(self.counters.is_empty()); + let alpha_map_size = Self::compute_alpha_map_size(new_capacity); + self.alpha_map = vec![0u64; alpha_map_size]; + self.requested_capacity = new_capacity; + self.target_capacity = std::cmp::max(64, new_capacity * 2); + self.counters.reserve(self.target_capacity); + } + self + } + + fn is_empty(&self) -> bool { + self.counters.is_empty() + } + + #[cfg(test)] + fn len(&self) -> usize { + self.counters.len() + } + + #[cfg(test)] + fn capacity(&self) -> usize { + self.requested_capacity + } + + fn find_counter_mut(&mut self, item: &[u8]) -> Option<&mut Counter> { + self.counter_map + .get(item) + .copied() + .map(|idx| &mut self.counters[idx]) + } + + #[cfg(test)] + fn find_counter(&self, item: &[u8]) -> Option<&Counter> { + self.counter_map.get(item).map(|&idx| &self.counters[idx]) + } + + /// Add an item with increment 1. + fn add(&mut self, item: &[u8]) { + self.insert(item, 1, 0); + } + + /// Core insertion algorithm from Filtered Space-Saving. + fn insert(&mut self, item: &[u8], increment: u64, error: u64) { + let hash = Self::hash_item(item); + + // Fast path: item already tracked. + if let Some(counter) = self.find_counter_mut(item) { + counter.count += increment; + counter.error += error; + return; + } + + // Below capacity: add directly. + if self.counters.len() < self.requested_capacity { + self.push_counter(item.to_vec(), hash, increment, error); + return; + } + + // At capacity: use alpha map for historical frequency. + let alpha_mask = self.alpha_map.len() - 1; + let alpha_idx = (hash as usize) & alpha_mask; + let alpha = self.alpha_map[alpha_idx]; + + self.push_counter(item.to_vec(), hash, alpha + increment, alpha + error); + } + + fn push_counter(&mut self, item: Vec, hash: u64, count: u64, error: u64) { + let idx = self.counters.len(); + self.counter_map.insert(item.clone(), idx); + self.counters.push(Counter { + item, + hash, + count, + error, + }); + self.truncate_if_needed(false); + } + + /// Truncate counters when `target_capacity` is reached, + /// updating the alpha map with evicted items' true counts. + fn truncate_if_needed(&mut self, force_rebuild: bool) { + let need_truncate = self.counters.len() >= self.target_capacity; + + if need_truncate { + let k = self.requested_capacity; + if k < self.counters.len() { + self.counters + .select_nth_unstable_by(k - 1, |a, b| a.cmp_by_rank(b)); + + let alpha_mask = self.alpha_map.len() - 1; + for counter in self.counters.drain(k..) { + let alpha_idx = (counter.hash as usize) & alpha_mask; + let true_count = counter.count.saturating_sub(counter.error); + self.alpha_map[alpha_idx] = std::cmp::min( + self.alpha_map[alpha_idx] + true_count, + MAX_ALPHA_VALUE, + ); + } + } + } + + if need_truncate || force_rebuild { + self.counter_map.clear(); + for (idx, counter) in self.counters.iter().enumerate() { + self.counter_map.insert(counter.item.clone(), idx); + } + } + } + + #[cfg(test)] + fn get(&self, item: &[u8]) -> Option<(u64, u64)> { + self.find_counter(item).map(|c| (c.count, c.error)) + } + + /// Get the top-k items sorted by (count - error) descending. + fn top_k(&self, k: usize) -> Vec<(&[u8], u64, u64)> { + if k == 0 || self.counters.is_empty() { + return Vec::new(); + } + + let mut sorted: Vec<_> = self.counters.iter().collect(); + let return_size = std::cmp::min(sorted.len(), k); + + if return_size < sorted.len() { + sorted.select_nth_unstable_by(return_size - 1, |a, b| a.cmp_by_rank(b)); + sorted.truncate(return_size); + } + + sorted.sort_by(|a, b| a.cmp_by_rank(b)); + + sorted + .into_iter() + .map(|c| (c.item.as_slice(), c.count, c.error)) + .collect() + } + + /// Merge another summary into this one. + fn merge(&mut self, other: &SpaceSavingSummary) { + if other.is_empty() { + return; + } + + if self.is_empty() { + self.counters.clone_from(&other.counters); + self.counter_map.clone_from(&other.counter_map); + self.alpha_map.clone_from(&other.alpha_map); + self.requested_capacity = other.requested_capacity; + self.target_capacity = other.target_capacity; + return; + } + + for other_counter in &other.counters { + if let Some(idx) = self.counter_map.get(&other_counter.item).copied() { + self.counters[idx].count += other_counter.count; + self.counters[idx].error += other_counter.error; + } else { + self.counters.push(Counter { + item: other_counter.item.clone(), + hash: other_counter.hash, + count: other_counter.count, + error: other_counter.error, + }); + } + } + + // Merge alpha maps element-wise. Sizes should always match because the + // planner guarantees the same k (and thus the same capacity/alpha map size) + // across all partitions. If they differ due to a bug, we skip the merge + // which only degrades accuracy without affecting correctness. + if self.alpha_map.len() == other.alpha_map.len() { + for (i, &other_alpha) in other.alpha_map.iter().enumerate() { + self.alpha_map[i] = + std::cmp::min(self.alpha_map[i] + other_alpha, MAX_ALPHA_VALUE); + } + } + + self.truncate_if_needed(true); + } + + /// Serialize the summary to bytes. + fn serialize(&mut self) -> Vec { + // Ensure counters are truncated and alpha map is up to date before + // serializing, in case to_bytes is called without a prior truncation. + self.truncate_if_needed(false); + + let counters_to_write: Vec<_> = { + let mut sorted: Vec<_> = self.counters.iter().collect(); + let return_size = std::cmp::min(sorted.len(), self.requested_capacity); + if return_size > 0 && return_size < sorted.len() { + // After select_nth_unstable_by, the top-k counters are in + // sorted[..return_size] but in arbitrary order. This is fine + // since deserialization doesn't depend on counter ordering. + sorted.select_nth_unstable_by(return_size - 1, |a, b| a.cmp_by_rank(b)); + } + sorted.truncate(return_size); + sorted + }; + + let mut bytes = Vec::new(); + bytes.extend_from_slice(&(self.requested_capacity as u64).to_le_bytes()); + bytes.extend_from_slice(&(counters_to_write.len() as u64).to_le_bytes()); + + for counter in counters_to_write { + bytes.extend_from_slice(&(counter.item.len() as u32).to_le_bytes()); + bytes.extend_from_slice(&counter.item); + bytes.extend_from_slice(&counter.count.to_le_bytes()); + bytes.extend_from_slice(&counter.error.to_le_bytes()); + } + + bytes.extend_from_slice(&(self.alpha_map.len() as u64).to_le_bytes()); + for &alpha in &self.alpha_map { + bytes.extend_from_slice(&alpha.to_le_bytes()); + } + + bytes + } + + /// Deserialize a summary from bytes. + fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() < 16 { + return Err(datafusion_common::DataFusionError::Execution( + "Invalid Space-Saving summary bytes: too short".to_string(), + )); + } + + let requested_capacity = + u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize; + let num_counters = u64::from_le_bytes(bytes[8..16].try_into().unwrap()) as usize; + + let mut counters = Vec::with_capacity(num_counters); + let mut counter_map = HashMap::with_capacity(num_counters); + let mut offset = 16; + + for idx in 0..num_counters { + if offset + 4 > bytes.len() { + return Err(datafusion_common::DataFusionError::Execution( + "Invalid Space-Saving summary bytes: truncated item length" + .to_string(), + )); + } + let item_len = + u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) + as usize; + offset += 4; + + if offset + item_len + 16 > bytes.len() { + return Err(datafusion_common::DataFusionError::Execution( + "Invalid Space-Saving summary bytes: truncated counter".to_string(), + )); + } + + let item = bytes[offset..offset + item_len].to_vec(); + offset += item_len; + + let count = u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + + let error = u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + + let hash = Self::hash_item(&item); + counter_map.insert(item.clone(), idx); + counters.push(Counter { + item, + hash, + count, + error, + }); + } + + if offset + 8 > bytes.len() { + return Err(datafusion_common::DataFusionError::Execution( + "Invalid Space-Saving summary bytes: missing alpha map size".to_string(), + )); + } + let alpha_map_size = + u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap()) as usize; + offset += 8; + + if offset + alpha_map_size.saturating_mul(8) > bytes.len() { + return Err(datafusion_common::DataFusionError::Execution( + "Invalid Space-Saving summary bytes: truncated alpha map".to_string(), + )); + } + + let mut alpha_map = Vec::with_capacity(alpha_map_size); + for _ in 0..alpha_map_size { + let alpha = u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap()); + offset += 8; + alpha_map.push(alpha); + } + + let target_capacity = std::cmp::max(64, requested_capacity * 2); + + Ok(Self { + counters, + counter_map, + alpha_map, + requested_capacity, + target_capacity, + }) + } + + /// Approximate size in bytes of this summary. + fn size(&self) -> usize { + // Heap bytes owned by each counter's item Vec. + let item_heap_bytes: usize = self + .counters + .iter() + .map(|c| c.item.capacity()) + .sum::(); + + size_of::() + // Vec backing storage. + + self.counters.capacity() * size_of::() + // Heap allocations for counter item bytes (owned by counters). + + item_heap_bytes + // HashMap, usize>: bucket overhead + control bytes. + + self.counter_map.capacity() + * (size_of::<(Vec, usize)>() + size_of::()) + // HashMap keys are clones of counter items, so count their heap bytes again. + + item_heap_bytes + // Vec alpha map. + + self.alpha_map.capacity() * size_of::() + } +} + +// --------------------------------------------------------------------------- +// ApproxTopK UDAF struct +// --------------------------------------------------------------------------- + +/// Approximate top-k UDAF using the Filtered Space-Saving algorithm. +#[user_doc( + doc_section(label = "Approximate Functions"), + description = "Returns the approximate most frequent (top-k) values and their counts using the Filtered Space-Saving algorithm. Note: for float columns, -0.0 and +0.0 are treated as distinct values, and different NaN representations are tracked separately.", + syntax_example = "approx_top_k(expression, k)", + sql_example = r#"```sql +> SELECT approx_top_k(column_name, 3) FROM table_name; ++-------------------------------------------+ +| approx_top_k(column_name, 3) | ++-------------------------------------------+ +| [{value: foo, count: 3}, {value: bar, count: 2}, {value: baz, count: 1}] | ++-------------------------------------------+ +```"#, + standard_argument(name = "expression",), + argument( + name = "k", + description = "The number of top elements to return. Must be a literal integer between 1 and 10,000." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ApproxTopK { + signature: Signature, +} + +impl Default for ApproxTopK { + fn default() -> Self { + Self::new() + } +} + +impl ApproxTopK { + pub fn new() -> Self { + // Supported value types for the first argument. + let value_types = &[ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Date64, + DataType::Timestamp(arrow::datatypes::TimeUnit::Second, None), + DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None), + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), + DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), + ]; + // k must be a literal integer; accept any integer type for convenience. + let k_types = &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + ]; + + let mut variants = Vec::with_capacity(value_types.len() * k_types.len()); + for vt in value_types { + for kt in k_types { + variants.push(TypeSignature::Exact(vec![vt.clone(), kt.clone()])); + } + } + + Self { + signature: Signature::one_of(variants, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ApproxTopK { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "approx_top_k" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let value_type = if !arg_types.is_empty() { + match &arg_types[0] { + // Large variants are narrowed: the summary stores in-memory + // byte slices, not large column offsets, so i32 offsets suffice. + DataType::LargeUtf8 => DataType::Utf8, + DataType::LargeBinary => DataType::Binary, + other => other.clone(), + } + } else { + DataType::Utf8 + }; + + let struct_fields = Fields::from(vec![ + Field::new("value", value_type, true), + Field::new("count", DataType::UInt64, false), + ]); + Ok(DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(struct_fields), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Arc::new(Field::new( + format_state_name(args.name, "summary"), + DataType::Binary, + true, + )), + Arc::new(Field::new( + format_state_name(args.name, "k"), + DataType::UInt64, + true, + )), + Arc::new(Field::new( + format_state_name(args.name, "data_type"), + DataType::Utf8, + true, + )), + ]) + } + + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + if args.exprs.len() < 2 { + return Err(datafusion_common::DataFusionError::Plan( + "approx_top_k requires two arguments: column and k".to_string(), + )); + } + + let k_expr = &args.exprs[1]; + let k = k_expr + .as_any() + .downcast_ref::() + .and_then(|lit| match lit.value() { + // Guard against negative values before casting to usize to + // avoid wrapping (e.g. -1i32 as usize → u64::MAX). + // Zero is allowed through and caught by the bounds check below. + ScalarValue::Int8(Some(v)) if *v >= 0 => Some(*v as usize), + ScalarValue::Int16(Some(v)) if *v >= 0 => Some(*v as usize), + ScalarValue::Int32(Some(v)) if *v >= 0 => Some(*v as usize), + ScalarValue::Int64(Some(v)) if *v >= 0 => Some(*v as usize), + ScalarValue::UInt8(Some(v)) => Some(*v as usize), + ScalarValue::UInt16(Some(v)) => Some(*v as usize), + ScalarValue::UInt32(Some(v)) => Some(*v as usize), + ScalarValue::UInt64(Some(v)) => Some(*v as usize), + _ => None, + }) + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "approx_top_k requires k to be a positive literal integer" + .to_string(), + ) + })?; + + if k == 0 || k > APPROX_TOP_K_MAX_K { + return Err(datafusion_common::DataFusionError::Plan(format!( + "approx_top_k requires k to be between 1 and {APPROX_TOP_K_MAX_K}, got {k}" + ))); + } + + let data_type = args.expr_fields[0].data_type().clone(); + Ok(Box::new(ApproxTopKAccumulator::new_with_data_type( + k, data_type, + ))) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +// --------------------------------------------------------------------------- +// Accumulator +// --------------------------------------------------------------------------- + +/// Accumulator for `approx_top_k` using the Filtered Space-Saving algorithm. +#[derive(Debug)] +struct ApproxTopKAccumulator { + summary: SpaceSavingSummary, + k: usize, + /// The data type of the input column. + input_data_type: DataType, +} + +impl ApproxTopKAccumulator { + fn new_with_data_type(k: usize, input_data_type: DataType) -> Self { + let capacity = k * CAPACITY_MULTIPLIER; + Self { + summary: SpaceSavingSummary::new(capacity), + k, + input_data_type, + } + } + + /// Build the value array for the result based on the input data type. + fn build_value_array(&self, top_items: &[(&[u8], u64, u64)]) -> Result { + match &self.input_data_type { + DataType::Utf8 | DataType::LargeUtf8 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| String::from_utf8(bytes.to_vec()).ok()) + .collect(); + Ok(Arc::new(StringArray::from(values)) as ArrayRef) + } + DataType::Binary | DataType::LargeBinary => { + let values: Vec> = + top_items.iter().map(|(bytes, _, _)| Some(*bytes)).collect(); + Ok(Arc::new(BinaryArray::from(values)) as ArrayRef) + } + DataType::Int8 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 1]>::try_from(*bytes).ok().map(i8::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Int8Array::from(values)) as ArrayRef) + } + DataType::Int16 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 2]>::try_from(*bytes).ok().map(i16::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Int16Array::from(values)) as ArrayRef) + } + DataType::Int32 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 4]>::try_from(*bytes).ok().map(i32::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Int32Array::from(values)) as ArrayRef) + } + DataType::Int64 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 8]>::try_from(*bytes).ok().map(i64::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Int64Array::from(values)) as ArrayRef) + } + DataType::UInt8 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 1]>::try_from(*bytes).ok().map(u8::from_le_bytes) + }) + .collect(); + Ok(Arc::new(UInt8Array::from(values)) as ArrayRef) + } + DataType::UInt16 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 2]>::try_from(*bytes).ok().map(u16::from_le_bytes) + }) + .collect(); + Ok(Arc::new(UInt16Array::from(values)) as ArrayRef) + } + DataType::UInt32 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 4]>::try_from(*bytes).ok().map(u32::from_le_bytes) + }) + .collect(); + Ok(Arc::new(UInt32Array::from(values)) as ArrayRef) + } + DataType::UInt64 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 8]>::try_from(*bytes).ok().map(u64::from_le_bytes) + }) + .collect(); + Ok(Arc::new(UInt64Array::from(values)) as ArrayRef) + } + DataType::Float32 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 4]>::try_from(*bytes).ok().map(f32::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Float32Array::from(values)) as ArrayRef) + } + DataType::Float64 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 8]>::try_from(*bytes).ok().map(f64::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Float64Array::from(values)) as ArrayRef) + } + DataType::Date32 => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 4]>::try_from(*bytes).ok().map(i32::from_le_bytes) + }) + .collect(); + Ok(Arc::new(Date32Array::from(values)) as ArrayRef) + } + DataType::Date64 | DataType::Timestamp(_, _) => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| { + <[u8; 8]>::try_from(*bytes).ok().map(i64::from_le_bytes) + }) + .collect(); + // Date64 and all Timestamp variants share i64 storage. + match &self.input_data_type { + DataType::Date64 => { + Ok(Arc::new(Date64Array::from(values)) as ArrayRef) + } + DataType::Timestamp(unit, tz) => { + use arrow::datatypes::TimeUnit; + let arr: ArrayRef = match unit { + TimeUnit::Second => { + Arc::new(TimestampSecondArray::from(values)) + } + TimeUnit::Millisecond => { + Arc::new(TimestampMillisecondArray::from(values)) + } + TimeUnit::Microsecond => { + Arc::new(TimestampMicrosecondArray::from(values)) + } + TimeUnit::Nanosecond => { + Arc::new(TimestampNanosecondArray::from(values)) + } + }; + if tz.is_some() { + // Preserve timezone in the output. + Ok(arrow::compute::cast(&arr, &self.input_data_type)?) + } else { + Ok(arr) + } + } + _ => unreachable!(), + } + } + _ => { + let values: Vec> = top_items + .iter() + .map(|(bytes, _, _)| String::from_utf8(bytes.to_vec()).ok()) + .collect(); + Ok(Arc::new(StringArray::from(values)) as ArrayRef) + } + } + } + + /// Get the output data type for the value field. + fn output_value_data_type(&self) -> DataType { + match &self.input_data_type { + // LargeUtf8 is narrowed to Utf8: the summary stores in-memory + // byte slices, not large column offsets, so i32 offsets suffice. + DataType::LargeUtf8 => DataType::Utf8, + DataType::LargeBinary => DataType::Binary, + other => other.clone(), + } + } + + /// Serialize a DataType to a string using Arrow's Display impl. + fn data_type_to_string(dt: &DataType) -> String { + dt.to_string() + } +} + +impl Accumulator for ApproxTopKAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let data_array = &values[0]; + + // Downcast once and iterate directly to avoid per-row ScalarValue allocation. + macro_rules! process_array { + ($array_type:ty, $data_array:expr) => {{ + let arr = $data_array.as_any().downcast_ref::<$array_type>().unwrap(); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.summary.add(&arr.value(i).to_le_bytes()); + } + } + }}; + } + + match data_array.data_type() { + DataType::Utf8 => { + let arr = data_array.as_any().downcast_ref::().unwrap(); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.summary.add(arr.value(i).as_bytes()); + } + } + } + DataType::LargeUtf8 => { + let arr = data_array + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.summary.add(arr.value(i).as_bytes()); + } + } + } + DataType::Binary => { + let arr = data_array.as_any().downcast_ref::().unwrap(); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.summary.add(arr.value(i)); + } + } + } + DataType::LargeBinary => { + let arr = data_array + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..arr.len() { + if !arr.is_null(i) { + self.summary.add(arr.value(i)); + } + } + } + DataType::Int8 => process_array!(Int8Array, data_array), + DataType::Int16 => process_array!(Int16Array, data_array), + DataType::Int32 => process_array!(Int32Array, data_array), + DataType::Int64 => process_array!(Int64Array, data_array), + DataType::UInt8 => process_array!(UInt8Array, data_array), + DataType::UInt16 => process_array!(UInt16Array, data_array), + DataType::UInt32 => process_array!(UInt32Array, data_array), + DataType::UInt64 => process_array!(UInt64Array, data_array), + // Note: floats are compared by their byte representation, so -0.0 and +0.0 + // are treated as distinct values, and different NaN bit patterns are tracked + // separately. This matches ClickHouse's behavior for topK with floats. + DataType::Float32 => process_array!(Float32Array, data_array), + DataType::Float64 => process_array!(Float64Array, data_array), + DataType::Date32 => process_array!(Date32Array, data_array), + DataType::Date64 => process_array!(Date64Array, data_array), + DataType::Timestamp(_, _) => { + // All timestamp variants are stored as i64 internally. + match data_array.data_type() { + DataType::Timestamp(arrow::datatypes::TimeUnit::Second, _) => { + process_array!(TimestampSecondArray, data_array) + } + DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => { + process_array!(TimestampMillisecondArray, data_array) + } + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { + process_array!(TimestampMicrosecondArray, data_array) + } + DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, _) => { + process_array!(TimestampNanosecondArray, data_array) + } + _ => unreachable!(), + } + } + other => { + return Err(datafusion_common::DataFusionError::Execution(format!( + "Unsupported data type for approx_top_k: {other}" + ))); + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // State layout: [summary (Binary), k (UInt64), data_type (Utf8)]. + // The `k` field (states[1]) is carried for completeness but not read here + // because the planner guarantees all partial accumulators use the same `k`. + if states.is_empty() || states[0].is_empty() { + return Ok(()); + } + + let summary_array = states[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Execution( + "Expected Binary array for summary state".to_string(), + ) + })?; + + for i in 0..summary_array.len() { + if summary_array.is_null(i) { + continue; + } + let bytes = summary_array.value(i); + let other_summary = SpaceSavingSummary::from_bytes(bytes)?; + self.summary.merge(&other_summary); + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let top_items = self.summary.top_k(self.k); + + let value_data_type = self.output_value_data_type(); + let struct_fields = Fields::from(vec![ + Field::new("value", value_data_type, true), + Field::new("count", DataType::UInt64, false), + ]); + + let value_array = self.build_value_array(&top_items)?; + let counts: Vec = top_items.iter().map(|(_, count, _)| *count).collect(); + let count_array = Arc::new(UInt64Array::from(counts)) as ArrayRef; + + let struct_array = + StructArray::new(struct_fields.clone(), vec![value_array, count_array], None); + + let list_field = Field::new("item", DataType::Struct(struct_fields), true); + + Ok(ScalarValue::List(Arc::new(ListArray::new( + Arc::new(list_field), + OffsetBuffer::from_lengths([top_items.len()]), + Arc::new(struct_array), + None, + )))) + } + + fn state(&mut self) -> Result> { + let summary_bytes = self.summary.serialize(); + + Ok(vec![ + ScalarValue::Binary(Some(summary_bytes)), + ScalarValue::UInt64(Some(self.k as u64)), + ScalarValue::Utf8(Some(Self::data_type_to_string(&self.input_data_type))), + ]) + } + + fn size(&self) -> usize { + size_of::() + self.summary.size() + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_space_saving_basic() { + let mut summary = SpaceSavingSummary::new(3); + + summary.add(b"apple"); + summary.add(b"apple"); + summary.add(b"apple"); + summary.add(b"banana"); + summary.add(b"banana"); + summary.add(b"cherry"); + + let (count, error) = summary.get(b"apple").unwrap(); + assert_eq!(count, 3); + assert_eq!(error, 0); + + let (count, error) = summary.get(b"banana").unwrap(); + assert_eq!(count, 2); + assert_eq!(error, 0); + + let (count, error) = summary.get(b"cherry").unwrap(); + assert_eq!(count, 1); + assert_eq!(error, 0); + + let top = summary.top_k(3); + assert_eq!(top.len(), 3); + assert_eq!(top[0].0, b"apple"); + assert_eq!(top[0].1, 3); + assert_eq!(top[1].0, b"banana"); + assert_eq!(top[1].1, 2); + assert_eq!(top[2].0, b"cherry"); + assert_eq!(top[2].1, 1); + } + + #[test] + fn test_space_saving_eviction() { + let mut summary = SpaceSavingSummary::new(2); + + for _ in 0..100 { + summary.add(b"frequent"); + } + + for i in 0..63u64 { + let item = format!("rare_{i}"); + summary.add(item.as_bytes()); + } + + assert_eq!(summary.len(), 2); + + let (count, error) = summary.get(b"frequent").unwrap(); + assert_eq!(count, 100); + assert_eq!(error, 0); + + let evicted_count = (0..63u64) + .filter(|i| { + let item = format!("rare_{i}"); + summary.get(item.as_bytes()).is_none() + }) + .count(); + assert!(evicted_count >= 61); + } + + #[test] + fn test_space_saving_alpha_map() { + let mut summary = SpaceSavingSummary::new(2); + + for i in 0..64u64 { + let item = format!("item_{i}"); + summary.add(item.as_bytes()); + } + + assert_eq!(summary.len(), 2); + + let alpha_sum: u64 = summary.alpha_map.iter().sum(); + assert!(alpha_sum > 0); + + summary.add(b"item_0"); + assert_eq!(summary.len(), 3); + + let (count, error) = summary.get(b"item_0").unwrap(); + assert!(count > 1); + assert_eq!(count, error + 1); + } + + #[test] + fn test_space_saving_serialization() { + let mut summary = SpaceSavingSummary::new(3); + summary.add(b"test"); + summary.add(b"test"); + summary.add(b"value"); + + let bytes = summary.serialize(); + let restored = SpaceSavingSummary::from_bytes(&bytes).unwrap(); + + assert_eq!(restored.capacity(), summary.capacity()); + assert_eq!(restored.len(), summary.len()); + + let (count, _) = restored.get(b"test").unwrap(); + assert_eq!(count, 2); + let (count, _) = restored.get(b"value").unwrap(); + assert_eq!(count, 1); + } + + #[test] + fn test_space_saving_merge() { + let mut summary1 = SpaceSavingSummary::new(4); + let mut summary2 = SpaceSavingSummary::new(4); + + summary1.add(b"apple"); + summary1.add(b"apple"); + summary2.add(b"apple"); + summary2.add(b"banana"); + + summary1.merge(&summary2); + + let (count, _) = summary1.get(b"apple").unwrap(); + assert_eq!(count, 3); + + let (count, _) = summary1.get(b"banana").unwrap(); + assert_eq!(count, 1); + } + + #[test] + fn test_space_saving_merge_with_eviction() { + let mut summary1 = SpaceSavingSummary::new(2); + let mut summary2 = SpaceSavingSummary::new(2); + + for i in 0..40u64 { + let item = format!("s1_item_{i}"); + summary1.add(item.as_bytes()); + } + for _ in 0..10 { + summary1.add(b"top_item"); + } + + for i in 0..40u64 { + let item = format!("s2_item_{i}"); + summary2.add(item.as_bytes()); + } + for _ in 0..5 { + summary2.add(b"second_top"); + } + + summary1.merge(&summary2); + + let top = summary1.top_k(2); + assert!(!top.is_empty()); + let top_item_result = top.iter().find(|(item, _, _)| *item == b"top_item"); + assert!(top_item_result.is_some()); + } + + /// Helper to extract top-k results from a ScalarValue::List result. + fn extract_top_k_results(result: &ScalarValue) -> Vec<(String, u64)> { + if let ScalarValue::List(list_array) = result { + let struct_array = list_array + .values() + .as_any() + .downcast_ref::() + .expect("Expected StructArray"); + + let value_array = struct_array + .column(0) + .as_any() + .downcast_ref::() + .expect("Expected StringArray for values"); + let count_array = struct_array + .column(1) + .as_any() + .downcast_ref::() + .expect("Expected UInt64Array for counts"); + + (0..struct_array.len()) + .map(|i| { + let value = value_array.value(i).to_string(); + let count = count_array.value(i); + (value, count) + }) + .collect() + } else { + panic!("Expected ScalarValue::List, got {result:?}"); + } + } + + #[test] + fn test_accumulator_update_and_evaluate() { + let mut acc = ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + + let values: ArrayRef = Arc::new(StringArray::from(vec![ + "apple", "apple", "apple", "banana", "banana", "cherry", + ])); + + acc.update_batch(&[values]).unwrap(); + + let result = acc.evaluate().unwrap(); + let top_k = extract_top_k_results(&result); + + assert_eq!(top_k.len(), 3); + assert_eq!(top_k[0], ("apple".to_string(), 3)); + assert_eq!(top_k[1], ("banana".to_string(), 2)); + assert_eq!(top_k[2], ("cherry".to_string(), 1)); + } + + #[test] + fn test_accumulator_merge_batch() { + let mut acc1 = ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + let mut acc2 = ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + + let values1: ArrayRef = + Arc::new(StringArray::from(vec!["apple", "apple", "banana"])); + let values2: ArrayRef = + Arc::new(StringArray::from(vec!["apple", "cherry", "cherry"])); + + acc1.update_batch(&[values1]).unwrap(); + acc2.update_batch(&[values2]).unwrap(); + + let state2 = acc2.state().unwrap(); + + let summary_bytes = if let ScalarValue::Binary(Some(bytes)) = &state2[0] { + bytes.clone() + } else { + panic!("Expected Binary for summary") + }; + let k = if let ScalarValue::UInt64(Some(k)) = &state2[1] { + *k + } else { + panic!("Expected UInt64 for k") + }; + + let summary_array: ArrayRef = + Arc::new(BinaryArray::from(vec![Some(summary_bytes.as_slice())])); + let k_array: ArrayRef = Arc::new(UInt64Array::from(vec![k])); + + acc1.merge_batch(&[summary_array, k_array]).unwrap(); + + let result = acc1.evaluate().unwrap(); + let top_k = extract_top_k_results(&result); + + assert!(!top_k.is_empty()); + assert_eq!(top_k[0].0, "apple"); + assert_eq!(top_k[0].1, 3); + } + + #[test] + fn test_distributed_merge_simulation() { + let mut worker1_acc = + ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + let mut worker2_acc = + ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + let mut worker3_acc = + ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + + let values1: ArrayRef = + Arc::new(StringArray::from(vec!["apple", "apple", "apple", "banana"])); + worker1_acc.update_batch(&[values1]).unwrap(); + + let values2: ArrayRef = Arc::new(StringArray::from(vec![ + "apple", "apple", "cherry", "cherry", + ])); + worker2_acc.update_batch(&[values2]).unwrap(); + + let values3: ArrayRef = Arc::new(StringArray::from(vec![ + "banana", "banana", "banana", "durian", + ])); + worker3_acc.update_batch(&[values3]).unwrap(); + + let state1 = worker1_acc.state().unwrap(); + let state2 = worker2_acc.state().unwrap(); + let state3 = worker3_acc.state().unwrap(); + + let summary_bytes: Vec> = vec![ + if let ScalarValue::Binary(Some(ref b)) = state1[0] { + Some(b.as_slice()) + } else { + None + }, + if let ScalarValue::Binary(Some(ref b)) = state2[0] { + Some(b.as_slice()) + } else { + None + }, + if let ScalarValue::Binary(Some(ref b)) = state3[0] { + Some(b.as_slice()) + } else { + None + }, + ]; + let k_values: Vec = vec![ + if let ScalarValue::UInt64(Some(k)) = state1[1] { + k + } else { + 0 + }, + if let ScalarValue::UInt64(Some(k)) = state2[1] { + k + } else { + 0 + }, + if let ScalarValue::UInt64(Some(k)) = state3[1] { + k + } else { + 0 + }, + ]; + + let summary_array: ArrayRef = Arc::new(BinaryArray::from(summary_bytes)); + let k_array: ArrayRef = Arc::new(UInt64Array::from(k_values)); + + let mut coord_acc = ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); + coord_acc.merge_batch(&[summary_array, k_array]).unwrap(); + + let result = coord_acc.evaluate().unwrap(); + let top_k = extract_top_k_results(&result); + + assert!(top_k.len() >= 2); + assert_eq!(top_k[0], ("apple".to_string(), 5)); + assert_eq!(top_k[1], ("banana".to_string(), 4)); + } + + #[test] + fn test_accumulator_multiple_update_batches() { + let mut acc = ApproxTopKAccumulator::new_with_data_type(2, DataType::Utf8); + + // First batch: a=2, b=1 + let batch1: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "a"])); + acc.update_batch(&[batch1]).unwrap(); + + // Second batch: b=2, c=1 + let batch2: ArrayRef = Arc::new(StringArray::from(vec!["b", "c", "b"])); + acc.update_batch(&[batch2]).unwrap(); + + // Combined: b=3, a=2, c=1 → top-2 should be b, a + let result = acc.evaluate().unwrap(); + let top_k = extract_top_k_results(&result); + assert_eq!(top_k.len(), 2); + assert_eq!(top_k[0], ("b".to_string(), 3)); + assert_eq!(top_k[1], ("a".to_string(), 2)); + } + + #[test] + fn test_accumulator_large_utf8_input() { + let mut acc = ApproxTopKAccumulator::new_with_data_type(2, DataType::Utf8); + + let batch: ArrayRef = Arc::new(LargeStringArray::from(vec![ + "hello", "world", "hello", "hello", "world", + ])); + acc.update_batch(&[batch]).unwrap(); + + let result = acc.evaluate().unwrap(); + let top_k = extract_top_k_results(&result); + assert_eq!(top_k.len(), 2); + assert_eq!(top_k[0], ("hello".to_string(), 3)); + assert_eq!(top_k[1], ("world".to_string(), 2)); + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 1b9996220d882..7e014e5dc3580 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -69,6 +69,7 @@ pub mod approx_distinct; pub mod approx_median; pub mod approx_percentile_cont; pub mod approx_percentile_cont_with_weight; +pub mod approx_top_k; pub mod array_agg; pub mod average; pub mod bit_and_or_xor; @@ -106,6 +107,7 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::approx_top_k::approx_top_k; pub use super::array_agg::array_agg; pub use super::average::avg; pub use super::average::avg_distinct; @@ -175,6 +177,7 @@ pub fn all_default_aggregate_functions() -> Vec> { approx_distinct::approx_distinct_udaf(), approx_percentile_cont_udaf(), approx_percentile_cont_with_weight_udaf(), + approx_top_k::approx_top_k_udaf(), percentile_cont::percentile_cont_udaf(), string_agg::string_agg_udaf(), bit_and_or_xor::bit_and_udaf(), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6de9dd4caa9b4..61160b70317d1 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -53,9 +53,9 @@ use datafusion::execution::FunctionRegistry; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ - approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, - stddev, stddev_pop, sum, var_pop, var_sample, + approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, + approx_top_k, count, count_distinct, covar_pop, covar_samp, first_value, grouping, + max, median, min, stddev, stddev_pop, sum, var_pop, var_sample, }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; @@ -1172,6 +1172,7 @@ async fn roundtrip_expr_api() -> Result<()> { stddev_pop(lit(2.2)), approx_distinct(lit(2)), approx_median(lit(2)), + approx_top_k(vec![col("a"), lit(3)]), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), approx_percentile_cont_with_weight( diff --git a/datafusion/sqllogictest/test_files/approx_top_k.slt b/datafusion/sqllogictest/test_files/approx_top_k.slt new file mode 100644 index 0000000000000..b92acc77b27d8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/approx_top_k.slt @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########### +# approx_top_k tests +########### + +# approx_top_k basic usage with string column +query ? +SELECT approx_top_k(column1, 2) FROM (VALUES ('apple'), ('banana'), ('apple'), ('cherry'), ('apple'), ('banana')) AS t(column1); +---- +[{value: apple, count: 3}, {value: banana, count: 2}] + +# approx_top_k with integer column +query ? +SELECT approx_top_k(column1, 2) FROM (VALUES (1), (2), (1), (3), (1), (2)) AS t(column1); +---- +[{value: 1, count: 3}, {value: 2, count: 2}] + +# approx_top_k with GROUP BY +query T? +SELECT column2, approx_top_k(column1, 2) FROM (VALUES ('a', 'x'), ('b', 'x'), ('a', 'x'), ('c', 'y'), ('c', 'y'), ('d', 'y')) AS t(column1, column2) GROUP BY column2 ORDER BY column2; +---- +x [{value: a, count: 2}, {value: b, count: 1}] +y [{value: c, count: 2}, {value: d, count: 1}] + +# approx_top_k with k=1 +query ? +SELECT approx_top_k(column1, 1) FROM (VALUES ('red'), ('blue'), ('red'), ('green'), ('red')) AS t(column1); +---- +[{value: red, count: 3}] + +# approx_top_k with NULLs (should be skipped) +query ? +SELECT approx_top_k(column1, 2) FROM (VALUES ('a'), (NULL), ('a'), ('b'), (NULL), ('b'), ('b')) AS t(column1); +---- +[{value: b, count: 3}, {value: a, count: 2}] + +# approx_top_k error: missing k argument +statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_top_k' function +SELECT approx_top_k(column1) FROM (VALUES ('a')) AS t(column1); + +# approx_top_k error: k must be positive +statement error DataFusion error: Error during planning: approx_top_k requires k to be between 1 and 10000 +SELECT approx_top_k(column1, 0) FROM (VALUES ('a')) AS t(column1); + +# approx_top_k error: negative k +statement error DataFusion error: Error during planning: approx_top_k requires k to be a positive literal integer +SELECT approx_top_k(column1, -1) FROM (VALUES ('a')) AS t(column1); + +# approx_top_k with WHERE clause +query ? +SELECT approx_top_k(column1, 2) FROM (VALUES ('a', 1), ('b', 2), ('a', 3), ('c', 4), ('a', 5), ('b', 6)) AS t(column1, column2) WHERE column2 > 2; +---- +[{value: a, count: 2}, {value: c, count: 1}] + +# approx_top_k on empty input (should return an empty list) +query ? +SELECT approx_top_k(column1, 3) FROM (VALUES ('a')) AS t(column1) WHERE false; +---- +[] + +# approx_top_k with all NULLs (should return an empty list) +query ? +SELECT approx_top_k(column1, 2) FROM (VALUES (NULL), (NULL), (NULL)) AS t(column1); +---- +[] + +# approx_top_k where k > number of distinct values (should return all distinct values) +query ? +SELECT approx_top_k(column1, 100) FROM (VALUES ('x'), ('y'), ('x')) AS t(column1); +---- +[{value: x, count: 2}, {value: y, count: 1}] diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 0cd69ead4c33a..97c3dc95233b5 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -293,6 +293,7 @@ select log(-1), log(0), sqrt(-1); | approx_median(expr) | Calculates an approximation of the median for `expr`. | | approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | | approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| approx_top_k(expr, k) | Returns the approximate most frequent (top-k) values and their counts using the Filtered Space-Saving algorithm. | | bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | | bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | | bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index ba9c6ae12477b..60d1e0980aa6d 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -1081,6 +1081,7 @@ _Alias of [stddev](#stddev)._ - [approx_median](#approx_median) - [approx_percentile_cont](#approx_percentile_cont) - [approx_percentile_cont_with_weight](#approx_percentile_cont_with_weight) +- [approx_top_k](#approx_top_k) ### `approx_distinct` @@ -1219,3 +1220,27 @@ An alternative syntax is also supported: | 78.5 | +--------------------------------------------------+ ``` + +### `approx_top_k` + +Returns the approximate most frequent (top-k) values and their counts using the Filtered Space-Saving algorithm. Note: for float columns, -0.0 and +0.0 are treated as distinct values, and different NaN representations are tracked separately. + +```sql +approx_top_k(expression, k) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **k**: The number of top elements to return. Must be a literal integer between 1 and 10,000. + +#### Example + +```sql +> SELECT approx_top_k(column_name, 3) FROM table_name; ++-------------------------------------------+ +| approx_top_k(column_name, 3) | ++-------------------------------------------+ +| [{value: foo, count: 3}, {value: bar, count: 2}, {value: baz, count: 1}] | ++-------------------------------------------+ +``` From 3d1f008d066c9d410b22d7cc23abbceacb487c4c Mon Sep 17 00:00:00 2001 From: Sergio Esteves Date: Wed, 22 Apr 2026 16:23:59 +0100 Subject: [PATCH 2/3] fix: address review feedback for approx_top_k - Replace HashMap, usize> with hashbrown::HashTable<(u64, usize)> to eliminate key cloning, double-hashing, and O(n) rebuild allocations - Fix merge() to use Parallel Space-Saving m1/m2 algorithm (min counter count correction) matching ClickHouse's SpaceSaving::merge() - Fix serialize() to select top-N counters and fold evicted ones into alpha_map when requested_capacity < counters.len() < target_capacity - Harden from_bytes() with validation: requested_capacity > 0, num_counters <= requested_capacity, alpha_map_size is power-of-two, and overflow-safe bounds checking throughout - Replace all raw += on count/error/alpha with saturating_add - Track item_heap_bytes incrementally for O(1) size() accounting - Pre-allocate serialize() output buffer - Add TIMEZONE_WILDCARD timestamp variants to signature - Remove unused k and data_type from state_fields/state/merge_batch - Replace all manual DataFusionError construction with exec_err!/plan_err! - Replace .unwrap() downcasts in update_batch with proper error returns - Fix test_accumulator_large_utf8_input to use DataType::LargeUtf8 - Update #[user_doc] to document approximate counts, NULL handling, supported types, and return shape - Use with_timezone() instead of arrow::compute::cast for timestamps - Adapt to upstream API changes (remove as_any from AggregateUDFImpl, use direct downcast_ref for Literal) --- Cargo.lock | 1 + datafusion/functions-aggregate/Cargo.toml | 1 + .../functions-aggregate/src/approx_top_k.rs | 582 ++++++++++-------- .../user-guide/sql/aggregate_functions.md | 12 +- testing | 2 +- 5 files changed, 327 insertions(+), 271 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 336f1c24ef2bb..a967d51645018 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2309,6 +2309,7 @@ dependencies = [ "datafusion-physical-expr-common", "foldhash 0.2.0", "half", + "hashbrown 0.17.0", "log", "num-traits", "rand 0.9.4", diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 406f0a0e32cc3..3d9780c15be0a 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -50,6 +50,7 @@ datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +hashbrown = { workspace = true } foldhash = "0.2" half = { workspace = true } log = { workspace = true } diff --git a/datafusion/functions-aggregate/src/approx_top_k.rs b/datafusion/functions-aggregate/src/approx_top_k.rs index fdca78235ec34..3a6aec7e0b7e2 100644 --- a/datafusion/functions-aggregate/src/approx_top_k.rs +++ b/datafusion/functions-aggregate/src/approx_top_k.rs @@ -33,8 +33,8 @@ //! - Space-Saving: Metwally, Agrawal, El Abbadi. "Efficient Computation of Frequent //! and Top-k Elements in Data Streams" (ICDT 2005) -use std::any::Any; -use std::collections::HashMap; +use std::cmp::{Ordering, max, min}; +use std::mem::size_of; use std::sync::Arc; use arrow::array::{ @@ -45,12 +45,15 @@ use arrow::array::{ UInt8Array, UInt16Array, UInt32Array, UInt64Array, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field, FieldRef, Fields}; -use datafusion_common::{Result, ScalarValue}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, TimeUnit}; +use hashbrown::HashTable; + +use datafusion_common::{Result, ScalarValue, exec_err, plan_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Signature, TIMEZONE_WILDCARD, + TypeSignature, Volatility, }; use datafusion_macros::user_doc; @@ -113,30 +116,37 @@ impl Counter { } /// Ordering for top-k selection: highest-ranked counters sort first. - fn cmp_by_rank(&self, other: &Counter) -> std::cmp::Ordering { + fn cmp_by_rank(&self, other: &Counter) -> Ordering { if other.is_greater_than(self) { - std::cmp::Ordering::Greater + Ordering::Greater } else if self.is_greater_than(other) { - std::cmp::Ordering::Less + Ordering::Less } else { - std::cmp::Ordering::Equal + Ordering::Equal } } } /// Filtered Space-Saving algorithm summary for approximate top-k / heavy hitters. /// -/// Uses a hash map for O(1) counter lookups and maintains an alpha map (filter) -/// to remember evicted items' frequencies. +/// Uses a [`HashTable`] that stores `(hash, index)` tuples for O(1) counter +/// lookups without duplicating the key bytes. The actual item data lives in +/// `counters[index].item`. An alpha map (filter) remembers evicted items' +/// frequencies. #[derive(Debug, Clone)] struct SpaceSavingSummary { counters: Vec, - counter_map: HashMap, usize>, + /// Maps `(cached_hash, counter_index)`. Lookups use the cached hash for + /// the fast path and fall back to byte equality via the `counters` vec. + counter_map: HashTable<(u64, usize)>, alpha_map: Vec, requested_capacity: usize, /// Internal target capacity to avoid frequent truncations. /// Set to `max(64, requested_capacity * 2)`. target_capacity: usize, + /// Running total of heap bytes owned by counter item `Vec`s. + /// Updated on push / evict / clone so that [`size`] is O(1). + item_heap_bytes: usize, } impl SpaceSavingSummary { @@ -157,10 +167,11 @@ impl SpaceSavingSummary { fn new(capacity: usize) -> Self { Self { counters: Vec::new(), - counter_map: HashMap::new(), + counter_map: HashTable::new(), alpha_map: Vec::new(), requested_capacity: 0, target_capacity: 0, + item_heap_bytes: 0, } .resized(capacity) } @@ -171,7 +182,7 @@ impl SpaceSavingSummary { let alpha_map_size = Self::compute_alpha_map_size(new_capacity); self.alpha_map = vec![0u64; alpha_map_size]; self.requested_capacity = new_capacity; - self.target_capacity = std::cmp::max(64, new_capacity * 2); + self.target_capacity = max(64, new_capacity.saturating_mul(2)); self.counters.reserve(self.target_capacity); } self @@ -191,16 +202,21 @@ impl SpaceSavingSummary { self.requested_capacity } - fn find_counter_mut(&mut self, item: &[u8]) -> Option<&mut Counter> { + /// Find a counter by item bytes, using the pre-computed hash for fast + /// lookup and falling back to byte equality for collision resolution. + fn find_counter_idx(&self, item: &[u8], hash: u64) -> Option { self.counter_map - .get(item) - .copied() - .map(|idx| &mut self.counters[idx]) + .find(hash, |&(h, idx)| { + h == hash && self.counters[idx].item == item + }) + .map(|&(_, idx)| idx) } #[cfg(test)] fn find_counter(&self, item: &[u8]) -> Option<&Counter> { - self.counter_map.get(item).map(|&idx| &self.counters[idx]) + let hash = Self::hash_item(item); + self.find_counter_idx(item, hash) + .map(|idx| &self.counters[idx]) } /// Add an item with increment 1. @@ -213,9 +229,9 @@ impl SpaceSavingSummary { let hash = Self::hash_item(item); // Fast path: item already tracked. - if let Some(counter) = self.find_counter_mut(item) { - counter.count += increment; - counter.error += error; + if let Some(idx) = self.find_counter_idx(item, hash) { + self.counters[idx].count = self.counters[idx].count.saturating_add(increment); + self.counters[idx].error = self.counters[idx].error.saturating_add(error); return; } @@ -230,12 +246,19 @@ impl SpaceSavingSummary { let alpha_idx = (hash as usize) & alpha_mask; let alpha = self.alpha_map[alpha_idx]; - self.push_counter(item.to_vec(), hash, alpha + increment, alpha + error); + self.push_counter( + item.to_vec(), + hash, + alpha.saturating_add(increment), + alpha.saturating_add(error), + ); } fn push_counter(&mut self, item: Vec, hash: u64, count: u64, error: u64) { let idx = self.counters.len(); - self.counter_map.insert(item.clone(), idx); + self.item_heap_bytes += item.capacity(); + self.counter_map + .insert_unique(hash, (hash, idx), |&(h, _)| h); self.counters.push(Counter { item, hash, @@ -260,19 +283,29 @@ impl SpaceSavingSummary { for counter in self.counters.drain(k..) { let alpha_idx = (counter.hash as usize) & alpha_mask; let true_count = counter.count.saturating_sub(counter.error); - self.alpha_map[alpha_idx] = std::cmp::min( - self.alpha_map[alpha_idx] + true_count, + self.alpha_map[alpha_idx] = min( + self.alpha_map[alpha_idx].saturating_add(true_count), MAX_ALPHA_VALUE, ); + self.item_heap_bytes -= counter.item.capacity(); } } } if need_truncate || force_rebuild { - self.counter_map.clear(); - for (idx, counter) in self.counters.iter().enumerate() { - self.counter_map.insert(counter.item.clone(), idx); - } + self.rebuild_counter_map(); + } + } + + /// Rebuild the `counter_map` from the current `counters` vec. + fn rebuild_counter_map(&mut self) { + self.counter_map.clear(); + for (idx, counter) in self.counters.iter().enumerate() { + self.counter_map.insert_unique( + counter.hash, + (counter.hash, idx), + |&(h, _)| h, + ); } } @@ -288,7 +321,7 @@ impl SpaceSavingSummary { } let mut sorted: Vec<_> = self.counters.iter().collect(); - let return_size = std::cmp::min(sorted.len(), k); + let return_size = min(sorted.len(), k); if return_size < sorted.len() { sorted.select_nth_unstable_by(return_size - 1, |a, b| a.cmp_by_rank(b)); @@ -303,7 +336,9 @@ impl SpaceSavingSummary { .collect() } - /// Merge another summary into this one. + /// Merge another summary into this one using the Parallel Space-Saving + /// reduce-and-combine algorithm from , + /// matching ClickHouse's `SpaceSaving::merge()` implementation. fn merge(&mut self, other: &SpaceSavingSummary) { if other.is_empty() { return; @@ -315,61 +350,116 @@ impl SpaceSavingSummary { self.alpha_map.clone_from(&other.alpha_map); self.requested_capacity = other.requested_capacity; self.target_capacity = other.target_capacity; + // Recompute from cloned vecs since clone may allocate exact-size + // (capacity == len), which can differ from the original's capacity. + self.item_heap_bytes = self.counters.iter().map(|c| c.item.capacity()).sum(); return; } + // Compute m1/m2: the minimum counter count in each summary. + // Per the Parallel Space-Saving paper (Theorem 1), if a summary has + // reached capacity, items not in its counter list could have had at + // most min(counter.count) frequency. This is the merge correction. + let m1 = if self.counters.len() >= self.requested_capacity { + self.counters.iter().map(|c| c.count).min().unwrap_or(0) + } else { + 0 + }; + let m2 = if other.counters.len() >= other.requested_capacity { + other.counters.iter().map(|c| c.count).min().unwrap_or(0) + } else { + 0 + }; + + // Step 1: Bump all self counters by m2 (upper bound of what they + // could have counted in the other partition). + if m2 > 0 { + for counter in &mut self.counters { + counter.count = counter.count.saturating_add(m2); + counter.error = counter.error.saturating_add(m2); + } + } + + // Step 2: Merge other's counters into self. for other_counter in &other.counters { - if let Some(idx) = self.counter_map.get(&other_counter.item).copied() { - self.counters[idx].count += other_counter.count; - self.counters[idx].error += other_counter.error; + if let Some(idx) = + self.find_counter_idx(&other_counter.item, other_counter.hash) + { + // Item exists in both: add other's count, subtract the m2 we + // already added in step 1 (guaranteed non-negative). + self.counters[idx].count = self.counters[idx] + .count + .saturating_add(other_counter.count.saturating_sub(m2)); + self.counters[idx].error = self.counters[idx] + .error + .saturating_add(other_counter.error.saturating_sub(m2)); } else { + // Item only in other: add with m1 (upper bound of what it + // could have counted in self's partition). + let item = other_counter.item.clone(); + self.item_heap_bytes += item.capacity(); self.counters.push(Counter { - item: other_counter.item.clone(), + item, hash: other_counter.hash, - count: other_counter.count, - error: other_counter.error, + count: other_counter.count.saturating_add(m1), + error: other_counter.error.saturating_add(m1), }); } } - // Merge alpha maps element-wise. Sizes should always match because the - // planner guarantees the same k (and thus the same capacity/alpha map size) - // across all partitions. If they differ due to a bug, we skip the merge - // which only degrades accuracy without affecting correctness. + // Merge alpha maps element-wise. if self.alpha_map.len() == other.alpha_map.len() { for (i, &other_alpha) in other.alpha_map.iter().enumerate() { - self.alpha_map[i] = - std::cmp::min(self.alpha_map[i] + other_alpha, MAX_ALPHA_VALUE); + self.alpha_map[i] = min( + self.alpha_map[i].saturating_add(other_alpha), + MAX_ALPHA_VALUE, + ); } } self.truncate_if_needed(true); } - /// Serialize the summary to bytes. + /// Serialize the summary to bytes (matches ClickHouse's `write()` format). + /// + /// Only the top `requested_capacity` counters are written. The alpha map + /// carries evicted frequency information for the coordinator merge. fn serialize(&mut self) -> Vec { - // Ensure counters are truncated and alpha map is up to date before - // serializing, in case to_bytes is called without a prior truncation. self.truncate_if_needed(false); - let counters_to_write: Vec<_> = { - let mut sorted: Vec<_> = self.counters.iter().collect(); - let return_size = std::cmp::min(sorted.len(), self.requested_capacity); - if return_size > 0 && return_size < sorted.len() { - // After select_nth_unstable_by, the top-k counters are in - // sorted[..return_size] but in arbitrary order. This is fine - // since deserialization doesn't depend on counter ordering. - sorted.select_nth_unstable_by(return_size - 1, |a, b| a.cmp_by_rank(b)); + // If there are still more counters than requested_capacity (because + // target_capacity wasn't reached), partition out the top ones and + // fold the tail into the alpha map before serializing. + let k = self.requested_capacity; + if k > 0 && k < self.counters.len() { + self.counters + .select_nth_unstable_by(k - 1, |a, b| a.cmp_by_rank(b)); + + let alpha_mask = self.alpha_map.len() - 1; + for counter in self.counters.drain(k..) { + let alpha_idx = (counter.hash as usize) & alpha_mask; + let true_count = counter.count.saturating_sub(counter.error); + self.alpha_map[alpha_idx] = min( + self.alpha_map[alpha_idx].saturating_add(true_count), + MAX_ALPHA_VALUE, + ); + self.item_heap_bytes -= counter.item.capacity(); } - sorted.truncate(return_size); - sorted - }; + self.rebuild_counter_map(); + } + + let num_counters = self.counters.len(); + + // Pre-compute total size for a single allocation. + let counter_bytes: usize = + self.counters.iter().map(|c| 4 + c.item.len() + 16).sum(); + let total = 16 + counter_bytes + 8 + self.alpha_map.len() * 8; + let mut bytes = Vec::with_capacity(total); - let mut bytes = Vec::new(); bytes.extend_from_slice(&(self.requested_capacity as u64).to_le_bytes()); - bytes.extend_from_slice(&(counters_to_write.len() as u64).to_le_bytes()); + bytes.extend_from_slice(&(num_counters as u64).to_le_bytes()); - for counter in counters_to_write { + for counter in &self.counters { bytes.extend_from_slice(&(counter.item.len() as u32).to_le_bytes()); bytes.extend_from_slice(&counter.item); bytes.extend_from_slice(&counter.count.to_le_bytes()); @@ -387,35 +477,64 @@ impl SpaceSavingSummary { /// Deserialize a summary from bytes. fn from_bytes(bytes: &[u8]) -> Result { if bytes.len() < 16 { - return Err(datafusion_common::DataFusionError::Execution( - "Invalid Space-Saving summary bytes: too short".to_string(), - )); + return exec_err!("Invalid Space-Saving summary bytes: too short"); } let requested_capacity = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize; let num_counters = u64::from_le_bytes(bytes[8..16].try_into().unwrap()) as usize; + // Validate against reasonable upper bounds to prevent OOM from + // malformed state. + if requested_capacity == 0 { + return exec_err!("Invalid Space-Saving summary: requested_capacity is 0"); + } + let max_capacity = APPROX_TOP_K_MAX_K * CAPACITY_MULTIPLIER; + if requested_capacity > max_capacity { + return exec_err!( + "Invalid Space-Saving summary: requested_capacity {requested_capacity} \ + exceeds maximum {max_capacity}" + ); + } + if num_counters > requested_capacity { + return exec_err!( + "Invalid Space-Saving summary: num_counters {num_counters} exceeds \ + requested_capacity {requested_capacity}" + ); + } + // Each counter needs at least 20 bytes (4 len + 0 item + 8 count + 8 error). + let max_possible = (bytes.len().saturating_sub(16)) / 20; + if num_counters > max_possible { + return exec_err!( + "Invalid Space-Saving summary: num_counters {num_counters} exceeds \ + what fits in {} bytes", + bytes.len() + ); + } + let mut counters = Vec::with_capacity(num_counters); - let mut counter_map = HashMap::with_capacity(num_counters); - let mut offset = 16; + let mut counter_map = HashTable::with_capacity(num_counters); + let mut item_heap_bytes: usize = 0; + let mut offset: usize = 16; for idx in 0..num_counters { - if offset + 4 > bytes.len() { - return Err(datafusion_common::DataFusionError::Execution( + if offset.checked_add(4).is_none_or(|end| end > bytes.len()) { + return exec_err!( "Invalid Space-Saving summary bytes: truncated item length" - .to_string(), - )); + ); } let item_len = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) as usize; offset += 4; - if offset + item_len + 16 > bytes.len() { - return Err(datafusion_common::DataFusionError::Execution( - "Invalid Space-Saving summary bytes: truncated counter".to_string(), - )); + let needed = item_len.checked_add(16); + if needed + .is_none_or(|n| offset.checked_add(n).is_none_or(|end| end > bytes.len())) + { + return exec_err!( + "Invalid Space-Saving summary bytes: truncated counter" + ); } let item = bytes[offset..offset + item_len].to_vec(); @@ -428,7 +547,8 @@ impl SpaceSavingSummary { offset += 8; let hash = Self::hash_item(&item); - counter_map.insert(item.clone(), idx); + item_heap_bytes += item.capacity(); + counter_map.insert_unique(hash, (hash, idx), |&(h, _)| h); counters.push(Counter { item, hash, @@ -437,19 +557,28 @@ impl SpaceSavingSummary { }); } - if offset + 8 > bytes.len() { - return Err(datafusion_common::DataFusionError::Execution( - "Invalid Space-Saving summary bytes: missing alpha map size".to_string(), - )); + if offset.checked_add(8).is_none_or(|end| end > bytes.len()) { + return exec_err!( + "Invalid Space-Saving summary bytes: missing alpha map size" + ); } let alpha_map_size = u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap()) as usize; offset += 8; - if offset + alpha_map_size.saturating_mul(8) > bytes.len() { - return Err(datafusion_common::DataFusionError::Execution( - "Invalid Space-Saving summary bytes: truncated alpha map".to_string(), - )); + // Validate alpha_map_size is a power of two (required for bitmask indexing). + if alpha_map_size == 0 || !alpha_map_size.is_power_of_two() { + return exec_err!( + "Invalid Space-Saving summary: alpha_map_size {alpha_map_size} \ + is not a positive power of two" + ); + } + + let alpha_bytes = alpha_map_size + .checked_mul(8) + .and_then(|n| offset.checked_add(n)); + if alpha_bytes.is_none_or(|end| end > bytes.len()) { + return exec_err!("Invalid Space-Saving summary bytes: truncated alpha map"); } let mut alpha_map = Vec::with_capacity(alpha_map_size); @@ -459,7 +588,7 @@ impl SpaceSavingSummary { alpha_map.push(alpha); } - let target_capacity = std::cmp::max(64, requested_capacity * 2); + let target_capacity = max(64, requested_capacity.saturating_mul(2)); Ok(Self { counters, @@ -467,29 +596,19 @@ impl SpaceSavingSummary { alpha_map, requested_capacity, target_capacity, + item_heap_bytes, }) } - /// Approximate size in bytes of this summary. + /// Approximate size in bytes of this summary. O(1) thanks to + /// incremental `item_heap_bytes` tracking. fn size(&self) -> usize { - // Heap bytes owned by each counter's item Vec. - let item_heap_bytes: usize = self - .counters - .iter() - .map(|c| c.item.capacity()) - .sum::(); - size_of::() - // Vec backing storage. + self.counters.capacity() * size_of::() - // Heap allocations for counter item bytes (owned by counters). - + item_heap_bytes - // HashMap, usize>: bucket overhead + control bytes. + + self.item_heap_bytes + // HashTable<(u64, usize)>: each bucket stores (hash, idx) + 1 control byte. + self.counter_map.capacity() - * (size_of::<(Vec, usize)>() + size_of::()) - // HashMap keys are clones of counter items, so count their heap bytes again. - + item_heap_bytes - // Vec alpha map. + * (size_of::<(u64, usize)>() + size_of::()) + self.alpha_map.capacity() * size_of::() } } @@ -501,15 +620,15 @@ impl SpaceSavingSummary { /// Approximate top-k UDAF using the Filtered Space-Saving algorithm. #[user_doc( doc_section(label = "Approximate Functions"), - description = "Returns the approximate most frequent (top-k) values and their counts using the Filtered Space-Saving algorithm. Note: for float columns, -0.0 and +0.0 are treated as distinct values, and different NaN representations are tracked separately.", + description = r#"Returns the approximate most frequent (top-k) values with their estimated counts, using the Filtered Space-Saving algorithm. The returned counts are upper-bound estimates; the true frequency lies in `[count - error, count]`. NULL values are skipped; an empty or all-NULL input returns an empty list `[]`. For float columns, -0.0 and +0.0 are treated as distinct values, and different NaN representations are tracked separately."#, syntax_example = "approx_top_k(expression, k)", sql_example = r#"```sql > SELECT approx_top_k(column_name, 3) FROM table_name; -+-------------------------------------------+ -| approx_top_k(column_name, 3) | -+-------------------------------------------+ -| [{value: foo, count: 3}, {value: bar, count: 2}, {value: baz, count: 1}] | -+-------------------------------------------+ ++-----------------------------------------------------------------------------+ +| approx_top_k(column_name,Int64(3)) | ++-----------------------------------------------------------------------------+ +| [{value: foo, count: 3}, {value: bar, count: 2}, {value: baz, count: 1}] | ++-----------------------------------------------------------------------------+ ```"#, standard_argument(name = "expression",), argument( @@ -548,10 +667,14 @@ impl ApproxTopK { DataType::Float64, DataType::Date32, DataType::Date64, - DataType::Timestamp(arrow::datatypes::TimeUnit::Second, None), - DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None), - DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Second, Some(TIMEZONE_WILDCARD.into())), + DataType::Timestamp(TimeUnit::Millisecond, Some(TIMEZONE_WILDCARD.into())), + DataType::Timestamp(TimeUnit::Microsecond, Some(TIMEZONE_WILDCARD.into())), + DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]; // k must be a literal integer; accept any integer type for convenience. let k_types = &[ @@ -579,10 +702,6 @@ impl ApproxTopK { } impl AggregateUDFImpl for ApproxTopK { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "approx_top_k" } @@ -616,35 +735,20 @@ impl AggregateUDFImpl for ApproxTopK { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - Ok(vec![ - Arc::new(Field::new( - format_state_name(args.name, "summary"), - DataType::Binary, - true, - )), - Arc::new(Field::new( - format_state_name(args.name, "k"), - DataType::UInt64, - true, - )), - Arc::new(Field::new( - format_state_name(args.name, "data_type"), - DataType::Utf8, - true, - )), - ]) + Ok(vec![Arc::new(Field::new( + format_state_name(args.name, "summary"), + DataType::Binary, + true, + ))]) } fn accumulator(&self, args: AccumulatorArgs) -> Result> { if args.exprs.len() < 2 { - return Err(datafusion_common::DataFusionError::Plan( - "approx_top_k requires two arguments: column and k".to_string(), - )); + return plan_err!("approx_top_k requires two arguments: column and k"); } let k_expr = &args.exprs[1]; let k = k_expr - .as_any() .downcast_ref::() .and_then(|lit| match lit.value() { // Guard against negative values before casting to usize to @@ -659,18 +763,19 @@ impl AggregateUDFImpl for ApproxTopK { ScalarValue::UInt32(Some(v)) => Some(*v as usize), ScalarValue::UInt64(Some(v)) => Some(*v as usize), _ => None, - }) - .ok_or_else(|| { - datafusion_common::DataFusionError::Plan( - "approx_top_k requires k to be a positive literal integer" - .to_string(), - ) - })?; + }); + + let Some(k) = k else { + return plan_err!( + "approx_top_k requires k to be a positive literal integer \ + between 1 and 10000" + ); + }; if k == 0 || k > APPROX_TOP_K_MAX_K { - return Err(datafusion_common::DataFusionError::Plan(format!( + return plan_err!( "approx_top_k requires k to be between 1 and {APPROX_TOP_K_MAX_K}, got {k}" - ))); + ); } let data_type = args.expr_fields[0].data_type().clone(); @@ -833,29 +938,36 @@ impl ApproxTopKAccumulator { DataType::Date64 => { Ok(Arc::new(Date64Array::from(values)) as ArrayRef) } - DataType::Timestamp(unit, tz) => { - use arrow::datatypes::TimeUnit; - let arr: ArrayRef = match unit { - TimeUnit::Second => { - Arc::new(TimestampSecondArray::from(values)) + DataType::Timestamp(unit, tz) => match unit { + TimeUnit::Second => { + let mut arr = TimestampSecondArray::from(values); + if let Some(tz) = tz { + arr = arr.with_timezone(tz.as_ref()); } - TimeUnit::Millisecond => { - Arc::new(TimestampMillisecondArray::from(values)) + Ok(Arc::new(arr) as ArrayRef) + } + TimeUnit::Millisecond => { + let mut arr = TimestampMillisecondArray::from(values); + if let Some(tz) = tz { + arr = arr.with_timezone(tz.as_ref()); } - TimeUnit::Microsecond => { - Arc::new(TimestampMicrosecondArray::from(values)) + Ok(Arc::new(arr) as ArrayRef) + } + TimeUnit::Microsecond => { + let mut arr = TimestampMicrosecondArray::from(values); + if let Some(tz) = tz { + arr = arr.with_timezone(tz.as_ref()); } - TimeUnit::Nanosecond => { - Arc::new(TimestampNanosecondArray::from(values)) + Ok(Arc::new(arr) as ArrayRef) + } + TimeUnit::Nanosecond => { + let mut arr = TimestampNanosecondArray::from(values); + if let Some(tz) = tz { + arr = arr.with_timezone(tz.as_ref()); } - }; - if tz.is_some() { - // Preserve timezone in the output. - Ok(arrow::compute::cast(&arr, &self.input_data_type)?) - } else { - Ok(arr) + Ok(Arc::new(arr) as ArrayRef) } - } + }, _ => unreachable!(), } } @@ -879,11 +991,6 @@ impl ApproxTopKAccumulator { other => other.clone(), } } - - /// Serialize a DataType to a string using Arrow's Display impl. - fn data_type_to_string(dt: &DataType) -> String { - dt.to_string() - } } impl Accumulator for ApproxTopKAccumulator { @@ -897,7 +1004,12 @@ impl Accumulator for ApproxTopKAccumulator { // Downcast once and iterate directly to avoid per-row ScalarValue allocation. macro_rules! process_array { ($array_type:ty, $data_array:expr) => {{ - let arr = $data_array.as_any().downcast_ref::<$array_type>().unwrap(); + let Some(arr) = $data_array.as_any().downcast_ref::<$array_type>() else { + return exec_err!( + "approx_top_k: failed to downcast array to {}", + stringify!($array_type) + ); + }; for i in 0..arr.len() { if !arr.is_null(i) { self.summary.add(&arr.value(i).to_le_bytes()); @@ -906,44 +1018,30 @@ impl Accumulator for ApproxTopKAccumulator { }}; } - match data_array.data_type() { - DataType::Utf8 => { - let arr = data_array.as_any().downcast_ref::().unwrap(); + macro_rules! process_bytes_array { + ($array_type:ty, $data_array:expr) => {{ + let Some(arr) = $data_array.as_any().downcast_ref::<$array_type>() else { + return exec_err!( + "approx_top_k: failed to downcast array to {}", + stringify!($array_type) + ); + }; for i in 0..arr.len() { if !arr.is_null(i) { - self.summary.add(arr.value(i).as_bytes()); + self.summary.add(arr.value(i).as_ref()); } } - } + }}; + } + + match data_array.data_type() { + DataType::Utf8 => process_bytes_array!(StringArray, data_array), DataType::LargeUtf8 => { - let arr = data_array - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..arr.len() { - if !arr.is_null(i) { - self.summary.add(arr.value(i).as_bytes()); - } - } - } - DataType::Binary => { - let arr = data_array.as_any().downcast_ref::().unwrap(); - for i in 0..arr.len() { - if !arr.is_null(i) { - self.summary.add(arr.value(i)); - } - } + process_bytes_array!(LargeStringArray, data_array) } + DataType::Binary => process_bytes_array!(BinaryArray, data_array), DataType::LargeBinary => { - let arr = data_array - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..arr.len() { - if !arr.is_null(i) { - self.summary.add(arr.value(i)); - } - } + process_bytes_array!(LargeBinaryArray, data_array) } DataType::Int8 => process_array!(Int8Array, data_array), DataType::Int16 => process_array!(Int16Array, data_array), @@ -960,28 +1058,22 @@ impl Accumulator for ApproxTopKAccumulator { DataType::Float64 => process_array!(Float64Array, data_array), DataType::Date32 => process_array!(Date32Array, data_array), DataType::Date64 => process_array!(Date64Array, data_array), - DataType::Timestamp(_, _) => { - // All timestamp variants are stored as i64 internally. - match data_array.data_type() { - DataType::Timestamp(arrow::datatypes::TimeUnit::Second, _) => { - process_array!(TimestampSecondArray, data_array) - } - DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => { - process_array!(TimestampMillisecondArray, data_array) - } - DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => { - process_array!(TimestampMicrosecondArray, data_array) - } - DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, _) => { - process_array!(TimestampNanosecondArray, data_array) - } - _ => unreachable!(), + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => { + process_array!(TimestampSecondArray, data_array) } - } + TimeUnit::Millisecond => { + process_array!(TimestampMillisecondArray, data_array) + } + TimeUnit::Microsecond => { + process_array!(TimestampMicrosecondArray, data_array) + } + TimeUnit::Nanosecond => { + process_array!(TimestampNanosecondArray, data_array) + } + }, other => { - return Err(datafusion_common::DataFusionError::Execution(format!( - "Unsupported data type for approx_top_k: {other}" - ))); + return exec_err!("Unsupported data type for approx_top_k: {other}"); } } @@ -989,21 +1081,13 @@ impl Accumulator for ApproxTopKAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // State layout: [summary (Binary), k (UInt64), data_type (Utf8)]. - // The `k` field (states[1]) is carried for completeness but not read here - // because the planner guarantees all partial accumulators use the same `k`. if states.is_empty() || states[0].is_empty() { return Ok(()); } - let summary_array = states[0] - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion_common::DataFusionError::Execution( - "Expected Binary array for summary state".to_string(), - ) - })?; + let Some(summary_array) = states[0].as_any().downcast_ref::() else { + return exec_err!("Expected Binary array for summary state"); + }; for i in 0..summary_array.len() { if summary_array.is_null(i) { @@ -1044,13 +1128,7 @@ impl Accumulator for ApproxTopKAccumulator { } fn state(&mut self) -> Result> { - let summary_bytes = self.summary.serialize(); - - Ok(vec![ - ScalarValue::Binary(Some(summary_bytes)), - ScalarValue::UInt64(Some(self.k as u64)), - ScalarValue::Utf8(Some(Self::data_type_to_string(&self.input_data_type))), - ]) + Ok(vec![ScalarValue::Binary(Some(self.summary.serialize()))]) } fn size(&self) -> usize { @@ -1287,17 +1365,11 @@ mod tests { } else { panic!("Expected Binary for summary") }; - let k = if let ScalarValue::UInt64(Some(k)) = &state2[1] { - *k - } else { - panic!("Expected UInt64 for k") - }; let summary_array: ArrayRef = Arc::new(BinaryArray::from(vec![Some(summary_bytes.as_slice())])); - let k_array: ArrayRef = Arc::new(UInt64Array::from(vec![k])); - acc1.merge_batch(&[summary_array, k_array]).unwrap(); + acc1.merge_batch(&[summary_array]).unwrap(); let result = acc1.evaluate().unwrap(); let top_k = extract_top_k_results(&result); @@ -1351,29 +1423,11 @@ mod tests { None }, ]; - let k_values: Vec = vec![ - if let ScalarValue::UInt64(Some(k)) = state1[1] { - k - } else { - 0 - }, - if let ScalarValue::UInt64(Some(k)) = state2[1] { - k - } else { - 0 - }, - if let ScalarValue::UInt64(Some(k)) = state3[1] { - k - } else { - 0 - }, - ]; let summary_array: ArrayRef = Arc::new(BinaryArray::from(summary_bytes)); - let k_array: ArrayRef = Arc::new(UInt64Array::from(k_values)); let mut coord_acc = ApproxTopKAccumulator::new_with_data_type(3, DataType::Utf8); - coord_acc.merge_batch(&[summary_array, k_array]).unwrap(); + coord_acc.merge_batch(&[summary_array]).unwrap(); let result = coord_acc.evaluate().unwrap(); let top_k = extract_top_k_results(&result); @@ -1405,7 +1459,7 @@ mod tests { #[test] fn test_accumulator_large_utf8_input() { - let mut acc = ApproxTopKAccumulator::new_with_data_type(2, DataType::Utf8); + let mut acc = ApproxTopKAccumulator::new_with_data_type(2, DataType::LargeUtf8); let batch: ArrayRef = Arc::new(LargeStringArray::from(vec![ "hello", "world", "hello", "hello", "world", diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 60d1e0980aa6d..e5d1d092c382e 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -1223,7 +1223,7 @@ An alternative syntax is also supported: ### `approx_top_k` -Returns the approximate most frequent (top-k) values and their counts using the Filtered Space-Saving algorithm. Note: for float columns, -0.0 and +0.0 are treated as distinct values, and different NaN representations are tracked separately. +Returns the approximate most frequent (top-k) values with their estimated counts, using the Filtered Space-Saving algorithm. The returned counts are upper-bound estimates; the true frequency lies in `[count - error, count]`. NULL values are skipped; an empty or all-NULL input returns an empty list `[]`. For float columns, -0.0 and +0.0 are treated as distinct values, and different NaN representations are tracked separately. ```sql approx_top_k(expression, k) @@ -1238,9 +1238,9 @@ approx_top_k(expression, k) ```sql > SELECT approx_top_k(column_name, 3) FROM table_name; -+-------------------------------------------+ -| approx_top_k(column_name, 3) | -+-------------------------------------------+ -| [{value: foo, count: 3}, {value: bar, count: 2}, {value: baz, count: 1}] | -+-------------------------------------------+ ++-----------------------------------------------------------------------------+ +| approx_top_k(column_name,Int64(3)) | ++-----------------------------------------------------------------------------+ +| [{value: foo, count: 3}, {value: bar, count: 2}, {value: baz, count: 1}] | ++-----------------------------------------------------------------------------+ ``` diff --git a/testing b/testing index 7df2b70baf4f0..0d60ccae40d0e 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 7df2b70baf4f081ebf8e0c6bd22745cf3cbfd824 +Subproject commit 0d60ccae40d0e8f2d22c15fafb01c5d4be8c63a6 From 8d4f0f6b6b75dd32ed4f18964e9c2cbd19b2b026 Mon Sep 17 00:00:00 2001 From: Sergio Esteves Date: Wed, 22 Apr 2026 17:24:37 +0100 Subject: [PATCH 3/3] fix: resolve CI failures for approx_top_k - Fix testing submodule pointer (reverted to match upstream main) - Fix Cargo.toml formatting (taplo) - Fix unresolved rustdoc link: [`size`] -> `size()` --- datafusion/functions-aggregate/Cargo.toml | 2 +- datafusion/functions-aggregate/src/approx_top_k.rs | 2 +- testing | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 3d9780c15be0a..d935e712c0056 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -50,9 +50,9 @@ datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } -hashbrown = { workspace = true } foldhash = "0.2" half = { workspace = true } +hashbrown = { workspace = true } log = { workspace = true } num-traits = { workspace = true } diff --git a/datafusion/functions-aggregate/src/approx_top_k.rs b/datafusion/functions-aggregate/src/approx_top_k.rs index 3a6aec7e0b7e2..b374dd75dcd95 100644 --- a/datafusion/functions-aggregate/src/approx_top_k.rs +++ b/datafusion/functions-aggregate/src/approx_top_k.rs @@ -145,7 +145,7 @@ struct SpaceSavingSummary { /// Set to `max(64, requested_capacity * 2)`. target_capacity: usize, /// Running total of heap bytes owned by counter item `Vec`s. - /// Updated on push / evict / clone so that [`size`] is O(1). + /// Updated on push / evict / clone so that `size()` is O(1). item_heap_bytes: usize, } diff --git a/testing b/testing index 0d60ccae40d0e..7df2b70baf4f0 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 0d60ccae40d0e8f2d22c15fafb01c5d4be8c63a6 +Subproject commit 7df2b70baf4f081ebf8e0c6bd22745cf3cbfd824