From 11962e66899a2913dd59931a3f6ca72ea68204d2 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 8 Apr 2026 04:08:24 -0700 Subject: [PATCH 1/2] Actually validate SQL when propagating validity --- .../datajunction_server/internal/impact.py | 291 ++++++++++++- .../datajunction_server/models/impact.py | 1 + .../tests/internal/impact_test.py | 412 ++++++++++++++++++ 3 files changed, 693 insertions(+), 11 deletions(-) diff --git a/datajunction-server/datajunction_server/internal/impact.py b/datajunction-server/datajunction_server/internal/impact.py index c3e702fc8..564c2e634 100644 --- a/datajunction-server/datajunction_server/internal/impact.py +++ b/datajunction-server/datajunction_server/internal/impact.py @@ -2,18 +2,22 @@ Downstream impact propagation for deployments. """ +import asyncio import logging import time +from collections import defaultdict from sqlalchemy import select from sqlalchemy.sql.operators import is_ from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, selectinload from datajunction_server.database.node import Node, NodeRevision, NodeRelationship from datajunction_server.instrumentation.provider import get_metrics_provider +from datajunction_server.internal.validation import validate_node_data from datajunction_server.models.impact import DownstreamImpact, ImpactType from datajunction_server.models.node import NodeStatus +from datajunction_server.models.node_type import NodeType logger = logging.getLogger(__name__) @@ -24,13 +28,17 @@ async def propagate_impact( changed_node_names: set[str], deleted_node_names: frozenset[str] = frozenset(), ) -> list[DownstreamImpact]: - """BFS downstream impact analysis with INVALID propagation. + """BFS downstream impact analysis with INVALID propagation and validity recovery. Must be called inside the caller's active transaction (inside a SAVEPOINT for dry-runs). For dry-runs the caller rolls back the SAVEPOINT, undoing both the - node changes and any INVALID markers written here. For wet-runs the caller + node changes and any status changes written here. For wet-runs the caller commits, persisting everything. + Validity recovery: If a downstream node was INVALID and none of its changed + parents are invalidating (INVALID or deleted), we check if ALL of its parents + are now VALID. If so, the node is marked VALID (recovered). + Args: session: Active async session (inside a SAVEPOINT for dry-runs). namespace: Deployment namespace used to flag external impacts. @@ -78,6 +86,13 @@ async def propagate_impact( if n.current and n.current.status == NodeStatus.INVALID } | set(deleted_by_id) + # Track root nodes that became VALID (for validity recovery) + validating_root_ids: set[int] = { + nid + for nid, n in changed_by_id.items() + if n.current and n.current.status == NodeStatus.VALID + } + # Causality tracking: node_id → set of root node IDs responsible cause_map: dict[int, set[int]] = {nid: {nid} for nid in all_root_ids} root_id_to_name: dict[int, str] = { @@ -90,6 +105,13 @@ async def propagate_impact( results: list[DownstreamImpact] = [] depth = 1 + # Track recovery candidates: (node_id, node, depth, result_index) + # These are INVALID nodes with no invalidating parents in the traversal + recovery_candidates: list[tuple[int, Node, int, int]] = [] + + # Track all visited nodes by ID for later parent status lookup + visited_nodes_by_id: dict[int, Node] = {**changed_by_id, **deleted_by_id} + while frontier_ids: # Each row: (child_node_id, parent_node_id) for all frontier parents rows = ( @@ -126,6 +148,7 @@ async def propagate_impact( next_frontier: set[int] = set() for node in child_nodes: visited_node_ids.add(node.id) + visited_nodes_by_id[node.id] = node parent_ids = child_to_parents.get(node.id, set()) # Propagate causality: union of causes from all triggering parents @@ -135,16 +158,32 @@ async def propagate_impact( cause_map[node.id] = node_causes will_invalidate = bool(parent_ids & invalidating_ids) - impact_type = ( - ImpactType.WILL_INVALIDATE if will_invalidate else ImpactType.MAY_AFFECT - ) current_status = node.current.status if node.current else NodeStatus.INVALID - predicted_status = NodeStatus.INVALID if will_invalidate else current_status - if will_invalidate and current_status != NodeStatus.INVALID: - node.current.status = NodeStatus.INVALID - session.add(node.current) + # Check if this is a recovery candidate: + # - Currently INVALID + # - No invalidating parents in this traversal + # - Has at least one validating root in its causes + is_recovery_candidate = ( + not will_invalidate + and current_status == NodeStatus.INVALID + and bool(node_causes & validating_root_ids) + ) + + if will_invalidate: + impact_type = ImpactType.WILL_INVALIDATE + predicted_status = NodeStatus.INVALID + if current_status != NodeStatus.INVALID: + node.current.status = NodeStatus.INVALID + session.add(node.current) invalidating_ids.add(node.id) + elif is_recovery_candidate: + # Tentatively mark as MAY_AFFECT; will update after batch parent check + impact_type = ImpactType.MAY_AFFECT + predicted_status = current_status + else: + impact_type = ImpactType.MAY_AFFECT + predicted_status = current_status cause_names = sorted( root_id_to_name[cid] for cid in node_causes if cid in root_id_to_name @@ -154,6 +193,8 @@ async def propagate_impact( if will_invalidate else f"Upstream node(s) changed: {', '.join(cause_names)}" ) + + result_index = len(results) results.append( DownstreamImpact( name=node.name, @@ -167,19 +208,35 @@ async def propagate_impact( is_external=not node.name.startswith(namespace + "."), ), ) + + if is_recovery_candidate: + recovery_candidates.append((node.id, node, depth, result_index)) + next_frontier.add(node.id) frontier_ids = next_frontier depth += 1 + # Phase 2: Validity recovery via batched parent check + if recovery_candidates: + recovered_count = await _process_validity_recovery( + session=session, + recovery_candidates=recovery_candidates, + visited_nodes_by_id=visited_nodes_by_id, + results=results, + ) + else: + recovered_count = 0 + elapsed_ms = (time.perf_counter() - start) * 1000 will_invalidate_count = sum( 1 for r in results if r.impact_type == ImpactType.WILL_INVALIDATE ) logger.info( - "Impact analysis: %d downstream nodes (%d will_invalidate)", + "Impact analysis: %d downstream nodes (%d will_invalidate, %d will_recover)", len(results), will_invalidate_count, + recovered_count, ) get_metrics_provider().timer( "dj.deployment.propagate_impact_ms", @@ -194,3 +251,215 @@ async def propagate_impact( will_invalidate_count, ) return results + + +async def _process_validity_recovery( + session: AsyncSession, + recovery_candidates: list[tuple[int, Node, int, int]], + visited_nodes_by_id: dict[int, Node], + results: list[DownstreamImpact], +) -> int: + """ + Process validity recovery for candidate nodes. + + For each candidate: + 1. Check if ALL parents are VALID + 2. If so, validate the node's query to confirm it's actually valid + 3. Only mark as recovered if validation passes + + Args: + session: Active async session. + recovery_candidates: List of (node_id, node, depth, result_index). + visited_nodes_by_id: Map of node_id → Node for nodes in the traversal. + results: The results list to update in place. + + Returns: + Number of nodes that were recovered. + """ + candidate_ids = [c[0] for c in recovery_candidates] + + # Batch query: get all parent node IDs for each candidate + parent_query = ( + select( + NodeRevision.node_id.label("child_node_id"), + NodeRelationship.parent_id.label("parent_node_id"), + ) + .select_from(NodeRelationship) + .join(NodeRevision, NodeRelationship.child_id == NodeRevision.id) + .where(NodeRevision.node_id.in_(candidate_ids)) + ) + parent_rows = (await session.execute(parent_query)).all() + + # Group by child: child_node_id → set of parent_node_ids + parents_by_child: dict[int, set[int]] = defaultdict(set) + all_parent_ids: set[int] = set() + for child_node_id, parent_node_id in parent_rows: + parents_by_child[child_node_id].add(parent_node_id) + all_parent_ids.add(parent_node_id) + + # Load parent nodes that aren't already in visited_nodes_by_id + missing_parent_ids = all_parent_ids - set(visited_nodes_by_id.keys()) + if missing_parent_ids: + missing_parents = ( + ( + await session.execute( + select(Node) + .where(Node.id.in_(missing_parent_ids)) + .options(joinedload(Node.current)), + ) + ) + .unique() + .scalars() + .all() + ) + for p in missing_parents: + visited_nodes_by_id[p.id] = p + + # Sort candidates by depth (ascending) so cascading recovery works + recovery_candidates_sorted = sorted(recovery_candidates, key=lambda c: c[2]) + + # Track which nodes we've recovered (they become VALID for later checks) + recovered_ids: set[int] = set() + recovered_count = 0 + + # Process candidates level by level for cascading recovery + # Group by depth so we can validate each level in parallel + candidates_by_depth: dict[int, list[tuple[int, Node, int, int]]] = defaultdict(list) + for candidate in recovery_candidates_sorted: + _, _, depth, _ = candidate + candidates_by_depth[depth].append(candidate) + + for depth in sorted(candidates_by_depth.keys()): + level_candidates = candidates_by_depth[depth] + + # First pass: filter to candidates with all parents VALID + valid_parent_candidates: list[tuple[int, Node, int, int]] = [] + for node_id, node, d, result_index in level_candidates: + parent_ids = parents_by_child.get(node_id, set()) + if not parent_ids: + continue + + all_parents_valid = True + for pid in parent_ids: + parent_node = visited_nodes_by_id.get(pid) + if parent_node is None: + all_parents_valid = False + break + parent_status = ( + parent_node.current.status + if parent_node.current + else NodeStatus.INVALID + ) + if pid in recovered_ids: + parent_status = NodeStatus.VALID + if parent_status != NodeStatus.VALID: + all_parents_valid = False + break + + if all_parents_valid: + valid_parent_candidates.append((node_id, node, d, result_index)) + + if not valid_parent_candidates: + continue + + # Second pass: validate candidates in parallel + # Source and cube nodes don't need SQL validation - they just need valid parents + to_validate: list[tuple[int, Node, int, int]] = [] + auto_recover: list[tuple[int, Node, int, int]] = [] + + for candidate in valid_parent_candidates: + _, node, _, _ = candidate + if node.type == NodeType.SOURCE: + # Source nodes have no SQL to validate + auto_recover.append(candidate) + elif node.type == NodeType.CUBE: + # Cubes are VALID if all their metric/dimension elements are VALID + # Since we've already verified all parents are VALID, auto-recover + auto_recover.append(candidate) + else: + to_validate.append(candidate) + + # Auto-recover source nodes (they don't have queries to validate) + for node_id, node, d, result_index in auto_recover: + node.current.status = NodeStatus.VALID + session.add(node.current) + recovered_ids.add(node_id) + recovered_count += 1 + _update_result_to_recovered(results, result_index) + + # Validate query nodes in parallel + if to_validate: + # Ensure nodes have their revisions loaded with necessary relationships + node_ids_to_load = [c[0] for c in to_validate] + loaded_nodes = ( + ( + await session.execute( + select(Node) + .where(Node.id.in_(node_ids_to_load)) + .options( + selectinload(Node.current).options( + selectinload(NodeRevision.columns), + selectinload(NodeRevision.parents), + ), + ), + ) + ) + .unique() + .scalars() + .all() + ) + loaded_by_id = {n.id: n for n in loaded_nodes} + + # Validate in parallel + validation_tasks = [ + validate_node_data(loaded_by_id[node_id].current, session) + for node_id, _, _, _ in to_validate + ] + validation_results = await asyncio.gather( + *validation_tasks, + return_exceptions=True, + ) + + # Process validation results + for (node_id, node, d, result_index), validation_result in zip( + to_validate, + validation_results, + ): + if isinstance(validation_result, Exception): + logger.warning( + "Validation failed for recovery candidate %s: %s", + node.name, + validation_result, + ) + continue + + if validation_result.status == NodeStatus.VALID: + node.current.status = NodeStatus.VALID + session.add(node.current) + recovered_ids.add(node_id) + recovered_count += 1 + _update_result_to_recovered(results, result_index) + else: + logger.info( + "Recovery candidate %s failed validation: %s", + node.name, + [str(e) for e in validation_result.errors], + ) + + return recovered_count + + +def _update_result_to_recovered(results: list[DownstreamImpact], result_index: int): + """Update a result entry to reflect successful recovery.""" + old_result = results[result_index] + results[result_index] = DownstreamImpact( + name=old_result.name, + node_type=old_result.node_type, + current_status=old_result.current_status, + predicted_status=NodeStatus.VALID, + impact_type=ImpactType.WILL_RECOVER, + impact_reason=f"Validated and recovered - upstream nodes now valid: {', '.join(old_result.caused_by)}", + depth=old_result.depth, + caused_by=old_result.caused_by, + is_external=old_result.is_external, + ) diff --git a/datajunction-server/datajunction_server/models/impact.py b/datajunction-server/datajunction_server/models/impact.py index 367e747de..1e38e8b99 100644 --- a/datajunction-server/datajunction_server/models/impact.py +++ b/datajunction-server/datajunction_server/models/impact.py @@ -13,6 +13,7 @@ class ImpactType(str, Enum): """Type of impact on a downstream node""" WILL_INVALIDATE = "will_invalidate" # Certain to break + WILL_RECOVER = "will_recover" # Was INVALID, will become VALID MAY_AFFECT = "may_affect" # Might need revalidation UNCHANGED = "unchanged" # No predicted impact diff --git a/datajunction-server/tests/internal/impact_test.py b/datajunction-server/tests/internal/impact_test.py index 005549a46..8a91548f6 100644 --- a/datajunction-server/tests/internal/impact_test.py +++ b/datajunction-server/tests/internal/impact_test.py @@ -367,3 +367,415 @@ async def test_propagate_impact_cause_names_in_result(session, current_user: Use # Both should trace back to ns.source assert by_name["ns.transform"].caused_by == ["ns.source"] assert by_name["ns.metric"].caused_by == ["ns.source"] + + +# --------------------------------------------------------------------------- +# Validity Recovery Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_propagate_impact_validity_recovery_basic(session, current_user: User): + """INVALID child with VALID parent → child recovers to VALID.""" + session.add(NodeNamespace(namespace="ns")) + + # Parent is now VALID (simulating a fix in deployment) + parent, parent_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + # Child was INVALID (presumably because parent was previously INVALID) + child, child_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + await _persist(session, parent, parent_rev, child, child_rev) + await _persist(session, _link(parent, child_rev)) + + result = await propagate_impact(session, "ns", {"ns.source"}) + + assert len(result) == 1 + impact = result[0] + assert impact.name == "ns.transform" + assert impact.impact_type == ImpactType.WILL_RECOVER + assert impact.current_status == NodeStatus.INVALID + assert impact.predicted_status == NodeStatus.VALID + # The ORM object should be mutated to VALID + assert child_rev.status == NodeStatus.VALID + + +@pytest.mark.asyncio +async def test_propagate_impact_validity_recovery_cascading( + session, + current_user: User, +): + """INVALID chain recovers when root becomes VALID: A→B→C all recover.""" + session.add(NodeNamespace(namespace="ns")) + + # Root is now VALID + root, root_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + # Children were INVALID + child1, child1_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + child2, child2_rev = _make_node( + "ns.metric", + NodeType.METRIC, + NodeStatus.INVALID, + current_user.id, + ) + await _persist(session, root, root_rev, child1, child1_rev, child2, child2_rev) + await _persist(session, _link(root, child1_rev), _link(child1, child2_rev)) + + result = await propagate_impact(session, "ns", {"ns.source"}) + + assert len(result) == 2 + by_name = {r.name: r for r in result} + + # Both should recover + assert by_name["ns.transform"].impact_type == ImpactType.WILL_RECOVER + assert by_name["ns.transform"].predicted_status == NodeStatus.VALID + assert by_name["ns.metric"].impact_type == ImpactType.WILL_RECOVER + assert by_name["ns.metric"].predicted_status == NodeStatus.VALID + + # ORM objects mutated + assert child1_rev.status == NodeStatus.VALID + assert child2_rev.status == NodeStatus.VALID + + +@pytest.mark.asyncio +async def test_propagate_impact_no_recovery_if_other_parent_invalid( + session, + current_user: User, +): + """INVALID child with one VALID parent but another INVALID parent → no recovery.""" + session.add(NodeNamespace(namespace="ns")) + + # parent1 is VALID (changed) + parent1, parent1_rev = _make_node( + "ns.source1", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + # parent2 is INVALID (not changed, external) + parent2, parent2_rev = _make_node( + "ns.source2", + NodeType.SOURCE, + NodeStatus.INVALID, + current_user.id, + ) + # Child is INVALID + child, child_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + await _persist( + session, + parent1, + parent1_rev, + parent2, + parent2_rev, + child, + child_rev, + ) + await _persist(session, _link(parent1, child_rev), _link(parent2, child_rev)) + + # Only parent1 is in the changed set + result = await propagate_impact(session, "ns", {"ns.source1"}) + + assert len(result) == 1 + impact = result[0] + assert impact.name == "ns.transform" + # Should NOT recover because parent2 is still INVALID + assert impact.impact_type == ImpactType.MAY_AFFECT + assert impact.current_status == NodeStatus.INVALID + assert impact.predicted_status == NodeStatus.INVALID + # ORM object should NOT be changed + assert child_rev.status == NodeStatus.INVALID + + +@pytest.mark.asyncio +async def test_propagate_impact_no_recovery_if_parent_deleted( + session, + current_user: User, +): + """INVALID child whose parent is being deleted → invalidates, doesn't recover.""" + session.add(NodeNamespace(namespace="ns")) + + parent, parent_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + child, child_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + await _persist(session, parent, parent_rev, child, child_rev) + await _persist(session, _link(parent, child_rev)) + + # Parent is deleted, not changed + result = await propagate_impact( + session, + "ns", + set(), + deleted_node_names=frozenset(["ns.source"]), + ) + + assert len(result) == 1 + impact = result[0] + # Deletion invalidates, doesn't recover + assert impact.impact_type == ImpactType.WILL_INVALIDATE + assert impact.predicted_status == NodeStatus.INVALID + + +@pytest.mark.asyncio +async def test_propagate_impact_mixed_invalidation_and_recovery( + session, + current_user: User, +): + """Multiple changes: one fixes a node (recovery), another breaks a node (invalidation).""" + session.add(NodeNamespace(namespace="ns")) + + # source1 is now VALID (was fixed) + source1, source1_rev = _make_node( + "ns.source1", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + # source2 is now INVALID (was broken) + source2, source2_rev = _make_node( + "ns.source2", + NodeType.SOURCE, + NodeStatus.INVALID, + current_user.id, + ) + # child1 depends on source1, was INVALID → should recover + child1, child1_rev = _make_node( + "ns.child1", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + # child2 depends on source2, was VALID → should invalidate + child2, child2_rev = _make_node( + "ns.child2", + NodeType.TRANSFORM, + NodeStatus.VALID, + current_user.id, + ) + await _persist( + session, + source1, + source1_rev, + source2, + source2_rev, + child1, + child1_rev, + child2, + child2_rev, + ) + await _persist(session, _link(source1, child1_rev), _link(source2, child2_rev)) + + result = await propagate_impact(session, "ns", {"ns.source1", "ns.source2"}) + + assert len(result) == 2 + by_name = {r.name: r for r in result} + + # child1 should recover + assert by_name["ns.child1"].impact_type == ImpactType.WILL_RECOVER + assert by_name["ns.child1"].predicted_status == NodeStatus.VALID + assert child1_rev.status == NodeStatus.VALID + + # child2 should invalidate + assert by_name["ns.child2"].impact_type == ImpactType.WILL_INVALIDATE + assert by_name["ns.child2"].predicted_status == NodeStatus.INVALID + assert child2_rev.status == NodeStatus.INVALID + + +@pytest.mark.asyncio +async def test_propagate_impact_valid_child_stays_valid(session, current_user: User): + """VALID child with VALID parent → stays VALID (MAY_AFFECT, not recovery).""" + session.add(NodeNamespace(namespace="ns")) + + parent, parent_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + child, child_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.VALID, + current_user.id, + ) + await _persist(session, parent, parent_rev, child, child_rev) + await _persist(session, _link(parent, child_rev)) + + result = await propagate_impact(session, "ns", {"ns.source"}) + + assert len(result) == 1 + impact = result[0] + # Not a recovery because it was already VALID + assert impact.impact_type == ImpactType.MAY_AFFECT + assert impact.current_status == NodeStatus.VALID + assert impact.predicted_status == NodeStatus.VALID + assert child_rev.status == NodeStatus.VALID + + +@pytest.mark.asyncio +async def test_propagate_impact_no_recovery_if_node_has_sql_error( + session, + current_user: User, +): + """INVALID child with VALID parent but bad SQL → no recovery (validated).""" + session.add(NodeNamespace(namespace="ns")) + + # Parent is VALID + parent, parent_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + + # Child is INVALID and has broken SQL (syntax error) + child = Node( + name="ns.transform", + type=NodeType.TRANSFORM, + current_version="v1.0", + created_by_id=current_user.id, + namespace="ns", + ) + child_rev = NodeRevision( + name="ns.transform", + type=NodeType.TRANSFORM, + node=child, + version="v1.0", + status=NodeStatus.INVALID, + query="SELECT * FORM broken_syntax", # Intentional typo: FORM instead of FROM + created_by_id=current_user.id, + ) + + await _persist(session, parent, parent_rev, child, child_rev) + await _persist(session, _link(parent, child_rev)) + + result = await propagate_impact(session, "ns", {"ns.source"}) + + assert len(result) == 1 + impact = result[0] + # Should NOT recover because validation fails on the SQL syntax error + assert impact.impact_type == ImpactType.MAY_AFFECT + assert impact.current_status == NodeStatus.INVALID + assert impact.predicted_status == NodeStatus.INVALID + # Status should remain INVALID + assert child_rev.status == NodeStatus.INVALID + + +@pytest.mark.asyncio +async def test_propagate_impact_source_node_recovery(session, current_user: User): + """INVALID source node with VALID parent → auto-recovers (no SQL to validate).""" + session.add(NodeNamespace(namespace="ns")) + + # Parent source is VALID + parent, parent_rev = _make_node( + "ns.source1", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + + # Child is a source node that was INVALID (unusual but possible) + child = Node( + name="ns.source2", + type=NodeType.SOURCE, + current_version="v1.0", + created_by_id=current_user.id, + namespace="ns", + ) + child_rev = NodeRevision( + name="ns.source2", + type=NodeType.SOURCE, + node=child, + version="v1.0", + status=NodeStatus.INVALID, + query=None, # Source nodes don't have queries + created_by_id=current_user.id, + ) + + await _persist(session, parent, parent_rev, child, child_rev) + await _persist(session, _link(parent, child_rev)) + + result = await propagate_impact(session, "ns", {"ns.source1"}) + + assert len(result) == 1 + impact = result[0] + # Source nodes auto-recover without validation + assert impact.impact_type == ImpactType.WILL_RECOVER + assert impact.current_status == NodeStatus.INVALID + assert impact.predicted_status == NodeStatus.VALID + + +@pytest.mark.asyncio +async def test_propagate_impact_cube_node_recovery(session, current_user: User): + """INVALID cube with all VALID metric parents → auto-recovers (no SQL parsing needed).""" + session.add(NodeNamespace(namespace="ns")) + + # Parent metric is VALID + parent, parent_rev = _make_node( + "ns.metric1", + NodeType.METRIC, + NodeStatus.VALID, + current_user.id, + ) + + # Child is a cube node that was INVALID (e.g., because its metric parent was invalid) + child = Node( + name="ns.cube1", + type=NodeType.CUBE, + current_version="v1.0", + created_by_id=current_user.id, + namespace="ns", + ) + child_rev = NodeRevision( + name="ns.cube1", + type=NodeType.CUBE, + node=child, + version="v1.0", + status=NodeStatus.INVALID, + query=None, # Cube nodes don't have queries + created_by_id=current_user.id, + ) + + await _persist(session, parent, parent_rev, child, child_rev) + await _persist(session, _link(parent, child_rev)) + + result = await propagate_impact(session, "ns", {"ns.metric1"}) + + assert len(result) == 1 + impact = result[0] + # Cubes auto-recover when all their metric parents are VALID (no SQL parsing needed) + assert impact.impact_type == ImpactType.WILL_RECOVER + assert impact.current_status == NodeStatus.INVALID + assert impact.predicted_status == NodeStatus.VALID + assert child_rev.status == NodeStatus.VALID From 349c59bc4a30bd3fcdce81d94b19e74b0a942338 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Sat, 11 Apr 2026 16:09:55 -0700 Subject: [PATCH 2/2] Fix tests --- .../tests/internal/impact_test.py | 201 ++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/datajunction-server/tests/internal/impact_test.py b/datajunction-server/tests/internal/impact_test.py index 8a91548f6..df5155cf0 100644 --- a/datajunction-server/tests/internal/impact_test.py +++ b/datajunction-server/tests/internal/impact_test.py @@ -2,6 +2,8 @@ Unit tests for datajunction_server.internal.impact.propagate_impact """ +from unittest import mock + import pytest from datajunction_server.database.node import Node, NodeRevision, NodeRelationship @@ -779,3 +781,202 @@ async def test_propagate_impact_cube_node_recovery(session, current_user: User): assert impact.current_status == NodeStatus.INVALID assert impact.predicted_status == NodeStatus.VALID assert child_rev.status == NodeStatus.VALID + + +@pytest.mark.asyncio +async def test_process_validity_recovery_skips_candidate_with_no_parents( + session, + current_user: User, +): + """Directly test _process_validity_recovery: candidate with no parent edges is skipped.""" + from datajunction_server.internal.impact import _process_validity_recovery + from datajunction_server.models.impact import DownstreamImpact + + session.add(NodeNamespace(namespace="ns")) + + # Create a node that will be passed as a recovery candidate + orphan = Node( + name="ns.orphan", + type=NodeType.SOURCE, + current_version="v1.0", + created_by_id=current_user.id, + namespace="ns", + ) + orphan_rev = NodeRevision( + name="ns.orphan", + type=NodeType.SOURCE, + node=orphan, + version="v1.0", + status=NodeStatus.INVALID, + query=None, + created_by_id=current_user.id, + ) + await _persist(session, orphan, orphan_rev) + + # Create a placeholder result that _process_validity_recovery can update + results = [ + DownstreamImpact( + name="ns.orphan", + node_type=NodeType.SOURCE, + current_status=NodeStatus.INVALID, + predicted_status=NodeStatus.INVALID, + impact_type=ImpactType.MAY_AFFECT, + impact_reason="test", + depth=1, + ), + ] + + # Pass orphan as a recovery candidate — it has NO NodeRelationship rows + candidates = [(orphan.id, orphan, 1, 0)] + recovered = await _process_validity_recovery( + session, + candidates, + visited_nodes_by_id={orphan.id: orphan}, + results=results, + ) + + # Should not recover (no parents found → skipped) + assert recovered == 0 + assert orphan_rev.status == NodeStatus.INVALID + + +@pytest.mark.asyncio +async def test_propagate_impact_recovery_skips_candidate_with_missing_parent_node( + session, + current_user: User, +): + """Recovery candidate whose parent Node row can't be loaded is skipped.""" + from datajunction_server.internal.impact import _process_validity_recovery + from datajunction_server.models.impact import DownstreamImpact + + session.add(NodeNamespace(namespace="ns")) + + source, source_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + # Create a second parent that is linked to child + other_parent, other_parent_rev = _make_node( + "ns.other_parent", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + child, child_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + + await _persist( + session, + source, + source_rev, + other_parent, + other_parent_rev, + child, + child_rev, + ) + await _persist(session, _link(source, child_rev)) + await _persist(session, _link(other_parent, child_rev)) + + results = [ + DownstreamImpact( + name="ns.transform", + node_type=NodeType.TRANSFORM, + current_status=NodeStatus.INVALID, + predicted_status=NodeStatus.INVALID, + impact_type=ImpactType.MAY_AFFECT, + impact_reason="test", + depth=1, + ), + ] + + # Call _process_validity_recovery directly. + # visited_nodes_by_id includes source but NOT other_parent. + # The function will load missing parents (other_parent) from DB, so it + # would normally find it. We mock the load query to return empty so + # other_parent stays missing from visited_nodes_by_id. + original_execute = session.execute + call_count = 0 + + async def _intercept_execute(stmt, *args, **kwargs): + nonlocal call_count + call_count += 1 + # The 2nd query is the missing-parent load (select Node where id in ...) + # Return an empty result to simulate a missing parent + if call_count == 2: + empty_result = mock.MagicMock() + empty_result.unique.return_value = empty_result + empty_result.scalars.return_value = empty_result + empty_result.all.return_value = [] + return empty_result + return await original_execute(stmt, *args, **kwargs) + + # Pre-load source with its .current relationship to avoid lazy loads + from sqlalchemy import select + from sqlalchemy.orm import joinedload as jl + + source_loaded = ( + ( + await session.execute( + select(Node).where(Node.id == source.id).options(jl(Node.current)), + ) + ) + .unique() + .scalar_one() + ) + + candidates = [(child.id, child, 1, 0)] + with mock.patch.object(session, "execute", side_effect=_intercept_execute): + recovered = await _process_validity_recovery( + session, + candidates, + visited_nodes_by_id={child.id: child, source.id: source_loaded}, + results=results, + ) + + # Should not recover because other_parent can't be found + assert recovered == 0 + assert child_rev.status == NodeStatus.INVALID + + +@pytest.mark.asyncio +async def test_propagate_impact_recovery_handles_validation_exception( + session, + current_user: User, +): + """If validate_node_data raises, the candidate is skipped (not recovered).""" + session.add(NodeNamespace(namespace="ns")) + + source, source_rev = _make_node( + "ns.source", + NodeType.SOURCE, + NodeStatus.VALID, + current_user.id, + ) + child, child_rev = _make_node( + "ns.transform", + NodeType.TRANSFORM, + NodeStatus.INVALID, + current_user.id, + ) + + await _persist(session, source, source_rev, child, child_rev) + await _persist(session, _link(source, child_rev)) + + with mock.patch( + "datajunction_server.internal.impact.validate_node_data", + side_effect=RuntimeError("boom"), + ): + result = await propagate_impact(session, "ns", {"ns.source"}) + + # child should NOT recover because validation raised an exception + child_impacts = [r for r in result if r.name == "ns.transform"] + assert len(child_impacts) == 1 + assert child_impacts[0].impact_type != ImpactType.WILL_RECOVER + # Status should remain INVALID + assert child_rev.status == NodeStatus.INVALID