From e2be07d089ee1788bccec6d780eeb2005087d282 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Apr 2026 13:01:41 +0100 Subject: [PATCH 1/3] Parameterize Wait option with target atoms for noise-robust termination Wait previously terminated on any atom change, making it sensitive to incidental physics noise. Now Wait can be parameterized with specific target atoms (positive or negative) that must be satisfied before termination. Falls back to any-atom-change when no targets are specified. - Add check_wait_target_atoms, parse_wait_target_annotations, strip_wait_annotations, inject_wait_targets_for_option to utils.py - Update option_model.py termination + memory propagation through re-grounding - Update LLM prompts in agent_planner, agent_bilevel, agent_option_learning approaches to document -> {atoms, NOT atoms} annotation syntax - Parse and inject Wait targets in agent_planner and agent_bilevel approaches - Inject Wait targets from atoms_sequence in process planning paths - Add tests for target atom termination, including noisy-atom-ignored test --- .../approaches/agent_bilevel_approach.py | 85 +++++-- .../agent_option_learning_approach.py | 5 +- .../approaches/agent_planner_approach.py | 79 ++++++- .../approaches/process_planning_approach.py | 3 +- predicators/option_model.py | 46 ++-- predicators/planning_with_processes.py | 15 +- predicators/utils.py | 135 ++++++++++- .../approaches/test_agent_bilevel_approach.py | 209 +++++++++++++++++- 8 files changed, 512 insertions(+), 65 deletions(-) diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index f2125fdc4..df2882744 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -34,6 +34,8 @@ class _SketchStep: option: ParameterizedOption objects: Sequence[Object] subgoal_atoms: Optional[Set[GroundAtom]] # None = no subgoal constraint + # Atoms that must be FALSE after this step. + subgoal_neg_atoms: Optional[Set[GroundAtom]] = None class AgentBilevelApproach(AgentPlannerApproach): @@ -75,7 +77,11 @@ def _get_agent_system_prompt(self) -> str: " OptionName(obj1:type1, obj2:type2) -> {Pred(obj1:type1), " "Pred2(obj1:type1, obj2:type2)}\n" "Always use typed references (obj:type) in subgoal atoms.\n" - "Subgoal annotations are optional but improve search efficiency.") + "Subgoal annotations are optional but improve search efficiency.\n" + "For Wait steps, the annotation also specifies exactly when the " + "Wait should terminate. Use `NOT Pred(...)` for atoms that should " + "become false (e.g. `Wait(robot:Robot) -> " + "{Boiled(water:water_type)}`).") # ------------------------------------------------------------------ # # Solve prompt (no continuous params, subgoal format) @@ -173,11 +179,16 @@ def _build_solve_prompt(self, task: Task) -> str: helps the search verify progress. Use `-> {{atoms}}` after each step. After any action whose desired subgoal depends on a delayed process (e.g. \ -water filling, dominoes cascading, heating), insert a Wait action. +water filling, dominoes cascading, heating), insert a Wait action. For Wait \ +steps, annotate with the atoms the process should produce — this tells the \ +system exactly when the Wait should end rather than terminating on any \ +incidental atom change. Use `NOT Pred(...)` for atoms that should become false. Output the plan sketch with one option per line in this format: OptionName(obj1:type1, obj2:type2) -> \ {{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}} + Wait(robot:Robot) -> {{Boiled(water:water_type)}} + Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}} Always use typed references (obj:type) in both option arguments AND subgoal \ atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \ @@ -286,8 +297,15 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: sketch = [] for i, (option, objs, _) in enumerate(parsed): sg = subgoals[i] if i < len(subgoals) else None - sketch.append( - _SketchStep(option=option, objects=objs, subgoal_atoms=sg)) + if sg is not None: + pos, neg = sg + sketch.append(_SketchStep( + option=option, objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append(_SketchStep( + option=option, objects=objs, subgoal_atoms=None)) logging.info(f"[{self._run_id}] Agent produced sketch with " f"{len(sketch)} steps, " @@ -300,21 +318,22 @@ def _parse_subgoal_annotations( text: str, predicates: Set[Predicate], objects: Sequence[Object], - ) -> List[Optional[Set[GroundAtom]]]: - """Parse ``-> {Pred(obj1, obj2), ...}`` annotations from plan text. + ) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. Returns a list parallel to the option lines. Entries are None - for lines without annotations. + for lines without annotations. Each non-None entry is + ``(positive_atoms, negative_atoms)``. """ pred_map = {p.name: p for p in predicates} obj_map = {o.name: o for o in objects} # Regex: match -> { ... } after the option line subgoal_re = re.compile(r'->\s*\{([^}]*)\}') - # Regex: match individual atoms like Pred(obj1, obj2) - atom_re = re.compile(r'(\w+)\(([^)]*)\)') + # Regex: match individual atoms, optionally prefixed with NOT + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') - results: List[Optional[Set[GroundAtom]]] = [] + results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] option_names = {o.name for o in self._get_all_options()} for line in text.split('\n'): @@ -333,13 +352,15 @@ def _parse_subgoal_annotations( continue atoms_text = sg_match.group(1) - atoms: Set[GroundAtom] = set() + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() for atom_match in atom_re.finditer(atoms_text): - pred_name = atom_match.group(1) + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) # Handle both "obj" and "obj:type" formats obj_names = [ n.strip().split(':')[0] - for n in atom_match.group(2).split(',') + for n in atom_match.group(3).split(',') ] if pred_name not in pred_map: @@ -357,9 +378,16 @@ def _parse_subgoal_annotations( f"Arity mismatch for {pred_name}: expected " f"{len(pred.types)}, got {len(objs)}") continue - atoms.add(GroundAtom(pred, objs)) + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) - results.append(atoms if atoms else None) + if pos_atoms or neg_atoms: + results.append((pos_atoms, neg_atoms)) + else: + results.append(None) return results @@ -424,6 +452,13 @@ def _refine_sketch( # Sample continuous parameters and ground option params = self._sample_params(step.option, cur_state, rng) grounded = step.option.ground(step.objects, params) + # Inject Wait target atoms from sketch annotations + if grounded.name == "Wait": + if step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + if step.subgoal_neg_atoms is not None: + grounded.memory["wait_target_neg_atoms"] = \ + step.subgoal_neg_atoms plan[cur_idx] = grounded state = cur_state @@ -547,6 +582,10 @@ def _validate_plan_forward( grounded.parent) option_copy = env_param_opt.ground(grounded.objects, grounded.params.copy()) + # Propagate Wait target atoms through re-grounding + for key in ("wait_target_atoms", "wait_target_neg_atoms"): + if key in grounded.memory: + option_copy.memory[key] = grounded.memory[key] if not option_copy.initiable(state): logging.info(f"Forward validation: step {i} " @@ -558,10 +597,12 @@ def _validate_plan_forward( # 2. terminate_on_repeat (stuck detection) # 3. wait_option_terminate_on_atom_change last_state_ref: List[Optional[State]] = [None] + abstract_fn = lambda s, _p=predicates: utils.abstract(s, _p) def _terminal( # pylint: disable=cell-var-from-loop s: State, - oc: _Option = option_copy) -> bool: + oc: _Option = option_copy, + _abs: Callable = abstract_fn) -> bool: if oc.terminal(s): return True prev = last_state_ref[0] @@ -572,11 +613,17 @@ def _terminal( # pylint: disable=cell-var-from-loop f"Option '{oc.name}' got stuck.") if (CFG.wait_option_terminate_on_atom_change and oc.name == "Wait"): - cur_atoms = utils.abstract(s, predicates) - prev_atoms = utils.abstract(prev, predicates) - if cur_atoms != prev_atoms: + result = utils.check_wait_target_atoms( + oc, s, _abs) + if result is True: last_state_ref[0] = s return True + if result is None: + cur_atoms = _abs(s) + prev_atoms = _abs(prev) + if cur_atoms != prev_atoms: + last_state_ref[0] = s + return True last_state_ref[0] = s return False diff --git a/predicators/approaches/agent_option_learning_approach.py b/predicators/approaches/agent_option_learning_approach.py index 9473332d0..f9a3f54ab 100644 --- a/predicators/approaches/agent_option_learning_approach.py +++ b/predicators/approaches/agent_option_learning_approach.py @@ -99,7 +99,10 @@ def _get_agent_system_prompt(self) -> str: No continuous params. - `create_move_to_skill(name, types, params_space, config, \ get_target_pose_fn)` — move end-effector to a target pose -- `create_wait_option(name, config, robot_type)` — hold current pose +- `create_wait_option(name, config, robot_type)` — hold current pose; \ +annotate with `-> {{atoms}}` in the plan to specify when it should \ +terminate (e.g. `Wait(robot:Robot) -> {{Boiled(water:water_type)}}`). \ +Use `NOT Pred(...)` for atoms that should become false All factories (except `create_place_skill` and `create_wait_option`) \ take a `SkillConfig` (available as `skill_config` in the exec \ diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 0745a24e3..ac0c7e7a9 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -14,7 +14,8 @@ import inspect as _inspect import logging import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, \ + cast import dill as pkl import numpy as np @@ -28,9 +29,9 @@ from predicators.explorers.base_explorer import BaseExplorer from predicators.option_model import _OptionModelBase, create_option_model from predicators.settings import CFG -from predicators.structs import Action, Dataset, InteractionRequest, \ - InteractionResult, LowLevelTrajectory, ParameterizedOption, Predicate, \ - State, Task, Type +from predicators.structs import Action, Dataset, GroundAtom, \ + InteractionRequest, InteractionResult, LowLevelTrajectory, Object, \ + ParameterizedOption, Predicate, State, Task, Type class AgentPlannerApproach(AgentSessionMixin, BaseApproach): @@ -122,8 +123,12 @@ def _get_all_trajectories(self) -> List[LowLevelTrajectory]: "delayed process (e.g. water filling, dominoes cascading, " "heating), insert a Wait after it so the effect has time " "to occur before the next action. The Wait action holds the " - "robot's current pose and terminates once the abstract state " - "changes. Without a Wait, the robot will proceed to the next " + "robot's current pose. You can annotate Wait with target atoms " + "using `-> {atoms}` to specify exactly when it should terminate " + "(e.g. `Wait(robot:Robot)[] -> {Boiled(water:water_type)}`). " + "Use `NOT Pred(...)` for atoms that should become false. " + "If no annotation is provided, the Wait terminates on any atom " + "change. Without a Wait, the robot will proceed to the next " "action before the delayed effect has occurred, which might " "cause the plan to fail.") @@ -511,10 +516,14 @@ def _build_solve_prompt(self, task: Task) -> str: After any action whose desired subgoal depends on a delayed process (e.g. water \ filling, dominoes cascading, heating), insert a Wait action to let the process \ -complete before proceeding. The Wait terminates once the abstract state changes. \ -Without it, the plan will move on before the effect has occurred and fail. Only use \ -Wait when there is a genuine delayed effect; do not insert it between actions with \ -immediate effects (e.g. Pick, Place). +complete before proceeding. You can annotate Wait with target atoms using \ +`-> {{atoms}}` to specify exactly when it should terminate. Use `NOT Pred(...)` for \ +atoms that should become false. If no annotation is provided, the Wait terminates on \ +any atom change. Only use Wait when there is a genuine delayed effect; do not insert \ +it between actions with immediate effects (e.g. Pick, Place). + +For Wait with target atoms: `Wait(robot:Robot)[] -> {{Boiled(water:water_type)}}` +For negated targets: `Wait(robot:Robot)[] -> {{NOT Touching(a:block, b:block)}}` **Important — parameter tuning workflow:** - When a step fails or produces unexpected results, inspect the rendered images \ @@ -607,6 +616,37 @@ def _strip_code_fences(text: str) -> str: lines.pop() return '\n'.join(lines) + def _parse_wait_annotations( + self, + text: str, + predicates: Set[Predicate], + objects: Sequence[Object], + ) -> List[Tuple[Set[GroundAtom], Set[GroundAtom]]]: + """Parse ``-> {atoms}`` annotations from plan lines. + + Returns a list parallel to the option lines in the text. Each + entry is ``(positive_atoms, negative_atoms)`` for Wait lines + with annotations, or ``(set(), set())`` otherwise. + """ + results: List[Tuple[Set[GroundAtom], Set[GroundAtom]]] = [] + option_names = {o.name for o in self._get_all_options()} + for line in text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + if results: + break + continue + if first_token == "Wait" and '->' in stripped: + pos, neg = utils.parse_wait_target_annotations( + stripped, predicates, objects) + results.append((pos, neg)) + else: + results.append((set(), set())) + return results + def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: """Parse option plan text and ground into executable options.""" objects = list(task.init) @@ -616,8 +656,15 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: # Strip markdown code fences that agents often wrap plans in. cleaned_text = self._strip_code_fences(plan_text) + # Extract Wait target annotations before stripping them + wait_annotations = self._parse_wait_annotations( + cleaned_text, self._get_all_predicates(), objects) + + # Strip annotations so the option plan parser doesn't choke + parseable_text = utils.strip_wait_annotations(cleaned_text) + parsed = utils.parse_model_output_into_option_plan( - cleaned_text, + parseable_text, objects, self._types, all_options, @@ -628,10 +675,18 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: f" Available option names: {option_names}") grounded = [] - for option, objs, params in parsed: + for i, (option, objs, params) in enumerate(parsed): try: params_arr = np.array(params, dtype=np.float32) ground_opt = option.ground(objs, params_arr) + # Inject Wait target atoms from annotations + if (ground_opt.name == "Wait" + and i < len(wait_annotations)): + pos, neg = wait_annotations[i] + if pos: + ground_opt.memory["wait_target_atoms"] = pos + if neg: + ground_opt.memory["wait_target_neg_atoms"] = neg grounded.append(ground_opt) except Exception as e: # pylint: disable=broad-except logging.warning( diff --git a/predicators/approaches/process_planning_approach.py b/predicators/approaches/process_planning_approach.py index 627e48c65..40a8e644f 100644 --- a/predicators/approaches/process_planning_approach.py +++ b/predicators/approaches/process_planning_approach.py @@ -99,7 +99,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: process_plan, task.goal, self._rng, - abstract_function=lambda s: utils.abstract(s, preds)) + abstract_function=lambda s: utils.abstract(s, preds), + atoms_seq=atoms_seq) logging.debug("Current Task Plan:") for process in process_plan: logging.debug(process.name) diff --git a/predicators/option_model.py b/predicators/option_model.py index 5b35c0c30..2a3251deb 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -104,6 +104,10 @@ def get_next_state_and_num_actions(self, state: State, assert np.allclose(env_param_opt.params_space.high, param_opt.params_space.high) option_copy = env_param_opt.ground(option.objects, option.params) + # Propagate Wait target atoms through re-grounding + for key in ("wait_target_atoms", "wait_target_neg_atoms"): + if key in option.memory: + option_copy.memory[key] = option.memory[key] del option # unused after this assert option_copy.initiable(state) @@ -128,20 +132,27 @@ def _terminal(s: State) -> bool: f"produced a no-op (e.g. IK returned current " f"joints, or finger command matched current " f"finger state).") - # Terminate Wait on atom change, mirroring - # option_policy_to_policy in utils.py. + # Terminate Wait on target atoms or any atom change. if (CFG.wait_option_terminate_on_atom_change and option_copy.name == "Wait" and last_state is not DefaultState and self._abstract_function is not None): - cur_atoms = self._abstract_function(s) - prev_atoms = self._abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.info(f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") + result = utils.check_wait_target_atoms( + option_copy, s, self._abstract_function) + if result is True: + logging.info("Wait terminating: target atoms satisfied") last_state = s return True + if result is None: + cur_atoms = self._abstract_function(s) + prev_atoms = self._abstract_function(last_state) + if cur_atoms != prev_atoms: + logging.info( + f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") + last_state = s + return True last_state = s return False else: @@ -155,14 +166,21 @@ def _terminal(s: State) -> bool: if option_copy.terminal(s): return True if last_state_ref[0] is not DefaultState: - cur_atoms = abstract_fn(s) - prev_atoms = abstract_fn(last_state_ref[0]) - if cur_atoms != prev_atoms: + result = utils.check_wait_target_atoms( + option_copy, s, abstract_fn) + if result is True: logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") + "Wait terminating: target atoms satisfied") return True + if result is None: + cur_atoms = abstract_fn(s) + prev_atoms = abstract_fn(last_state_ref[0]) + if cur_atoms != prev_atoms: + logging.info( + f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") + return True last_state_ref[0] = s return False else: diff --git a/predicators/planning_with_processes.py b/predicators/planning_with_processes.py index a2c2a839f..74b057e60 100644 --- a/predicators/planning_with_processes.py +++ b/predicators/planning_with_processes.py @@ -12,7 +12,7 @@ from itertools import islice from pprint import pformat from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, \ - Set, Tuple + Sequence, Set, Tuple import numpy as np @@ -1061,6 +1061,9 @@ def sesame_plan_with_processes( metrics["plan_length"] = len(plan) metrics["refinement_time"] = (time.perf_counter() - refinement_start_time) + # Inject Wait target atoms from atoms_sequence so + # execution terminates on specific atoms, not noise. + _inject_wait_targets(plan, skeleton, atoms_sequence) return plan, skeleton, metrics partial_refinements.append((skeleton, plan)) @@ -1073,6 +1076,16 @@ def sesame_plan_with_processes( info={"partial_refinements": partial_refinements}) +def _inject_wait_targets( + plan: List[_Option], + skeleton: List[_GroundEndogenousProcess], + atoms_sequence: Sequence[Set[GroundAtom]], +) -> None: + """Inject Wait target atoms into all Wait options in a plan.""" + for i, option in enumerate(plan): + utils.inject_wait_targets_for_option(option, i, atoms_sequence) + + def create_ff_heuristic( goal: Set[GroundAtom], ground_processes: List[_GroundCausalProcess], diff --git a/predicators/utils.py b/predicators/utils.py index 94284e9c7..6ba2fbe9d 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1584,6 +1584,107 @@ def __str__(self) -> str: return repr(self) +def check_wait_target_atoms( + option: _Option, + state: State, + abstract_function: Callable[[State], Set[GroundAtom]], +) -> Optional[bool]: + """Check if a Wait option's target atoms are satisfied. + + Returns True if targets are met (Wait should terminate), False if + not yet met, or None if no targets were specified (caller should + fall back to any-atom-change behaviour). + """ + pos = option.memory.get("wait_target_atoms", set()) + neg = option.memory.get("wait_target_neg_atoms", set()) + if not pos and not neg: + return None + cur_atoms = abstract_function(state) + return pos.issubset(cur_atoms) and neg.isdisjoint(cur_atoms) + + +def parse_wait_target_annotations( + line: str, + predicates: Collection[Predicate], + objects: Collection[Object], +) -> Tuple[Set[GroundAtom], Set[GroundAtom]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` from a plan line. + + Returns ``(positive_atoms, negative_atoms)`` where positive atoms + must become TRUE and negative atoms must become FALSE for the Wait + to terminate. + """ + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + + sg_match = re.search(r'->\s*\{([^}]*)\}', line) + if not sg_match: + return set(), set() + + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + for m in atom_re.finditer(sg_match.group(1)): + is_neg = m.group(1) is not None + pred_name = m.group(2) + obj_names = [n.strip().split(':')[0] for n in m.group(3).split(',')] + + if pred_name not in pred_map: + logging.warning("Unknown predicate in Wait target: %s", pred_name) + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError as e: + logging.warning("Unknown object in Wait target: %s", e) + continue + if len(objs) != len(pred.types): + logging.warning( + "Arity mismatch for %s: expected %d, got %d", + pred_name, len(pred.types), len(objs)) + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + + return pos_atoms, neg_atoms + + +def inject_wait_targets_for_option( + option: _Option, + step_idx: int, + atoms_sequence: Sequence[Set[GroundAtom]], +) -> None: + """Inject Wait target atoms into a single option from atoms_sequence. + + Computes the expected atom delta from ``atoms_sequence[step_idx]`` + to ``atoms_sequence[step_idx + 1]`` and stores it in the option's + memory so that execution terminates on specific atoms rather than + any noisy change. No-op for non-Wait options or out-of-bounds + indices. + """ + if option.name != "Wait": + return + if step_idx + 1 >= len(atoms_sequence): + return + before = atoms_sequence[step_idx] + after = atoms_sequence[step_idx + 1] + target_pos = after - before + target_neg = before - after + if target_pos: + option.memory["wait_target_atoms"] = target_pos + if target_neg: + option.memory["wait_target_neg_atoms"] = target_neg + + +def strip_wait_annotations(text: str) -> str: + """Remove ``-> {...}`` annotations from plan text lines.""" + return re.sub(r'\s*->\s*\{[^}]*\}', '', text) + + def option_policy_to_policy( option_policy: Callable[[State], _Option], max_option_steps: Optional[int] = None, @@ -1628,13 +1729,21 @@ def _policy(state: State) -> Action: and cur_option.name == "Wait": assert abstract_function is not None assert last_state is not None - cur_atoms = abstract_function(state) - prev_atoms = abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.debug(f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms-prev_atoms)} " - f"Del: {sorted(prev_atoms-cur_atoms)}") + result = check_wait_target_atoms( + cur_option, state, abstract_function) + if result is True: + logging.debug("Wait terminating: target atoms satisfied") wait_terminate = True + elif result is None: + # No targets specified: fall back to any-atom-change + cur_atoms = abstract_function(state) + prev_atoms = abstract_function(last_state) + if cur_atoms != prev_atoms: + logging.debug( + f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms-prev_atoms)} " + f"Del: {sorted(prev_atoms-cur_atoms)}") + wait_terminate = True last_state = state @@ -1753,6 +1862,7 @@ def process_plan_to_greedy_option_policy( goal: Set[GroundAtom], rng: np.random.Generator, necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, ) -> Callable[[State], _Option]: """Greedily execute a process plan, assuming downward refinability and that any sample will work. @@ -1769,9 +1879,10 @@ def process_plan_to_greedy_option_policy( ] assert len(necessary_atoms_seq) == len(process_plan) + 1 necessary_atoms_queue = list(necessary_atoms_seq) + step_idx = 0 def _option_policy(state: State) -> _Option: - nonlocal cur_process + nonlocal cur_process, step_idx if not process_queue: raise OptionExecutionFailure("Process plan exhausted.") expected_atoms = necessary_atoms_queue.pop(0) @@ -1780,6 +1891,10 @@ def _option_policy(state: State) -> _Option: "Executing the process failed to achieve the necessary atoms.") cur_process = process_queue.pop(0) cur_option = cur_process.sample_option(state, goal, rng) + if atoms_seq is not None: + inject_wait_targets_for_option( + cur_option, step_idx, atoms_seq) + step_idx += 1 logging.debug(f"Using option {cur_option.name}{cur_option.objects}" f"{cur_option.params} from process plan.") return cur_option @@ -1792,11 +1907,13 @@ def process_plan_to_greedy_policy( goal: Set[GroundAtom], rng: np.random.Generator, necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, - abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None, + atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, ) -> Callable[[State], Action]: """Convert a process plan to a greedy policy.""" option_policy = process_plan_to_greedy_option_policy( - process_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq) + process_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq, + atoms_seq=atoms_seq) return option_policy_to_policy(option_policy, abstract_function=abstract_function) diff --git a/tests/approaches/test_agent_bilevel_approach.py b/tests/approaches/test_agent_bilevel_approach.py index 8d5023842..8eb756203 100644 --- a/tests/approaches/test_agent_bilevel_approach.py +++ b/tests/approaches/test_agent_bilevel_approach.py @@ -141,10 +141,14 @@ def test_basic_subgoals(self): assert len(result) == 2 # First step: Holding(block0) assert result[0] is not None - assert GroundAtom(_Holding, [_block0]) in result[0] + pos, neg = result[0] + assert GroundAtom(_Holding, [_block0]) in pos + assert len(neg) == 0 # Second step: On(block0, block1) assert result[1] is not None - assert GroundAtom(_On, [_block0, _block1]) in result[1] + pos2, neg2 = result[1] + assert GroundAtom(_On, [_block0, _block1]) in pos2 + assert len(neg2) == 0 def test_no_subgoals(self): """Test no subgoals.""" @@ -184,9 +188,11 @@ def test_multiple_atoms_in_subgoal(self): assert len(result) == 1 assert result[0] is not None - assert len(result[0]) == 2 - assert GroundAtom(_On, [_block0, _block1]) in result[0] - assert GroundAtom(_HandEmpty, [_robot]) in result[0] + pos, neg = result[0] + assert len(pos) == 2 + assert len(neg) == 0 + assert GroundAtom(_On, [_block0, _block1]) in pos + assert GroundAtom(_HandEmpty, [_robot]) in pos def test_unknown_predicate_skipped(self): """Test unknown predicate skipped.""" @@ -230,9 +236,11 @@ def test_typed_object_refs_in_subgoals(self): assert len(result) == 2 assert result[0] is not None - assert GroundAtom(_Holding, [_block0]) in result[0] + pos, _ = result[0] + assert GroundAtom(_Holding, [_block0]) in pos assert result[1] is not None - assert GroundAtom(_On, [_block0, _block1]) in result[1] + pos2, _ = result[1] + assert GroundAtom(_On, [_block0, _block1]) in pos2 def test_preamble_ignored(self): """Non-option lines should be ignored.""" @@ -257,7 +265,192 @@ def test_whitespace_in_atoms(self): assert len(result) == 1 assert result[0] is not None - assert GroundAtom(_On, [_block0, _block1]) in result[0] + pos, _ = result[0] + assert GroundAtom(_On, [_block0, _block1]) in pos + + def test_not_atoms_in_subgoals(self): + """Test NOT prefix for negative target atoms.""" + approach, _, _ = _make_approach() + text = ("Wait(robot0:robot) -> " + "{Holding(block0:block), NOT On(block0:block, block1:block)}\n") + result = approach._parse_subgoal_annotations(text, _ALL_PREDICATES, + _ALL_OBJECTS) + + assert len(result) == 1 + assert result[0] is not None + pos, neg = result[0] + assert GroundAtom(_Holding, [_block0]) in pos + assert GroundAtom(_On, [_block0, _block1]) in neg + + +# --------------------------------------------------------------------------- +# Tests: check_wait_target_atoms +# --------------------------------------------------------------------------- + + +class TestCheckWaitTargetAtoms: + """Tests that Wait terminates on target atoms, not noisy changes.""" + + def test_no_targets_returns_none(self): + """No targets in memory -> returns None (fall back to any-change).""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + # No targets in memory + state = _make_state({_block0: [0.0, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + result = utils.check_wait_target_atoms(opt, state, abstract_fn) + assert result is None + + def test_positive_target_met(self): + """Wait terminates when positive target atom holds.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + target_atom = GroundAtom(_Holding, [_block0]) + opt.memory["wait_target_atoms"] = {target_atom} + + # State where Holding(block0) is true (held > 0.5) + state_held = _make_state({_block0: [0.0, 0.0, 1.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state_held, abstract_fn) \ + is True + + def test_positive_target_not_met(self): + """Wait does NOT terminate when target atom doesn't hold yet.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + target_atom = GroundAtom(_Holding, [_block0]) + opt.memory["wait_target_atoms"] = {target_atom} + + # State where Holding(block0) is false (held <= 0.5) + state_not_held = _make_state({_block0: [0.0, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms( + opt, state_not_held, abstract_fn) is False + + def test_noisy_atom_change_ignored_with_targets(self): + """Wait ignores noisy atom changes when specific targets are set. + + This is the key test: if the Wait is parameterized with a target + atom (e.g. Holding(block0)), it should NOT terminate when a + different atom changes (e.g. On(block0, block1)). + """ + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + # Only waiting for Holding(block0) + target_atom = GroundAtom(_Holding, [_block0]) + opt.memory["wait_target_atoms"] = {target_atom} + + # State where On(block0, block1) is true (noisy change) but + # Holding(block0) is still false + state_noisy = _make_state({_block0: [0.5, 0.0, 0.0], + _block1: [0.5, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + atoms = abstract_fn(state_noisy) + # On is true (positions are close), but Holding is false + assert GroundAtom(_On, [_block0, _block1]) in atoms + assert GroundAtom(_Holding, [_block0]) not in atoms + + # Wait should NOT terminate (target not met, despite On changing) + assert utils.check_wait_target_atoms( + opt, state_noisy, abstract_fn) is False + + def test_negative_target_met(self): + """Wait terminates when negative target atom is false.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + neg_atom = GroundAtom(_On, [_block0, _block1]) + opt.memory["wait_target_neg_atoms"] = {neg_atom} + + # State where On(block0, block1) is false (positions far apart) + state = _make_state({_block0: [0.0, 0.0, 0.0], + _block1: [5.0, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state, abstract_fn) is True + + def test_negative_target_not_met(self): + """Wait does NOT terminate when negative target atom is still true.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + neg_atom = GroundAtom(_On, [_block0, _block1]) + opt.memory["wait_target_neg_atoms"] = {neg_atom} + + # State where On(block0, block1) is true (positions close) + state = _make_state({_block0: [0.5, 0.0, 0.0], + _block1: [0.5, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state, abstract_fn) is False + + def test_mixed_positive_and_negative_targets(self): + """Both positive and negative targets must be satisfied.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + opt.memory["wait_target_atoms"] = {GroundAtom(_Holding, [_block0])} + opt.memory["wait_target_neg_atoms"] = { + GroundAtom(_On, [_block0, _block1]) + } + + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + + # Only positive met (Holding true, On still true) + state1 = _make_state({_block0: [0.5, 0.0, 1.0], + _block1: [0.5, 0.0, 0.0]}) + assert utils.check_wait_target_atoms(opt, state1, abstract_fn) is False + + # Only negative met (On false, Holding false) + state2 = _make_state({_block0: [0.0, 0.0, 0.0], + _block1: [5.0, 0.0, 0.0]}) + assert utils.check_wait_target_atoms(opt, state2, abstract_fn) is False + + # Both met (Holding true, On false) + state3 = _make_state({_block0: [0.0, 0.0, 1.0], + _block1: [5.0, 0.0, 0.0]}) + assert utils.check_wait_target_atoms(opt, state3, abstract_fn) is True + + +# --------------------------------------------------------------------------- +# Tests: parse_wait_target_annotations and strip_wait_annotations +# --------------------------------------------------------------------------- + + +class TestWaitTargetParsing: + """Tests for parse_wait_target_annotations and strip_wait_annotations.""" + + def test_parse_positive_target(self): + """Parse a positive target atom.""" + line = "Wait(robot0:robot) -> {Holding(block0:block)}" + pos, neg = utils.parse_wait_target_annotations( + line, _ALL_PREDICATES, _ALL_OBJECTS) + assert GroundAtom(_Holding, [_block0]) in pos + assert len(neg) == 0 + + def test_parse_negative_target(self): + """Parse a NOT-prefixed target atom.""" + line = "Wait(robot0:robot) -> {NOT On(block0:block, block1:block)}" + pos, neg = utils.parse_wait_target_annotations( + line, _ALL_PREDICATES, _ALL_OBJECTS) + assert len(pos) == 0 + assert GroundAtom(_On, [_block0, _block1]) in neg + + def test_parse_mixed_targets(self): + """Parse both positive and negative target atoms.""" + line = ("Wait(robot0:robot) -> " + "{Holding(block0:block), NOT On(block0:block, block1:block)}") + pos, neg = utils.parse_wait_target_annotations( + line, _ALL_PREDICATES, _ALL_OBJECTS) + assert GroundAtom(_Holding, [_block0]) in pos + assert GroundAtom(_On, [_block0, _block1]) in neg + + def test_parse_no_annotation(self): + """Line without -> returns empty sets.""" + line = "Wait(robot0:robot)[]" + pos, neg = utils.parse_wait_target_annotations( + line, _ALL_PREDICATES, _ALL_OBJECTS) + assert len(pos) == 0 + assert len(neg) == 0 + + def test_strip_annotations(self): + """strip_wait_annotations removes -> {...} suffixes.""" + text = ("Pick(block0:block)[0.5]\n" + "Wait(robot0:robot)[] -> {Holding(block0:block)}\n" + "Place(block0:block, block1:block)[0.1, 0.2]\n") + stripped = utils.strip_wait_annotations(text) + assert "-> {" not in stripped + assert "Pick(block0:block)[0.5]" in stripped + assert "Wait(robot0:robot)[]" in stripped + assert "Place(block0:block, block1:block)[0.1, 0.2]" in stripped # --------------------------------------------------------------------------- From 6290ff09511c9e990d3229d2c82257aeb17f4a75 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Apr 2026 13:32:50 +0100 Subject: [PATCH 2/3] Fix formatting and exclude logs dir from mypy --- conftest.py | 3 + mypy.ini | 2 +- .../approaches/agent_bilevel_approach.py | 18 ++--- .../approaches/agent_planner_approach.py | 3 +- predicators/option_model.py | 3 +- predicators/utils.py | 24 +++---- .../approaches/test_agent_bilevel_approach.py | 65 +++++++++++-------- 7 files changed, 68 insertions(+), 50 deletions(-) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..a04be0df0 --- /dev/null +++ b/conftest.py @@ -0,0 +1,3 @@ +"""Root pytest configuration.""" + +collect_ignore_glob = ["logs/*"] diff --git a/mypy.ini b/mypy.ini index 17dba5d4d..2204ebf83 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ strict_equality = True disallow_untyped_calls = True warn_unreachable = True -exclude = (predicators/envs/assets|venv|prompts) +exclude = (predicators/envs/assets|venv|prompts|logs) [mypy-predicators.*] disallow_untyped_defs = True diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index df2882744..7dbb7819c 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -299,13 +299,16 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: sg = subgoals[i] if i < len(subgoals) else None if sg is not None: pos, neg = sg - sketch.append(_SketchStep( - option=option, objects=objs, - subgoal_atoms=pos if pos else None, - subgoal_neg_atoms=neg if neg else None)) + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) else: - sketch.append(_SketchStep( - option=option, objects=objs, subgoal_atoms=None)) + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=None)) logging.info(f"[{self._run_id}] Agent produced sketch with " f"{len(sketch)} steps, " @@ -613,8 +616,7 @@ def _terminal( # pylint: disable=cell-var-from-loop f"Option '{oc.name}' got stuck.") if (CFG.wait_option_terminate_on_atom_change and oc.name == "Wait"): - result = utils.check_wait_target_atoms( - oc, s, _abs) + result = utils.check_wait_target_atoms(oc, s, _abs) if result is True: last_state_ref[0] = s return True diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index ac0c7e7a9..f178fc76b 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -680,8 +680,7 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: params_arr = np.array(params, dtype=np.float32) ground_opt = option.ground(objs, params_arr) # Inject Wait target atoms from annotations - if (ground_opt.name == "Wait" - and i < len(wait_annotations)): + if (ground_opt.name == "Wait" and i < len(wait_annotations)): pos, neg = wait_annotations[i] if pos: ground_opt.memory["wait_target_atoms"] = pos diff --git a/predicators/option_model.py b/predicators/option_model.py index 2a3251deb..9af23cd51 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -140,7 +140,8 @@ def _terminal(s: State) -> bool: result = utils.check_wait_target_atoms( option_copy, s, self._abstract_function) if result is True: - logging.info("Wait terminating: target atoms satisfied") + logging.info( + "Wait terminating: target atoms satisfied") last_state = s return True if result is None: diff --git a/predicators/utils.py b/predicators/utils.py index 6ba2fbe9d..7181522b0 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1640,9 +1640,8 @@ def parse_wait_target_annotations( logging.warning("Unknown object in Wait target: %s", e) continue if len(objs) != len(pred.types): - logging.warning( - "Arity mismatch for %s: expected %d, got %d", - pred_name, len(pred.types), len(objs)) + logging.warning("Arity mismatch for %s: expected %d, got %d", + pred_name, len(pred.types), len(objs)) continue atom = GroundAtom(pred, objs) if is_neg: @@ -1729,8 +1728,8 @@ def _policy(state: State) -> Action: and cur_option.name == "Wait": assert abstract_function is not None assert last_state is not None - result = check_wait_target_atoms( - cur_option, state, abstract_function) + result = check_wait_target_atoms(cur_option, state, + abstract_function) if result is True: logging.debug("Wait terminating: target atoms satisfied") wait_terminate = True @@ -1739,10 +1738,9 @@ def _policy(state: State) -> Action: cur_atoms = abstract_function(state) prev_atoms = abstract_function(last_state) if cur_atoms != prev_atoms: - logging.debug( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms-prev_atoms)} " - f"Del: {sorted(prev_atoms-cur_atoms)}") + logging.debug(f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms-prev_atoms)} " + f"Del: {sorted(prev_atoms-cur_atoms)}") wait_terminate = True last_state = state @@ -1892,8 +1890,7 @@ def _option_policy(state: State) -> _Option: cur_process = process_queue.pop(0) cur_option = cur_process.sample_option(state, goal, rng) if atoms_seq is not None: - inject_wait_targets_for_option( - cur_option, step_idx, atoms_seq) + inject_wait_targets_for_option(cur_option, step_idx, atoms_seq) step_idx += 1 logging.debug(f"Using option {cur_option.name}{cur_option.objects}" f"{cur_option.params} from process plan.") @@ -1912,7 +1909,10 @@ def process_plan_to_greedy_policy( ) -> Callable[[State], Action]: """Convert a process plan to a greedy policy.""" option_policy = process_plan_to_greedy_option_policy( - process_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq, + process_plan, + goal, + rng, + necessary_atoms_seq=necessary_atoms_seq, atoms_seq=atoms_seq) return option_policy_to_policy(option_policy, abstract_function=abstract_function) diff --git a/tests/approaches/test_agent_bilevel_approach.py b/tests/approaches/test_agent_bilevel_approach.py index 8eb756203..4d399883d 100644 --- a/tests/approaches/test_agent_bilevel_approach.py +++ b/tests/approaches/test_agent_bilevel_approach.py @@ -271,8 +271,9 @@ def test_whitespace_in_atoms(self): def test_not_atoms_in_subgoals(self): """Test NOT prefix for negative target atoms.""" approach, _, _ = _make_approach() - text = ("Wait(robot0:robot) -> " - "{Holding(block0:block), NOT On(block0:block, block1:block)}\n") + text = ( + "Wait(robot0:robot) -> " + "{Holding(block0:block), NOT On(block0:block, block1:block)}\n") result = approach._parse_subgoal_annotations(text, _ALL_PREDICATES, _ALL_OBJECTS) @@ -321,8 +322,8 @@ def test_positive_target_not_met(self): # State where Holding(block0) is false (held <= 0.5) state_not_held = _make_state({_block0: [0.0, 0.0, 0.0]}) abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) - assert utils.check_wait_target_atoms( - opt, state_not_held, abstract_fn) is False + assert utils.check_wait_target_atoms(opt, state_not_held, + abstract_fn) is False def test_noisy_atom_change_ignored_with_targets(self): """Wait ignores noisy atom changes when specific targets are set. @@ -338,8 +339,10 @@ def test_noisy_atom_change_ignored_with_targets(self): # State where On(block0, block1) is true (noisy change) but # Holding(block0) is still false - state_noisy = _make_state({_block0: [0.5, 0.0, 0.0], - _block1: [0.5, 0.0, 0.0]}) + state_noisy = _make_state({ + _block0: [0.5, 0.0, 0.0], + _block1: [0.5, 0.0, 0.0] + }) abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) atoms = abstract_fn(state_noisy) # On is true (positions are close), but Holding is false @@ -347,8 +350,8 @@ def test_noisy_atom_change_ignored_with_targets(self): assert GroundAtom(_Holding, [_block0]) not in atoms # Wait should NOT terminate (target not met, despite On changing) - assert utils.check_wait_target_atoms( - opt, state_noisy, abstract_fn) is False + assert utils.check_wait_target_atoms(opt, state_noisy, + abstract_fn) is False def test_negative_target_met(self): """Wait terminates when negative target atom is false.""" @@ -357,8 +360,10 @@ def test_negative_target_met(self): opt.memory["wait_target_neg_atoms"] = {neg_atom} # State where On(block0, block1) is false (positions far apart) - state = _make_state({_block0: [0.0, 0.0, 0.0], - _block1: [5.0, 0.0, 0.0]}) + state = _make_state({ + _block0: [0.0, 0.0, 0.0], + _block1: [5.0, 0.0, 0.0] + }) abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) assert utils.check_wait_target_atoms(opt, state, abstract_fn) is True @@ -369,8 +374,10 @@ def test_negative_target_not_met(self): opt.memory["wait_target_neg_atoms"] = {neg_atom} # State where On(block0, block1) is true (positions close) - state = _make_state({_block0: [0.5, 0.0, 0.0], - _block1: [0.5, 0.0, 0.0]}) + state = _make_state({ + _block0: [0.5, 0.0, 0.0], + _block1: [0.5, 0.0, 0.0] + }) abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) assert utils.check_wait_target_atoms(opt, state, abstract_fn) is False @@ -385,18 +392,24 @@ def test_mixed_positive_and_negative_targets(self): abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) # Only positive met (Holding true, On still true) - state1 = _make_state({_block0: [0.5, 0.0, 1.0], - _block1: [0.5, 0.0, 0.0]}) + state1 = _make_state({ + _block0: [0.5, 0.0, 1.0], + _block1: [0.5, 0.0, 0.0] + }) assert utils.check_wait_target_atoms(opt, state1, abstract_fn) is False # Only negative met (On false, Holding false) - state2 = _make_state({_block0: [0.0, 0.0, 0.0], - _block1: [5.0, 0.0, 0.0]}) + state2 = _make_state({ + _block0: [0.0, 0.0, 0.0], + _block1: [5.0, 0.0, 0.0] + }) assert utils.check_wait_target_atoms(opt, state2, abstract_fn) is False # Both met (Holding true, On false) - state3 = _make_state({_block0: [0.0, 0.0, 1.0], - _block1: [5.0, 0.0, 0.0]}) + state3 = _make_state({ + _block0: [0.0, 0.0, 1.0], + _block1: [5.0, 0.0, 0.0] + }) assert utils.check_wait_target_atoms(opt, state3, abstract_fn) is True @@ -411,16 +424,16 @@ class TestWaitTargetParsing: def test_parse_positive_target(self): """Parse a positive target atom.""" line = "Wait(robot0:robot) -> {Holding(block0:block)}" - pos, neg = utils.parse_wait_target_annotations( - line, _ALL_PREDICATES, _ALL_OBJECTS) + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) assert GroundAtom(_Holding, [_block0]) in pos assert len(neg) == 0 def test_parse_negative_target(self): """Parse a NOT-prefixed target atom.""" line = "Wait(robot0:robot) -> {NOT On(block0:block, block1:block)}" - pos, neg = utils.parse_wait_target_annotations( - line, _ALL_PREDICATES, _ALL_OBJECTS) + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) assert len(pos) == 0 assert GroundAtom(_On, [_block0, _block1]) in neg @@ -428,16 +441,16 @@ def test_parse_mixed_targets(self): """Parse both positive and negative target atoms.""" line = ("Wait(robot0:robot) -> " "{Holding(block0:block), NOT On(block0:block, block1:block)}") - pos, neg = utils.parse_wait_target_annotations( - line, _ALL_PREDICATES, _ALL_OBJECTS) + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) assert GroundAtom(_Holding, [_block0]) in pos assert GroundAtom(_On, [_block0, _block1]) in neg def test_parse_no_annotation(self): """Line without -> returns empty sets.""" line = "Wait(robot0:robot)[]" - pos, neg = utils.parse_wait_target_annotations( - line, _ALL_PREDICATES, _ALL_OBJECTS) + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) assert len(pos) == 0 assert len(neg) == 0 From d40b5ed1564ffea31dd02bf2dcacc224656d0318 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Apr 2026 14:53:40 +0100 Subject: [PATCH 3/3] Fix pylint unused-argument warning in _inject_wait_targets --- predicators/planning_with_processes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/predicators/planning_with_processes.py b/predicators/planning_with_processes.py index 74b057e60..e23f325b9 100644 --- a/predicators/planning_with_processes.py +++ b/predicators/planning_with_processes.py @@ -1078,7 +1078,7 @@ def sesame_plan_with_processes( def _inject_wait_targets( plan: List[_Option], - skeleton: List[_GroundEndogenousProcess], + _skeleton: List[_GroundEndogenousProcess], atoms_sequence: Sequence[Set[GroundAtom]], ) -> None: """Inject Wait target atoms into all Wait options in a plan."""