diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..a04be0df0 --- /dev/null +++ b/conftest.py @@ -0,0 +1,3 @@ +"""Root pytest configuration.""" + +collect_ignore_glob = ["logs/*"] diff --git a/mypy.ini b/mypy.ini index 17dba5d4d..2204ebf83 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ strict_equality = True disallow_untyped_calls = True warn_unreachable = True -exclude = (predicators/envs/assets|venv|prompts) +exclude = (predicators/envs/assets|venv|prompts|logs) [mypy-predicators.*] disallow_untyped_defs = True diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index f2125fdc4..7dbb7819c 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -34,6 +34,8 @@ class _SketchStep: option: ParameterizedOption objects: Sequence[Object] subgoal_atoms: Optional[Set[GroundAtom]] # None = no subgoal constraint + # Atoms that must be FALSE after this step. + subgoal_neg_atoms: Optional[Set[GroundAtom]] = None class AgentBilevelApproach(AgentPlannerApproach): @@ -75,7 +77,11 @@ def _get_agent_system_prompt(self) -> str: " OptionName(obj1:type1, obj2:type2) -> {Pred(obj1:type1), " "Pred2(obj1:type1, obj2:type2)}\n" "Always use typed references (obj:type) in subgoal atoms.\n" - "Subgoal annotations are optional but improve search efficiency.") + "Subgoal annotations are optional but improve search efficiency.\n" + "For Wait steps, the annotation also specifies exactly when the " + "Wait should terminate. Use `NOT Pred(...)` for atoms that should " + "become false (e.g. `Wait(robot:Robot) -> " + "{Boiled(water:water_type)}`).") # ------------------------------------------------------------------ # # Solve prompt (no continuous params, subgoal format) @@ -173,11 +179,16 @@ def _build_solve_prompt(self, task: Task) -> str: helps the search verify progress. Use `-> {{atoms}}` after each step. After any action whose desired subgoal depends on a delayed process (e.g. \ -water filling, dominoes cascading, heating), insert a Wait action. +water filling, dominoes cascading, heating), insert a Wait action. For Wait \ +steps, annotate with the atoms the process should produce — this tells the \ +system exactly when the Wait should end rather than terminating on any \ +incidental atom change. Use `NOT Pred(...)` for atoms that should become false. Output the plan sketch with one option per line in this format: OptionName(obj1:type1, obj2:type2) -> \ {{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}} + Wait(robot:Robot) -> {{Boiled(water:water_type)}} + Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}} Always use typed references (obj:type) in both option arguments AND subgoal \ atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \ @@ -286,8 +297,18 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: sketch = [] for i, (option, objs, _) in enumerate(parsed): sg = subgoals[i] if i < len(subgoals) else None - sketch.append( - _SketchStep(option=option, objects=objs, subgoal_atoms=sg)) + if sg is not None: + pos, neg = sg + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=None)) logging.info(f"[{self._run_id}] Agent produced sketch with " f"{len(sketch)} steps, " @@ -300,21 +321,22 @@ def _parse_subgoal_annotations( text: str, predicates: Set[Predicate], objects: Sequence[Object], - ) -> List[Optional[Set[GroundAtom]]]: - """Parse ``-> {Pred(obj1, obj2), ...}`` annotations from plan text. + ) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. Returns a list parallel to the option lines. Entries are None - for lines without annotations. + for lines without annotations. Each non-None entry is + ``(positive_atoms, negative_atoms)``. """ pred_map = {p.name: p for p in predicates} obj_map = {o.name: o for o in objects} # Regex: match -> { ... } after the option line subgoal_re = re.compile(r'->\s*\{([^}]*)\}') - # Regex: match individual atoms like Pred(obj1, obj2) - atom_re = re.compile(r'(\w+)\(([^)]*)\)') + # Regex: match individual atoms, optionally prefixed with NOT + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') - results: List[Optional[Set[GroundAtom]]] = [] + results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] option_names = {o.name for o in self._get_all_options()} for line in text.split('\n'): @@ -333,13 +355,15 @@ def _parse_subgoal_annotations( continue atoms_text = sg_match.group(1) - atoms: Set[GroundAtom] = set() + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() for atom_match in atom_re.finditer(atoms_text): - pred_name = atom_match.group(1) + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) # Handle both "obj" and "obj:type" formats obj_names = [ n.strip().split(':')[0] - for n in atom_match.group(2).split(',') + for n in atom_match.group(3).split(',') ] if pred_name not in pred_map: @@ -357,9 +381,16 @@ def _parse_subgoal_annotations( f"Arity mismatch for {pred_name}: expected " f"{len(pred.types)}, got {len(objs)}") continue - atoms.add(GroundAtom(pred, objs)) + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) - results.append(atoms if atoms else None) + if pos_atoms or neg_atoms: + results.append((pos_atoms, neg_atoms)) + else: + results.append(None) return results @@ -424,6 +455,13 @@ def _refine_sketch( # Sample continuous parameters and ground option params = self._sample_params(step.option, cur_state, rng) grounded = step.option.ground(step.objects, params) + # Inject Wait target atoms from sketch annotations + if grounded.name == "Wait": + if step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + if step.subgoal_neg_atoms is not None: + grounded.memory["wait_target_neg_atoms"] = \ + step.subgoal_neg_atoms plan[cur_idx] = grounded state = cur_state @@ -547,6 +585,10 @@ def _validate_plan_forward( grounded.parent) option_copy = env_param_opt.ground(grounded.objects, grounded.params.copy()) + # Propagate Wait target atoms through re-grounding + for key in ("wait_target_atoms", "wait_target_neg_atoms"): + if key in grounded.memory: + option_copy.memory[key] = grounded.memory[key] if not option_copy.initiable(state): logging.info(f"Forward validation: step {i} " @@ -558,10 +600,12 @@ def _validate_plan_forward( # 2. terminate_on_repeat (stuck detection) # 3. wait_option_terminate_on_atom_change last_state_ref: List[Optional[State]] = [None] + abstract_fn = lambda s, _p=predicates: utils.abstract(s, _p) def _terminal( # pylint: disable=cell-var-from-loop s: State, - oc: _Option = option_copy) -> bool: + oc: _Option = option_copy, + _abs: Callable = abstract_fn) -> bool: if oc.terminal(s): return True prev = last_state_ref[0] @@ -572,11 +616,16 @@ def _terminal( # pylint: disable=cell-var-from-loop f"Option '{oc.name}' got stuck.") if (CFG.wait_option_terminate_on_atom_change and oc.name == "Wait"): - cur_atoms = utils.abstract(s, predicates) - prev_atoms = utils.abstract(prev, predicates) - if cur_atoms != prev_atoms: + result = utils.check_wait_target_atoms(oc, s, _abs) + if result is True: last_state_ref[0] = s return True + if result is None: + cur_atoms = _abs(s) + prev_atoms = _abs(prev) + if cur_atoms != prev_atoms: + last_state_ref[0] = s + return True last_state_ref[0] = s return False diff --git a/predicators/approaches/agent_option_learning_approach.py b/predicators/approaches/agent_option_learning_approach.py index 9473332d0..f9a3f54ab 100644 --- a/predicators/approaches/agent_option_learning_approach.py +++ b/predicators/approaches/agent_option_learning_approach.py @@ -99,7 +99,10 @@ def _get_agent_system_prompt(self) -> str: No continuous params. - `create_move_to_skill(name, types, params_space, config, \ get_target_pose_fn)` — move end-effector to a target pose -- `create_wait_option(name, config, robot_type)` — hold current pose +- `create_wait_option(name, config, robot_type)` — hold current pose; \ +annotate with `-> {{atoms}}` in the plan to specify when it should \ +terminate (e.g. `Wait(robot:Robot) -> {{Boiled(water:water_type)}}`). \ +Use `NOT Pred(...)` for atoms that should become false All factories (except `create_place_skill` and `create_wait_option`) \ take a `SkillConfig` (available as `skill_config` in the exec \ diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index 0745a24e3..f178fc76b 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -14,7 +14,8 @@ import inspect as _inspect import logging import os -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, \ + cast import dill as pkl import numpy as np @@ -28,9 +29,9 @@ from predicators.explorers.base_explorer import BaseExplorer from predicators.option_model import _OptionModelBase, create_option_model from predicators.settings import CFG -from predicators.structs import Action, Dataset, InteractionRequest, \ - InteractionResult, LowLevelTrajectory, ParameterizedOption, Predicate, \ - State, Task, Type +from predicators.structs import Action, Dataset, GroundAtom, \ + InteractionRequest, InteractionResult, LowLevelTrajectory, Object, \ + ParameterizedOption, Predicate, State, Task, Type class AgentPlannerApproach(AgentSessionMixin, BaseApproach): @@ -122,8 +123,12 @@ def _get_all_trajectories(self) -> List[LowLevelTrajectory]: "delayed process (e.g. water filling, dominoes cascading, " "heating), insert a Wait after it so the effect has time " "to occur before the next action. The Wait action holds the " - "robot's current pose and terminates once the abstract state " - "changes. Without a Wait, the robot will proceed to the next " + "robot's current pose. You can annotate Wait with target atoms " + "using `-> {atoms}` to specify exactly when it should terminate " + "(e.g. `Wait(robot:Robot)[] -> {Boiled(water:water_type)}`). " + "Use `NOT Pred(...)` for atoms that should become false. " + "If no annotation is provided, the Wait terminates on any atom " + "change. Without a Wait, the robot will proceed to the next " "action before the delayed effect has occurred, which might " "cause the plan to fail.") @@ -511,10 +516,14 @@ def _build_solve_prompt(self, task: Task) -> str: After any action whose desired subgoal depends on a delayed process (e.g. water \ filling, dominoes cascading, heating), insert a Wait action to let the process \ -complete before proceeding. The Wait terminates once the abstract state changes. \ -Without it, the plan will move on before the effect has occurred and fail. Only use \ -Wait when there is a genuine delayed effect; do not insert it between actions with \ -immediate effects (e.g. Pick, Place). +complete before proceeding. You can annotate Wait with target atoms using \ +`-> {{atoms}}` to specify exactly when it should terminate. Use `NOT Pred(...)` for \ +atoms that should become false. If no annotation is provided, the Wait terminates on \ +any atom change. Only use Wait when there is a genuine delayed effect; do not insert \ +it between actions with immediate effects (e.g. Pick, Place). + +For Wait with target atoms: `Wait(robot:Robot)[] -> {{Boiled(water:water_type)}}` +For negated targets: `Wait(robot:Robot)[] -> {{NOT Touching(a:block, b:block)}}` **Important — parameter tuning workflow:** - When a step fails or produces unexpected results, inspect the rendered images \ @@ -607,6 +616,37 @@ def _strip_code_fences(text: str) -> str: lines.pop() return '\n'.join(lines) + def _parse_wait_annotations( + self, + text: str, + predicates: Set[Predicate], + objects: Sequence[Object], + ) -> List[Tuple[Set[GroundAtom], Set[GroundAtom]]]: + """Parse ``-> {atoms}`` annotations from plan lines. + + Returns a list parallel to the option lines in the text. Each + entry is ``(positive_atoms, negative_atoms)`` for Wait lines + with annotations, or ``(set(), set())`` otherwise. + """ + results: List[Tuple[Set[GroundAtom], Set[GroundAtom]]] = [] + option_names = {o.name for o in self._get_all_options()} + for line in text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + if results: + break + continue + if first_token == "Wait" and '->' in stripped: + pos, neg = utils.parse_wait_target_annotations( + stripped, predicates, objects) + results.append((pos, neg)) + else: + results.append((set(), set())) + return results + def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: """Parse option plan text and ground into executable options.""" objects = list(task.init) @@ -616,8 +656,15 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: # Strip markdown code fences that agents often wrap plans in. cleaned_text = self._strip_code_fences(plan_text) + # Extract Wait target annotations before stripping them + wait_annotations = self._parse_wait_annotations( + cleaned_text, self._get_all_predicates(), objects) + + # Strip annotations so the option plan parser doesn't choke + parseable_text = utils.strip_wait_annotations(cleaned_text) + parsed = utils.parse_model_output_into_option_plan( - cleaned_text, + parseable_text, objects, self._types, all_options, @@ -628,10 +675,17 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: f" Available option names: {option_names}") grounded = [] - for option, objs, params in parsed: + for i, (option, objs, params) in enumerate(parsed): try: params_arr = np.array(params, dtype=np.float32) ground_opt = option.ground(objs, params_arr) + # Inject Wait target atoms from annotations + if (ground_opt.name == "Wait" and i < len(wait_annotations)): + pos, neg = wait_annotations[i] + if pos: + ground_opt.memory["wait_target_atoms"] = pos + if neg: + ground_opt.memory["wait_target_neg_atoms"] = neg grounded.append(ground_opt) except Exception as e: # pylint: disable=broad-except logging.warning( diff --git a/predicators/approaches/process_planning_approach.py b/predicators/approaches/process_planning_approach.py index 627e48c65..40a8e644f 100644 --- a/predicators/approaches/process_planning_approach.py +++ b/predicators/approaches/process_planning_approach.py @@ -99,7 +99,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: process_plan, task.goal, self._rng, - abstract_function=lambda s: utils.abstract(s, preds)) + abstract_function=lambda s: utils.abstract(s, preds), + atoms_seq=atoms_seq) logging.debug("Current Task Plan:") for process in process_plan: logging.debug(process.name) diff --git a/predicators/option_model.py b/predicators/option_model.py index 5b35c0c30..9af23cd51 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -104,6 +104,10 @@ def get_next_state_and_num_actions(self, state: State, assert np.allclose(env_param_opt.params_space.high, param_opt.params_space.high) option_copy = env_param_opt.ground(option.objects, option.params) + # Propagate Wait target atoms through re-grounding + for key in ("wait_target_atoms", "wait_target_neg_atoms"): + if key in option.memory: + option_copy.memory[key] = option.memory[key] del option # unused after this assert option_copy.initiable(state) @@ -128,20 +132,28 @@ def _terminal(s: State) -> bool: f"produced a no-op (e.g. IK returned current " f"joints, or finger command matched current " f"finger state).") - # Terminate Wait on atom change, mirroring - # option_policy_to_policy in utils.py. + # Terminate Wait on target atoms or any atom change. if (CFG.wait_option_terminate_on_atom_change and option_copy.name == "Wait" and last_state is not DefaultState and self._abstract_function is not None): - cur_atoms = self._abstract_function(s) - prev_atoms = self._abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.info(f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") + result = utils.check_wait_target_atoms( + option_copy, s, self._abstract_function) + if result is True: + logging.info( + "Wait terminating: target atoms satisfied") last_state = s return True + if result is None: + cur_atoms = self._abstract_function(s) + prev_atoms = self._abstract_function(last_state) + if cur_atoms != prev_atoms: + logging.info( + f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") + last_state = s + return True last_state = s return False else: @@ -155,14 +167,21 @@ def _terminal(s: State) -> bool: if option_copy.terminal(s): return True if last_state_ref[0] is not DefaultState: - cur_atoms = abstract_fn(s) - prev_atoms = abstract_fn(last_state_ref[0]) - if cur_atoms != prev_atoms: + result = utils.check_wait_target_atoms( + option_copy, s, abstract_fn) + if result is True: logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") + "Wait terminating: target atoms satisfied") return True + if result is None: + cur_atoms = abstract_fn(s) + prev_atoms = abstract_fn(last_state_ref[0]) + if cur_atoms != prev_atoms: + logging.info( + f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") + return True last_state_ref[0] = s return False else: diff --git a/predicators/planning_with_processes.py b/predicators/planning_with_processes.py index a2c2a839f..e23f325b9 100644 --- a/predicators/planning_with_processes.py +++ b/predicators/planning_with_processes.py @@ -12,7 +12,7 @@ from itertools import islice from pprint import pformat from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, \ - Set, Tuple + Sequence, Set, Tuple import numpy as np @@ -1061,6 +1061,9 @@ def sesame_plan_with_processes( metrics["plan_length"] = len(plan) metrics["refinement_time"] = (time.perf_counter() - refinement_start_time) + # Inject Wait target atoms from atoms_sequence so + # execution terminates on specific atoms, not noise. + _inject_wait_targets(plan, skeleton, atoms_sequence) return plan, skeleton, metrics partial_refinements.append((skeleton, plan)) @@ -1073,6 +1076,16 @@ def sesame_plan_with_processes( info={"partial_refinements": partial_refinements}) +def _inject_wait_targets( + plan: List[_Option], + _skeleton: List[_GroundEndogenousProcess], + atoms_sequence: Sequence[Set[GroundAtom]], +) -> None: + """Inject Wait target atoms into all Wait options in a plan.""" + for i, option in enumerate(plan): + utils.inject_wait_targets_for_option(option, i, atoms_sequence) + + def create_ff_heuristic( goal: Set[GroundAtom], ground_processes: List[_GroundCausalProcess], diff --git a/predicators/utils.py b/predicators/utils.py index 94284e9c7..7181522b0 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1584,6 +1584,106 @@ def __str__(self) -> str: return repr(self) +def check_wait_target_atoms( + option: _Option, + state: State, + abstract_function: Callable[[State], Set[GroundAtom]], +) -> Optional[bool]: + """Check if a Wait option's target atoms are satisfied. + + Returns True if targets are met (Wait should terminate), False if + not yet met, or None if no targets were specified (caller should + fall back to any-atom-change behaviour). + """ + pos = option.memory.get("wait_target_atoms", set()) + neg = option.memory.get("wait_target_neg_atoms", set()) + if not pos and not neg: + return None + cur_atoms = abstract_function(state) + return pos.issubset(cur_atoms) and neg.isdisjoint(cur_atoms) + + +def parse_wait_target_annotations( + line: str, + predicates: Collection[Predicate], + objects: Collection[Object], +) -> Tuple[Set[GroundAtom], Set[GroundAtom]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` from a plan line. + + Returns ``(positive_atoms, negative_atoms)`` where positive atoms + must become TRUE and negative atoms must become FALSE for the Wait + to terminate. + """ + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + + sg_match = re.search(r'->\s*\{([^}]*)\}', line) + if not sg_match: + return set(), set() + + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + for m in atom_re.finditer(sg_match.group(1)): + is_neg = m.group(1) is not None + pred_name = m.group(2) + obj_names = [n.strip().split(':')[0] for n in m.group(3).split(',')] + + if pred_name not in pred_map: + logging.warning("Unknown predicate in Wait target: %s", pred_name) + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError as e: + logging.warning("Unknown object in Wait target: %s", e) + continue + if len(objs) != len(pred.types): + logging.warning("Arity mismatch for %s: expected %d, got %d", + pred_name, len(pred.types), len(objs)) + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + + return pos_atoms, neg_atoms + + +def inject_wait_targets_for_option( + option: _Option, + step_idx: int, + atoms_sequence: Sequence[Set[GroundAtom]], +) -> None: + """Inject Wait target atoms into a single option from atoms_sequence. + + Computes the expected atom delta from ``atoms_sequence[step_idx]`` + to ``atoms_sequence[step_idx + 1]`` and stores it in the option's + memory so that execution terminates on specific atoms rather than + any noisy change. No-op for non-Wait options or out-of-bounds + indices. + """ + if option.name != "Wait": + return + if step_idx + 1 >= len(atoms_sequence): + return + before = atoms_sequence[step_idx] + after = atoms_sequence[step_idx + 1] + target_pos = after - before + target_neg = before - after + if target_pos: + option.memory["wait_target_atoms"] = target_pos + if target_neg: + option.memory["wait_target_neg_atoms"] = target_neg + + +def strip_wait_annotations(text: str) -> str: + """Remove ``-> {...}`` annotations from plan text lines.""" + return re.sub(r'\s*->\s*\{[^}]*\}', '', text) + + def option_policy_to_policy( option_policy: Callable[[State], _Option], max_option_steps: Optional[int] = None, @@ -1628,13 +1728,20 @@ def _policy(state: State) -> Action: and cur_option.name == "Wait": assert abstract_function is not None assert last_state is not None - cur_atoms = abstract_function(state) - prev_atoms = abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.debug(f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms-prev_atoms)} " - f"Del: {sorted(prev_atoms-cur_atoms)}") + result = check_wait_target_atoms(cur_option, state, + abstract_function) + if result is True: + logging.debug("Wait terminating: target atoms satisfied") wait_terminate = True + elif result is None: + # No targets specified: fall back to any-atom-change + cur_atoms = abstract_function(state) + prev_atoms = abstract_function(last_state) + if cur_atoms != prev_atoms: + logging.debug(f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms-prev_atoms)} " + f"Del: {sorted(prev_atoms-cur_atoms)}") + wait_terminate = True last_state = state @@ -1753,6 +1860,7 @@ def process_plan_to_greedy_option_policy( goal: Set[GroundAtom], rng: np.random.Generator, necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, ) -> Callable[[State], _Option]: """Greedily execute a process plan, assuming downward refinability and that any sample will work. @@ -1769,9 +1877,10 @@ def process_plan_to_greedy_option_policy( ] assert len(necessary_atoms_seq) == len(process_plan) + 1 necessary_atoms_queue = list(necessary_atoms_seq) + step_idx = 0 def _option_policy(state: State) -> _Option: - nonlocal cur_process + nonlocal cur_process, step_idx if not process_queue: raise OptionExecutionFailure("Process plan exhausted.") expected_atoms = necessary_atoms_queue.pop(0) @@ -1780,6 +1889,9 @@ def _option_policy(state: State) -> _Option: "Executing the process failed to achieve the necessary atoms.") cur_process = process_queue.pop(0) cur_option = cur_process.sample_option(state, goal, rng) + if atoms_seq is not None: + inject_wait_targets_for_option(cur_option, step_idx, atoms_seq) + step_idx += 1 logging.debug(f"Using option {cur_option.name}{cur_option.objects}" f"{cur_option.params} from process plan.") return cur_option @@ -1792,11 +1904,16 @@ def process_plan_to_greedy_policy( goal: Set[GroundAtom], rng: np.random.Generator, necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, - abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None, + atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, ) -> Callable[[State], Action]: """Convert a process plan to a greedy policy.""" option_policy = process_plan_to_greedy_option_policy( - process_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq) + process_plan, + goal, + rng, + necessary_atoms_seq=necessary_atoms_seq, + atoms_seq=atoms_seq) return option_policy_to_policy(option_policy, abstract_function=abstract_function) diff --git a/tests/approaches/test_agent_bilevel_approach.py b/tests/approaches/test_agent_bilevel_approach.py index 8d5023842..4d399883d 100644 --- a/tests/approaches/test_agent_bilevel_approach.py +++ b/tests/approaches/test_agent_bilevel_approach.py @@ -141,10 +141,14 @@ def test_basic_subgoals(self): assert len(result) == 2 # First step: Holding(block0) assert result[0] is not None - assert GroundAtom(_Holding, [_block0]) in result[0] + pos, neg = result[0] + assert GroundAtom(_Holding, [_block0]) in pos + assert len(neg) == 0 # Second step: On(block0, block1) assert result[1] is not None - assert GroundAtom(_On, [_block0, _block1]) in result[1] + pos2, neg2 = result[1] + assert GroundAtom(_On, [_block0, _block1]) in pos2 + assert len(neg2) == 0 def test_no_subgoals(self): """Test no subgoals.""" @@ -184,9 +188,11 @@ def test_multiple_atoms_in_subgoal(self): assert len(result) == 1 assert result[0] is not None - assert len(result[0]) == 2 - assert GroundAtom(_On, [_block0, _block1]) in result[0] - assert GroundAtom(_HandEmpty, [_robot]) in result[0] + pos, neg = result[0] + assert len(pos) == 2 + assert len(neg) == 0 + assert GroundAtom(_On, [_block0, _block1]) in pos + assert GroundAtom(_HandEmpty, [_robot]) in pos def test_unknown_predicate_skipped(self): """Test unknown predicate skipped.""" @@ -230,9 +236,11 @@ def test_typed_object_refs_in_subgoals(self): assert len(result) == 2 assert result[0] is not None - assert GroundAtom(_Holding, [_block0]) in result[0] + pos, _ = result[0] + assert GroundAtom(_Holding, [_block0]) in pos assert result[1] is not None - assert GroundAtom(_On, [_block0, _block1]) in result[1] + pos2, _ = result[1] + assert GroundAtom(_On, [_block0, _block1]) in pos2 def test_preamble_ignored(self): """Non-option lines should be ignored.""" @@ -257,7 +265,205 @@ def test_whitespace_in_atoms(self): assert len(result) == 1 assert result[0] is not None - assert GroundAtom(_On, [_block0, _block1]) in result[0] + pos, _ = result[0] + assert GroundAtom(_On, [_block0, _block1]) in pos + + def test_not_atoms_in_subgoals(self): + """Test NOT prefix for negative target atoms.""" + approach, _, _ = _make_approach() + text = ( + "Wait(robot0:robot) -> " + "{Holding(block0:block), NOT On(block0:block, block1:block)}\n") + result = approach._parse_subgoal_annotations(text, _ALL_PREDICATES, + _ALL_OBJECTS) + + assert len(result) == 1 + assert result[0] is not None + pos, neg = result[0] + assert GroundAtom(_Holding, [_block0]) in pos + assert GroundAtom(_On, [_block0, _block1]) in neg + + +# --------------------------------------------------------------------------- +# Tests: check_wait_target_atoms +# --------------------------------------------------------------------------- + + +class TestCheckWaitTargetAtoms: + """Tests that Wait terminates on target atoms, not noisy changes.""" + + def test_no_targets_returns_none(self): + """No targets in memory -> returns None (fall back to any-change).""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + # No targets in memory + state = _make_state({_block0: [0.0, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + result = utils.check_wait_target_atoms(opt, state, abstract_fn) + assert result is None + + def test_positive_target_met(self): + """Wait terminates when positive target atom holds.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + target_atom = GroundAtom(_Holding, [_block0]) + opt.memory["wait_target_atoms"] = {target_atom} + + # State where Holding(block0) is true (held > 0.5) + state_held = _make_state({_block0: [0.0, 0.0, 1.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state_held, abstract_fn) \ + is True + + def test_positive_target_not_met(self): + """Wait does NOT terminate when target atom doesn't hold yet.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + target_atom = GroundAtom(_Holding, [_block0]) + opt.memory["wait_target_atoms"] = {target_atom} + + # State where Holding(block0) is false (held <= 0.5) + state_not_held = _make_state({_block0: [0.0, 0.0, 0.0]}) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state_not_held, + abstract_fn) is False + + def test_noisy_atom_change_ignored_with_targets(self): + """Wait ignores noisy atom changes when specific targets are set. + + This is the key test: if the Wait is parameterized with a target + atom (e.g. Holding(block0)), it should NOT terminate when a + different atom changes (e.g. On(block0, block1)). + """ + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + # Only waiting for Holding(block0) + target_atom = GroundAtom(_Holding, [_block0]) + opt.memory["wait_target_atoms"] = {target_atom} + + # State where On(block0, block1) is true (noisy change) but + # Holding(block0) is still false + state_noisy = _make_state({ + _block0: [0.5, 0.0, 0.0], + _block1: [0.5, 0.0, 0.0] + }) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + atoms = abstract_fn(state_noisy) + # On is true (positions are close), but Holding is false + assert GroundAtom(_On, [_block0, _block1]) in atoms + assert GroundAtom(_Holding, [_block0]) not in atoms + + # Wait should NOT terminate (target not met, despite On changing) + assert utils.check_wait_target_atoms(opt, state_noisy, + abstract_fn) is False + + def test_negative_target_met(self): + """Wait terminates when negative target atom is false.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + neg_atom = GroundAtom(_On, [_block0, _block1]) + opt.memory["wait_target_neg_atoms"] = {neg_atom} + + # State where On(block0, block1) is false (positions far apart) + state = _make_state({ + _block0: [0.0, 0.0, 0.0], + _block1: [5.0, 0.0, 0.0] + }) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state, abstract_fn) is True + + def test_negative_target_not_met(self): + """Wait does NOT terminate when negative target atom is still true.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + neg_atom = GroundAtom(_On, [_block0, _block1]) + opt.memory["wait_target_neg_atoms"] = {neg_atom} + + # State where On(block0, block1) is true (positions close) + state = _make_state({ + _block0: [0.5, 0.0, 0.0], + _block1: [0.5, 0.0, 0.0] + }) + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + assert utils.check_wait_target_atoms(opt, state, abstract_fn) is False + + def test_mixed_positive_and_negative_targets(self): + """Both positive and negative targets must be satisfied.""" + opt = _Wait.ground([_robot], np.array([], dtype=np.float32)) + opt.memory["wait_target_atoms"] = {GroundAtom(_Holding, [_block0])} + opt.memory["wait_target_neg_atoms"] = { + GroundAtom(_On, [_block0, _block1]) + } + + abstract_fn = lambda s: utils.abstract(s, _ALL_PREDICATES) + + # Only positive met (Holding true, On still true) + state1 = _make_state({ + _block0: [0.5, 0.0, 1.0], + _block1: [0.5, 0.0, 0.0] + }) + assert utils.check_wait_target_atoms(opt, state1, abstract_fn) is False + + # Only negative met (On false, Holding false) + state2 = _make_state({ + _block0: [0.0, 0.0, 0.0], + _block1: [5.0, 0.0, 0.0] + }) + assert utils.check_wait_target_atoms(opt, state2, abstract_fn) is False + + # Both met (Holding true, On false) + state3 = _make_state({ + _block0: [0.0, 0.0, 1.0], + _block1: [5.0, 0.0, 0.0] + }) + assert utils.check_wait_target_atoms(opt, state3, abstract_fn) is True + + +# --------------------------------------------------------------------------- +# Tests: parse_wait_target_annotations and strip_wait_annotations +# --------------------------------------------------------------------------- + + +class TestWaitTargetParsing: + """Tests for parse_wait_target_annotations and strip_wait_annotations.""" + + def test_parse_positive_target(self): + """Parse a positive target atom.""" + line = "Wait(robot0:robot) -> {Holding(block0:block)}" + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) + assert GroundAtom(_Holding, [_block0]) in pos + assert len(neg) == 0 + + def test_parse_negative_target(self): + """Parse a NOT-prefixed target atom.""" + line = "Wait(robot0:robot) -> {NOT On(block0:block, block1:block)}" + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) + assert len(pos) == 0 + assert GroundAtom(_On, [_block0, _block1]) in neg + + def test_parse_mixed_targets(self): + """Parse both positive and negative target atoms.""" + line = ("Wait(robot0:robot) -> " + "{Holding(block0:block), NOT On(block0:block, block1:block)}") + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) + assert GroundAtom(_Holding, [_block0]) in pos + assert GroundAtom(_On, [_block0, _block1]) in neg + + def test_parse_no_annotation(self): + """Line without -> returns empty sets.""" + line = "Wait(robot0:robot)[]" + pos, neg = utils.parse_wait_target_annotations(line, _ALL_PREDICATES, + _ALL_OBJECTS) + assert len(pos) == 0 + assert len(neg) == 0 + + def test_strip_annotations(self): + """strip_wait_annotations removes -> {...} suffixes.""" + text = ("Pick(block0:block)[0.5]\n" + "Wait(robot0:robot)[] -> {Holding(block0:block)}\n" + "Place(block0:block, block1:block)[0.1, 0.2]\n") + stripped = utils.strip_wait_annotations(text) + assert "-> {" not in stripped + assert "Pick(block0:block)[0.5]" in stripped + assert "Wait(robot0:robot)[]" in stripped + assert "Place(block0:block, block1:block)[0.1, 0.2]" in stripped # ---------------------------------------------------------------------------