Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ async def validate_node_data(
for cte in query_ast.ctes:
local_aliases.add(cte.alias_or_name.identifier(False))

# Lambda parameters (e.g. `c` in `c -> c.name = ...`) are also valid namespaces
# inside their lambda body and must be excluded from INVALID_COLUMN checks.
for lambda_expr in query_ast.find_all(ast.Lambda):
for ident in lambda_expr.identifiers:
local_aliases.add(ident.name)

(
dependencies_map,
missing_parents_map,
Expand Down
9 changes: 8 additions & 1 deletion datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,7 +2629,7 @@ class Lambda(Expression):
Represents a lambda expression
"""

identifiers: List[Named]
identifiers: List[Name]
expr: Expression

def __str__(self) -> str:
Expand Down Expand Up @@ -3151,6 +3151,13 @@ async def compile(self, ctx: CompileContext):
if self._is_compiled:
return

# A Query whose select is an InlineTable arises from (VALUES ...) AS alias(cols).
# Its columns are already set on the InlineTable; just expose them and return.
if isinstance(self.select, InlineTable):
self._columns = list(self.select._columns)
self._is_compiled = True
return

def _compile(info: Tuple[Column, List[TableExpression]]):
"""
Given a list of table sources, find a matching origin table for the column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1045,11 +1045,17 @@ def _(ctx: sbp.AliasedQueryContext) -> ast.Select:
query = visit(ctx.query())
query.parenthesized = True
table_alias = ctx.tableAlias()
ident, _ = visit(table_alias)
ident, col_aliases = visit(table_alias)
if ident:
query = query.set_alias(ident)
if table_alias.AS():
query = query.set_as(True)
# Apply explicit column name aliases to InlineTable columns when present.
# For example: (VALUES (1, 2)) AS v(a, b) → rename col1/col2 → a/b so that
# outer references like v.a resolve correctly.
if col_aliases and isinstance(query.select, ast.InlineTable):
for col_obj, col_name in zip(query.select._columns, col_aliases):
col_obj.name = col_name
return query


Expand Down
246 changes: 243 additions & 3 deletions datajunction-server/tests/internal/node_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from datajunction_server.database.node import Node, NodeRevision, NodeType
from datajunction_server.database.user import OAuthProvider, User
from datajunction_server.internal.validation import (
validate_node_data,
_reparse_parent_column_types,
validate_node_data,
)
from datajunction_server.models.node import NodeRevisionBase, NodeStatus
from datajunction_server.sql.parsing import types as ct
Expand Down Expand Up @@ -472,6 +472,246 @@ async def test_metric_referencing_dimension_attr_is_valid(
)


@pytest.mark.asyncio
async def test_metric_with_lambda_parameters_is_valid(
session: AsyncSession,
user: User,
):
"""
Regression test: a transform node using higher-order functions (FILTER, AGGREGATE) with
lambda expressions should remain VALID. Lambda parameters (e.g. ``c`` in ``c -> c.name``)
are valid namespaces inside the lambda body and must not be flagged as INVALID_COLUMN.

Previously, the INVALID_COLUMN surfacing fix in PR #1961 incorrectly treated lambda
parameters as unresolved table aliases, causing nodes using FILTER/AGGREGATE/TRANSFORM
with struct field access to be rejected.
"""
from datajunction_server.sql.parsing.ast import Name

_make_source(
session,
user,
"test.lambda_source",
[
Column(
name="items",
type=ct.ListType(
element_type=ct.StructType(
ct.NestedField(name=Name("key"), field_type=ct.StringType()),
ct.NestedField(name=Name("val"), field_type=ct.DoubleType()),
),
),
order=0,
),
],
)
await session.commit()

# Lambda parameters `x` and `acc` must not be treated as unresolved table aliases
data = NodeRevisionBase(
name="test.transform_with_lambda",
display_name="Transform with lambda",
type=NodeType.TRANSFORM,
query=(
"SELECT 4.0 * AGGREGATE("
" FILTER(items, x -> x.key = 'FOO'),"
" CAST(0.0 AS DOUBLE),"
" (acc, x) -> CAST(acc + x.val AS DOUBLE)"
") AS result "
"FROM test.lambda_source"
),
mode="published",
)

validator = await validate_node_data(data, session)

from datajunction_server.errors import ErrorCode

invalid_col_errors = [
e for e in validator.errors if e.code == ErrorCode.INVALID_COLUMN
]
assert not invalid_col_errors, (
f"Unexpected INVALID_COLUMN errors from lambda params: {invalid_col_errors}"
)


def _make_source(session, user, name, columns):
"""Helper: create and persist a SOURCE node with the given columns."""
node = Node(
name=name,
type=NodeType.SOURCE,
created_by_id=user.id,
current_version="v1.0",
)
revision = NodeRevision(
name=name,
display_name=name,
type=NodeType.SOURCE,
query=None,
status=NodeStatus.VALID,
version="v1.0",
node=node,
columns=columns,
created_by_id=user.id,
)
session.add(node)
session.add(revision)
return node


@pytest.mark.asyncio
async def test_lateral_view_alias_not_flagged_as_invalid_column(
session: AsyncSession,
user: User,
):
"""
Regression: LATERAL VIEW generates a virtual table alias (e.g. ``t`` in
``LATERAL VIEW explode(arr) t AS elem``) that is referenced as ``t.elem``
in the SELECT. That alias is not an ast.Table or ast.Query node, so it was
not captured in local_aliases before the fix — causing a false INVALID_COLUMN.
"""
from datajunction_server.errors import ErrorCode

_make_source(
session,
user,
"test.lateral_source",
[Column(name="arr", type=ct.ListType(element_type=ct.StringType()), order=0)],
)
await session.commit()

data = NodeRevisionBase(
name="test.lateral_transform",
display_name="Lateral view transform",
type=NodeType.TRANSFORM,
query=(
"SELECT t.elem FROM test.lateral_source LATERAL VIEW explode(arr) t AS elem"
),
mode="published",
)
validator = await validate_node_data(data, session)
invalid_col_errors = [
e for e in validator.errors if e.code == ErrorCode.INVALID_COLUMN
]
assert not invalid_col_errors, (
f"False INVALID_COLUMN from LATERAL VIEW alias: {invalid_col_errors}"
)


@pytest.mark.asyncio
async def test_unnest_alias_not_flagged_as_invalid_column(
session: AsyncSession,
user: User,
):
"""
Regression: UNNEST with a column alias (``UNNEST(arr) AS t(elem)``) produces a
virtual table ``t`` referenced as ``t.elem``. Like LATERAL VIEW, that alias may
not be captured in local_aliases — causing a false INVALID_COLUMN.
"""
from datajunction_server.errors import ErrorCode

_make_source(
session,
user,
"test.unnest_source",
[Column(name="arr", type=ct.ListType(element_type=ct.StringType()), order=0)],
)
await session.commit()

data = NodeRevisionBase(
name="test.unnest_transform",
display_name="Unnest transform",
type=NodeType.TRANSFORM,
query=(
"SELECT t.elem FROM test.unnest_source CROSS JOIN UNNEST(arr) AS t(elem)"
),
mode="published",
)
validator = await validate_node_data(data, session)
invalid_col_errors = [
e for e in validator.errors if e.code == ErrorCode.INVALID_COLUMN
]
assert not invalid_col_errors, (
f"False INVALID_COLUMN from UNNEST alias: {invalid_col_errors}"
)


@pytest.mark.asyncio
async def test_values_clause_alias_not_flagged_as_invalid_column(
session: AsyncSession,
user: User,
):
"""
A VALUES clause with explicit column aliases ``(VALUES ...) AS v(a, b)`` should
not produce false INVALID_COLUMN errors for ``v.a`` / ``v.b``.
"""
from datajunction_server.errors import ErrorCode

data = NodeRevisionBase(
name="test.values_transform",
display_name="Values transform",
type=NodeType.TRANSFORM,
query="SELECT v.a, v.b FROM (VALUES (1, 2)) AS v(a, b)",
mode="published",
)
validator = await validate_node_data(data, session)
invalid_col_errors = [
e for e in validator.errors if e.code == ErrorCode.INVALID_COLUMN
]
assert not invalid_col_errors, (
f"False INVALID_COLUMN from VALUES alias: {invalid_col_errors}"
)


@pytest.mark.asyncio
async def test_correlated_subquery_outer_alias_not_flagged_as_invalid_column(
session: AsyncSession,
user: User,
):
"""
A correlated subquery references an alias (``o``) from the outer query scope.
``find_all(ast.Table)`` traverses into subqueries, so ``o`` should be in
local_aliases and not produce a false INVALID_COLUMN.
This test confirms that assumption holds.
"""
from datajunction_server.errors import ErrorCode

_make_source(
session,
user,
"test.outer_table",
[Column(name="id", type=ct.BigIntType(), order=0)],
)
_make_source(
session,
user,
"test.inner_table",
[Column(name="id", type=ct.BigIntType(), order=0)],
)
await session.commit()

data = NodeRevisionBase(
name="test.correlated_transform",
display_name="Correlated subquery transform",
type=NodeType.TRANSFORM,
query=(
"SELECT o.id "
"FROM test.outer_table o "
"WHERE EXISTS ("
" SELECT 1 FROM test.inner_table WHERE test.inner_table.id = o.id"
")"
),
mode="published",
)
validator = await validate_node_data(data, session)
invalid_col_errors = [
e for e in validator.errors if e.code == ErrorCode.INVALID_COLUMN
]
assert not invalid_col_errors, (
f"False INVALID_COLUMN from correlated subquery outer alias: {invalid_col_errors}"
)


class TestReparseParentColumnTypes:
"""Tests for _reparse_parent_column_types."""

Expand All @@ -491,7 +731,7 @@ def test_string_type_is_parsed(self):
assert isinstance(col.type, ct.IntegerType)

def test_unparseable_string_is_left_unchanged(self):
"""If parsing fails, the original string value is preserved (except path covered)."""
"""If parsing fails, the original string value is preserved."""
col = Column(name="bad", type="NOT_A_VALID_TYPE_$$$$", order=0)
revision = self._make_revision("test.node", [col])
_reparse_parent_column_types({revision: None})
Expand All @@ -507,4 +747,4 @@ def test_already_parsed_type_is_skipped(self):

def test_empty_map_is_noop(self):
"""An empty dependencies_map doesn't raise."""
_reparse_parent_column_types({}) # Should not raise
_reparse_parent_column_types({})
10 changes: 10 additions & 0 deletions datajunction-server/tests/sql/parsing/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,16 @@ def test_struct_column_name_deep_namespace():
)


def test_values_clause_explicit_column_aliases():
"""(VALUES (1, 2)) AS v(a, b) — inner Query._columns should use 'a'/'b', not 'col1'/'col2'."""
query_ast = parse("SELECT v.a, v.b FROM (VALUES (1, 2)) AS v(a, b)")
all_queries = list(query_ast.find_all(ast.Query))
inner_q = [q for q in all_queries if q.alias is not None][0]
assert inner_q.alias.name == "v"
col_names = [col.name.name for col in inner_q.select._columns]
assert col_names == ["a", "b"], f"Expected ['a', 'b'], got {col_names}"


def test_struct_column_name_two_level():
"""
The original 2-level struct case must continue to work after the deep-namespace fix.
Expand Down
Loading