diff --git a/.github/workflows/predicators.yml b/.github/workflows/predicators.yml index 1739fbb188..3dbbd14027 100644 --- a/.github/workflows/predicators.yml +++ b/.github/workflows/predicators.yml @@ -14,8 +14,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/setup.py' - run: | pip install -e . pip install pytest-cov==2.12.1 @@ -35,8 +33,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/setup.py' - name: Install dependencies run: | pip install -e . @@ -55,8 +51,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/setup.py' - name: Install dependencies run: | pip install -e . @@ -77,8 +71,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/setup.py' - name: Install dependencies run: | pip install yapf==0.32.0 @@ -101,8 +93,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/setup.py' - name: Install dependencies run: | pip install isort==5.10.1 @@ -122,8 +112,6 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/setup.py' - name: Install dependencies run: | pip install docformatter==1.4 diff --git a/.gitignore b/.gitignore index 74eeb654f3..5ee0f47369 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__ *.pyc .DS_Store +CLAUDE.md .vscode *.egg-info *.pkl @@ -33,5 +34,7 @@ Gymnasium-Robotics/ predicators/datasets/vlm_input_data_prompts/vision_api/prompt.txt predicators/datasets/vlm_input_data_prompts/vision_api/response.txt +.mypy_cache/ + # Jetbrains IDEs .idea/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..65ce0393ac --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,3 @@ +# Instructions + +- Never add Claude as a co-author on git commits (no `Co-Authored-By` lines) diff --git a/README.md b/README.md index 136f5a636a..b2317dfda8 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ A simple implementation of search-then-sample bilevel planning is provided in `p Our code assumes that python hashing is deterministic between processes, which is [not true by default](https://stackoverflow.com/questions/30585108/disable-hash-randomization-from-within-python-program). Please make sure to `export PYTHONHASHSEED=0` when running the code. You can add this line to your bash profile, or prepend `export PYTHONHASHSEED=0` to any command line call, e.g., `export PYTHONHASHSEED=0 python predicators/main.py --env ...`. +### Weights & Biases (wandb) Setup +If you want to use wandb for experiment tracking: +* Set `use_wandb = True` in your config or pass `--use_wandb` as a command line argument +* Set your wandb API key: `export WANDB_API_KEY=your_api_key_here` +* Or run `wandb login` to authenticate interactively + ### Locally * (recommended) Make a new virtual env or conda env. * Run, e.g., `python predicators/main.py --env cover --approach oracle --seed 0` to run the system. diff --git a/mypy.ini b/mypy.ini index adea295d11..17dba5d4df 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ strict_equality = True disallow_untyped_calls = True warn_unreachable = True -exclude = (predicators/envs/assets|venv) +exclude = (predicators/envs/assets|venv|prompts) [mypy-predicators.*] disallow_untyped_defs = True diff --git a/predicators/args.py b/predicators/args.py index 537c45fffd..30feaf983a 100644 --- a/predicators/args.py +++ b/predicators/args.py @@ -26,10 +26,13 @@ def create_arg_parser(env_required: bool = True, parser.add_argument("--timeout", default=10, type=float) parser.add_argument("--make_test_videos", action="store_true") parser.add_argument("--make_failure_videos", action="store_true") + parser.add_argument("--make_test_images", action="store_true") + parser.add_argument("--make_failure_images", action="store_true") parser.add_argument("--make_interaction_videos", action="store_true") parser.add_argument("--make_demo_videos", action="store_true") parser.add_argument("--make_demo_images", action="store_true") parser.add_argument("--make_cogman_videos", action="store_true") + parser.add_argument("--video_not_break_on_exception", action="store_true") parser.add_argument("--load_approach", action="store_true") # In the case of online learning approaches, load_approach by itself # will try to load an approach on *every* online learning cycle. @@ -51,4 +54,7 @@ def create_arg_parser(env_required: bool = True, const=logging.DEBUG, default=logging.INFO) parser.add_argument("--crash_on_failure", action="store_true") + parser.add_argument("--excluded_objects_in_state_str", + default="", + type=str) return parser diff --git a/predicators/settings.py b/predicators/settings.py index 8bf8974ad6..d519078ab8 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -20,6 +20,9 @@ class GlobalSettings: # transitions have been collected, whichever happens first. num_online_learning_cycles = 10 online_learning_max_transitions = float("inf") + online_learning_early_stopping = False + skip_test_until_last_ite_or_early_stopping = False + online_learning_early_stopping_by_test_solve_rate = False # just for plotting # Maximum number of training tasks to give a demonstration for, if the # offline_data_method is demo-based. max_initial_demos = float("inf") @@ -55,6 +58,7 @@ class GlobalSettings: # either of its arguments is not None. allow_state_allclose_comparison_despite_simulator_state = False + env_include_bbox_features = False # cover_multistep_options env parameters cover_multistep_action_limits = [-np.inf, np.inf] cover_multistep_degenerate_oracle_samplers = False @@ -64,6 +68,7 @@ class GlobalSettings: cover_multistep_bhr_percent = 0.4 # block hand region percent of width cover_multistep_bimodal_goal = False cover_multistep_goal_conditioned_sampling = False # assumes one goal + cover_blocks_change_color_when_cover = False # bumpy cover env parameters bumpy_cover_num_bumps = 2 @@ -80,6 +85,32 @@ class GlobalSettings: blocks_num_blocks_test = [5, 6] blocks_holding_goals = False blocks_block_size = 0.045 # use 0.0505 for real with panda + blocks_high_towers_are_unstable = False + + # balance env parameters + balance_num_blocks_train = [2, 4] + balance_num_blocks_test = [4, 6] + # balance_num_blocks_test = [2] + balance_holding_goals = False + balance_block_size = 0.045 # use 0.0505 for real with panda + balance_wierd_balance = False + + # grow env parameters + grow_use_skill_factories = True # Use skill-factory-based option implementations + grow_plant_same_color_as_cup = False + grow_weak_pour_terminate_condition = False + grow_place_option_no_sampler = False + grow_num_cups_train = [2] + grow_num_cups_test = [2, 3] + grow_num_jugs_train = [2] + grow_num_jugs_test = [2] + + # laser env parameters + laser_zero_reflection_angle = False + laser_use_debug_line_for_beams = False + + # ants env params + ants_ants_attracted_to_points = False # playroom env parameters playroom_num_blocks_train = [3] @@ -150,6 +181,7 @@ class GlobalSettings: pybullet_birrt_num_iters = 100 pybullet_birrt_smooth_amt = 50 pybullet_birrt_extend_num_interp = 10 + pybullet_birrt_path_subsample_ratio = 1 pybullet_control_mode = "position" pybullet_max_vel_norm = 0.05 # env -> robot -> quaternion @@ -157,6 +189,7 @@ class GlobalSettings: # Fetch and Panda gripper down and parallel to x-axis by default. lambda: { "fetch": (0.5, -0.5, -0.5, -0.5), + "mobile_fetch": (0.5, -0.5, -0.5, -0.5), "panda": (0.7071, 0.7071, 0.0, 0.0), }, # In Blocks, Fetch gripper down since it's thin we don't need to @@ -164,9 +197,16 @@ class GlobalSettings: { "pybullet_blocks": { "fetch": (0.7071, 0.0, -0.7071, 0.0), + "mobile_fetch": (0.7071, 0.0, -0.7071, 0.0), + "panda": (0.7071, 0.7071, 0.0, 0.0), + }, + "pybullet_balance": { + "fetch": (0.7071, 0.0, -0.7071, 0.0), + "mobile_fetch": (0.7071, 0.0, -0.7071, 0.0), "panda": (0.7071, 0.7071, 0.0, 0.0), } }) + pybullet_ik_validate = True # IKFast parameters ikfast_max_time = 0.05 @@ -315,10 +355,26 @@ class GlobalSettings: exit_garage_motion_planning_ignore_obstacles = False exit_garage_raise_environment_failure = False + # skill phase parameters + skill_phase_use_motion_planning = False + # coffee env parameters coffee_num_cups_train = [1, 2] coffee_num_cups_test = [2, 3] coffee_jug_init_rot_amt = 2 * np.pi / 3 + coffee_rotated_jug_ratio = 0.5 + coffee_twist_sampler = True + coffee_combined_move_and_twist_policy = False + coffee_move_back_after_place_and_push = False + coffee_jug_pickable_pred = False + coffee_render_grid_world = False + coffee_simple_tasks = False + coffee_machine_have_light_bar = True + coffee_machine_has_plug = False + coffee_use_pixelated_jug = False + coffee_plug_break_after_plugged_in = False + coffee_fill_jug_gradually = False + coffee_use_skill_factories = True # Use skill-factory-based option implementations # satellites env parameters satellites_num_sat_train = [2, 3] @@ -354,6 +410,34 @@ class GlobalSettings: # grid row env parameters grid_row_num_cells = 100 + # float + float_water_level_doesnt_raise = False + + # domino + domino_debug_layout = False + domino_some_dominoes_are_connected = False + domino_initialize_at_finished_state = True + domino_use_domino_blocks_as_target = False + domino_use_grid = False + domino_include_connected_predicate = False + domino_has_glued_dominos = True + domino_prune_actions = False # Set to True to enable action pruning + domino_only_straight_sequence_in_training = True # Generate only straight sequences during training + domino_train_num_dominos = [2] + domino_test_num_dominos = [3] + domino_train_num_targets = [1] + domino_test_num_targets = [1, 2] + domino_train_num_pivots = [0] + domino_test_num_pivots = [0] + domino_train_num_pos_x = 3 + domino_train_num_pos_y = 2 + domino_test_num_pos_x = 4 # 5 is too large for robot to reach sometimes + domino_test_num_pos_y = 3 + domino_oracle_knows_glued_dominos = False + domino_use_continuous_place = False # Use PlaceContinuous option instead of Place + domino_restricted_push = False # When True, Push only targets the start block (no domino arg) + domino_use_skill_factories = True # Use skill_factories-based option implementations + # burger env parameters burger_render_set_of_marks = True # Which type of train/test tasks to generate. Options are "more_stacks", @@ -365,6 +449,58 @@ class GlobalSettings: # Number of test tasks where you start out holding a patty. burger_num_test_start_holding = 5 + # circuit + circuit_light_doesnt_need_battery = False + circuit_battery_in_box = False + + # fan env + fan_use_skill_factories = True # Use skill-factory-based option implementations + fan_fans_blow_opposite_direction = False + fan_known_controls_relation = True + fan_combine_switch_on_off = False + fan_use_kinematic = False + fan_train_num_pos_x = 3 + fan_train_num_pos_y = 3 + fan_test_num_pos_x = 6 # can do 9 + fan_test_num_pos_y = 4 + fan_train_num_walls_per_task = [1] + fan_test_num_walls_per_task = [2, 3] # can do 4 + + # domino_fan env (combined domino + fan environment) + domino_domino_on_stairs = False + domino_fan_use_grid = True + domino_fan_train_num_dominos = [3, 4] + domino_fan_test_num_dominos = [5, 6] + domino_fan_train_num_targets = [1] + domino_fan_test_num_targets = [1, 2] + domino_fan_train_num_walls = [2, 3] + domino_fan_test_num_walls = [3, 4] + domino_fan_train_grid_size = (5, 5) + domino_fan_test_grid_size = (6, 6) + domino_fan_ball_task_ratio = 0.5 # Fraction of tasks with ball goals vs domino goals + domino_fan_include_ball_in_domino_tasks = True # Include ball in domino tasks (as obstacle) + domino_fan_include_dominoes_in_ball_tasks = False # Include dominoes in ball tasks + domino_fan_ball_position_tolerance = 0.04 # Tolerance for ball reaching target + domino_fan_use_kinematic = True # Use kinematic ball movement (vs dynamic forces) + domino_fan_has_glued_dominoes = False # Include immovable glued dominoes + + # boil env + boil_use_skill_factories = True # Use skill-factory-based option implementations + boil_use_constant_delay = False + boil_use_normal_delay = True + boil_use_cmp_delay = False + boil_goal = "simple" # Can also be "task_completed", "human_happy" + boil_goal_simple_human_happy = False # Require a simpler condition for human happy + boil_use_derived_predicates = True + boil_require_jug_full_to_heatup = False + boil_goal_require_burner_off = True + boil_add_jug_reached_capacity_predicate = False + boil_num_jugs_train = [1] + boil_num_jugs_test = [1, 2] + boil_num_burner_train = [1] + boil_num_burner_test = [1] + boil_water_fill_speed = 0.002 + # parameters for random options approach random_options_max_tries = 100 @@ -420,6 +556,8 @@ class GlobalSettings: llm_model_name = "text-curie-001" # "text-davinci-002" llm_temperature = 0.5 llm_num_completions = 1 + # supported provider: "google", "openai", or "openrouter" + pretrained_model_service_provider = "openai" # parameters for vision language models # gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o @@ -432,11 +570,24 @@ class GlobalSettings: # parameters for the vlm_open_loop planning approach vlm_open_loop_use_training_demos = False + vlm_open_loop_no_image = False # Use object-centric state + + # parameters for the human_interaction_approach + human_interaction_approach_use_scripted_option = False + human_interaction_approach_use_all_options = False + scripted_option_dir = "scripted_options" + script_option_file_name = "scripted_plan.txt" + + # parameters for the human_low_level_control_approach + # Note: actual movement is limited by pybullet_max_vel_norm (default 0.05) + # For faster response, also increase pybullet_max_vel_norm + human_control_move_speed = 0.15 # meters per step (target delta) + human_control_rot_speed = 0.2 # radians per step # SeSamE parameters sesame_task_planner = "astar" # "astar" or "fdopt" or "fdsat" sesame_task_planning_heuristic = "lmcut" - sesame_allow_noops = True # recommended to keep this False if using replays + sesame_allow_waits = True # recommended to keep this False if using replays sesame_check_expected_atoms = True sesame_use_necessary_atoms = True sesame_use_visited_state_set = False @@ -459,6 +610,9 @@ class GlobalSettings: # observed states match (at the abstract level) the expected states, and # replan if not. But for now, we just execute each step without checking. bilevel_plan_without_sim = False + planning_filter_unreachable_nsrt = True + planning_check_dr_reachable = True + no_repeated_arguments_in_grounding = False # evaluation parameters log_dir = "logs" @@ -470,6 +624,10 @@ class GlobalSettings: image_dir = "images" video_fps = 2 failure_video_mode = "longest_only" + terminate_on_goal_reached = True + keep_failed_demos = False # For saving videos + terminate_on_goal_reached_and_option_terminated = False + env_has_impossible_goals = False # dataset parameters # For learning-based approaches, the data collection timeout for planning. @@ -495,6 +653,7 @@ class GlobalSettings: # STRIPS learning algorithm. See get_name() functions in the directory # nsrt_learning/strips_learning/ for valid settings. strips_learner = "cluster_and_intersect" + clustering_learner_check_effect_equality = True disable_harmlessness_check = False # some methods may want this to be True enable_harmless_op_pruning = False # some methods may want this to be True precondition_soft_intersection_threshold_percent = 0.8 # between 0 and 1 @@ -510,6 +669,7 @@ class GlobalSettings: cluster_and_search_score_func_max_groundings = 10000 cluster_and_search_var_count_weight = 0.1 cluster_and_search_precon_size_weight = 0.01 + cluster_and_search_llm_propose_batch_size = 4 cluster_and_intersect_prune_low_data_pnads = False # If cluster_and_intersect_prune_low_data_pnads is set to True, PNADs must # have at least this fraction of the segments produced by the option that is @@ -517,10 +677,61 @@ class GlobalSettings: # learning. cluster_and_intersect_min_datastore_fraction = 0.0 cluster_and_intersect_soft_intersection_for_preconditions = False + find_best_matching_pnad_skip_if_effect_not_subset = True + exogenous_process_learner = "cluster_and_intersect" + exogenous_process_learner_do_intersect = False + only_learn_exogenous_processes = False + learn_process_parameters = False + use_empirical_init_for_vi_params = False + pause_after_process_learning_for_inspection = False + learnable_delay_distribution = "cmp" # "constant", "cmp", "normal" + process_learner_check_false_positives = False + cluster_and_search_process_learner_parallel_condition = True + cluster_and_search_process_learner_parallel_pnad = False + process_learner_ablate_bayes = False + cluster_and_search_process_learner_llm_select_condition = False + cluster_and_search_process_learner_llm_rank_atoms = False + cluster_and_search_process_learner_llm_propose_top_conditions = False + process_learner_llm_atom_ranking_max_atoms = 10 + process_learner_llm_propose_conditions_k = 5 + cluster_and_search_vi_steps = 200 + cluster_search_max_workers = -1 + cluster_and_inverse_planning_candidates = "top_consistent" # "all", "top_consistent" + cluster_and_inverse_planning_top_consistent_method = "percentage" # "number", "percentage", "cost", "percentage_cost" + cluster_and_inverse_planning_top_consistent_num = -1 + cluster_and_inverse_planning_top_p_percent = 3 # percentage of top consistent candidates to use + cluster_and_inverse_planning_top_consistent_max_cost = 3 + cluster_process_learner_top_n_conditions = -1 + process_scoring_method = "data_likelihood" # "count_fp", "data_likelihood" + process_condition_search_complexity_weight = 1e-4 + process_param_learning_num_steps = 200 + process_param_learning_use_empirical = False + process_param_learning_patience = None + process_param_learning_batch_size = 16 + process_learning_use_empirical = False + process_condition_search_prune_with_fp_count = False + process_learning_learn_strength = True + process_learning_process_per_physical_core = True # Physical core vs logical core + process_learning_init_at_previous_results = False # Loading hasn't been very helpful + predicate_invent_neural_symbolic_predicates = False + predicate_invent_invent_derived_predicates = False + cluster_learning_one_effect_per_process = False + use_derived_predicate_in_heuristic = True + process_planning_heuristic_weight = 1.0 + build_exogenous_process_index_for_planning = True + process_planning_use_abstract_policy = False + process_planning_max_policy_guided_rollout = 10 + process_planning_set_parameters_one = False + process_task_planning_heuristic = 'h_ff' + wait_option_terminate_on_atom_change = True + running_no_invent_baseline = False # torch GPU usage setting use_torch_gpu = False + # wandb logging setting + use_wandb = False + # torch model parameters learning_rate = 1e-3 weight_decay = 0 @@ -581,6 +792,9 @@ class GlobalSettings: # online NSRT learning parameters online_nsrt_learning_requests_per_cycle = 10 online_learning_max_novelty_count = 0 + online_nsrt_learning_number_of_tasks_to_try = 1 + online_nsrt_learning_requests_per_task = 3 + online_learning_assert_no_exclude_pred = True # active sampler learning parameters active_sampler_learning_model = "myopic_classifier_mlp" @@ -605,6 +819,7 @@ class GlobalSettings: # maple q function parameters use_epsilon_annealing = True min_epsilon = 0.05 + maple_q_same_hla_option_param_space = True # skill competence model parameters skill_competence_model = "optimistic" @@ -657,11 +872,18 @@ class GlobalSettings: active_sampler_explorer_skip_perfect = True active_sampler_learning_init_cycles_to_pursue_goal = 1 + bilevel_planning_explorer_enumerate_plans = False + + exploit_bilevel_planning_explorer_fallback_explorer = "RandomOptions" + # grammar search invention parameters + grammar_search_grammar_use_single_feature = True grammar_search_grammar_includes_givens = True + grammar_search_grammar_includes_negation = True grammar_search_grammar_includes_foralls = True grammar_search_grammar_use_diff_features = False grammar_search_grammar_use_euclidean_dist = False + grammar_search_grammar_use_skip_grammar = True grammar_search_use_handcoded_debug_grammar = False grammar_search_forall_penalty = 1 grammar_search_pred_selection_approach = "score_optimization" @@ -675,6 +897,7 @@ class GlobalSettings: grammar_search_predicate_cost_upper_bound = 6 grammar_search_prune_redundant_preds = True grammar_search_score_function = "expected_nodes_created" + grammar_search_additional_bonus_for_matching_plan = 0 grammar_search_heuristic_based_weight = 10. grammar_search_max_demos = float("inf") grammar_search_max_nondemos = 50 @@ -690,8 +913,8 @@ class GlobalSettings: grammar_search_expected_nodes_upper_bound = 1e5 grammar_search_expected_nodes_optimal_demo_prob = 1 - 1e-5 grammar_search_expected_nodes_backtracking_cost = 1e3 - grammar_search_expected_nodes_allow_noops = True - grammar_search_classifier_pretty_str_names = ["?x", "?y", "?z"] + grammar_search_expected_nodes_allow_waits = True + grammar_search_classifier_pretty_str_names = ["?x", "?y", "?z", "?w"] grammar_search_vlm_atom_proposal_prompt_type = \ "options_labels_whole_traj_diverse" grammar_search_vlm_atom_label_prompt_type = "per_scene_naive" @@ -700,6 +923,7 @@ class GlobalSettings: grammar_search_select_all_debug = False grammar_search_invent_geo_predicates_only = False grammar_search_early_termination_heuristic_thresh = 0.0 + grammar_search_recognizing_unsolvable_goals_bonus = 1000 # grammar search clustering algorithm parameters grammar_search_clustering_gmm_num_components = 10 @@ -716,6 +940,50 @@ class GlobalSettings: vlm_test_time_atom_label_prompt_type = "per_scene_naive" # Whether or not to save eval trajectories save_eval_trajs = True + rgb_observation = False + render_init_state = False + use_counterfactual_dataset_path_name = False + use_classification_problem_setting = False + classification_has_counterfactual_support = True + + # dino similarity approach + dino_model_name = "dinov2_vits14" + distance_function = "dtw" + + # vlm predicate invention parameters + vlm_predicator_oracle_base_predicates = False + vlm_predicator_oracle_learned_predicates = False + vlm_predicator_use_grammar = True + vlm_predicator_num_proposal_batches = 1 + + # agent SDK online abstraction learning parameters + agent_sdk_model_name = "claude-sonnet-4-6" + agent_sdk_max_agent_turns_per_iteration = 20 + agent_sdk_agent_timeout = 300 # seconds per iteration + agent_sdk_resume_session = True # resume previous session if available + agent_sdk_propose_types = True + agent_sdk_propose_predicates = True + agent_sdk_propose_objects = True + agent_sdk_propose_processes = True + agent_sdk_propose_options = True + agent_sdk_auto_select_predicates = True # run hill-climbing after proposals + agent_sdk_max_trajectories_in_context = 3 + agent_sdk_log_agent_responses = True + + # Sandbox settings for agent SDK + agent_sdk_use_docker_sandbox = False # run agent inside Docker container + agent_sdk_docker_image = "predicators-sandbox" # Docker image name + agent_sdk_use_local_sandbox = False # sandbox dir with built-in tools, no Docker + + # Agent explorer settings + agent_explorer_max_turns = 5 # max agent turns per exploration query + agent_explorer_fallback_to_random = True # fall back to random on failure + + # Agent planner approach settings + agent_planner_isolate_test_session = True + agent_planner_use_scratchpad = False # include notes.md scratchpad + agent_planner_use_visualize_state = False # include visualize_state tool + agent_planner_use_annotate_scene = False # include annotate_scene tool @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: @@ -737,6 +1005,17 @@ def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: # tasks take more actions to complete. "pybullet_cover": 1000, "pybullet_blocks": 1000, + "pybullet_coffee": 2000, + "pybullet_balance": 2000, + "pybullet_grow": 2000, + "pybullet_circuit": 2000, + "pybullet_float": 2000, + "pybullet_domino_grid": 2000, + "pybullet_laser": 2000, + "pybullet_ants": 2000, + "pybullet_fan": 2000, + "pybullet_switch": 2000, + "pybullet_barrier": 2000, "doors": 1000, "coffee": 1000, "kitchen": 1000, @@ -754,6 +1033,8 @@ def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: # For the stick button environment, limit the per-option # horizon. "stick_button": 50, + "pybullet_switch": 2000, + "pybullet_barrier": 2000, })[args.get("env", "")], # In SeSamE, when to propagate failures back up to the high level @@ -794,6 +1075,8 @@ def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: { # For these environments, allow more skeletons. "coffee": 1000, + "pybullet_coffee": 100, + "pybullet_coffee_pixel": 100, "exit_garage": 1000, "tools": 1000, "stick_button": 1000, diff --git a/predicators/structs.py b/predicators/structs.py index d7a107e501..2ff2ed4931 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -5,16 +5,19 @@ import abc import copy import itertools -from dataclasses import dataclass, field +import random +from dataclasses import dataclass, field, replace from functools import cached_property, lru_cache from typing import Any, Callable, Collection, DefaultDict, Dict, Iterator, \ List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast import numpy as np import PIL.Image +import torch from gym.spaces import Box from numpy.typing import NDArray from tabulate import tabulate +from torch import Tensor import predicators.pretrained_model_interface import predicators.utils as utils # pylint: disable=consider-using-from-import @@ -23,10 +26,27 @@ @dataclass(frozen=True, order=True) class Type: - """Struct defining a type.""" + """Struct defining a type. + + sim_feature_names are features stored in an object, and usually + won't change throughout and across tasks. An example is the object's + pybullet id. + This is convenient for variables that are not easily extractable from the + sim state -- whether a food block attracts ants, or the joint id for a + switch -- but are nonetheless for running the simulation. + + Why not store all features here instead of storing in the State object? + They can only store one value per feature, so if we generate 10 tasks where + the blocks are at different locations, it won't be able to store all 10 + locations. One might think they could reset any feature at when reset is + called. But this would require the information is first stored in the State + object. + """ name: str feature_names: Sequence[str] = field(repr=False) parent: Optional[Type] = field(default=None, repr=False) + sim_features: Sequence[str] = field(default_factory=lambda: ["id"], + repr=False) @property def dim(self) -> int: @@ -43,6 +63,17 @@ def get_ancestors(self) -> Set[Type]: curr_type = curr_type.parent return ancestors_set + def pretty_str(self) -> str: + """Display the type in a nice human-readable format.""" + formatted_features = [f"'{name}'" for name in self.feature_names] + return f"{self.name}: {{{', '.join(formatted_features)}}}" + + def python_definition_str(self) -> str: + """Display in a format similar to how a type is instantiated.""" + formatted_features = [f"'{name}'" for name in self.feature_names] + return f"_{self.name}_type = Type('{self.name}', "+\ + f"[{', '.join(formatted_features)}])" + def __call__(self, name: str) -> _TypedEntity: """Convenience method for generating _TypedEntities.""" if name.startswith("?"): @@ -53,7 +84,7 @@ def __hash__(self) -> int: return hash((self.name, tuple(self.feature_names))) -@dataclass(frozen=True, order=True, repr=False) +@dataclass(frozen=False, order=True, repr=False) class _TypedEntity: """Struct defining an entity with some type, either an object (e.g., block3) or a variable (e.g., ?block). @@ -88,21 +119,68 @@ def is_instance(self, t: Type) -> bool: return False -@dataclass(frozen=True, order=True, repr=False) +@dataclass(frozen=False, order=True, repr=False) class Object(_TypedEntity): """Struct defining an Object, which is just a _TypedEntity whose name does not start with "?".""" + sim_data: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: assert not self.name.startswith("?") + # Initialize sim_data from the Type's sim_features + for sim_feature in self.type.sim_features: + self.sim_data[sim_feature] = None # Default to None + # Keep track of allowed attributes + self._allowed_attributes = {"sim_data"}.union(self.sim_data.keys()) + + def __getattr__(self, name: str) -> Any: + # Bypass custom logic for internal attributes + # Use object.__getattribute__(...) instead of self.sim_data + sim_data = object.__getattribute__(self, "sim_data") + if name in sim_data: + return sim_data[name] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + # Always allow the dataclass fields (e.g., "name", "type", "sim_data"). + if name in {"name", "type", "sim_data", "_allowed_attributes"}: + super().__setattr__(name, value) + return + + # For anything else, check _allowed_attributes. + allowed_attrs = object.__getattribute__(self, "_allowed_attributes") \ + if object.__getattribute__(self, "__dict__").get( + "_allowed_attributes") else set() + if name in allowed_attrs: + sim_data = object.__getattribute__(self, "sim_data") + if name in sim_data: + sim_data[name] = value + else: + super().__setattr__(name, value) + else: + raise AttributeError(f"Cannot set unknown attribute '{name}'") def __hash__(self) -> int: # By default, the dataclass generates a new __hash__ method when # frozen=True and eq=True, so we need to override it. return self._hash + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Object): + return False + return self.name == other.name and self.type == other.type + + @cached_property + def id_name(self) -> str: + try: + assert self.id is not None, "Object must have an id set to use id_name" + except: + breakpoint() + return f"{self.type.name}{self.id}" -@dataclass(frozen=True, order=True, repr=False) + +@dataclass(frozen=False, order=True, repr=False) class Variable(_TypedEntity): """Struct defining a Variable, which is just a _TypedEntity whose name starts with "?".""" @@ -129,6 +207,20 @@ def __post_init__(self) -> None: for obj in self: assert len(self[obj]) == obj.type.dim + def __hash__(self) -> int: + # Hash object keys and array contents using numpy's built-in hashing + items = [] + for obj in sorted(self.data.keys()): + arr = self.data[obj] + if hasattr(arr, 'tobytes'): + # For numpy arrays, hash the bytes representation + items.append((obj, hash(arr.tobytes()))) + else: + items.append((obj, hash(tuple(arr)))) + + data_hash = hash(tuple(items)) + return data_hash + def __iter__(self) -> Iterator[Object]: """An iterator over the state's objects, in sorted order.""" return iter(sorted(self.data)) @@ -215,16 +307,45 @@ def pretty_str(self) -> str: suffix = "\n" + "#" * ll + "\n" return prefix + "\n\n".join(table_strs) + suffix - def dict_str(self, indent: int = 0, object_features: bool = True) -> str: + def dict_str( + self, + indent: int = 0, + object_features: bool = True, + num_decimal_points: int = 2, + use_object_id: bool = False, + ignored_features: List[str] = ["capacity_liquid", + "target_liquid"]) -> str: """Return a dictionary representation of the state.""" + excluded_objects = [] + if CFG.excluded_objects_in_state_str: + excluded_objects = CFG.excluded_objects_in_state_str.split(",") state_dict = {} + + # Collect all unique types from objects in the state + object_types = set() for obj in self: - obj_dict = {} - if obj.type.name == "robot" or object_features: - for attribute, value in zip(obj.type.feature_names, self[obj]): - obj_dict[attribute] = value - obj_name = obj.name - state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict + object_types.add(obj.type) + + # Iterate through types and add all objects of each type + for obj_type in sorted(object_types, key=lambda t: t.name): + obj_type_name = obj_type.name + if obj_type_name not in excluded_objects: + # Get all objects of this type + objects_of_type = self.get_objects(obj_type) + + # Process each object of this type + for obj in objects_of_type: + obj_dict = {} + if obj_type_name == "robot" or object_features: + for attribute, value in zip(obj.type.feature_names, + self[obj]): + if attribute not in ignored_features: + obj_dict[attribute] = value + if use_object_id: + obj_name = obj.id_name + else: + obj_name = obj.name + state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict # Create a string of n_space spaces spaces = " " * indent @@ -233,7 +354,16 @@ def dict_str(self, indent: int = 0, object_features: bool = True) -> str: dict_str = spaces + "{" n_keys = len(state_dict.keys()) for i, (key, value) in enumerate(state_dict.items()): - value_str = ', '.join(f"'{k}': {v}" for k, v in value.items()) + # Format values in the string representation + formatted_items = [] + for k, v in value.items(): + if isinstance(v, (float, np.floating)): + formatted_items.append( + f"'{k}': {v:.{num_decimal_points}f}") + else: + formatted_items.append(f"'{k}': {v}") + value_str = ', '.join(formatted_items) + if i == 0: dict_str += f"'{key}': {{{value_str}}},\n" elif i == n_keys - 1: @@ -257,6 +387,9 @@ class Predicate: # treated "specially" by the classifier. _classifier: Callable[[State, Sequence[Object]], bool] = field(compare=False) + natural_language_assertion: Optional[Callable[[List[str]], + str]] = field(default=None, + compare=False) def __call__(self, entities: Sequence[_TypedEntity]) -> _Atom: """Convenience method for generating Atoms.""" @@ -279,6 +412,18 @@ def _hash(self) -> int: def __hash__(self) -> int: return self._hash + def __eq__(self, other: Predicate) -> bool: # type: ignore[override] + # equal by name + assert isinstance(other, Predicate) + if self.name != other.name: + return False + if len(self.types) != len(other.types): + return False + for self_type, other_type in zip(self.types, other.types): + if self_type != other_type: + return False + return True + @cached_property def arity(self) -> int: """The arity of this predicate (number of arguments).""" @@ -320,6 +465,23 @@ def pretty_str(self) -> Tuple[str, str]: body_str = f"{self.name}({vars_str_no_types})" return vars_str, body_str + def pretty_str_with_assertion(self) -> str: + var_names = [] + vars_str = [] + for i, t in enumerate(self.types): + vars_str.append( + f"{CFG.grammar_search_classifier_pretty_str_names[i]}:{t.name}" + ) + var_names.append( + f"{CFG.grammar_search_classifier_pretty_str_names[i]}") + vars_str = ", ".join(vars_str) # type: ignore[assignment] + + body_str = f"{self.name}({vars_str})" + if hasattr(self, "natural_language_assertion") and\ + self.natural_language_assertion is not None: + body_str += f": {self.natural_language_assertion(var_names)}" + return body_str + def pddl_str(self) -> str: """Get a string representation suitable for writing out to a PDDL file.""" @@ -342,6 +504,89 @@ def _negated_classifier(self, state: State, def __lt__(self, other: Predicate) -> bool: return str(self) < str(other) + def __reduce__(self) -> Tuple: + """Tell pickle/dill how to re-create a Predicate: + + (constructor, (name, types, classifier)) + """ + # • `tuple(self.types)` ensures the sequence itself is picklable + # • `_classifier` must be a top-level def or otherwise dill-pickleable + return (self.__class__, (self.name, tuple(self.types), + self._classifier)) + + +@dataclass(frozen=True, order=False, repr=False) +class DerivedPredicate(Predicate): + """Struct defining a concept predicate.""" + name: str + types: Sequence[Type] + # The classifier takes in a complete state and a sequence of objects + # representing the arguments. These objects should be the only ones + # treated "specially" by the classifier. + _classifier: Callable[[Set[GroundAtom], Sequence[Object]], + bool] = field(compare=False) + untransformed_predicate: Optional[Predicate] = field(default=None, + compare=False) + auxiliary_predicates: Optional[Set[Predicate]] = field(default=None, + compare=False) + + def update_auxiliary_concepts( + self, + auxiliary_predicates: Set[DerivedPredicate]) -> DerivedPredicate: + """Create a new ConceptPredicate with updated auxiliary_concepts.""" + return replace(self, auxiliary_predicates=auxiliary_predicates + ) # type: ignore[arg-type] + + @cached_property + def _hash(self) -> int: + # Make the hash the same regardless types is a list or tuple. + return hash(self.name + " ".join(t.name for t in self.types)) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Predicate) -> bool: # type: ignore[override] + # equal by name + assert isinstance(other, Predicate) + if self.name != other.name: + return False + if len(self.types) != len(other.types): + return False + for self_type, other_type in zip(self.types, other.types): + if self_type != other_type: + return False + return True + + def holds(self, state: Set[GroundAtom], + objects: Sequence[Object]) -> bool: # type: ignore[override] + """Public method for calling the classifier. + + Performs type checking first. + """ + assert len(objects) == self.arity + for obj, pred_type in zip(objects, self.types): + assert isinstance(obj, Object) + assert obj.is_instance(pred_type) + return self._classifier(state, objects) + + def _negated_classifier( + self, + state: Set[GroundAtom], # type: ignore[override] + objects: Sequence[Object]) -> bool: + # Separate this into a named function for pickling reasons. + return not self._classifier(state, objects) + + def __reduce__(self) -> Tuple: + """Tell pickle/dill how to re-create a DerivedPredicate: + + (constructor, (name, types, classifier)) + """ + # • `tuple(self.types)` ensures the sequence itself is picklable + # • `_classifier` must be a top-level def or otherwise dill-pickleable + return (self.__class__, + (self.name, tuple(self.types), self._classifier, + self.untransformed_predicate, self.auxiliary_predicates)) + @dataclass(frozen=True, order=False, repr=False, eq=False) class VLMPredicate(Predicate): @@ -352,7 +597,117 @@ class VLMPredicate(Predicate): classifier (i.e., one that returns simply raises some kind of error instead of actually outputting a value of any kind). """ - get_vlm_query_str: Callable[[Sequence[Object]], str] + get_vlm_query_str: Optional[Callable[[Sequence[Object]], + str]] = field(default=None) + + +class NSPredicate(Predicate): + """Neuro-Symbolic Predicate.""" + + def __init__( + self, name: str, types: Sequence[Type], + _classifier: Callable[[RawState, Sequence[Object]], bool] + ) -> None: # type: ignore[name-defined] + self._original_classifier = _classifier + super().__init__(name, types, _MemoizedClassifier(_classifier)) + + @cached_property + def _hash(self) -> int: + # return hash(str(self)) + return hash(self.name + str(self.types)) + + def __hash__(self) -> int: + return self._hash + + def classifier_str(self) -> str: + """Get a string representation of the classifier.""" + clf_str = getsource( + self._original_classifier) # type: ignore[name-defined] + clf_str = textwrap.dedent(clf_str) # type: ignore[name-defined] + clf_str = clf_str.replace("@staticmethod\n", "") + return clf_str + + +@dataclass +class _MemoizedClassifier(): + classifier: Callable[[State, Sequence[Object]], + Union[bool, VLMQuery]] # type: ignore[name-defined] + cache: Dict = field(default_factory=dict) + + def cache_truth_value(self, state: State, objects: Sequence[Object], + truth_value: bool) -> None: + """Cache the boolean value after querying the VLM and obtaining the + result.""" + combined_hash = self.hash_state_objs(state, objects) + self.cache[combined_hash] = truth_value + + def hash_state_objs(self, state: State, objects: Sequence[Object]) -> int: + objects_tuple_hash = hash(tuple(objects)) + state_hash = state.__hash__() + return hash((state_hash, objects_tuple_hash)) + + def has_classified(self, state: State, objects: Sequence[Object]) -> bool: + """Check if the state, object pair has been stored in the cache.""" + combined_hash = self.hash_state_objs(state, objects) + return combined_hash in self.cache + + def __call__(self, state: State, objects: Sequence[Object]) -> \ + Union[bool, VLMQuery]: # type: ignore[name-defined] + """When the classifier is called, return the cached value if it exists + otherwise call self.classifier.""" + # if state, object exist in cache, return the value + # else compute the truth value using the classifier + combined_hash = self.hash_state_objs(state, objects) + return self.cache.get(combined_hash, self.classifier(state, objects)) + + +@dataclass(frozen=True, order=False, repr=False) +class ConceptPredicate(Predicate): + """Struct defining a concept predicate.""" + name: str + types: Sequence[Type] + # The classifier takes in a complete state and a sequence of objects + # representing the arguments. These objects should be the only ones + # treated "specially" by the classifier. + _classifier: Callable[[Set[GroundAtom], Sequence[Object]], + bool] = field(compare=False) + untransformed_predicate: Optional[Predicate] = field(default=None, + compare=False) + auxiliary_concepts: Optional[Set[ConceptPredicate]] = field(default=None, + compare=False) + + def update_auxiliary_concepts( + self, + auxiliary_concepts: Set[ConceptPredicate]) -> ConceptPredicate: + """Create a new ConceptPredicate with updated auxiliary_concepts.""" + return replace(self, auxiliary_concepts=auxiliary_concepts) + + @cached_property + def _hash(self) -> int: + # return hash(str(self)) + return hash(self.name + str(self.types)) + + def __hash__(self) -> int: + return self._hash + + def holds(self, state: Set[GroundAtom], + objects: Sequence[Object]) -> bool: # type: ignore[override] + """Public method for calling the classifier. + + Performs type checking first. + """ + assert len(objects) == self.arity + for obj, pred_type in zip(objects, self.types): + assert isinstance(obj, Object) + assert obj.is_instance(pred_type) + return self._classifier(state, objects) + + def _negated_classifier( + self, + state: Set[GroundAtom], # type: ignore[override] + objects: Sequence[Object]) -> bool: + # Separate this into a named function for pickling reasons. + return not self._classifier(state, objects) @dataclass(frozen=True, repr=False, eq=False) @@ -406,6 +761,18 @@ def __lt__(self, other: object) -> bool: assert isinstance(other, _Atom) return str(self) < str(other) + def __reduce__(self) -> Tuple: + """Return a pickling recipe: call the class with (predicate, entities). + + - This ensures that when the object is unpickled, all dataclass fields + (predicate, entities) are set before anything like hashing or + stringification is triggered. + - This prevents errors where e.g. self.predicate does not exist yet at + the time __hash__ or __str__ is called during deserialization (which is + exactly what caused crash during parallel pnad learning). + """ + return (self.__class__, (self.predicate, tuple(self.entities))) + @dataclass(frozen=True, repr=False, eq=False) class LiftedAtom(_Atom): @@ -465,7 +832,20 @@ def get_vlm_query_str(self) -> str: """If this GroundAtom is associated with a VLMPredicate, then get the string that will be used to query the VLM.""" assert isinstance(self.predicate, VLMPredicate) - return self.predicate.get_vlm_query_str(self.objects) # pylint:disable=no-member + return self.predicate.get_vlm_query_str(self.objects) # type: ignore[misc] # pylint:disable=no-member + + def get_negated_atom(self) -> GroundAtom: + """Get the negated atom of this GroundAtom.""" + from predicators.approaches.grammar_search_invention_approach import \ + _NegationClassifier + if isinstance(self.predicate._classifier, _NegationClassifier): + return GroundAtom(self.predicate._classifier.body, self.objects) + else: + # classifier = _NegationClassifier(self.predicate) + # negated_predicate = Predicate(str(classifier), self.predicate.types, + # classifier) + # return GroundAtom(negated_predicate, self.objects) + return GroundAtom(self.predicate.get_negation(), self.objects) @dataclass(frozen=True, eq=False) @@ -479,6 +859,11 @@ class Task: # an "alternative goal" in this field and replace the goal with the # alternative goal before giving the task to the agent. alt_goal: Optional[Set[GroundAtom]] = field(default_factory=set) + # Optional natural language description of the goal. When present, + # approaches can surface this to an LLM agent so it understands the + # *intent* behind the goal atoms (e.g. "arrange dominoes so the chain + # reaction topples the targets" rather than just Toppled(target0)). + goal_nl: Optional[str] = None def __post_init__(self) -> None: # Verify types. @@ -508,7 +893,7 @@ def replace_goal_with_alt_goal(self) -> Task: # demonstrator. To prevent leakage of this information, we discard the # original goal. if self.alt_goal: - return Task(self.init, goal=self.alt_goal) + return Task(self.init, goal=self.alt_goal, goal_nl=self.goal_nl) return self @@ -530,6 +915,8 @@ class EnvironmentTask: goal_description: GoalDescription # See Task._alt_goal for the reason for this field. alt_goal_desc: Optional[GoalDescription] = field(default=None) + # Optional natural language goal description (passed through to Task). + goal_nl: Optional[str] = None @cached_property def task(self) -> Task: @@ -539,7 +926,7 @@ def task(self) -> Task: # goal exists, then there's nothing particular to set the task's # alt_goal field to. if self.alt_goal_desc is None: - return Task(self.init, self.goal) + return Task(self.init, self.goal, goal_nl=self.goal_nl) # If we turn the environment task into a task before replacing the goal # with the alternative goal, we have to set the task's alt_goal field # accordingly to leave open the possibility of doing that replacement @@ -549,7 +936,10 @@ def task(self) -> Task: assert isinstance(self.alt_goal_desc, set) for atom in self.alt_goal_desc: assert isinstance(atom, GroundAtom) - return Task(self.init, self.goal, alt_goal=self.alt_goal_desc) + return Task(self.init, + self.goal, + alt_goal=self.alt_goal_desc, + goal_nl=self.goal_nl) @cached_property def init(self) -> State: @@ -608,6 +998,7 @@ class ParameterizedOption: # terminate now. The objects' types will match those in # self.types. The parameters will be contained in params_space. terminal: ParameterizedTerminal = field(repr=False) + params_description: Optional[Tuple[str, ...]] = None @cached_property def _hash(self) -> int: @@ -630,11 +1021,24 @@ def __hash__(self) -> int: def ground(self, objects: Sequence[Object], params: Array) -> _Option: """Ground into an Option, given objects and parameter values.""" - assert len(objects) == len(self.types) - for obj, t in zip(objects, self.types): - assert obj.is_instance(t) + if len(objects) != len(self.types): + expected = [t.name for t in self.types] + got = [f"{o.name}:{o.type.name}" for o in objects] + raise ValueError( + f"Cannot ground '{self.name}': expected {len(self.types)} " + f"objects {expected}, got {len(objects)} {got}") + for i, (obj, t) in enumerate(zip(objects, self.types)): + if not obj.is_instance(t): + raise TypeError( + f"Cannot ground '{self.name}': object '{obj.name}' at " + f"position {i} has type '{obj.type.name}', " + f"expected '{t.name}'") params = np.array(params, dtype=self.params_space.dtype) - assert self.params_space.contains(params) + if not self.params_space.contains(params): + raise ValueError( + f"Cannot ground '{self.name}': params {params.tolist()} " + f"outside bounds low={self.params_space.low.tolist()}, " + f"high={self.params_space.high.tolist()}") memory: Dict = {} # each option has its own memory dict return _Option( self.name, @@ -684,6 +1088,26 @@ def policy(self, state: State) -> Action: action.set_option(self) return action + def __str__(self) -> str: + """Full spec including objects and parameters.""" + objects = ", ".join(o.name for o in self.objects) + params = ", ".join(str(round(p, 2)) for p in self.params) + return f"{self.name}({objects}, {params})" + + def simple_str(self, use_object_id: bool = False) -> str: + """Simple spec without parameters.""" + if use_object_id: + objects = ", ".join( + [o.id_name + ":" + o.type.name for o in self.objects]) + else: + objects = ", ".join(o.name for o in self.objects) + return f"{self.name}({objects})" + + +DummyParameterizedOption: ParameterizedOption = ParameterizedOption( + "DummyParameterizedOption", [], Box(0, 1, (0, )), + lambda s, m, o, p: Action(np.array([0.0])), lambda s, m, o, p: False, + lambda s, m, o, p: True) DummyOption: _Option = ParameterizedOption( "DummyOption", [], Box(0, 1, @@ -720,6 +1144,70 @@ def make_nsrt( self.add_effects, self.delete_effects, self.ignore_effects, option, option_vars, sampler) + def make_endogenous_process( + self, + option: Optional[ParameterizedOption], + option_vars: Optional[Sequence[Variable]], + sampler: Optional[NSRTSampler], + process_strength: Optional[float] = None, + process_delay_params: Optional[Sequence[float]] = None, + process_rng: Optional[np.random.Generator] = None, + ) -> EndogenousProcess: + """Make a CausalProcess out of this STRIPSOperator object.""" + assert option is not None and option_vars is not None and \ + sampler is not None + if process_delay_params is None: + process_delay_params = [5, 1] + if process_strength is None: + process_strength = 1.0 + if process_rng is None: + process_rng = np.random.default_rng(CFG.seed) + + proc = EndogenousProcess( + self.name, + self.parameters, + condition_at_start=self.preconditions + if option.name != "Wait" else set(), + condition_overall=set(), + condition_at_end=set(), + add_effects=self.add_effects if option.name != "Wait" else set(), + delete_effects=self.delete_effects + if option.name != "Wait" else set(), + delay_distribution=utils.CMPDelay( + *process_delay_params, # type: ignore[attr-defined] + rng=process_rng), + strength=process_strength, # type: ignore[arg-type] + option=option, + option_vars=option_vars, + _sampler=sampler) + return proc + + def make_exogenous_process( + self, + process_strength: Optional[float] = None, + process_delay_params: Optional[Sequence[float]] = None, + process_rng: Optional[np.random.Generator] = None + ) -> ExogenousProcess: + """Make an ExogenousProcess out of this STRIPSOperator object.""" + if process_delay_params is None: + process_delay_params = torch.tensor([1, 1 + ]) # type: ignore[assignment] + if process_strength is None: + process_strength = torch.tensor(1.0) # type: ignore[assignment] + dist = utils.DiscreteGaussianDelay(torch.tensor(1), torch.tensor(1)) + + proc = ExogenousProcess( + self.name, + self.parameters, + condition_at_start=self.preconditions, + condition_overall=self.preconditions, + condition_at_end=set(), + add_effects=self.add_effects, + delete_effects=self.delete_effects, + delay_distribution=dist, + strength=process_strength) # type: ignore[arg-type] + return proc + @lru_cache(maxsize=None) def ground(self, objects: Tuple[Object]) -> _GroundSTRIPSOperator: """Ground into a _GroundSTRIPSOperator, given objects. @@ -1235,6 +1723,44 @@ def train_task_idx(self) -> int: return self._train_task_idx +@dataclass(frozen=True, repr=False, eq=False) +class AtomOptionTrajectory: + """A structure similar to a LowLevelTrajectory but save atoms at every + state, as well as the option that was executed.""" + _low_level_states: List[State] + _states: List[Set[GroundAtom]] + _actions: List[_Option] + _is_demo: bool = field(default=False) + _train_task_idx: Optional[int] = field(default=None) + + def __post_init__(self) -> None: + assert len(self._states) == len(self._actions) + 1 + if self._is_demo: + assert self._train_task_idx is not None + + @property + def states(self) -> List[Set[GroundAtom]]: + """States in the trajectory.""" + return self._states + + @property + def actions(self) -> List[_Option]: + """Actions in the trajectory.""" + return self._actions + + @property + def is_demo(self) -> bool: + """Whether this trajectory is a demonstration.""" + return self._is_demo + + @property + def train_task_idx(self) -> int: + """The index of the train task.""" + assert self._train_task_idx is not None, \ + "This trajectory doesn't contain a train task idx!" + return self._train_task_idx + + @dataclass(frozen=True, repr=False, eq=False) class ImageOptionTrajectory: """A structure similar to a LowLevelTrajectory where we record images at @@ -1338,6 +1864,66 @@ def append(self, self._trajectories.append(trajectory) +@dataclass(repr=False, eq=False) +class ClassificationDataset: + """Maybe ultimately a collection of LowLevelTrajectory objects, and a list + of labels, one per trajectory. + + There is List[Video] for each episode + """ + task_names: List[str] + support_videos: List[List[Video]] + support_labels: List[List[int]] + query_videos: List[List[Video]] + query_labels: List[List[int]] + seed: int + + def __post_init__(self) -> None: + assert len(self.support_videos) == len(self.support_labels) == \ + len(self.query_videos) == len(self.query_labels) == \ + len(self.task_names) + self._current_idx: int = 0 + self._rng = random.Random(self.seed) # Create a local random generator + + def __iter__(self) -> "Iterator[ClassificationEpisode]": + self._current_idx = 0 + return self + + def __next__(self) -> ClassificationEpisode: + if self._current_idx >= len(self.support_videos): + raise StopIteration + + episode_name = self.task_names[self._current_idx] + episode_support_videos = self.support_videos[self._current_idx] + episode_support_labels = self.support_labels[self._current_idx] + episode_query_videos = self.query_videos[self._current_idx] + episode_query_labels = self.query_labels[self._current_idx] + + assert len(episode_support_videos) == len(episode_support_labels) + assert len(episode_query_videos) == len(episode_query_labels) + + # Generate a permutation index for shuffling + perm = list(range(len(episode_query_videos))) + perm.reverse() + # self._rng.shuffle(perm) + + # Apply shuffle to query videos and labels + episode_query_videos = [episode_query_videos[i] for i in perm] + episode_query_labels = [episode_query_labels[i] for i in perm] + + episode: ClassificationEpisode = (episode_name, episode_support_videos, + episode_support_labels, + episode_query_videos, + episode_query_labels) + + self._current_idx += 1 + return episode + + def __len__(self) -> int: + """The number of episodes in the dataset.""" + return len(self.support_labels) + + @dataclass(eq=False) class Segment: """A segment represents a low-level trajectory that is the result of @@ -1442,26 +2028,37 @@ class PNAD: def add_to_datastore(self, member: Tuple[Segment, VarToObjSub], - check_effect_equality: bool = True) -> None: + check_effect_equality: bool = True, + check_option_equality: bool = True) -> None: """Add a new member to self.datastore.""" seg, var_obj_sub = member if len(self.datastore) > 0: # All variables should have a corresponding object. - assert set(var_obj_sub) == set(self.op.parameters) + if CFG.exogenous_process_learner_do_intersect: + # When we don't assume preconditions contain only atoms with + # variables present in the effect, we would first include + # all the variables in the op.parameters, and the var_obj_sub + # only contain parameters that can be unified with the last + # segment. So it can be a subset of the op.parameters. + assert set(var_obj_sub).issubset(set(self.op.parameters)) + else: + assert set(var_obj_sub) == set(self.op.parameters) # The effects should match. if check_effect_equality: obj_var_sub = {o: v for (v, o) in var_obj_sub.items()} lifted_add_effects = { a.lift(obj_var_sub) for a in seg.add_effects + if not isinstance(a.predicate, DerivedPredicate) } lifted_del_effects = { a.lift(obj_var_sub) for a in seg.delete_effects + if not isinstance(a.predicate, DerivedPredicate) } assert lifted_add_effects == self.op.add_effects assert lifted_del_effects == self.op.delete_effects - if seg.has_option(): + if seg.has_option() and check_option_equality: # The option should match. option = seg.get_option() part_param_option, part_option_args = self.option_spec @@ -1477,6 +2074,26 @@ def make_nsrt(self) -> NSRT: param_option, option_vars = self.option_spec return self.op.make_nsrt(param_option, option_vars, self.sampler) + def make_endogenous_process(self) -> EndogenousProcess: + """Make an EndogenousProcess from this PNAD.""" + assert self.sampler is not None + param_option, option_vars = self.option_spec + return self.op.make_endogenous_process(param_option, option_vars, + self.sampler) + + def make_exogenous_process( + self, + process_strength: Optional[float] = None, + process_delay_params: Optional[Sequence[float]] = None, + process_rng: Optional[np.random.Generator] = None + ) -> ExogenousProcess: + """Make an ExogenousProcess from this PNAD.""" + return self.op.make_exogenous_process( + process_strength=process_strength, + process_delay_params=process_delay_params, + process_rng=process_rng, + ) + def copy(self) -> PNAD: """Make a copy of this PNAD object, taking care to ensure that modifying the original will not affect the copy.""" @@ -1506,6 +2123,13 @@ def __lt__(self, other: PNAD) -> bool: return repr(self) < repr(other) +@dataclass(eq=False, repr=False) +class PAPAD: + """Partial Process and Datastore.""" + # The non option and sampler part of the CausalProcess + pprocess: PartialProcess + + @dataclass(frozen=True, eq=False, repr=False) class InteractionRequest: """A request for interacting with a training task during online learning. @@ -1983,6 +2607,498 @@ def __len__(self) -> int: return len(self.ground_nsrts) +@dataclass(frozen=False, repr=False, eq=False) +class DelayDistribution: + + def set_parameters(self, parameters: Sequence[torch.Tensor], + **kwargs: Any) -> None: + raise NotImplementedError + + def get_parameters(self) -> Sequence[float]: + raise NotImplementedError + + def sample(self) -> int: + raise NotImplementedError + + def log_prob(self, k: Union[int, torch.Tensor]) -> torch.Tensor: + raise NotImplementedError + + def __str__(self) -> str: + return self._str + + @cached_property + def _str(self) -> str: + raise NotImplementedError + + +@dataclass(frozen=False, repr=False, eq=False) +class PartialProcess: + pass + + +@dataclass(frozen=False, repr=False, eq=False) +class CausalProcess(abc.ABC): + name: str + parameters: Sequence[Variable] + condition_at_start: Set[LiftedAtom] + condition_overall: Set[LiftedAtom] + condition_at_end: Set[LiftedAtom] + add_effects: Set[LiftedAtom] + delete_effects: Set[LiftedAtom] + delay_distribution: DelayDistribution + strength: torch.Tensor + + @abc.abstractmethod + def ground(self, objects: Sequence[Object]) -> _GroundCausalProcess: + pass + + @abc.abstractmethod + def copy(self) -> CausalProcess: + """Create a deep copy of this causal process.""" + pass + + @abc.abstractmethod + def filter_predicates(self, kept: Collection[Predicate]) -> CausalProcess: + """Keep only the given predicates in the preconditions, add effects, + delete effects, and ignore effects. + + Note that the parameters must stay the same for the sake of the + sampler inputs. + """ + pass + + def _set_parameters(self, parameters: Sequence[float], + **kwargs: Any) -> None: + self.strength = parameters[0] # type: ignore[assignment] + self.delay_distribution.set_parameters( + parameters[1:], **kwargs) # type: ignore[arg-type] + # Invalidate cached properties + if '_str' in self.__dict__: + del self.__dict__['_str'] + if '_hash' in self.__dict__: + del self.__dict__['_hash'] + + def _get_parameters(self) -> Sequence[float]: + """Get the parameters of this CausalProcess. + + The first parameter is the strength, and the rest are the delay + distribution parameters. + """ + return [ + self.strength + ] + self.delay_distribution.get_parameters() # type: ignore[operator] + + def delay_probability(self, delay: int) -> float: + return self.delay_distribution.probability( + delay) # type: ignore[attr-defined] + + @cached_property + def _hash(self) -> int: + return hash(str(self)) + + def __hash__(self) -> int: + return self._hash + + @cached_property + def _str(self) -> str: + ignore_effects_str = "" + if hasattr(self, 'ignore_effects') and isinstance( + self.ignore_effects, set): + ignore_effects_str = f"\n Ignore Effects: {sorted(self.ignore_effects, key=str)}" + return f""" Parameters: {self.parameters} + Conditions at start: {sorted(self.condition_at_start, key=str)} + Conditions overall: {sorted(self.condition_overall, key=str)} + Conditions at end: {sorted(self.condition_at_end, key=str)} + Add Effects: {sorted(self.add_effects, key=str)} + Delete Effects: {sorted(self.delete_effects, key=str)}{ignore_effects_str} + Log Strength: {self.strength:.4f} + Delay Distribution: {self.delay_distribution}""" + + @cached_property + def _str_wo_params(self) -> str: + return f""" Parameters: {self.parameters} + Conditions at start: {sorted(self.condition_at_start, key=str)} + Conditions overall: {sorted(self.condition_overall, key=str)} + Conditions at end: {sorted(self.condition_at_end, key=str)} + Add Effects: {sorted(self.add_effects, key=str)} + Delete Effects: {sorted(self.delete_effects, key=str)}""" + + def __str__(self) -> str: + return self._str + + def __repr__(self) -> str: + return str(self) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, CausalProcess) + return str(self) == str(other) + + def __lt__(self, other: object) -> bool: + assert isinstance(other, CausalProcess) + return str(self) < str(other) + + def __gt__(self, other: object) -> bool: + assert isinstance(other, CausalProcess) + return str(self) > str(other) + + def get_complexity(self) -> float: + """Get the complexity of this operator. + + We only care about the arity of the operator, since that is what + affects grounding. We'll use 2^arity as a measure of grounding + effort. + """ + return float(2**len(self.parameters)) + + +@dataclass(frozen=False, repr=False, eq=False) +class ExogenousProcess(CausalProcess): + + def copy(self) -> ExogenousProcess: + """Create a deep copy of this exogenous process.""" + return ExogenousProcess( + name=self.name, + parameters=list(self.parameters), + condition_at_start=self.condition_at_start.copy(), + condition_overall=self.condition_overall.copy(), + condition_at_end=self.condition_at_end.copy(), + add_effects=self.add_effects.copy(), + delete_effects=self.delete_effects.copy(), + delay_distribution=self.delay_distribution.copy( + ), # type: ignore[attr-defined] + strength=self.strength.clone()) + + def filter_predicates(self, + kept: Collection[Predicate]) -> ExogenousProcess: + condition_at_start = {a for a in self.condition_at_start if a.predicate\ + in kept} + condition_overall = {a for a in self.condition_overall if a.predicate\ + in kept} + condition_at_end = {a for a in self.condition_at_end if a.predicate\ + in kept} + add_effects = {a for a in self.add_effects if a.predicate in kept} + delete_effects = { + a + for a in self.delete_effects if a.predicate in kept + } + + return ExogenousProcess(self.name, self.parameters, condition_at_start, + condition_overall, condition_at_end, + add_effects, delete_effects, + self.delay_distribution, self.strength) + + @cached_property + def _str(self) -> str: + process_str = super()._str + return f"""ExogenousProcess-{self.name}: +{process_str}""" + + @cached_property + def _str_wo_params(self) -> str: + process_str = super()._str_wo_params + return f"""ExogenousProcess-{self.name}: +{process_str}""" + + def ground(self, objects: Sequence[Object]) -> _GroundExogenousProcess: + assert len(objects) == len(self.parameters) + assert all( + o.is_instance(p.type) for o, p in zip(objects, self.parameters)) + sub = dict(zip(self.parameters, objects)) + condition_at_start = {a.ground(sub) for a in self.condition_at_start} + condition_overall = {a.ground(sub) for a in self.condition_overall} + condition_at_end = {a.ground(sub) for a in self.condition_at_end} + add_effects = {a.ground(sub) for a in self.add_effects} + delete_effects = {a.ground(sub) for a in self.delete_effects} + return _GroundExogenousProcess(self, objects, condition_at_start, + condition_overall, condition_at_end, + add_effects, delete_effects) + + +@dataclass(frozen=False, repr=False, eq=False) +class EndogenousProcess(CausalProcess): + option: ParameterizedOption + option_vars: Sequence[Variable] + _sampler: NSRTSampler = field(repr=False) + ignore_effects: Set[Predicate] = field(default_factory=set) + + def copy(self) -> EndogenousProcess: + """Create a deep copy of this endogenous process.""" + return EndogenousProcess( + name=self.name, + parameters=list(self.parameters), + condition_at_start=self.condition_at_start.copy(), + condition_overall=self.condition_overall.copy(), + condition_at_end=self.condition_at_end.copy(), + add_effects=self.add_effects.copy(), + delete_effects=self.delete_effects.copy(), + delay_distribution=self.delay_distribution.copy( + ), # type: ignore[attr-defined] + strength=self.strength.clone(), + option=self.option.copy(), # type: ignore[attr-defined] + option_vars=self.option_vars.copy(), # type: ignore[attr-defined] + _sampler=self._sampler.copy(), # type: ignore[attr-defined] + ignore_effects=self.ignore_effects.copy(), + ) + + def filter_predicates(self, + kept: Collection[Predicate]) -> EndogenousProcess: + """Keep only the given predicates in the preconditions, add effects, + delete effects, and ignore effects. + + Note that the parameters must stay the same for the sake of the + sampler inputs. + """ + condition_at_start = {a for a in self.condition_at_start if a.predicate\ + in kept} + condition_overall = {a for a in self.condition_overall if a.predicate\ + in kept} + condition_at_env = {a for a in self.condition_at_end if a.predicate\ + in kept} + add_effects = {a for a in self.add_effects if a.predicate in kept} + delete_effects = { + a + for a in self.delete_effects if a.predicate in kept + } + ignore_effects = {a for a in self.ignore_effects if a in kept} + + return EndogenousProcess(self.name, self.parameters, + condition_at_start, condition_overall, + condition_at_env, add_effects, delete_effects, + self.delay_distribution, self.strength, + self.option, self.option_vars, self._sampler, + ignore_effects) + + def ground(self, objects: Sequence[Object]) -> _GroundEndogenousProcess: + assert len(objects) == len(self.parameters) + assert all( + o.is_instance(p.type) for o, p in zip(objects, self.parameters)) + sub = dict(zip(self.parameters, objects)) + condition_at_start = {a.ground(sub) for a in self.condition_at_start} + condition_overall = {a.ground(sub) for a in self.condition_overall} + condition_at_end = {a.ground(sub) for a in self.condition_at_end} + add_effects = {a.ground(sub) for a in self.add_effects} + delete_effects = {a.ground(sub) for a in self.delete_effects} + option_objs = [sub[v] for v in self.option_vars] + return _GroundEndogenousProcess(self, objects, condition_at_start, + condition_overall, condition_at_end, + add_effects, delete_effects, + self.option, option_objs, + self._sampler) + + @cached_property + def _str(self) -> str: + option_var_str = ", ".join([str(v) for v in self.option_vars]) + process_str = super()._str + return f"""EndogenousProcess-{self.name}: +{process_str} + Option Spec: {self.option.name}({option_var_str})""" + + +@dataclass(frozen=False, repr=False, eq=False) +class _GroundCausalProcess: + parent: CausalProcess + objects: Sequence[Object] + condition_at_start: Set[GroundAtom] + condition_overall: Set[GroundAtom] + condition_at_end: Set[GroundAtom] + add_effects: Set[GroundAtom] + delete_effects: Set[GroundAtom] + + @property + def delay_distribution(self) -> DelayDistribution: + """The delay distribution of the parent CausalProcess.""" + return self.parent.delay_distribution + + @property + def strength(self) -> float: + """The strength of the parent CausalProcess.""" + return self.parent.strength # type: ignore[return-value] + + @abc.abstractmethod + def cause_triggered(self, state_history: List[Set[GroundAtom]], + action_history: List[_Option]) -> bool: + raise NotImplementedError + + def effect_factor(self, state: Set[GroundAtom]) -> float: + """Compute the effect factor of this ground causal process on the + state.""" + return int( + self.add_effects.issubset(state) + and not self.delete_effects.issubset(state)) * self.strength + + def factored_effect_factor(self, y_tj: bool, factor_atom: GroundAtom, + prev_val: bool) -> Tensor: + """If x_tj is True, we say that x_tj would get the effect factor of a + process if at this time step, factor_atom is in the add effects and not + in the delete effects of the process. + + If x_tj is False in the current step t, then we say that x_tj + would get effect from the effect factor of a process if at this + time step, x_tj is in the delete effects and not in the add + effects of the process. + """ + # match1 requires in the x_tj = False case because match1 requires that + # (atom in not add_effects or in delete_effects) simply be true, + # whereas match2 requires specifically that + # (atom in delete_effects and not in add_effects) be true. + # match1 = (factor_atom in self.add_effects and + # factor_atom not in self.delete_effects) == x_tj + if y_tj: + match = int(y_tj != prev_val and factor_atom in self.add_effects + and factor_atom not in self.delete_effects) + else: + match = int(y_tj != prev_val and factor_atom in self.delete_effects + and factor_atom not in self.add_effects) + return match * self.strength # type: ignore[return-value] + + @property + def name(self) -> str: + """Name of this ground causal process.""" + return self.parent.name + + @cached_property + def _str(self) -> str: + return f"""GroundProcess-{self.name}: + Parameters: {self.objects} + Conditions at start: {sorted(self.condition_at_start, key=str)} + Conditions overall: {sorted(self.condition_overall, key=str)} + Conditions at end: {sorted(self.condition_at_end, key=str)} + Add Effects: {sorted(self.add_effects, key=str)} + Delete Effects: {sorted(self.delete_effects, key=str)}""" + + @cached_property + def _hash(self) -> int: + return hash(str(self)) + + def __str__(self) -> str: + return self._str + + def __repr__(self) -> str: + return str(self) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: object) -> bool: + assert isinstance(other, _GroundCausalProcess) + return str(self) == str(other) + + def __lt__(self, other: object) -> bool: + assert isinstance(other, _GroundCausalProcess) + return str(self) < str(other) + + def __gt__(self, other: object) -> bool: + assert isinstance(other, _GroundCausalProcess) + return str(self) > str(other) + + def name_and_objects_str(self) -> str: + return f"{self.name}({', '.join([str(o) for o in self.objects])})" + + +@dataclass(frozen=False, repr=False, eq=False) +class _GroundEndogenousProcess(_GroundCausalProcess): + option: ParameterizedOption + option_objs: Sequence[Object] + _sampler: NSRTSampler = field(repr=False) + + @property + def ignore_effects(self) -> Set[Predicate]: + """Ignore effects from the parent.""" + return self.parent.ignore_effects # type: ignore[attr-defined] + + @cached_property + def _str(self) -> str: + return f"""Process-{self.name}: + Parameters: {self.objects} + Conditions at start: {sorted(self.condition_at_start, key=str)} + Conditions overall: {sorted(self.condition_overall, key=str)} + Conditions at end: {sorted(self.condition_at_end, key=str)} + Add Effects: {sorted(self.add_effects, key=str)} + Delete Effects: {sorted(self.delete_effects, key=str)} + Ignore Effects: {sorted(self.ignore_effects, key=str)} + Option: {self.option} + Option Objects: {self.option_objs}""" + + def cause_triggered(self, state_history: List[Set[GroundAtom]], + action_history: List[_Option]) -> bool: + """Check if this endogenous process was triggered by the last + action.""" + + def check_wo_s(state: Set[GroundAtom], action: _Option) -> bool: + return (action.parent == self.option + and action.objects == self.option_objs) + + def check_w_s(state: Set[GroundAtom], action: _Option) -> bool: + return (action.parent == self.option + and action.objects == self.option_objs + and self.condition_at_start.issubset(state)) + + # if self.name == "SwitchFaucetOff" and check_wo_s(state_history[-1], + # action_history[-1]): + # breakpoint() + return check_w_s(state_history[-1], action_history[-1]) and ( + len(state_history) == 1 + or not check_wo_s(state_history[-2], action_history[-2])) + + def copy(self) -> _GroundEndogenousProcess: + """Make a copy of this _GroundEndogenousProcess object.""" + new_condition_at_start = set(self.condition_at_start) + new_condition_overall = set(self.condition_overall) + new_condition_at_end = set(self.condition_at_end) + new_add_effects = set(self.add_effects) + new_delete_effects = set(self.delete_effects) + return _GroundEndogenousProcess(self.parent, self.objects, + new_condition_at_start, + new_condition_overall, + new_condition_at_end, new_add_effects, + new_delete_effects, self.option, + self.option_objs, self._sampler) + + def sample_option(self, state: State, goal: Set[GroundAtom], + rng: np.random.Generator) -> _Option: + """Sample an _Option for this ground NSRT, by invoking the contained + sampler. + + On the Option that is returned, one can call, e.g., + policy(state). + """ + # Note that the sampler takes in ALL self.objects, not just the subset + # self.option_objs of objects that are passed into the option. + params = self._sampler(state, goal, rng, self.objects) + # Clip the params into the params_space of self.option, for safety. + low = self.option.params_space.low + high = self.option.params_space.high + params = np.clip(params, low, high) + return self.option.ground(self.option_objs, params) + + +@dataclass(frozen=False, repr=False, eq=False) +class _GroundExogenousProcess(_GroundCausalProcess): + + def cause_triggered(self, state_history: List[Set[GroundAtom]], + action_history: List[_Option]) -> bool: + """Check if this exogenous process was triggered by the last action.""" + + def check(state: Set[GroundAtom]) -> bool: + return self.condition_at_start.issubset(state) + + return check(state_history[-1]) and (len(state_history) == 1 + or not check(state_history[-2])) + + def copy(self) -> _GroundExogenousProcess: + """Make a copy of this _GroundExogenousProcess object.""" + new_condition_at_start = set(self.condition_at_start) + new_condition_overall = set(self.condition_overall) + new_condition_at_end = set(self.condition_at_end) + new_add_effects = set(self.add_effects) + new_delete_effects = set(self.delete_effects) + return _GroundExogenousProcess(self.parent, self.objects, + new_condition_at_start, + new_condition_overall, + new_condition_at_end, new_add_effects, + new_delete_effects) + + # Convenience higher-order types useful throughout the code Observation = Any GoalDescription = Any @@ -2029,6 +3145,12 @@ def __len__(self) -> int: ParameterizedTerminal = Callable[[State, Dict, Sequence[Object], Array], bool] AbstractPolicy = Callable[[Set[GroundAtom], Set[Object], Set[GroundAtom]], Optional[_GroundNSRT]] +AbstractProcessPolicy = Callable[ + [Set[GroundAtom], Set[Object], Set[GroundAtom]], + Optional[_GroundEndogenousProcess]] RGBA = Tuple[float, float, float, float] BridgePolicy = Callable[[State, Set[GroundAtom], List[_Option]], _Option] BridgeDataset = List[Tuple[Set[_Option], _GroundNSRT, Set[GroundAtom], State]] +Mask = NDArray[np.bool_] +ClassificationEpisode = Tuple[str, List[Video], List[int], List[Video], + List[int]] diff --git a/predicators/utils.py b/predicators/utils.py index 1493e5f40a..292d7c596b 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -4,6 +4,8 @@ import abc import contextlib +import copy +import datetime import functools import gc import heapq as hq @@ -18,15 +20,19 @@ import sys import time from argparse import ArgumentParser -from collections import defaultdict +from collections import defaultdict, namedtuple +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field +from functools import cached_property from pathlib import Path +from pprint import pformat from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Dict, \ - FrozenSet, Generator, Generic, Hashable, Iterator, List, Optional, \ - Sequence, Set, Tuple + FrozenSet, Generator, Generic, Hashable, Iterable, Iterator, List, \ + Optional, Sequence, Set, Tuple from typing import Type as TypingType from typing import TypeVar, Union, cast +import colorlog import dill as pkl import imageio import matplotlib @@ -34,6 +40,7 @@ import numpy as np import pathos.multiprocessing as mp import PIL.Image +import torch from gym.spaces import Box from matplotlib import patches from numpy.typing import NDArray @@ -44,18 +51,21 @@ from scipy.stats import beta as BetaRV from predicators.args import create_arg_parser +from predicators.image_patch_wrapper import ImagePatch from predicators.pretrained_model_interface import GoogleGeminiLLM, \ - GoogleGeminiVLM, LargeLanguageModel, OpenAILLM, OpenAIVLM, \ - VisionLanguageModel + GoogleGeminiVLM, LargeLanguageModel, OpenAILLM, OpenAIVLM, OpenRouterLLM, \ + OpenRouterVLM, VisionLanguageModel from predicators.pybullet_helpers.joint import JointPositions from predicators.settings import CFG, GlobalSettings -from predicators.structs import NSRT, Action, Array, DummyOption, \ - EntToEntSub, GroundAtom, GroundAtomTrajectory, \ +from predicators.structs import NSRT, Action, Array, AtomOptionTrajectory, \ + CausalProcess, DelayDistribution, DerivedPredicate, DummyOption, \ + EntToEntSub, ExogenousProcess, GroundAtom, GroundAtomTrajectory, \ GroundNSRTOrSTRIPSOperator, Image, LDLRule, LiftedAtom, \ - LiftedDecisionList, LiftedOrGroundAtom, LowLevelTrajectory, Metrics, \ - NSRTOrSTRIPSOperator, Object, ObjectOrVariable, Observation, OptionSpec, \ - ParameterizedOption, Predicate, Segment, State, STRIPSOperator, Task, \ - Type, Variable, VarToObjSub, Video, VLMPredicate, _GroundLDLRule, \ + LiftedDecisionList, LiftedOrGroundAtom, LowLevelTrajectory, Mask, \ + Metrics, NSRTOrSTRIPSOperator, Object, ObjectOrVariable, Observation, \ + OptionSpec, ParameterizedOption, Predicate, Segment, State, \ + STRIPSOperator, Task, Type, Variable, VarToObjSub, Video, VLMPredicate, \ + _GroundEndogenousProcess, _GroundExogenousProcess, _GroundLDLRule, \ _GroundNSRT, _GroundSTRIPSOperator, _Option, _TypedEntity from predicators.third_party.fast_downward_translator.translate import \ main as downward_translate @@ -962,8 +972,12 @@ def _policy(self, state: State, memory: Dict, objects: Sequence[Object], memory["current_child_index"] = current_index current_child = self._children[current_index] child_memory = memory["child_memory"][current_index] - assert current_child.initiable(state, child_memory, objects, - params) + try: + assert current_child.initiable(state, child_memory, objects, + params) + except: + breakpoint() + # logging.debug(f"Executing {current_child.name}") return current_child.policy(state, child_memory, objects, params) def _terminal(self, state: State, memory: Dict, objects: Sequence[Object], @@ -1035,17 +1049,363 @@ class PyBulletState(State): @property def joint_positions(self) -> JointPositions: """Expose the current joints state in the simulator_state.""" - return cast(JointPositions, self.simulator_state) + # if the simulator state is an array + if isinstance(self.simulator_state, Dict): + jp = self.simulator_state["joint_positions"] + else: + jp = self.simulator_state + return cast(JointPositions, jp) + + @property + def state_image(self) -> Image: + """Expose the current image state in the simulator_state.""" + assert isinstance(self.simulator_state, Dict) + return self.simulator_state["unlabeled_image"] + + @property + def labeled_image(self) -> Optional[Image]: + """Expose the current image state in the simulator_state.""" + assert isinstance(self.simulator_state, Dict) + return self.simulator_state.get("images") + + @property + def obj_mask_dict(self) -> Optional[Dict[Object, Mask]]: + """Expose the current object masks in the simulator_state.""" + assert isinstance(self.simulator_state, Dict) + return self.simulator_state.get("obj_mask_dict") def allclose(self, other: State) -> bool: # Ignores the simulator state. return State(self.data).allclose(State(other.data)) - def copy(self) -> State: - state_dict_copy = super().copy().data - simulator_state_copy = list(self.joint_positions) + def copy(self) -> PyBulletState: + copy = super().copy() + state_dict_copy = copy.data + # simulator_state_copy = list(self.joint_positions) + simulator_state_copy = copy.simulator_state return PyBulletState(state_dict_copy, simulator_state_copy) + def get_obj_mask(self, obj: Object) -> Mask: + """Return the mask for the object.""" + assert self.obj_mask_dict is not None + mask = self.obj_mask_dict.get(obj) + assert mask is not None + return mask + + def label_all_objects(self) -> None: + """Label all objects in the simulator state.""" + state_ip = ImagePatch(self) + obj_mask_dict = self.obj_mask_dict + assert obj_mask_dict is not None + state_ip.label_all_objects(obj_mask_dict) + assert isinstance(self.simulator_state, Dict) + self.simulator_state["images"] = state_ip.cropped_image_in_PIL + + def add_images_and_masks(self, unlabeled_image: PIL.Image.Image, + masks: Dict[Object, Mask]) -> None: + """Add the unlabeled image and object masks to the simulator state.""" + assert isinstance(self.simulator_state, Dict) + self.simulator_state["unlabeled_image"] = unlabeled_image + self.simulator_state["obj_mask_dict"] = masks + self.label_all_objects() + + +BoundingBox = namedtuple('BoundingBox', 'left lower right upper') + + +@dataclass +class RawState(PyBulletState): + state_image: PIL.Image.Image = None # type: ignore[assignment] + obj_mask_dict: Dict[Object, Mask] = field(default_factory=dict) + labeled_image: Optional[PIL.Image.Image] = None # type: ignore[assignment] + option_history: Optional[List[str]] = None + bbox_features: Dict[Object, np.ndarray] = field( + default_factory=lambda: defaultdict(lambda: np.zeros(4))) + prev_state: Optional[RawState] = None + next_state: Optional[RawState] = None + + def __hash__(self) -> int: + # Convert the dictionary to a tuple of key-value pairs and hash it + # data_hash = hash(tuple(sorted(self.data.items()))) + data_tuple = tuple((k, tuple(v)) for k, v in sorted(self.data.items())) + if self.simulator_state is not None: + data_tuple += tuple(self.simulator_state) + data_hash = hash(data_tuple) + # # Hash the simulator_state + # simulator_state_hash = hash(self.simulator_state) + # Combine the two hashes + # return hash((data_hash, simulator_state_hash)) + return data_hash + + def evaluate_simple_assertion( + self, assertion: str, image: Tuple[BoundingBox, + Sequence[Object]]) -> VLMQuery: + """Given an assertion and an image, queries a VLM and returns whether + the assertion is true or false.""" + bbox, objs = image + return VLMQuery(assertion, bbox, list(objs)) + + def generate_previous_option_message(self) -> str: + """Generate the message for the previous option.""" + assert self.option_history is not None + msg = "Evaluate the truth value of the following assertions in the "\ + "current state as depicted by the image" + if CFG.nsp_pred_include_prev_image_in_prompt and \ + self.prev_state is not None: + msg += " labeled with 'curr. state'" + if CFG.nsp_pred_include_state_str_in_prompt: + msg += " and the information below" + + msg += ".\n" + + if CFG.nsp_pred_include_state_str_in_prompt: + msg += f"We have the object positions and the robot's "\ + "proprioception:\n" + msg += self.dict_str(indent=2, + object_features=False, + use_object_id=True, + position_proprio_features=True) + msg += "\n" + + if len(self.option_history) == 0: + msg += "For context, this is at the beginning of a task, before "\ + "the robot has done anything.\n" + else: + # return f"For context, this is right after the robot has "\ + # f"successfully executed its [{', '.join(self.option_history[-2:])}]"\ + # f" option sequence." + # msg = f"For context, this state is right after the robot has "\ + # f"successfully executed its {self.option_history[-1]} action." + msg += "For context, the state is right after the robot has"\ + " successfully executed the action "\ + f"{self.option_history[-1]}." + if CFG.nsp_pred_include_state_str_in_prompt: + if self.prev_state is not None: + msg += " The object position and robot proprioception "\ + "before executing the action is:\n" + msg += self.prev_state.dict_str( + indent=2, + object_features=False, + use_object_id=True, + position_proprio_features=True) + msg += "\n" + if CFG.nsp_pred_include_prev_image_in_prompt: + msg += " The state before executing the action is depicted"\ + " by the image labeled with 'prev. state'." + msg += " Please carefully examine the images depicting the "\ + "'prev. state' and 'curr. state' before making a judgment." + msg += "\n" + msg += "The assertions to evaluate are:" + return msg + + def add_bbox_features(self) -> None: + """Add the features about the bounding box to the objects.""" + for obj, mask in self.obj_mask_dict.items(): + bbox = mask_to_bbox(mask) + for name, value in bbox._asdict().items(): + self.set(obj, f"bbox_{name}", value) + + def set(self, obj: Object, feature_name: str, feature_val: Any) -> None: + """Set the value of an object feature by name.""" + try: + idx = obj.type.feature_names.index(feature_name) + except: + breakpoint() + standard_feature_len = len(self.data[obj]) + if idx >= standard_feature_len: + # When setting the bounding box features for the first time + # So we'd first append 4 dimension and try to set again + self.bbox_features[obj][idx - standard_feature_len] = feature_val + else: + self.data[obj][idx] = feature_val + + def get(self, obj: Object, feature_name: str) -> Any: + idx = obj.type.feature_names.index(feature_name) + standard_feature_len = len(self.data[obj]) + if idx >= standard_feature_len: + return self.bbox_features[obj][idx - standard_feature_len] + else: + return self.data[obj][idx] + + def dict_str( + self, # type: ignore[override] + indent: int = 0, + object_features: bool = True, + use_object_id: bool = False, + position_proprio_features: bool = False) -> str: + """Return a dictionary representation of the state.""" + state_dict = {} + for obj in self: + obj_dict = {} + for attribute, value in zip( + obj.type.feature_names, + np.concatenate([self[obj], self.bbox_features[obj]]) + if self.bbox_features else self[obj]): + # include if it's proprioception feature, or position/bbox + # feature, or object_features is True + # if (obj.type.name == "robot" and \ + # attribute not in ["bbox_left", "bbox_right", "bbox_upper", + # "pose_x", "pose_y", "pose_z", "pose_y_norm", + # "bbox_lower"]) or object_features: + # # attribute in ["pose_x", "pose_y", "pose_z", "bbox_left", + # # "bbox_right", "bbox_upper", "bbox_lower"] or\ + # if isinstance(value, (float, int, np.float32)): + # value = round(float(value), 1) + # obj_dict[attribute] = value + if (position_proprio_features and attribute in [ + # "pose_x", "pose_y", "pose_z", "x", "y", "z", + "rot", + "fingers" + ]) or (object_features and attribute not in [ + "is_heavy", + # "grasp", + # "held", + # "is_held", + ]): + if isinstance(value, (float, int, np.float32)): + value = round(float(value), 1) + obj_dict[attribute] = value + + if use_object_id: obj_name = obj.id_name + else: obj_name = obj.name + state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict + + # Create a string of n_space spaces + spaces = " " * indent + + # Create a PrettyPrinter with a large width + dict_str = spaces + "{" + n_keys = len(state_dict.keys()) + for i, (key, value) in enumerate(state_dict.items()): + value_str = ', '.join(f"'{k}': {v}" for k, v in value.items()) + if value_str == "": + content_str = f"'{key}'" + else: + content_str = f"'{key}': {{{value_str}}}" + if i == 0: + dict_str += f"{content_str},\n" + elif i == n_keys - 1: + dict_str += spaces + f" {content_str}" + else: + dict_str += spaces + f" {content_str},\n" + dict_str += "}" + return dict_str + + def __eq__(self, other: object) -> bool: + # Compare the data and simulator_state + assert isinstance(other, RawState) + + if len(self.data) != len(other.data): + return False + + for key, value in self.data.items(): + if key not in other.data or not np.array_equal( + value, other.data[key]): + return False + + return self.simulator_state == other.simulator_state + + def label_all_objects(self) -> None: + state_ip = ImagePatch(self) + # state_ip.cropped_image_in_PIL.save(f"images/obs_before_label_all.png") + # labels = [obj.id for obj in self.obj_mask_dict.keys()] + # masks = self.obj_mask_dict.values() + # state_ip.label_all_objects(masks, labels) + state_ip.label_all_objects(self.obj_mask_dict) + # state_ip.label_object(mask, obj.id) + # state_ip.cropped_image_in_PIL.save(f"images/obs_after_label_all.png") + self.labeled_image = state_ip.cropped_image_in_PIL + + def copy(self) -> RawState: + pybullet_state_copy = super().copy() + # simulator_state_copy = list(self.joint_positions) + state_image_copy = copy.copy(self.state_image) + obj_mask_copy = copy.deepcopy(self.obj_mask_dict) + labeled_image_copy = copy.copy(self.labeled_image) + option_history_copy = copy.copy(self.option_history) + bbox_features_copy = copy.deepcopy(self.bbox_features) + prev_state_copy = self.prev_state.copy() if self.prev_state else None + return RawState(pybullet_state_copy.data, + pybullet_state_copy.simulator_state, state_image_copy, + obj_mask_copy, labeled_image_copy, option_history_copy, + bbox_features_copy, prev_state_copy) + + def get_obj_mask(self, object: Object) -> Mask: + """Return the mask for the object.""" + return self.obj_mask_dict[object] + + def get_obj_bbox(self, object: Object) -> BoundingBox: + """Get the bounding box of the object in the state image The origin is + bottom left corner--(0, 0)""" + mask = self.get_obj_mask(object) + return mask_to_bbox(mask) + + def crop_to_objects( + self, + objects: Sequence[Object], + # left_margin: int = 15, + # lower_margin: int = 15, + # right_margin: int = 15, + # top_margin: int = 20 + left_margin: int = 30, + lower_margin: int = 30, + right_margin: int = 30, + top_margin: int = 30) -> Tuple[BoundingBox, Sequence[Object]]: + + bboxes = [self.get_obj_bbox(obj) for obj in objects] + bbox = smallest_bbox_from_bboxes(bboxes) + return (BoundingBox( + max(bbox.left - left_margin, 0), max(bbox.lower - lower_margin, 0), + min(bbox.right + right_margin, self.state_image.width), + min(bbox.upper + top_margin, self.state_image.height)), objects) + + # state_ip = ImagePatch(self, attn_objects=objects) + # return state_ip.crop_to_objects(objects, left_margin, lower_margin, + # right_margin, top_margin) + + +@dataclass +class VLMQuery: + """A class to represent a query to a VLM.""" + query_str: str + attention_box: BoundingBox + attn_objects: Optional[List[Object]] = None + ground_atom: Optional[GroundAtom] = None + + +def mask_to_bbox(mask: Mask) -> BoundingBox: + """Return the bounding box of the mask.""" + y_indices, x_indices = np.where(mask) + height = mask.shape[0] + + # Get the bounding box + try: + left = x_indices.min() + right = x_indices.max() + lower = height - (y_indices.max() + 1) + upper = height - (y_indices.min() + 1) + except ValueError: + left, lower, right, upper = 0, 0, 0, 0 + # If the mask is empty, return a bounding box with all zeros + + return BoundingBox(left, lower, right, upper) + + +def smallest_bbox_from_bboxes(bboxes: Sequence[BoundingBox]) -> BoundingBox: + """Return the smallest bounding box that contains all the given + bounding.""" + + # Initialize the bounding box coordinates + left, lower, right, upper = np.inf, np.inf, -np.inf, -np.inf + # Iterate over all masks + for bbox in bboxes: + # Update the bounding box + left = min(left, bbox.left) + lower = min(lower, bbox.lower) + right = max(right, bbox.right) + upper = max(upper, bbox.upper) + return BoundingBox(left, lower, right, upper) + class StateWithCache(State): """A state with a cache stored in the simulator state that is ignored for @@ -1135,12 +1495,26 @@ def run_policy( start_time = time.perf_counter() act = policy(state) metrics["policy_call_time"] += time.perf_counter() - start_time + except Exception as e: + if not CFG.video_not_break_on_exception: + if exceptions_to_break_on is not None and \ + type(e) in exceptions_to_break_on: + if monitor_observed: + exception_raised_in_step = True + break + raise e + if monitor is not None and not monitor_observed: + monitor.observe(state, None) + monitor_observed = True + else: + if monitor is not None and not monitor_observed: + monitor.observe(state, act) + monitor_observed = True + + try: # Note: it's important to call monitor.observe() before # env.step(), because the monitor may use the environment's # internal state. - if monitor is not None: - monitor.observe(state, act) - monitor_observed = True state = env.step(act) actions.append(act) states.append(state) @@ -1150,8 +1524,6 @@ def run_policy( if monitor_observed: exception_raised_in_step = True break - if monitor is not None and not monitor_observed: - monitor.observe(state, None) raise e if termination_function(state): break @@ -1195,11 +1567,15 @@ def run_policy_with_simulator( actions: List[Action] = [] exception_raised_in_step = False if not termination_function(state): - for _ in range(max_num_steps): + for i in range(max_num_steps): + if i % 15 == 0: + logging.debug(f"Step {i}") + # logging.debug(f"State: {state.pretty_str()}") monitor_observed = False exception_raised_in_step = False try: act = policy(state) + # logging.debug(f"Action: {act}") if monitor is not None: monitor.observe(state, act) monitor_observed = True @@ -1207,6 +1583,7 @@ def run_policy_with_simulator( actions.append(act) states.append(state) except Exception as e: + logging.debug(f"Exception during running policy: {e}") if exceptions_to_break_on is not None and \ type(e) in exceptions_to_break_on: if monitor_observed: @@ -1271,6 +1648,7 @@ def option_policy_to_policy( option_policy: Callable[[State], _Option], max_option_steps: Optional[int] = None, raise_error_on_repeated_state: bool = False, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None ) -> Callable[[State], Action]: """Create a policy that executes a policy over options.""" cur_option = DummyOption @@ -1296,9 +1674,31 @@ def _policy(state: State) -> Action: raise OptionTimeoutFailure( "Encountered repeated state.", info={"last_failed_option": last_option}) + # logging for debugging + # if last_state is not None: + # cur_atoms = abstract_function(state) + # prev_atoms = abstract_function(last_state) + # logging.debug(f"Prev atoms: {sorted(prev_atoms)}") + # logging.info(f"Add atoms: {sorted(cur_atoms-prev_atoms)} " + # f"Del atoms: {sorted(prev_atoms-cur_atoms)}") + + # whether the noop option should terminate + wait_terminate = False + if CFG.wait_option_terminate_on_atom_change and cur_option.name == "Wait": + assert abstract_function is not None + assert last_state is not None + cur_atoms = abstract_function(state) + prev_atoms = abstract_function(last_state) + if cur_atoms != prev_atoms: + # logging.debug(f"Prev atoms: {sorted(prev_atoms)}") + # logging.info(f"Add atoms: {sorted(cur_atoms-prev_atoms)} " + # f"Del atoms: {sorted(prev_atoms-cur_atoms)}") + wait_terminate = True + last_state = state - if cur_option is DummyOption or cur_option.terminal(state): + if wait_terminate or cur_option is DummyOption or cur_option.terminal( + state): try: cur_option = option_policy(state) except OptionExecutionFailure as e: @@ -1318,9 +1718,10 @@ def _policy(state: State) -> Action: def option_plan_to_policy( - plan: Sequence[_Option], - max_option_steps: Optional[int] = None, - raise_error_on_repeated_state: bool = False + plan: Sequence[_Option], + max_option_steps: Optional[int] = None, + raise_error_on_repeated_state: bool = False, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None ) -> Callable[[State], Action]: """Create a policy that executes a sequence of options in order.""" queue = list(plan) # don't modify plan, just in case @@ -1329,12 +1730,15 @@ def _option_policy(state: State) -> _Option: del state # not used if not queue: raise OptionExecutionFailure("Option plan exhausted!") - return queue.pop(0) + option = queue.pop(0) + logging.debug(f"Executing option {option.simple_str()}") + return option return option_policy_to_policy( _option_policy, max_option_steps=max_option_steps, - raise_error_on_repeated_state=raise_error_on_repeated_state) + raise_error_on_repeated_state=raise_error_on_repeated_state, + abstract_function=abstract_function) def nsrt_plan_to_greedy_option_policy( @@ -1378,7 +1782,8 @@ def nsrt_plan_to_greedy_policy( nsrt_plan: Sequence[_GroundNSRT], goal: Set[GroundAtom], rng: np.random.Generator, - necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None ) -> Callable[[State], Action]: """Greedily execute an NSRT plan, assuming downward refinability and that any sample will work. @@ -1388,7 +1793,60 @@ def nsrt_plan_to_greedy_policy( """ option_policy = nsrt_plan_to_greedy_option_policy( nsrt_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq) - return option_policy_to_policy(option_policy) + return option_policy_to_policy(option_policy, + abstract_function=abstract_function) + + +def process_plan_to_greedy_option_policy( + process_plan: Sequence[_GroundEndogenousProcess], + goal: Set[GroundAtom], + rng: np.random.Generator, + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, +) -> Callable[[State], _Option]: + """Greedily execute a process plan, assuming downward refinability and that + any sample will work. + + If an option is not initiable or if the plan runs out, an + OptionExecutionFailure is raised. + """ + cur_process: Optional[_GroundEndogenousProcess] = None + process_queue = list(process_plan) + if necessary_atoms_seq is None: + empty_atoms: Set[GroundAtom] = set() + necessary_atoms_seq = [ + empty_atoms for _ in range(len(process_plan) + 1) + ] + assert len(necessary_atoms_seq) == len(process_plan) + 1 + necessary_atoms_queue = list(necessary_atoms_seq) + + def _option_policy(state: State) -> _Option: + nonlocal cur_process + if not process_queue: + raise OptionExecutionFailure("Process plan exhausted.") + expected_atoms = necessary_atoms_queue.pop(0) + if not all(a.holds(state) for a in expected_atoms): + raise OptionExecutionFailure( + "Executing the process failed to achieve the necessary atoms.") + cur_process = process_queue.pop(0) + cur_option = cur_process.sample_option(state, goal, rng) + logging.debug(f"Using option {cur_option.name}{cur_option.objects}" + f"{cur_option.params} from process plan.") + return cur_option + + return _option_policy + + +def process_plan_to_greedy_policy( + process_plan: Sequence[_GroundEndogenousProcess], + goal: Set[GroundAtom], + rng: np.random.Generator, + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None +) -> Callable[[State], Action]: + option_policy = process_plan_to_greedy_option_policy( + process_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq) + return option_policy_to_policy(option_policy, + abstract_function=abstract_function) def sample_applicable_option(param_options: List[ParameterizedOption], @@ -1444,7 +1902,7 @@ def sample_applicable_ground_nsrt( if len(applicable_nsrts) == 0: return None idx = rng.choice(len(applicable_nsrts)) - return applicable_nsrts[idx] + return applicable_nsrts[idx] # type: ignore[return-value] def action_arrs_to_policy( @@ -1803,18 +2261,25 @@ def run_hill_climbing( heuristic: Callable[[_S], float], early_termination_heuristic_thresh: Optional[float] = None, enforced_depth: int = 0, + exhaustive_lookahead: bool = False, parallelize: bool = False, verbose: bool = True, timeout: float = float('inf') ) -> Tuple[List[_S], List[_A], List[float]]: """Enforced hill climbing local search. - For each node, the best child node is always selected, if that child is + For each node, this search looks for an improvement up to `enforced_depth`. + If `exhaustive_lookahead` is False (default), for each node, the best child + node is always selected, if that child is an improvement over the node. If no children improve on the node, look at the children's children, etc., up to enforced_depth, where enforced_depth 0 corresponds to simple hill climbing. Terminate when no improvement can be found. early_termination_heuristic_thresh allows for searching until heuristic reaches a specified value. + Let b be the branching factor, d be the enforced_depth, this has time + complxity of O(b^{d+1}). + If True, it searches the entire horizon up to the + enforced depth and picks the best overall improvement. Lower heuristic is better. """ @@ -1823,12 +2288,13 @@ def run_hill_climbing( initial_state, 0, 0) last_heuristic = heuristic(cur_node.state) heuristics = [last_heuristic] - visited = {initial_state} + # visited = {initial_state} # <--- deleted for exhaustive_lookahead if verbose: logging.info(f"\n\nStarting hill climbing at state {cur_node.state} " f"with heuristic {last_heuristic}") start_time = time.perf_counter() while True: + visited = {cur_node.state} # <--- added for exhaustive_lookahead # Stops when heuristic reaches specified value. if early_termination_heuristic_thresh is not None \ @@ -1882,15 +2348,27 @@ def run_hill_climbing( best_heuristic = child_heuristic best_child_node = child_node all_best_heuristics.append(best_heuristic) - if last_heuristic > best_heuristic: + + if not exhaustive_lookahead and last_heuristic > best_heuristic: # Some improvement found. if verbose: logging.info(f"Found an improvement at depth {depth}") break # Continue on to the next depth. current_depth_nodes = successors_at_depth + if not current_depth_nodes: + if verbose: + logging.info( + f"No more successors to explore at depth {depth}.") + break # No need to search deeper if there are no more nodes. + if verbose: - logging.info(f"No improvement found at depth {depth}") + if exhaustive_lookahead: + logging.info(f"Finished depth {depth}. " + f"Best heuristic so far: {best_heuristic}") + elif last_heuristic <= best_heuristic: + logging.info(f"No improvement found at depth {depth}") + if best_child_node is None: if verbose: logging.info("\nTerminating hill climbing, no more successors") @@ -1906,9 +2384,13 @@ def run_hill_climbing( if verbose: logging.info(f"\nHill climbing reached new state {cur_node.state} " f"with heuristic {last_heuristic}") + states, actions = _finish_plan(cur_node) - assert len(states) == len(heuristics) - return states, actions, heuristics + # The number of heuristics might not match the plan length perfectly now, + # so we should regenerate them from the final plan. + final_heuristics = [heuristic(s) for s in states] + assert len(states) == len(final_heuristics) + return states, actions, final_heuristics def run_policy_guided_astar( @@ -2224,23 +2706,36 @@ def _stripped_classifier( objects: Sequence[Object]) -> bool: # pragma: no cover. raise Exception("VLM predicate classifier should never be called!") - return VLMPredicate(name, types, _stripped_classifier, get_vlm_query_str) + return VLMPredicate(name, types, _stripped_classifier, + get_vlm_query_str) # type: ignore[arg-type] def create_llm_by_name( model_name: str) -> LargeLanguageModel: # pragma: no cover """Create particular llm using a provided name.""" - if "gemini" in model_name: + if CFG.pretrained_model_service_provider == "openai": + return OpenAILLM(model_name) + elif CFG.pretrained_model_service_provider == "google": return GoogleGeminiLLM(model_name) - return OpenAILLM(model_name) + elif CFG.pretrained_model_service_provider == "openrouter": + return OpenRouterLLM(model_name) + else: + raise ValueError(f"Unknown pretrained model service provider: " + f"{CFG.pretrained_model_service_provider}") def create_vlm_by_name( model_name: str) -> VisionLanguageModel: # pragma: no cover """Create particular vlm using a provided name.""" - if "gemini" in model_name: + if CFG.pretrained_model_service_provider == "openai": + return OpenAIVLM(model_name) + elif CFG.pretrained_model_service_provider == "google": return GoogleGeminiVLM(model_name) - return OpenAIVLM(model_name) + elif CFG.pretrained_model_service_provider == "openrouter": + return OpenRouterVLM(model_name) + else: + raise ValueError(f"Unknown pretrained model service provider: " + f"{CFG.pretrained_model_service_provider}") def parse_model_output_into_option_plan( @@ -2273,11 +2768,15 @@ def parse_model_output_into_option_plan( continue if option_name not in option_name_to_option.keys() or \ "(" not in option_str: - logging.info( - f"Line {option_str} output by model doesn't " - "contain a valid option name. Terminating option plan " - "parsing.") - break + if option_plan: + # Already found some options; stop on first non-option line. + logging.info( + f"Line {option_str} output by model doesn't " + "contain a valid option name. Terminating option plan " + "parsing.") + break + # Skip preamble lines (analysis text before the plan starts). + continue if parse_continuous_params and "[" not in option_str: logging.info( f"Line {option_str} output by model doesn't contain a " @@ -2473,8 +2972,11 @@ def query_vlm_for_atom_vals( # Query VLM. if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. - vlm_input_imgs = \ - [PIL.Image.fromarray(img_arr) for img_arr in imgs] # type: ignore + if CFG.env in ["pybullet_coffee"]: + vlm_input_imgs = list(imgs) # type: ignore + else: + vlm_input_imgs = \ + [PIL.Image.fromarray(img_arr) for img_arr in imgs] # type: ignore vlm_output = vlm.sample_completions(vlm_query_str, vlm_input_imgs, 0.0, @@ -2510,9 +3012,16 @@ def abstract(state: State, """ # Start by pulling out all VLM predicates. vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) + derived_preds, primitive_preds = set(), set() + for pred in preds: + if isinstance(pred, DerivedPredicate): + derived_preds.add(pred) + else: + primitive_preds.add(pred) + # Next, classify all non-VLM predicates. atoms = set() - for pred in preds: + for pred in primitive_preds: if pred not in vlm_preds: for choice in get_object_combinations(list(state), pred.types): if pred.holds(state, choice): @@ -2526,6 +3035,20 @@ def abstract(state: State, vlm_atoms.add(GroundAtom(pred, choice)) true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) atoms |= true_vlm_atoms + + # Evaluate derived predicates. + if len(derived_preds) > 0: + try: + atoms |= abstract_with_derived_predicates(atoms, derived_preds, + list(state)) + except PredicateEvaluationError as e: + raise e + # buggy_pred = e.pred + # # logging.debug(f"preds before {buggy_pred} is removed: {preds}") + # cnpt_preds.remove(buggy_pred) + # # logging.debug(f"preds after {buggy_pred} is removed: {preds}") + # return abstract(state, prim_preds | cnpt_preds, vlm, + # return_valid_preds) return atoms @@ -2561,12 +3084,17 @@ def all_ground_operators_given_partial( yield ground_op -def all_ground_nsrts(nsrt: NSRT, +def all_ground_nsrts(nsrt: Union[NSRT, CausalProcess], objects: Collection[Object]) -> Iterator[_GroundNSRT]: """Get all possible groundings of the given NSRT with the given objects.""" types = [p.type for p in nsrt.parameters] for choice in get_object_combinations(objects, types): - yield nsrt.ground(tuple(choice)) + # only return if there are no repeated arguments + if CFG.no_repeated_arguments_in_grounding: + if len(choice) == len(set(choice)): + yield nsrt.ground(tuple(choice)) # type: ignore[misc] + else: + yield nsrt.ground(tuple(choice)) # type: ignore[misc] def all_ground_nsrts_fd_translator( @@ -2865,6 +3393,24 @@ def create_ground_atom_dataset( return ground_atom_dataset +def create_ground_atom_option_dataset( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate]) -> List[AtomOptionTrajectory]: + """Apply all predicates to all trajectories in the dataset and also + annotate with options (HLA).""" + ground_atom_option_dataset = [] + for traj in trajectories: + # TODO: this is current just based on the current states. We would + # probably want to extend this to state history. + atoms = [abstract(s, predicates) for s in traj.states] + options = [a.get_option() for a in traj.actions] + ground_atom_option_dataset.append( + AtomOptionTrajectory( + traj.states, atoms, options, traj.is_demo, + traj.train_task_idx if traj.is_demo else None)) + return ground_atom_option_dataset + + def prune_ground_atom_dataset( ground_atom_dataset: List[GroundAtomTrajectory], kept_predicates: Collection[Predicate]) -> List[GroundAtomTrajectory]: @@ -3017,14 +3563,21 @@ def get_reachable_atoms(ground_ops: Collection[GroundNSRTOrSTRIPSOperator], def get_applicable_operators( - ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - atoms: Collection[GroundAtom]) -> Iterator[GroundNSRTOrSTRIPSOperator]: + ground_ops: Collection[Union[GroundNSRTOrSTRIPSOperator, + _GroundEndogenousProcess]], + atoms: Collection[GroundAtom] +) -> Iterator[Union[GroundNSRTOrSTRIPSOperator, _GroundEndogenousProcess]]: """Iterate over ground operators whose preconditions are satisfied. Note: the order may be nondeterministic. Users should be invariant. """ for op in ground_ops: - applicable = op.preconditions.issubset(atoms) + if isinstance(op, _GroundNSRT) or isinstance(op, + _GroundSTRIPSOperator): + applicable = op.preconditions.issubset(atoms) + elif isinstance(op, _GroundEndogenousProcess): + applicable = op.condition_at_start.issubset(atoms) + if applicable: yield op @@ -3070,7 +3623,7 @@ def get_successors_from_ground_ops( """ seen_successors = set() for ground_op in get_applicable_operators(ground_ops, atoms): - next_atoms = apply_operator(ground_op, atoms) + next_atoms = apply_operator(ground_op, atoms) # type: ignore[type-var] if unique: frozen_next_atoms = frozenset(next_atoms) if frozen_next_atoms in seen_successors: @@ -3461,14 +4014,21 @@ def create_video_from_partial_refinements( _, plan = max(partial_refinements, key=lambda x: len(x[1])) policy = option_plan_to_policy(plan) video: Video = [] + logging.debug("reset env for create video") state = env.reset(train_or_test, task_idx) - for _ in range(max_num_steps): + # logging.debug(f"{pformat(state.pretty_str())}") + for i in range(max_num_steps): + # logging.debug(f"state: {state.pretty_str()}") try: act = policy(state) + # logging.debug(f"act: {act}") except OptionExecutionFailure: video.extend(env.render()) - break - video.extend(env.render(act)) + if not CFG.video_not_break_on_exception: + break + else: + video.extend(env.render(act)) + # logging.debug("Finished rendering.") try: state = env.step(act) except EnvironmentFailure: @@ -3493,22 +4053,37 @@ def save_video(outfile: str, video: Video) -> None: outdir = CFG.video_dir os.makedirs(outdir, exist_ok=True) outpath = os.path.join(outdir, outfile) - imageio.mimwrite(outpath, video, fps=CFG.video_fps) # type: ignore + video_uint8 = [np.array(frame).astype(np.uint8) for frame in video] + imageio.mimwrite(outpath, video_uint8, fps=CFG.video_fps) # type: ignore logging.info(f"Wrote out to {outpath}") -def save_images(outfile_prefix: str, video: Video) -> None: - """Save the video as individual images to image_dir.""" +def save_images_parallel(outfile_prefix: str, video: Video) -> None: + """Save the video as individual images in parallel.""" outdir = CFG.image_dir + outdir = os.path.join(outdir, os.path.dirname(outfile_prefix)) + outfile_prefix = os.path.basename(outfile_prefix) + os.makedirs(outdir, exist_ok=True) width = len(str(len(video))) - for i, image in enumerate(video): + + def _write_frame(i: int, image: Any) -> None: image_number = str(i).zfill(width) outfile = outfile_prefix + f"_image_{image_number}.png" outpath = os.path.join(outdir, outfile) - imageio.imwrite(outpath, image) + image_array = np.array(image) + imageio.imwrite(outpath, image_array.astype(np.uint8)) logging.info(f"Wrote out to {outpath}") + with ThreadPoolExecutor() as executor: + for i, frame in enumerate(video): + executor.submit(_write_frame, i, frame) + + +def save_images(outfile_prefix: str, video: Video) -> None: + """Save the video as individual images to image_dir.""" + return save_images_parallel(outfile_prefix, video) + def get_env_asset_path(asset_name: str, assert_exists: bool = True) -> str: """Return the absolute path to env asset.""" @@ -3628,8 +4203,13 @@ def get_config_path_str(experiment_id: Optional[str] = None) -> str: """ if experiment_id is None: experiment_id = CFG.experiment_id - return (f"{CFG.env}__{CFG.approach}__{CFG.seed}__{CFG.excluded_predicates}" - f"__{CFG.included_options}__{experiment_id}") + if CFG.use_counterfactual_dataset_path_name: + return (f"{CFG.env}__{CFG.seed}__{CFG.experiment_id}__query") + else: + return ( + f"{CFG.env}__{CFG.approach}__{CFG.seed}__" + f"{CFG.excluded_predicates}__{CFG.included_options}__{experiment_id}" + ) def get_approach_save_path_str() -> str: @@ -3769,6 +4349,187 @@ def null_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, return np.array([], dtype=np.float32) # no continuous parameters +class ConstantDelay(DelayDistribution): + + def __init__(self, delay: Union[int, float, torch.Tensor]): + # keep dtype consistent with the rest of the model + self.delay = torch.as_tensor(delay, dtype=torch.get_default_dtype()) + # reusable – matches self.delay’s dtype/device + self._neg_inf = torch.tensor(float("-inf"), + dtype=self.delay.dtype, + device=self.delay.device) + + def copy(self) -> ConstantDelay: + """Return a copy of this distribution.""" + return ConstantDelay(self.delay.clone()) + + def sample(self) -> int: + return int(self.delay.item()) + + def set_parameters(self, parameters: Sequence[torch.Tensor], + **kwargs: Any) -> None: + self.delay = parameters[0] + # Invalidate cached properties + self.__dict__.pop("_str", None) + self.__dict__.pop("_hash", None) + + def get_parameters(self) -> Sequence[float]: + """Return the parameters of the distribution.""" + return [self.delay.item()] + + def probability(self, k: int) -> float: + return 1.0 if k == int(self.delay.item()) else 0.0 + + def log_prob(self, k: Union[int, torch.Tensor]) -> torch.Tensor: + """Vectorised log-prob; differentiable w.r.t. + + self.delay. + """ + if not isinstance(k, torch.Tensor): + k_tensor = torch.tensor(k, + dtype=torch.long, + device=self.delay.device) + else: + k_tensor = k.long().to(self.delay.device) + + zeros = torch.zeros_like(k_tensor, dtype=torch.get_default_dtype()) + neg_inf = torch.full_like(k_tensor, + float("-inf"), + dtype=torch.get_default_dtype()) + return torch.where(k_tensor == self.delay.long(), zeros, neg_inf) + + @cached_property + def _str(self) -> str: + return f"ConstantDelay({self.delay:.4f})" + + +class DiscreteGaussianDelay(DelayDistribution): + r"""Truncated discrete Gaussian distribution (a.k.a. Discrete Normal). + + Parameters + ---------- + mu : float or Tensor + Location parameter (can be any real number). + sigma : float or Tensor + Scale (> 0). Smaller values → tighter mass around ``mu``. + max_k : int, optional + Build / cache the PMF on the support k = 0 … max_k-1 (default 300). + """ + + def __init__(self, + mu: torch.Tensor, + sigma: torch.Tensor, + max_k: int = 300) -> None: + if not torch.all(sigma > 0): + raise ValueError("Initial sigma must be positive.") + + self.log_mu = torch.log(mu) + self.log_sigma = torch.log(sigma) + self._max_k = max_k + self._update_cache() + + def copy(self) -> DiscreteGaussianDelay: + """Return a copy of this distribution.""" + return DiscreteGaussianDelay(self.mu.clone(), self.sigma.clone(), + self._max_k) + + @property + def sigma(self) -> torch.Tensor: + """The actual standard deviation, derived from the optimized + log_sigma.""" + return torch.exp(self.log_sigma) + + @property + def mu(self) -> torch.Tensor: + """The mean of the discrete Gaussian.""" + return torch.exp(self.log_mu) + + # ------------------------------------------------------------------ # + # Internals + # ------------------------------------------------------------------ # + def _update_cache(self) -> None: + """Rebuild cached log-PMF / PMF / CDF using safe numerics.""" + EPS = 1e-8 + + mu = self.mu + sigma_val = self.sigma + sigma = torch.clamp(sigma_val, min=EPS) # ensure positivity + if not torch.all(sigma > 0): + raise ValueError("Initial sigma must be positive.") + + assert isinstance(self._max_k, int) + ks = torch.arange(self._max_k, dtype=mu.dtype, + device=mu.device) # k = 0 … max_k-1 + + # Unnormalised log-probability of a discrete Gaussian + # p̃(k) = exp( −(k−μ)² / (2σ²) ) + # Work in log-space for stability: + log_p_unnorm = -0.5 * ((ks - mu)**2) / (sigma**2) + + # Remove any accidental NaNs / ±Inf + log_p_unnorm = torch.nan_to_num(log_p_unnorm, + nan=-torch.inf, + posinf=-torch.inf, + neginf=-torch.inf) + + # Normalise on the bounded support 0 … max_k-1 + log_norm = torch.logsumexp(log_p_unnorm, dim=0) + self._log_pmf = log_p_unnorm - log_norm + + self._pmf = self._log_pmf.exp() + self._cdf = torch.cumsum(self._pmf, dim=0) + + # ------------------------------------------------------------------ # + # Public interface (identical to DoublePoissonDelay) + # ------------------------------------------------------------------ # + def set_parameters(self, parameters: Sequence[torch.Tensor], + **kwargs: Any) -> None: + self.log_mu, self.log_sigma = parameters + if "max_k" in kwargs and kwargs["max_k"] is not None: + self._max_k = kwargs["max_k"] + self._update_cache() + # Invalidate cached repr/hash if present + self.__dict__.pop('_str', None) + self.__dict__.pop('_hash', None) + + def get_parameters(self) -> Sequence[float]: + """Return the parameters of the distribution.""" + return [self.mu.item(), self.sigma.item()] + + def probability(self, k: int) -> float: + if 0 <= k < self._max_k: + return float(self._pmf[k]) + return 0.0 + + def log_prob(self, k: Union[int, torch.Tensor]) -> torch.Tensor: + if not isinstance(k, torch.Tensor): + k_tensor = torch.tensor(k, dtype=torch.long) + else: + k_tensor = k.long() + + k_flat = k_tensor.flatten() + log_probs_flat = torch.full_like(k_flat, + float('-inf'), + dtype=self._log_pmf.dtype) + + mask = (k_flat >= 0) & (k_flat < self._max_k) + if mask.any(): + log_probs_flat[mask] = self._log_pmf[k_flat[mask]] + + return log_probs_flat.reshape(k_tensor.shape) + + def sample(self, sample_mode: bool = True) -> int: + if sample_mode: + return int(self.mu.item()) + else: + u = torch.rand(1).item() + return int(torch.searchsorted(self._cdf, torch.tensor(u))) + + @cached_property + def _str(self) -> str: + return f"DiscreteGaussianDelay({self.mu:.4f}, {self.sigma:.4f})" + + @functools.lru_cache(maxsize=None) def get_git_commit_hash() -> str: """Return the hash of the current git commit.""" @@ -4085,3 +4846,314 @@ def add_text_to_draw_img( # Add the text to the image draw.text(position, text, fill="red", font=font) return draw + + +def wrap_angle(angle: float) -> float: + """Wrap an angle in radians to [-pi, pi].""" + return np.arctan2(np.sin(angle), np.cos(angle)) + + +def get_parameterized_option_by_name( + options: Set[ParameterizedOption], + option_name: str) -> Optional[ParameterizedOption]: + """Retrieve an option by its name from a set of options.""" + return next((option for option in options if option.name == option_name), + None) + + +def get_object_by_name(objects: Collection[Object], + name: str) -> Optional[Object]: + """Get an object by its name from a collection of objects. + + Args: + objects: Collection of objects to search through + name: Name of the object to find + + Returns: + The object if found, None otherwise + """ + return next((obj for obj in objects if obj.name == name), None) + + +def configure_logging() -> None: + # Create a single formatter instance to be reused + colored_formatter = colorlog.ColoredFormatter( + '%(log_color)s%(levelname)s: %(message)s', + log_colors={ + 'DEBUG': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'red,bg_white', + }, + reset=True, + style='%') + # Log to stderr. + colorlog_handler = colorlog.StreamHandler() + colorlog_handler.setFormatter(colored_formatter) + handlers: List[logging.Handler] = [colorlog_handler] + if CFG.log_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + CFG.log_file += (f"{CFG.approach}/{CFG.experiment_id}/" + f"seed{CFG.seed}/run_{timestamp}/") + os.makedirs(CFG.log_file, exist_ok=True) + + # Handler for DEBUG level messages + debug_handler = logging.FileHandler(os.path.join( + CFG.log_file, "debug.log"), + mode='w') + debug_handler.setLevel(logging.DEBUG) + debug_handler.setFormatter(colored_formatter) + handlers.append(debug_handler) + + # Handler for INFO level messages + info_handler = logging.FileHandler(os.path.join( + CFG.log_file, "info.log"), + mode='w') + info_handler.setLevel(logging.INFO) + info_handler.setFormatter(colored_formatter) + handlers.append(info_handler) + + logging.basicConfig(level=CFG.loglevel, + format="%(message)s", + handlers=handlers, + force=True) + logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) + logging.getLogger('libpng').setLevel(logging.ERROR) + logging.getLogger('PIL').setLevel(logging.ERROR) + logging.getLogger('openai').setLevel(logging.INFO) + # Used by openai package + logging.getLogger("httpx").setLevel(logging.INFO) + logging.getLogger("httpcore").setLevel(logging.INFO) + + +def log_initial_info(str_args: str) -> None: + """Log initial configuration and setup information.""" + if CFG.log_file: + logging.info(f"Logging to {CFG.log_file}") + logging.info(f"Running command: python {str_args}") + logging.info("Full config:") + logging.info(CFG) + logging.info(f"Git commit hash: {get_git_commit_hash()}") + + +def add_label_to_video(video: Video, + prefix: str, + imgs_dir: str, + save: bool = True) -> Video: + """Add a label to each frame of the video and save the images.""" + os.makedirs(imgs_dir, exist_ok=True) + new_video: Video = [] + for i, img in enumerate(video): + img_name = prefix + f"frame_{i+1}" + labeled_img = add_label_to_image(img, img_name, imgs_dir, + save=save) # type: ignore[arg-type] + new_video.append(labeled_img) # type: ignore[arg-type] + return new_video + + +def add_label_to_image(img: PIL.Image.Image, + s_name: str, + obs_dir: str, + f_suffix: str = ".png", + save: bool = True) -> PIL.Image.Image: + """Add a label to an image and potentially save.""" + img_copy = img.copy() + draw = ImageDraw.Draw(img_copy) + font = ImageFont.load_default().font_variant( + size=50) # type: ignore[union-attr] + + # Get text dimensions + bbox = draw.textbbox((0, 0), s_name, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # Calculate position (bottom right with padding) + padding = 10 + x = img_copy.width - text_width - padding + y = img_copy.height - text_height - 2 * padding + + text_color = (0, 0, 0) # black + draw.text((x, y), s_name, fill=text_color, font=font) + + if save: + os.makedirs(obs_dir, exist_ok=True) + img_copy.save(os.path.join(obs_dir, s_name + f_suffix)) + logging.debug(f"Saved Image {s_name}") + return img_copy + + +def load_all_images_from_dir(dir_path: str) -> List[PIL.Image.Image]: + """Load all images from a directory.""" + images = [] + img_paths = sorted(os.listdir(dir_path)) + for file in img_paths: + if file.endswith(('.png', '.jpg')): + images.append(PIL.Image.open(os.path.join(dir_path, file))) + return images + + +def all_subsets(input_set: Iterable[Any]) -> Iterator[Set[Any]]: + """Generates all subsets of a given set. + + Args: + input_set: An iterable (e.g., a list, set, tuple) + from which to generate subsets. + + Yields: + tuple: Each subset as a tuple. + """ + s = list(input_set) # Convert to list to handle various iterable inputs + n = len(s) + for i in range(n + 1): # Iterate from subset size 0 up to n + for subset in itertools.combinations(s, i): + yield set(subset) + + +def add_in_auxiliary_predicates(predicates: Set[Predicate]) -> Set[Predicate]: + # If a predicate is a drived predicate, check its auxiliary predicates + # attribute, and add them and all their derived predicates to the set + # recursively. + def add_auxiliary(pred: Predicate, preds: Set[Predicate]) -> None: + if isinstance(pred, DerivedPredicate): + if pred.auxiliary_predicates: + preds.update(pred.auxiliary_predicates) + for aux_pred in pred.auxiliary_predicates: + add_auxiliary(aux_pred, preds) + + new_preds = predicates.copy() + for pred in predicates: + add_auxiliary(pred, new_preds) + return new_preds + + +def get_derived_predicates( + predicates: Set[Predicate]) -> Set[DerivedPredicate]: + """Get all derived predicates from a set of predicates.""" + return {pred for pred in predicates if isinstance(pred, DerivedPredicate)} + + +# def abstract_with_derived_predicates(atoms, derived_preds, objects): +# """Compute all derived atoms via layered evaluation (fewer passes). +# Potentially faster than the current implementation.""" +# # Build dependency graph over derived preds +# is_derived = {p for p in derived_preds} +# indeg = {p: 0 for p in derived_preds} +# edges = {p: set() for p in derived_preds} +# for p in derived_preds: +# for aux in getattr(p, "auxiliary_predicates", []): +# # only count deps on other derived preds +# q = next((dp for dp in derived_preds if dp.name == aux.name), None) +# if q: +# edges[q].add(p); indeg[p] += 1 + +# # Kahn’s algorithm => layers +# frontier = [p for p in derived_preds if indeg[p] == 0] +# layers: list[list] = [] +# while frontier: +# layer = list(frontier); layers.append(layer); frontier = [] +# for u in layer: +# for v in edges[u]: +# indeg[v] -= 1 +# if indeg[v] == 0: +# frontier.append(v) + +# # Evaluate per layer; state grows monotonically +# state = set(atoms) +# derived_all = set() +# # (Optional) cache object choices per predicate once +# by_type = {} +# for o in objects: +# by_type.setdefault(o.type, []).append(o) +# choices_cache = { +# p: list(itertools.product(*(by_type[t] for t in p.types))) +# for p in derived_preds +# } + +# for layer in layers: +# for p in layer: +# for choice in choices_cache[p]: +# if p.holds(state, choice): +# derived_all.add(GroundAtom(p, choice)) +# state |= derived_all # grow state for next layer + +# return derived_all + + +def abstract_with_derived_predicates( + atoms: Set[GroundAtom], derived_preds: Collection[DerivedPredicate], + objects: Collection[Object]) -> Set[GroundAtom]: + """Compute the fixed point of concept predicate atoms.""" + primitive_atoms = atoms + new_concept_atoms: Set[GroundAtom] = set() + prev_new_concept_atoms: Set[GroundAtom] = set() + counter = 0 + while True: + # All the concept atoms that holds; all the previous atoms + atoms = primitive_atoms | new_concept_atoms + new_concept_atoms = _abstract_with_derived_predicates( + atoms, derived_preds, objects) + # logging.debug(f"ite {counter} concept atoms: {new_concept_atoms}") + converged = new_concept_atoms == prev_new_concept_atoms + if converged: + # logging.debug("converged") + break + prev_new_concept_atoms = new_concept_atoms + counter += 1 + return new_concept_atoms + + +def _abstract_with_derived_predicates( + abs_state: Set[GroundAtom], + derived_preds: Collection[DerivedPredicate], + objects: Collection[Object]) -> Set[GroundAtom]: + """Get the atoms based on the existing atomic state and concept + predicates.""" + atoms: Set[GroundAtom] = set() + for pred in derived_preds: + for choice in get_object_combinations(objects, pred.types): + try: + if pred.holds(abs_state, choice): + atoms.add(GroundAtom(pred, choice)) + except Exception as e: + logging.error(f"Error in evaluating concept predicate {pred}: " + f"{e}") + # raise e + raise PredicateEvaluationError( + f"Error in evaluating concept predicate {pred}: {e}", pred) + return atoms + + +def get_base_supporter_predicates( + root_predicate: DerivedPredicate) -> Set[Predicate]: + """Finds all primitive (non-derived) supporter predicates for a given root + derived predicate by traversing its dependency graph.""" + base_predicates: Set[Predicate] = set() + + # Use a worklist to process predicates in a breadth-first manner. + predicates_to_process: List[Predicate] = list( + root_predicate.auxiliary_predicates or []) + processed_predicates: Set[Predicate] = {root_predicate} + + while predicates_to_process: + pred = predicates_to_process.pop(0) + + if pred in processed_predicates: + continue + processed_predicates.add(pred) + + # If the predicate is derived, add its auxiliaries to the worklist. + if isinstance(pred, DerivedPredicate): + predicates_to_process.extend(pred.auxiliary_predicates or []) + # If it's a primitive predicate, we've found a base supporter. + else: + base_predicates.add(pred) + + return base_predicates + + +class PredicateEvaluationError(Exception): + + def __init__(self, message: str, pred: Any) -> None: + super().__init__(message) + self.pred = pred diff --git a/setup.py b/setup.py index 12326074f1..010bffe90f 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,8 @@ "imageio==2.22.2", "imageio-ffmpeg", "pandas==1.5.1", - "torch==2.0.1", + "torch>=2.2.0", + "torchvision>=0.17.0", "scipy==1.9.3", "tabulate==0.9.0", "dill==0.3.5.1", @@ -24,7 +25,7 @@ "requests", "slack_bolt", "pybullet>=3.2.0", - "scikit-learn==1.1.2", + "scikit-learn>=1.1.3", "graphlib-backport", "openai==1.19.0", "pyyaml==6.0", @@ -38,7 +39,9 @@ "ImageHash", "google-generativeai", "tenacity", - "httpx==0.27.0" + "httpx==0.27.0", + "opencv-python>=4.5.0", + "colorlog", ], include_package_data=True, extras_require={