Skip to content
Open
79 changes: 76 additions & 3 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,14 @@ impl Unparser<'_> {
/// Reconstructs a SELECT SQL statement from a logical plan by unprojecting column expressions
/// found in a [Projection] node. This requires scanning the plan tree for relevant Aggregate
/// and Window nodes and matching column expressions to the appropriate agg or window expressions.
///
/// Returns `true` if an Aggregate node was found and claimed for this SELECT.
fn reconstruct_select_statement(
&self,
plan: &LogicalPlan,
p: &Projection,
select: &mut SelectBuilder,
) -> Result<()> {
) -> Result<bool> {
let mut exprs = p.expr.clone();

// If an Unnest node is found within the select, find and unproject the unnest column
Expand Down Expand Up @@ -299,6 +301,7 @@ impl Unparser<'_> {
.collect::<Result<Vec<_>>>()?,
vec![],
));
Ok(true)
}
(None, Some(window)) => {
let items = exprs
Expand All @@ -310,6 +313,7 @@ impl Unparser<'_> {
.collect::<Result<Vec<_>>>()?;

select.projection(items);
Ok(false)
}
_ => {
let items = exprs
Expand All @@ -328,9 +332,9 @@ impl Unparser<'_> {
})
.collect::<Result<Vec<_>>>()?;
select.projection(items);
Ok(false)
}
}
Ok(())
}

fn derive(
Expand Down Expand Up @@ -605,7 +609,76 @@ impl Unparser<'_> {
if self.dialect.unnest_as_lateral_flatten() {
Self::collect_flatten_aliases(p.input.as_ref(), select);
}
self.reconstruct_select_statement(plan, p, select)?;
let found_agg = self.reconstruct_select_statement(plan, p, select)?;

// If the Projection claimed an Aggregate by reaching through
// a Limit or Sort, fold those clauses into the current query
// and skip the node during recursion. Otherwise the Limit/Sort
// arm would see `already_projected` and wrap everything in a
// spurious derived subquery.
if found_agg {
if let LogicalPlan::Limit(limit) = p.input.as_ref() {
if let Some(fetch) = &limit.fetch {
let Some(query) = query.as_mut() else {
return internal_err!(
"Limit operator only valid in a statement context."
);
};
query.limit(Some(self.expr_to_sql(fetch)?));
}
if let Some(skip) = &limit.skip {
let Some(query) = query.as_mut() else {
return internal_err!(
"Offset operator only valid in a statement context."
);
};
query.offset(Some(ast::Offset {
rows: ast::OffsetRows::None,
value: self.expr_to_sql(skip)?,
}));
}
return self.select_to_sql_recursively(
limit.input.as_ref(),
query,
select,
relation,
);
}
if let LogicalPlan::Sort(sort) = p.input.as_ref() {
let Some(query_ref) = query.as_mut() else {
return internal_err!(
"Sort operator only valid in a statement context."
);
};
if let Some(fetch) = sort.fetch {
query_ref.limit(Some(ast::Expr::value(ast::Value::Number(
fetch.to_string(),
false,
))));
}
let agg =
find_agg_node_within_select(plan, select.already_projected());
let sort_exprs: Vec<SortExpr> = sort
.expr
.iter()
.map(|sort_expr| {
unproject_sort_expr(
sort_expr.clone(),
agg,
sort.input.as_ref(),
)
})
.collect::<Result<Vec<_>>>()?;
query_ref.order_by(self.sorts_to_sql(&sort_exprs)?);
return self.select_to_sql_recursively(
sort.input.as_ref(),
query,
select,
relation,
);
}
}

self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
}
LogicalPlan::Filter(filter) => {
Expand Down
106 changes: 106 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2953,6 +2953,112 @@ fn roundtrip_subquery_aggregate_with_column_alias() -> Result<(), DataFusionErro
Ok(())
}

/// Roundtrip: aggregate over a subquery projection with limit.
#[test]
fn roundtrip_aggregate_over_subquery() -> Result<(), DataFusionError> {
roundtrip_statement_with_dialect_helper!(
sql: r#"SELECT __agg_0 AS "min(j1_id)", __agg_1 AS "max(j1_id)" FROM (SELECT min(j1_rename) AS __agg_0, max(j1_rename) AS __agg_1 FROM (SELECT j1_id AS j1_rename FROM j1) AS bla LIMIT 20)"#,
parser_dialect: GenericDialect {},
unparser_dialect: UnparserDefaultDialect {},
expected: @r#"SELECT __agg_0 AS "min(j1_id)", __agg_1 AS "max(j1_id)" FROM (SELECT min(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1 FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla LIMIT 20)"#,
);
Ok(())
}

/// Projection → Limit → Aggregate (aliases inlined into Aggregate, no
/// intermediate Projection). Verifies the Limit is folded into the outer
/// SELECT rather than creating a spurious derived subquery.
#[test]
fn test_unparse_aggregate_over_subquery_no_inner_proj() -> Result<()> {
let context = MockContextProvider {
state: MockSessionState::default(),
};
let j1_schema = context
.get_table_source(TableReference::bare("j1"))?
.schema();

let scan = table_scan(Some("j1"), &j1_schema, None)?.build()?;
let plan = LogicalPlanBuilder::from(scan)
.project(vec![col("j1.j1_id").alias("j1_rename")])?
.alias("bla")?
.aggregate(
vec![] as Vec<Expr>,
vec![
max(col("bla.j1_rename")).alias("__agg_0"),
max(col("bla.j1_rename")).alias("__agg_1"),
],
)?
.limit(0, Some(20))?
.project(vec![
col("__agg_0").alias("max1(j1_id)"),
col("__agg_1").alias("max2(j1_id)"),
])?
.build()?;

let sql = Unparser::default().plan_to_sql(&plan)?.to_string();
insta::assert_snapshot!(sql, @r#"SELECT max(bla.j1_rename) AS "max1(j1_id)", max(bla.j1_rename) AS "max2(j1_id)" FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla LIMIT 20"#);
Ok(())
}

/// Projection → Aggregate (aliases inlined, no rename in outer Projection).
/// Verifies the aggregate aliases are preserved as output column names.
#[test]
fn test_unparse_aggregate_no_outer_rename() -> Result<()> {
let context = MockContextProvider {
state: MockSessionState::default(),
};
let j1_schema = context
.get_table_source(TableReference::bare("j1"))?
.schema();

let scan = table_scan(Some("j1"), &j1_schema, None)?.build()?;
let plan = LogicalPlanBuilder::from(scan)
.project(vec![col("j1.j1_id").alias("j1_rename")])?
.alias("bla")?
.aggregate(
vec![] as Vec<Expr>,
vec![
max(col("bla.j1_rename")).alias("__agg_0"),
max(col("bla.j1_rename")).alias("__agg_1"),
],
)?
.project(vec![col("__agg_0"), col("__agg_1")])?
.build()?;

let sql = Unparser::default().plan_to_sql(&plan)?.to_string();
insta::assert_snapshot!(sql, @"SELECT max(bla.j1_rename) AS __agg_0, max(bla.j1_rename) AS __agg_1 FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla");
Ok(())
}

/// Projection → Sort → Aggregate (aliases inlined into Aggregate).
/// Verifies the Sort is folded into the outer SELECT rather than creating
/// a spurious derived subquery.
#[test]
fn test_unparse_aggregate_with_sort_no_inner_proj() -> Result<()> {
let context = MockContextProvider {
state: MockSessionState::default(),
};
let j1_schema = context
.get_table_source(TableReference::bare("j1"))?
.schema();

let scan = table_scan(Some("j1"), &j1_schema, None)?.build()?;
let plan = LogicalPlanBuilder::from(scan)
.project(vec![col("j1.j1_id").alias("j1_rename")])?
.alias("bla")?
.aggregate(
vec![] as Vec<Expr>,
vec![max(col("bla.j1_rename")).alias("__agg_0")],
)?
.sort(vec![col("__agg_0").sort(true, true)])?
.project(vec![col("__agg_0").alias("max1(j1_id)")])?
.build()?;

let sql = Unparser::default().plan_to_sql(&plan)?.to_string();
insta::assert_snapshot!(sql, @r#"SELECT max(bla.j1_rename) AS "max1(j1_id)" FROM (SELECT j1.j1_id AS j1_rename FROM j1) AS bla ORDER BY max(bla.j1_rename) ASC NULLS FIRST"#);
Ok(())
}

/// Test that unparsing a manually constructed join with a subquery aggregate
/// preserves the MAX aggregate function.
///
Expand Down
Loading