Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Root pytest configuration."""

collect_ignore_glob = ["logs/*"]
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
strict_equality = True
disallow_untyped_calls = True
warn_unreachable = True
exclude = (predicators/envs/assets|venv|prompts)
exclude = (predicators/envs/assets|venv|prompts|logs)

[mypy-predicators.*]
disallow_untyped_defs = True
Expand Down
87 changes: 68 additions & 19 deletions predicators/approaches/agent_bilevel_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class _SketchStep:
option: ParameterizedOption
objects: Sequence[Object]
subgoal_atoms: Optional[Set[GroundAtom]] # None = no subgoal constraint
# Atoms that must be FALSE after this step.
subgoal_neg_atoms: Optional[Set[GroundAtom]] = None


class AgentBilevelApproach(AgentPlannerApproach):
Expand Down Expand Up @@ -75,7 +77,11 @@ def _get_agent_system_prompt(self) -> str:
" OptionName(obj1:type1, obj2:type2) -> {Pred(obj1:type1), "
"Pred2(obj1:type1, obj2:type2)}\n"
"Always use typed references (obj:type) in subgoal atoms.\n"
"Subgoal annotations are optional but improve search efficiency.")
"Subgoal annotations are optional but improve search efficiency.\n"
"For Wait steps, the annotation also specifies exactly when the "
"Wait should terminate. Use `NOT Pred(...)` for atoms that should "
"become false (e.g. `Wait(robot:Robot) -> "
"{Boiled(water:water_type)}`).")

# ------------------------------------------------------------------ #
# Solve prompt (no continuous params, subgoal format)
Expand Down Expand Up @@ -173,11 +179,16 @@ def _build_solve_prompt(self, task: Task) -> str:
helps the search verify progress. Use `-> {{atoms}}` after each step.

After any action whose desired subgoal depends on a delayed process (e.g. \
water filling, dominoes cascading, heating), insert a Wait action.
water filling, dominoes cascading, heating), insert a Wait action. For Wait \
steps, annotate with the atoms the process should produce — this tells the \
system exactly when the Wait should end rather than terminating on any \
incidental atom change. Use `NOT Pred(...)` for atoms that should become false.

Output the plan sketch with one option per line in this format:
OptionName(obj1:type1, obj2:type2) -> \
{{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}}
Wait(robot:Robot) -> {{Boiled(water:water_type)}}
Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}}

Always use typed references (obj:type) in both option arguments AND subgoal \
atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \
Expand Down Expand Up @@ -286,8 +297,18 @@ def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]:
sketch = []
for i, (option, objs, _) in enumerate(parsed):
sg = subgoals[i] if i < len(subgoals) else None
sketch.append(
_SketchStep(option=option, objects=objs, subgoal_atoms=sg))
if sg is not None:
pos, neg = sg
sketch.append(
_SketchStep(option=option,
objects=objs,
subgoal_atoms=pos if pos else None,
subgoal_neg_atoms=neg if neg else None))
else:
sketch.append(
_SketchStep(option=option,
objects=objs,
subgoal_atoms=None))

logging.info(f"[{self._run_id}] Agent produced sketch with "
f"{len(sketch)} steps, "
Expand All @@ -300,21 +321,22 @@ def _parse_subgoal_annotations(
text: str,
predicates: Set[Predicate],
objects: Sequence[Object],
) -> List[Optional[Set[GroundAtom]]]:
"""Parse ``-> {Pred(obj1, obj2), ...}`` annotations from plan text.
) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]:
"""Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text.

Returns a list parallel to the option lines. Entries are None
for lines without annotations.
for lines without annotations. Each non-None entry is
``(positive_atoms, negative_atoms)``.
"""
pred_map = {p.name: p for p in predicates}
obj_map = {o.name: o for o in objects}

# Regex: match -> { ... } after the option line
subgoal_re = re.compile(r'->\s*\{([^}]*)\}')
# Regex: match individual atoms like Pred(obj1, obj2)
atom_re = re.compile(r'(\w+)\(([^)]*)\)')
# Regex: match individual atoms, optionally prefixed with NOT
atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)')

results: List[Optional[Set[GroundAtom]]] = []
results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = []
option_names = {o.name for o in self._get_all_options()}

for line in text.split('\n'):
Expand All @@ -333,13 +355,15 @@ def _parse_subgoal_annotations(
continue

atoms_text = sg_match.group(1)
atoms: Set[GroundAtom] = set()
pos_atoms: Set[GroundAtom] = set()
neg_atoms: Set[GroundAtom] = set()
for atom_match in atom_re.finditer(atoms_text):
pred_name = atom_match.group(1)
is_neg = atom_match.group(1) is not None
pred_name = atom_match.group(2)
# Handle both "obj" and "obj:type" formats
obj_names = [
n.strip().split(':')[0]
for n in atom_match.group(2).split(',')
for n in atom_match.group(3).split(',')
]

if pred_name not in pred_map:
Expand All @@ -357,9 +381,16 @@ def _parse_subgoal_annotations(
f"Arity mismatch for {pred_name}: expected "
f"{len(pred.types)}, got {len(objs)}")
continue
atoms.add(GroundAtom(pred, objs))
atom = GroundAtom(pred, objs)
if is_neg:
neg_atoms.add(atom)
else:
pos_atoms.add(atom)

results.append(atoms if atoms else None)
if pos_atoms or neg_atoms:
results.append((pos_atoms, neg_atoms))
else:
results.append(None)

return results

Expand Down Expand Up @@ -424,6 +455,13 @@ def _refine_sketch(
# Sample continuous parameters and ground option
params = self._sample_params(step.option, cur_state, rng)
grounded = step.option.ground(step.objects, params)
# Inject Wait target atoms from sketch annotations
if grounded.name == "Wait":
if step.subgoal_atoms is not None:
grounded.memory["wait_target_atoms"] = step.subgoal_atoms
if step.subgoal_neg_atoms is not None:
grounded.memory["wait_target_neg_atoms"] = \
step.subgoal_neg_atoms
plan[cur_idx] = grounded

state = cur_state
Expand Down Expand Up @@ -547,6 +585,10 @@ def _validate_plan_forward(
grounded.parent)
option_copy = env_param_opt.ground(grounded.objects,
grounded.params.copy())
# Propagate Wait target atoms through re-grounding
for key in ("wait_target_atoms", "wait_target_neg_atoms"):
if key in grounded.memory:
option_copy.memory[key] = grounded.memory[key]

if not option_copy.initiable(state):
logging.info(f"Forward validation: step {i} "
Expand All @@ -558,10 +600,12 @@ def _validate_plan_forward(
# 2. terminate_on_repeat (stuck detection)
# 3. wait_option_terminate_on_atom_change
last_state_ref: List[Optional[State]] = [None]
abstract_fn = lambda s, _p=predicates: utils.abstract(s, _p)

def _terminal( # pylint: disable=cell-var-from-loop
s: State,
oc: _Option = option_copy) -> bool:
oc: _Option = option_copy,
_abs: Callable = abstract_fn) -> bool:
if oc.terminal(s):
return True
prev = last_state_ref[0]
Expand All @@ -572,11 +616,16 @@ def _terminal( # pylint: disable=cell-var-from-loop
f"Option '{oc.name}' got stuck.")
if (CFG.wait_option_terminate_on_atom_change
and oc.name == "Wait"):
cur_atoms = utils.abstract(s, predicates)
prev_atoms = utils.abstract(prev, predicates)
if cur_atoms != prev_atoms:
result = utils.check_wait_target_atoms(oc, s, _abs)
if result is True:
last_state_ref[0] = s
return True
if result is None:
cur_atoms = _abs(s)
prev_atoms = _abs(prev)
if cur_atoms != prev_atoms:
last_state_ref[0] = s
return True
last_state_ref[0] = s
return False

Expand Down
5 changes: 4 additions & 1 deletion predicators/approaches/agent_option_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def _get_agent_system_prompt(self) -> str:
No continuous params.
- `create_move_to_skill(name, types, params_space, config, \
get_target_pose_fn)` — move end-effector to a target pose
- `create_wait_option(name, config, robot_type)` — hold current pose
- `create_wait_option(name, config, robot_type)` — hold current pose; \
annotate with `-> {{atoms}}` in the plan to specify when it should \
terminate (e.g. `Wait(robot:Robot) -> {{Boiled(water:water_type)}}`). \
Use `NOT Pred(...)` for atoms that should become false

All factories (except `create_place_skill` and `create_wait_option`) \
take a `SkillConfig` (available as `skill_config` in the exec \
Expand Down
78 changes: 66 additions & 12 deletions predicators/approaches/agent_planner_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import inspect as _inspect
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, \
cast

import dill as pkl
import numpy as np
Expand All @@ -28,9 +29,9 @@
from predicators.explorers.base_explorer import BaseExplorer
from predicators.option_model import _OptionModelBase, create_option_model
from predicators.settings import CFG
from predicators.structs import Action, Dataset, InteractionRequest, \
InteractionResult, LowLevelTrajectory, ParameterizedOption, Predicate, \
State, Task, Type
from predicators.structs import Action, Dataset, GroundAtom, \
InteractionRequest, InteractionResult, LowLevelTrajectory, Object, \
ParameterizedOption, Predicate, State, Task, Type


class AgentPlannerApproach(AgentSessionMixin, BaseApproach):
Expand Down Expand Up @@ -122,8 +123,12 @@ def _get_all_trajectories(self) -> List[LowLevelTrajectory]:
"delayed process (e.g. water filling, dominoes cascading, "
"heating), insert a Wait after it so the effect has time "
"to occur before the next action. The Wait action holds the "
"robot's current pose and terminates once the abstract state "
"changes. Without a Wait, the robot will proceed to the next "
"robot's current pose. You can annotate Wait with target atoms "
"using `-> {atoms}` to specify exactly when it should terminate "
"(e.g. `Wait(robot:Robot)[] -> {Boiled(water:water_type)}`). "
"Use `NOT Pred(...)` for atoms that should become false. "
"If no annotation is provided, the Wait terminates on any atom "
"change. Without a Wait, the robot will proceed to the next "
"action before the delayed effect has occurred, which might "
"cause the plan to fail.")

Expand Down Expand Up @@ -511,10 +516,14 @@ def _build_solve_prompt(self, task: Task) -> str:

After any action whose desired subgoal depends on a delayed process (e.g. water \
filling, dominoes cascading, heating), insert a Wait action to let the process \
complete before proceeding. The Wait terminates once the abstract state changes. \
Without it, the plan will move on before the effect has occurred and fail. Only use \
Wait when there is a genuine delayed effect; do not insert it between actions with \
immediate effects (e.g. Pick, Place).
complete before proceeding. You can annotate Wait with target atoms using \
`-> {{atoms}}` to specify exactly when it should terminate. Use `NOT Pred(...)` for \
atoms that should become false. If no annotation is provided, the Wait terminates on \
any atom change. Only use Wait when there is a genuine delayed effect; do not insert \
it between actions with immediate effects (e.g. Pick, Place).

For Wait with target atoms: `Wait(robot:Robot)[] -> {{Boiled(water:water_type)}}`
For negated targets: `Wait(robot:Robot)[] -> {{NOT Touching(a:block, b:block)}}`

**Important — parameter tuning workflow:**
- When a step fails or produces unexpected results, inspect the rendered images \
Expand Down Expand Up @@ -607,6 +616,37 @@ def _strip_code_fences(text: str) -> str:
lines.pop()
return '\n'.join(lines)

def _parse_wait_annotations(
self,
text: str,
predicates: Set[Predicate],
objects: Sequence[Object],
) -> List[Tuple[Set[GroundAtom], Set[GroundAtom]]]:
"""Parse ``-> {atoms}`` annotations from plan lines.

Returns a list parallel to the option lines in the text. Each
entry is ``(positive_atoms, negative_atoms)`` for Wait lines
with annotations, or ``(set(), set())`` otherwise.
"""
results: List[Tuple[Set[GroundAtom], Set[GroundAtom]]] = []
option_names = {o.name for o in self._get_all_options()}
for line in text.split('\n'):
stripped = line.strip()
if not stripped:
continue
first_token = stripped.split('(')[0]
if first_token not in option_names:
if results:
break
continue
if first_token == "Wait" and '->' in stripped:
pos, neg = utils.parse_wait_target_annotations(
stripped, predicates, objects)
results.append((pos, neg))
else:
results.append((set(), set()))
return results

def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list:
"""Parse option plan text and ground into executable options."""
objects = list(task.init)
Expand All @@ -616,8 +656,15 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list:
# Strip markdown code fences that agents often wrap plans in.
cleaned_text = self._strip_code_fences(plan_text)

# Extract Wait target annotations before stripping them
wait_annotations = self._parse_wait_annotations(
cleaned_text, self._get_all_predicates(), objects)

# Strip annotations so the option plan parser doesn't choke
parseable_text = utils.strip_wait_annotations(cleaned_text)

parsed = utils.parse_model_output_into_option_plan(
cleaned_text,
parseable_text,
objects,
self._types,
all_options,
Expand All @@ -628,10 +675,17 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list:
f" Available option names: {option_names}")

grounded = []
for option, objs, params in parsed:
for i, (option, objs, params) in enumerate(parsed):
try:
params_arr = np.array(params, dtype=np.float32)
ground_opt = option.ground(objs, params_arr)
# Inject Wait target atoms from annotations
if (ground_opt.name == "Wait" and i < len(wait_annotations)):
pos, neg = wait_annotations[i]
if pos:
ground_opt.memory["wait_target_atoms"] = pos
if neg:
ground_opt.memory["wait_target_neg_atoms"] = neg
grounded.append(ground_opt)
except Exception as e: # pylint: disable=broad-except
logging.warning(
Expand Down
3 changes: 2 additions & 1 deletion predicators/approaches/process_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
process_plan,
task.goal,
self._rng,
abstract_function=lambda s: utils.abstract(s, preds))
abstract_function=lambda s: utils.abstract(s, preds),
atoms_seq=atoms_seq)
logging.debug("Current Task Plan:")
for process in process_plan:
logging.debug(process.name)
Expand Down
Loading
Loading