From 4675c43305cc3b4f98f067713271176903a5ff6c Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 14 Apr 2026 19:57:04 +0530 Subject: [PATCH 01/12] Push TopK (Sort with fetch) through outer joins --- datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + .../src/push_down_topk_through_join.rs | 405 ++++++++++++++++++ .../push_down_topk_through_join.slt | 176 ++++++++ 4 files changed, 584 insertions(+) create mode 100644 datafusion/optimizer/src/push_down_topk_through_join.rs create mode 100644 datafusion/sqllogictest/test_files/push_down_topk_through_join.slt diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e610091824092..e8309a3ceb028 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -65,6 +65,7 @@ pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; +pub mod push_down_topk_through_join; pub mod replace_distinct_aggregate; pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index bdea6a83072cd..1f9d1de863239 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,6 +51,7 @@ use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; +use crate::push_down_topk_through_join::PushDownTopKThroughJoin; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -296,6 +297,7 @@ impl Optimizer { Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit Arc::new(PushDownLimit::new()), + Arc::new(PushDownTopKThroughJoin::new()), Arc::new(PushDownFilter::new()), Arc::new(SingleDistinctToGroupBy::new()), // The previous optimizations added expressions and projections, diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs new file mode 100644 index 0000000000000..d8f18d9a9ec30 --- /dev/null +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -0,0 +1,405 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through outer joins +//! +//! When a `Sort` with a fetch limit sits above an outer join and all sort +//! expressions come from the **preserved** side, this rule inserts a copy +//! of the `Sort(fetch)` on that input to reduce the number of rows +//! entering the join. +//! +//! This is correct because: +//! - A LEFT JOIN preserves every left row (each appears at least once in the +//! output). The final top-N by left-side columns must come from the top-N +//! left rows. +//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side +//! columns. +//! +//! The top-level sort is kept for correctness since a 1-to-many join can +//! produce more than N output rows from N input rows. +//! +//! ## Example +//! +//! Before: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Scan: t1 ← scans ALL rows +//! Scan: t2 +//! ``` +//! +//! After: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Sort: t1.b ASC, fetch=3 ← pushed down +//! Scan: t1 +//! Scan: t2 +//! ``` + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use crate::utils::{has_all_column_refs, schema_columns}; +use datafusion_common::Result; +use datafusion_common::tree_node::Transformed; +use datafusion_expr::logical_plan::{JoinType, LogicalPlan, Sort as SortPlan}; + +/// Optimization rule that pushes TopK (Sort with fetch) through +/// LEFT / RIGHT outer joins when all sort expressions come from +/// the preserved side. +/// +/// See module-level documentation for details. +#[derive(Default, Debug)] +pub struct PushDownTopKThroughJoin; + +impl PushDownTopKThroughJoin { + #[expect(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownTopKThroughJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // Match Sort with fetch (TopK) + let LogicalPlan::Sort(sort) = &plan else { + return Ok(Transformed::no(plan)); + }; + let Some(fetch) = sort.fetch else { + return Ok(Transformed::no(plan)); + }; + + // Check if the child is a Join (look through Projection) + let (has_projection, join) = match sort.input.as_ref() { + LogicalPlan::Join(join) => (false, join), + LogicalPlan::Projection(proj) => match proj.input.as_ref() { + LogicalPlan::Join(join) => (true, join), + _ => return Ok(Transformed::no(plan)), + }, + _ => return Ok(Transformed::no(plan)), + }; + + // Only LEFT or RIGHT, no non-equijoin filter + let preserved_is_left = match join.join_type { + JoinType::Left => true, + JoinType::Right => false, + _ => return Ok(Transformed::no(plan)), + }; + if join.filter.is_some() { + return Ok(Transformed::no(plan)); + } + + // Check all sort expression columns come from the preserved side + let preserved_schema = if preserved_is_left { + join.left.schema() + } else { + join.right.schema() + }; + let preserved_cols = schema_columns(preserved_schema); + + let all_from_preserved = sort + .expr + .iter() + .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); + if !all_from_preserved { + return Ok(Transformed::no(plan)); + } + + // Don't push if preserved child is already a Sort (redundant) + let preserved_child = if preserved_is_left { + &join.left + } else { + &join.right + }; + if matches!(preserved_child.as_ref(), LogicalPlan::Sort(_)) { + return Ok(Transformed::no(plan)); + } + + // Create the new Sort(fetch) on the preserved child + let new_child_sort = Arc::new(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })); + + // Reconstruct the join with the new child + let mut new_join = join.clone(); + if preserved_is_left { + new_join.left = new_child_sort; + } else { + new_join.right = new_child_sort; + } + + // Rebuild the tree: join → optional projection → top-level sort + let new_join_plan = LogicalPlan::Join(new_join); + let new_sort_input = if has_projection { + // Reconstruct the Projection with the new join + let LogicalPlan::Projection(proj) = sort.input.as_ref() else { + unreachable!() + }; + let mut new_proj = proj.clone(); + new_proj.input = Arc::new(new_join_plan); + Arc::new(LogicalPlan::Projection(new_proj)) + } else { + Arc::new(new_join_plan) + }; + + Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: new_sort_input, + fetch: sort.fetch, + }))) + } + + fn name(&self) -> &str { + "push_down_topk_through_join" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::OptimizerContext; + use crate::assert_optimized_plan_eq_snapshot; + use crate::test::*; + + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownTopKThroughJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + + /// TopK on left-side columns above a LEFT JOIN → pushed to left child. + #[test] + fn topk_pushed_to_left_of_left_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// TopK on right-side columns above a RIGHT JOIN → pushed to right child. + #[test] + fn topk_pushed_to_right_of_right_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Right, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(5))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=5 + Right Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=5 + TableScan: t2 + " + ) + } + + /// TopK pushed through a Projection between Sort and Join. + #[test] + fn topk_pushed_through_projection() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .project(vec![col("t1.a"), col("t1.b"), col("t2.c")])? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Projection: t1.a, t1.b, t2.c + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// INNER JOIN → no pushdown. + #[test] + fn topk_not_pushed_for_inner_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Inner, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// LEFT JOIN but sort on right-side columns → no pushdown. + #[test] + fn topk_not_pushed_for_wrong_side() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Join with a non-equijoin filter → no pushdown (conservative). + #[test] + fn topk_not_pushed_with_join_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + vec![col("t1.a").eq(col("t2.a"))], + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Left Join: Filter: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Sort without fetch (unbounded) → no pushdown. + #[test] + fn topk_not_pushed_without_fetch() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort(vec![col("t1.b").sort(true, false)])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } +} \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt new file mode 100644 index 0000000000000..ef6858c406b8f --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for pushing TopK (Sort with fetch) through outer joins + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.explain.logical_plan_only = true; + +# Create test tables +statement ok +CREATE TABLE t1 (a INT, b INT, c VARCHAR) AS VALUES + (1, 10, 'one'), + (2, 20, 'two'), + (3, 30, 'three'), + (4, 40, 'four'), + (5, 50, 'five'); + +statement ok +CREATE TABLE t2 (x INT, y INT, z VARCHAR) AS VALUES + (1, 100, 'alpha'), + (2, 200, 'beta'), + (3, 300, 'gamma'), + (6, 600, 'delta'), + (7, 700, 'epsilon'); + +### +### Positive cases — TopK should be pushed down +### + +# LEFT JOIN: TopK on left-side columns pushed to left child +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness of the above query +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# RIGHT JOIN: TopK on right-side columns pushed to right child +query TT +EXPLAIN SELECT t1.a, t2.x, t2.y +FROM t1 RIGHT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--Right Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +# Verify correctness +query III +SELECT t1.a, t2.x, t2.y +FROM t1 RIGHT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +1 1 100 +2 2 200 +3 3 300 + +### +### Negative cases — TopK should NOT be pushed down +### + +# INNER JOIN: no pushdown +query TT +EXPLAIN SELECT t1.a, t2.x +FROM t1 INNER JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Projection: t1.a, t2.x +02)--Sort: t1.b ASC NULLS LAST, fetch=3 +03)----Projection: t1.a, t2.x, t1.b +04)------Inner Join: t1.a = t2.x +05)--------TableScan: t1 projection=[a, b] +06)--------TableScan: t2 projection=[x] + +# LEFT JOIN but sort on right-side columns: no pushdown +query TT +EXPLAIN SELECT t1.a, t2.x, t2.y +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x, y] + +# FULL OUTER JOIN: no pushdown +query TT +EXPLAIN SELECT t1.a, t2.x +FROM t1 FULL OUTER JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Projection: t1.a, t2.x +02)--Sort: t1.b ASC NULLS LAST, fetch=3 +03)----Projection: t1.a, t2.x, t1.b +04)------Full Join: t1.a = t2.x +05)--------TableScan: t1 projection=[a, b] +06)--------TableScan: t2 projection=[x] + +# LEFT JOIN with non-equijoin filter: no pushdown (conservative) +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > t2.y +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Projection: t1.a, t1.b, t2.x +03)----Left Join: t1.a = t2.x Filter: t1.b > t2.y +04)------TableScan: t1 projection=[a, b] +05)------TableScan: t2 projection=[x, y] + +# Sort without LIMIT: no pushdown +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] + +### +### Config reset +### + +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +reset datafusion.explain.logical_plan_only; + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; \ No newline at end of file From 9aede677a207a38d56eacb942573d61629546313 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 14 Apr 2026 19:58:49 +0530 Subject: [PATCH 02/12] lint fix --- datafusion/optimizer/src/push_down_topk_through_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index d8f18d9a9ec30..24977b215c400 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -402,4 +402,4 @@ mod test { " ) } -} \ No newline at end of file +} From 19b0edc4e4bfe188924e7c14cdc27065202007eb Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 14 Apr 2026 20:48:00 +0530 Subject: [PATCH 03/12] fix build failure --- datafusion/sqllogictest/test_files/explain.slt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 467afe7b6c2ba..3628f6a70ccd1 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -193,6 +193,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE @@ -217,6 +218,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE @@ -565,6 +567,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE @@ -589,6 +592,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE From baf25ef47f6339eef88f33c559f0cefcc0367327 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Fri, 17 Apr 2026 15:08:48 +0530 Subject: [PATCH 04/12] Handle edge cases --- .../src/push_down_topk_through_join.rs | 348 +++++++++++++++++- .../push_down_topk_through_join.slt | 219 ++++++++++- 2 files changed, 551 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 24977b215c400..cd42cfd00797b 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -57,9 +57,12 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use crate::utils::{has_all_column_refs, schema_columns}; -use datafusion_common::Result; -use datafusion_common::tree_node::Transformed; -use datafusion_expr::logical_plan::{JoinType, LogicalPlan, Sort as SortPlan}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, Result}; +use datafusion_expr::logical_plan::{ + JoinType, LogicalPlan, Projection, Sort as SortPlan, +}; +use datafusion_expr::{Expr, SortExpr}; /// Optimization rule that pushes TopK (Sort with fetch) through /// LEFT / RIGHT outer joins when all sort expressions come from @@ -104,17 +107,29 @@ impl OptimizerRule for PushDownTopKThroughJoin { _ => return Ok(Transformed::no(plan)), }; - // Only LEFT or RIGHT, no non-equijoin filter + // Only outer/semi/anti joins where the preserved side is known. + // No non-equijoin filter (conservative — filter may change row count). let preserved_is_left = match join.join_type { - JoinType::Left => true, - JoinType::Right => false, + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => true, + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => false, _ => return Ok(Transformed::no(plan)), }; if join.filter.is_some() { return Ok(Transformed::no(plan)); } - // Check all sort expression columns come from the preserved side + // Check all sort expression columns come from the preserved side. + // When there's a projection, resolve sort expressions through it first + // since the sort references post-projection columns. + let resolved_sort_exprs = if has_projection { + let LogicalPlan::Projection(proj) = sort.input.as_ref() else { + unreachable!() + }; + resolve_sort_exprs_through_projection(&sort.expr, proj)? + } else { + sort.expr.clone() + }; + let preserved_schema = if preserved_is_left { join.left.schema() } else { @@ -122,28 +137,65 @@ impl OptimizerRule for PushDownTopKThroughJoin { }; let preserved_cols = schema_columns(preserved_schema); - let all_from_preserved = sort - .expr + let all_from_preserved = resolved_sort_exprs .iter() .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); if !all_from_preserved { return Ok(Transformed::no(plan)); } - // Don't push if preserved child is already a Sort (redundant) + // Push through when the preserved child has no Sort, or has a Sort + // with a larger/no fetch limit (our tighter limit reduces data further). + // + // Example (push): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Child limits to 10, our tighter fetch=5 reduces data further. + // + // Example (push): Sort(a ASC, fetch=5) → Join → Sort(a ASC) + // Child has no fetch (full sort), adding fetch=5 limits early. + // + // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) + // Child already limits to 3 rows, pushing fetch=5 won't help. let preserved_child = if preserved_is_left { &join.left } else { &join.right }; - if matches!(preserved_child.as_ref(), LogicalPlan::Sort(_)) { + if let LogicalPlan::Sort(child_sort) = preserved_child.as_ref() { + // Compare using resolved expressions since the parent sort may + // reference post-projection column names while the child uses + // pre-projection expressions. + let same_exprs = child_sort.expr == resolved_sort_exprs; + let child_fetch_tighter = match child_sort.fetch { + Some(child_fetch) => child_fetch <= fetch, + None => false, + }; + if same_exprs && child_fetch_tighter { + return Ok(Transformed::no(plan)); + } + } + + // Don't push if any sort expression is non-deterministic (e.g. random()). + // Duplicating such expressions would produce different values at each + // evaluation point, potentially changing the result. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { return Ok(Transformed::no(plan)); } - // Create the new Sort(fetch) on the preserved child + // Create the new Sort(fetch) on the preserved child. + // Use the resolved expressions (pre-projection) for the pushed Sort. + // + // If the child is already a Sort with the same expressions but a larger + // fetch, tighten its fetch in-place instead of stacking a redundant Sort + // on top. + let (sort_input, sort_exprs) = match preserved_child.as_ref() { + LogicalPlan::Sort(child_sort) if child_sort.expr == resolved_sort_exprs => { + (Arc::clone(&child_sort.input), child_sort.expr.clone()) + } + _ => (Arc::clone(preserved_child), resolved_sort_exprs), + }; let new_child_sort = Arc::new(LogicalPlan::Sort(SortPlan { - expr: sort.expr.clone(), - input: Arc::clone(preserved_child), + expr: sort_exprs, + input: sort_input, fetch: Some(fetch), })); @@ -185,6 +237,63 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// For example, if sort expr is `b ASC` and projection has `-t1.b AS b`, +/// the resolved sort expr becomes `-t1.b ASC`. +/// +/// Before: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// t1 +/// t2 +/// ``` +/// +/// After resolving, the pushed Sort uses pre-projection expressions: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// Sort: -t1.b ASC, fetch=3 ← resolved through projection +/// t1 +/// t2 +/// ``` +fn resolve_sort_exprs_through_projection( + sort_exprs: &[SortExpr], + projection: &Projection, +) -> Result> { + // Build map: output column name → underlying expression + let replace_map: std::collections::HashMap = projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect(); + + sort_exprs + .iter() + .map(|sort_expr| { + let new_expr = sort_expr.expr.clone().transform(|expr| { + let replacement = match &expr { + Expr::Column(col) => replace_map.get(&col.flat_name()).cloned(), + _ => None, + }; + Ok(replacement.map_or_else(|| Transformed::no(expr), Transformed::yes)) + })?; + Ok(SortExpr { + expr: new_expr.data, + ..*sort_expr + }) + }) + .collect() +} + #[cfg(test)] mod test { use super::*; @@ -192,7 +301,8 @@ mod test { use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; - use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_expr::col; + use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; macro_rules! assert_optimized_plan_equal { ( @@ -402,4 +512,212 @@ mod test { " ) } + + /// LEFT SEMI JOIN: TopK on left-side columns → pushed to left child. + #[test] + fn topk_pushed_for_left_semi_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::LeftSemi, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + LeftSemi Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// LEFT ANTI JOIN: TopK on left-side columns → pushed to left child. + #[test] + fn topk_pushed_for_left_anti_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::LeftAnti, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + LeftAnti Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// RIGHT SEMI JOIN: TopK on right-side columns → pushed to right child. + #[test] + fn topk_pushed_for_right_semi_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::RightSemi, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + RightSemi Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=3 + TableScan: t2 + " + ) + } + + /// RIGHT ANTI JOIN: TopK on right-side columns → pushed to right child. + #[test] + fn topk_pushed_for_right_anti_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::RightAnti, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + RightAnti Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=3 + TableScan: t2 + " + ) + } + + /// Multi-column sort with columns from both sides → no pushdown. + #[test] + fn topk_not_pushed_for_mixed_side_sort() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit( + vec![col("t1.b").sort(true, false), col("t2.b").sort(true, false)], + Some(3), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, t2.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Preserved child has a larger fetch → push our tighter limit. + #[test] + fn topk_pushed_when_child_has_larger_fetch() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Child already has Sort(b ASC, fetch=10); our outer Sort has fetch=3 (tighter). + let t1_with_sort = LogicalPlanBuilder::from(t1) + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(10))? + .build()?; + + let plan = LogicalPlanBuilder::from(t1_with_sort) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Preserved child already has a tighter fetch → skip pushdown. + #[test] + fn topk_not_pushed_when_child_has_smaller_fetch() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Child already has Sort(b ASC, fetch=2); our outer Sort has fetch=5 (looser). + let t1_with_sort = LogicalPlanBuilder::from(t1) + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(2))? + .build()?; + + let plan = LogicalPlanBuilder::from(t1_with_sort) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(5))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=5 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=2 + TableScan: t1 + TableScan: t2 + " + ) + } } diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index ef6858c406b8f..b3b8f987aa2e6 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -159,6 +159,223 @@ logical_plan 03)----TableScan: t1 projection=[a, b] 04)----TableScan: t2 projection=[x] +### +### Sort child cases — push vs skip based on existing child Sort +### + +# Child has larger fetch: push our tighter limit +# The inner Sort(fetch=5) has a larger limit than our outer Sort(fetch=2), +# so pushing fetch=2 to the preserved child reduces data further. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.b ASC NULLS LAST, fetch=5 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Child has smaller fetch with same sort: skip (already tighter) +# The inner Sort(fetch=2) already has a tighter limit than our outer Sort(fetch=5), +# so pushing fetch=5 would be redundant. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 2) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 5; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=5 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 2) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 5; +---- +1 10 1 +2 20 2 + +### +### Semi/Anti join cases — pushdown supported +### + +# LEFT SEMI JOIN: push to left child +query TT +EXPLAIN SELECT t1.a, t1.b +FROM t1 LEFT SEMI JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--LeftSemi Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# LEFT ANTI JOIN: push to left child +query TT +EXPLAIN SELECT t1.a, t1.b +FROM t1 LEFT ANTI JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--LeftAnti Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# RIGHT SEMI JOIN: push to right child +query TT +EXPLAIN SELECT t2.x, t2.y +FROM t1 RIGHT SEMI JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--RightSemi Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +# RIGHT ANTI JOIN: push to right child +query TT +EXPLAIN SELECT t2.x, t2.y +FROM t1 RIGHT ANTI JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--RightAnti Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +### +### Multi-column sort and OFFSET cases +### + +# ORDER BY columns from both sides: no pushdown +# Sort uses t1.b (left) and t2.y (right) — not all from preserved side +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x, t2.y +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC, t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, t2.y ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x, y] + +# Verify correctness +query IIII +SELECT t1.a, t1.b, t2.x, t2.y +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC, t2.y ASC LIMIT 3; +---- +1 10 1 100 +2 20 2 200 +3 30 3 300 + +# LIMIT with OFFSET: pushdown still applies (Sort fetch = limit + offset = 3) +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 2 OFFSET 1; +---- +logical_plan +01)Limit: skip=1, fetch=2 +02)--Sort: t1.b ASC NULLS LAST, fetch=3 +03)----Left Join: t1.a = t2.x +04)------Sort: t1.b ASC NULLS LAST, fetch=3 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# Verify correctness: skip 1, take 2 +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 2 OFFSET 1; +---- +2 20 2 +3 30 3 + +### +### Projection expression resolution cases +### + +# Sort on a projected expression: the pushed Sort should use the +# pre-projection expression, not the aliased column name. +# ORDER BY neg_b (which is -t1.b) should push Sort(-t1.b) below the join. +query TT +EXPLAIN SELECT -t1.b AS neg_b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY neg_b ASC LIMIT 3; +---- +logical_plan +01)Sort: neg_b ASC NULLS LAST, fetch=3 +02)--Projection: (- t1.b) AS neg_b, t2.x +03)----Left Join: t1.a = t2.x +04)------Sort: (- t1.b) ASC NULLS LAST, fetch=3 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# Verify correctness: -b ascending means largest b first +query II +SELECT -t1.b AS neg_b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY neg_b ASC LIMIT 3; +---- +-50 NULL +-40 NULL +-30 3 + +# Non-deterministic sort expression (random()): no pushdown +# Duplicating random() would produce different values at each evaluation point. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b + random() ASC LIMIT 3; +---- +logical_plan +01)Sort: CAST(t1.b AS Float64) + random() ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] + ### ### Config reset ### @@ -173,4 +390,4 @@ statement ok DROP TABLE t1; statement ok -DROP TABLE t2; \ No newline at end of file +DROP TABLE t2; From 67f92658b9f21740cf90938a9bd89ff1fc7dd661 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Fri, 17 Apr 2026 15:16:04 +0530 Subject: [PATCH 05/12] Handle volatile expr early --- .../optimizer/src/push_down_topk_through_join.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index cd42cfd00797b..22711f5aba54b 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -97,6 +97,13 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); }; + // Don't push if any sort expression is non-deterministic (e.g. random()). + // Duplicating such expressions would produce different values at each + // evaluation point, potentially changing the result. + if sort.expr.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + // Check if the child is a Join (look through Projection) let (has_projection, join) = match sort.input.as_ref() { LogicalPlan::Join(join) => (false, join), @@ -174,13 +181,6 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } - // Don't push if any sort expression is non-deterministic (e.g. random()). - // Duplicating such expressions would produce different values at each - // evaluation point, potentially changing the result. - if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { - return Ok(Transformed::no(plan)); - } - // Create the new Sort(fetch) on the preserved child. // Use the resolved expressions (pre-projection) for the pushed Sort. // From d12aefa983f4af9d2b62520b06bd0cf87546e09d Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Fri, 17 Apr 2026 17:25:49 +0530 Subject: [PATCH 06/12] Fix build failure --- .../src/push_down_topk_through_join.rs | 37 +++++++++---------- .../push_down_topk_through_join.slt | 28 +++++++------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 22711f5aba54b..fd13f864390c8 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -114,11 +114,14 @@ impl OptimizerRule for PushDownTopKThroughJoin { _ => return Ok(Transformed::no(plan)), }; - // Only outer/semi/anti joins where the preserved side is known. + // Only outer joins where the preserved side is known. + // Semi/Anti joins are excluded: not all preserved-side rows appear in + // the output (only matched/unmatched rows do), so pushing fetch=N to + // the preserved child can drop rows that would have survived the filter. // No non-equijoin filter (conservative — filter may change row count). let preserved_is_left = match join.join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => true, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => false, + JoinType::Left => true, + JoinType::Right => false, _ => return Ok(Transformed::no(plan)), }; if join.filter.is_some() { @@ -513,9 +516,9 @@ mod test { ) } - /// LEFT SEMI JOIN: TopK on left-side columns → pushed to left child. + /// LEFT SEMI JOIN: pushing fetch is unsafe (not all left rows appear in output). #[test] - fn topk_pushed_for_left_semi_join() -> Result<()> { + fn topk_not_pushed_for_left_semi_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -534,16 +537,15 @@ mod test { @r" Sort: t1.b ASC NULLS LAST, fetch=3 LeftSemi Join: t1.a = t2.a - Sort: t1.b ASC NULLS LAST, fetch=3 - TableScan: t1 + TableScan: t1 TableScan: t2 " ) } - /// LEFT ANTI JOIN: TopK on left-side columns → pushed to left child. + /// LEFT ANTI JOIN: pushing fetch is unsafe (not all left rows appear in output). #[test] - fn topk_pushed_for_left_anti_join() -> Result<()> { + fn topk_not_pushed_for_left_anti_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -562,16 +564,15 @@ mod test { @r" Sort: t1.b ASC NULLS LAST, fetch=3 LeftAnti Join: t1.a = t2.a - Sort: t1.b ASC NULLS LAST, fetch=3 - TableScan: t1 + TableScan: t1 TableScan: t2 " ) } - /// RIGHT SEMI JOIN: TopK on right-side columns → pushed to right child. + /// RIGHT SEMI JOIN: pushing fetch is unsafe (not all right rows appear in output). #[test] - fn topk_pushed_for_right_semi_join() -> Result<()> { + fn topk_not_pushed_for_right_semi_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -591,15 +592,14 @@ mod test { Sort: t2.b ASC NULLS LAST, fetch=3 RightSemi Join: t1.a = t2.a TableScan: t1 - Sort: t2.b ASC NULLS LAST, fetch=3 - TableScan: t2 + TableScan: t2 " ) } - /// RIGHT ANTI JOIN: TopK on right-side columns → pushed to right child. + /// RIGHT ANTI JOIN: pushing fetch is unsafe (not all right rows appear in output). #[test] - fn topk_pushed_for_right_anti_join() -> Result<()> { + fn topk_not_pushed_for_right_anti_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -619,8 +619,7 @@ mod test { Sort: t2.b ASC NULLS LAST, fetch=3 RightAnti Join: t1.a = t2.a TableScan: t1 - Sort: t2.b ASC NULLS LAST, fetch=3 - TableScan: t2 + TableScan: t2 " ) } diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index b3b8f987aa2e6..1b1aebeec4355 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -228,10 +228,12 @@ ORDER BY b ASC LIMIT 5; 2 20 2 ### -### Semi/Anti join cases — pushdown supported +### Semi/Anti join cases — pushdown NOT supported +### (not all preserved-side rows appear in output, so pushing fetch +### could drop rows that would have survived the semi/anti filter) ### -# LEFT SEMI JOIN: push to left child +# LEFT SEMI JOIN: no pushdown query TT EXPLAIN SELECT t1.a, t1.b FROM t1 LEFT SEMI JOIN t2 ON t1.a = t2.x @@ -240,11 +242,10 @@ ORDER BY t1.b ASC LIMIT 3; logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--LeftSemi Join: t1.a = t2.x -03)----Sort: t1.b ASC NULLS LAST, fetch=3 -04)------TableScan: t1 projection=[a, b] -05)----TableScan: t2 projection=[x] +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] -# LEFT ANTI JOIN: push to left child +# LEFT ANTI JOIN: no pushdown query TT EXPLAIN SELECT t1.a, t1.b FROM t1 LEFT ANTI JOIN t2 ON t1.a = t2.x @@ -253,11 +254,10 @@ ORDER BY t1.b ASC LIMIT 3; logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--LeftAnti Join: t1.a = t2.x -03)----Sort: t1.b ASC NULLS LAST, fetch=3 -04)------TableScan: t1 projection=[a, b] -05)----TableScan: t2 projection=[x] +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] -# RIGHT SEMI JOIN: push to right child +# RIGHT SEMI JOIN: no pushdown query TT EXPLAIN SELECT t2.x, t2.y FROM t1 RIGHT SEMI JOIN t2 ON t1.a = t2.x @@ -267,10 +267,9 @@ logical_plan 01)Sort: t2.y ASC NULLS LAST, fetch=3 02)--RightSemi Join: t1.a = t2.x 03)----TableScan: t1 projection=[a] -04)----Sort: t2.y ASC NULLS LAST, fetch=3 -05)------TableScan: t2 projection=[x, y] +04)----TableScan: t2 projection=[x, y] -# RIGHT ANTI JOIN: push to right child +# RIGHT ANTI JOIN: no pushdown query TT EXPLAIN SELECT t2.x, t2.y FROM t1 RIGHT ANTI JOIN t2 ON t1.a = t2.x @@ -280,8 +279,7 @@ logical_plan 01)Sort: t2.y ASC NULLS LAST, fetch=3 02)--RightAnti Join: t1.a = t2.x 03)----TableScan: t1 projection=[a] -04)----Sort: t2.y ASC NULLS LAST, fetch=3 -05)------TableScan: t2 projection=[x, y] +04)----TableScan: t2 projection=[x, y] ### ### Multi-column sort and OFFSET cases From 902ef770174196a2cf99dc454aba3fda68f7438a Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 18 Apr 2026 19:12:57 +0530 Subject: [PATCH 07/12] Handle subquery alias --- .../src/push_down_topk_through_join.rs | 297 +++++++++++++----- .../push_down_topk_through_join.slt | 274 +++++++++++++++- 2 files changed, 486 insertions(+), 85 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index fd13f864390c8..d1c4d9c32e9f6 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -60,7 +60,7 @@ use crate::utils::{has_all_column_refs, schema_columns}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, Result}; use datafusion_expr::logical_plan::{ - JoinType, LogicalPlan, Projection, Sort as SortPlan, + JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias, }; use datafusion_expr::{Expr, SortExpr}; @@ -104,41 +104,61 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); } - // Check if the child is a Join (look through Projection) - let (has_projection, join) = match sort.input.as_ref() { - LogicalPlan::Join(join) => (false, join), - LogicalPlan::Projection(proj) => match proj.input.as_ref() { - LogicalPlan::Join(join) => (true, join), + // Peel through transparent nodes (SubqueryAlias, Projection) to find + // the Join. Track intermediate nodes so we can reconstruct the tree + // and resolve sort expressions through them. + let mut current = sort.input.as_ref(); + let mut intermediates: Vec<&LogicalPlan> = Vec::new(); + let join = loop { + match current { + LogicalPlan::Join(join) => break join, + LogicalPlan::Projection(proj) => { + intermediates.push(current); + current = proj.input.as_ref(); + } + LogicalPlan::SubqueryAlias(sq) => { + intermediates.push(current); + current = sq.input.as_ref(); + } _ => return Ok(Transformed::no(plan)), - }, - _ => return Ok(Transformed::no(plan)), + } }; // Only outer joins where the preserved side is known. // Semi/Anti joins are excluded: not all preserved-side rows appear in // the output (only matched/unmatched rows do), so pushing fetch=N to // the preserved child can drop rows that would have survived the filter. - // No non-equijoin filter (conservative — filter may change row count). + // + // Non-equijoin filters in the ON clause are safe: outer joins guarantee + // all preserved-side rows appear in the output regardless of the filter. + // The filter only controls matching (which non-preserved rows pair up), + // not which preserved rows survive. let preserved_is_left = match join.join_type { JoinType::Left => true, JoinType::Right => false, _ => return Ok(Transformed::no(plan)), }; - if join.filter.is_some() { - return Ok(Transformed::no(plan)); - } - // Check all sort expression columns come from the preserved side. - // When there's a projection, resolve sort expressions through it first - // since the sort references post-projection columns. - let resolved_sort_exprs = if has_projection { - let LogicalPlan::Projection(proj) = sort.input.as_ref() else { - unreachable!() - }; - resolve_sort_exprs_through_projection(&sort.expr, proj)? - } else { - sort.expr.clone() - }; + // Resolve sort expressions through all intermediate nodes (Projection, + // SubqueryAlias) so that column references match the join's schema. + let mut resolved_sort_exprs = sort.expr.clone(); + for node in &intermediates { + match node { + LogicalPlan::Projection(proj) => { + resolved_sort_exprs = resolve_sort_exprs_through_projection( + &resolved_sort_exprs, + proj, + )?; + } + LogicalPlan::SubqueryAlias(sq) => { + resolved_sort_exprs = resolve_sort_exprs_through_subquery_alias( + &resolved_sort_exprs, + sq, + )?; + } + _ => unreachable!(), + } + } let preserved_schema = if preserved_is_left { join.left.schema() @@ -154,6 +174,42 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); } + let preserved_child = if preserved_is_left { + &join.left + } else { + &join.right + }; + + // Resolve sort exprs further through any SubqueryAlias wrapping the + // preserved child, so we can compare with the inner Sort's expressions. + // + // After intermediate resolution, resolved_sort_exprs = [t1.b ASC]. + // The inner Sort uses [orders.b ASC]. This step maps t1.b → orders.b. + // + // ```text + // Sort(sub.b ASC, fetch=2) + // SubqueryAlias(sub) ← intermediate, already resolved + // Left Join + // SubqueryAlias(t1) ← preserved child, resolve here + // Sort(orders.b ASC, fetch=5) + // TableScan: orders + // ``` + let (inner_child, child_resolved_exprs) = match preserved_child.as_ref() { + LogicalPlan::SubqueryAlias(sq) => { + let exprs = + resolve_sort_exprs_through_subquery_alias(&resolved_sort_exprs, sq)?; + (sq.input.as_ref(), exprs) + } + _ => (preserved_child.as_ref(), resolved_sort_exprs.clone()), + }; + + // If the inner child is a Limit (PushDownLimit hasn't merged it with + // the Sort yet), skip this iteration. PushDownLimit will merge + // Limit → Sort in the next pass, then our rule will tighten the Sort. + if matches!(inner_child, LogicalPlan::Limit(_)) { + return Ok(Transformed::no(plan)); + } + // Push through when the preserved child has no Sort, or has a Sort // with a larger/no fetch limit (our tighter limit reduces data further). // @@ -165,16 +221,8 @@ impl OptimizerRule for PushDownTopKThroughJoin { // // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) // Child already limits to 3 rows, pushing fetch=5 won't help. - let preserved_child = if preserved_is_left { - &join.left - } else { - &join.right - }; - if let LogicalPlan::Sort(child_sort) = preserved_child.as_ref() { - // Compare using resolved expressions since the parent sort may - // reference post-projection column names while the child uses - // pre-projection expressions. - let same_exprs = child_sort.expr == resolved_sort_exprs; + if let LogicalPlan::Sort(child_sort) = inner_child { + let same_exprs = sort_exprs_equal(&child_sort.expr, &child_resolved_exprs); let child_fetch_tighter = match child_sort.fetch { Some(child_fetch) => child_fetch <= fetch, None => false, @@ -185,44 +233,67 @@ impl OptimizerRule for PushDownTopKThroughJoin { } // Create the new Sort(fetch) on the preserved child. - // Use the resolved expressions (pre-projection) for the pushed Sort. + // Use the resolved expressions for the pushed Sort. + // + // If the inner child is already a Sort with the same expressions but a + // larger fetch, tighten its fetch in-place instead of stacking a + // redundant Sort on top. // - // If the child is already a Sort with the same expressions but a larger - // fetch, tighten its fetch in-place instead of stacking a redundant Sort - // on top. - let (sort_input, sort_exprs) = match preserved_child.as_ref() { - LogicalPlan::Sort(child_sort) if child_sort.expr == resolved_sort_exprs => { - (Arc::clone(&child_sort.input), child_sort.expr.clone()) + // When the preserved child is wrapped in SubqueryAlias, the new Sort + // must sit INSIDE the SubqueryAlias (between it and its input), using + // inner-schema column names. + let inner_input: &Arc = match preserved_child.as_ref() { + LogicalPlan::SubqueryAlias(sq) => &sq.input, + _ => preserved_child, + }; + let new_inner_child = match inner_child { + LogicalPlan::Sort(child_sort) + if sort_exprs_equal(&child_sort.expr, &child_resolved_exprs) => + { + Arc::new(LogicalPlan::Sort(SortPlan { + expr: child_sort.expr.clone(), + input: Arc::clone(&child_sort.input), + fetch: Some(fetch), + })) } - _ => (Arc::clone(preserved_child), resolved_sort_exprs), + _ => Arc::new(LogicalPlan::Sort(SortPlan { + expr: child_resolved_exprs, + input: Arc::clone(inner_input), + fetch: Some(fetch), + })), + }; + + // Wrap the new Sort back in SubqueryAlias if the preserved child had one. + let new_preserved_child = match preserved_child.as_ref() { + LogicalPlan::SubqueryAlias(sq) => Arc::new(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_inner_child, sq.alias.clone())?, + )), + _ => new_inner_child, }; - let new_child_sort = Arc::new(LogicalPlan::Sort(SortPlan { - expr: sort_exprs, - input: sort_input, - fetch: Some(fetch), - })); // Reconstruct the join with the new child let mut new_join = join.clone(); if preserved_is_left { - new_join.left = new_child_sort; + new_join.left = new_preserved_child; } else { - new_join.right = new_child_sort; + new_join.right = new_preserved_child; } - // Rebuild the tree: join → optional projection → top-level sort - let new_join_plan = LogicalPlan::Join(new_join); - let new_sort_input = if has_projection { - // Reconstruct the Projection with the new join - let LogicalPlan::Projection(proj) = sort.input.as_ref() else { - unreachable!() - }; - let mut new_proj = proj.clone(); - new_proj.input = Arc::new(new_join_plan); - Arc::new(LogicalPlan::Projection(new_proj)) - } else { - Arc::new(new_join_plan) - }; + // Rebuild the tree: join → intermediate nodes → top-level sort + let mut new_sort_input = Arc::new(LogicalPlan::Join(new_join)); + for node in intermediates.into_iter().rev() { + new_sort_input = Arc::new(match node { + LogicalPlan::Projection(proj) => { + let mut new_proj = proj.clone(); + new_proj.input = new_sort_input; + LogicalPlan::Projection(new_proj) + } + LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, + ), + _ => unreachable!(), + }); + } Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { expr: sort.expr.clone(), @@ -264,21 +335,21 @@ impl OptimizerRule for PushDownTopKThroughJoin { /// t1 /// t2 /// ``` -fn resolve_sort_exprs_through_projection( +/// Replace column references in sort expressions using a name→expr map. +/// Uses `transform()` for deep replacement (handles nested expressions +/// like `-t1.b` where the column is inside a Negative wrapper). +/// +/// Example with `replace_map = {"sub.b" → Column(t1.b)}`: +/// +/// ```text +/// Input: [sub.b ASC] → Output: [t1.b ASC] (simple column) +/// Input: [(- sub.b) ASC] → Output: [(- t1.b) ASC] (nested column) +/// Input: [sub.a ASC, sub.b ASC] → Output: [t1.a ASC, t1.b ASC] (multiple) +/// ``` +fn replace_columns_in_sort_exprs( sort_exprs: &[SortExpr], - projection: &Projection, + replace_map: &std::collections::HashMap, ) -> Result> { - // Build map: output column name → underlying expression - let replace_map: std::collections::HashMap = projection - .schema - .iter() - .zip(projection.expr.iter()) - .map(|((qualifier, field), expr)| { - let key = Column::from((qualifier, field)).flat_name(); - (key, expr.clone().unalias()) - }) - .collect(); - sort_exprs .iter() .map(|sort_expr| { @@ -297,6 +368,75 @@ fn resolve_sort_exprs_through_projection( .collect() } +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// Example: sort expr is `neg_b ASC` and projection has `-t1.b AS neg_b`: +/// +/// ```text +/// Input sort exprs: [neg_b ASC] +/// Output sort exprs: [(- t1.b) ASC] +/// ``` +fn resolve_sort_exprs_through_projection( + sort_exprs: &[SortExpr], + projection: &Projection, +) -> Result> { + let replace_map: std::collections::HashMap = projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect(); + + replace_columns_in_sort_exprs(sort_exprs, &replace_map) +} + +/// Compare two slices of `SortExpr` using `flat_name()` for column identity. +/// +/// `Column` derives `PartialEq` which compares the `relation` field +/// (`Option`) structurally. A `Bare("t1")` and +/// `Full { catalog, schema, table: "t1" }` are NOT equal even though they +/// refer to the same column. After resolving through SubqueryAlias the +/// variant may differ, so we compare by flat_name() instead. +fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { + a.len() == b.len() + && a.iter().zip(b.iter()).all(|(left, right)| { + left.asc == right.asc + && left.nulls_first == right.nulls_first + && left.expr.to_string() == right.expr.to_string() + }) +} + +/// Resolve sort expressions through a SubqueryAlias by replacing the alias +/// qualifier with the input schema's qualifier. +/// +/// Example: SubqueryAlias is `sub` wrapping a join whose left input is `t1`: +/// +/// ```text +/// Input sort exprs: [sub.b ASC] +/// Output sort exprs: [t1.b ASC] +/// ``` +fn resolve_sort_exprs_through_subquery_alias( + sort_exprs: &[SortExpr], + subquery_alias: &SubqueryAlias, +) -> Result> { + let replace_map: std::collections::HashMap = subquery_alias + .schema + .iter() + .zip(subquery_alias.input.schema().iter()) + .map(|((alias_qual, alias_field), (input_qual, input_field))| { + let alias_col = Column::from((alias_qual, alias_field)); + let input_col = Column::from((input_qual, input_field)); + (alias_col.flat_name(), Expr::Column(input_col)) + }) + .collect(); + + replace_columns_in_sort_exprs(sort_exprs, &replace_map) +} + #[cfg(test)] mod test { use super::*; @@ -463,9 +603,11 @@ mod test { ) } - /// Join with a non-equijoin filter → no pushdown (conservative). + /// Join with a non-equijoin filter → pushdown still happens. + /// Outer joins preserve all rows from the preserved side regardless + /// of the ON filter. #[test] - fn topk_not_pushed_with_join_filter() -> Result<()> { + fn topk_pushed_with_join_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -483,7 +625,8 @@ mod test { @r" Sort: t1.b ASC NULLS LAST, fetch=3 Left Join: Filter: t1.a = t2.a - TableScan: t1 + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 TableScan: t2 " ) diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 1b1aebeec4355..630fa85472328 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -134,7 +134,9 @@ logical_plan 05)--------TableScan: t1 projection=[a, b] 06)--------TableScan: t2 projection=[x] -# LEFT JOIN with non-equijoin filter: no pushdown (conservative) +# LEFT JOIN with non-equijoin filter on BOTH sides: pushdown OK +# Filter t1.b > t2.y is in the ON clause — it only controls matching, not +# which preserved (left) rows appear. All left rows are preserved. query TT EXPLAIN SELECT t1.a, t1.b, t2.x FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > t2.y @@ -144,8 +146,69 @@ logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--Projection: t1.a, t1.b, t2.x 03)----Left Join: t1.a = t2.x Filter: t1.b > t2.y +04)------Sort: t1.b ASC NULLS LAST, fetch=3 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x, y] + +# Verify correctness: all left rows appear, filter only affects matching +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > t2.y +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 NULL +2 20 NULL +3 30 NULL + +# LEFT JOIN with non-equijoin filter on non-preserved side only: pushdown OK +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t2.y > 100 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 04)------TableScan: t1 projection=[a, b] -05)------TableScan: t2 projection=[x, y] +05)----Projection: t2.x +06)------Filter: t2.y > Int32(100) +07)--------TableScan: t2 projection=[x, y] + +# LEFT JOIN with preserved-side-only filter: pushdown OK +# Filter t1.b > 20 prevents matching for left rows with b <= 20, +# but those rows still appear with NULL-filled right columns. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > 20 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x Filter: t1.b > Int32(20) +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness: rows with b <= 20 get NULL right columns +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > 20 +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 NULL +2 20 NULL +3 30 3 + +# Verify correctness: non-preserved side filter +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t2.y > 100 +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 NULL +2 20 2 +3 30 3 # Sort without LIMIT: no pushdown query TT @@ -164,8 +227,7 @@ logical_plan ### # Child has larger fetch: push our tighter limit -# The inner Sort(fetch=5) has a larger limit than our outer Sort(fetch=2), -# so pushing fetch=2 to the preserved child reduces data further. +# The inner Sort(fetch=5) is tightened to fetch=2 in-place. query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x @@ -179,7 +241,7 @@ logical_plan 02)--SubqueryAlias: sub 03)----Left Join: t1.a = t2.x 04)------SubqueryAlias: t1 -05)--------Sort: t1.b ASC NULLS LAST, fetch=5 +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 06)----------TableScan: t1 projection=[a, b] 07)------TableScan: t2 projection=[x] @@ -195,9 +257,9 @@ ORDER BY b ASC LIMIT 2; 1 10 1 2 20 2 -# Child has smaller fetch with same sort: skip (already tighter) -# The inner Sort(fetch=2) already has a tighter limit than our outer Sort(fetch=5), -# so pushing fetch=5 would be redundant. +# Child has smaller fetch with same sort: our rule skips (already tighter). +# PushDownLimit inserts a Sort(fetch=5) that gets collapsed with the inner +# Sort(fetch=2) to Sort(fetch=2) by stacked-sort merging. query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x @@ -374,6 +436,202 @@ logical_plan 03)----TableScan: t1 projection=[a, b] 04)----TableScan: t2 projection=[x] +### +### SubqueryAlias edge cases +### + +# SubqueryAlias without inner Sort: push new Sort through SubqueryAlias +# Preserved child is SubqueryAlias(t1, TableScan) — no existing Sort to tighten, +# so a new Sort(fetch=2) is inserted inside the SubqueryAlias. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# RIGHT JOIN with SubqueryAlias on preserved (right) side +# Inner Sort(fetch=10) is tightened to fetch=3 via stacked-sort merging. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t2.x, t2.y + FROM t1 + RIGHT JOIN (SELECT * FROM t2 ORDER BY y ASC LIMIT 10) t2 + ON t1.a = t2.x +) sub +ORDER BY y ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.y ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Right Join: t1.a = t2.x +04)------TableScan: t1 projection=[a] +05)------SubqueryAlias: t2 +06)--------Sort: t2.y ASC NULLS LAST, fetch=3 +07)----------TableScan: t2 projection=[x, y] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t2.x, t2.y + FROM t1 + RIGHT JOIN (SELECT * FROM t2 ORDER BY y ASC LIMIT 10) t2 + ON t1.a = t2.x +) sub +ORDER BY y ASC LIMIT 3; +---- +1 1 100 +2 2 200 +3 3 300 + +# SubqueryAlias with different alias name (foo ≠ t1) +# Column resolution: foo.b → t1.b through SubqueryAlias renaming. +query TT +EXPLAIN SELECT * FROM ( + SELECT foo.a, foo.b, t2.x + FROM (SELECT * FROM t1) foo + LEFT JOIN t2 ON foo.a = t2.x +) sub +ORDER BY b ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Left Join: foo.a = t2.x +04)------SubqueryAlias: foo +05)--------Sort: t1.b ASC NULLS LAST, fetch=3 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT foo.a, foo.b, t2.x + FROM (SELECT * FROM t1) foo + LEFT JOIN t2 ON foo.a = t2.x +) sub +ORDER BY b ASC LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# Sort on non-preserved side column through SubqueryAlias: no pushdown +# ORDER BY t2.x is from the non-preserved (right) side of a LEFT JOIN. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY x ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.x ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# INNER JOIN through SubqueryAlias: no pushdown (only LEFT/RIGHT) +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + INNER JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Inner Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# Multiple sort columns from preserved side through SubqueryAlias +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY a ASC, b ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.a ASC NULLS LAST, sub.b ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.a ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=3 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY a ASC, b ASC LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# WHERE filter on preserved side: pushdown still happens +# PushDownFilter pushes the WHERE filter below the Join first, +# then our rule sees Sort → Join (no Filter in between) and pushes TopK. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +WHERE t1.b > 10 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------Filter: t1.b > Int32(10) +05)--------TableScan: t1 projection=[a, b] +06)----TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +WHERE t1.b > 10 +ORDER BY t1.b ASC LIMIT 3; +---- +2 20 2 +3 30 3 +4 40 NULL + ### ### Config reset ### From 254b224b5b4c5192a645f2f778a87b899984fbd7 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 18 Apr 2026 19:28:54 +0530 Subject: [PATCH 08/12] Update comment --- .../sqllogictest/test_files/push_down_topk_through_join.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 630fa85472328..00a1f74b743c1 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -473,7 +473,7 @@ ORDER BY b ASC LIMIT 2; 2 20 2 # RIGHT JOIN with SubqueryAlias on preserved (right) side -# Inner Sort(fetch=10) is tightened to fetch=3 via stacked-sort merging. +# Inner Sort(fetch=10) is tightened to fetch=3 query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t2.x, t2.y From c648e71de92eb60a30dd9dd3d1b996458ba62c9f Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 18 Apr 2026 20:09:40 +0530 Subject: [PATCH 09/12] Doc fix --- .../sqllogictest/test_files/push_down_topk_through_join.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 00a1f74b743c1..153da8a7d3054 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -259,7 +259,7 @@ ORDER BY b ASC LIMIT 2; # Child has smaller fetch with same sort: our rule skips (already tighter). # PushDownLimit inserts a Sort(fetch=5) that gets collapsed with the inner -# Sort(fetch=2) to Sort(fetch=2) by stacked-sort merging. +# Sort(fetch=2) to Sort(fetch=2) query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x From 03f64999e467319789d7e31e77f24e73294bc539 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 21 Apr 2026 10:29:36 +0530 Subject: [PATCH 10/12] Handle volatile expr in projection --- .../src/push_down_topk_through_join.rs | 12 ++++++++++-- .../test_files/push_down_topk_through_join.slt | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index d1c4d9c32e9f6..3286319c1de3e 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -160,6 +160,14 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } + // After resolving through projections, the sort expressions may now + // contain volatile functions (e.g. `random() AS col`). Duplicating + // volatile expressions in the pushed Sort would produce different + // values, changing results. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + let preserved_schema = if preserved_is_left { join.left.schema() } else { @@ -394,13 +402,13 @@ fn resolve_sort_exprs_through_projection( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } -/// Compare two slices of `SortExpr` using `flat_name()` for column identity. +/// Compare two slices of `SortExpr` using `Expr::to_string()` for column identity. /// /// `Column` derives `PartialEq` which compares the `relation` field /// (`Option`) structurally. A `Bare("t1")` and /// `Full { catalog, schema, table: "t1" }` are NOT equal even though they /// refer to the same column. After resolving through SubqueryAlias the -/// variant may differ, so we compare by flat_name() instead. +/// variant may differ, so we compare by display string instead. fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 153da8a7d3054..ee52c59124a20 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -436,6 +436,23 @@ logical_plan 03)----TableScan: t1 projection=[a, b] 04)----TableScan: t2 projection=[x] +# Non-deterministic projected expression (random() AS col): no pushdown +# Sort references a column that resolves to random() through the projection. +query TT +EXPLAIN SELECT rand_col, t2.x +FROM ( + SELECT random() AS rand_col, t1.a, t2.x + FROM t1 LEFT JOIN t2 ON t1.a = t2.x +) +ORDER BY rand_col ASC LIMIT 3; +---- +logical_plan +01)Sort: rand_col ASC NULLS LAST, fetch=3 +02)--Projection: random() AS rand_col, t2.x +03)----Left Join: t1.a = t2.x +04)------TableScan: t1 projection=[a] +05)------TableScan: t2 projection=[x] + ### ### SubqueryAlias edge cases ### From 051868a28332926e9bd47da08c34181c28cbc9fa Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 21 Apr 2026 17:02:41 +0530 Subject: [PATCH 11/12] use structural equality --- .../optimizer/src/push_down_topk_through_join.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 3286319c1de3e..a9e4995437950 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -402,19 +402,14 @@ fn resolve_sort_exprs_through_projection( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } -/// Compare two slices of `SortExpr` using `Expr::to_string()` for column identity. -/// -/// `Column` derives `PartialEq` which compares the `relation` field -/// (`Option`) structurally. A `Bare("t1")` and -/// `Full { catalog, schema, table: "t1" }` are NOT equal even though they -/// refer to the same column. After resolving through SubqueryAlias the -/// variant may differ, so we compare by display string instead. +/// Compare two slices of `SortExpr` for equality. +/// Uses structural equality on the sort expressions fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { left.asc == right.asc && left.nulls_first == right.nulls_first - && left.expr.to_string() == right.expr.to_string() + && left.expr == right.expr }) } From 1cfeb76c07ca316446845e85cef46237ae5b37c4 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 21 Apr 2026 18:53:40 +0530 Subject: [PATCH 12/12] Adds UT --- .../src/push_down_topk_through_join.rs | 204 +++++++++++++++++- 1 file changed, 203 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index a9e4995437950..4a024731f1899 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -403,7 +403,9 @@ fn resolve_sort_exprs_through_projection( } /// Compare two slices of `SortExpr` for equality. -/// Uses structural equality on the sort expressions +/// +/// Uses structural equality on the sort expressions (direction, nulls_first, +/// and the expression tree). fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { @@ -865,4 +867,204 @@ mod test { " ) } + + // --------------------------------------------------------------- + // Unit tests for resolve_sort_exprs_through_projection + // --------------------------------------------------------------- + + /// Simple passthrough: sort on a column that projection passes through. + /// Projection: [t1.a, t1.b] → sort on t1.b resolves to t1.b + #[test] + fn resolve_through_projection_passthrough() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .project(vec![col("t1.a"), col("t1.b")])? + .build()?; + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + + let sort_exprs = vec![col("t1.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + assert!(resolved[0].asc); + Ok(()) + } + + /// Aliased expression: sort on neg_b resolves to (- t1.b) + #[test] + fn resolve_through_projection_alias() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .project(vec![ + col("t1.a"), + (Expr::Negative(Box::new(col("t1.b")))).alias("neg_b"), + ])? + .build()?; + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + + let sort_exprs = vec![col("neg_b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); + Ok(()) + } + + /// Multiple columns through projection: sort on (t1.a, t1.b) + #[test] + fn resolve_through_projection_multi_column() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .project(vec![col("t1.a"), col("t1.b"), col("t1.c")])? + .build()?; + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + + let sort_exprs = + vec![col("t1.a").sort(true, false), col("t1.b").sort(false, true)]; + let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].expr.to_string(), "t1.a"); + assert!(resolved[0].asc); + assert_eq!(resolved[1].expr.to_string(), "t1.b"); + assert!(!resolved[1].asc); + assert!(resolved[1].nulls_first); + Ok(()) + } + + /// Projection + SubqueryAlias stacked: sort resolves through both. + /// neg_b → (- sub.b) through Projection → (- t1.b) through SubqueryAlias + #[test] + fn resolve_through_projection_and_subquery_alias() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .alias("sub")? + .project(vec![ + col("sub.a"), + (Expr::Negative(Box::new(col("sub.b")))).alias("neg_b"), + ])? + .build()?; + + // Peel: Projection then SubqueryAlias + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + let LogicalPlan::SubqueryAlias(sq) = proj.input.as_ref() else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![col("neg_b").sort(true, false)]; + + // Resolve through Projection: neg_b → (- sub.b) + let after_proj = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + assert_eq!(after_proj[0].expr.to_string(), "(- sub.b)"); + + // Resolve through SubqueryAlias: (- sub.b) → (- t1.b) + let after_sq = resolve_sort_exprs_through_subquery_alias(&after_proj, sq)?; + assert_eq!(after_sq[0].expr.to_string(), "(- t1.b)"); + assert!(after_sq[0].asc); + assert!(!after_sq[0].nulls_first); + + Ok(()) + } + + // --------------------------------------------------------------- + // Unit tests for resolve_sort_exprs_through_subquery_alias + // --------------------------------------------------------------- + + /// Simple column rename: sub.b → t1.b + #[test] + fn resolve_through_subquery_alias_simple() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![col("sub.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + assert!(resolved[0].asc); + assert!(!resolved[0].nulls_first); + Ok(()) + } + + /// Multiple sort columns: sub.a ASC, sub.b DESC → t1.a ASC, t1.b DESC + #[test] + fn resolve_through_subquery_alias_multi_column() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![ + col("sub.a").sort(true, false), + col("sub.b").sort(false, true), + ]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].expr.to_string(), "t1.a"); + assert!(resolved[0].asc); + assert_eq!(resolved[1].expr.to_string(), "t1.b"); + assert!(!resolved[1].asc); + assert!(resolved[1].nulls_first); + Ok(()) + } + + /// Alias name differs from table name: foo.b → t1.b + #[test] + fn resolve_through_subquery_alias_different_name() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("foo")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![col("foo.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + Ok(()) + } + + /// Nested expression: (- sub.b) ASC → (- t1.b) ASC + #[test] + fn resolve_through_subquery_alias_nested_expr() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![SortExpr { + expr: Expr::Negative(Box::new(col("sub.b"))), + asc: true, + nulls_first: false, + }]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); + Ok(()) + } }