From 7dbca12fcab491566d22e7e1d0c655500f2dc26c Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 28 Feb 2024 15:11:06 -0700 Subject: [PATCH 01/31] first throw at refactoring SamplingIterator --- src/gflownet/algo/trajectory_balance.py | 4 + src/gflownet/data/config.py | 6 + src/gflownet/data/data_source.py | 268 ++++++++++++++++++++++++ src/gflownet/data/sampling_iterator.py | 49 ++++- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/trainer.py | 66 +++--- src/gflownet/utils/misc.py | 21 ++ 7 files changed, 390 insertions(+), 26 deletions(-) create mode 100644 src/gflownet/data/data_source.py diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 75e5471f..308983ab 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -125,6 +125,7 @@ def __init__( # instead give "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using something # like a transformer with causal self-attention. self.model_is_autoregressive = False + self.random_action_prob = [cfg.algo.train_random_action_prob, cfg.algo.valid_random_action_prob] self.graph_sampler = GraphSampler( ctx, @@ -140,6 +141,9 @@ def __init__( self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float ): diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index fab5d036..5c5a9c84 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -16,9 +16,15 @@ class ReplayConfig: The number of samples to collect before starting to sample from the replay buffer hindsight_ratio : float The ratio of hindsight samples within a batch + batch_size : Optional[int] + The batch size for sampling from the replay buffer, defaults to the online batch size + replaces_online_data : bool + Whether to replace online data with samples from the replay buffer """ use: bool = False capacity: Optional[int] = None warmup: Optional[int] = None hindsight_ratio: float = 0 + batch_size: Optional[int] = None + replaces_online_data: bool = True diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py new file mode 100644 index 00000000..e6354e40 --- /dev/null +++ b/src/gflownet/data/data_source.py @@ -0,0 +1,268 @@ +import numpy as np +import torch +from gflownet.data.replay_buffer import ReplayBuffer +from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple, Generator +from torch.utils.data import Dataset, IterableDataset + +from gflownet.config import Config +from gflownet.utils.misc import get_worker_rng +from gflownet.envs.graph_building_env import GraphBuildingEnvContext +#from gflownet.trainer import GFNAlgorithm, GFNTask + + +def cycle_call(it): + while True: + for i in it(): + yield i + + +class DataSource(IterableDataset): + def __init__( + self, + cfg: Config, + ctx: GraphBuildingEnvContext, + algo, #: GFNAlgorithm, + task, #: GFNTask, # TODO: this will cause a circular import + dev: torch.device, + replay_buffer: Optional[ReplayBuffer] = None, + is_algo_eval: bool = False, + start_at_step: int = 0, + ): + """A DataSource mixes multiple iterators into one. These are created with do_* methods.""" + self.iterators: List[Generator] = [] + self.cfg = cfg + self.ctx = ctx + self.algo = algo + self.task = task + self.dev = dev + self.replay_buffer = replay_buffer + self.is_algo_eval = is_algo_eval + + self.global_step_count = torch.zeros(1, dtype=torch.int64) + start_at_step + self.global_step_count.share_memory_() + self.global_step_count_lock = torch.multiprocessing.Lock() + self.current_iter = start_at_step + self.sampling_hooks: List[Callable] = [] + self.active = True + + def add_sampling_hook(self, hook: Callable): + """Add a hook that is called when sampling new trajectories. + + The hook should take a list of trajectories as input. + The hook will not be called on trajectories that are sampled from the replay buffer or dataset. + """ + self.sampling_hooks.append(hook) + return self + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + self.rng = get_worker_rng() + its = [i() for i in self.iterators] + self.algo.set_is_eval(self.is_algo_eval) + while True: + with self.global_step_count_lock: + self.current_iter = self.global_step_count.item() + self.global_step_count += 1 + iterator_outputs = [next(i, None) for i in its] + if any(i is None for i in iterator_outputs): + if not all(i is None for i in iterator_outputs): + raise ValueError("Some iterators are done, but not all. You may be mixing incompatible iterators.") + break + traj_lists, batch_infos = zip(*iterator_outputs) + trajs = sum(traj_lists, []) + # Merge all the dicts into one + batch_info = {} + for d in batch_infos: + batch_info.update(d) + yield self.create_batch(trajs, batch_info) + + def do_sample_model(self, model, num_samples, keep_samples_in_batch=True): + if not keep_samples_in_batch: + assert self.replay_buffer is not None, "Throwing away samples without a replay buffer" + + def iterator(): + while self.active: + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_cond_info(num_samples, t) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.compute_properties(trajs, mark_as_online=True) + self.compute_log_rewards(trajs) + self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + batch_info = self.call_sampling_hooks(trajs) + yield (trajs, batch_info) if keep_samples_in_batch else ([], {}) + + self.iterators.append(iterator) + return self + + def do_sample_replay(self, num_samples): + def iterator(): + while self.active: + trajs = self.replay_buffer.sample(num_samples) + self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 + yield trajs, {} + show_type(iterator) + self.iterators.append(iterator) + return self + + def do_dataset_in_order(self, data, num_samples, backwards_model): + def iterator(): + for idcs in self.iterate_indices(num_samples): + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_conditional_information(num_samples, t) + objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.set_traj_props(trajs, props) + self.compute_log_rewards(trajs) + yield trajs, {} + + self.iterators.append(iterator) + return self + + def do_conditionals_dataset_in_order(self, data, num_samples, model): + def iterator(): + for idcs in self.iterate_indices(len(data), num_samples): + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = torch.stack([data[i] for i in idcs]) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + self.compute_properties(trajs, mark_as_online=True) + self.compute_log_rewards(trajs) + self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + batch_info = self.call_sampling_hooks(trajs) + yield trajs, batch_info + + self.iterators.append(iterator) + return self + + def do_sample_dataset(self, data, num_samples, backwards_model): + def iterator(): + while self.active: + idcs = self.sample_idcs(len(data), num_samples) + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_conditional_information(num_samples, t) + objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.set_traj_props(trajs, props) + self.compute_log_rewards(trajs) + yield trajs, {} + + self.iterators.append(iterator) + return self + + def call_sampling_hooks(self, trajs): + batch_info = {} + # TODO: just pass trajs to the hooks and deprecate passing all those arguments + flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + # convert cond_info back to a dict + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs["cond_info"][0]} + log_rewards = torch.stack([t["log_reward"] for t in trajs]) + for hook in self.sampling_hooks: + batch_info.update(hook(trajs, log_rewards, flat_rewards, cond_info)) + + def create_batch(self, trajs, batch_info): + ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) + log_rewards = torch.stack([t["log_reward"] for t in trajs]) + batch = self.algo.construct_batch(trajs, ci, log_rewards) + batch.num_online = sum(t["is_online"] for t in trajs) + batch.num_offline = len(trajs) - batch.num_online + batch.extra_info = batch_info + batch.preferences = torch.stack([t["preference"] for t in trajs]) + batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) + + if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute + log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] + batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) + batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) + # TODO: find code that depends on batch.flat_rewards and deprecate it + return batch + + def compute_properties(self, trajs, mark_as_online=False): + """Sets trajs' flat_rewards and is_valid keys by querying the task.""" + # TODO: refactor flat_rewards into properties + valid_idcs = torch.tensor([i for i in range(len(trajs)) if trajs[i].get("is_valid", True)]).long() + # fetch the valid trajectories endpoints + objs = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] + # ask the task to compute their reward + # TODO: it's really weird that the task is responsible for this and returns a flat_rewards + # tensor whose first dimension is possibly not the same as the output??? + flat_rewards, m_is_valid = self.task.compute_flat_rewards(objs) + assert flat_rewards.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + # The task may decide some of the objs are invalid, we have to again filter those + valid_idcs = valid_idcs[m_is_valid] + all_fr = torch.zeros((len(trajs), flat_rewards.shape[1])) + all_fr[valid_idcs] = flat_rewards + for i in range(len(trajs)): + trajs[i]["flat_rewards"] = all_fr[i] + trajs[i]["is_online"] = mark_as_online + # Override the is_valid key in case the task made some objs invalid + for i in valid_idcs: + trajs[i]["is_valid"] = True + + def compute_log_rewards(self, trajs): + """Sets trajs' log_reward key by querying the task.""" + flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} + log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + for i in range(len(trajs)): + trajs[i]["log_reward"] = log_rewards[i] if trajs[i]["is_valid"] else self.cfg.algo.illegal_action_logreward + + def send_to_replay(self, trajs): + if self.replay_buffer is not None: + for t in trajs: + self.replay_buffer.push(t, t["log_rewards"], t["flat_rewards"], t["cond_info"], t["is_valid"]) + + def set_traj_cond_info(self, trajs, cond_info): + for i in range(len(trajs)): + trajs[i]["cond_info"] = {k: cond_info[k][i] for k in cond_info} + + def set_traj_props(self, trajs, props): + for i in range(len(trajs)): + trajs[i]["flat_rewards"] = props[i] # TODO: refactor + + def relabel_in_hindsight(self, trajs): + if self.cfg.replay.hindsight_ratio == 0: + return + assert hasattr( + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + # samples indexes of trajectories without repeats + hindsight_idxs = torch.randperm(len(trajs))[: int(len(trajs) * self.cfg.replay.hindsight_ratio)] + log_rewards = torch.stack([t["log_reward"] for t in trajs]) + flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( + cond_info, log_rewards, flat_rewards, hindsight_idxs + ) + # TODO: This seems wrong, since we haven't recomputed is_valid + # log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + def sample_idcs(self, n, num_samples): + return self.rng.choice(n, num_samples, replace=False) + + def iterate_indices(self, n, num_samples): + worker_info = torch.utils.data.get_worker_info() + if n == 0: + # Should we be raising an error here? warning? + yield np.arange(0, 0) + return + + if worker_info is None: # no multi-processing + start, end, wid = 0, n, -1 + else: # split the data into chunks (per-worker) + nw = worker_info.num_workers + wid = worker_info.id + start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) + + if end - start <= num_samples: + yield np.arange(start, end) + return + for i in range(start, end - num_samples, num_samples): + yield np.arange(i, i + num_samples) + if i + num_samples < end: + yield np.arange(i + num_samples, end) \ No newline at end of file diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 9795467e..cf15f66c 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -198,7 +198,7 @@ def __iter__(self): num_online = num_offline num_offline = 0 cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) + steer_info=torch.stack([self.data[i] for i in idcs]) # This is sus, what's going on here? ) trajs, flat_rewards = [], [] @@ -400,6 +400,53 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info): self.log.insert_many(data, data_labels) +class SQLiteLogHook: + def __init__(self, log_dir, ctx) -> None: + self.log = None # Only initialized in __call__, which will occur inside the worker + self.log_dir = log_dir + self.ctx = ctx + self.data_labels = None + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + if self.log is None: + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log = SQLiteLog() + self.log.connect(self.log_path) + + if hasattr(self.ctx, "object_to_log_repr"): + mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + else: + mols = [""] * len(trajs) + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + if self.data_labels is None: + self.data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, self.data_labels) + + class SQLiteLog: def __init__(self, timeout=300): """Creates a log instance, but does not connect it to any db.""" diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 91d65818..fcce6ca6 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -206,7 +206,7 @@ def main(): "device": "cuda" if torch.cuda.is_available() else "cpu", "overwrite_existing_exp": True, "num_training_steps": 10_000, - "num_workers": 8, + "num_workers": 0, "opt": { "lr_decay": 20000, }, diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index e60d742e..afeceafe 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -16,8 +16,9 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.data.sampling_iterator import SamplingIterator +from gflownet.data.sampling_iterator import SamplingIterator, SQLiteLogHook from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger @@ -35,9 +36,11 @@ class GFNAlgorithm: updates: int = 0 + global_cfg: Config + is_eval: bool = False def step(self): - self.updates += 1 + self.updates += 1 # This isn't used anywhere? def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 @@ -62,6 +65,13 @@ def compute_batch_losses( """ raise NotImplementedError() + def get_random_action_prob(self, it: int): + if self.is_eval: + return self.global_cfg.algo.valid_random_action_prob + if it < self.global_cfg.algo.train_det_after or self.global_cfg.algo.train_det_after is None: + return self.global_cfg.algo.train_random_action_prob + return 0 + class GFNTask: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: @@ -188,14 +198,14 @@ def _wrap_for_mp(self, obj, send_to_device=False): if send_to_device: obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: - wapper = mp_object_wrapper( + wrapper = mp_object_wrapper( obj, self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, ) - self.to_terminate.append(wapper.terminate) - return wapper.placeholder, torch.device("cpu") + self.to_terminate.append(wrapper.terminate) + return wrapper.placeholder, torch.device("cpu") else: return obj, self.device @@ -203,28 +213,36 @@ def build_callbacks(self): return {} def build_training_data_loader(self) -> DataLoader: + # Since the model may be used by a worker in a different process, we need to wrap it. + # The device `dev` returned here is the device that the worker will use to interact with the model; + # normally, if the main process has the model on 'cuda', this will simply be 'cpu' (since workers + # don't have CUDA access). + # See implementation_nodes.md for more details. model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) - iterator = SamplingIterator( - self.training_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - replay_buffer=replay_buffer, - ratio=self.cfg.algo.offline_ratio, - log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), - random_action_prob=self.cfg.algo.train_random_action_prob, - det_after=self.cfg.algo.train_det_after, - hindsight_ratio=self.cfg.replay.hindsight_ratio, - ) + + n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.offline_ratio)) + n_replayed = n_drawn if self.cfg.replay.batch_size is None else self.cfg.replay.batch_size + n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + + src = DataSource(self.cfg, self.ctx, self.algo, self.task, dev, replay_buffer=replay_buffer) + if n_from_dataset: + src.do_dataset_in_order(self.training_data, n_from_dataset, backwards_model=model) + if n_drawn: + # If we are using a replay buffer, we can choose to keep the new samples in the minibatch, or just + # send them to the replay and train only on replay samples. + keep_samples_in_batch = not self.cfg.replay.use or not self.cfg.replay.replaces_online_data + src.do_sample_model(model, n_drawn, keep_samples_in_batch) + if n_replayed and replay_buffer is not None: + src.do_sample_replay(n_replayed) + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "train"), self.ctx)) for hook in self.sampling_hooks: - iterator.add_log_hook(hook) + src.add_sampling_hook(hook) + # TODO: We could just have a build_training_data_source method that returns a DataSource + # All the other build_* methods do the same DataLoader setup return torch.utils.data.DataLoader( - iterator, + src, batch_size=None, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers > 0, @@ -296,7 +314,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: if not torch.isfinite(loss): raise ValueError("loss is not finite") step_info = self.step(loss) - self.algo.step() + self.algo.step() # This also isn't used anywhere? if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") except ValueError as e: diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index d8b350b2..f65a83d5 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -1,6 +1,9 @@ import logging import sys +import numpy as np +import torch + def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): logger = logging.getLogger(name) @@ -21,3 +24,21 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHand logger.addHandler(handler) return logger + + +_worker_rngs = {} +_worker_rng_seed = [142857] + + +def get_worker_rng(): + worker_info = torch.utils.data.get_worker_info() + wid = worker_info.id if worker_info is not None else 0 + if wid not in _worker_rngs: + _worker_rngs[wid] = np.random.RandomState(_worker_rng_seed[0] + wid) + return _worker_rngs[wid] + + +def set_worker_rng_seed(seed): + _worker_rng_seed[0] = seed + for wid in _worker_rngs: + _worker_rngs[wid].seed(seed + wid) From dfba1ca478ee8a7b29cdb54753f25bc8ca2c7000 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Feb 2024 09:59:00 -0700 Subject: [PATCH 02/31] changed all iterators to DataSource --- src/gflownet/algo/soft_q_learning.py | 3 +- src/gflownet/algo/trajectory_balance.py | 9 +- src/gflownet/config.py | 1 + src/gflownet/data/data_source.py | 92 ++++++++++++++------ src/gflownet/data/sampling_iterator.py | 1 + src/gflownet/envs/frag_mol_env.py | 1 - src/gflownet/envs/mol_building_env.py | 1 - src/gflownet/online_trainer.py | 1 + src/gflownet/tasks/seh_frag.py | 11 ++- src/gflownet/trainer.py | 109 ++++++++++-------------- src/gflownet/utils/misc.py | 12 +++ 11 files changed, 142 insertions(+), 99 deletions(-) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index 1e3f1146..378d0b7e 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -8,6 +8,7 @@ from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory +from gflownet.utils.misc import get_worker_device class SoftQLearning: @@ -75,7 +76,7 @@ def create_training_data_from_own_samples( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) return data diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 308983ab..fcd171b4 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -22,6 +22,7 @@ generate_forward_trajectory, ) from gflownet.trainer import GFNAlgorithm +from gflownet.utils.misc import get_worker_device def shift_right(x: torch.Tensor, z=0): @@ -139,7 +140,7 @@ def __init__( ) if self.cfg.variant == TBVariant.SubTB1: self._subtb_max_len = self.global_cfg.algo.max_len + 2 - self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? + self._init_subtb(get_worker_device()) def set_is_eval(self, is_eval: bool): self.is_eval = is_eval @@ -171,7 +172,7 @@ def create_training_data_from_own_samples( - loss: predicted loss (if bootstrapping) - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) logZ_pred = model.logZ(cond_info) @@ -206,7 +207,7 @@ def create_training_data_from_graphs( """ if self.cfg.do_sample_p_b: assert model is not None and cond_info is not None and random_action_prob is not None - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob @@ -217,7 +218,7 @@ def create_training_data_from_graphs( self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) for gp, _ in traj["traj"][1:] ] + [1] - traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) + traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(get_worker_device()) traj["result"] = traj["traj"][-1][0] if self.cfg.do_parameterize_p_b: traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]] diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 782b4ff4..73ed6f15 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -94,6 +94,7 @@ class Config: print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None + num_validation_gen_steps: Optional[int] = None num_training_steps: int = 10_000 num_workers: int = 0 hostname: Optional[str] = None diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index e6354e40..5cc79b61 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -1,13 +1,16 @@ +import warnings +from typing import Callable, Generator, List, Optional + import numpy as np import torch -from gflownet.data.replay_buffer import ReplayBuffer -from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple, Generator -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import IterableDataset from gflownet.config import Config -from gflownet.utils.misc import get_worker_rng +from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext -#from gflownet.trainer import GFNAlgorithm, GFNTask +from gflownet.utils.misc import get_worker_rng + +# from gflownet.trainer import GFNAlgorithm, GFNTask def cycle_call(it): @@ -21,9 +24,8 @@ def __init__( self, cfg: Config, ctx: GraphBuildingEnvContext, - algo, #: GFNAlgorithm, - task, #: GFNTask, # TODO: this will cause a circular import - dev: torch.device, + algo, #: GFNAlgorithm, + task, #: GFNTask, # TODO: this will cause a circular import replay_buffer: Optional[ReplayBuffer] = None, is_algo_eval: bool = False, start_at_step: int = 0, @@ -34,16 +36,15 @@ def __init__( self.ctx = ctx self.algo = algo self.task = task - self.dev = dev self.replay_buffer = replay_buffer self.is_algo_eval = is_algo_eval + self.sampling_hooks: List[Callable] = [] + self.active = True self.global_step_count = torch.zeros(1, dtype=torch.int64) + start_at_step self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step - self.sampling_hooks: List[Callable] = [] - self.active = True def add_sampling_hook(self, hook: Callable): """Add a hook that is called when sampling new trajectories. @@ -67,8 +68,10 @@ def __iter__(self): iterator_outputs = [next(i, None) for i in its] if any(i is None for i in iterator_outputs): if not all(i is None for i in iterator_outputs): - raise ValueError("Some iterators are done, but not all. You may be mixing incompatible iterators.") - break + warnings.warn("Some iterators are done, but not all. You may be mixing incompatible iterators.") + iterator_outputs = [i for i in iterator_outputs if i is not None] + else: + break traj_lists, batch_infos = zip(*iterator_outputs) trajs = sum(traj_lists, []) # Merge all the dicts into one @@ -85,8 +88,9 @@ def iterator(): while self.active: t = self.current_iter p = self.algo.get_random_action_prob(t) - cond_info = self.task.sample_cond_info(num_samples, t) - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + cond_info = self.task.sample_conditional_information(num_samples, t) + # TODO: in the cond info refactor, pass the whole thing instead of just the encoding + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -97,13 +101,43 @@ def iterator(): self.iterators.append(iterator) return self + def do_sample_model_n_times(self, model, num_samples_per_batch, num_total): + total = torch.zeros(1, dtype=torch.int64) + total.share_memory_() + total_lock = torch.multiprocessing.Lock() + total_barrier = torch.multiprocessing.Barrier(max(1, self.cfg.num_workers)) + + def iterator(): + while self.active: + with total_lock: + n_so_far = total.item() + n_this_time = min(num_total - n_so_far, num_samples_per_batch) + total[:] += n_this_time + if n_this_time == 0: + break + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_conditional_information(n_this_time, t) + # TODO: in the cond info refactor, pass the whole thing instead of just the encoding + trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info["encoding"], p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.compute_properties(trajs, mark_as_online=True) + self.compute_log_rewards(trajs) + batch_info = self.call_sampling_hooks(trajs) + yield trajs, batch_info + total_barrier.wait() # Wait for all workers to finish before resetting the counter + total[:] = 0 + + self.iterators.append(iterator) + return self + def do_sample_replay(self, num_samples): def iterator(): while self.active: trajs = self.replay_buffer.sample(num_samples) self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} - show_type(iterator) + self.iterators.append(iterator) return self @@ -114,7 +148,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -129,7 +163,7 @@ def iterator(): t = self.current_iter p = self.algo.get_random_action_prob(t) cond_info = torch.stack([data[i] for i in idcs]) - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) self.send_to_replay(trajs) # This is a no-op if there is no replay buffer @@ -147,7 +181,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -161,10 +195,11 @@ def call_sampling_hooks(self, trajs): # TODO: just pass trajs to the hooks and deprecate passing all those arguments flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) # convert cond_info back to a dict - cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs["cond_info"][0]} + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = torch.stack([t["log_reward"] for t in trajs]) for hook in self.sampling_hooks: batch_info.update(hook(trajs, log_rewards, flat_rewards, cond_info)) + return batch_info def create_batch(self, trajs, batch_info): ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) @@ -173,8 +208,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t["is_online"] for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - batch.preferences = torch.stack([t["preference"] for t in trajs]) - batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]: + batch.preferences = torch.stack([t["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]: + batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] @@ -230,12 +267,13 @@ def relabel_in_hindsight(self, trajs): if self.cfg.replay.hindsight_ratio == 0: return assert hasattr( - self.task, "relabel_condinfo_and_logrewards" - ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" # samples indexes of trajectories without repeats hindsight_idxs = torch.randperm(len(trajs))[: int(len(trajs) * self.cfg.replay.hindsight_ratio)] log_rewards = torch.stack([t["log_reward"] for t in trajs]) flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( cond_info, log_rewards, flat_rewards, hindsight_idxs ) @@ -251,18 +289,18 @@ def iterate_indices(self, n, num_samples): # Should we be raising an error here? warning? yield np.arange(0, 0) return - + if worker_info is None: # no multi-processing start, end, wid = 0, n, -1 else: # split the data into chunks (per-worker) nw = worker_info.num_workers wid = worker_info.id start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) - + if end - start <= num_samples: yield np.arange(start, end) return for i in range(start, end - num_samples, num_samples): yield np.arange(i, i + num_samples) if i + num_samples < end: - yield np.arange(i + num_samples, end) \ No newline at end of file + yield np.arange(i + num_samples, end) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index cf15f66c..04ff0ebe 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -445,6 +445,7 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): ) self.log.insert_many(data, self.data_labels) + return {} class SQLiteLog: diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index bab9506b..daf9f99f 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -87,7 +87,6 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu GraphActionType.RemoveNode, GraphActionType.RemoveEdgeAttr, ] - self.device = torch.device("cpu") self.n_counter = NCounter() self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms()) diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 20c05586..5e43dd0b 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -157,7 +157,6 @@ def __init__( GraphActionType.RemoveEdge, GraphActionType.RemoveEdgeAttr, ] - self.device = torch.device("cpu") def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index edda9c79..2e59304f 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -6,6 +6,7 @@ import torch from omegaconf import OmegaConf from torch import Tensor +from torch.utils.data import DataLoader from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 2d56f213..54adcc7c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -158,6 +158,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 cfg.algo.valid_offline_ratio = 0 + cfg.num_validation_gen_steps = 10 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False cfg.algo.tb.Z_learning_rate = 1e-3 @@ -199,13 +200,17 @@ def setup(self): def main(): """Example of how this model can be run.""" + import datetime + config = init_empty(Config()) config.print_every = 1 - config.log_dir = "./logs/debug_run_seh_frag_pb" + config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.overwrite_existing_exp = True - config.num_training_steps = 10_000 - config.num_workers = 0 + config.num_training_steps = 1_00 + config.validate_every = 20 + config.num_final_gen_steps = 10 + config.num_workers = 8 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 config.algo.offline_ratio = 0.0 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index d903821b..17f38aef 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -19,10 +19,10 @@ from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.data.sampling_iterator import SamplingIterator, SQLiteLogHook +from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch -from gflownet.utils.misc import create_logger +from gflownet.utils.misc import create_logger, set_main_process_device from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from .config import Config @@ -69,7 +69,7 @@ def compute_batch_losses( def get_random_action_prob(self, it: int): if self.is_eval: return self.global_cfg.algo.valid_random_action_prob - if it < self.global_cfg.algo.train_det_after or self.global_cfg.algo.train_det_after is None: + if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: return self.global_cfg.algo.train_random_action_prob return 0 @@ -150,9 +150,10 @@ def __init__(self, config: Config, print_config=True): assert isinstance(self.default_cfg, Config) and isinstance( config, Config ) # make sure the config is a Config object, and not the Config class itself - self.cfg = OmegaConf.merge(self.default_cfg, config) + self.cfg: Config = OmegaConf.merge(self.default_cfg, config) self.device = torch.device(self.cfg.device) + set_main_process_device(self.device) # Print the loss every `self.print_every` iterations self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data @@ -223,6 +224,15 @@ def _wrap_for_mp(self, obj, send_to_device=False): def build_callbacks(self): return {} + def _make_data_loader(self, src): + return torch.utils.data.DataLoader( + src, + batch_size=None, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + prefetch_factor=1 if self.cfg.num_workers else None, + ) + def build_training_data_loader(self) -> DataLoader: # Since the model may be used by a worker in a different process, we need to wrap it. # The device `dev` returned here is the device that the worker will use to interact with the model; @@ -236,7 +246,7 @@ def build_training_data_loader(self) -> DataLoader: n_replayed = n_drawn if self.cfg.replay.batch_size is None else self.cfg.replay.batch_size n_from_dataset = self.cfg.algo.global_batch_size - n_drawn - src = DataSource(self.cfg, self.ctx, self.algo, self.task, dev, replay_buffer=replay_buffer) + src = DataSource(self.cfg, self.ctx, self.algo, self.task, replay_buffer=replay_buffer) if n_from_dataset: src.do_dataset_in_order(self.training_data, n_from_dataset, backwards_model=model) if n_drawn: @@ -250,70 +260,45 @@ def build_training_data_loader(self) -> DataLoader: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "train"), self.ctx)) for hook in self.sampling_hooks: src.add_sampling_hook(hook) - # TODO: We could just have a build_training_data_source method that returns a DataSource - # All the other build_* methods do the same DataLoader setup - return torch.utils.data.DataLoader( - src, - batch_size=None, - num_workers=self.cfg.num_workers, - persistent_workers=self.cfg.num_workers > 0, - prefetch_factor=1 if self.cfg.num_workers else None, - ) + return self._make_data_loader(src) def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - iterator = SamplingIterator( - self.test_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - ratio=self.cfg.algo.valid_offline_ratio, - log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), - sample_cond_info=self.cfg.cond.valid_sample_cond_info, - stream=False, - random_action_prob=self.cfg.algo.valid_random_action_prob, - ) + # TODO: we're changing the default, make sure anything that is using test data is adjusted + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.valid_offline_ratio)) + n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + if n_from_dataset: + src.do_dataset_in_order(self.test_data, n_from_dataset, backwards_model=model) + if n_drawn: + assert self.cfg.num_validation_gen_steps is not None + # TODO: might be better to change total steps to total trajectories drawn + src.do_sample_model_n_times(model, n_drawn, num_total=self.cfg.num_validation_gen_steps * n_drawn) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) for hook in self.valid_sampling_hooks: - iterator.add_log_hook(hook) - return torch.utils.data.DataLoader( - iterator, - batch_size=None, - num_workers=self.cfg.num_workers, - persistent_workers=self.cfg.num_workers > 0, - prefetch_factor=1 if self.cfg.num_workers else None, - ) + src.add_sampling_hook(hook) + return self._make_data_loader(src) def build_final_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - iterator = SamplingIterator( - self.training_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - replay_buffer=None, - ratio=0.0, - log_dir=os.path.join(self.cfg.log_dir, "final"), - random_action_prob=0.0, - hindsight_ratio=0.0, - init_train_iter=self.cfg.num_training_steps, - ) + model, dev = self._wrap_for_mp(self.model, send_to_device=True) + # TODO: we're changing the default, make sure anything that is using test data is adjusted + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + n_drawn = int(self.cfg.algo.global_batch_size) + + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + assert self.cfg.num_final_gen_steps is not None + # TODO: might be better to change total steps to total trajectories drawn + src.do_sample_model_n_times(model, n_drawn, num_total=self.cfg.num_final_gen_steps * n_drawn) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "final"), self.ctx)) for hook in self.sampling_hooks: - iterator.add_log_hook(hook) - return torch.utils.data.DataLoader( - iterator, - batch_size=None, - num_workers=self.cfg.num_workers, - persistent_workers=self.cfg.num_workers > 0, - prefetch_factor=1 if self.cfg.num_workers else None, - ) + src.add_sampling_hook(hook) + return self._make_data_loader(src) def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: tick = time.time() diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index f65a83d5..7ec5bdba 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -8,6 +8,8 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): logger = logging.getLogger(name) logger.setLevel(loglevel) + while len([logger.removeHandler(i) for i in logger.handlers]): + pass # Remove all handlers (only useful when debugging) formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - {} - %(message)s".format(name), datefmt="%d/%m/%Y %H:%M:%S", @@ -28,6 +30,7 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHand _worker_rngs = {} _worker_rng_seed = [142857] +_main_process_device = [torch.device("cpu")] def get_worker_rng(): @@ -42,3 +45,12 @@ def set_worker_rng_seed(seed): _worker_rng_seed[0] = seed for wid in _worker_rngs: _worker_rngs[wid].seed(seed + wid) + + +def set_main_process_device(device): + _main_process_device[0] = device + + +def get_worker_device(): + worker_info = torch.utils.data.get_worker_info() + return _main_process_device[0] if worker_info is None else torch.device("cpu") From e5239fb48726bd27e9fc4cfca435e5e98510a6d2 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Feb 2024 15:10:00 -0700 Subject: [PATCH 03/31] lots of little fixes, tested all tasks, better device management --- src/gflownet/algo/advantage_actor_critic.py | 4 +-- src/gflownet/algo/envelope_q_learning.py | 3 +- src/gflownet/data/data_source.py | 23 +++++++++++---- src/gflownet/online_trainer.py | 7 +++++ src/gflownet/tasks/qm9_moo.py | 31 ++++++++++++++++----- src/gflownet/tasks/seh_frag_moo.py | 23 +++++++++++++-- src/gflownet/trainer.py | 2 +- 7 files changed, 73 insertions(+), 20 deletions(-) diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 001e19d0..c6547b3d 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -6,7 +6,7 @@ from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory - +from gflownet.utils.misc import get_worker_device from .graph_sampling import GraphSampler @@ -79,7 +79,7 @@ def create_training_data_from_own_samples( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) return data diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 4d694ae2..7adfd68c 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -15,6 +15,7 @@ ) from gflownet.models.graph_transformer import GraphTransformer, mlp from gflownet.trainer import GFNTask +from gflownet.utils.misc import get_worker_device from .graph_sampling import GraphSampler @@ -233,7 +234,7 @@ def create_training_data_from_own_samples( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) return data diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 5cc79b61..4df5dcd0 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -134,7 +134,7 @@ def iterator(): def do_sample_replay(self, num_samples): def iterator(): while self.active: - trajs = self.replay_buffer.sample(num_samples) + trajs, *_ = self.replay_buffer.sample(num_samples) self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} @@ -143,7 +143,7 @@ def iterator(): def do_dataset_in_order(self, data, num_samples, backwards_model): def iterator(): - for idcs in self.iterate_indices(num_samples): + for idcs in self.iterate_indices(len(data), num_samples): t = self.current_iter p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) @@ -162,11 +162,18 @@ def iterator(): for idcs in self.iterate_indices(len(data), num_samples): t = self.current_iter p = self.algo.get_random_action_prob(t) - cond_info = torch.stack([data[i] for i in idcs]) + # TODO: when we refactor cond_info, data[i] will probably be a dict? (or CondInfo objects) + # I'm also not a fan of encode_conditional_information, it assumes lots of things about what's passed to + # it and the state of the program (e.g. validation mode) + cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + # If we're using a dataset of preferences, the user/hooks may want to know the id of the preference + for i, j in zip(trajs, idcs): + i["data_idx"] = j batch_info = self.call_sampling_hooks(trajs) yield trajs, batch_info @@ -197,15 +204,16 @@ def call_sampling_hooks(self, trajs): # convert cond_info back to a dict cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = torch.stack([t["log_reward"] for t in trajs]) + rewards = torch.exp(log_rewards / (cond_info.get("beta", 1))) for hook in self.sampling_hooks: - batch_info.update(hook(trajs, log_rewards, flat_rewards, cond_info)) + batch_info.update(hook(trajs, rewards, flat_rewards, cond_info)) return batch_info def create_batch(self, trajs, batch_info): ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) log_rewards = torch.stack([t["log_reward"] for t in trajs]) batch = self.algo.construct_batch(trajs, ci, log_rewards) - batch.num_online = sum(t["is_online"] for t in trajs) + batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info if "preferences" in trajs[0]: @@ -247,8 +255,11 @@ def compute_log_rewards(self, trajs): flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + min_r = torch.as_tensor(self.cfg.algo.illegal_action_logreward).float() for i in range(len(trajs)): - trajs[i]["log_reward"] = log_rewards[i] if trajs[i]["is_valid"] else self.cfg.algo.illegal_action_logreward + trajs[i]["log_reward"] = ( + log_rewards[i] if trajs[i].get("is_valid", True) else min_r + ) def send_to_replay(self, trajs): if self.replay_buffer is not None: diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 2e59304f..a73802d2 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -73,6 +73,8 @@ def setup(self): super().setup() self.offline_ratio = 0 self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None + self.sampling_hooks.append(AvgRewardHook()) + self.valid_sampling_hooks.append(AvgRewardHook()) # Separate Z parameters from non-Z to allow for LR decay on the former if hasattr(self.model, "logZ"): @@ -130,3 +132,8 @@ def step(self, loss: Tensor): for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) return {"grad_norm": g0, "grad_norm_clip": g1} + + +class AvgRewardHook: + def __call__(self, trajs, rewards, flat_rewards, extra_info): + return {"sampled_reward_avg": rewards.mean().item()} diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index b1dab870..d69e67f9 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -7,13 +7,15 @@ import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader import gflownet.models.mxmnet as mxmnet from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset +from gflownet.data.data_source import DataSource +from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks @@ -45,6 +47,7 @@ def __init__( self.cfg = cfg mcfg = self.cfg.task.qm9_moo self.objectives = cfg.task.qm9_moo.objectives + cfg.cond.moo.num_objectives = len(self.objectives) self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) @@ -224,14 +227,14 @@ def setup_model(self): num_emb=self.cfg.model.num_emb, num_layers=self.cfg.model.num_layers, num_heads=self.cfg.model.graph_transformer.num_heads, - num_objectives=len(self.cfg.task.seh_moo.objectives), + num_objectives=len(self.cfg.task.qm9_moo.objectives), ) else: super().setup_model() def setup(self): super().setup() - if self.cfg.task.seh_moo.online_pareto_front: + if self.cfg.task.qm9_moo.online_pareto_front: self.sampling_hooks.append( MultiObjectiveStatsHook( 256, @@ -245,7 +248,7 @@ def setup(self): self.to_terminate.append(self.sampling_hooks[-1].terminate) # instantiate preference and focus conditioning vectors for validation - n_obj = len(self.cfg.task.seh_moo.objectives) + n_obj = len(self.cfg.task.qm9_moo.objectives) cond_cfg = self.cfg.cond # making sure hyperparameters for preferences and focus regions are consistent @@ -263,7 +266,7 @@ def setup(self): if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1: n_valid = len(cond_cfg.focus_region.focus_type) else: - n_valid = self.cfg.task.seh_moo.n_valid + n_valid = self.cfg.task.qm9_moo.n_valid # preference vectors if cond_cfg.weighted_prefs.preference_type is None: @@ -298,8 +301,8 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) - self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) + self._top_k_hook = TopKHook(10, self.cfg.task.qm9_moo.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.qm9_moo.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -324,6 +327,20 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} + def build_validation_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.model, send_to_device=True) + + n_from_dataset = self.cfg.algo.global_batch_size + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) + for hook in self.valid_sampling_hooks: + src.add_sampling_hook(hook) + + return self._make_data_loader(src) + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: if self.task.focus_cond is not None: self.task.focus_cond.step_focus_model(batch, train_it) diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 1f8787d3..9ee69408 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -8,12 +8,14 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext +from gflownet.data.data_source import DataSource +from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.trainer import FlatRewards, RewardScalar @@ -70,6 +72,7 @@ def __init__( self.cfg = cfg mcfg = self.cfg.task.seh_moo self.objectives = cfg.task.seh_moo.objectives + cfg.cond.moo.num_objectives = len(self.objectives) # This value is used by the focus_cond and pref_cond self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) @@ -350,6 +353,20 @@ def on_validation_end(self, metrics: Dict[str, Any]): return callback_dict + def build_validation_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.model, send_to_device=True) + + n_from_dataset = self.cfg.algo.global_batch_size + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) + for hook in self.valid_sampling_hooks: + src.add_sampling_hook(hook) + + return self._make_data_loader(src) + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: if self.task.focus_cond is not None: self.task.focus_cond.step_focus_model(batch, train_it) @@ -363,7 +380,7 @@ def _save_state(self, it): class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): - self.cond_info_vectors = cond_info_vectors + self.cond_info_vectors = torch.as_tensor(cond_info_vectors).float() self.repeat = repeat def __len__(self): @@ -371,7 +388,7 @@ def __len__(self): def __getitem__(self, idx): assert 0 <= idx < len(self) - return torch.tensor(self.cond_info_vectors[int(idx // self.repeat)]) + return self.cond_info_vectors[int(idx // self.repeat)] def main(): diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 17f38aef..a9409562 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -248,7 +248,7 @@ def build_training_data_loader(self) -> DataLoader: src = DataSource(self.cfg, self.ctx, self.algo, self.task, replay_buffer=replay_buffer) if n_from_dataset: - src.do_dataset_in_order(self.training_data, n_from_dataset, backwards_model=model) + src.do_sample_dataset(self.training_data, n_from_dataset, backwards_model=model) if n_drawn: # If we are using a replay buffer, we can choose to keep the new samples in the minibatch, or just # send them to the replay and train only on replay samples. From 43dfc2b0c010cca57edbf10f5d92d229ee2286ff Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Feb 2024 17:51:51 -0700 Subject: [PATCH 04/31] style --- src/gflownet/algo/advantage_actor_critic.py | 1 + src/gflownet/data/data_source.py | 4 +--- src/gflownet/online_trainer.py | 1 - src/gflownet/tasks/qm9_moo.py | 4 ++-- src/gflownet/tasks/seh_frag_moo.py | 4 ++-- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index c6547b3d..40c58010 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -7,6 +7,7 @@ from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device + from .graph_sampling import GraphSampler diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 4df5dcd0..b68bfc8c 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -257,9 +257,7 @@ def compute_log_rewards(self, trajs): log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) min_r = torch.as_tensor(self.cfg.algo.illegal_action_logreward).float() for i in range(len(trajs)): - trajs[i]["log_reward"] = ( - log_rewards[i] if trajs[i].get("is_valid", True) else min_r - ) + trajs[i]["log_reward"] = log_rewards[i] if trajs[i].get("is_valid", True) else min_r def send_to_replay(self, trajs): if self.replay_buffer is not None: diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index a73802d2..9d30d457 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -6,7 +6,6 @@ import torch from omegaconf import OmegaConf from torch import Tensor -from torch.utils.data import DataLoader from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index d69e67f9..45c3576b 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -7,14 +7,14 @@ import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset import gflownet.models.mxmnet as mxmnet from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config -from gflownet.data.qm9 import QM9Dataset from gflownet.data.data_source import DataSource +from gflownet.data.qm9 import QM9Dataset from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 9ee69408..60f8092c 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -8,14 +8,14 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty -from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.data.data_source import DataSource from gflownet.data.sampling_iterator import SQLiteLogHook +from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.trainer import FlatRewards, RewardScalar From 279ecfcb8577dd3866d883ff4d1adaf2830f2f51 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:25:22 -0700 Subject: [PATCH 05/31] change batch size hyperparameters + fix nested dataclasses --- src/gflownet/algo/config.py | 38 +++++++++++++++++------------ src/gflownet/algo/graph_sampling.py | 3 ++- src/gflownet/config.py | 14 +++++------ src/gflownet/data/config.py | 16 +++++++----- src/gflownet/data/data_source.py | 15 ++++++++---- src/gflownet/models/config.py | 6 ++--- src/gflownet/online_trainer.py | 1 - src/gflownet/tasks/config.py | 8 +++--- src/gflownet/tasks/make_rings.py | 3 +-- src/gflownet/tasks/qm9.py | 3 ++- src/gflownet/tasks/qm9_moo.py | 7 ++---- src/gflownet/tasks/seh_frag.py | 8 +++--- src/gflownet/tasks/seh_frag_moo.py | 11 ++++----- src/gflownet/tasks/toy_seq.py | 4 +-- src/gflownet/trainer.py | 26 ++++++++++---------- src/gflownet/utils/config.py | 8 +++--- 16 files changed, 89 insertions(+), 82 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 0ccf2e0e..e2576982 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Optional @@ -95,8 +95,18 @@ class AlgoConfig: ---------- method : str The name of the algorithm to use (e.g. "TB") - global_batch_size : int - The batch size for training + num_from_policy : int + The number of on-policy samples for a training batch. + If using a replay buffer, see `replay.num_from_replay` for the number of samples from the replay buffer, and + `replay.num_new_samples` for the number of new samples to add to the replay buffer (e.g. `num_from_policy=0`, + and `num_new_samples=N` inserts `N` new samples in the replay buffer at each step, but does not make that data + part of the training batch). + num_from_dataset : int + The number of samples from the dataset for a training batch + valid_num_from_policy : int + The number of on-policy samples for a validation batch + valid_num_from_dataset : int + The number of samples from the dataset for a validation batch max_len : int The maximum length of a trajectory max_nodes : int @@ -105,11 +115,6 @@ class AlgoConfig: The maximum number of edges in a generated graph illegal_action_logreward : float The log reward an agent gets for illegal actions - offline_ratio: float - The ratio of samples drawn from `self.training_data` during training. The rest is drawn from - `self.sampling_model` - valid_offline_ratio: float - Idem but for validation, and `self.test_data`. train_random_action_prob : float The probability of taking a random action during training train_det_after: Optional[int] @@ -121,19 +126,20 @@ class AlgoConfig: """ method: str = "TB" - global_batch_size: int = 64 + num_from_policy: int = 64 + num_from_dataset: int = 0 + valid_num_from_policy: int = 64 + valid_num_from_dataset: int = 0 max_len: int = 128 max_nodes: int = 128 max_edges: int = 128 illegal_action_logreward: float = -100 - offline_ratio: float = 0.5 - valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 - tb: TBConfig = TBConfig() - moql: MOQLConfig = MOQLConfig() - a2c: A2CConfig = A2CConfig() - fm: FMConfig = FMConfig() - sql: SQLConfig = SQLConfig() + tb: TBConfig = field(default_factory=TBConfig) + moql: MOQLConfig = field(default_factory=MOQLConfig) + a2c: A2CConfig = field(default_factory=A2CConfig) + fm: FMConfig = field(default_factory=FMConfig) + sql: SQLConfig = field(default_factory=SQLConfig) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 7ad4fc0a..0db5bcec 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -1,4 +1,5 @@ import copy +import warnings from typing import List, Optional import torch @@ -248,7 +249,7 @@ def not_done(lst): # TODO: This should be doable. if random_action_prob > 0: - raise NotImplementedError("Random action not implemented for backward sampling") + warnings.warn("Random action not implemented for backward sampling") while sum(done) < n: torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 73ed6f15..e8238e97 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass, field, fields, is_dataclass from typing import Optional from omegaconf import MISSING @@ -101,12 +101,12 @@ class Config: pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False - algo: AlgoConfig = AlgoConfig() - model: ModelConfig = ModelConfig() - opt: OptimizerConfig = OptimizerConfig() - replay: ReplayConfig = ReplayConfig() - task: TasksConfig = TasksConfig() - cond: ConditionalsConfig = ConditionalsConfig() + algo: AlgoConfig = field(default_factory=AlgoConfig) + model: ModelConfig = field(default_factory=ModelConfig) + opt: OptimizerConfig = field(default_factory=OptimizerConfig) + replay: ReplayConfig = field(default_factory=ReplayConfig) + task: TasksConfig = field(default_factory=TasksConfig) + cond: ConditionalsConfig = field(default_factory=ConditionalsConfig) def init_empty(cfg): diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index 5c5a9c84..ce1bac7e 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -16,15 +16,19 @@ class ReplayConfig: The number of samples to collect before starting to sample from the replay buffer hindsight_ratio : float The ratio of hindsight samples within a batch - batch_size : Optional[int] - The batch size for sampling from the replay buffer, defaults to the online batch size - replaces_online_data : bool - Whether to replace online data with samples from the replay buffer + num_from_replay : Optional[int] + The number of replayed samples for a training batch (defaults to cfg.algo.num_from_policy, i.e. a 50/50 split) + num_new_samples : Optional[int] + The number of new samples added to the replay at every training step. Defaults to cfg.algo.num_from_policy. If + smaller than num_from_policy then not all on-policy samples will be added to the replay. If larger + than num_from_policy then the training batch will not contain all the new samples, but the buffer will. + For example, if one wishes to sample N samples every step but only add them to the buffer and not make them + part of the training batch, then one should set replay.num_new_samples=N and algo.num_from_policy=0. """ use: bool = False capacity: Optional[int] = None warmup: Optional[int] = None hindsight_ratio: float = 0 - batch_size: Optional[int] = None - replaces_online_data: bool = True + num_from_replay: Optional[int] = None + num_new_samples: Optional[int] = None diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index b68bfc8c..ca8c6d1b 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -80,9 +80,14 @@ def __iter__(self): batch_info.update(d) yield self.create_batch(trajs, batch_info) - def do_sample_model(self, model, num_samples, keep_samples_in_batch=True): - if not keep_samples_in_batch: - assert self.replay_buffer is not None, "Throwing away samples without a replay buffer" + def do_sample_model(self, model, num_from_policy, num_new_replay_samples=None): + if num_new_replay_samples is not None: + assert self.replay_buffer is not None, "num_new_replay_samples specified without a replay buffer" + if num_new_replay_samples is None: + assert self.replay_buffer is None, "num_new_replay_samples not specified with a replay buffer" + + num_new_replay_samples = num_new_replay_samples or 0 + num_samples = max(num_from_policy, num_new_replay_samples) def iterator(): while self.active: @@ -94,9 +99,9 @@ def iterator(): self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) - self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + self.send_to_replay(trajs[:num_new_replay_samples]) # This is a no-op if there is no replay buffer batch_info = self.call_sampling_hooks(trajs) - yield (trajs, batch_info) if keep_samples_in_batch else ([], {}) + yield (trajs[:num_from_policy], batch_info) self.iterators.append(iterator) return self diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index 05b00b6e..acce656a 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum @@ -35,5 +35,5 @@ class ModelConfig: num_layers: int = 3 num_emb: int = 128 dropout: float = 0 - graph_transformer: GraphTransformerConfig = GraphTransformerConfig() - seq_transformer: SeqTransformerConfig = SeqTransformerConfig() + graph_transformer: GraphTransformerConfig = field(default_factory=GraphTransformerConfig) + seq_transformer: SeqTransformerConfig = field(default_factory=SeqTransformerConfig) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 9d30d457..3ba1fde1 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -95,7 +95,6 @@ def setup(self): else: self.sampling_model = self.model - self.mb_size = self.cfg.algo.global_batch_size self.clip_grad_callback = { "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), "norm": lambda params: [torch.nn.utils.clip_grad_norm_(p, self.cfg.opt.clip_grad_param) for p in params], diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 4c29f634..3c8a0fab 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -62,7 +62,7 @@ class QM9MOOTaskConfig: @dataclass class TasksConfig: - qm9: QM9TaskConfig = QM9TaskConfig() - qm9_moo: QM9MOOTaskConfig = QM9MOOTaskConfig() - seh: SEHTaskConfig = SEHTaskConfig() - seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() + qm9: QM9TaskConfig = field(default_factory=QM9TaskConfig) + qm9_moo: QM9MOOTaskConfig = field(default_factory=QM9MOOTaskConfig) + seh: SEHTaskConfig = field(default_factory=SEHTaskConfig) + seh_moo: SEHMOOTaskConfig = field(default_factory=SEHMOOTaskConfig) diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 9211d038..00b2863b 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -41,8 +41,7 @@ class MakeRingsTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.num_workers = 8 - cfg.algo.global_batch_size = 64 - cfg.algo.offline_ratio = 0 + cfg.algo.num_from_policy = 64 cfg.model.num_emb = 128 cfg.model.num_layers = 4 diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 0e934906..4a3dd1ba 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -120,7 +120,8 @@ def set_default_hps(self, cfg: Config): cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 cfg.algo.max_nodes = 9 - cfg.algo.global_batch_size = 64 + cfg.algo.num_from_policy = 32 + cfg.algo.num_from_dataset = 32 cfg.algo.train_random_action_prob = 0.001 cfg.algo.illegal_action_logreward = -75 cfg.algo.sampling_tau = 0.0 diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 45c3576b..ec62643c 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -198,10 +198,8 @@ class QM9MOOTrainer(QM9GapTrainer): def set_default_hps(self, cfg: Config): super().set_default_hps(cfg) cfg.algo.sampling_tau = 0.95 - # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) - # sampling and set the offline ratio to 1 cfg.cond.valid_sample_cond_info = False - cfg.algo.valid_offline_ratio = 1 + cfg.algo.valid_num_from_dataset = 64 def setup_algo(self): algo = self.cfg.algo.method @@ -330,9 +328,8 @@ def on_validation_end(self, metrics: Dict[str, Any]): def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - n_from_dataset = self.cfg.algo.global_batch_size src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) if self.cfg.log_dir: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 54adcc7c..5c9040a7 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -111,7 +111,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class LittleSEHDataset(Dataset): """Note: this dataset isn't used by default, but turning it on showcases some features of this codebase. - To turn on, self `cfg.algo.offline_ratio > 0`""" + To turn on, self `cfg.algo.num_from_dataset > 0`""" def __init__(self, smis) -> None: super().__init__() @@ -146,8 +146,7 @@ def set_default_hps(self, cfg: Config): cfg.opt.lr_decay = 20_000 cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 - cfg.algo.global_batch_size = 64 - cfg.algo.offline_ratio = 0 + cfg.algo.num_from_policy = 64 cfg.model.num_emb = 128 cfg.model.num_layers = 4 @@ -157,7 +156,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.illegal_action_logreward = -75 cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 - cfg.algo.valid_offline_ratio = 0 + cfg.algo.valid_num_from_policy = 64 cfg.num_validation_gen_steps = 10 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False @@ -213,7 +212,6 @@ def main(): config.num_workers = 8 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 - config.algo.offline_ratio = 0.0 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 60f8092c..ca5b7e6a 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -222,10 +222,10 @@ class SEHMOOFragTrainer(SEHFragTrainer): def set_default_hps(self, cfg: Config): super().set_default_hps(cfg) cfg.algo.sampling_tau = 0.95 - # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) - # sampling and set the offline ratio to 1 - cfg.cond.valid_sample_cond_info = False - cfg.algo.valid_offline_ratio = 1 + # We sample from a dataset of valid conditional information, so we set this, and override + # build_validation_data_loader to use the dataset + cfg.cond.valid_sample_cond_info = False # TODO deprecate this? + cfg.algo.valid_num_from_dataset = 64 def setup_algo(self): algo = self.cfg.algo.method @@ -356,9 +356,8 @@ def on_validation_end(self, metrics: Dict[str, Any]): def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - n_from_dataset = self.cfg.algo.global_batch_size src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) if self.cfg.log_dir: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 901baea6..2aece1b2 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -54,8 +54,7 @@ def set_default_hps(self, cfg: Config): cfg.opt.lr_decay = 20_000 cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 - cfg.algo.global_batch_size = 64 - cfg.algo.offline_ratio = 0 + cfg.algo.num_from_policy = 64 cfg.model.num_emb = 64 cfg.model.num_layers = 4 @@ -66,7 +65,6 @@ def set_default_hps(self, cfg: Config): cfg.algo.illegal_action_logreward = -75 cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 - cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False cfg.algo.tb.Z_learning_rate = 1e-2 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index a9409562..5caed451 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -133,7 +133,6 @@ def __init__(self, config: Config, print_config=True): # the same as `model`. self.sampling_model: nn.Module self.replay_buffer: Optional[ReplayBuffer] - self.mb_size: int self.env: GraphBuildingEnv self.ctx: GraphBuildingEnvContext self.task: GFNTask @@ -242,18 +241,21 @@ def build_training_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) - n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.offline_ratio)) - n_replayed = n_drawn if self.cfg.replay.batch_size is None else self.cfg.replay.batch_size - n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + if self.cfg.replay.use: + # None is fine for either value, it will be replaced by num_from_policy, but 0 is not + assert self.cfg.replay.num_from_replay != 0, "Replay is enabled but no samples are being drawn from it" + assert self.cfg.replay.num_new_samples != 0, "Replay is enabled but no new samples are being added to it" + + n_drawn = self.cfg.algo.num_from_policy + n_replayed = self.cfg.replay.num_from_replay or n_drawn if self.cfg.replay.use else 0 + n_new_replay_samples = self.cfg.replay.num_new_samples or n_drawn if self.cfg.replay.use else None + n_from_dataset = self.cfg.algo.num_from_dataset src = DataSource(self.cfg, self.ctx, self.algo, self.task, replay_buffer=replay_buffer) if n_from_dataset: src.do_sample_dataset(self.training_data, n_from_dataset, backwards_model=model) if n_drawn: - # If we are using a replay buffer, we can choose to keep the new samples in the minibatch, or just - # send them to the replay and train only on replay samples. - keep_samples_in_batch = not self.cfg.replay.use or not self.cfg.replay.replaces_online_data - src.do_sample_model(model, n_drawn, keep_samples_in_batch) + src.do_sample_model(model, n_drawn, n_new_replay_samples) if n_replayed and replay_buffer is not None: src.do_sample_replay(n_replayed) if self.cfg.log_dir: @@ -266,8 +268,8 @@ def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) # TODO: we're changing the default, make sure anything that is using test data is adjusted src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.valid_offline_ratio)) - n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + n_drawn = self.cfg.algo.valid_num_from_policy + n_from_dataset = self.cfg.algo.valid_num_from_dataset src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) if n_from_dataset: @@ -285,10 +287,8 @@ def build_validation_data_loader(self) -> DataLoader: def build_final_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - # TODO: we're changing the default, make sure anything that is using test data is adjusted - src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - n_drawn = int(self.cfg.algo.global_batch_size) + n_drawn = self.cfg.algo.num_from_policy src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) assert self.cfg.num_final_gen_steps is not None # TODO: might be better to change total steps to total trajectories drawn diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index 8f67af3a..5ee5369a 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -73,7 +73,7 @@ class FocusRegionConfig: @dataclass class ConditionalsConfig: valid_sample_cond_info: bool = True - temperature: TempCondConfig = TempCondConfig() - moo: MultiObjectiveConfig = MultiObjectiveConfig() - weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig() - focus_region: FocusRegionConfig = FocusRegionConfig() + temperature: TempCondConfig = field(default_factory=TempCondConfig) + moo: MultiObjectiveConfig = field(default_factory=MultiObjectiveConfig) + weighted_prefs: WeightedPreferencesConfig = field(default_factory=WeightedPreferencesConfig) + focus_region: FocusRegionConfig = field(default_factory=FocusRegionConfig) From 282bbfb82f3cfd444ba1dc720da261aa292a4d84 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:47:06 -0700 Subject: [PATCH 06/31] move things around & prevent circular import --- src/gflownet/__init__.py | 88 +++++ src/gflownet/data/data_source.py | 14 +- src/gflownet/data/sampling_iterator.py | 489 ------------------------- src/gflownet/tasks/make_rings.py | 2 +- src/gflownet/tasks/qm9.py | 2 +- src/gflownet/tasks/qm9_moo.py | 2 +- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/tasks/toy_seq.py | 2 +- src/gflownet/trainer.py | 85 +---- src/gflownet/utils/sqlite_log.py | 93 +++++ 11 files changed, 196 insertions(+), 585 deletions(-) delete mode 100644 src/gflownet/data/sampling_iterator.py create mode 100644 src/gflownet/utils/sqlite_log.py diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index e69de29b..9f445e01 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -0,0 +1,88 @@ +from typing import Dict, List, NewType, Optional, Tuple + +import torch_geometric.data as gd +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor, nn + +from .config import Config + +# This type represents an unprocessed list of reward signals/conditioning information +FlatRewards = NewType("FlatRewards", Tensor) # type: ignore + +# This type represents the outcome for a multi-objective task of +# converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta +RewardScalar = NewType("RewardScalar", Tensor) # type: ignore + + +class GFNAlgorithm: + updates: int = 0 + global_cfg: Config + is_eval: bool = False + + def step(self): + self.updates += 1 # This isn't used anywhere? + + def compute_batch_losses( + self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Computes the loss for a batch of data, and proves logging informations + + Parameters + ---------- + model: nn.Module + The model being trained or evaluated + batch: gd.Batch + A batch of graphs + num_bootstrap: Optional[int] + The number of trajectories with reward targets in the batch (if applicable). + + Returns + ------- + loss: Tensor + The loss for that batch + info: Dict[str, Tensor] + Logged information about model predictions. + """ + raise NotImplementedError() + + def get_random_action_prob(self, it: int): + if self.is_eval: + return self.global_cfg.algo.valid_random_action_prob + if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: + return self.global_cfg.algo.train_random_action_prob + return 0 + + +class GFNTask: + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. + + Parameters + ---------- + cond_info: Dict[str, Tensor] + A dictionary with various conditional informations (e.g. temperature) + flat_reward: FlatRewards + A 2d tensor where each row represents a series of flat rewards. + + Returns + ------- + reward: RewardScalar + A 1d tensor, a scalar log-reward for each minibatch entry. + """ + raise NotImplementedError() + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + """Compute the flat rewards of mols according the the tasks' proxies + + Parameters + ---------- + mols: List[RDMol] + A list of RDKit molecules. + Returns + ------- + reward: FlatRewards + A 2d tensor, a vector of scalar reward for valid each molecule. + is_valid: Tensor + A 1d tensor, a boolean indicating whether the molecule is valid. + """ + raise NotImplementedError() diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index ca8c6d1b..1f74dc59 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -5,13 +5,12 @@ import torch from torch.utils.data import IterableDataset +from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng -# from gflownet.trainer import GFNAlgorithm, GFNTask - def cycle_call(it): while True: @@ -24,8 +23,8 @@ def __init__( self, cfg: Config, ctx: GraphBuildingEnvContext, - algo, #: GFNAlgorithm, - task, #: GFNTask, # TODO: this will cause a circular import + algo: GFNAlgorithm, + task: GFNTask, replay_buffer: Optional[ReplayBuffer] = None, is_algo_eval: bool = False, start_at_step: int = 0, @@ -230,7 +229,7 @@ def create_batch(self, trajs, batch_info): log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) - # TODO: find code that depends on batch.flat_rewards and deprecate it + batch.flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) return batch def compute_properties(self, trajs, mark_as_online=False): @@ -291,8 +290,9 @@ def relabel_in_hindsight(self, trajs): cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( cond_info, log_rewards, flat_rewards, hindsight_idxs ) - # TODO: This seems wrong, since we haven't recomputed is_valid - # log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + self.set_traj_cond_info(trajs, cond_info) + for i in range(len(trajs)): + trajs[i]["log_reward"] = log_rewards[i] def sample_idcs(self, n, num_samples): return self.rng.choice(n, num_samples, replace=False) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py deleted file mode 100644 index 4d2a8c07..00000000 --- a/src/gflownet/data/sampling_iterator.py +++ /dev/null @@ -1,489 +0,0 @@ -import os -import sqlite3 -from collections.abc import Iterable -from typing import Callable, List, Optional - -import numpy as np -import torch -import torch.nn as nn -from rdkit import RDLogger -from torch.utils.data import Dataset, IterableDataset - -from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphActionCategorical - - -class SamplingIterator(IterableDataset): - """This class allows us to parallelise and train faster. - - By separating sampling data/the model and building torch geometric - graphs from training the model, we can do the former in different - processes, which is much faster since much of graph construction - is CPU-bound. - - """ - - def __init__( - self, - dataset: Dataset, - model: nn.Module, - ctx, - algo, - task, - device, - batch_size: int = 1, - illegal_action_logreward: float = -50, - ratio: float = 0.5, - stream: bool = True, - replay_buffer: ReplayBuffer = None, - log_dir: str = None, - sample_cond_info: bool = True, - random_action_prob: float = 0.0, - det_after: Optional[int] = None, - hindsight_ratio: float = 0.0, - init_train_iter: int = 0, - ): - """Parameters - ---------- - dataset: Dataset - A dataset instance - model: nn.Module - The model we sample from (must be on CUDA already or share_memory() must be called so that - parameters are synchronized between each worker) - ctx: - The context for the environment, e.g. a MolBuildingEnvContext instance - algo: - The training algorithm, e.g. a TrajectoryBalance instance - task: GFNTask - A Task instance, e.g. a MakeRingsTask instance - device: torch.device - The device the model is on - replay_buffer: ReplayBuffer - The replay buffer for training on past data - batch_size: int - The number of trajectories, each trajectory will be comprised of many graphs, so this is - _not_ the batch size in terms of the number of graphs (that will depend on the task) - illegal_action_logreward: float - The logreward for invalid trajectories - ratio: float - The ratio of offline trajectories in the batch. - stream: bool - If True, data is sampled iid for every batch. Otherwise, this is a normal in-order - dataset iterator. - log_dir: str - If not None, logs each SamplingIterator worker's generated molecules to that file. - sample_cond_info: bool - If True (default), then the dataset is a dataset of points used in offline training. - If False, then the dataset is a dataset of preferences (e.g. used to validate the model) - random_action_prob: float - The probability of taking a random action, passed to the graph sampler - init_train_iter: int - The initial training iteration, incremented and passed to task.sample_conditional_information - """ - self.data = dataset - self.model = model - self.replay_buffer = replay_buffer - self.batch_size = batch_size - self.illegal_action_logreward = illegal_action_logreward - self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) - self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) - self.ratio = ratio - self.ctx = ctx - self.algo = algo - self.task = task - self.device = device - self.stream = stream - self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely - self.sample_cond_info = sample_cond_info - self.random_action_prob = random_action_prob - self.hindsight_ratio = hindsight_ratio - self.train_it = init_train_iter - self.do_validate_batch = False # Turn this on for debugging - self.iter = 0 - self.det_after = det_after - # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) - # then "offline" now refers to cond info and online to x, so no duplication and we don't end - # up with 2*batch_size accidentally - if not sample_cond_info: - self.offline_batch_size = self.online_batch_size = self.batch_size - - # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we - # don't want to initialize per-worker things just yet, such as where the log the worker writes - # to. This must be done in __iter__, which is called by the DataLoader once this instance - # has been copied into a new python process. - self.log_dir = log_dir - self.log = SQLiteLog() - self.log_hooks: List[Callable] = [] - - def add_log_hook(self, hook: Callable): - self.log_hooks.append(hook) - - def _idx_iterator(self): - RDLogger.DisableLog("rdApp.*") - if self.stream: - # If we're streaming data, just sample `offline_batch_size` indices - while True: - if self.offline_batch_size == 0 or len(self.data) == 0: - yield np.arange(0, 0) - else: - yield self.rng.integers(0, len(self.data), self.offline_batch_size) - else: - # Otherwise, figure out which indices correspond to this worker - worker_info = torch.utils.data.get_worker_info() - n = len(self.data) - if n == 0: - yield np.arange(0, 0) - return - assert ( - self.offline_batch_size > 0 - ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" - if worker_info is None: # no multi-processing - start, end, wid = 0, n, -1 - else: # split the data into chunks (per-worker) - nw = worker_info.num_workers - wid = worker_info.id - start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) - bs = self.offline_batch_size - if end - start <= bs: - yield np.arange(start, end) - return - for i in range(start, end - bs, bs): - yield np.arange(i, i + bs) - if i + bs < end: - yield np.arange(i + bs, end) - - def __len__(self): - if self.stream: - return int(1e6) - if len(self.data) == 0 and self.sample_online_once: - return 1 - return len(self.data) - - def __iter__(self): - self.iter += 1 - if self.det_after is not None and self.iter > self.det_after: - self.random_action_prob = 0 - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - # Now that we know we are in a worker instance, we can initialize per-worker things - self.rng = self.algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) - self.ctx.device = self.device - if self.log_dir is not None: - os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" - self.log.connect(self.log_path) - - for idcs in self._idx_iterator(): - num_offline = idcs.shape[0] # This is in [0, self.offline_batch_size] - # Sample conditional info such as temperature, trade-off weights, etc. - - if self.sample_cond_info: - num_online = self.online_batch_size - cond_info = self.task.sample_conditional_information( - num_offline + self.online_batch_size, self.train_it - ) - - # Sample some dataset data - graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) - flat_rewards = ( - list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] - ) - - trajs = self.algo.create_training_data_from_graphs( - graphs, self.model, cond_info["encoding"][:num_offline], 0 - ) - - else: # If we're not sampling the conditionals, then the idcs refer to listed preferences - num_online = num_offline - num_offline = 0 - cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) # This is sus, what's going on here? - ) - trajs, flat_rewards = [], [] - - # Sample some on-policy data - is_valid = torch.ones(num_offline + num_online).bool() - if num_online > 0: - with torch.no_grad(): - trajs += self.algo.create_training_data_from_own_samples( - self.model, - num_online, - cond_info["encoding"][num_offline:], - random_action_prob=self.random_action_prob, - ) - if self.algo.bootstrap_own_reward: - # The model can be trained to predict its own reward, - # i.e. predict the output of cond_info_to_logreward - pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] - flat_rewards += pred_reward - else: - # Otherwise, query the task for flat rewards - valid_idcs = torch.tensor( - [i + num_offline for i in range(num_online) if trajs[i + num_offline]["is_valid"]] - ).long() - # fetch the valid trajectories endpoints - mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] - # ask the task to compute their reward - online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) - assert ( - online_flat_rew.ndim == 2 - ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" - # The task may decide some of the mols are invalid, we have to again filter those - valid_idcs = valid_idcs[m_is_valid] - pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) - pred_reward[valid_idcs - num_offline] = online_flat_rew - is_valid[num_offline:] = False - is_valid[valid_idcs] = True - flat_rewards += list(pred_reward) - # Override the is_valid key in case the task made some mols invalid - for i in range(num_online): - trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() - - # Compute scalar rewards from conditional information & flat rewards - flat_rewards = torch.stack(flat_rewards) - log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - assert len(trajs) == num_online + num_offline - # Computes some metrics - extra_info = {"random_action_prob": self.random_action_prob} - if num_online > 0: - H = sum(i["fwd_logprob"] for i in trajs[num_offline:]) - extra_info["entropy"] = -H / num_online - extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]]) - if not self.sample_cond_info: - # If we're using a dataset of preferences, the user may want to know the id of the preference - for i, j in zip(trajs, idcs): - i["data_idx"] = j - # note: we convert back into natural rewards for logging purposes - # (allows to take averages and plot in objective space) - # TODO: implement that per-task (in case they don't apply the same beta and log transformations) - rewards = torch.exp(log_rewards / cond_info["beta"]) - if num_online > 0 and self.log_dir is not None: - self.log_generated( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, - ) - if num_online > 0: - extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() - for hook in self.log_hooks: - extra_info.update( - hook( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, - ) - ) - - if self.replay_buffer is not None: - # If we have a replay buffer, we push the online trajectories in it - # and resample immediately such that the "online" data in the batch - # comes from a more stable distribution (try to avoid forgetting) - - # cond_info is a dict, so we need to convert it to a list of dicts - cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] - - # push the online trajectories in the replay buffer and sample a new 'online' batch - for i in range(num_offline, len(trajs)): - self.replay_buffer.push( - trajs[i], - log_rewards[i], - flat_rewards[i], - cond_info[i], - is_valid[i], - ) - replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( - num_online - ) - - # append the online trajectories to the offline ones - trajs = trajs[:num_offline] + replay_trajs - log_rewards = torch.cat([log_rewards[:num_offline], replay_logr], dim=0) - flat_rewards = torch.cat([flat_rewards[:num_offline], replay_fr], dim=0) - cond_info = cond_info[:num_offline] + replay_condinfo # list of dicts - is_valid = torch.cat([is_valid[:num_offline], replay_valid], dim=0) - - # convert cond_info back to a dict - cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} - - if self.hindsight_ratio > 0.0: - # Relabels some of the online trajectories with hindsight - assert hasattr( - self.task, "relabel_condinfo_and_logrewards" - ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" - # samples indexes of trajectories without repeats - hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline - cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( - cond_info, log_rewards, flat_rewards, hindsight_idxs - ) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - # Construct batch - batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) - batch.num_offline = num_offline - batch.num_online = num_online - batch.flat_rewards = flat_rewards - batch.preferences = cond_info.get("preferences", None) - batch.focus_dir = cond_info.get("focus_dir", None) - batch.extra_info = extra_info - if self.ctx.has_n(): - log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] - batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) - batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) - # TODO: we could very well just pass the cond_info dict to construct_batch above, - # and the algo can decide what it wants to put in the batch object - - # Only activate for debugging your environment or dataset (e.g. the dataset could be - # generating trajectories with illegal actions) - if self.do_validate_batch: - self.validate_batch(batch, trajs) - - self.train_it += worker_info.num_workers if worker_info is not None else 1 - yield batch - - def validate_batch(self, batch, trajs): - for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( - [(batch.bck_actions, self.ctx.bck_action_type_order)] - if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") - else [] - ): - mask_cat = GraphActionCategorical( - batch, - [self.model._action_type_to_mask(t, batch) for t in atypes], - [self.model._action_type_to_key[t] for t in atypes], - [None for _ in atypes], - ) - masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) - num_trajs = len(trajs) - batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) - first_graph_idx = torch.zeros_like(batch.traj_lens) - torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) - if masked_action_is_used.sum() != 0: - invalid_idx = masked_action_is_used.argmax().item() - traj_idx = batch_idx[invalid_idx].item() - timestep = invalid_idx - first_graph_idx[traj_idx].item() - raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) - - def log_generated(self, trajs, rewards, flat_rewards, cond_info): - if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] - else: - mols = [""] * len(trajs) - - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() - rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] - - data = [ - [mols[i], rewards[i]] - + flat_rewards[i] - + preferences[i] - + focus_dir[i] - + [cond_info[k][i].item() for k in logged_keys] - for i in range(len(trajs)) - ] - - data_labels = ( - ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] - + [f"pref_{i}" for i in range(len(preferences[0]))] - + [f"focus_{i}" for i in range(len(focus_dir[0]))] - + [f"ci_{k}" for k in logged_keys] - ) - - self.log.insert_many(data, data_labels) - - -class SQLiteLogHook: - def __init__(self, log_dir, ctx) -> None: - self.log = None # Only initialized in __call__, which will occur inside the worker - self.log_dir = log_dir - self.ctx = ctx - self.data_labels = None - - def __call__(self, trajs, rewards, flat_rewards, cond_info): - if self.log is None: - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" - self.log = SQLiteLog() - self.log.connect(self.log_path) - - if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] - else: - mols = [""] * len(trajs) - - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() - rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] - - data = [ - [mols[i], rewards[i]] - + flat_rewards[i] - + preferences[i] - + focus_dir[i] - + [cond_info[k][i].item() for k in logged_keys] - for i in range(len(trajs)) - ] - if self.data_labels is None: - self.data_labels = ( - ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] - + [f"pref_{i}" for i in range(len(preferences[0]))] - + [f"focus_{i}" for i in range(len(focus_dir[0]))] - + [f"ci_{k}" for k in logged_keys] - ) - - self.log.insert_many(data, self.data_labels) - return {} - - -class SQLiteLog: - def __init__(self, timeout=300): - """Creates a log instance, but does not connect it to any db.""" - self.is_connected = False - self.db = None - self.timeout = timeout - - def connect(self, db_path: str): - """Connects to db_path - - Parameters - ---------- - db_path: str - The sqlite3 database path. If it does not exist, it will be created. - """ - self.db = sqlite3.connect(db_path, timeout=self.timeout) - cur = self.db.cursor() - self._has_results_table = len( - cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() - ) - cur.close() - - def _make_results_table(self, types, names): - type_map = {str: "text", float: "real", int: "real"} - col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) - cur = self.db.cursor() - cur.execute(f"create table results ({col_str})") - self._has_results_table = True - cur.close() - - def insert_many(self, rows, column_names): - assert all( - [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] - ), "rows must only contain scalars" - if not self._has_results_table: - self._make_results_table([type(i) for i in rows[0]], column_names) - cur = self.db.cursor() - cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec - cur.close() - self.db.commit() diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 00b2863b..34f47924 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -7,10 +7,10 @@ from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar class MakeRingsTask(GFNTask): diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 4a3dd1ba..0bad429f 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -9,11 +9,11 @@ from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index ec62643c..4ca55483 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader, Dataset import gflownet.models.mxmnet as mxmnet +from gflownet import FlatRewards, RewardScalar from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config @@ -19,7 +20,6 @@ from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks -from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 5c9040a7..c24a8d05 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -11,11 +11,11 @@ from torch.utils.data import Dataset from torch_geometric.data import Data +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph from gflownet.models import bengio2021flow from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index ca3fa0bc..8ec06d5e 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -10,6 +10,7 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet import FlatRewards, RewardScalar from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty @@ -18,7 +19,6 @@ from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask -from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 2aece1b2..f2c75f60 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -5,11 +5,11 @@ import torch from torch import Tensor +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv from gflownet.models.seq_transformer import SeqTransformerGFN from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 5caed451..92d45538 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -3,7 +3,7 @@ import pathlib import shutil import time -from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol import numpy as np import torch @@ -13,10 +13,10 @@ import wandb from omegaconf import OmegaConf from rdkit import RDLogger -from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer from gflownet.data.sampling_iterator import SQLiteLogHook @@ -27,87 +27,6 @@ from .config import Config -# This type represents an unprocessed list of reward signals/conditioning information -FlatRewards = NewType("FlatRewards", Tensor) # type: ignore - -# This type represents the outcome for a multi-objective task of -# converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta -RewardScalar = NewType("RewardScalar", Tensor) # type: ignore - - -class GFNAlgorithm: - updates: int = 0 - global_cfg: Config - is_eval: bool = False - - def step(self): - self.updates += 1 # This isn't used anywhere? - - def compute_batch_losses( - self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 - ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Computes the loss for a batch of data, and proves logging informations - - Parameters - ---------- - model: nn.Module - The model being trained or evaluated - batch: gd.Batch - A batch of graphs - num_bootstrap: Optional[int] - The number of trajectories with reward targets in the batch (if applicable). - - Returns - ------- - loss: Tensor - The loss for that batch - info: Dict[str, Tensor] - Logged information about model predictions. - """ - raise NotImplementedError() - - def get_random_action_prob(self, it: int): - if self.is_eval: - return self.global_cfg.algo.valid_random_action_prob - if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: - return self.global_cfg.algo.train_random_action_prob - return 0 - - -class GFNTask: - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. - - Parameters - ---------- - cond_info: Dict[str, Tensor] - A dictionary with various conditional informations (e.g. temperature) - flat_reward: FlatRewards - A 2d tensor where each row represents a series of flat rewards. - - Returns - ------- - reward: RewardScalar - A 1d tensor, a scalar log-reward for each minibatch entry. - """ - raise NotImplementedError() - - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: - """Compute the flat rewards of mols according the the tasks' proxies - - Parameters - ---------- - mols: List[RDMol] - A list of RDKit molecules. - Returns - ------- - reward: FlatRewards - A 2d tensor, a vector of scalar reward for valid each molecule. - is_valid: Tensor - A 1d tensor, a boolean indicating whether the molecule is valid. - """ - raise NotImplementedError() - class Closable(Protocol): def close(self): diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py new file mode 100644 index 00000000..0740baf8 --- /dev/null +++ b/src/gflownet/utils/sqlite_log.py @@ -0,0 +1,93 @@ +from typing import Iterable +import os +import sqlite3 +import torch + +class SQLiteLogHook: + def __init__(self, log_dir, ctx) -> None: + self.log = None # Only initialized in __call__, which will occur inside the worker + self.log_dir = log_dir + self.ctx = ctx + self.data_labels = None + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + if self.log is None: + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log = SQLiteLog() + self.log.connect(self.log_path) + + if hasattr(self.ctx, "object_to_log_repr"): + mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + else: + mols = [""] * len(trajs) + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + if self.data_labels is None: + self.data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, self.data_labels) + return {} + + +class SQLiteLog: + def __init__(self, timeout=300): + """Creates a log instance, but does not connect it to any db.""" + self.is_connected = False + self.db = None + self.timeout = timeout + + def connect(self, db_path: str): + """Connects to db_path + + Parameters + ---------- + db_path: str + The sqlite3 database path. If it does not exist, it will be created. + """ + self.db = sqlite3.connect(db_path, timeout=self.timeout) + cur = self.db.cursor() + self._has_results_table = len( + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() + ) + cur.close() + + def _make_results_table(self, types, names): + type_map = {str: "text", float: "real", int: "real"} + col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) + cur = self.db.cursor() + cur.execute(f"create table results ({col_str})") + self._has_results_table = True + cur.close() + + def insert_many(self, rows, column_names): + assert all( + [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] + ), "rows must only contain scalars" + if not self._has_results_table: + self._make_results_table([type(i) for i in rows[0]], column_names) + cur = self.db.cursor() + cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec + cur.close() + self.db.commit() From c3bc6d05087195aab4df6938a8e3105cb6ce2d66 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:47:56 -0700 Subject: [PATCH 07/31] tox --- src/gflownet/utils/sqlite_log.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index 0740baf8..1ac183db 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -1,8 +1,10 @@ -from typing import Iterable import os import sqlite3 +from typing import Iterable + import torch + class SQLiteLogHook: def __init__(self, log_dir, ctx) -> None: self.log = None # Only initialized in __call__, which will occur inside the worker From b1c5630019d569243a6987b33891cda322842772 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:58:32 -0700 Subject: [PATCH 08/31] fix imports --- src/gflownet/tasks/qm9_moo.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/trainer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 4ca55483..953c48e4 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -16,13 +16,13 @@ from gflownet.config import Config from gflownet.data.data_source import DataSource from gflownet.data.qm9 import QM9Dataset -from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks from gflownet.utils import metrics from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.sqlite_log import SQLiteLogHook from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 8ec06d5e..bbab91b7 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -15,13 +15,13 @@ from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty from gflownet.data.data_source import DataSource -from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.sqlite_log import SQLiteLogHook from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 92d45538..f7406fd7 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -19,11 +19,11 @@ from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device from gflownet.utils.multiprocessing_proxy import mp_object_wrapper +from gflownet.utils.sqlite_log import SQLiteLogHook from .config import Config From a64a639bf6ff66c57ded4aa271aff9cf372d272c Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 09:11:41 -0700 Subject: [PATCH 09/31] replace device references with get_worker_device --- src/gflownet/tasks/qm9.py | 10 ++++++---- src/gflownet/tasks/qm9_moo.py | 2 +- src/gflownet/tasks/seh_frag.py | 5 ++++- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/trainer.py | 19 +++++++------------ src/gflownet/utils/conditioning.py | 4 ++-- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 0bad429f..266ff77b 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -15,6 +15,7 @@ from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.misc import get_worker_device from gflownet.utils.transforms import to_logreward @@ -30,7 +31,8 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models(cfg.task.qm9.model_path, torch.device(cfg.device)) + self.device = get_worker_device() + self.models = self.load_task_models(cfg.task.qm9.model_path) self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) self.num_cond_dim = self.temperature_conditional.encoding_size() @@ -60,7 +62,7 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self, path, device): + def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? try: @@ -73,8 +75,8 @@ def load_task_models(self, path, device): "https://storage.googleapis.com/emmanuel-data/models/mxmnet_gap_model.pt", ) gap_model.load_state_dict(state_dict) - gap_model.to(device) - gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) + gap_model.to(self.device) + gap_model = self._wrap_model(gap_model) return {"mxmnet_gap": gap_model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 953c48e4..51029e3a 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -326,7 +326,7 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index c24a8d05..a731a57c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -17,6 +17,7 @@ from gflownet.models import bengio2021flow from gflownet.online_trainer import StandardOnlineTrainer from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.misc import get_worker_device from gflownet.utils.transforms import to_logreward @@ -43,6 +44,7 @@ def __init__( self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) self.num_cond_dim = self.temperature_conditional.encoding_size() + self.device = get_worker_device() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -52,7 +54,8 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model, send_to_device=True) + model.to(self.device) + model = self._wrap_model(model, send_to_device=True) return {"seh": model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index bbab91b7..ef8def85 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -351,7 +351,7 @@ def on_validation_end(self, metrics: Dict[str, Any]): return callback_dict def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index f7406fd7..c8be7312 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -122,11 +122,9 @@ def setup(self): self.setup_algo() self.setup_model() - def _wrap_for_mp(self, obj, send_to_device=False): + def _wrap_for_mp(self, obj): """Wraps an object in a placeholder whose reference can be sent to a data worker process (only if the number of workers is non-zero).""" - if send_to_device: - obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: wrapper = mp_object_wrapper( obj, @@ -135,9 +133,9 @@ def _wrap_for_mp(self, obj, send_to_device=False): pickle_messages=self.cfg.pickle_mp_messages, ) self.to_terminate.append(wrapper.terminate) - return wrapper.placeholder, torch.device("cpu") + return wrapper.placeholder else: - return obj, self.device + return obj def build_callbacks(self): return {} @@ -153,12 +151,9 @@ def _make_data_loader(self, src): def build_training_data_loader(self) -> DataLoader: # Since the model may be used by a worker in a different process, we need to wrap it. - # The device `dev` returned here is the device that the worker will use to interact with the model; - # normally, if the main process has the model on 'cuda', this will simply be 'cpu' (since workers - # don't have CUDA access). # See implementation_nodes.md for more details. - model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) + model = self._wrap_for_mp(self.sampling_model) + replay_buffer = self._wrap_for_mp(self.replay_buffer) if self.cfg.replay.use: # None is fine for either value, it will be replaced by num_from_policy, but 0 is not @@ -184,7 +179,7 @@ def build_training_data_loader(self) -> DataLoader: return self._make_data_loader(src) def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) # TODO: we're changing the default, make sure anything that is using test data is adjusted src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) n_drawn = self.cfg.algo.valid_num_from_policy @@ -205,7 +200,7 @@ def build_validation_data_loader(self) -> DataLoader: return self._make_data_loader(src) def build_final_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) n_drawn = self.cfg.algo.num_from_policy src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index aece3868..0630be55 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -12,6 +12,7 @@ from gflownet.config import Config from gflownet.utils import metrics from gflownet.utils.focus_model import TabularFocusModel +from gflownet.utils.misc import get_worker_device from gflownet.utils.transforms import thermometer @@ -142,8 +143,7 @@ def __init__(self, cfg: Config, n_valid: int, rng: np.random.Generator): if focus_type is not None and "learned" in focus_type: if focus_type == "learned-tabular": self.focus_model = TabularFocusModel( - # TODO: proper device propagation - device=torch.device("cpu"), + device=get_worker_device(), n_objectives=cfg.cond.moo.num_objectives, state_space_res=self.cfg.focus_model_state_space_res, ) From 28bcc5946779b89f142b5407aa8230a95aeb52ec Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 09:39:45 -0700 Subject: [PATCH 10/31] little fixes --- src/gflownet/data/data_source.py | 2 +- src/gflownet/tasks/seh_frag.py | 4 ++-- src/gflownet/utils/multiobjective_hooks.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 1f74dc59..1d4f3384 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -170,7 +170,7 @@ def iterator(): # I'm also not a fan of encode_conditional_information, it assumes lots of things about what's passed to # it and the state of the program (e.g. validation mode) cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index a731a57c..f02f3a44 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -40,11 +40,11 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng + self.device = get_worker_device() self.models = self._load_task_models() self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) self.num_cond_dim = self.temperature_conditional.encoding_size() - self.device = get_worker_device() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -55,7 +55,7 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() model.to(self.device) - model = self._wrap_model(model, send_to_device=True) + model = self._wrap_model(model) return {"seh": model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 115bef3a..359d1f6b 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -74,7 +74,7 @@ def _hsri(self, x): def _run_pareto_accumulation(self): num_updates = 0 timeouts = 0 - while not self.stop.is_set() or timeouts < 200: + while not self.stop.is_set() and timeouts < 200: try: r, smi, owid = self.pareto_queue.get(block=True, timeout=1) except queue.Empty: From 4811e7c16eb10cd0557030ba1acdf2bda89beb65 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 12:01:45 -0700 Subject: [PATCH 11/31] a few more stragglers --- src/gflownet/data/data_source.py | 2 +- src/gflownet/tasks/qm9.py | 4 +++- src/gflownet/tasks/seh_frag.py | 5 ++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 1d4f3384..d78a2a7f 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -266,7 +266,7 @@ def compute_log_rewards(self, trajs): def send_to_replay(self, trajs): if self.replay_buffer is not None: for t in trajs: - self.replay_buffer.push(t, t["log_rewards"], t["flat_rewards"], t["cond_info"], t["is_valid"]) + self.replay_buffer.push(t, t["log_reward"], t["flat_rewards"], t["cond_info"], t["is_valid"]) def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 266ff77b..5f489938 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -87,7 +87,9 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) + batch.to( + self.models["mxmnet_gap"].device if hasattr(self.models["mxmnet_gap"], "device") else get_worker_device() + ) preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 preds = ( diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index f02f3a44..e64d642d 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -40,7 +40,6 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng - self.device = get_worker_device() self.models = self._load_task_models() self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) @@ -54,7 +53,7 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model.to(self.device) + model.to(get_worker_device()) model = self._wrap_model(model) return {"seh": model} @@ -66,7 +65,7 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_reward_from_graph(self, graphs: List[Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) + batch.to(self.models["seh"].device if hasattr(self.models["seh"], "device") else get_worker_device()) preds = self.models["seh"](batch).reshape((-1,)).data.cpu() preds[preds.isnan()] = 0 return self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) From 7d32ac142c64bb7fde180b5adbe7ad7e5ddad6c2 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 23 Feb 2024 15:44:26 -0700 Subject: [PATCH 12/31] proof of concept of using shared pinned buffers --- src/gflownet/algo/graph_sampling.py | 4 +- src/gflownet/config.py | 1 + src/gflownet/data/sampling_iterator.py | 453 ++++++++++++++++++++ src/gflownet/models/graph_transformer.py | 2 + src/gflownet/tasks/seh_frag.py | 1 + src/gflownet/trainer.py | 54 ++- src/gflownet/utils/multiprocessing_proxy.py | 198 ++++++++- 7 files changed, 707 insertions(+), 6 deletions(-) create mode 100644 src/gflownet/data/sampling_iterator.py diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 0db5bcec..3cd253df 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -114,7 +114,9 @@ def not_done(lst): # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + batch = self.ctx.collate(torch_graphs) + batch.cond_info = cond_info[not_done_mask] + fwd_cat, *_, log_reward_preds = model(batch.to(dev), None) if random_action_prob > 0: masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks # Device which graphs in the minibatch will get their action randomized diff --git a/src/gflownet/config.py b/src/gflownet/config.py index e8238e97..070d68d1 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -101,6 +101,7 @@ class Config: pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False + mp_buffer_size: Optional[int] = None algo: AlgoConfig = field(default_factory=AlgoConfig) model: ModelConfig = field(default_factory=ModelConfig) opt: OptimizerConfig = field(default_factory=OptimizerConfig) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py new file mode 100644 index 00000000..8daec21b --- /dev/null +++ b/src/gflownet/data/sampling_iterator.py @@ -0,0 +1,453 @@ +import os +import sqlite3 +from collections.abc import Iterable +from copy import deepcopy +from typing import Callable, List + +import numpy as np +import torch +import torch.nn as nn +import torch.multiprocessing as mp +from rdkit import RDLogger +from torch.utils.data import Dataset, IterableDataset + +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.envs.graph_building_env import GraphActionCategorical +from gflownet.utils.multiprocessing_proxy import put_into_batch_buffer, SharedPinnedBuffer + + +class SamplingIterator(IterableDataset): + """This class allows us to parallelise and train faster. + + By separating sampling data/the model and building torch geometric + graphs from training the model, we can do the former in different + processes, which is much faster since much of graph construction + is CPU-bound. + + """ + + def __init__( + self, + dataset: Dataset, + model: nn.Module, + ctx, + algo, + task, + device, + batch_size: int = 1, + illegal_action_logreward: float = -50, + ratio: float = 0.5, + stream: bool = True, + replay_buffer: ReplayBuffer = None, + log_dir: str = None, + sample_cond_info: bool = True, + random_action_prob: float = 0.0, + hindsight_ratio: float = 0.0, + init_train_iter: int = 0, + buffer_size: int = None, + num_workers: int = 1, + do_multiple_buffers = True, # If True, each worker has its own buffer; doesn't seem to have much impact either way + ): + """Parameters + ---------- + dataset: Dataset + A dataset instance + model: nn.Module + The model we sample from (must be on CUDA already or share_memory() must be called so that + parameters are synchronized between each worker) + ctx: + The context for the environment, e.g. a MolBuildingEnvContext instance + algo: + The training algorithm, e.g. a TrajectoryBalance instance + task: GFNTask + A Task instance, e.g. a MakeRingsTask instance + device: torch.device + The device the model is on + replay_buffer: ReplayBuffer + The replay buffer for training on past data + batch_size: int + The number of trajectories, each trajectory will be comprised of many graphs, so this is + _not_ the batch size in terms of the number of graphs (that will depend on the task) + illegal_action_logreward: float + The logreward for invalid trajectories + ratio: float + The ratio of offline trajectories in the batch. + stream: bool + If True, data is sampled iid for every batch. Otherwise, this is a normal in-order + dataset iterator. + log_dir: str + If not None, logs each SamplingIterator worker's generated molecules to that file. + sample_cond_info: bool + If True (default), then the dataset is a dataset of points used in offline training. + If False, then the dataset is a dataset of preferences (e.g. used to validate the model) + random_action_prob: float + The probability of taking a random action, passed to the graph sampler + init_train_iter: int + The initial training iteration, incremented and passed to task.sample_conditional_information + """ + self.data = dataset + self.model = model + self.replay_buffer = replay_buffer + self.batch_size = batch_size + self.illegal_action_logreward = illegal_action_logreward + self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) + self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) + self.ratio = ratio + self.ctx = ctx + self.algo = algo + self.task = task + self.device = device + self.stream = stream + self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely + self.sample_cond_info = sample_cond_info + self.random_action_prob = random_action_prob + self.hindsight_ratio = hindsight_ratio + self.train_it = init_train_iter + self.do_validate_batch = False # Turn this on for debugging + + self.result_buffer_size = buffer_size + self.do_multiple_buffers = do_multiple_buffers + if buffer_size and do_multiple_buffers: + self.result_buffer = [SharedPinnedBuffer(buffer_size) for _ in range(num_workers)] + elif buffer_size: + self.result_buffer = SharedPinnedBuffer(buffer_size) + self.round_robin_cond = mp.Condition() + self.round_robin_counter = torch.zeros(1) + self.round_robin_counter.share_memory_() + + # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) + # then "offline" now refers to cond info and online to x, so no duplication and we don't end + # up with 2*batch_size accidentally + if not sample_cond_info: + self.offline_batch_size = self.online_batch_size = self.batch_size + + # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we + # don't want to initialize per-worker things just yet, such as where the log the worker writes + # to. This must be done in __iter__, which is called by the DataLoader once this instance + # has been copied into a new python process. + self.log_dir = log_dir + self.log = SQLiteLog() + self.log_hooks: List[Callable] = [] + + def add_log_hook(self, hook: Callable): + self.log_hooks.append(hook) + + def _idx_iterator(self): + RDLogger.DisableLog("rdApp.*") + if self.stream: + # If we're streaming data, just sample `offline_batch_size` indices + while True: + yield self.rng.integers(0, len(self.data), self.offline_batch_size) + else: + # Otherwise, figure out which indices correspond to this worker + worker_info = torch.utils.data.get_worker_info() + n = len(self.data) + if n == 0: + yield np.arange(0, 0) + return + assert ( + self.offline_batch_size > 0 + ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" + if worker_info is None: # no multi-processing + start, end, wid = 0, n, -1 + else: # split the data into chunks (per-worker) + nw = worker_info.num_workers + wid = worker_info.id + start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) + bs = self.offline_batch_size + if end - start <= bs: + yield np.arange(start, end) + return + for i in range(start, end - bs, bs): + yield np.arange(i, i + bs) + if i + bs < end: + yield np.arange(i + bs, end) + + def __len__(self): + if self.stream: + return int(1e6) + if len(self.data) == 0 and self.sample_online_once: + return 1 + return len(self.data) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + # Now that we know we are in a worker instance, we can initialize per-worker things + self.rng = self.algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) + self.ctx.device = self.device + if self.log_dir is not None: + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log.connect(self.log_path) + + for idcs in self._idx_iterator(): + num_offline = idcs.shape[0] # This is in [0, self.offline_batch_size] + # Sample conditional info such as temperature, trade-off weights, etc. + + if self.sample_cond_info: + num_online = self.online_batch_size + cond_info = self.task.sample_conditional_information( + num_offline + self.online_batch_size, self.train_it + ) + + # Sample some dataset data + graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) + flat_rewards = ( + list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] + ) + trajs = self.algo.create_training_data_from_graphs( + graphs, self.model, cond_info["encoding"][:num_offline], 0 + ) + + else: # If we're not sampling the conditionals, then the idcs refer to listed preferences + num_online = num_offline + num_offline = 0 + cond_info = self.task.encode_conditional_information( + steer_info=torch.stack([self.data[i] for i in idcs]) + ) + trajs, flat_rewards = [], [] + + # Sample some on-policy data + is_valid = torch.ones(num_offline + num_online).bool() + if num_online > 0: + with torch.no_grad(): + trajs += self.algo.create_training_data_from_own_samples( + self.model, + num_online, + cond_info["encoding"][num_offline:], + random_action_prob=self.random_action_prob, + ) + if self.algo.bootstrap_own_reward: + # The model can be trained to predict its own reward, + # i.e. predict the output of cond_info_to_logreward + pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] + flat_rewards += pred_reward + else: + # Otherwise, query the task for flat rewards + valid_idcs = torch.tensor( + [i + num_offline for i in range(num_online) if trajs[i + num_offline]["is_valid"]] + ).long() + # fetch the valid trajectories endpoints + mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] + # ask the task to compute their reward + online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) + assert ( + online_flat_rew.ndim == 2 + ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + # The task may decide some of the mols are invalid, we have to again filter those + valid_idcs = valid_idcs[m_is_valid] + pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) + pred_reward[valid_idcs - num_offline] = online_flat_rew + is_valid[num_offline:] = False + is_valid[valid_idcs] = True + flat_rewards += list(pred_reward) + # Override the is_valid key in case the task made some mols invalid + for i in range(num_online): + trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() + + # Compute scalar rewards from conditional information & flat rewards + flat_rewards = torch.stack(flat_rewards) + log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + # Computes some metrics + extra_info = {} + if not self.sample_cond_info: + # If we're using a dataset of preferences, the user may want to know the id of the preference + for i, j in zip(trajs, idcs): + i["data_idx"] = j + # note: we convert back into natural rewards for logging purposes + # (allows to take averages and plot in objective space) + # TODO: implement that per-task (in case they don't apply the same beta and log transformations) + rewards = torch.exp(log_rewards / cond_info["beta"]) + if num_online > 0 and self.log_dir is not None: + self.log_generated( + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, + ) + if num_online > 0: + extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() + for hook in self.log_hooks: + extra_info.update( + hook( + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, + ) + ) + + if self.replay_buffer is not None: + # If we have a replay buffer, we push the online trajectories in it + # and resample immediately such that the "online" data in the batch + # comes from a more stable distribution (try to avoid forgetting) + + # cond_info is a dict, so we need to convert it to a list of dicts + cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] + + # push the online trajectories in the replay buffer and sample a new 'online' batch + for i in range(num_offline, len(trajs)): + self.replay_buffer.push( + deepcopy(trajs[i]), + deepcopy(log_rewards[i]), + deepcopy(flat_rewards[i]), + deepcopy(cond_info[i]), + deepcopy(is_valid[i]), + ) + replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( + num_online + ) + + # append the online trajectories to the offline ones + trajs[num_offline:] = replay_trajs + log_rewards[num_offline:] = replay_logr + flat_rewards[num_offline:] = replay_fr + cond_info[num_offline:] = replay_condinfo + is_valid[num_offline:] = replay_valid + + # convert cond_info back to a dict + cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} + + if self.hindsight_ratio > 0.0: + # Relabels some of the online trajectories with hindsight + assert hasattr( + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + # samples indexes of trajectories without repeats + hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline + cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( + cond_info, log_rewards, flat_rewards, hindsight_idxs + ) + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + # Construct batch + batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) + batch.num_offline = num_offline + batch.num_online = num_online + batch.flat_rewards = flat_rewards + batch.preferences = cond_info.get("preferences", None) + batch.focus_dir = cond_info.get("focus_dir", None) + batch.extra_info = extra_info + # TODO: we could very well just pass the cond_info dict to construct_batch above, + # and the algo can decide what it wants to put in the batch object + + # Only activate for debugging your environment or dataset (e.g. the dataset could be + # generating trajectories with illegal actions) + if self.do_validate_batch: + self.validate_batch(batch, trajs) + + self.train_it += worker_info.num_workers if worker_info is not None else 1 + if self.result_buffer_size and not self.do_multiple_buffers: + with self.round_robin_cond: + self.round_robin_cond.wait_for(lambda: self.round_robin_counter[0] == self._wid) + self.result_buffer.lock.acquire() + yield put_into_batch_buffer(batch, self.result_buffer.buffer) + self.round_robin_counter[0] = (self._wid + 1) % worker_info.num_workers + with self.round_robin_cond: + self.round_robin_cond.notify_all() + elif self.result_buffer_size: + self.result_buffer[self._wid].lock.acquire() + desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) + desc.wid = self._wid + yield desc + else: + yield batch + + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [self.model._action_type_to_mask(t, batch) for t in atypes], + [self.model._action_type_to_key[t] for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) + + def log_generated(self, trajs, rewards, flat_rewards, cond_info): + if hasattr(self.ctx, "object_to_log_repr"): + mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + else: + mols = [""] * len(trajs) + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + + data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, data_labels) + + +class SQLiteLog: + def __init__(self, timeout=300): + """Creates a log instance, but does not connect it to any db.""" + self.is_connected = False + self.db = None + self.timeout = timeout + + def connect(self, db_path: str): + """Connects to db_path + + Parameters + ---------- + db_path: str + The sqlite3 database path. If it does not exist, it will be created. + """ + self.db = sqlite3.connect(db_path, timeout=self.timeout) + cur = self.db.cursor() + self._has_results_table = len( + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() + ) + cur.close() + + def _make_results_table(self, types, names): + type_map = {str: "text", float: "real", int: "real"} + col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) + cur = self.db.cursor() + cur.execute(f"create table results ({col_str})") + self._has_results_table = True + cur.close() + + def insert_many(self, rows, column_names): + assert all( + [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] + ), "rows must only contain scalars" + if not self._has_results_table: + self._make_results_table([type(i) for i in rows[0]], column_names) + cur = self.db.cursor() + cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec + cur.close() + self.db.commit() diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 8c3993f0..6dae55b7 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -245,6 +245,8 @@ def _make_cat(self, g, emb, action_types): ) def forward(self, g: gd.Batch, cond: torch.Tensor): + if cond is None: + cond = g.cond_info node_embeddings, graph_embeddings = self.transf(g, cond) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e64d642d..18db5a2c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -216,6 +216,7 @@ def main(): config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] + config.mp_buffer_size = None # 32 * 1024 ** 2, trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index c8be7312..12899958 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -22,8 +22,8 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device -from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook +from gflownet.utils.multiprocessing_proxy import mp_object_wrapper, resolve_batch_buffer, BatchDescriptor from .config import Config @@ -131,6 +131,7 @@ def _wrap_for_mp(self, obj): self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, + bb_size=self.cfg.mp_buffer_size, ) self.to_terminate.append(wrapper.terminate) return wrapper.placeholder @@ -140,7 +141,32 @@ def _wrap_for_mp(self, obj): def build_callbacks(self): return {} +<<<<<<< HEAD def _make_data_loader(self, src): +======= + def build_training_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) + iterator = SamplingIterator( + self.training_data, + model, + self.ctx, + self.algo, + self.task, + dev, + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, + replay_buffer=replay_buffer, + ratio=self.cfg.algo.offline_ratio, + log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), + random_action_prob=self.cfg.algo.train_random_action_prob, + hindsight_ratio=self.cfg.replay.hindsight_ratio, + buffer_size=self.cfg.mp_buffer_size, + num_workers=self.cfg.num_workers, + ) + for hook in self.sampling_hooks: + iterator.add_log_hook(hook) +>>>>>>> proof of concept of using shared pinned buffers return torch.utils.data.DataLoader( src, batch_size=None, @@ -266,6 +292,7 @@ def run(self, logger=None): start = self.cfg.start_at_step + 1 num_training_steps = self.cfg.num_training_steps logger.info("Starting training") +<<<<<<< HEAD start_time = time.time() for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): # the memory fragmentation or allocation keeps growing, how often should we clean up? @@ -274,6 +301,27 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() +======= + import time + t0 = time.time() + times = [] + for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): + if isinstance(batch, BatchDescriptor): + print(f"buffer size was {batch.size / 1024**2:.2f}M") + if train_dl.dataset.do_multiple_buffers: + wid = batch.wid + batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer[wid].buffer, self.device) + train_dl.dataset.result_buffer[wid].lock.release() + else: + batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer.buffer, self.device) + train_dl.dataset.result_buffer.lock.release() + else: + batch = batch.to(self.device) + t1 = time.time() + times.append(t1 - t0) + print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") + t0 = t1 +>>>>>>> proof of concept of using shared pinned buffers epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -281,9 +329,13 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue +<<<<<<< HEAD info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() +======= + info = self.train_batch(batch, epoch_idx, batch_idx, it) +>>>>>>> proof of concept of using shared pinned buffers self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index df13b565..618d98f2 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -2,9 +2,169 @@ import queue import threading import traceback +from itertools import chain +import numpy as np import torch import torch.multiprocessing as mp +from torch_geometric.data import Batch + +from gflownet.envs.graph_building_env import GraphActionCategorical + + +class SharedPinnedBuffer: + def __init__(self, size): + self.size = size + self.buffer = torch.empty(size, dtype=torch.uint8) + self.buffer.share_memory_() + self.lock = mp.Lock() + + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) + assert r == 0 + assert self.buffer.is_shared() + assert self.buffer.is_pinned() + + +class BatchDescriptor: + def __init__(self, names, types, shapes, size, other): + self.names = names + self.types = types + self.shapes = shapes + self.size = size + self.other = other + + +class ResultDescriptor: + def __init__(self, names, types, shapes, size, gac_attrs): + self.names = names + self.types = types + self.shapes = shapes + self.size = size + self.gac_attrs = gac_attrs + + +def prod(l): + p = 1 + for i in l: + p *= i + return p + + +def put_into_batch_buffer(batch, buffer): + names = [] + types = [] + shapes = [] + offset = 0 + others = {} + for k, v in chain(batch._store.items(), (("_slice_dict_" + k, v) for k, v in batch._slice_dict.items())): + if not isinstance(v, torch.Tensor): + try: + v = torch.as_tensor(v) + except Exception as e: + others[k] = v + continue + names.append(k) + types.append(v.dtype) + shapes.append(tuple(v.shape)) + numel = v.numel() * v.element_size() + # print('putting', k, v.shape, numel, offset) + buffer[offset : offset + numel] = v.view(-1).view(torch.uint8) + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + if offset > buffer.shape[0]: + raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") + # print(f'total size: {offset / 1024**2:.3f}M') + # print(batch.batch) + return BatchDescriptor(names, types, shapes, offset, others) + + +def resolve_batch_buffer(descriptor, buffer, device): + offset = 0 + batch = Batch() + batch._slice_dict = {} + cuda_buffer = buffer[: descriptor.size].to(device) # TODO: check if only sending `size` is faster? + for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): + numel = prod(shape) * dtype.itemsize + # print('restoring', name, shape, numel, offset) + if name.startswith("_slice_dict_"): + batch._slice_dict[name[12:]] = cuda_buffer[offset : offset + numel].view(dtype).view(shape) + else: + setattr(batch, name, cuda_buffer[offset : offset + numel].view(dtype).view(shape)) + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + # print(batch.batch) + # print(f'total size: {offset / 1024**2:.3f}M') + for k, v in descriptor.other.items(): + setattr(batch, k, v) + return batch + + +def put_into_result_buffer(result, buffer): + gac_names = ["logits", "batch", "slice", "masks"] + gac, tensor = result + buffer[: tensor.numel() * tensor.element_size()] = tensor.view(-1).view(torch.uint8) + offset = tensor.numel() * tensor.element_size() + offset += (8 - offset % 8) % 8 # align to 8 bytes + names = ["@per_graph_out"] + types = [tensor.dtype] + shapes = [tensor.shape] + for name in gac_names: + tensors = getattr(gac, name) + for i, x in enumerate(tensors): + # print(f"putting {name}@{i} with shape {x.shape}") + numel = x.numel() * x.element_size() + if numel > 0: + # We need this for a funny reason + # torch.zeros(0)[::2] has a stride of (2,), and is contiguous according to torch + # so, flattening it and then reshaping it will not change the stride, which will + # make view(uint8) complain that the strides are not compatible. + # The batch[::2] happens when creating the categorical and deduplicate_edge_index is True + buffer[offset : offset + numel] = x.flatten().view(torch.uint8) + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + if offset > buffer.shape[0]: + raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") + names.append(f"{name}@{i}") + types.append(x.dtype) + shapes.append(tuple(x.shape)) + return ResultDescriptor(names, types, shapes, offset, (gac.num_graphs, gac.keys, gac.types)) + + +def resolve_result_buffer(descriptor, buffer, device): + # TODO: models can return multiple GraphActionCategoricals, but we only support one for now + # Would be nice to have something generic (and recursive?) + offset = 0 + tensor = buffer[: descriptor.size].to(device) + if tensor.device == device: # CPU to CPU + # I think we need this? Otherwise when we release the lock, the memory might be overwritten + tensor = tensor.clone() + # Maybe make this a static method, or just overload __new__? + gac = GraphActionCategorical.__new__(GraphActionCategorical) + gac.num_graphs, gac.keys, gac.types = descriptor.gac_attrs + gac.dev = device + gac.logprobs = None + gac._epsilon = 1e-38 + + gac_names = ["logits", "batch", "slice", "masks"] + for i in gac_names: + setattr(gac, i, [None] * len(gac.types)) + + for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): + numel = prod(shape) * dtype.itemsize + if name == "@per_graph_out": + per_graph_out = tensor[offset : offset + numel].view(dtype).view(shape) + else: + name, index = name.split("@") + index = int(index) + if name in gac_names: + getattr(gac, name)[index] = tensor[offset : offset + numel].view(dtype).view(shape) + else: + raise ValueError(f"Unknown result descriptor name: {name}") + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + # print(f"restored {name} with shape {shape}") + return gac, per_graph_out class MPObjectPlaceholder: @@ -12,11 +172,15 @@ class MPObjectPlaceholder: in a worker process, and translates calls to the object-placeholder into queries for the main process to execute on the real object.""" - def __init__(self, in_queues, out_queues, pickle_messages=False): + def __init__(self, in_queues, out_queues, pickle_messages=False, batch_buffer_size=None): self.qs = in_queues, out_queues self.device = torch.device("cpu") self.pickle_messages = pickle_messages self._is_init = False + self.batch_buffer_size = batch_buffer_size + if batch_buffer_size is not None: + self._batch_buffer = SharedPinnedBuffer(batch_buffer_size) + self._result_buffer = SharedPinnedBuffer(batch_buffer_size) def _check_init(self): if self._is_init: @@ -41,6 +205,9 @@ def decode(self, m): if isinstance(m, Exception): print("Received exception from main process, reraising.") raise m + if isinstance(m, ResultDescriptor): + m = resolve_result_buffer(m, self._result_buffer.buffer, self.device) + self._result_buffer.lock.release() return m def __getattr__(self, name): @@ -53,6 +220,11 @@ def method_wrapper(*a, **kw): def __call__(self, *a, **kw): self._check_init() + if self.batch_buffer_size and len(a) and isinstance(a[0], Batch): + # The lock will be released by the consumer of this buffer once the memory has been transferred to CUDA + self._batch_buffer.lock.acquire() + batch_descriptor = put_into_batch_buffer(a[0], self._batch_buffer.buffer) + a = (batch_descriptor,) + a[1:] self.in_queue.put(self.encode(("__call__", a, kw))) return self.decode(self.out_queue.get()) @@ -75,7 +247,7 @@ class MPObjectProxy: Always passes CPU tensors between processes. """ - def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, bb_size=None): """Construct a multiprocessing object proxy. Parameters @@ -91,11 +263,13 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo If True, pickle messages sent between processes. This reduces load on shared memory, but increases load on CPU. It is recommended to activate this flag if encountering "Too many open files"-type errors. + bb_size: Optional[int] + batch buffer size """ self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, bb_size) self.obj = obj if hasattr(obj, "parameters"): self.device = next(obj.parameters()).device @@ -109,6 +283,15 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo def encode(self, m): if self.pickle_messages: return pickle.dumps(m) + if ( + self.placeholder.batch_buffer_size + and isinstance(m, (list, tuple)) + and len(m) == 2 + and isinstance(m[0], GraphActionCategorical) + and isinstance(m[1], torch.Tensor) + ): + self.placeholder._result_buffer.lock.acquire() + return put_into_result_buffer(m, self.placeholder._result_buffer.buffer) return m def decode(self, m): @@ -133,6 +316,12 @@ def run(self): break timeouts = 0 attr, args, kwargs = r + if self.placeholder.batch_buffer_size and len(args) and isinstance(args[0], BatchDescriptor): + batch = resolve_batch_buffer(args[0], self.placeholder._batch_buffer.buffer, self.device) + args = (batch,) + args[1:] + # Should this release happen after the call to f()? Are we at risk of overwriting memory that + # is still being used by CUDA? + self.placeholder._batch_buffer.lock.release() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -143,6 +332,7 @@ def run(self): except Exception as e: result = e exc_str = traceback.format_exc() + print(exc_str) try: pickle.dumps(e) except Exception: @@ -159,7 +349,7 @@ def terminate(self): self.stop.set() -def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, bb_size=None): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, this can be used to wrap a model such that only the main process makes From d4a2a7ddf512895d8c0b2e5bd67fd530395c283f Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 23 Feb 2024 15:52:41 -0700 Subject: [PATCH 13/31] 32mb buffer --- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/trainer.py | 40 +++------------------------------- 2 files changed, 4 insertions(+), 38 deletions(-) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 18db5a2c..41e67df4 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -216,7 +216,7 @@ def main(): config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - config.mp_buffer_size = None # 32 * 1024 ** 2, + config.mp_buffer_size = 32 * 1024 ** 2 trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 12899958..0d758c78 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -141,32 +141,7 @@ def _wrap_for_mp(self, obj): def build_callbacks(self): return {} -<<<<<<< HEAD def _make_data_loader(self, src): -======= - def build_training_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) - iterator = SamplingIterator( - self.training_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - replay_buffer=replay_buffer, - ratio=self.cfg.algo.offline_ratio, - log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), - random_action_prob=self.cfg.algo.train_random_action_prob, - hindsight_ratio=self.cfg.replay.hindsight_ratio, - buffer_size=self.cfg.mp_buffer_size, - num_workers=self.cfg.num_workers, - ) - for hook in self.sampling_hooks: - iterator.add_log_hook(hook) ->>>>>>> proof of concept of using shared pinned buffers return torch.utils.data.DataLoader( src, batch_size=None, @@ -292,8 +267,9 @@ def run(self, logger=None): start = self.cfg.start_at_step + 1 num_training_steps = self.cfg.num_training_steps logger.info("Starting training") -<<<<<<< HEAD start_time = time.time() + t0 = time.time() + times = [] for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): # the memory fragmentation or allocation keeps growing, how often should we clean up? # is changing the allocation strategy helpful? @@ -301,11 +277,6 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() -======= - import time - t0 = time.time() - times = [] - for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): if isinstance(batch, BatchDescriptor): print(f"buffer size was {batch.size / 1024**2:.2f}M") if train_dl.dataset.do_multiple_buffers: @@ -321,7 +292,6 @@ def run(self, logger=None): times.append(t1 - t0) print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") t0 = t1 ->>>>>>> proof of concept of using shared pinned buffers epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -329,13 +299,9 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue -<<<<<<< HEAD - info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) + info = self.train_batch(batch, epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() -======= - info = self.train_batch(batch, epoch_idx, batch_idx, it) ->>>>>>> proof of concept of using shared pinned buffers self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) From 27dfc23a52c0dfb204669b493860e624e1470d2b Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 13:07:21 -0700 Subject: [PATCH 14/31] add to DataSource --- src/gflownet/data/data_source.py | 19 +- src/gflownet/data/sampling_iterator.py | 453 -------------------- src/gflownet/models/seq_transformer.py | 3 + src/gflownet/trainer.py | 1 + src/gflownet/utils/multiprocessing_proxy.py | 15 +- 5 files changed, 27 insertions(+), 464 deletions(-) delete mode 100644 src/gflownet/data/sampling_iterator.py diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index d78a2a7f..caba7399 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -3,6 +3,7 @@ import numpy as np import torch +import torch.multiprocessing as mp from torch.utils.data import IterableDataset from gflownet import GFNAlgorithm, GFNTask @@ -10,6 +11,7 @@ from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng +from gflownet.utils.multiprocessing_proxy import SharedPinnedBuffer, put_into_batch_buffer def cycle_call(it): @@ -44,6 +46,7 @@ def __init__( self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step + self.setup_mp_buffers() def add_sampling_hook(self, hook: Callable): """Add a hook that is called when sampling new trajectories. @@ -230,7 +233,7 @@ def create_batch(self, trajs, batch_info): batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) batch.flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) - return batch + return self._maybe_put_in_mp_buffer(batch) def compute_properties(self, trajs, mark_as_online=False): """Sets trajs' flat_rewards and is_valid keys by querying the task.""" @@ -318,3 +321,17 @@ def iterate_indices(self, n, num_samples): yield np.arange(i, i + num_samples) if i + num_samples < end: yield np.arange(i + num_samples, end) + + def setup_mp_buffers(self): + self.result_buffer_size = self.cfg.mp_buffer_size + if self.result_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + + def _maybe_put_in_mp_buffer(self, batch): + if self.result_buffer_size: + self.result_buffer[self._wid].lock.acquire() + desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) + desc.wid = self._wid + return desc + else: + return batch diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py deleted file mode 100644 index 8daec21b..00000000 --- a/src/gflownet/data/sampling_iterator.py +++ /dev/null @@ -1,453 +0,0 @@ -import os -import sqlite3 -from collections.abc import Iterable -from copy import deepcopy -from typing import Callable, List - -import numpy as np -import torch -import torch.nn as nn -import torch.multiprocessing as mp -from rdkit import RDLogger -from torch.utils.data import Dataset, IterableDataset - -from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphActionCategorical -from gflownet.utils.multiprocessing_proxy import put_into_batch_buffer, SharedPinnedBuffer - - -class SamplingIterator(IterableDataset): - """This class allows us to parallelise and train faster. - - By separating sampling data/the model and building torch geometric - graphs from training the model, we can do the former in different - processes, which is much faster since much of graph construction - is CPU-bound. - - """ - - def __init__( - self, - dataset: Dataset, - model: nn.Module, - ctx, - algo, - task, - device, - batch_size: int = 1, - illegal_action_logreward: float = -50, - ratio: float = 0.5, - stream: bool = True, - replay_buffer: ReplayBuffer = None, - log_dir: str = None, - sample_cond_info: bool = True, - random_action_prob: float = 0.0, - hindsight_ratio: float = 0.0, - init_train_iter: int = 0, - buffer_size: int = None, - num_workers: int = 1, - do_multiple_buffers = True, # If True, each worker has its own buffer; doesn't seem to have much impact either way - ): - """Parameters - ---------- - dataset: Dataset - A dataset instance - model: nn.Module - The model we sample from (must be on CUDA already or share_memory() must be called so that - parameters are synchronized between each worker) - ctx: - The context for the environment, e.g. a MolBuildingEnvContext instance - algo: - The training algorithm, e.g. a TrajectoryBalance instance - task: GFNTask - A Task instance, e.g. a MakeRingsTask instance - device: torch.device - The device the model is on - replay_buffer: ReplayBuffer - The replay buffer for training on past data - batch_size: int - The number of trajectories, each trajectory will be comprised of many graphs, so this is - _not_ the batch size in terms of the number of graphs (that will depend on the task) - illegal_action_logreward: float - The logreward for invalid trajectories - ratio: float - The ratio of offline trajectories in the batch. - stream: bool - If True, data is sampled iid for every batch. Otherwise, this is a normal in-order - dataset iterator. - log_dir: str - If not None, logs each SamplingIterator worker's generated molecules to that file. - sample_cond_info: bool - If True (default), then the dataset is a dataset of points used in offline training. - If False, then the dataset is a dataset of preferences (e.g. used to validate the model) - random_action_prob: float - The probability of taking a random action, passed to the graph sampler - init_train_iter: int - The initial training iteration, incremented and passed to task.sample_conditional_information - """ - self.data = dataset - self.model = model - self.replay_buffer = replay_buffer - self.batch_size = batch_size - self.illegal_action_logreward = illegal_action_logreward - self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) - self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) - self.ratio = ratio - self.ctx = ctx - self.algo = algo - self.task = task - self.device = device - self.stream = stream - self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely - self.sample_cond_info = sample_cond_info - self.random_action_prob = random_action_prob - self.hindsight_ratio = hindsight_ratio - self.train_it = init_train_iter - self.do_validate_batch = False # Turn this on for debugging - - self.result_buffer_size = buffer_size - self.do_multiple_buffers = do_multiple_buffers - if buffer_size and do_multiple_buffers: - self.result_buffer = [SharedPinnedBuffer(buffer_size) for _ in range(num_workers)] - elif buffer_size: - self.result_buffer = SharedPinnedBuffer(buffer_size) - self.round_robin_cond = mp.Condition() - self.round_robin_counter = torch.zeros(1) - self.round_robin_counter.share_memory_() - - # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) - # then "offline" now refers to cond info and online to x, so no duplication and we don't end - # up with 2*batch_size accidentally - if not sample_cond_info: - self.offline_batch_size = self.online_batch_size = self.batch_size - - # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we - # don't want to initialize per-worker things just yet, such as where the log the worker writes - # to. This must be done in __iter__, which is called by the DataLoader once this instance - # has been copied into a new python process. - self.log_dir = log_dir - self.log = SQLiteLog() - self.log_hooks: List[Callable] = [] - - def add_log_hook(self, hook: Callable): - self.log_hooks.append(hook) - - def _idx_iterator(self): - RDLogger.DisableLog("rdApp.*") - if self.stream: - # If we're streaming data, just sample `offline_batch_size` indices - while True: - yield self.rng.integers(0, len(self.data), self.offline_batch_size) - else: - # Otherwise, figure out which indices correspond to this worker - worker_info = torch.utils.data.get_worker_info() - n = len(self.data) - if n == 0: - yield np.arange(0, 0) - return - assert ( - self.offline_batch_size > 0 - ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" - if worker_info is None: # no multi-processing - start, end, wid = 0, n, -1 - else: # split the data into chunks (per-worker) - nw = worker_info.num_workers - wid = worker_info.id - start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) - bs = self.offline_batch_size - if end - start <= bs: - yield np.arange(start, end) - return - for i in range(start, end - bs, bs): - yield np.arange(i, i + bs) - if i + bs < end: - yield np.arange(i + bs, end) - - def __len__(self): - if self.stream: - return int(1e6) - if len(self.data) == 0 and self.sample_online_once: - return 1 - return len(self.data) - - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - # Now that we know we are in a worker instance, we can initialize per-worker things - self.rng = self.algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) - self.ctx.device = self.device - if self.log_dir is not None: - os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" - self.log.connect(self.log_path) - - for idcs in self._idx_iterator(): - num_offline = idcs.shape[0] # This is in [0, self.offline_batch_size] - # Sample conditional info such as temperature, trade-off weights, etc. - - if self.sample_cond_info: - num_online = self.online_batch_size - cond_info = self.task.sample_conditional_information( - num_offline + self.online_batch_size, self.train_it - ) - - # Sample some dataset data - graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) - flat_rewards = ( - list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] - ) - trajs = self.algo.create_training_data_from_graphs( - graphs, self.model, cond_info["encoding"][:num_offline], 0 - ) - - else: # If we're not sampling the conditionals, then the idcs refer to listed preferences - num_online = num_offline - num_offline = 0 - cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) - ) - trajs, flat_rewards = [], [] - - # Sample some on-policy data - is_valid = torch.ones(num_offline + num_online).bool() - if num_online > 0: - with torch.no_grad(): - trajs += self.algo.create_training_data_from_own_samples( - self.model, - num_online, - cond_info["encoding"][num_offline:], - random_action_prob=self.random_action_prob, - ) - if self.algo.bootstrap_own_reward: - # The model can be trained to predict its own reward, - # i.e. predict the output of cond_info_to_logreward - pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] - flat_rewards += pred_reward - else: - # Otherwise, query the task for flat rewards - valid_idcs = torch.tensor( - [i + num_offline for i in range(num_online) if trajs[i + num_offline]["is_valid"]] - ).long() - # fetch the valid trajectories endpoints - mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] - # ask the task to compute their reward - online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) - assert ( - online_flat_rew.ndim == 2 - ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" - # The task may decide some of the mols are invalid, we have to again filter those - valid_idcs = valid_idcs[m_is_valid] - pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) - pred_reward[valid_idcs - num_offline] = online_flat_rew - is_valid[num_offline:] = False - is_valid[valid_idcs] = True - flat_rewards += list(pred_reward) - # Override the is_valid key in case the task made some mols invalid - for i in range(num_online): - trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() - - # Compute scalar rewards from conditional information & flat rewards - flat_rewards = torch.stack(flat_rewards) - log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - # Computes some metrics - extra_info = {} - if not self.sample_cond_info: - # If we're using a dataset of preferences, the user may want to know the id of the preference - for i, j in zip(trajs, idcs): - i["data_idx"] = j - # note: we convert back into natural rewards for logging purposes - # (allows to take averages and plot in objective space) - # TODO: implement that per-task (in case they don't apply the same beta and log transformations) - rewards = torch.exp(log_rewards / cond_info["beta"]) - if num_online > 0 and self.log_dir is not None: - self.log_generated( - deepcopy(trajs[num_offline:]), - deepcopy(rewards[num_offline:]), - deepcopy(flat_rewards[num_offline:]), - {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, - ) - if num_online > 0: - extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() - for hook in self.log_hooks: - extra_info.update( - hook( - deepcopy(trajs[num_offline:]), - deepcopy(rewards[num_offline:]), - deepcopy(flat_rewards[num_offline:]), - {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, - ) - ) - - if self.replay_buffer is not None: - # If we have a replay buffer, we push the online trajectories in it - # and resample immediately such that the "online" data in the batch - # comes from a more stable distribution (try to avoid forgetting) - - # cond_info is a dict, so we need to convert it to a list of dicts - cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] - - # push the online trajectories in the replay buffer and sample a new 'online' batch - for i in range(num_offline, len(trajs)): - self.replay_buffer.push( - deepcopy(trajs[i]), - deepcopy(log_rewards[i]), - deepcopy(flat_rewards[i]), - deepcopy(cond_info[i]), - deepcopy(is_valid[i]), - ) - replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( - num_online - ) - - # append the online trajectories to the offline ones - trajs[num_offline:] = replay_trajs - log_rewards[num_offline:] = replay_logr - flat_rewards[num_offline:] = replay_fr - cond_info[num_offline:] = replay_condinfo - is_valid[num_offline:] = replay_valid - - # convert cond_info back to a dict - cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} - - if self.hindsight_ratio > 0.0: - # Relabels some of the online trajectories with hindsight - assert hasattr( - self.task, "relabel_condinfo_and_logrewards" - ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" - # samples indexes of trajectories without repeats - hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline - cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( - cond_info, log_rewards, flat_rewards, hindsight_idxs - ) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - # Construct batch - batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) - batch.num_offline = num_offline - batch.num_online = num_online - batch.flat_rewards = flat_rewards - batch.preferences = cond_info.get("preferences", None) - batch.focus_dir = cond_info.get("focus_dir", None) - batch.extra_info = extra_info - # TODO: we could very well just pass the cond_info dict to construct_batch above, - # and the algo can decide what it wants to put in the batch object - - # Only activate for debugging your environment or dataset (e.g. the dataset could be - # generating trajectories with illegal actions) - if self.do_validate_batch: - self.validate_batch(batch, trajs) - - self.train_it += worker_info.num_workers if worker_info is not None else 1 - if self.result_buffer_size and not self.do_multiple_buffers: - with self.round_robin_cond: - self.round_robin_cond.wait_for(lambda: self.round_robin_counter[0] == self._wid) - self.result_buffer.lock.acquire() - yield put_into_batch_buffer(batch, self.result_buffer.buffer) - self.round_robin_counter[0] = (self._wid + 1) % worker_info.num_workers - with self.round_robin_cond: - self.round_robin_cond.notify_all() - elif self.result_buffer_size: - self.result_buffer[self._wid].lock.acquire() - desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) - desc.wid = self._wid - yield desc - else: - yield batch - - def validate_batch(self, batch, trajs): - for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( - [(batch.bck_actions, self.ctx.bck_action_type_order)] - if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") - else [] - ): - mask_cat = GraphActionCategorical( - batch, - [self.model._action_type_to_mask(t, batch) for t in atypes], - [self.model._action_type_to_key[t] for t in atypes], - [None for _ in atypes], - ) - masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) - num_trajs = len(trajs) - batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) - first_graph_idx = torch.zeros_like(batch.traj_lens) - torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) - if masked_action_is_used.sum() != 0: - invalid_idx = masked_action_is_used.argmax().item() - traj_idx = batch_idx[invalid_idx].item() - timestep = invalid_idx - first_graph_idx[traj_idx].item() - raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) - - def log_generated(self, trajs, rewards, flat_rewards, cond_info): - if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] - else: - mols = [""] * len(trajs) - - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() - rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] - - data = [ - [mols[i], rewards[i]] - + flat_rewards[i] - + preferences[i] - + focus_dir[i] - + [cond_info[k][i].item() for k in logged_keys] - for i in range(len(trajs)) - ] - - data_labels = ( - ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] - + [f"pref_{i}" for i in range(len(preferences[0]))] - + [f"focus_{i}" for i in range(len(focus_dir[0]))] - + [f"ci_{k}" for k in logged_keys] - ) - - self.log.insert_many(data, data_labels) - - -class SQLiteLog: - def __init__(self, timeout=300): - """Creates a log instance, but does not connect it to any db.""" - self.is_connected = False - self.db = None - self.timeout = timeout - - def connect(self, db_path: str): - """Connects to db_path - - Parameters - ---------- - db_path: str - The sqlite3 database path. If it does not exist, it will be created. - """ - self.db = sqlite3.connect(db_path, timeout=self.timeout) - cur = self.db.cursor() - self._has_results_table = len( - cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() - ) - cur.close() - - def _make_results_table(self, types, names): - type_map = {str: "text", float: "real", int: "real"} - col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) - cur = self.db.cursor() - cur.execute(f"create table results ({col_str})") - self._has_results_table = True - cur.close() - - def insert_many(self, rows, column_names): - assert all( - [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] - ), "rows must only contain scalars" - if not self._has_results_table: - self._make_results_table([type(i) for i in rows[0]], column_names) - cur = self.db.cursor() - cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec - cur.close() - self.db.commit() diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 6916366a..751cf290 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -77,6 +77,9 @@ def forward(self, xs: SeqBatch, cond, batched=False): x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) + if cond is None: + cond = xs.cond_info + if self.use_cond: cond_var = self.cond_embed(cond) # (batch, nemb) cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 0d758c78..6bddf98b 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -277,6 +277,7 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() + if isinstance(batch, BatchDescriptor): print(f"buffer size was {batch.size / 1024**2:.2f}M") if train_dl.dataset.do_multiple_buffers: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 618d98f2..d92bb61f 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -68,14 +68,13 @@ def put_into_batch_buffer(batch, buffer): types.append(v.dtype) shapes.append(tuple(v.shape)) numel = v.numel() * v.element_size() - # print('putting', k, v.shape, numel, offset) buffer[offset : offset + numel] = v.view(-1).view(torch.uint8) offset += numel offset += (8 - offset % 8) % 8 # align to 8 bytes if offset > buffer.shape[0]: - raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") - # print(f'total size: {offset / 1024**2:.3f}M') - # print(batch.batch) + raise ValueError( + f"Offset {offset} exceeds buffer size {buffer.shape[0]}. Try increasing `cfg.mp_buffer_size`." + ) return BatchDescriptor(names, types, shapes, offset, others) @@ -83,18 +82,16 @@ def resolve_batch_buffer(descriptor, buffer, device): offset = 0 batch = Batch() batch._slice_dict = {} - cuda_buffer = buffer[: descriptor.size].to(device) # TODO: check if only sending `size` is faster? + cuda_buffer = buffer[: descriptor.size].to(device) for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): numel = prod(shape) * dtype.itemsize - # print('restoring', name, shape, numel, offset) if name.startswith("_slice_dict_"): batch._slice_dict[name[12:]] = cuda_buffer[offset : offset + numel].view(dtype).view(shape) else: setattr(batch, name, cuda_buffer[offset : offset + numel].view(dtype).view(shape)) offset += numel offset += (8 - offset % 8) % 8 # align to 8 bytes - # print(batch.batch) - # print(f'total size: {offset / 1024**2:.3f}M') + for k, v in descriptor.other.items(): setattr(batch, k, v) return batch @@ -112,7 +109,6 @@ def put_into_result_buffer(result, buffer): for name in gac_names: tensors = getattr(gac, name) for i, x in enumerate(tensors): - # print(f"putting {name}@{i} with shape {x.shape}") numel = x.numel() * x.element_size() if numel > 0: # We need this for a funny reason @@ -163,7 +159,6 @@ def resolve_result_buffer(descriptor, buffer, device): raise ValueError(f"Unknown result descriptor name: {name}") offset += numel offset += (8 - offset % 8) % 8 # align to 8 bytes - # print(f"restored {name} with shape {shape}") return gac, per_graph_out From e9f1dc13824a0ba621c7f71e05cd2005bee4945d Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 8 Mar 2024 09:16:53 -0700 Subject: [PATCH 15/31] various fixes --- docs/implementation_notes.md | 14 +++++++++ src/gflownet/algo/config.py | 1 + src/gflownet/data/data_source.py | 17 +++++++---- src/gflownet/tasks/seh_frag.py | 5 ++-- src/gflownet/trainer.py | 32 +++++++++++---------- src/gflownet/utils/multiprocessing_proxy.py | 27 +++++++++++++---- 6 files changed, 68 insertions(+), 28 deletions(-) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 600bb1d4..696095f1 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -34,3 +34,17 @@ The code contains a specific categorical distribution type for graph actions, `G Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. + +## Multiprocessing + +We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks. + +Because training models involves sampling them, the worker processes need to be able to call the models. This is done by passing a wrapped model (and possibly wrapped replay buffer) to the workers, using `gflownet.utils.multiprocessing_proxy`. These wrappers ensure that model calls are routed to the main worker process, where the model lives (e.g. in CUDA), and that the returned values are properly serialized and sent back to the worker process. These wrappers are also designed to be API-compatible with models, e.g. `model(input)` or `model.method(input)` will work as expected, regardless of whether `model` is a torch module or a wrapper. Note that it is only possible to call methods on these wrappers, direct attribute access is not supported. + +Note that the workers do not use CUDA, therefore have to work entirely on CPU, but the code is designed to be somewhat agnostic to this fact. By using `get_worker_device`, code can be written without assuming too much; again, calls such as `model(input)` will work as expected. + +On message serialization, naively sending batches of data and results (`Batch` and `GraphActionCategorical`) through multiprocessing queues is fairly inefficient. Torch tries to be smart and will use shared memory for tensors that are sent through queues, which unfortunately is very slow because creating these shared memory files is slow, and because `Data` `Batch`es tend to contain lots of small tensors, which is not a good fit for shared memory. + +We implement two solutions to this problem (in order of preference): +- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but that the size of the largest possible batch/return value is known in advance. This is currently only implemented for `Batch` inputs and `(GraphActionCategorical, Tensor)` outputs. +- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creating of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle. \ No newline at end of file diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index e2576982..70780359 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -138,6 +138,7 @@ class AlgoConfig: train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 + compute_log_n: bool = False tb: TBConfig = field(default_factory=TBConfig) moql: MOQLConfig = field(default_factory=MOQLConfig) a2c: A2CConfig = field(default_factory=A2CConfig) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index caba7399..eabad5cd 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -5,11 +5,12 @@ import torch import torch.multiprocessing as mp from torch.utils.data import IterableDataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphBuildingEnvContext +from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical from gflownet.utils.misc import get_worker_rng from gflownet.utils.multiprocessing_proxy import SharedPinnedBuffer, put_into_batch_buffer @@ -228,7 +229,7 @@ def create_batch(self, trajs, batch_info): if "focus_dir" in trajs[0]: batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) - if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute + if self.ctx.has_n() and self.cfg.algo.compute_log_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) @@ -323,12 +324,18 @@ def iterate_indices(self, n, num_samples): yield np.arange(i + num_samples, end) def setup_mp_buffers(self): - self.result_buffer_size = self.cfg.mp_buffer_size - if self.result_buffer_size: - self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + if self.cfg.num_workers > 0: + self.result_buffer_size = self.cfg.mp_buffer_size + if self.result_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + else: + self.result_buffer_size = None def _maybe_put_in_mp_buffer(self, batch): if self.result_buffer_size: + if not (isinstance(batch, Batch)): + warnings.warn(f"Expected a Batch object, but got {type(batch)}. " "Not using mp buffers.") + return batch self.result_buffer[self._wid].lock.acquire() desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) desc.wid = self._wid diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 41e67df4..b5179124 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -208,9 +208,10 @@ def main(): config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.overwrite_existing_exp = True + config.algo.num_from_policy = 64 config.num_training_steps = 1_00 - config.validate_every = 20 - config.num_final_gen_steps = 10 + config.validate_every = 2000 + config.num_final_gen_steps = 0 config.num_workers = 8 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 6bddf98b..a71cb835 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -3,7 +3,7 @@ import pathlib import shutil import time -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, List, Optional, Protocol, Union import numpy as np import torch @@ -15,6 +15,7 @@ from rdkit import RDLogger from torch import Tensor from torch.utils.data import DataLoader, Dataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource @@ -181,8 +182,7 @@ def build_training_data_loader(self) -> DataLoader: def build_validation_data_loader(self) -> DataLoader: model = self._wrap_for_mp(self.model) - # TODO: we're changing the default, make sure anything that is using test data is adjusted - src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + n_drawn = self.cfg.algo.valid_num_from_policy n_from_dataset = self.cfg.algo.valid_num_from_dataset @@ -247,6 +247,16 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} + def _maybe_resolve_batch_buffer(self, batch: Union[Batch, BatchDescriptor], dl: DataLoader) -> Batch: + if isinstance(batch, BatchDescriptor): + print(f"buffer size was {batch.size / 1024**2:.2f}M") + wid = batch.wid + batch = resolve_batch_buffer(batch, dl.dataset.result_buffer[wid].buffer, self.device) + dl.dataset.result_buffer[wid].lock.release() + else: + batch = batch.to(self.device) + return batch + def run(self, logger=None): """Trains the GFN for `num_training_steps` minibatches, performing validation every `validate_every` minibatches. @@ -277,18 +287,8 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() - - if isinstance(batch, BatchDescriptor): - print(f"buffer size was {batch.size / 1024**2:.2f}M") - if train_dl.dataset.do_multiple_buffers: - wid = batch.wid - batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer[wid].buffer, self.device) - train_dl.dataset.result_buffer[wid].lock.release() - else: - batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer.buffer, self.device) - train_dl.dataset.result_buffer.lock.release() - else: - batch = batch.to(self.device) + _bd = batch + batch = self._maybe_resolve_batch_buffer(batch, train_dl) t1 = time.time() times.append(t1 - t0) print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") @@ -309,6 +309,7 @@ def run(self, logger=None): if valid_freq > 0 and it % valid_freq == 0: for batch in valid_dl: + batch = self._maybe_resolve_batch_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -329,6 +330,7 @@ def run(self, logger=None): range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): + batch = self._maybe_resolve_batch_buffer(batch, final_dl) if hasattr(batch, "extra_info"): for k, v in batch.extra_info.items(): if k not in final_info: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index d92bb61f..0f2c6aef 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -8,6 +8,7 @@ import torch import torch.multiprocessing as mp from torch_geometric.data import Batch +import warnings from gflownet.envs.graph_building_env import GraphActionCategorical @@ -18,13 +19,24 @@ def __init__(self, size): self.buffer = torch.empty(size, dtype=torch.uint8) self.buffer.share_memory_() self.lock = mp.Lock() - - cudart = torch.cuda.cudart() - r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) - assert r == 0 + self.do_unreg = False + + if not self.buffer.is_pinned(): + # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to + # pin it again; doing so will raise a CUDA error + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) + assert r == 0 + self.do_unreg = True # But then we need to unregister it later assert self.buffer.is_shared() assert self.buffer.is_pinned() + def __del__(self): + if self.do_unreg and torch.utils.data.get_worker_info() is None: + cudart = torch.cuda.cudart() + r = cudart.cudaHostUnregister(self.buffer.data_ptr()) + assert r == 0 + class BatchDescriptor: def __init__(self, names, types, shapes, size, other): @@ -82,7 +94,10 @@ def resolve_batch_buffer(descriptor, buffer, device): offset = 0 batch = Batch() batch._slice_dict = {} + # Seems legit to send just a 0-starting slice, because it should be pinned as well (and timing this vs sending + # the whole buffer, it seems to be the marginally faster option) cuda_buffer = buffer[: descriptor.size].to(device) + for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): numel = prod(shape) * dtype.itemsize if name.startswith("_slice_dict_"): @@ -300,7 +315,7 @@ def to_cpu(self, i): def run(self): timeouts = 0 - while not self.stop.is_set() or timeouts < 500: + while not self.stop.is_set() and timeouts < 5 / 1e-5: for qi, q in enumerate(self.in_queues): try: r = self.decode(q.get(True, 1e-5)) @@ -379,4 +394,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, bb_size=bb_size) From c048e77ff0c5fc0cbd11c128f9c392195ec4e091 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 8 Mar 2024 13:52:54 -0700 Subject: [PATCH 16/31] major simplification by reusing pickling mechanisms --- src/gflownet/data/data_source.py | 22 +- src/gflownet/envs/seq_building_env.py | 1 + src/gflownet/tasks/seh_frag.py | 4 +- src/gflownet/trainer.py | 24 +- src/gflownet/utils/multiprocessing_proxy.py | 266 +++++++------------- 5 files changed, 108 insertions(+), 209 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index eabad5cd..009b0fd4 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -3,16 +3,15 @@ import numpy as np import torch -import torch.multiprocessing as mp from torch.utils.data import IterableDataset from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical +from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng -from gflownet.utils.multiprocessing_proxy import SharedPinnedBuffer, put_into_batch_buffer +from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer def cycle_call(it): @@ -325,20 +324,17 @@ def iterate_indices(self, n, num_samples): def setup_mp_buffers(self): if self.cfg.num_workers > 0: - self.result_buffer_size = self.cfg.mp_buffer_size - if self.result_buffer_size: - self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + self.mp_buffer_size = self.cfg.mp_buffer_size + if self.mp_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.mp_buffer_size) for _ in range(self.cfg.num_workers)] else: - self.result_buffer_size = None + self.mp_buffer_size = None def _maybe_put_in_mp_buffer(self, batch): - if self.result_buffer_size: + if self.mp_buffer_size: if not (isinstance(batch, Batch)): - warnings.warn(f"Expected a Batch object, but got {type(batch)}. " "Not using mp buffers.") + warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.") return batch - self.result_buffer[self._wid].lock.acquire() - desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) - desc.wid = self._wid - return desc + return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid) else: return batch diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index b8189690..c5e77bab 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -69,6 +69,7 @@ def __init__(self, seqs: List[torch.Tensor], pad: int): # Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this # is the total number of timesteps. self.num_graphs = self.lens.sum().item() + self.cond_info: torch.Tensor # May be set later def to(self, device): for name in dir(self): diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index b5179124..74d997a7 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -214,10 +214,12 @@ def main(): config.num_final_gen_steps = 0 config.num_workers = 8 config.opt.lr_decay = 20_000 + config.opt.clip_grad_type = "total_norm" config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - config.mp_buffer_size = 32 * 1024 ** 2 + config.mp_buffer_size = 32 * 1024**2 + # config.pickle_mp_messages = True trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index a71cb835..1e3352e6 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -23,8 +23,8 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device +from gflownet.utils.multiprocessing_proxy import BufferUnpickler, mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook -from gflownet.utils.multiprocessing_proxy import mp_object_wrapper, resolve_batch_buffer, BatchDescriptor from .config import Config @@ -132,7 +132,7 @@ def _wrap_for_mp(self, obj): self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, - bb_size=self.cfg.mp_buffer_size, + sb_size=self.cfg.mp_buffer_size, ) self.to_terminate.append(wrapper.terminate) return wrapper.placeholder @@ -182,7 +182,7 @@ def build_training_data_loader(self) -> DataLoader: def build_validation_data_loader(self) -> DataLoader: model = self._wrap_for_mp(self.model) - + n_drawn = self.cfg.algo.valid_num_from_policy n_from_dataset = self.cfg.algo.valid_num_from_dataset @@ -247,13 +247,11 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} - def _maybe_resolve_batch_buffer(self, batch: Union[Batch, BatchDescriptor], dl: DataLoader) -> Batch: - if isinstance(batch, BatchDescriptor): - print(f"buffer size was {batch.size / 1024**2:.2f}M") - wid = batch.wid - batch = resolve_batch_buffer(batch, dl.dataset.result_buffer[wid].buffer, self.device) - dl.dataset.result_buffer[wid].lock.release() - else: + def _maybe_resolve_shared_buffer(self, batch: Union[Batch, tuple, list], dl: DataLoader) -> Batch: + if dl.dataset.mp_buffer_size > 0 and isinstance(batch, (tuple, list)): + batch, wid = batch + batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load() + elif isinstance(batch, Batch): batch = batch.to(self.device) return batch @@ -288,7 +286,7 @@ def run(self, logger=None): gc.collect() torch.cuda.empty_cache() _bd = batch - batch = self._maybe_resolve_batch_buffer(batch, train_dl) + batch = self._maybe_resolve_shared_buffer(batch, train_dl) t1 = time.time() times.append(t1 - t0) print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") @@ -309,7 +307,7 @@ def run(self, logger=None): if valid_freq > 0 and it % valid_freq == 0: for batch in valid_dl: - batch = self._maybe_resolve_batch_buffer(batch, valid_dl) + batch = self._maybe_resolve_shared_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -330,7 +328,7 @@ def run(self, logger=None): range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): - batch = self._maybe_resolve_batch_buffer(batch, final_dl) + batch = self._maybe_resolve_shared_buffer(batch, final_dl) if hasattr(batch, "extra_info"): for k, v in batch.extra_info.items(): if k not in final_info: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 0f2c6aef..86a3d971 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -1,16 +1,12 @@ +import io import pickle import queue import threading import traceback -from itertools import chain +from pickle import Pickler, Unpickler, UnpicklingError -import numpy as np import torch import torch.multiprocessing as mp -from torch_geometric.data import Batch -import warnings - -from gflownet.envs.graph_building_env import GraphActionCategorical class SharedPinnedBuffer: @@ -22,7 +18,7 @@ def __init__(self, size): self.do_unreg = False if not self.buffer.is_pinned(): - # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to + # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to # pin it again; doing so will raise a CUDA error cudart = torch.cuda.cudart() r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) @@ -38,159 +34,82 @@ def __del__(self): assert r == 0 -class BatchDescriptor: - def __init__(self, names, types, shapes, size, other): - self.names = names - self.types = types - self.shapes = shapes - self.size = size - self.other = other +class _BufferPicklerSentinel: + pass -class ResultDescriptor: - def __init__(self, names, types, shapes, size, gac_attrs): - self.names = names - self.types = types - self.shapes = shapes - self.size = size - self.gac_attrs = gac_attrs +class BufferPickler(Pickler): + def __init__(self, buf: SharedPinnedBuffer): + self._f = io.BytesIO() + super().__init__(self._f) + self.buf = buf + # The lock will be released by the consumer of this buffer once the memory has been transferred to the device + self.buf.lock.acquire() + self.buf_offset = 0 + + def persistent_id(self, v): + if not isinstance(v, torch.Tensor): + return None + numel = v.numel() * v.element_size() + start = self.buf_offset + if numel > 0: + self.buf.buffer[start : start + numel] = v.view(-1).view(torch.uint8) + self.buf_offset += numel + self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes + return (_BufferPicklerSentinel, (start, tuple(v.shape), v.dtype)) + + def dumps(self, obj): + self.dump(obj) + return (self._f.getvalue(), self.buf_offset) + + +class BufferUnpickler(Unpickler): + def __init__(self, buf: SharedPinnedBuffer, data, device): + self._f, total_size = io.BytesIO(data[0]), data[1] + super().__init__(self._f) + self.buf = buf + self.target_buf = buf.buffer[:total_size].to(device) + + def load_tensor(self, offset, shape, dtype): + numel = prod(shape) * dtype.itemsize + tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) + return tensor + def persistent_load(self, pid): + if isinstance(pid, tuple): + sentinel, (offset, shape, dtype) = pid + if sentinel is _BufferPicklerSentinel: + return self.load_tensor(offset, shape, dtype) + return UnpicklingError("Invalid persistent id") -def prod(l): + def load(self): + r = super().load() + # We're done with this buffer, release it for the next consumer + self.buf.lock.release() + return r + + +def prod(ns): p = 1 - for i in l: + for i in ns: p *= i return p -def put_into_batch_buffer(batch, buffer): - names = [] - types = [] - shapes = [] - offset = 0 - others = {} - for k, v in chain(batch._store.items(), (("_slice_dict_" + k, v) for k, v in batch._slice_dict.items())): - if not isinstance(v, torch.Tensor): - try: - v = torch.as_tensor(v) - except Exception as e: - others[k] = v - continue - names.append(k) - types.append(v.dtype) - shapes.append(tuple(v.shape)) - numel = v.numel() * v.element_size() - buffer[offset : offset + numel] = v.view(-1).view(torch.uint8) - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - if offset > buffer.shape[0]: - raise ValueError( - f"Offset {offset} exceeds buffer size {buffer.shape[0]}. Try increasing `cfg.mp_buffer_size`." - ) - return BatchDescriptor(names, types, shapes, offset, others) - - -def resolve_batch_buffer(descriptor, buffer, device): - offset = 0 - batch = Batch() - batch._slice_dict = {} - # Seems legit to send just a 0-starting slice, because it should be pinned as well (and timing this vs sending - # the whole buffer, it seems to be the marginally faster option) - cuda_buffer = buffer[: descriptor.size].to(device) - - for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): - numel = prod(shape) * dtype.itemsize - if name.startswith("_slice_dict_"): - batch._slice_dict[name[12:]] = cuda_buffer[offset : offset + numel].view(dtype).view(shape) - else: - setattr(batch, name, cuda_buffer[offset : offset + numel].view(dtype).view(shape)) - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - - for k, v in descriptor.other.items(): - setattr(batch, k, v) - return batch - - -def put_into_result_buffer(result, buffer): - gac_names = ["logits", "batch", "slice", "masks"] - gac, tensor = result - buffer[: tensor.numel() * tensor.element_size()] = tensor.view(-1).view(torch.uint8) - offset = tensor.numel() * tensor.element_size() - offset += (8 - offset % 8) % 8 # align to 8 bytes - names = ["@per_graph_out"] - types = [tensor.dtype] - shapes = [tensor.shape] - for name in gac_names: - tensors = getattr(gac, name) - for i, x in enumerate(tensors): - numel = x.numel() * x.element_size() - if numel > 0: - # We need this for a funny reason - # torch.zeros(0)[::2] has a stride of (2,), and is contiguous according to torch - # so, flattening it and then reshaping it will not change the stride, which will - # make view(uint8) complain that the strides are not compatible. - # The batch[::2] happens when creating the categorical and deduplicate_edge_index is True - buffer[offset : offset + numel] = x.flatten().view(torch.uint8) - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - if offset > buffer.shape[0]: - raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") - names.append(f"{name}@{i}") - types.append(x.dtype) - shapes.append(tuple(x.shape)) - return ResultDescriptor(names, types, shapes, offset, (gac.num_graphs, gac.keys, gac.types)) - - -def resolve_result_buffer(descriptor, buffer, device): - # TODO: models can return multiple GraphActionCategoricals, but we only support one for now - # Would be nice to have something generic (and recursive?) - offset = 0 - tensor = buffer[: descriptor.size].to(device) - if tensor.device == device: # CPU to CPU - # I think we need this? Otherwise when we release the lock, the memory might be overwritten - tensor = tensor.clone() - # Maybe make this a static method, or just overload __new__? - gac = GraphActionCategorical.__new__(GraphActionCategorical) - gac.num_graphs, gac.keys, gac.types = descriptor.gac_attrs - gac.dev = device - gac.logprobs = None - gac._epsilon = 1e-38 - - gac_names = ["logits", "batch", "slice", "masks"] - for i in gac_names: - setattr(gac, i, [None] * len(gac.types)) - - for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): - numel = prod(shape) * dtype.itemsize - if name == "@per_graph_out": - per_graph_out = tensor[offset : offset + numel].view(dtype).view(shape) - else: - name, index = name.split("@") - index = int(index) - if name in gac_names: - getattr(gac, name)[index] = tensor[offset : offset + numel].view(dtype).view(shape) - else: - raise ValueError(f"Unknown result descriptor name: {name}") - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - return gac, per_graph_out - - class MPObjectPlaceholder: """This class can be used for example as a model or dataset placeholder in a worker process, and translates calls to the object-placeholder into queries for the main process to execute on the real object.""" - def __init__(self, in_queues, out_queues, pickle_messages=False, batch_buffer_size=None): + def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_size=None): self.qs = in_queues, out_queues self.device = torch.device("cpu") self.pickle_messages = pickle_messages self._is_init = False - self.batch_buffer_size = batch_buffer_size - if batch_buffer_size is not None: - self._batch_buffer = SharedPinnedBuffer(batch_buffer_size) - self._result_buffer = SharedPinnedBuffer(batch_buffer_size) + self.shared_buffer_size = shared_buffer_size + if shared_buffer_size is not None: + self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size) + self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size) def _check_init(self): if self._is_init: @@ -205,19 +124,20 @@ def _check_init(self): self._is_init = True def encode(self, m): + if self.shared_buffer_size: + return BufferPickler(self._buffer_to_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) return m def decode(self, m): + if self.shared_buffer_size: + m = BufferUnpickler(self._buffer_from_main, m, self.device).load() if self.pickle_messages: m = pickle.loads(m) if isinstance(m, Exception): print("Received exception from main process, reraising.") raise m - if isinstance(m, ResultDescriptor): - m = resolve_result_buffer(m, self._result_buffer.buffer, self.device) - self._result_buffer.lock.release() return m def __getattr__(self, name): @@ -230,11 +150,6 @@ def method_wrapper(*a, **kw): def __call__(self, *a, **kw): self._check_init() - if self.batch_buffer_size and len(a) and isinstance(a[0], Batch): - # The lock will be released by the consumer of this buffer once the memory has been transferred to CUDA - self._batch_buffer.lock.acquire() - batch_descriptor = put_into_batch_buffer(a[0], self._batch_buffer.buffer) - a = (batch_descriptor,) + a[1:] self.in_queue.put(self.encode(("__call__", a, kw))) return self.decode(self.out_queue.get()) @@ -257,7 +172,7 @@ class MPObjectProxy: Always passes CPU tensors between processes. """ - def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, bb_size=None): + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy. Parameters @@ -273,13 +188,14 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo If True, pickle messages sent between processes. This reduces load on shared memory, but increases load on CPU. It is recommended to activate this flag if encountering "Too many open files"-type errors. - bb_size: Optional[int] - batch buffer size + sb_size: Optional[int] + shared buffer size """ self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, bb_size) + self.use_shared_buffer = sb_size is not None + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size) self.obj = obj if hasattr(obj, "parameters"): self.device = next(obj.parameters()).device @@ -291,20 +207,16 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.thread.start() def encode(self, m): + if self.use_shared_buffer: + return BufferPickler(self.placeholder._buffer_from_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) - if ( - self.placeholder.batch_buffer_size - and isinstance(m, (list, tuple)) - and len(m) == 2 - and isinstance(m[0], GraphActionCategorical) - and isinstance(m[1], torch.Tensor) - ): - self.placeholder._result_buffer.lock.acquire() - return put_into_result_buffer(m, self.placeholder._result_buffer.buffer) return m def decode(self, m): + if self.use_shared_buffer: + return BufferUnpickler(self.placeholder._buffer_to_main, m, self.device).load() + if self.pickle_messages: return pickle.loads(m) return m @@ -326,12 +238,6 @@ def run(self): break timeouts = 0 attr, args, kwargs = r - if self.placeholder.batch_buffer_size and len(args) and isinstance(args[0], BatchDescriptor): - batch = resolve_batch_buffer(args[0], self.placeholder._batch_buffer.buffer, self.device) - args = (batch,) + args[1:] - # Should this release happen after the call to f()? Are we at risk of overwriting memory that - # is still being used by CUDA? - self.placeholder._batch_buffer.lock.release() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -359,34 +265,30 @@ def terminate(self): self.stop.set() -def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, bb_size=None): +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, this can be used to wrap a model such that only the main process makes cuda calls by forwarding data through the model, or a replay buffer such that the new data is pushed in from the worker processes but only the main process has to hold the full buffer in memory. - self.out_queues[qi].put(self.encode(msg)) - elif isinstance(result, dict): - msg = {k: self.to_cpu(i) for k, i in result.items()} - self.out_queues[qi].put(self.encode(msg)) - else: - msg = self.to_cpu(result) - self.out_queues[qi].put(self.encode(msg)) Parameters ---------- obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) - Lives in the main process to which method calls are passed + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple Types that will be cast to cuda when received as arguments of method calls. torch.Tensor is cast by default. pickle_messages: bool - If True, pickle messages sent between processes. This reduces load on shared - memory, but increases load on CPU. It is recommended to activate this flag if - encountering "Too many open files"-type errors. + If True, pickle messages sent between processes. This reduces load on shared + memory, but increases load on CPU. It is recommended to activate this flag if + encountering "Too many open files"-type errors. + sb_size: Optional[int] + If not None, creates a shared buffer of this size for sending tensors between processes. + Note, this will allocate two buffers of this size (one for sending, the other for receiving). Returns ------- @@ -394,4 +296,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, bb_size=bb_size) + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, sb_size=sb_size) From acfe07075eec5460d8c7abbb1a48985c37bc9d22 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Mon, 11 Mar 2024 10:00:34 -0600 Subject: [PATCH 17/31] memory copy + fixes and doc --- src/gflownet/config.py | 4 ++ src/gflownet/data/data_source.py | 3 +- src/gflownet/trainer.py | 9 ++-- src/gflownet/utils/multiprocessing_proxy.py | 51 ++++++++++++++++----- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 070d68d1..4a3035f9 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -78,6 +78,10 @@ class Config: The hostname of the machine on which the experiment is run pickle_mp_messages : bool Whether to pickle messages sent between processes (only relevant if num_workers > 0) + mp_buffer_size : Optional[int] + If specified, use a buffer of this size for passing tensors between processes. + Note that this is only relevant if num_workers > 0. + Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers. git_hash : Optional[str] The git hash of the current commit overwrite_existing_exp : bool diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 009b0fd4..c795ef32 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -10,6 +10,7 @@ from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext +from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import get_worker_rng from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer @@ -332,7 +333,7 @@ def setup_mp_buffers(self): def _maybe_put_in_mp_buffer(self, batch): if self.mp_buffer_size: - if not (isinstance(batch, Batch)): + if not (isinstance(batch, (Batch, SeqBatch))): warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.") return batch return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 1e3352e6..bc1e2645 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -247,11 +247,13 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} - def _maybe_resolve_shared_buffer(self, batch: Union[Batch, tuple, list], dl: DataLoader) -> Batch: - if dl.dataset.mp_buffer_size > 0 and isinstance(batch, (tuple, list)): + def _maybe_resolve_shared_buffer( + self, batch: Union[Batch, SeqBatch, tuple, list], dl: DataLoader + ) -> Union[Batch, SeqBatch]: + if dl.dataset.mp_buffer_size and isinstance(batch, (tuple, list)): batch, wid = batch batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load() - elif isinstance(batch, Batch): + elif isinstance(batch, (Batch, SeqBatch)): batch = batch.to(self.device) return batch @@ -285,7 +287,6 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() - _bd = batch batch = self._maybe_resolve_shared_buffer(batch, train_dl) t1 = time.time() times.append(t1 - t0) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 86a3d971..7ed734bd 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -28,10 +28,11 @@ def __init__(self, size): assert self.buffer.is_pinned() def __del__(self): - if self.do_unreg and torch.utils.data.get_worker_info() is None: - cudart = torch.cuda.cudart() - r = cudart.cudaHostUnregister(self.buffer.data_ptr()) - assert r == 0 + if torch.utils.data.get_worker_info() is None: + if self.do_unreg: + cudart = torch.cuda.cudart() + r = cudart.cudaHostUnregister(self.buffer.data_ptr()) + assert r == 0 class _BufferPicklerSentinel: @@ -43,7 +44,8 @@ def __init__(self, buf: SharedPinnedBuffer): self._f = io.BytesIO() super().__init__(self._f) self.buf = buf - # The lock will be released by the consumer of this buffer once the memory has been transferred to the device + # The lock will be released by the consumer (BufferUnpickler) of this buffer once + # the memory has been transferred to the device and copied self.buf.lock.acquire() self.buf_offset = 0 @@ -51,12 +53,30 @@ def persistent_id(self, v): if not isinstance(v, torch.Tensor): return None numel = v.numel() * v.element_size() + if self.buf_offset + numel > self.buf.size: + raise RuntimeError( + f"Tried to allocate {self.buf_offset + numel} bytes in a buffer of size {self.buf.size}. " + "Consider increasing cfg.mp_buffer_size" + ) start = self.buf_offset + shape = tuple(v.shape) + if v.ndim > 0 and v.stride(-1) != 1 or not v.is_contiguous(): + v = v.contiguous().reshape(-1) + if v.ndim > 0 and v.stride(-1) != 1: + # We're still not contiguous, this unfortunately happens occasionally, e.g.: + # x = torch.arange(10).reshape((10, 1)) + # y = x.T[::2].T + # y.stride(), y.is_contiguous(), y.contiguous().stride() + # -> (1, 2), True, (1, 2) + v = v.flatten() + 0 + # I don't know if this comes from my misunderstanding of strides or if it's a bug in torch + # but either way torch will refuse to view this tensor as a uint8 tensor, so we have to + 0 + # to force torch to materialize it into a new tensor (it may otherwise be lazy and not materialize) if numel > 0: - self.buf.buffer[start : start + numel] = v.view(-1).view(torch.uint8) + self.buf.buffer[start : start + numel] = v.flatten().view(torch.uint8) self.buf_offset += numel self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes - return (_BufferPicklerSentinel, (start, tuple(v.shape), v.dtype)) + return (_BufferPicklerSentinel, (start, shape, v.dtype)) def dumps(self, obj): self.dump(obj) @@ -68,11 +88,19 @@ def __init__(self, buf: SharedPinnedBuffer, data, device): self._f, total_size = io.BytesIO(data[0]), data[1] super().__init__(self._f) self.buf = buf - self.target_buf = buf.buffer[:total_size].to(device) + self.target_buf = buf.buffer[:total_size].to(device) + 0 + # Why the `+ 0`? Unfortunately, we have no way to know exactly when the consumer of the object we're + # unpickling will be done using the buffer underlying the tensor, so we have to create a copy. + # If we don't and another consumer starts using the buffer, and this consumer transfers this pinned + # buffer to the GPU, the first consumer's tensors will be corrupted, because (depending on the CUDA + # memory manager) the pinned buffer will transfer to the same GPU location. + # Hopefully, especially if the target device is the GPU, the copy will be fast and/or async. + # Note that this could be fixed by using one buffer for each worker, but that would be significantly + # more memory usage. def load_tensor(self, offset, shape, dtype): numel = prod(shape) * dtype.itemsize - tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) + tensor: torch.Tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) return tensor def persistent_load(self, pid): @@ -107,7 +135,7 @@ def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_s self.pickle_messages = pickle_messages self._is_init = False self.shared_buffer_size = shared_buffer_size - if shared_buffer_size is not None: + if shared_buffer_size: self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size) self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size) @@ -194,7 +222,7 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.use_shared_buffer = sb_size is not None + self.use_shared_buffer = bool(sb_size) self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size) self.obj = obj if hasattr(obj, "parameters"): @@ -226,7 +254,6 @@ def to_cpu(self, i): def run(self): timeouts = 0 - while not self.stop.is_set() and timeouts < 5 / 1e-5: for qi, q in enumerate(self.in_queues): try: From 907ffcd513b9531351fce45d56023f602f09deba Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 8 May 2024 16:58:25 -0600 Subject: [PATCH 18/31] fix global_cfg + opt_Z when there's no Z --- src/gflownet/__init__.py | 3 +++ src/gflownet/algo/advantage_actor_critic.py | 7 +++-- src/gflownet/algo/envelope_q_learning.py | 27 +++++++++++-------- src/gflownet/algo/flow_matching.py | 5 ++-- src/gflownet/algo/multiobjective_reinforce.py | 3 ++- src/gflownet/algo/soft_q_learning.py | 7 +++-- src/gflownet/algo/trajectory_balance.py | 13 +++++---- src/gflownet/models/graph_transformer.py | 10 +++---- src/gflownet/models/seq_transformer.py | 2 +- src/gflownet/online_trainer.py | 20 +++++++++----- 10 files changed, 58 insertions(+), 39 deletions(-) diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 6cb8f979..5415ecd8 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -23,6 +23,9 @@ class GFNAlgorithm: def step(self): self.updates += 1 # This isn't used anywhere? + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 7077a9d1..7ce05a81 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -3,6 +3,7 @@ import torch_geometric.data as gd from torch import Tensor +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device @@ -10,7 +11,7 @@ from .graph_sampling import GraphSampler -class A2C: +class A2C(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -36,6 +37,7 @@ def __init__( The experiment configuration """ + self.global_cfg = cfg # TODO: this belongs in the base class self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -149,7 +151,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values - policy, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + policy, per_state_preds = model(batch) V = per_state_preds[:, 0] G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1 G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 9bfc3345..4798600b 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -5,6 +5,7 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import ( GraphActionCategorical, @@ -39,24 +40,24 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective num_layers=num_layers, num_heads=num_heads, ) - num_final = num_emb * 2 + num_final = num_emb num_mlp_layers = 0 self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers) # Edge attr logits are "sided", so we will compute both sides independently self.emb2set_edge_attr = mlp( num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers ) - self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers) - self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers) + self.emb2stop = mlp(num_emb * 2, num_emb, num_objectives, num_mlp_layers) + self.emb2reward = mlp(num_emb * 2, num_emb, 1, num_mlp_layers) self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers) self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) self.action_type_order = env_ctx.action_type_order self.mask_value = -10 self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): + def forward(self, g: gd.Batch, output_Qs=False): """See `GraphTransformer` for argument values""" - node_embeddings, graph_embeddings = self.transf(g, cond) + node_embeddings, graph_embeddings = self.transf(g) # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col]) @@ -86,7 +87,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): # Compute the greedy policy # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes - w = cond[:, -self.num_objectives :] + w = g.cond_info[:, -self.num_objectives :] w_dot_Q = [ (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) for qi, b in zip(cat.logits, cat.batch) @@ -122,8 +123,9 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective self.action_type_order = env_ctx.action_type_order self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch, output_Qs=False): + cond = g.cond_info + node_embeddings, graph_embeddings = self.transf(g) ne_row, ne_col = g.non_edge_index # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] @@ -156,7 +158,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): return cat, r_pred -class EnvelopeQLearning: +class EnvelopeQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -182,6 +184,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.task = task @@ -314,7 +317,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values # Q(s,a,omega) - fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True) + batch.cond_info = cond_info[batch_idx] + fwd_cat, per_state_preds = model(batch, output_Qs=True) Q_omega = fwd_cat.logits # reshape to List[shape: (num in all graphs, num actions on T, num_objectives) | for all types T] Q_omega = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega] @@ -323,7 +327,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: batchp = batch.batch_prime batchp_num_trajs = int(batchp.traj_lens.shape[0]) batchp_batch_idx = torch.arange(batchp_num_trajs, device=dev).repeat_interleave(batchp.traj_lens) - fwd_cat_prime, per_state_preds = model(batchp, batchp.cond_info[batchp_batch_idx], output_Qs=True) + batchp.cond_info = batchp.cond_info[batchp_batch_idx] + fwd_cat_prime, per_state_preds = model(batchp, output_Qs=True) Q_omega_prime = fwd_cat_prime.logits # We've repeated everything N_omega times, so we can reshape the same way as above but with # an extra N_omega first dimension diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index 33c436bf..c75c1ce4 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -46,7 +46,7 @@ def __init__( # in a number of settings the regular loss is more stable. self.fm_balanced_loss = cfg.algo.fm.balanced_loss self.fm_leaf_coef = cfg.algo.fm.leaf_coef - self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent + self.correct_idempotent: bool = cfg.algo.fm.correct_idempotent def construct_batch(self, trajs, cond_info, log_rewards): """Construct a batch from a list of trajectories and their information @@ -149,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Query the model for Fsa. The model will output a GraphActionCategorical, but we will # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of # log F(s,a) so we don't have to change anything in the sampling routines. - cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) + batch.cond_info = batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)] + cat, graph_out = model(batch) # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index b1a636de..52314d03 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -34,7 +34,8 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) # Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions - fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + fwd_cat, log_reward_preds = model(batch) # This is the log prob of each action in the trajectory log_prob = fwd_cat.log_prob(batch.actions) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index a9d61aaa..99730279 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -4,13 +4,14 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device -class SoftQLearning: +class SoftQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -33,6 +34,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -147,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per object predictions # Here we will interpret the logits of the fwd_cat as Q values - Q, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + Q, per_state_preds = model(batch) if self.do_q_prime_correction: # First we need to estimate V_soft. We will use q_a' = \pi diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index eac57cc6..8713cde3 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -109,7 +109,7 @@ def __init__( """ self.ctx = ctx self.env = env - self.global_cfg = cfg + self.global_cfg = cfg # TODO: this belongs in the base class self.cfg = cfg.algo.tb self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes @@ -147,9 +147,6 @@ def __init__( self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(get_worker_device()) - def set_is_eval(self, is_eval: bool): - self.is_eval = is_eval - def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, @@ -402,12 +399,14 @@ def compute_batch_losses( # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: - fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, bck_cat, per_graph_out = model(batch) else: if self.model_is_autoregressive: - fwd_cat, per_graph_out = model(batch, cond_info, batched=True) + fwd_cat, per_graph_out = model(batch, batched=True) else: - fwd_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, per_graph_out = model(batch) # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index a010b22a..b84980dc 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -90,7 +90,7 @@ def __init__( ) ) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): + def forward(self, g: gd.Batch): """Forward pass Parameters @@ -112,7 +112,7 @@ def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): x = g.x o = self.x2h(x) e = self.e2h(g.edge_attr) - c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) + c = self.c2h(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] # Augment the edges with a new edge to the conditioning # information node. This new node is connected to every node @@ -255,10 +255,8 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap types=action_types, ) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): - if cond is None: - cond = g.cond_info - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch): + node_embeddings, graph_embeddings = self.transf(g) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): ne_row, ne_col = g.non_edge_index diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index b694fad2..ce696d2d 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -65,7 +65,7 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) return self._logZ(cond_info) - def forward(self, xs: SeqBatch, cond, batched=False): + def forward(self, xs: SeqBatch, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. Parameters diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 103acc95..13dd0e48 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -82,12 +82,17 @@ def setup(self): else: Z_params = [] non_Z_params = list(self.model.parameters()) + self.opt = self._opt(non_Z_params) - self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) - ) + + if Z_params: + self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + else: + self.opt_Z = None self.sampling_tau = self.cfg.algo.sampling_tau if self.sampling_tau > 0: @@ -124,10 +129,11 @@ def step(self, loss: Tensor): g1 = model_grad_norm(self.model) self.opt.step() self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() self.lr_sched.step() - self.lr_sched_Z.step() + if self.opt_Z is not None: + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched_Z.step() if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) From 60722a777ad27b44ea5de245acfb9118ca57d541 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 9 May 2024 08:17:28 -0600 Subject: [PATCH 19/31] fix entropy when masks are used --- src/gflownet/envs/graph_building_env.py | 6 ++++-- src/gflownet/models/seq_transformer.py | 4 +--- src/gflownet/trainer.py | 1 + tests/test_graph_building_env.py | 22 +++++++++++++++++++++- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index b10b228d..601d7bd5 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -887,10 +887,12 @@ def entropy(self, logprobs=None): """ if logprobs is None: logprobs = self.logsoftmax() + masks = self.action_masks if self.action_masks is not None else [None] * len(logprobs) entropy = -sum( [ - scatter(i * i.exp(), b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) - for i, b in zip(logprobs, self.batch) + scatter(im, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) + for i, b, m in zip(logprobs, self.batch, masks) + for im in [i.masked_fill(m == 0.0, 0) if m is not None else i] ] ) return entropy diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index ce696d2d..8ecb8919 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -83,9 +83,7 @@ def forward(self, xs: SeqBatch, batched=False): x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) - if cond is None: - cond = xs.cond_info - + cond = xs.cond_info if self.use_cond: cond_var = self.cond_embed(cond) # (batch, nemb) cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index c537b0ce..fcd078a8 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -219,6 +219,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: tick = time.time() self.model.train() try: + loss = info = None loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") diff --git a/tests/test_graph_building_env.py b/tests/test_graph_building_env.py index e9184cbd..adf41120 100644 --- a/tests/test_graph_building_env.py +++ b/tests/test_graph_building_env.py @@ -123,4 +123,24 @@ def test_log_prob(): def test_entropy(): cat = make_test_cat() - cat.entropy() + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + cat.action_masks = [ + torch.tensor([[0], [1], [1.0]]), + ((torch.arange(cat.logits[1].numel()) % 2) == 0).float().reshape(cat.logits[1].shape), + torch.tensor([[1, 0, 1], [0, 1, 1.0]]), + ] + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + +def test_entropy_grad(): + # Purposefully large values to test extremal behaviors + logits = torch.tensor([[100, 101, -102, 95, 10, 20, 72]]).float() + logits.requires_grad_(True) + batch = Batch.from_data_list([Data(x=torch.ones((1, 10)), y=torch.ones((2, 6)))], follow_batch=["y"]) + cat = GraphActionCategorical(batch, [logits[:, :3], logits[:, 3:].reshape(2, 2)], [None, "y"], [None, None]) + cat._epsilon = 0 + (grad_gac,) = torch.autograd.grad(cat.entropy(), logits, retain_graph=True) + assert torch.isfinite(grad_gac).all() From f859640929fa33c7542be33981631a75c387d9be Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 9 May 2024 08:36:40 -0600 Subject: [PATCH 20/31] small fixes --- docs/implementation_notes.md | 5 +---- src/gflownet/algo/config.py | 1 - src/gflownet/config.py | 2 +- src/gflownet/tasks/seh_frag.py | 5 +---- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 4b85f2e4..ba63e708 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -53,7 +53,7 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo - Generating new trajectories (w.r.t a fixed dataset of conditioning goals) - Evaluating the model's likelihood on trajectories from a fixed, offline dataset -## Multiprocessing +## Multiprocessing We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks. @@ -66,6 +66,3 @@ On message serialization, naively sending batches of data and results (`Batch` a We implement two solutions to this problem (in order of preference): - using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages. - using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle. - - - diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index c5b1ea8c..4dd9cbfe 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -196,7 +196,6 @@ class AlgoConfig(StrictDataClass): train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 - compute_log_n: bool = False tb: TBConfig = field(default_factory=TBConfig) moql: MOQLConfig = field(default_factory=MOQLConfig) a2c: A2CConfig = field(default_factory=A2CConfig) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index d72bd567..86b225f0 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -80,7 +80,7 @@ class Config(StrictDataClass): pickle_mp_messages : bool Whether to pickle messages sent between processes (only relevant if num_workers > 0) mp_buffer_size : Optional[int] - If specified, use a buffer of this size for passing tensors between processes. + If specified, use a buffer of this size in bytes for passing tensors between processes. Note that this is only relevant if num_workers > 0. Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers. git_hash : Optional[str] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 84da11c0..de4b72f0 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -129,6 +129,7 @@ class SEHFragTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False + cfg.mp_buffer_size = 32 * 1024**2 # 32 MB cfg.num_workers = 8 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 @@ -195,18 +196,14 @@ def main(): config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.overwrite_existing_exp = True - config.algo.num_from_policy = 64 config.num_training_steps = 1_00 config.validate_every = 20 config.num_final_gen_steps = 10 config.num_workers = 1 config.opt.lr_decay = 20_000 - config.opt.clip_grad_type = "total_norm" config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - config.mp_buffer_size = 32 * 1024**2 - # config.pickle_mp_messages = True trial = SEHFragTrainer(config) trial.run() From d536233ecde108e8a3dad5bb7640f9b75c2771c4 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 9 May 2024 08:59:47 -0600 Subject: [PATCH 21/31] removing timing prints --- src/gflownet/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index fcd078a8..4d90fa3a 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -279,8 +279,6 @@ def run(self, logger=None): num_training_steps = self.cfg.num_training_steps logger.info("Starting training") start_time = time.time() - t0 = time.time() - times = [] for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): # the memory fragmentation or allocation keeps growing, how often should we clean up? # is changing the allocation strategy helpful? @@ -289,10 +287,6 @@ def run(self, logger=None): gc.collect() torch.cuda.empty_cache() batch = self._maybe_resolve_shared_buffer(batch, train_dl) - t1 = time.time() - times.append(t1 - t0) - print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") - t0 = t1 epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: From 6c3bebaa4721dacccf0f9bcc0f44d50b0cf3f3be Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 21 Aug 2024 13:59:28 -0600 Subject: [PATCH 22/31] C graphs, DDP, logit scaling --- setup.py | 19 +- src/C/data.c | 612 ++++++++++ src/C/degree_view.c | 70 ++ src/C/edge_view.c | 274 +++++ src/C/graph_def.c | 152 +++ src/C/main.c | 1143 +++++++++++++++++++ src/C/main.h | 115 ++ src/C/mol_graph_to_Data.c | 390 +++++++ src/C/node_view.c | 244 ++++ src/gflownet/algo/graph_sampling.py | 5 +- src/gflownet/algo/trajectory_balance.py | 9 +- src/gflownet/config.py | 2 + src/gflownet/data/data_source.py | 13 +- src/gflownet/envs/graph_building_env.py | 56 +- src/gflownet/envs/mol_building_env.py | 36 +- src/gflownet/models/graph_transformer.py | 6 +- src/gflownet/trainer.py | 40 +- src/gflownet/utils/misc.py | 2 +- src/gflownet/utils/multiprocessing_proxy.py | 10 +- src/gflownet/utils/sqlite_log.py | 5 +- 20 files changed, 3159 insertions(+), 44 deletions(-) create mode 100644 src/C/data.c create mode 100644 src/C/degree_view.c create mode 100644 src/C/edge_view.c create mode 100644 src/C/graph_def.c create mode 100644 src/C/main.c create mode 100644 src/C/main.h create mode 100644 src/C/mol_graph_to_Data.c create mode 100644 src/C/node_view.c diff --git a/setup.py b/setup.py index 26bcab4d..e5cccec7 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from ast import literal_eval from subprocess import check_output # nosec - command is hard-coded, no possibility of injection -from setuptools import setup +from setuptools import Extension, setup def _get_next_version(): @@ -24,5 +24,18 @@ def _get_next_version(): latest_patch = -1 return f"{major}.{minor}.{latest_patch+1}" - -setup(name="gflownet", version=_get_next_version()) +ext = [ + Extension( + name="gflownet._C", + sources=[ + "src/C/main.c", + "src/C/data.c", + "src/C/graph_def.c", + "src/C/node_view.c", + "src/C/edge_view.c", + "src/C/degree_view.c", + "src/C/mol_graph_to_Data.c", + ], + ) +] +setup(name="gflownet", version=_get_next_version(), ext_modules=ext) diff --git a/src/C/data.c b/src/C/data.c new file mode 100644 index 00000000..1b8e3c3c --- /dev/null +++ b/src/C/data.c @@ -0,0 +1,612 @@ +#define PY_SSIZE_T_CLEAN +#include "main.h" +#include "structmember.h" +#include + +static void Data_dealloc(Data *self) { + Py_XDECREF(self->bytes); + Py_XDECREF(self->graph_def); + free(self->shapes); + free(self->is_float); + free(self->offsets); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *torch_module = NULL; +static PyObject *torch_gd_module = NULL; +static void _check_torch() { + if (torch_module == NULL) { + torch_module = PyImport_ImportModule("torch"); + torch_gd_module = PyImport_ImportModule("torch_geometric.data"); + } +} + +static PyObject *gbe_module = NULL; +static PyObject *Stop, *AddNode, *SetNodeAttr, *AddEdge, *SetEdgeAttr, *RemoveNode, *RemoveNodeAttr, *RemoveEdge, + *RemoveEdgeAttr, *GraphAction, *GraphActionType; +void _check_gbe() { + if (gbe_module == NULL) { + gbe_module = PyImport_ImportModule("gflownet.envs.graph_building_env"); + if (gbe_module == NULL) { + PyErr_SetString(PyExc_ImportError, "Could not import gflownet.envs.graph_building_env"); + return; + } + GraphActionType = PyObject_GetAttrString(gbe_module, "GraphActionType"); + GraphAction = PyObject_GetAttrString(gbe_module, "GraphAction"); + Stop = PyObject_GetAttrString(GraphActionType, "Stop"); + AddNode = PyObject_GetAttrString(GraphActionType, "AddNode"); + SetNodeAttr = PyObject_GetAttrString(GraphActionType, "SetNodeAttr"); + AddEdge = PyObject_GetAttrString(GraphActionType, "AddEdge"); + SetEdgeAttr = PyObject_GetAttrString(GraphActionType, "SetEdgeAttr"); + RemoveNode = PyObject_GetAttrString(GraphActionType, "RemoveNode"); + RemoveNodeAttr = PyObject_GetAttrString(GraphActionType, "RemoveNodeAttr"); + RemoveEdge = PyObject_GetAttrString(GraphActionType, "RemoveEdge"); + RemoveEdgeAttr = PyObject_GetAttrString(GraphActionType, "RemoveEdgeAttr"); + } +} + +static PyObject *Data_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + Data *self; + self = (Data *)type->tp_alloc(type, 0); + if (self != NULL) { + self->bytes = NULL; + self->graph_def = NULL; + self->shapes = NULL; + self->is_float = NULL; + self->num_matrices = 0; + self->names = NULL; + self->offsets = NULL; + } + return (PyObject *)self; +} + +static int Data_init(Data *self, PyObject *args, PyObject *kwds) { + if (args != NULL && 0) { + PyErr_SetString(PyExc_KeyError, "Trying to initialize a Data object from Python. Use *_graph_to_Data instead."); + return -1; + } + return 0; +} + +void *Data_ptr_and_shape(Data *self, char *name, int *n, int *m) { + for (int i = 0; i < self->num_matrices; i++) { + if (strcmp(self->names[i], name) == 0) { + *n = self->shapes[i * 2 + 0]; + *m = self->shapes[i * 2 + 1]; + return PyByteArray_AsString(self->bytes) + self->offsets[i]; + } + } + return (void *)0xdeadbeef; +} + +void Data_init_C(Data *self, PyObject *bytes, PyObject *graph_def, int shapes[][2], int *is_float, int num_matrices, + const char **names) { + PyObject *tmp; + tmp = (PyObject *)self->bytes; + Py_INCREF(bytes); + self->bytes = bytes; + Py_XDECREF(tmp); + + tmp = (PyObject *)self->graph_def; + Py_INCREF(graph_def); + self->graph_def = (GraphDef *)graph_def; + Py_XDECREF(tmp); + + self->shapes = malloc(sizeof(int) * num_matrices * 2); + self->offsets = malloc(sizeof(int) * num_matrices); + int offset_bytes = 0; + for (int i = 0; i < num_matrices; i++) { + self->shapes[i * 2 + 0] = shapes[i][0]; + self->shapes[i * 2 + 1] = shapes[i][1]; + self->offsets[i] = offset_bytes; + offset_bytes += shapes[i][0] * shapes[i][1] * (is_float[i] ? 4 : 8); + } + self->is_float = malloc(sizeof(int) * num_matrices); + memcpy(self->is_float, is_float, sizeof(int) * num_matrices); + self->num_matrices = num_matrices; + self->names = names; +} + +PyObject *Data_mol_aidx_to_GraphAction(PyObject *_self, PyObject *args) { + _check_gbe(); + int _t, row, col; // _t is unused, we have gt from python + PyObject *gt = NULL; + if (!PyArg_ParseTuple(args, "(iii)O", &_t, &row, &col, >)) { + return NULL; + } + Data *self = (Data *)_self; + // These are singletons, so we can just compare pointers! + if (gt == Stop) { + PyObject *args = PyTuple_Pack(1, Stop); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(args); + return res; + } + if (gt == AddNode) { + // return GraphAction(t, source=act_row, value=self.atom_attr_values["v"][act_col]) + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *vals = PyDict_GetItemString(self->graph_def->node_values, "v"); // ref is borrowed + PyObject *value = PyList_GetItem(vals, col); // ref is borrowed + PyObject *args = PyTuple_Pack(5, gt, source, Py_None, value, Py_None); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == SetNodeAttr) { + // attr, val = self.atom_attr_logit_map[act_col] + // return GraphAction(t, source=act_row, attr=attr, value=val) + PyObject *attr, *val; + for (int i = 0; i < self->graph_def->num_settable_node_attrs; i++) { + int start = self->graph_def->node_attr_offsets[i] - i; // - i because these are logits + int end = self->graph_def->node_attr_offsets[i + 1] - (i + 1); + if (start <= col && col < end) { + attr = PyList_GetItem(self->graph_def->node_poskey, i); + PyObject *vals = PyDict_GetItem(self->graph_def->node_values, attr); + val = PyList_GetItem(vals, col - start + 1); // + 1 because the default is 0 + break; + } + } + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *args = PyTuple_Pack(5, gt, source, Py_None, val, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == AddEdge) { + int n, m; + long *non_edge_index = Data_ptr_and_shape(self, "non_edge_index", &n, &m); + PyObject *u = PyLong_FromLong(non_edge_index[row]); + PyObject *v = PyLong_FromLong(non_edge_index[row + m]); + PyObject *args = PyTuple_Pack(3, AddEdge, u, v); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(args); + return res; + } + if (gt == SetEdgeAttr) { + // attr, val = self.bond_attr_logit_map[act_col] + // return GraphAction(t, source=act_row, attr=attr, value=val) + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + // edge_index should be (2, m), * 2 because edges are duplicated + PyObject *u = PyLong_FromLong(edge_index[row * 2]); + PyObject *v = PyLong_FromLong(edge_index[row * 2 + m]); + PyObject *attr, *val = NULL; + for (int i = 0; i < self->graph_def->num_settable_edge_attrs; i++) { + int start = self->graph_def->edge_attr_offsets[i] - i; // - i because these are logits + int end = self->graph_def->edge_attr_offsets[i + 1] - (i + 1); + if (start <= col && col < end) { + attr = PyList_GetItem(self->graph_def->edge_poskey, i); + PyObject *vals = PyDict_GetItem(self->graph_def->edge_values, attr); + val = PyList_GetItem(vals, col - start + 1); // + 1 because the default is 0 + break; + } + } + if (val == NULL) { + PyErr_SetString(PyExc_ValueError, "failed to find edge attr"); + return NULL; + } + PyObject *args = PyTuple_Pack(5, gt, u, v, val, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(args); + return res; + } + if (gt == RemoveNode) { + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *args = PyTuple_Pack(2, RemoveNode, source); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == RemoveNodeAttr) { + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *attr = PyList_GetItem(self->graph_def->node_poskey, col); // this works because 'v' is last + PyObject *args = PyTuple_Pack(5, RemoveNodeAttr, source, Py_None, Py_None, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == RemoveEdge) { + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + // edge_index should be (2, m), * 2 because edges are duplicated + PyObject *u = PyLong_FromLong(edge_index[row * 2 + 0]); + PyObject *v = PyLong_FromLong(edge_index[row * 2 + m]); + PyObject *vargs = PyTuple_Pack(3, RemoveEdge, u, v); + PyObject *res = PyObject_CallObject(GraphAction, vargs); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(vargs); + return res; + } + if (gt == RemoveEdgeAttr) { + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + // edge_index should be (2, m), * 2 because edges are duplicated + PyObject *u = PyLong_FromLong(edge_index[row * 2 + 0]); + PyObject *v = PyLong_FromLong(edge_index[row * 2 + m]); + PyObject *attr = PyList_GetItem(self->graph_def->edge_poskey, col); + PyObject *args = PyTuple_Pack(5, RemoveEdgeAttr, u, v, Py_None, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(args); + return res; + } + PyErr_SetString(PyExc_ValueError, "Unknown action type"); + return NULL; +} + +PyObject *Data_mol_GraphAction_to_aidx(PyObject *_self, PyObject *args) { + _check_gbe(); + + PyObject *action = NULL; + if (!PyArg_ParseTuple(args, "O", &action)) { + return NULL; + } + Data *self = (Data *)_self; + PyObject *action_type = PyObject_GetAttrString(action, "action"); // new ref + // These are singletons, so we can just compare pointers! + long row = 0, col = 0; + if (action_type == Stop) { + // return (0, 0) + } else if (action_type == AddNode) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + PyObject *val = PyObject_GetAttrString(action, "value"); + PyObject *vals = PyDict_GetItemString(self->graph_def->node_values, "v"); + col = PySequence_Index(vals, val); + Py_DECREF(val); + } else if (action_type == SetNodeAttr) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + PyObject *attr = PyObject_GetAttrString(action, "attr"); + PyObject *val = PyObject_GetAttrString(action, "value"); + PyObject *vals = PyDict_GetItem(self->graph_def->node_values, attr); + col = PySequence_Index(vals, val) - 1; // -1 because the default value is at index 0 + int attr_pos = PyLong_AsLong(PyDict_GetItem(self->graph_def->node_keypos, attr)); + col += self->graph_def->node_attr_offsets[attr_pos] - attr_pos; + Py_DECREF(attr); + Py_DECREF(val); + } else if (action_type == AddEdge) { + col = 0; + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *non_edge_index = Data_ptr_and_shape(self, "non_edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((non_edge_index[i] == u && non_edge_index[i + m] == v) || + (non_edge_index[i] == v && non_edge_index[i + m] == u)) { + row = i; + break; + } + } + } else if (action_type == SetEdgeAttr) { + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((edge_index[i] == u && edge_index[i + m] == v) || (edge_index[i] == v && edge_index[i + m] == u)) { + row = i / 2; // edges are duplicated + break; + } + } + PyObject *attr = PyObject_GetAttrString(action, "attr"); + PyObject *val = PyObject_GetAttrString(action, "value"); + PyObject *vals = PyDict_GetItem(self->graph_def->edge_values, attr); + col = PySequence_Index(vals, val) - 1; // -1 because the default value is at index 0 + int attr_pos = PyLong_AsLong(PyDict_GetItem(self->graph_def->edge_keypos, attr)); + col += self->graph_def->edge_attr_offsets[attr_pos] - attr_pos; + Py_DECREF(attr); + Py_DECREF(val); + } else if (action_type == RemoveNode) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + col = 0; + } else if (action_type == RemoveNodeAttr) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + PyObject *attr = PyObject_GetAttrString(action, "attr"); + col = PyLong_AsLong(PyDict_GetItem(self->graph_def->node_keypos, attr)); + } else if (action_type == RemoveEdge) { + col = 0; + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((edge_index[i] == u && edge_index[i + m] == v) || (edge_index[i] == v && edge_index[i + m] == u)) { + row = i / 2; // edges are duplicated + break; + } + } + } else if (action_type == RemoveEdgeAttr) { + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((edge_index[i] == u && edge_index[i + m] == v) || (edge_index[i] == v && edge_index[i + m] == u)) { + row = i / 2; // edges are duplicated + break; + } + } + PyObject *attr = PyObject_GetAttrString(action, "attr"); + col = PyLong_AsLong(PyDict_GetItem(self->graph_def->edge_keypos, attr)); + } else { + PyErr_SetString(PyExc_ValueError, "Unknown action type"); + return NULL; + } + PyObject *py_row = PyLong_FromLong(row); + PyObject *py_col = PyLong_FromLong(col); + PyObject *res = PyTuple_Pack(2, py_row, py_col); + Py_DECREF(py_row); + Py_DECREF(py_col); + Py_DECREF(action_type); + return res; +} + +PyObject *Data_as_torch(PyObject *_self, PyObject *unused_args) { + _check_torch(); + Data *self = (Data *)_self; + PyObject *res = PyDict_New(); + PyObject *frombuffer = PyObject_GetAttrString(torch_module, "frombuffer"); + PyObject *empty = PyObject_GetAttrString(torch_module, "empty"); + PyObject *dtype_f32 = PyObject_GetAttrString(torch_module, "float32"); + PyObject *dtype_i64 = PyObject_GetAttrString(torch_module, "int64"); + PyObject *fb_args = PyTuple_Pack(1, self->bytes); + PyObject *fb_kwargs = PyDict_New(); + int do_del_kw = 0; + int offset = 0; + for (int i = 0; i < self->num_matrices; i++) { + int i_num_items = self->shapes[i * 2 + 0] * self->shapes[i * 2 + 1]; + PyObject *tensor; + PyDict_SetItemString(fb_kwargs, "dtype", self->is_float[i] ? dtype_f32 : dtype_i64); + if (i_num_items == 0) { + if (do_del_kw) { + PyDict_DelItemString(fb_kwargs, "offset"); + PyDict_DelItemString(fb_kwargs, "count"); + do_del_kw = 0; + } + PyObject *zero = PyLong_FromLong(0); + PyObject *args = PyTuple_Pack(1, zero); + tensor = PyObject_Call(empty, args, fb_kwargs); + Py_DECREF(args); + Py_DECREF(zero); + } else { + PyObject *py_offset = PyLong_FromLong(offset); + PyObject *py_numi = PyLong_FromLong(i_num_items); + PyDict_SetItemString(fb_kwargs, "offset", py_offset); + PyDict_SetItemString(fb_kwargs, "count", py_numi); + Py_DECREF(py_offset); + Py_DECREF(py_numi); + do_del_kw = 1; + tensor = PyObject_Call(frombuffer, fb_args, fb_kwargs); + } + PyObject *reshaped_tensor = + PyObject_CallMethod(tensor, "view", "ii", self->shapes[i * 2 + 0], self->shapes[i * 2 + 1]); + PyDict_SetItemString(res, self->names[i], reshaped_tensor); + Py_DECREF(tensor); + Py_DECREF(reshaped_tensor); + offset += i_num_items * (self->is_float[i] ? 4 : 8); + } + Py_DECREF(frombuffer); + Py_DECREF(dtype_f32); + Py_DECREF(dtype_i64); + Py_DECREF(fb_args); + Py_DECREF(fb_kwargs); + + PyObject *Data_cls = PyObject_GetAttrString(torch_gd_module, "Data"); // new ref + PyObject *args = PyTuple_New(0); + PyObject *Data_res = PyObject_Call(Data_cls, args, res); + Py_DECREF(args); + Py_DECREF(Data_cls); + Py_DECREF(res); + return Data_res; +} + +static PyMemberDef Data_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef Data_methods[] = { + {"mol_aidx_to_GraphAction", (PyCFunction)Data_mol_aidx_to_GraphAction, METH_VARARGS, "mol_aidx_to_GraphAction"}, + {"mol_GraphAction_to_aidx", (PyCFunction)Data_mol_GraphAction_to_aidx, METH_VARARGS, "mol_GraphAction_to_aidx"}, + {"as_torch", (PyCFunction)Data_as_torch, METH_NOARGS, "to pyg data"}, + {NULL} /* Sentinel */ +}; + +PyTypeObject DataType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.Data", + .tp_doc = PyDoc_STR("Constrained Data object"), + .tp_basicsize = sizeof(Data), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = Data_new, + .tp_init = (initproc)Data_init, + .tp_dealloc = (destructor)Data_dealloc, + .tp_members = Data_members, + .tp_methods = Data_methods, +}; + +#include + +PyObject *Data_collate(PyObject *self, PyObject *args) { + PyObject *graphs = NULL, *follow_batch = NULL; + if (!PyArg_ParseTuple(args, "O|O", &graphs, &follow_batch)) { + return NULL; + } + if (!PyList_Check(graphs) || PyList_Size(graphs) < 1 || !PyObject_TypeCheck(PyList_GetItem(graphs, 0), &DataType)) { + PyErr_SetString(PyExc_TypeError, "Data_collate expects a non-empty list of Data objects"); + return NULL; + } + _check_torch(); + PyObject *empty_cb = PyObject_GetAttrString(torch_module, "empty"); + PyObject *empty_cb_int64_arg = PyDict_New(); + PyDict_SetItemString(empty_cb_int64_arg, "dtype", PyObject_GetAttrString(torch_module, "int64")); + int num_graphs = PyList_Size(graphs); + // PyObject *batch = PyObject_CallMethod(torch_gd_module, "Batch", NULL); + // Actually we want batch = Batch(_base_cls=graphs[0].__class__) + PyObject *base_cls = PyObject_GetAttrString(PyList_GetItem(graphs, 0), "__class__"); + PyObject *batch = PyObject_CallMethod(torch_gd_module, "Batch", "O", base_cls); + PyObject *slice_dict = PyDict_New(); + Data *first = (Data *)PyList_GetItem(graphs, 0); + int index_of_x = 0; + for (int i = 0; i < first->num_matrices; i++) { + if (strcmp(first->names[i], "x") == 0) { + index_of_x = i; + break; + } + } + for (int i = 0; i < first->num_matrices; i++) { + PyObject *tensor = NULL; + const char *name = first->names[i]; + int do_compute_batch = 0; + if (strcmp(name, "x") == 0) { + do_compute_batch = 1; + } else if (follow_batch != NULL) { + PyObject *py_name = PyUnicode_FromString(name); + do_compute_batch = PySequence_Contains(follow_batch, py_name); + Py_DECREF(py_name); + } + int item_size = (first->is_float[i] ? sizeof(float) : sizeof(long)); + // First let's count the total number of items in the batch + int num_rows = 0; + int cat_dim = 0, val_dim = 1; + if (strstr(name, "index") != NULL) { + cat_dim = 1; + val_dim = 0; + } + int val_dim_size = first->shapes[i * 2 + val_dim]; + for (int j = 0; j < num_graphs; j++) { + Data *data = (Data *)PyList_GetItem(graphs, j); + // We're concatenating along the first dimension, so we should check that the second dimension is the same + if (data->shapes[i * 2 + val_dim] != val_dim_size) { + PyErr_Format(PyExc_TypeError, + "mol_Data_collate concatenates %s along dimension %d, but tensor has shape %d along " + "dimension %d", + name, cat_dim, data->shapes[i * 2 + val_dim], val_dim); + return NULL; + } + num_rows += data->shapes[i * 2 + cat_dim]; + } + // Now we allocate the tensor itself, its batch tensor & slices + PyObject *py_num_rows = PyLong_FromLong(num_rows); + PyObject *py_val_dim_size = PyLong_FromLong(val_dim_size); + PyObject *py_shape = cat_dim == 0 ? PyTuple_Pack(2, py_num_rows, py_val_dim_size) + : PyTuple_Pack(2, py_val_dim_size, py_num_rows); + int tensor_shape_x = cat_dim == 0 ? num_rows : val_dim_size; + int tensor_shape_y = cat_dim == 0 ? val_dim_size : num_rows; + if (first->is_float[i]) { + tensor = PyObject_Call(empty_cb, py_shape, NULL); + } else { + tensor = PyObject_Call(empty_cb, py_shape, empty_cb_int64_arg); + } + int _total_ni = val_dim_size * num_rows; + Py_DECREF(py_shape); + PyObject *batch_tensor = NULL; + if (do_compute_batch) { + py_shape = PyTuple_Pack(1, py_num_rows); + batch_tensor = PyObject_Call(empty_cb, py_shape, empty_cb_int64_arg); + Py_DECREF(py_shape); + } + PyObject *py_num_graphsp1 = PyLong_FromLong(num_graphs + 1); + py_shape = PyTuple_Pack(1, py_num_graphsp1); + PyObject *slice_tensor = PyObject_Call(empty_cb, py_shape, empty_cb_int64_arg); + Py_DECREF(py_shape); + Py_DECREF(py_num_rows); + Py_DECREF(py_val_dim_size); + Py_DECREF(py_num_graphsp1); + PyObject *ptr = PyObject_CallMethod(tensor, "data_ptr", ""); + void *tensor_ptr = PyLong_AsVoidPtr(ptr); + Py_DECREF(ptr); + long *batch_tensor_ptr; + if (do_compute_batch) { + ptr = PyObject_CallMethod(batch_tensor, "data_ptr", ""); + batch_tensor_ptr = PyLong_AsVoidPtr(ptr); + Py_DECREF(ptr); + } + ptr = PyObject_CallMethod(slice_tensor, "data_ptr", ""); + long *slice_ptr = PyLong_AsVoidPtr(ptr); + Py_DECREF(ptr); + // Now we copy the data from the individual Data objects to the batch + int offset_bytes = 0; + int offset_items = 0; + int offset_rows = 0; + slice_ptr[0] = 0; + int value_increment = 0; // we need to increment edge indices across graphs + int do_increment = strstr(name, "index") != NULL; + for (int j = 0; j < num_graphs; j++) { + Data *data = (Data *)PyList_GetItem(graphs, j); // borrowed ref + int num_items_j = data->shapes[i * 2 + 0] * data->shapes[i * 2 + 1]; + if (cat_dim == 0) { + // we're given a 2d matrix m, n and we want to fit it into a bigger matrix M, n + void *dst = tensor_ptr; + void *src = PyByteArray_AsString(data->bytes) + data->offsets[i]; + for (int u = 0; u < data->shapes[i * 2 + 0]; u++) { + for (int v = 0; v < data->shapes[i * 2 + 1]; v++) { + // dst[u + offset_rows, v] = src[u, v] + if (first->is_float[i]) { + ((float *)dst)[(u + offset_rows) * tensor_shape_y + v] = + ((float *)src)[u * data->shapes[i * 2 + 1] + v]; + } else { + ((long *)dst)[(u + offset_rows) * tensor_shape_y + v] = + ((long *)src)[u * data->shapes[i * 2 + 1] + v] + value_increment; + } + } + } + } else { + // we're given a 2d matrix m, n and we want to fit it into a bigger matrix m, N + void *dst = tensor_ptr; + void *src = PyByteArray_AsString(data->bytes) + data->offsets[i]; + for (int u = 0; u < data->shapes[i * 2 + 0]; u++) { + for (int v = 0; v < data->shapes[i * 2 + 1]; v++) { + // dst[u, v + offset_rows] = src[u, v] + if (first->is_float[i]) { + ((float *)dst)[u * tensor_shape_y + v + offset_rows] = + ((float *)src)[u * data->shapes[i * 2 + 1] + v]; + } else { + ((long *)dst)[u * tensor_shape_y + v + offset_rows] = + ((long *)src)[u * data->shapes[i * 2 + 1] + v] + value_increment; + } + } + } + } + if (do_compute_batch) { + for (int k = 0; k < data->shapes[i * 2 + cat_dim]; k++) { + batch_tensor_ptr[k + offset_rows] = j; + } + } + offset_rows += data->shapes[i * 2 + cat_dim]; + offset_items += num_items_j; + offset_bytes += num_items_j * item_size; + slice_ptr[j + 1] = offset_rows; + if (do_increment) { + value_increment += data->shapes[index_of_x * 2]; // increment by num_nodes + } + } + PyObject_SetAttrString(batch, name, tensor); + // for x, the batch is just 'batch', otherwise '%s_batch' + if (strcmp(name, "x") == 0) { + PyObject_SetAttrString(batch, "batch", batch_tensor); + } else if (do_compute_batch) { + char buf[100]; + sprintf(buf, "%s_batch", name); + PyObject_SetAttrString(batch, buf, batch_tensor); + } + PyDict_SetItemString(slice_dict, name, slice_tensor); + Py_DECREF(tensor); + Py_XDECREF(batch_tensor); + Py_DECREF(slice_tensor); + } + PyObject_SetAttrString(batch, "_slice_dict", slice_dict); + Py_DECREF(empty_cb); + Py_DECREF(empty_cb_int64_arg); + Py_DECREF(slice_dict); + + return batch; +} diff --git a/src/C/degree_view.c b/src/C/degree_view.c new file mode 100644 index 00000000..7e752946 --- /dev/null +++ b/src/C/degree_view.c @@ -0,0 +1,70 @@ +#define PY_SSIZE_T_CLEAN +#include "./main.h" +#include "structmember.h" +#include + +static void DegreeView_dealloc(DegreeView *self) { + Py_XDECREF(self->graph); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *DegreeView_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + DegreeView *self; + self = (DegreeView *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph = NULL; + } + return (PyObject *)self; +} + +static int DegreeView_init(DegreeView *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph", NULL}; + PyObject *graph = NULL, *tmp; + int node_id = -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &graph, &node_id)) + return -1; + + if (graph) { + tmp = (PyObject *)self->graph; + Py_INCREF(graph); + self->graph = (Graph *)graph; + Py_XDECREF(tmp); + } + return 0; +} + +static PyMemberDef DegreeView_members[] = { + {NULL} /* Sentinel */ +}; + +static PyObject *DegreeView_getitem(PyObject *_self, PyObject *k) { + DegreeView *self = (DegreeView *)_self; + int index = PyLong_AsLong(k); + if (PyErr_Occurred()) { + return NULL; + } + return PyLong_FromLong(self->graph->degrees[index]); +} + +static PyMappingMethods DegreeView_mapmeth = { + .mp_subscript = DegreeView_getitem, +}; + +static PyMethodDef DegreeView_methods[] = { + {NULL} /* Sentinel */ +}; + +PyTypeObject DegreeViewType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.DegreeView", + .tp_doc = PyDoc_STR("DegreeView object"), + .tp_basicsize = sizeof(DegreeView), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = DegreeView_new, + .tp_init = (initproc)DegreeView_init, + .tp_dealloc = (destructor)DegreeView_dealloc, + .tp_members = DegreeView_members, + .tp_methods = DegreeView_methods, + .tp_as_mapping = &DegreeView_mapmeth, +}; \ No newline at end of file diff --git a/src/C/edge_view.c b/src/C/edge_view.c new file mode 100644 index 00000000..6cc40f67 --- /dev/null +++ b/src/C/edge_view.c @@ -0,0 +1,274 @@ +#include "./main.h" +#include "structmember.h" + +static void EdgeView_dealloc(EdgeView *self) { + Py_XDECREF(self->graph); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *EdgeView_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + EdgeView *self; + self = (EdgeView *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph = NULL; + } + return (PyObject *)self; +} + +static int EdgeView_init(EdgeView *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph", "index", NULL}; + PyObject *graph = NULL, *tmp; + int index = -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i", kwlist, &graph, &index)) + return -1; + + if (graph) { + tmp = (PyObject *)self->graph; + Py_INCREF(graph); + self->graph = (Graph *)graph; + Py_XDECREF(tmp); + self->index = index; + } + return 0; +} + +static PyMemberDef EdgeView_members[] = { + {NULL} /* Sentinel */ +}; + +static int EdgeView_setitem(PyObject *_self, PyObject *k, PyObject *v) { + EdgeView *self = (EdgeView *)_self; + if (self->index < 0) { + PyErr_SetString(PyExc_KeyError, "Cannot assign to a node"); + return -1; + } + PyObject *r = Graph_setedgeattr(self->graph, self->index, k, v); + return r == NULL ? -1 : 0; +} + +int get_edge_index_from_pos(Graph *g, int u_pos, int v_pos) { + for (int i = 0; i < g->num_edges; i++) { + if ((g->edges[2 * i] == u_pos && g->edges[2 * i + 1] == v_pos) || + (g->edges[2 * i] == v_pos && g->edges[2 * i + 1] == u_pos)) { + return i; + } + } + return -2; +} + +int get_edge_index(Graph *g, int u, int v) { + if (u > v) { + int tmp = u; + u = v; + v = tmp; + } + int u_pos = -1; + int v_pos = -1; + for (int i = 0; i < g->num_nodes; i++) { + if (g->nodes[i] == u) { + u_pos = i; + } else if (g->nodes[i] == v) { + v_pos = i; + } + if (u_pos >= 0 && v_pos >= 0) { + break; + } + } + return get_edge_index_from_pos(g, u_pos, v_pos); +} + +int get_edge_index_py(Graph *g, PyObject *k) { + if (!PyTuple_Check(k) || PyTuple_Size(k) != 2) { + PyErr_SetString(PyExc_KeyError, "EdgeView key must be a tuple of length 2"); + return -1; + } + int u = PyLong_AsLong(PyTuple_GetItem(k, 0)); + int v = PyLong_AsLong(PyTuple_GetItem(k, 1)); + return get_edge_index(g, u, v); +} + +static PyObject *EdgeView_getitem(PyObject *_self, PyObject *k) { + EdgeView *self = (EdgeView *)_self; + if (self->index < 0) { + int edge_idx = get_edge_index_py(self->graph, k); + if (edge_idx < 0) { + if (edge_idx == -2) + PyErr_SetString(PyExc_KeyError, "Edge not found"); + return NULL; + } + PyObject *idx = PyLong_FromLong(edge_idx); + PyObject *args = PyTuple_Pack(2, self->graph, idx); + PyObject *res = PyObject_CallObject((PyObject *)&EdgeViewType, args); + Py_DECREF(args); + Py_DECREF(idx); + return res; + } + return Graph_getedgeattr(self->graph, self->index, k); +} + +static int EdgeView_contains(PyObject *_self, PyObject *v) { + EdgeView *self = (EdgeView *)_self; + if (self->index < 0) { + int index = get_edge_index_py(self->graph, + v); // Returns -2 if not found, -1 on error + if (index == -1) { + return -1; // There was an error + } + return index >= 0; + } + PyObject *attr = Graph_getedgeattr(self->graph, self->index, v); + if (attr == NULL) { + PyErr_Clear(); + return 0; + } + Py_DECREF(attr); + return 1; +} +PyObject *EdgeView_iter(PyObject *_self) { + EdgeView *self = (EdgeView *)_self; + if (self->index != -1) { + PyErr_SetString(PyExc_TypeError, "A bound EdgeView is not iterable"); + return NULL; + } + Py_INCREF(_self); // We have to return a new reference, not a borrowed one + return _self; +} + +PyObject *EdgeView_iternext(PyObject *_self) { + EdgeView *self = (EdgeView *)_self; + self->index++; + if (self->index >= self->graph->num_edges) { + return NULL; + } + int u = self->graph->nodes[self->graph->edges[2 * self->index]]; + int v = self->graph->nodes[self->graph->edges[2 * self->index + 1]]; + PyObject *pu = PyLong_FromLong(u); + PyObject *pv = PyLong_FromLong(v); + PyObject *res = PyTuple_Pack(2, pu, pv); + Py_DECREF(pu); + Py_DECREF(pv); + return res; +} + +Py_ssize_t EdgeView_len(PyObject *_self) { + EdgeView *self = (EdgeView *)_self; + if (self->index != -1) { + // This is inefficient, much like a lot of this codebase... but it's in C. For our graph sizes + // it's not a big deal (and unsurprisingly, it's fast) + Py_ssize_t num_attrs = 0; + for (int i = 0; i < self->graph->num_edge_attrs; i++) { + if (self->graph->edge_attrs[3 * i] == self->index) { + num_attrs++; + } + } + return num_attrs; + } + return self->graph->num_edges; +} + +PyObject *EdgeView_get(PyObject *_self, PyObject *args) { + PyObject *key; + PyObject *default_value = Py_None; + if (!PyArg_ParseTuple(args, "O|O", &key, &default_value)) { + return NULL; + } + PyObject *res = EdgeView_getitem(_self, key); + if (res == NULL) { + PyErr_Clear(); + Py_INCREF(default_value); + return default_value; + } + return res; +} +int _EdgeView_eq(EdgeView *self, EdgeView *other) { + if (self->index == -1 || other->index == -1) { + return -1; + } + + int num_settable = ((GraphDef*)self->graph->graph_def)->num_settable_edge_attrs; + int self_values[num_settable]; + int other_values[num_settable]; + for (int i = 0; i < num_settable; i++) { + self_values[i] = -1; + other_values[i] = -1; + } + for (int i = 0; i < self->graph->num_edge_attrs; i++) { + if (self->graph->edge_attrs[3 * i] == self->index) { + self_values[self->graph->edge_attrs[3 * i + 1]] = self->graph->edge_attrs[3 * i + 2]; + } + } + for (int i = 0; i < other->graph->num_edge_attrs; i++) { + if (other->graph->edge_attrs[3 * i] == other->index) { + other_values[other->graph->edge_attrs[3 * i + 1]] = other->graph->edge_attrs[3 * i + 2]; + } + } + for (int i = 0; i < num_settable; i++) { + if (self_values[i] != other_values[i]) { + return 0; + } + } + return 1; +} + +static PyObject* +EdgeView_richcompare(PyObject *_self, PyObject *other, int op) { + PyObject *result = NULL; + EdgeView *self = (EdgeView *)_self; + + if (other == NULL || other->ob_type != &EdgeViewType) { + result = Py_NotImplemented; + } + else { + if (op == Py_EQ){ + if (self->index == -1 || ((EdgeView *)other)->index == -1){ + PyErr_SetString(PyExc_TypeError, "EdgeView.__eq__ only supports equality comparison for bound EdgeViews"); + return NULL; + } + EdgeView *other_node = (EdgeView *)other; + if (_EdgeView_eq(self, other_node) == 1){ + result = Py_True; + } else { + result = Py_False; + } + }else{ + PyErr_SetString(PyExc_TypeError, "EdgeView only supports equality comparison"); + return NULL; + } + } + + Py_XINCREF(result); + return result; +} + +static PyMappingMethods EdgeView_mapmeth = { + .mp_subscript = EdgeView_getitem, + .mp_ass_subscript = EdgeView_setitem, +}; + +static PySequenceMethods EdgeView_seqmeth = { + .sq_contains = EdgeView_contains, + .sq_length = EdgeView_len, +}; + +static PyMethodDef EdgeView_methods[] = { + {"get", (PyCFunction)EdgeView_get, METH_VARARGS, "dict-like get"}, {NULL} /* Sentinel */ +}; + +PyTypeObject EdgeViewType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.EdgeView", + .tp_doc = PyDoc_STR("Constrained EdgeView object"), + .tp_basicsize = sizeof(EdgeView), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = EdgeView_new, + .tp_init = (initproc)EdgeView_init, + .tp_dealloc = (destructor)EdgeView_dealloc, + .tp_members = EdgeView_members, + .tp_methods = EdgeView_methods, + .tp_as_mapping = &EdgeView_mapmeth, + .tp_as_sequence = &EdgeView_seqmeth, + .tp_iter = EdgeView_iter, + .tp_iternext = EdgeView_iternext, + .tp_richcompare = EdgeView_richcompare, +}; \ No newline at end of file diff --git a/src/C/graph_def.c b/src/C/graph_def.c new file mode 100644 index 00000000..bea85c35 --- /dev/null +++ b/src/C/graph_def.c @@ -0,0 +1,152 @@ +#define PY_SSIZE_T_CLEAN +#include "./main.h" +#include "structmember.h" +#include + +static void GraphDef_dealloc(GraphDef *self) { + Py_XDECREF(self->node_values); + Py_XDECREF(self->node_keypos); + Py_XDECREF(self->node_poskey); + Py_XDECREF(self->edge_values); + Py_XDECREF(self->edge_keypos); + Py_XDECREF(self->edge_poskey); + free(self->node_attr_offsets); + free(self->edge_attr_offsets); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *GraphDef_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + GraphDef *self; + self = (GraphDef *)type->tp_alloc(type, 0); + if (self != NULL) { + self->node_values = NULL; + self->node_keypos = NULL; + self->node_poskey = NULL; + self->edge_values = NULL; + self->edge_keypos = NULL; + self->edge_poskey = NULL; + self->node_attr_offsets = NULL; + self->edge_attr_offsets = NULL; + } + return (PyObject *)self; +} + +static int GraphDef_init(GraphDef *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"node_values", "edge_values", NULL}; + PyObject *node_values = NULL, *edge_values = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &node_values, &edge_values)) + return -1; + + if (!(node_values && edge_values)) { + return -1; + } + if (!(PyDict_Check(node_values) && PyDict_Check(edge_values))) { + PyErr_SetString(PyExc_TypeError, "GraphDef: node_values and edge_values must be dicts"); + return -1; + } + PyObject *v_list = PyDict_GetItemString(node_values, "v"); + if (!v_list) { + PyErr_SetString(PyExc_TypeError, "GraphDef: node_values must contain 'v'"); + return -1; + } + SELF_SET(node_values); + SELF_SET(edge_values); + self->node_keypos = PyDict_New(); + // We want to create a dict, node_keypos = {k: i for i, k in enumerate(sorted(node_values.keys()))} + // First get the keys + PyObject *keys = PyDict_Keys(node_values); + // Then sort them + PyObject *sorted_keys = PySequence_List(keys); + Py_DECREF(keys); + PyList_Sort(sorted_keys); + // Then create the dict + Py_ssize_t len = PyList_Size(sorted_keys); + self->node_attr_offsets = malloc(sizeof(int) * (len + 1)); + self->node_attr_offsets[0] = 0; + // TODO: here this works because 'v' is last in the list, if it's not + // node_logit_offsets[i] != node_attr_offsets[i] - i, which we're using a lot as an assumption + // throughout this and mol_graph_to_Data + Py_ssize_t offset = 0; + for (Py_ssize_t i = 0; i < len; i++) { + PyObject *k = PyList_GetItem(sorted_keys, i); + PyDict_SetItem(self->node_keypos, k, PyLong_FromSsize_t(i)); + PyObject *vals = PyDict_GetItem(node_values, k); + if (!PyList_Check(vals)) { + PyErr_SetString(PyExc_TypeError, "GraphDef: node_values must be a Dict[str, List[Any]]"); + return -1; + } + offset += PyList_Size(vals); + self->node_attr_offsets[i + 1] = offset; + if (i == len - 1 && strcmp(PyUnicode_AsUTF8(k), "v") != 0) { + PyErr_SetString(PyExc_TypeError, "GraphDef: 'v' must be the last sorted key in current implementation"); + return -1; + } + } + self->num_node_dim = offset + 1; // + 1 for the empty graph + self->num_settable_node_attrs = len - 1; // 'v' is not settable by setnodeattr but only by addnode + self->num_new_node_values = PyList_Size(v_list); + self->num_node_attr_logits = offset - (len - 1) - self->num_new_node_values; // 'v' is not settable + self->node_poskey = sorted_keys; + + // Repeat for edge_keypos + self->edge_keypos = PyDict_New(); + keys = PyDict_Keys(edge_values); + sorted_keys = PySequence_List(keys); + Py_DECREF(keys); + PyList_Sort(sorted_keys); + len = PyList_Size(sorted_keys); + self->edge_attr_offsets = malloc(sizeof(int) * (len + 1)); + self->edge_attr_offsets[0] = 0; + offset = 0; + for (Py_ssize_t i = 0; i < len; i++) { + PyObject *k = PyList_GetItem(sorted_keys, i); + PyDict_SetItem(self->edge_keypos, k, PyLong_FromSsize_t(i)); + PyObject *vals = PyDict_GetItem(edge_values, k); + if (!PyList_Check(vals)) { + PyErr_SetString(PyExc_TypeError, "GraphDef: edge_values must be a Dict[str, List[Any]]"); + return -1; + } + offset += PyList_Size(vals); + self->edge_attr_offsets[i + 1] = offset; + } + self->num_edge_dim = offset; + self->num_settable_edge_attrs = len; + self->num_edge_attr_logits = offset - len; + self->edge_poskey = sorted_keys; + + return 0; +} + +PyObject *GraphDef___getstate__(GraphDef *self, PyObject *args) { + return PyTuple_Pack(1, Py_BuildValue("OO", self->node_values, self->edge_values)); +} + +PyObject *GraphDef___setstate__(GraphDef *self, PyObject *state) { + // new() was just called on self + GraphDef_init(self, PyTuple_GetItem(state, 0), NULL); + Py_RETURN_NONE; +} + +static PyMemberDef GraphDef_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef GraphDef_methods[] = { + {"__getstate__", (PyCFunction)GraphDef___getstate__, METH_NOARGS, "Pickle the Custom object"}, + {"__setstate__", (PyCFunction)GraphDef___setstate__, METH_O, "Un-pickle the Custom object"}, + {NULL} /* Sentinel */ +}; + +PyTypeObject GraphDefType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "gflownet._C.GraphDef", + .tp_doc = PyDoc_STR("GraphDef object"), + .tp_basicsize = sizeof(GraphDef), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = GraphDef_new, + .tp_init = (initproc)GraphDef_init, + .tp_dealloc = (destructor)GraphDef_dealloc, + .tp_members = GraphDef_members, + .tp_methods = GraphDef_methods, +}; \ No newline at end of file diff --git a/src/C/main.c b/src/C/main.c new file mode 100644 index 00000000..d5bdcaa4 --- /dev/null +++ b/src/C/main.c @@ -0,0 +1,1143 @@ +#include "main.h" +#include "structmember.h" +#include + +// #define GRAPH_DEBUG +#ifdef GRAPH_DEBUG +void *__real_malloc(size_t size); +void __real_free(void *ptr); +void *__real_realloc(void *ptr, size_t size); + +long total_mem_alloc = 0; +long cnt = 0; +long things_alloc = 0; +/* build with LDFLAGS="-Wl,--wrap,malloc -Wl,--wrap,free -Wl,--wrap,realloc" python setup.py build_ext --inplace + * __wrap_malloc - malloc wrapper function + */ +void *__wrap_malloc(size_t size) { + void *ptr = __real_malloc(size + sizeof(long) * 2); + void *ptr2 = ptr + size; + *(long *)ptr2 = size; + *(long *)(ptr2 + 8) = 0x14285700deadbeef; + // printf("malloc(%d) = %p\n", size, ptr); + total_mem_alloc += size; + things_alloc++; + if (cnt++ % 10000 == 0) { + printf("total_mem_alloc: %ld [%ld]\n", total_mem_alloc, things_alloc); + } + return ptr; +} + +/* + * __wrap_free - free wrapper function + */ +void __wrap_free(void *ptr) { + if (ptr == NULL) { + return; + } + int size = -8; + void *ptr2 = ptr; + while (*(long *)ptr2 != 0x14285700deadbeef) { + ptr2++; + size++; + } + if (size != *(long *)(ptr2 - 8)) { + printf("size mismatch?: %d != %ld\n", size, *(long *)(ptr2 - 8)); + abort(); + } + __real_free(ptr); + // printf("free(%p)\n", ptr); + total_mem_alloc -= size; + things_alloc--; +} + +void *__wrap_realloc(void *ptr, size_t new_size) { + if (ptr == NULL) { + return __wrap_malloc(new_size); + } + int size = -8; + void *ptr2 = ptr; + while (*(long *)ptr2 != 0x14285700deadbeef) { + ptr2++; + size++; + } + *(long *)ptr2 = 0; // erase marker + if (size != *(long *)(ptr2 - 8)) { + printf("size mismatch?: %d != %ld\n", size, *(long *)(ptr2 - 8)); + abort(); + } + void *newptr = __real_realloc(ptr, new_size + sizeof(long) * 2); + void *ptr3 = newptr + new_size; + *(long *)ptr3 = new_size; + *(long *)(ptr3 + 8) = 0x14285700deadbeef; + total_mem_alloc -= size; + total_mem_alloc += new_size; + return newptr; +} + +#endif + +static void Graph_dealloc(Graph *self) { + Py_XDECREF(self->graph_def); + free(self->nodes); + free(self->node_attrs); + free(self->edges); + free(self->edge_attrs); + free(self->degrees); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *Graph_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + Graph *self; + self = (Graph *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph_def = Py_None; + Py_INCREF(Py_None); + self->num_nodes = 0; + self->num_edges = 0; + self->nodes = self->edges = self->node_attrs = self->edge_attrs = self->degrees = NULL; + self->num_node_attrs = 0; + self->num_edge_attrs = 0; + } + return (PyObject *)self; +} + +static int Graph_init(Graph *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph_def", NULL}; + PyObject *graph_def = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &graph_def)) + return -1; + if (PyObject_TypeCheck(graph_def, &GraphDefType) == 0) { + PyErr_SetString(PyExc_TypeError, "Graph: graph_def must be a GraphDef"); + return -1; + } + if (graph_def) { + SELF_SET(graph_def); + } + return 0; +} + +static PyMemberDef Graph_members[] = { + {"graph_def", T_OBJECT_EX, offsetof(Graph, graph_def), 0, "node values"}, {NULL} /* Sentinel */ +}; + +int graph_get_node_pos(Graph *g, int node_id) { + for (int i = 0; i < g->num_nodes; i++) { + if (g->nodes[i] == node_id) { + return i; + } + } + return -1; +} + +static PyObject *Graph_add_node(Graph *self, PyObject *args, PyObject *kwds) { + PyObject *node = NULL; + if (!PyArg_ParseTuple(args, "O", &node)) + return NULL; + if (node) { + if (!PyLong_Check(node)) { + PyErr_SetString(PyExc_TypeError, "node must be an int"); + return NULL; + } + int node_id = PyLong_AsLong(node); + if (graph_get_node_pos(self, node_id) >= 0) { + PyErr_SetString(PyExc_KeyError, "node already exists"); + return NULL; + } + int node_pos = self->num_nodes; + self->num_nodes++; + self->nodes = realloc(self->nodes, self->num_nodes * sizeof(int)); + self->nodes[node_pos] = node_id; + self->degrees = realloc(self->degrees, self->num_nodes * sizeof(int)); + self->degrees[node_pos] = 0; + Py_ssize_t num_attrs; + if (kwds == NULL || (num_attrs = PyDict_Size(kwds)) == 0) { + Py_RETURN_NONE; + } + // Now we need to add the node attributes + // First check if the attributes are found in the GraphDef + GraphDef *gt = (GraphDef *)self->graph_def; + // for k in kwds: + // assert k in gt.node_values and kwds[k] in gt.node_values[k] + PyObject *key, *value; + Py_ssize_t pos = 0; + int node_attr_pos = self->num_node_attrs; + self->num_node_attrs += num_attrs; + self->node_attrs = realloc(self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + + while (PyDict_Next(kwds, &pos, &key, &value)) { + PyObject *node_values = PyDict_GetItem(gt->node_values, key); + if (node_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(node_values, value); + if (value_idx == -1) { + PyObject *repr = PyObject_Repr(value); + PyObject *key_repr = PyObject_Repr(key); + PyErr_Format(PyExc_KeyError, "value %s not found in GraphDef for key %s", PyUnicode_AsUTF8(repr), + PyUnicode_AsUTF8(key_repr)); + Py_DECREF(repr); + // PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, key)); + self->node_attrs[node_attr_pos * 3] = node_pos; + self->node_attrs[node_attr_pos * 3 + 1] = attr_index; + self->node_attrs[node_attr_pos * 3 + 2] = value_idx; + node_attr_pos++; + } + } + Py_RETURN_NONE; +} + +static PyObject *Graph_add_edge(Graph *self, PyObject *args, PyObject *kwds) { + PyObject *u = NULL, *v = NULL; + if (!PyArg_ParseTuple(args, "OO", &u, &v)) + return NULL; + if (u && v) { + if (!PyLong_Check(u) || !PyLong_Check(v)) { + PyErr_SetString(PyExc_TypeError, "u, v must be ints"); + return NULL; + } + int u_id = -PyLong_AsLong(u); + int v_id = -PyLong_AsLong(v); + int u_pos = -1; + int v_pos = -1; + if (u_id < v_id) { + int tmp = u_id; + u_id = v_id; + v_id = tmp; + } + for (int i = 0; i < self->num_nodes; i++) { + if (self->nodes[i] == -u_id) { + u_id = -u_id; + u_pos = i; + } else if (self->nodes[i] == -v_id) { + v_id = -v_id; + v_pos = i; + } + if (u_pos >= 0 && v_pos >= 0) { + break; + } + } + if (u_id < 0 || v_id < 0) { + PyErr_SetString(PyExc_KeyError, "u, v must refer to existing nodes"); + return NULL; + } + if (u_id == v_id){ + PyErr_SetString(PyExc_KeyError, "u, v must be different nodes"); + return NULL; + } + for (int i = 0; i < self->num_edges; i++) { + if (self->edges[i * 2] == u_pos && self->edges[i * 2 + 1] == v_pos) { + PyErr_SetString(PyExc_KeyError, "edge already exists"); + return NULL; + } + } + self->num_edges++; + self->edges = realloc(self->edges, self->num_edges * sizeof(int) * 2); + self->edges[self->num_edges * 2 - 2] = u_pos; + self->edges[self->num_edges * 2 - 1] = v_pos; + self->degrees[u_pos]++; + self->degrees[v_pos]++; + + Py_ssize_t num_attrs; + if (kwds == NULL || (num_attrs = PyDict_Size(kwds)) == 0) { + Py_RETURN_NONE; + } + int edge_id = self->num_edges - 1; + // First check if the attributes are found in the GraphDef + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *key, *value; + Py_ssize_t pos = 0; + int edge_attr_pos = self->num_edge_attrs; + self->num_edge_attrs += num_attrs; + self->edge_attrs = realloc(self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + + while (PyDict_Next(kwds, &pos, &key, &value)) { + PyObject *edge_values = PyDict_GetItem(gt->edge_values, key); + if (edge_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(edge_values, value); + if (value_idx == -1) { + PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, key)); + self->edge_attrs[edge_attr_pos * 3] = edge_id; + self->edge_attrs[edge_attr_pos * 3 + 1] = attr_index; + self->edge_attrs[edge_attr_pos * 3 + 2] = value_idx; + edge_attr_pos++; + } + } + Py_RETURN_NONE; +} + +PyObject *Graph_isdirected(Graph *self, PyObject *Py_UNUSED(ignored)) { Py_RETURN_FALSE; } + +void bridge_dfs(int v, int parent, int *visited, int *tin, int *low, int *timer, int **adj, int *degrees, int *result, + Graph *g) { + visited[v] = 1; + tin[v] = low[v] = (*timer)++; + for (int i = 0; i < degrees[v]; i++) { + int to = adj[v][i]; + if (to == parent) + continue; + if (visited[to]) { + low[v] = mini(low[v], tin[to]); + } else { + bridge_dfs(to, v, visited, tin, low, timer, adj, degrees, result, g); + low[v] = mini(low[v], low[to]); + if (low[to] > tin[v]) { + result[get_edge_index_from_pos(g, v, to)] = 1; + } + } + } +} +PyObject *Graph_bridges(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + // from: + // https://cp-algorithms.com/graph/bridge-searching.html + const int n = self->num_nodes; // number of nodes + int *adj[n]; + for (int i = 0; i < n; i++) { + adj[i] = malloc(self->degrees[i] * sizeof(int)); + } + int is_bridge[self->num_edges]; + for (int i = 0; i < self->num_edges; i++) { + is_bridge[i] = 0; + *(adj[self->edges[i * 2]]) = self->edges[i * 2 + 1]; + *(adj[self->edges[i * 2 + 1]]) = self->edges[i * 2]; + // increase the pointer + adj[self->edges[i * 2]]++; + adj[self->edges[i * 2 + 1]]++; + } + // reset the pointers using the degrees + for (int i = 0; i < n; i++) { + adj[i] -= self->degrees[i]; + } + int visited[n]; + int tin[n]; + int low[n]; + for (int i = 0; i < n; i++) { + visited[i] = 0; + tin[i] = -1; + low[i] = -1; + } + + int timer = 0; + for (int i = 0; i < n; ++i) { + if (!visited[i]) + bridge_dfs(i, -1, visited, tin, low, &timer, adj, self->degrees, is_bridge, self); + } + if (args == NULL) { + // This function is being called from python, so we must return a list + PyObject *result = PyList_New(0); + for (int i = 0; i < self->num_edges; i++) { + if (is_bridge[i]) { + PyObject *u = PyLong_FromLong(self->nodes[self->edges[i * 2]]); + PyObject *v = PyLong_FromLong(self->nodes[self->edges[i * 2 + 1]]); + PyObject *t = PyTuple_Pack(2, u, v); + PyList_Append(result, t); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(t); + } + } + for (int i = 0; i < n; i++) { + free(adj[i]); + } + return result; + } + // This function is being called from C, so args is actually a int* to store the result + int *result = (int *)args; + for (int i = 0; i < self->num_edges; i++) { + result[i] = is_bridge[i]; + } + for (int i = 0; i < n; i++) { + free(adj[i]); + } + return 0; // success +} + +PyObject *Graph_copy(PyObject *_self, PyObject *unused_args) { + Graph *self = (Graph *)_self; + PyObject *args = PyTuple_Pack(1, self->graph_def); + Graph *obj = (Graph *)PyObject_CallObject((PyObject *)&GraphType, args); + Py_DECREF(args); + obj->num_nodes = self->num_nodes; + obj->num_edges = self->num_edges; + obj->nodes = malloc(self->num_nodes * sizeof(int)); + memcpy(obj->nodes, self->nodes, self->num_nodes * sizeof(int)); + obj->edges = malloc(self->num_edges * 2 * sizeof(int)); + memcpy(obj->edges, self->edges, self->num_edges * 2 * sizeof(int)); + obj->num_node_attrs = self->num_node_attrs; + obj->num_edge_attrs = self->num_edge_attrs; + obj->node_attrs = malloc(self->num_node_attrs * 3 * sizeof(int)); + memcpy(obj->node_attrs, self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + obj->edge_attrs = malloc(self->num_edge_attrs * 3 * sizeof(int)); + memcpy(obj->edge_attrs, self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + obj->degrees = malloc(self->num_nodes * sizeof(int)); + memcpy(obj->degrees, self->degrees, self->num_nodes * sizeof(int)); + return (PyObject *)obj; +} + +PyObject *Graph_has_edge(PyObject *_self, PyObject *args) { + int u, v; + if (!PyArg_ParseTuple(args, "ii", &u, &v)) + return NULL; + if (get_edge_index((Graph *)_self, u, v) >= 0) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +void _Graph_check(Graph *g) { +#ifdef GRAPH_DEBUG + for (int i = 0; i < g->num_node_attrs; i++) { + if (g->node_attrs[i * 3] >= g->num_nodes) { + printf("Invalid node attr pointer %d %d\n", g->node_attrs[i * 3], g->num_nodes); + abort(); + } + } + for (int i = 0; i < g->num_edge_attrs; i++) { + if (g->edge_attrs[i * 3] >= g->num_edges) { + printf("Invalid edge attr pointer %d %d\n", g->edge_attrs[i * 3], g->num_edges); + abort(); + } + } +#endif +} + +PyObject *Graph_remove_edge(PyObject *, PyObject *); // fwd decl + +PyObject *Graph_remove_node(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + int u; + if (!PyArg_ParseTuple(args, "i", &u)) + return NULL; + int pos = graph_get_node_pos(self, u); + if (pos < 0) { + PyErr_SetString(PyExc_KeyError, "node not found"); + return NULL; + } + // Check if any edge has this node + for (int i = 0; i < self->num_edges; i++) { + if (self->edges[i * 2] == pos || self->edges[i * 2 + 1] == pos) { + PyObject *u = PyLong_FromLong(self->nodes[self->edges[i * 2]]); + PyObject *v = PyLong_FromLong(self->nodes[self->edges[i * 2 + 1]]); + PyObject *rm_args = PyTuple_Pack(2, u, v); + // if so remove it + PyObject *rm_res = Graph_remove_edge(_self, rm_args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(rm_args); + if (rm_res == NULL) { + return NULL; + } + Py_DECREF(rm_res); + i--; + } + } + // Remove the node + // self->nodes contains node "names", so we just pop one element from the array + int *old_nodes = self->nodes; + self->nodes = malloc((self->num_nodes - 1) * sizeof(int)); + memcpy(self->nodes, old_nodes, pos * sizeof(int)); + memcpy(self->nodes + pos, old_nodes + pos + 1, (self->num_nodes - pos - 1) * sizeof(int)); + free(old_nodes); + int *old_degrees = self->degrees; + self->degrees = malloc((self->num_nodes - 1) * sizeof(int)); + memcpy(self->degrees, old_degrees, pos * sizeof(int)); + memcpy(self->degrees + pos, old_degrees + pos + 1, (self->num_nodes - pos - 1) * sizeof(int)); + free(old_degrees); + self->num_nodes--; + // Remove the node attributes + int *old_node_attrs = self->node_attrs; + int num_rem = 0; + // first find out how many attributes we need to remove + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == pos) { + num_rem++; + } + } + if (num_rem) { + // now remove them + self->node_attrs = malloc((self->num_node_attrs - num_rem) * 3 * sizeof(int)); + int i, j = 0; + for (i = 0; i < self->num_node_attrs; i++) { + if (old_node_attrs[i * 3] == pos) { + continue; + } + int old_node_index = old_node_attrs[i * 3]; + self->node_attrs[j * 3] = old_node_index; + self->node_attrs[j * 3 + 1] = old_node_attrs[i * 3 + 1]; + self->node_attrs[j * 3 + 2] = old_node_attrs[i * 3 + 2]; + j++; + } + self->num_node_attrs -= num_rem; + free(old_node_attrs); + } + // since we removed the node at pos, all other node indices past that must be decremented + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] > pos) { + self->node_attrs[i * 3]--; + } + } + // We must also relabel the edges + for (int i = 0; i < self->num_edges; i++) { + if (self->edges[i * 2] > pos) { + self->edges[i * 2]--; + } + if (self->edges[i * 2 + 1] > pos) { + self->edges[i * 2 + 1]--; + } + } + _Graph_check(self); + Py_RETURN_NONE; +} + +PyObject *Graph_remove_edge(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + int u, v; + if (!PyArg_ParseTuple(args, "ii", &u, &v)) + return NULL; + int edge_index = get_edge_index((Graph *)_self, u, v); + if (edge_index < 0) { + PyErr_SetString(PyExc_KeyError, "edge not found"); + return NULL; + } + // Decrease degree + self->degrees[self->edges[edge_index * 2]]--; + self->degrees[self->edges[edge_index * 2 + 1]]--; + // Remove the edge + int *old_edges = self->edges; + self->edges = malloc((self->num_edges - 1) * 2 * sizeof(int)); + memcpy(self->edges, old_edges, edge_index * 2 * sizeof(int)); + memcpy(self->edges + edge_index * 2, old_edges + edge_index * 2 + 2, + (self->num_edges - edge_index - 1) * 2 * sizeof(int)); + free(old_edges); + self->num_edges--; + // Remove the edge attributes + int num_rem = 0; + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == edge_index) { + num_rem++; + } + } + if (num_rem) { + int *old_edge_attrs = self->edge_attrs; + self->edge_attrs = malloc((self->num_edge_attrs - num_rem) * 3 * sizeof(int)); + int i, j = 0; + for (i = 0; i < self->num_edge_attrs; i++) { + int old_edge_index = old_edge_attrs[i * 3]; + if (old_edge_index == edge_index) { + continue; + } + self->edge_attrs[j * 3] = old_edge_index; + self->edge_attrs[j * 3 + 1] = old_edge_attrs[i * 3 + 1]; + self->edge_attrs[j * 3 + 2] = old_edge_attrs[i * 3 + 2]; + j++; + } + + self->num_edge_attrs -= num_rem; + free(old_edge_attrs); + } + // since we removed the edge at edge_index, all other edge indices past that must be decremented + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] > edge_index) { + self->edge_attrs[i * 3]--; + } + } + _Graph_check(self); + Py_RETURN_NONE; +} + +PyObject *Graph_relabel(PyObject *_self, PyObject *args) { + PyObject *mapping = NULL; + if (!PyArg_ParseTuple(args, "O", &mapping)) + return NULL; + if (!PyDict_Check(mapping)) { + PyErr_SetString(PyExc_TypeError, "mapping must be a dict"); + return NULL; + } + Graph *new = (Graph *)Graph_copy(_self, NULL); // new ref + for (int i = 0; i < new->num_nodes; i++) { + PyObject *node = PyLong_FromLong(new->nodes[i]); + PyObject *new_node = PyDict_GetItem(mapping, node); // ref is borrowed + Py_DECREF(node); + if (new_node == NULL) { + PyErr_Format(PyExc_KeyError, "node %d not found in mapping", new->nodes[i]); + return NULL; + } + if (!PyLong_Check(new_node)) { + PyErr_SetString(PyExc_TypeError, "mapping must be a dict of ints"); + return NULL; + } + new->nodes[i] = PyLong_AsLong(new_node); + } + _Graph_check((Graph *)_self); + return (PyObject *)new; +} + +PyObject *Graph_inspect(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + char buffer[4096]; + memset(buffer, 0, 4096); + int cap = 4096; + int ptr = 0; + + ptr += snprintf(buffer, cap - ptr, "Node labels:\n "); + for (int i = 0; i < self->num_nodes; i++) { + printf("%d ", self->nodes[i]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "\n"); + ptr += snprintf(buffer + ptr, cap - ptr, "Edges:\n"); + for (int i = 0; i < self->num_edges; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, " %d %d (%d %d)\n", self->nodes[self->edges[i * 2]], self->nodes[self->edges[i * 2 + 1]], + self->edges[i * 2], self->edges[i * 2 + 1]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "Node attributes:\n"); + for (int i = 0; i < self->num_node_attrs; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, " %d %d %d\n", self->node_attrs[i * 3], self->node_attrs[i * 3 + 1], self->node_attrs[i * 3 + 2]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "Edge attributes:\n"); + for (int i = 0; i < self->num_edge_attrs; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, " %d %d %d\n", self->edge_attrs[i * 3], self->edge_attrs[i * 3 + 1], self->edge_attrs[i * 3 + 2]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "Degrees:\n "); + for (int i = 0; i < self->num_nodes; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, "%d ", self->degrees[i]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "\n\n"); + PyObject *res = PyUnicode_FromString(buffer); + return res; +} + +PyObject *Graph_getstate(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + int num_bytes = (4 + self->num_nodes * 2 + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3) * sizeof(int); + // Create a new bytes object + PyObject *state = PyBytes_FromStringAndSize(NULL, num_bytes); + char *buf = PyBytes_AS_STRING(state); + int* ptr = (int*)buf; + ptr[0] = self->num_nodes; + ptr[1] = self->num_edges; + ptr[2] = self->num_node_attrs; + ptr[3] = self->num_edge_attrs; + memcpy(ptr + 4, self->nodes, self->num_nodes * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes, self->edges, self->num_edges * 2 * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2, self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, + self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, + self->degrees, self->num_nodes * sizeof(int)); + // Return (bytes, GraphDef) + PyObject *res = PyTuple_Pack(2, state, self->graph_def); + Py_DECREF(state); + return res; +} +PyObject *Graph_setstate(PyObject *_self, PyObject *args) { + PyObject* state = PyTuple_GetItem(args, 0); + PyObject* graph_def = PyTuple_GetItem(args, 1); + if (!PyBytes_Check(state)) { + PyErr_SetString(PyExc_TypeError, "state must be a bytes object"); + return NULL; + } + if (PyObject_TypeCheck(graph_def, &GraphDefType) == 0) { + PyErr_SetString(PyExc_TypeError, "graph_def must be a GraphDef"); + return NULL; + } + Graph *self = (Graph *)_self; + Py_XDECREF(self->graph_def); + self->graph_def = graph_def; + Py_XINCREF(graph_def); + char *buf = PyBytes_AS_STRING(state); + int* ptr = (int*)buf; + self->num_nodes = ptr[0]; + self->num_edges = ptr[1]; + self->num_node_attrs = ptr[2]; + self->num_edge_attrs = ptr[3]; + self->nodes = malloc(self->num_nodes * sizeof(int)); + self->edges = malloc(self->num_edges * 2 * sizeof(int)); + self->node_attrs = malloc(self->num_node_attrs * 3 * sizeof(int)); + self->edge_attrs = malloc(self->num_edge_attrs * 3 * sizeof(int)); + self->degrees = malloc(self->num_nodes * sizeof(int)); + memcpy(self->nodes, ptr + 4, self->num_nodes * sizeof(int)); + memcpy(self->edges, ptr + 4 + self->num_nodes, self->num_edges * 2 * sizeof(int)); + memcpy(self->node_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2, self->num_node_attrs * 3 * sizeof(int)); + memcpy(self->edge_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, + self->num_edge_attrs * 3 * sizeof(int)); + memcpy(self->degrees, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, + self->num_nodes * sizeof(int)); + Py_RETURN_NONE; +} + + +int _NodeView_eq(NodeView *self, NodeView *other); +int _EdgeView_eq(EdgeView *self, EdgeView *other); + +inline int _fast_Node_eq(int i, int j, int* nodeval_cache, int N, int M){ + for (int k = 0; k < M; k++){ + //printf("nodeeq i=%d j=%d k=%d M=%d: %d == %d\n", i, j, k, M, nodeval_cache[i * M + k], nodeval_cache[(j + N) * M + k]); + if (nodeval_cache[i * M + k] != nodeval_cache[(j + N) * M + k]) + return 0; + } + return 1; +} + +int _Graph_iso_search(Graph* a, Graph* b, int* assignment, int depth, int* nodeval_cache, int* a_edgeidx_cache){ + /* Python version: +def search(graph, subgraph, assignments, depth): + # Make sure that every edge between assigned vertices in the subgraph is also an + # edge in the graph. + for u, v in subgraph.edges: + if u < depth and v < depth: + x, y = assignments[u], assignments[v] + if not ( + graph.has_edge(x, y) and + graph.nodes[x] == subgraph.nodes[u] and graph.nodes[y] == subgraph.nodes[v] and + graph.edges[x, y] == subgraph.edges[u, v]): + return False + + # If all the vertices in the subgraph are assigned, then we are done. + if depth == len(subgraph): + return True + + for j in range(len(subgraph)): #possible_assignments[i]: + if assignments[depth] == -1: + assignments[depth] = j + if search(graph, subgraph, assignments, depth+1): + return True + + assignments[depth] = -1 */ + /* + for (int i=0;inum_nodes*2;i++){ + if (i == a->num_nodes){ + printf("|| "); + } + printf("%d ", assignment[i]); + } + puts("");*/ + for (int i = 0; i < b->num_edges; i++){ + int u = b->edges[i * 2]; + int v = b->edges[i * 2 + 1]; + if (u < depth && v < depth){ + int x = assignment[u]; + int y = assignment[v]; + int x_y = a_edgeidx_cache[x * a->num_nodes + y]; // get_edge_index(a, a->nodes[x], a->nodes[y]); + if (x_y < 0) + return 0; + //NodeView nv_x = {.graph = a, .index = x}; + //NodeView nv_y = {.graph = a, .index = y}; + //NodeView nv_u = {.graph = b, .index = u}; + //NodeView nv_v = {.graph = b, .index = v}; + //if (!_NodeView_eq(&nv_x, &nv_u) || !_NodeView_eq(&nv_y, &nv_v)) + if (!_fast_Node_eq(x, u, nodeval_cache, a->num_nodes, 1 + ((GraphDef*)a->graph_def)->num_settable_node_attrs) || + !_fast_Node_eq(y, v, nodeval_cache, a->num_nodes, 1 + ((GraphDef*)a->graph_def)->num_settable_node_attrs)) + return 0; + EdgeView ev_x_y = {.graph = a, .index = x_y}; + EdgeView ev_u_v = {.graph = b, .index = i}; + if (!_EdgeView_eq(&ev_x_y, &ev_u_v)) + return 0; + } + } + + if (depth == b->num_nodes){ + // Found a valid assignment + return 1; + } + for (int j = 0; j < b->num_nodes; j++){ + //printf("Trying %d = %d\n", depth, j); + if (assignment[b->num_nodes + j] != -1) + continue; + //NodeView nv_a = {.graph = a, .index = j}; + //NodeView nv_b = {.graph = b, .index = depth}; + if (a->degrees[j] == b->degrees[depth]){ + if (!_fast_Node_eq(j, depth, nodeval_cache, a->num_nodes, 1 + ((GraphDef*)a->graph_def)->num_settable_node_attrs)){ + continue; + } + /*if (!_NodeView_eq(&nv_a, &nv_b)){ + continue; + }*/ + assignment[depth] = j; + assignment[j + b->num_nodes] = depth; + if (_Graph_iso_search(a, b, assignment, depth + 1, nodeval_cache, a_edgeidx_cache)){ + return 1; + } + assignment[depth] = -1; + assignment[j + b->num_nodes] = -1; + } + } + return 0; +} + +PyObject *Graph_is_isomorphic(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + PyObject *other = NULL; + if (!PyArg_ParseTuple(args, "O", &other)) + return NULL; + if (PyObject_TypeCheck(other, &GraphType) == 0) { + PyErr_SetString(PyExc_TypeError, "other must be a Graph"); + return NULL; + } + Graph *other_graph = (Graph *)other; + if (self->graph_def != other_graph->graph_def) { + Py_RETURN_FALSE; + } + if (self->num_nodes != other_graph->num_nodes || self->num_edges != other_graph->num_edges) { + Py_RETURN_FALSE; + } + int assignments[self->num_nodes * 2]; + for (int i = 0; i < self->num_nodes * 2; i++) { + assignments[i] = -1; + } + int num_attrs = ((GraphDef *)self->graph_def)->num_settable_node_attrs + 1; + int nodeval_cache[self->num_nodes * num_attrs * 2]; + memset(nodeval_cache, 0, self->num_nodes * num_attrs * 2 * sizeof(int)); + for (int j = 0; j < self->num_node_attrs; j++){ + int node_attr_index = self->node_attrs[j * 3]; + int attr_index = self->node_attrs[j * 3 + 1]; + int value_index = self->node_attrs[j * 3 + 2]; + nodeval_cache[node_attr_index * num_attrs + attr_index] = value_index; + } + for (int j = 0; j < other_graph->num_node_attrs; j++){ + int node_attr_index = other_graph->node_attrs[j * 3]; + int attr_index = other_graph->node_attrs[j * 3 + 1]; + int value_index = other_graph->node_attrs[j * 3 + 2]; + nodeval_cache[(node_attr_index + self->num_nodes) * num_attrs + attr_index] = value_index; + } + // Q: Is something wrong with the above? Are we setting the values correctly? + // A: Yes, we are setting the values correctly. The nodeval_cache is a 3D array, where the first dimension is the node index, the second dimension is the attribute index, and the third dimension is the value index. + // The first num_nodes * num_attrs entries are for the first graph, and the next num_nodes * num_attrs entries are for the second graph. + + int a_edgeidx_cache[self->num_nodes * self->num_nodes]; + memset(a_edgeidx_cache, -1, self->num_nodes * self->num_nodes * sizeof(int)); + for (int i = 0; i < self->num_edges; i++){ + a_edgeidx_cache[self->edges[i * 2] * self->num_nodes + self->edges[i * 2 + 1]] = i; + } + if (_Graph_iso_search(self, other_graph, assignments, 0, nodeval_cache, a_edgeidx_cache)){ + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + + +static PyMethodDef Graph_methods[] = { + {"add_node", (PyCFunction)Graph_add_node, METH_VARARGS | METH_KEYWORDS, "Add a node"}, + {"add_edge", (PyCFunction)Graph_add_edge, METH_VARARGS | METH_KEYWORDS, "Add an edge"}, + {"has_edge", (PyCFunction)Graph_has_edge, METH_VARARGS, "Check if an edge is present"}, + {"remove_node", (PyCFunction)Graph_remove_node, METH_VARARGS, "Remove a node"}, + {"remove_edge", (PyCFunction)Graph_remove_edge, METH_VARARGS, "Remove an edge"}, + {"is_directed", (PyCFunction)Graph_isdirected, METH_NOARGS, "Is the graph directed?"}, + {"is_multigraph", (PyCFunction)Graph_isdirected, METH_NOARGS, "Is the graph a multigraph?"}, + {"bridges", (PyCFunction)Graph_bridges, METH_NOARGS, "Find the bridges of the graph"}, + {"copy", (PyCFunction)Graph_copy, METH_VARARGS, "Copy the graph"}, + {"relabel_nodes", (PyCFunction)Graph_relabel, METH_VARARGS, "relabel the graph"}, + {"is_isomorphic", (PyCFunction)Graph_is_isomorphic, METH_VARARGS, "isomorphism test"}, + {"__deepcopy__", (PyCFunction)Graph_copy, METH_VARARGS, "Copy the graph"}, + {"__getstate__", (PyCFunction)Graph_getstate, METH_NOARGS, "pickling"}, + {"__setstate__", (PyCFunction)Graph_setstate, METH_O, "unpickling"}, + {"_inspect", (PyCFunction)Graph_inspect, METH_VARARGS, "Print the graph's information"}, + {NULL} /* Sentinel */ +}; + +static PyObject *Graph_getnodes(Graph *self, void *closure) { + // Return a new NodeView + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&NodeViewType, args); // new ref + Py_DECREF(args); + return obj; +} + +static PyObject *Graph_getedges(Graph *self, void *closure) { + // Return a new EdgeView + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&EdgeViewType, args); + Py_DECREF(args); + return obj; +} + +static PyObject *Graph_getdegree(Graph *self, void *closure) { + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&DegreeViewType, args); + Py_DECREF(args); + return obj; +} + +static PyGetSetDef Graph_getsetters[] = { + {"nodes", (getter)Graph_getnodes, NULL, "nodes", NULL}, + {"edges", (getter)Graph_getedges, NULL, "edges", NULL}, + {"degree", (getter)Graph_getdegree, NULL, "degree", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject *Graph_contains(Graph *self, PyObject *v) { + if (!PyLong_Check(v)) { + PyErr_SetString(PyExc_TypeError, "Graph.__contains__ only accepts integers"); + return NULL; + } + int node_id = PyLong_AsLong(v); + int pos = graph_get_node_pos(self, node_id); + if (pos >= 0) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +static Py_ssize_t Graph_len(Graph *self) { return self->num_nodes; } + +static PySequenceMethods Graph_seqmeth = { + .sq_contains = (objobjproc)Graph_contains, + .sq_length = (lenfunc)Graph_len, +}; + +static PyObject *Graph_iter(PyObject *self) { + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&NodeViewType, args); // new ref + Py_DECREF(args); + return obj; +} + +static PyObject *Graph_getitem(PyObject *self, PyObject *key) { + PyObject *args = PyTuple_Pack(2, self, key); + PyObject *obj = PyObject_CallObject((PyObject *)&NodeViewType, args); // new ref + Py_DECREF(args); + return obj; +} +static PyMappingMethods Graph_mapmeth = { + .mp_subscript = Graph_getitem, +}; + +PyTypeObject GraphType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "gflownet._C.Graph", + .tp_doc = PyDoc_STR("Constrained Graph object"), + .tp_basicsize = sizeof(Graph), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = Graph_new, + .tp_init = (initproc)Graph_init, + .tp_dealloc = (destructor)Graph_dealloc, + .tp_members = Graph_members, + .tp_methods = Graph_methods, + .tp_iter = Graph_iter, + .tp_getset = Graph_getsetters, + .tp_as_sequence = &Graph_seqmeth, + .tp_as_mapping = &Graph_mapmeth, +}; + +PyObject *Graph_getnodeattr(Graph *self, int index, PyObject *k) { + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *value_list = PyDict_GetItem(gt->node_values, k); // borrowed ref + if (value_list == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found"); + return NULL; + } + long attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, k)); + int true_node_index = -1; + for (int i = 0; i < self->num_nodes; i++) { + if (self->nodes[i] == index) { + true_node_index = i; + break; + } + } + if (true_node_index == -1) { + PyErr_SetString(PyExc_KeyError, "node not found"); + return NULL; + } + + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == true_node_index && self->node_attrs[i * 3 + 1] == attr_index) { + // borrowed ref so we have to increase its refcnt because we are returning it + PyObject *r = PyList_GetItem(value_list, self->node_attrs[i * 3 + 2]); + Py_INCREF(r); + return r; + } + } + PyErr_SetString(PyExc_KeyError, "attribute not set for this node"); + return NULL; +} + +PyObject *Graph_setnodeattr(Graph *self, int index, PyObject *k, PyObject *v) { + GraphDef *gt = (GraphDef *)self->graph_def; + int true_node_index = -1; + for (int i = 0; i < self->num_nodes; i++) { + if (self->nodes[i] == index) { + true_node_index = i; + break; + } + } + if (true_node_index == -1) { + PyErr_SetString(PyExc_KeyError, "node not found"); + return NULL; + } + PyObject *node_values = PyDict_GetItem(gt->node_values, k); + if (node_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + if (v == NULL) { + // this means we have to delete g.nodes[index][k] + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, k)); + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == true_node_index && self->node_attrs[i * 3 + 1] == attr_index) { + // found the attribute, remove it + int *old_node_attrs = self->node_attrs; + self->node_attrs = malloc((self->num_node_attrs - 1) * 3 * sizeof(int)); + memcpy(self->node_attrs, old_node_attrs, i * 3 * sizeof(int)); + memcpy(self->node_attrs + i * 3, old_node_attrs + (i + 1) * 3, + (self->num_node_attrs - i - 1) * 3 * sizeof(int)); + self->num_node_attrs--; + free(old_node_attrs); + Py_RETURN_NONE; + } + } + PyErr_SetString(PyExc_KeyError, "trying to delete a key that's not set"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(node_values, v); + if (value_idx == -1) { + PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, k)); + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == true_node_index && self->node_attrs[i * 3 + 1] == attr_index) { + self->node_attrs[i * 3 + 2] = value_idx; + Py_RETURN_NONE; + } + } + // Could not find the attribute, add it + int new_idx = self->num_node_attrs; + self->num_node_attrs++; + self->node_attrs = realloc(self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + self->node_attrs[new_idx * 3] = true_node_index; + self->node_attrs[new_idx * 3 + 1] = attr_index; + self->node_attrs[new_idx * 3 + 2] = value_idx; + Py_RETURN_NONE; +} + +PyObject *Graph_getedgeattr(Graph *self, int index, PyObject *k) { + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *value_list = PyDict_GetItem(gt->edge_values, k); // borrowed ref + if (value_list == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found"); + return NULL; + } + long attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, k)); + if (index > self->num_edges) { + // Should never happen, index is computed by us in EdgeView_getitem, not + // by the user! + PyErr_SetString(PyExc_KeyError, "edge not found [but this should never happen!]"); + return NULL; + } + + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == index && self->edge_attrs[i * 3 + 1] == attr_index) { + // borrowed ref so we have to increase its refcnt because we are returning it + PyObject *r = PyList_GetItem(value_list, self->edge_attrs[i * 3 + 2]); + Py_INCREF(r); + return r; + } + } + PyErr_SetString(PyExc_KeyError, "attribute not set for this node"); + return NULL; +} +PyObject *Graph_setedgeattr(Graph *self, int index, PyObject *k, PyObject *v) { + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *edge_values = PyDict_GetItem(gt->edge_values, k); + if (edge_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + if (v == 0) { + // this means we have to delete g.edges[index][k] + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, k)); + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == index && self->edge_attrs[i * 3 + 1] == attr_index) { + // found the attribute, remove it + int *old_edge_attrs = self->edge_attrs; + self->edge_attrs = malloc((self->num_edge_attrs - 1) * 3 * sizeof(int)); + memcpy(self->edge_attrs, old_edge_attrs, i * 3 * sizeof(int)); + memcpy(self->edge_attrs + i * 3, old_edge_attrs + (i + 1) * 3, + (self->num_edge_attrs - i - 1) * 3 * sizeof(int)); + self->num_edge_attrs--; + free(old_edge_attrs); + Py_RETURN_NONE; + } + } + PyErr_SetString(PyExc_KeyError, "trying to delete a key that's not set"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(edge_values, v); + if (value_idx == -1) { + PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, k)); + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == index && self->edge_attrs[i * 3 + 1] == attr_index) { + self->edge_attrs[i * 3 + 2] = value_idx; + Py_RETURN_NONE; + } + } + // Could not find the attribute, add it + int new_idx = self->num_edge_attrs; + self->num_edge_attrs++; + self->edge_attrs = realloc(self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + self->edge_attrs[new_idx * 3] = index; + self->edge_attrs[new_idx * 3 + 1] = attr_index; + self->edge_attrs[new_idx * 3 + 2] = value_idx; + Py_RETURN_NONE; +} + +static PyMethodDef SpamMethods[] = { + //{"print_count", print_count, METH_VARARGS, "Execute a shell command."}, + {"mol_graph_to_Data", mol_graph_to_Data, METH_VARARGS, "Convert a mol_graph to a Data object."}, + {"Data_collate", Data_collate, METH_VARARGS, "collate Data instances"}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef spammodule = {PyModuleDef_HEAD_INIT, "gflownet._C", /* name of module */ + "doc", /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + SpamMethods}; +PyObject *SpamError; + +PyMODINIT_FUNC PyInit__C(void) { + PyObject *m; + + m = PyModule_Create(&spammodule); + if (m == NULL) + return NULL; + + SpamError = PyErr_NewException("gflownet._C.error", NULL, NULL); + Py_XINCREF(SpamError); + if (PyModule_AddObject(m, "error", SpamError) < 0) { + Py_XDECREF(SpamError); + Py_CLEAR(SpamError); + Py_DECREF(m); + return NULL; + } + PyTypeObject *types[] = {&GraphType, &GraphDefType, &NodeViewType, &EdgeViewType, &DegreeViewType, &DataType}; + char *names[] = {"Graph", "GraphDef", "NodeView", "EdgeView", "DegreeView", "Data"}; + for (int i = 0; i < (int)(sizeof(types) / sizeof(PyTypeObject *)); i++) { + if (PyType_Ready(types[i]) < 0) { + Py_DECREF(m); + return NULL; + } + Py_XINCREF(types[i]); + if (PyModule_AddObject(m, names[i], (PyObject *)types[i]) < 0) { + Py_XDECREF(types[i]); + Py_DECREF(m); + return NULL; + } + } + + return m; +} \ No newline at end of file diff --git a/src/C/main.h b/src/C/main.h new file mode 100644 index 00000000..25162eff --- /dev/null +++ b/src/C/main.h @@ -0,0 +1,115 @@ +#define PY_SSIZE_T_CLEAN +#include + +extern PyObject *SpamError; + +typedef struct { + PyObject_HEAD; + PyObject *node_values; /* Dict[str, List[Any]] */ + PyObject *edge_values; /* Dict[str, List[Any]] */ + PyObject *node_keypos; /* Dict[str, int] */ + PyObject *edge_keypos; /* Dict[str, int] */ + PyObject *node_poskey; /* List[str] */ + PyObject *edge_poskey; /* List[str] */ + int *node_attr_offsets; /* List[int] */ + int *edge_attr_offsets; /* List[int] */ + int num_node_dim; + int num_settable_node_attrs; + int num_node_attr_logits; + int num_new_node_values; + int num_edge_dim; + int num_settable_edge_attrs; + int num_edge_attr_logits; +} GraphDef; + +extern PyTypeObject GraphDefType; + +typedef struct { + PyObject_HEAD; + PyObject *graph_def; + int num_nodes; + int num_edges; + int num_node_attrs; + int num_edge_attrs; + int *nodes; /* List[int] */ + int *edges; /* List[Tuple[int, int]] */ + int *node_attrs; /* List[Tuple[nodeid, attrid, attrvalueidx]] */ + int *edge_attrs; /* List[Tuple[edgeid, attrid, attrvalueidx]] */ + int *degrees; /* List[int] */ +} Graph; + +extern PyTypeObject GraphType; + +typedef struct NodeView { + PyObject_HEAD; + Graph *graph; + int index; +} NodeView; + +extern PyTypeObject NodeViewType; + +typedef struct EdgeView { + PyObject_HEAD; + Graph *graph; + int index; +} EdgeView; + +extern PyTypeObject EdgeViewType; + +typedef struct DegreeView { + PyObject_HEAD; + Graph *graph; +} DegreeView; + +extern PyTypeObject DegreeViewType; + +/* This is a specialized Data instance that only knows how to hold onto 2d tensors + that are either float32 or int64 (for reasons of compatibility with torch_geometric) */ +typedef struct Data { + PyObject_HEAD; + PyObject *bytes; + GraphDef *graph_def; + int *shapes; + int *is_float; + int *offsets; // in bytes + int num_matrices; + const char **names; +} Data; + +extern PyTypeObject DataType; + +#define SELF_SET(v) \ + { \ + PyObject *tmp = self->v; \ + Py_INCREF(v); \ + self->v = v; \ + Py_XDECREF(tmp); \ + } + +// when we want to do f(x) where x is a new reference that we want to decref after the call +// e.g. PyLong_AsLong(PyObject_GetAttrString(x, "id")), GetAttrString returns a new reference +#define borrow_new_and_call(f, x) \ + ({ \ + PyObject *tmp = x; \ + __auto_type rval = f(tmp); \ + Py_DECREF(tmp); \ + rval; \ + }) + +static inline int maxi(int a, int b) { return a > b ? a : b; } +static inline int mini(int a, int b) { return a < b ? a : b; } + +PyObject *Graph_getnodeattr(Graph *self, int index, PyObject *k); +PyObject *Graph_setnodeattr(Graph *self, int index, PyObject *k, PyObject *v); +PyObject *Graph_getedgeattr(Graph *self, int index, PyObject *k); +PyObject *Graph_setedgeattr(Graph *self, int index, PyObject *k, PyObject *v); +PyObject *Graph_bridges(PyObject *self, PyObject *args); + +int get_edge_index(Graph *g, int u, int v); +int get_edge_index_from_pos(Graph *g, int u_pos, int v_pos); + +PyObject *Data_collate(PyObject *self, PyObject *args); +void Data_init_C(Data *self, PyObject *bytes, PyObject *graph_def, int shapes[][2], int *is_float, int num_matrices, + const char **names); + +PyObject *mol_graph_to_Data(PyObject *self, PyObject *args); \ No newline at end of file diff --git a/src/C/mol_graph_to_Data.c b/src/C/mol_graph_to_Data.c new file mode 100644 index 00000000..e6d5fc38 --- /dev/null +++ b/src/C/mol_graph_to_Data.c @@ -0,0 +1,390 @@ +#include "main.h" +#include + +void memsetf(float *ptr, float value, size_t num) { + for (size_t i = 0; i < num; i++) { + ptr[i] = value; + } +} + +#define RETURN_C_DATA 1 +static const char *mol_Data_names[] = {"x", + "edge_index", + "edge_attr", + "non_edge_index", + "stop_mask", + "add_node_mask", + "set_node_attr_mask", + "add_edge_mask", + "set_edge_attr_mask", + "remove_node_mask", + "remove_node_attr_mask", + "remove_edge_mask", + "remove_edge_attr_mask", + "cond_info"}; + +static PyObject *torch_module = NULL; +static PyObject *torch_gd_module = NULL; +void _check_torch() { + if (torch_module == NULL) { + torch_module = PyImport_ImportModule("torch"); + torch_gd_module = PyImport_ImportModule("torch_geometric.data"); + } +} + +PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { + PyObject *mol_graph = NULL, *ctx = NULL, *_torch_module = NULL, *cond_info = NULL; + if (PyArg_ParseTuple(args, "OOO|O", &mol_graph, &ctx, &_torch_module, &cond_info) == 0) { + return NULL; + } + _check_torch(); + if (PyObject_TypeCheck(mol_graph, &GraphType) == 0) { + PyErr_SetString(PyExc_TypeError, "mol_graph must be a Graph"); + return NULL; + } + if (cond_info == Py_None) { + cond_info = NULL; + } + Graph *g = (Graph *)mol_graph; + GraphDef *gd = (GraphDef *)g->graph_def; + PyObject *atom_types = PyDict_GetItemString(gd->node_values, "v"); // borrowed ref + int num_atom_types = PySequence_Size(atom_types); + int atom_valences[num_atom_types]; + float used_valences[g->num_nodes]; + int max_valence[g->num_nodes]; + PyObject *_max_atom_valence = PyObject_GetAttrString(ctx, "_max_atom_valence"); // new ref + for (int i = 0; i < num_atom_types; i++) { + atom_valences[i] = PyLong_AsLong(PyDict_GetItem(_max_atom_valence, PyList_GetItem(atom_types, i))); // borrowed + } + Py_DECREF(_max_atom_valence); + int v_val[g->num_nodes]; + int v_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "v")); + int charge_val[g->num_nodes]; + int charge_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "charge")); + int explH_val[g->num_nodes]; + int explH_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "expl_H")); + int chi_val[g->num_nodes]; + int chi_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "chi")); + int noImpl_val[g->num_nodes]; + int noImpl_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "no_impl")); + int num_set_attribs[g->num_nodes]; + PyObject *N_str = PyUnicode_FromString("N"); + Py_ssize_t nitro_attr_value = PySequence_Index(PyDict_GetItemString(gd->node_values, "v"), N_str); + Py_DECREF(N_str); + for (int i = 0; i < g->num_nodes; i++) { + used_valences[i] = max_valence[i] = v_val[i] = charge_val[i] = explH_val[i] = chi_val[i] = noImpl_val[i] = + num_set_attribs[i] = 0; + } + for (int i = 0; i < g->num_node_attrs; i++) { + int node_pos = g->node_attrs[3 * i]; + int attr_type = g->node_attrs[3 * i + 1]; + int attr_value = g->node_attrs[3 * i + 2]; + num_set_attribs[node_pos] += 1; + if (attr_type == v_idx) { + v_val[node_pos] = attr_value; + max_valence[node_pos] += atom_valences[v_val[node_pos]]; + } else if (attr_type == charge_idx) { + charge_val[node_pos] = attr_value; + max_valence[node_pos] -= 1; // If we change the possible charge ranges from [0,1,-1] this won't work + } else if (attr_type == explH_idx) { + explH_val[node_pos] = attr_value; + max_valence[node_pos] -= attr_value; + } else if (attr_type == chi_idx) { + chi_val[node_pos] = attr_value; + } else if (attr_type == noImpl_idx) { + noImpl_val[node_pos] = attr_value; + } + } + + // bonds are the only edge attributes + char has_connecting_edge_attr_set[g->num_nodes]; + memset(has_connecting_edge_attr_set, 0, g->num_nodes); + float bond_valence[] = {1, 2, 3, 1.5}; // single, double, triple, aromatic + int bond_val[g->num_edges]; + memset(bond_val, 0, g->num_edges * sizeof(int)); + for (int i = 0; i < g->num_edge_attrs; i++) { + int edge_pos = g->edge_attrs[3 * i]; + int attr_type = g->edge_attrs[3 * i + 1]; + int attr_value = g->edge_attrs[3 * i + 2]; + if (attr_type == 0) { // this should always be true, but whatever + bond_val[edge_pos] = attr_value; + } + has_connecting_edge_attr_set[g->edges[2 * edge_pos]] = 1; + has_connecting_edge_attr_set[g->edges[2 * edge_pos + 1]] = 1; + } + for (int i = 0; i < g->num_edges; i++) { + int u = g->edges[2 * i]; + int v = g->edges[2 * i + 1]; + used_valences[u] += bond_valence[bond_val[i]]; + used_valences[v] += bond_valence[bond_val[i]]; + } + // Correct for the valence of charge Nitro atoms + for (int i = 0; i < g->num_nodes; i++) { + if (v_val[i] == nitro_attr_value && charge_val[i] == 1) { + max_valence[i] += 2; + } + } + // Compute creatable edges + char can_create_edge[g->num_nodes][g->num_nodes]; + for (int i = 0; i < g->num_nodes; i++) { + for (int j = 0; j < g->num_nodes; j++) { + can_create_edge[i][j] = + (used_valences[i] + 1 <= max_valence[i]) && (used_valences[j] + 1 <= max_valence[j]); + } + } + for (int i = 0; i < g->num_edges; i++) { + int u = g->edges[2 * i]; + int v = g->edges[2 * i + 1]; + can_create_edge[u][v] = can_create_edge[v][u] = 0; + } + int num_creatable_edges = 0; + for (int i = 0; i < g->num_nodes; i++) { + for (int j = i + 1; j < g->num_nodes; j++) { + num_creatable_edges += can_create_edge[i][j]; + } + } + PyObject *max_nodes_py = PyObject_GetAttrString(ctx, "max_nodes"); // new ref + int max_nodes = max_nodes_py == Py_None ? 1000 : PyLong_AsLong(max_nodes_py); + Py_DECREF(max_nodes_py); + + PyObject *cond_info_shape = cond_info == NULL ? NULL : PyObject_CallMethod(cond_info, "size", NULL); + int num_cond_info_dims = cond_info == NULL ? 0 : PyLong_AsLong(PyTuple_GetItem(cond_info_shape, 1)); + Py_XDECREF(cond_info_shape); + int node_feat_shape = maxi(1, g->num_nodes); + int is_float[] = {1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + int shapes[14][2] = { + {node_feat_shape, gd->num_node_dim}, // node_feat + {2, g->num_edges * 2}, // edge_index + {2 * g->num_edges, gd->num_edge_dim}, // edge_feat + {2, num_creatable_edges}, // non_edge_index + {1, 1}, // stop_mask + {node_feat_shape, gd->num_new_node_values}, // add_node_mask + {node_feat_shape, gd->num_node_attr_logits}, // set_node_attr_mask + {num_creatable_edges, 1}, // add_edge_mask + {g->num_edges, gd->num_edge_attr_logits}, // set_edge_attr_mask + {node_feat_shape, 1}, // remove_node_mask + {node_feat_shape, gd->num_settable_node_attrs}, // remove_node_attr_mask + {g->num_edges, 1}, // remove_edge_mask + {g->num_edges, gd->num_settable_edge_attrs}, // remove_edge_attr_mask + {1, num_cond_info_dims}, // cond_info + }; + int offsets[14]; + Py_ssize_t num_items = 0; + for (int i = 0; i < 14; i++) { + offsets[i] = num_items; + // we need twice the space for longs + num_items += shapes[i][0] * shapes[i][1] * (2 - is_float[i]); + } + // Allocate the memory for the Data object in a way we can return it to Python + // (and let Python free it when it's done) + PyObject *data = PyByteArray_FromStringAndSize(NULL, num_items * sizeof(float)); + int *dataptr = (int *)PyByteArray_AsString(data); + memset(dataptr, 0, num_items * sizeof(float)); + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + + float *node_feat = (float *)(dataptr + offsets[0]); + long *edge_index = (long *)(dataptr + offsets[1]); + float *edge_feat = (float *)(dataptr + offsets[2]); + long *non_edge_index = (long *)(dataptr + offsets[3]); + float *stop_mask = (float *)(dataptr + offsets[4]); + float *add_node_mask = (float *)(dataptr + offsets[5]); + float *set_node_attr_mask = (float *)(dataptr + offsets[6]); + float *add_edge_mask = (float *)(dataptr + offsets[7]); + float *set_edge_attr_mask = (float *)(dataptr + offsets[8]); + float *remove_node_mask = (float *)(dataptr + offsets[9]); + float *remove_node_attr_mask = (float *)(dataptr + offsets[10]); + float *remove_edge_mask = (float *)(dataptr + offsets[11]); + float *remove_edge_attr_mask = (float *)(dataptr + offsets[12]); + float *cond_info_data = (float *)(dataptr + offsets[13]); + + int bridges[g->num_edges]; + // Graph_brigdes' second argument is expected to be NULL when called by Python + // we're using it to pass the bridges array instead + Graph_bridges((PyObject *)g, (PyObject *)bridges); + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + + // sorted attrs is 'charge', 'chi', 'expl_H', 'no_impl', 'v' + int *_node_attrs[5] = {charge_val, chi_val, explH_val, noImpl_val, v_val}; + int *_edge_attrs[1] = {bond_val}; + if (g->num_nodes == 0) { + node_feat[gd->num_node_dim - 1] = 1; + memsetf(add_node_mask, 1, gd->num_new_node_values); + remove_node_mask[0] = 1; + } + /*for (int i=0;i<5;i++){ + printf("%d %d\n", gd->node_attr_offsets[i], gd->node_attr_offsets[i]-i); + }*/ + for (int i = 0; i < g->num_nodes; i++) { + if (g->degrees[i] <= 1 && !has_connecting_edge_attr_set[i] && num_set_attribs[i] == 1) { + remove_node_mask[i] = 1; + } + int is_nitro = v_val[i] == nitro_attr_value; + for (int j = 0; j < 5; j++) { + int one_hot_idx = gd->node_attr_offsets[j] + _node_attrs[j][i]; + node_feat[i * gd->num_node_dim + one_hot_idx] = 1; + int logit_slice_start = gd->node_attr_offsets[j] - j; + int logit_slice_end = gd->node_attr_offsets[j + 1] - j - 1; + if (logit_slice_end == logit_slice_start) + continue; // we cannot actually set this attribute + if (j == v_idx) + continue; // we cannot remove nor set 'v' + // printf("Node attr %d %d = %d\n", j, i, _node_attrs[j][i]); + if (_node_attrs[j][i] > 0) { + // Special case: if the node is a positively charged Nitrogen, and the valence is maxed out (i.e. 5) + // the agent cannot remove the charge, it has to remove the bonds (or bond attrs) making this valence + // maxed out first. + if (j == charge_idx && is_nitro && used_valences[i] >= max_valence[i] && charge_val[i] == 1) + continue; + remove_node_attr_mask[i * gd->num_settable_node_attrs + j] = 1; + } else { + // Special case for N: adding charge to N increases its valence by 2, so we don't want to prevent that + // action, even if the max_valence is "full" (for v=3) + if (j == charge_idx && used_valences[i] >= (max_valence[i] + (is_nitro ? 2 : 0))) // charge + continue; + if (j == explH_idx && used_valences[i] >= max_valence[i]) // expl_H + continue; + // Following on the special case, we made it here, now if the node is not yet charged but already + // has a valence of 3, the only charge we can add is a +1, which is the 0th logit + // (again this assumes charge ranges are [0,1,-1]) + if (j == charge_idx && is_nitro && used_valences[i] >= max_valence[i] && charge_val[i] == 0){ + // printf("Setting 1: %d\n", i * gd->num_node_attr_logits + logit_slice_start); + set_node_attr_mask[i * gd->num_node_attr_logits + logit_slice_start] = 1; + continue; + } + // printf("Setting: %d %d, %d %d\n", j, i * gd->num_node_attr_logits + logit_slice_start, logit_slice_end, logit_slice_start); + memsetf(set_node_attr_mask + i * gd->num_node_attr_logits + logit_slice_start, 1, + (logit_slice_end - logit_slice_start)); + } + } + if (used_valences[i] < max_valence[i] && g->num_nodes < max_nodes) { + memsetf(add_node_mask + i * gd->num_new_node_values, 1, gd->num_new_node_values); + } + } + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + for (int i = 0; i < g->num_edges; i++) { + if (bridges[i] == 0) { + remove_edge_mask[i] = 1; + } + int j = 0; // well there's only the bond attr + int one_hot_idx = gd->edge_attr_offsets[j] + _edge_attrs[j][i]; + edge_feat[2 * i * gd->num_edge_dim + one_hot_idx] = 1; + edge_feat[(2 * i + 1) * gd->num_edge_dim + one_hot_idx] = 1; + int logit_slice_start = gd->edge_attr_offsets[j] - j; + if (_edge_attrs[j][i] > 0) { + remove_edge_attr_mask[i * gd->num_settable_edge_attrs + j] = 1; + } else { + // TODO: we're not using aromatics here, instead single-double-triple is hardcoded + // k starts at 1 because the default value (0) is the single bond + for (int k = 1; k < 3; k++) { + // bond_valence - 1 because the 0th value is the single bond which has valence 1, and we'd be "removing" + // it to replace it with another bond + if ((used_valences[g->edges[2 * i]] + bond_valence[k] - 1 > max_valence[g->edges[2 * i]]) || + (used_valences[g->edges[2 * i + 1]] + bond_valence[k] - 1 > max_valence[g->edges[2 * i + 1]])) + continue; + + // use k - 1 because the default value (0) doesn't have an associated logit + set_edge_attr_mask[i * gd->num_edge_attr_logits + logit_slice_start + k - 1] = 1; + } + } + edge_index[2 * i] = g->edges[2 * i]; + edge_index[2 * i + 1] = g->edges[2 * i + 1]; + edge_index[2 * i + g->num_edges * 2] = g->edges[2 * i + 1]; + edge_index[2 * i + g->num_edges * 2 + 1] = g->edges[2 * i]; + } + // already filtered out for valence and such + memsetf(add_edge_mask, 1, num_creatable_edges); + int non_edge_idx_idx = 0; + for (int i = 0; i < g->num_nodes; i++) { + for (int j = i + 1; j < g->num_nodes; j++) { + if (!can_create_edge[i][j]) + continue; + non_edge_index[non_edge_idx_idx] = i; + non_edge_index[num_creatable_edges + non_edge_idx_idx] = j; + non_edge_idx_idx++; + } + } + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + if (cond_info != NULL) { + PyObject *cpu_cond_info = PyObject_CallMethod(cond_info, "cpu", ""); + PyObject *cont_cond_info = PyObject_CallMethod(cpu_cond_info, "contiguous", ""); + PyObject *ptr = PyObject_CallMethod(cont_cond_info, "data_ptr", ""); + float *tensor_ptr = PyLong_AsVoidPtr(ptr); + + for (int i = 0; i < num_cond_info_dims; i++) { + cond_info_data[i] = tensor_ptr[i]; + } + Py_DECREF(ptr); + Py_DECREF(cpu_cond_info); + Py_DECREF(cont_cond_info); + } + + if (PyErr_Occurred()) { + PyErr_Print(); + raise(SIGINT); + return NULL; + } + *stop_mask = g->num_nodes > 0 ? 1 : 0; + if (RETURN_C_DATA) { + // Data *data_obj = DataType.tp_new(&DataType, NULL, NULL); + Data *data_obj = (Data *)PyObject_CallObject((PyObject *)&DataType, NULL); + Data_init_C(data_obj, data, g->graph_def, shapes, is_float, 13 + (cond_info != NULL), mol_Data_names); + Py_DECREF(data); + return (PyObject *)data_obj; + } + // The following lines take about 80% of the runtime of the function on ~50 node graphs :"( + PyObject *res = PyDict_New(); + PyObject *frombuffer = PyObject_GetAttrString(torch_module, "frombuffer"); + PyObject *empty = PyObject_GetAttrString(torch_module, "empty"); + PyObject *dtype_f32 = PyObject_GetAttrString(torch_module, "float32"); + PyObject *dtype_i64 = PyObject_GetAttrString(torch_module, "int64"); + PyObject *fb_args = PyTuple_Pack(1, data); + PyObject *fb_kwargs = PyDict_New(); + int do_del_kw = 0; + for (int i = 0; i < 13; i++) { + int i_num_items = shapes[i][0] * shapes[i][1]; + PyObject *tensor; + PyDict_SetItemString(fb_kwargs, "dtype", is_float[i] ? dtype_f32 : dtype_i64); + if (i_num_items == 0) { + if (do_del_kw) { + PyDict_DelItemString(fb_kwargs, "offset"); + PyDict_DelItemString(fb_kwargs, "count"); + do_del_kw = 0; + } + PyObject *zero = PyLong_FromLong(0); + tensor = PyObject_Call(empty, PyTuple_Pack(1, zero), fb_kwargs); + Py_DECREF(zero); + } else { + PyObject *py_offset = PyLong_FromLong(offsets[i] * sizeof(float)); + PyObject *py_count = PyLong_FromLong(i_num_items); + PyDict_SetItemString(fb_kwargs, "offset", py_offset); + PyDict_SetItemString(fb_kwargs, "count", py_count); + do_del_kw = 1; + tensor = PyObject_Call(frombuffer, fb_args, fb_kwargs); + Py_DECREF(py_offset); + Py_DECREF(py_count); + } + PyObject *reshaped_tensor = PyObject_CallMethod(tensor, "view", "ii", shapes[i][0], shapes[i][1]); + PyDict_SetItemString(res, mol_Data_names[i], reshaped_tensor); + Py_DECREF(tensor); + Py_DECREF(reshaped_tensor); + } + Py_DECREF(frombuffer); + Py_DECREF(dtype_f32); + Py_DECREF(dtype_i64); + Py_DECREF(fb_args); + Py_DECREF(fb_kwargs); + Py_DECREF(data); + return res; +} diff --git a/src/C/node_view.c b/src/C/node_view.c new file mode 100644 index 00000000..208b4bea --- /dev/null +++ b/src/C/node_view.c @@ -0,0 +1,244 @@ +#define PY_SSIZE_T_CLEAN +#include "./main.h" +#include "structmember.h" +#include + +static void NodeView_dealloc(NodeView *self) { + Py_XDECREF(self->graph); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *NodeView_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + NodeView *self; + self = (NodeView *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph = NULL; + } + return (PyObject *)self; +} + +static int NodeView_init(NodeView *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph", "index", NULL}; + PyObject *graph = NULL, *tmp; + int node_id = -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i", kwlist, &graph, &node_id)) + return -1; + + if (graph) { + tmp = (PyObject *)self->graph; + Py_INCREF(graph); + self->graph = (Graph *)graph; + Py_XDECREF(tmp); + self->index = node_id; + if (node_id >= 0) { + int node_exists = 0; + for (int i = 0; i < self->graph->num_nodes; i++) { + if (self->graph->nodes[i] == node_id) { + node_exists = 1; + } + } + if (!node_exists) { + PyErr_SetString(PyExc_KeyError, "Trying to create a view with a node that does not exist"); + return -1; + } + } + } + return 0; +} + +static PyMemberDef NodeView_members[] = { + {NULL} /* Sentinel */ +}; + +static int NodeView_setitem(PyObject *_self, PyObject *k, PyObject *v) { + NodeView *self = (NodeView *)_self; + if (self->index < 0) { + PyErr_SetString(PyExc_KeyError, "Cannot assign to a node"); + return -1; + } + PyObject *r = Graph_setnodeattr(self->graph, self->index, k, v); + Py_DECREF(r); + return r == NULL ? -1 : 0; +} + +static PyObject *NodeView_getitem(PyObject *_self, PyObject *k) { + NodeView *self = (NodeView *)_self; + if (self->index < 0) { + PyObject *args = PyTuple_Pack(2, self->graph, k); + PyObject *res = PyObject_CallObject((PyObject *)&NodeViewType, args); + Py_DECREF(args); + return res; + } + return Graph_getnodeattr(self->graph, self->index, k); +} + +static int NodeView_contains(PyObject *_self, PyObject *v) { + NodeView *self = (NodeView *)_self; + if (self->index < 0) { + // Graph_containsnode + if (!PyLong_Check(v)) { + PyErr_SetString(PyExc_TypeError, "NodeView.__contains__ only accepts integers"); + return -1; + } + int index = PyLong_AsLong(v); + for (int i = 0; i < self->graph->num_nodes; i++) { + if (self->graph->nodes[i] == index) { + return 1; + } + } + return 0; + } + // Graph_nodehasattr + PyObject *attr = Graph_getnodeattr(self->graph, self->index, v); + if (attr == NULL) { + PyErr_Clear(); + return 0; + } + Py_DECREF(attr); + return 1; +} + +static Py_ssize_t NodeView_len(PyObject *_self) { + NodeView *self = (NodeView *)_self; + if (self->index != -1) { + // This is inefficient, much like a lot of this codebase... but it's in C. For our graph sizes + // it's not a big deal (and unsurprisingly, it's fast) + Py_ssize_t num_attrs = 0; + for (int i = 0; i < self->graph->num_node_attrs; i++) { + if (self->graph->node_attrs[3 * i] == self->index) { + num_attrs++; + } + } + return num_attrs; + } + return self->graph->num_nodes; +} + +PyObject *NodeView_iter(PyObject *_self) { + NodeView *self = (NodeView *)_self; + if (self->index != -1) { + PyErr_SetString(PyExc_TypeError, "A bound NodeView is not iterable"); + return NULL; + } + Py_INCREF(_self); // We have to return a new reference, not a borrowed one + return _self; +} + +PyObject *NodeView_iternext(PyObject *_self) { + NodeView *self = (NodeView *)_self; + self->index++; + if (self->index >= self->graph->num_nodes) { + return NULL; + } + return PyLong_FromLong(self->graph->nodes[self->index]); +} + +PyObject *NodeView_get(PyObject *_self, PyObject *args) { + PyObject *key; + PyObject *default_value = Py_None; + if (!PyArg_ParseTuple(args, "O|O", &key, &default_value)) { + return NULL; + } + PyObject *res = NodeView_getitem(_self, key); + if (res == NULL) { + PyErr_Clear(); + Py_INCREF(default_value); + return default_value; + } + return res; +} + +int _NodeView_eq(NodeView *self, NodeView *other) { + if (self->index == -1 || other->index == -1) { + return -1; + } + + int num_settable = ((GraphDef*)self->graph->graph_def)->num_settable_node_attrs + 1; + int self_values[num_settable]; + int other_values[num_settable]; + for (int i = 0; i < num_settable; i++) { + self_values[i] = -1; + other_values[i] = -1; + } + for (int i = 0; i < self->graph->num_node_attrs; i++) { + if (self->graph->node_attrs[3 * i] == self->graph->nodes[self->index]) { + self_values[self->graph->node_attrs[3 * i + 1]] = self->graph->node_attrs[3 * i + 2]; + } + } + for (int i = 0; i < other->graph->num_node_attrs; i++) { + if (other->graph->node_attrs[3 * i] == other->graph->nodes[other->index]) { + other_values[other->graph->node_attrs[3 * i + 1]] = other->graph->node_attrs[3 * i + 2]; + } + } + for (int i = 0; i < num_settable; i++) { + // printf("%d %d %d\n", i, self_values[i], other_values[i]); + if (self_values[i] != other_values[i]) { + return 0; + } + } + return 1; +} + +static PyObject * +NodeView_richcompare(PyObject *_self, PyObject *other, int op) +{ + PyObject *result = NULL; + NodeView *self = (NodeView *)_self; + + if (other == NULL || other->ob_type != &NodeViewType) { + result = Py_NotImplemented; + } + else { + if (op == Py_EQ){ + if (self->index == -1 || ((NodeView *)other)->index == -1){ + PyErr_SetString(PyExc_TypeError, "NodeView.__eq__ only supports equality comparison for bound NodeViews"); + return NULL; + } + NodeView *other_node = (NodeView *)other; + if (_NodeView_eq(self, other_node) == 1){ + result = Py_True; + } else { + result = Py_False; + } + }else{ + PyErr_SetString(PyExc_TypeError, "NodeView only supports equality comparison"); + return NULL; + } + } + + Py_XINCREF(result); + return result; +} + +static PyMappingMethods NodeView_mapmeth = { + .mp_subscript = NodeView_getitem, + .mp_ass_subscript = NodeView_setitem, +}; + +static PySequenceMethods NodeView_seqmeth = { + .sq_contains = NodeView_contains, + .sq_length = NodeView_len, +}; + +static PyMethodDef NodeView_methods[] = { + {"get", (PyCFunction)NodeView_get, METH_VARARGS, "dict-like get"}, {NULL} /* Sentinel */ +}; + +PyTypeObject NodeViewType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.NodeView", + .tp_doc = PyDoc_STR("Constrained NodeView object"), + .tp_basicsize = sizeof(NodeView), + .tp_itemsize = 0, + //.tp_flags = Py_TPFLAGS_HAVE_RICHCOMPARE, //Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = NodeView_new, + .tp_init = (initproc)NodeView_init, + .tp_dealloc = (destructor)NodeView_dealloc, + .tp_members = NodeView_members, + .tp_methods = NodeView_methods, + .tp_as_mapping = &NodeView_mapmeth, + .tp_as_sequence = &NodeView_seqmeth, + .tp_iter = NodeView_iter, + .tp_iternext = NodeView_iternext, + .tp_richcompare = NodeView_richcompare, +}; \ No newline at end of file diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 1199dd31..b79fed0d 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -256,8 +256,9 @@ def not_done(lst): torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] not_done_mask = torch.tensor(done, device=dev).logical_not() if model is not None: - ci = cond_info[not_done_mask] if cond_info is not None else None - _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), ci) + gbatch = self.ctx.collate(torch_graphs).to(dev) + gbatch.cond_info = cond_info[not_done_mask] if cond_info is not None else None + _, bck_cat, *_ = model(gbatch) else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 8713cde3..06d021e0 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -583,8 +583,6 @@ def compute_batch_losses( tb_loss = traj_losses.mean() loss = tb_loss + reward_loss + self.cfg.n_loss_multiplier * n_loss info = { - "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, - "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, "reward_loss": reward_loss, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, "invalid_logprob": (invalid_mask * traj_log_p_F).sum() / (invalid_mask.sum() + 1e-4), @@ -596,7 +594,12 @@ def compute_batch_losses( "tb_loss": tb_loss.item(), "batch_entropy": -traj_log_p_F.mean(), "traj_lens": batch.traj_lens.float().mean(), + 'avg_log_reward': clip_log_R.mean(), } + sources = set(batch.sources) + if len(sources) > 1: + for source in sources: + info[f'{source}_loss'] = traj_losses[torch.as_tensor([i == source for i in batch.sources])].mean().item() if self.ctx.has_n() and self.cfg.do_predict_n: info["n_loss_pred"] = scatter( (log_n_preds - batch.log_ns) ** 2, batch_idx, dim=0, dim_size=num_trajs, reduce="sum" @@ -609,6 +612,8 @@ def compute_batch_losses( d[final_graph_idx] = 0 info["n_loss_maxent"] = scatter(d, batch_idx, dim=0, dim_size=num_trajs, reduce="sum").mean() + if not torch.isfinite(loss): + import pdb; pdb.set_trace() return loss, info def analytical_maxent_backward(self, batch, first_graph_idx): diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 86b225f0..65733f64 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -107,6 +107,8 @@ class Config(StrictDataClass): git_hash: Optional[str] = None overwrite_existing_exp: bool = False mp_buffer_size: Optional[int] = None + world_size: int = 1 + rank: int = 0 algo: AlgoConfig = field(default_factory=AlgoConfig) model: ModelConfig = field(default_factory=ModelConfig) opt: OptimizerConfig = field(default_factory=OptimizerConfig) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index a6fdecb7..536e9ca1 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -99,6 +99,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(num_samples, t) # TODO: in the cond info refactor, pass the whole thing instead of just the encoding trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + self.mark_all(trajs, source='sample') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -128,6 +129,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(n_this_time, t) # TODO: in the cond info refactor, pass the whole thing instead of just the encoding trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info["encoding"], p) + self.mark_all(trajs, source='sample') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -143,6 +145,7 @@ def do_sample_replay(self, num_samples): def iterator(): while self.active: trajs, *_ = self.replay_buffer.sample(num_samples) + self.mark_all(trajs, source='replay') self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} @@ -157,6 +160,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) + self.mark_all(trajs, source='dataset') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -175,6 +179,7 @@ def iterator(): # it and the state of the program (e.g. validation mode) cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info["encoding"], p) + self.mark_all(trajs, source='dataset') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -197,6 +202,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) + self.mark_all(trajs, source='dataset') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -205,6 +211,10 @@ def iterator(): self.iterators.append(iterator) return self + def mark_all(self, trajs, **kw): + for t in trajs: + t.update(kw) + def call_sampling_hooks(self, trajs): batch_info = {} # TODO: just pass trajs to the hooks and deprecate passing all those arguments @@ -222,8 +232,7 @@ def create_batch(self, trajs, batch_info): ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) log_rewards = torch.stack([t["log_reward"] for t in trajs]) batch = self.algo.construct_batch(trajs, ci, log_rewards) - batch.num_online = sum(t.get("is_online", 0) for t in trajs) - batch.num_offline = len(trajs) - batch.num_online + batch.sources = [t.get("source", "unknown") for t in trajs] batch.extra_info = batch_info if "preferences" in trajs[0]["cond_info"].keys(): batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs]) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 601d7bd5..63681154 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -140,7 +140,7 @@ class GraphBuildingEnv: - we can generate a legal action for any attribute that isn't a default one. """ - def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True): + def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True, graph_cls=None): """A graph building environment instance Parameters @@ -152,13 +152,16 @@ def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=Tr if True, allows this action and computes SetNodeAttr parents allow_edge_attr: bool if True, allows this action and computes SetEdgeAttr parents + graph_cls: class + if None, defaults to Graph. Called in new(). """ self.allow_add_edge = allow_add_edge self.allow_node_attr = allow_node_attr self.allow_edge_attr = allow_edge_attr + self.graph_cls = Graph if graph_cls is None else graph_cls def new(self): - return Graph() + return self.graph_cls() def step(self, g: Graph, action: GraphAction) -> Graph: """Step forward the given graph state with an action @@ -178,12 +181,12 @@ def step(self, g: Graph, action: GraphAction) -> Graph: gp = g.copy() if action.action is GraphActionType.AddEdge: a, b = action.source, action.target - assert self.allow_add_edge - assert a in g and b in g + #assert self.allow_add_edge + #assert a in g and b in g if a > b: a, b = b, a - assert a != b - assert not g.has_edge(a, b) + #assert a != b + #assert not g.has_edge(a, b) # Ideally the FA underlying this must only be able to send # create_edge actions which respect this a Graph: elif action.action is GraphActionType.AddNode: if len(g) == 0: - assert action.source == 0 # TODO: this may not be useful + #assert action.source == 0 # TODO: this may not be useful gp.add_node(0, v=action.value) else: - assert action.source in g.nodes + #assert action.source in g.nodes e = [action.source, max(g.nodes) + 1] # if kw and 'relabel' in kw: # e[1] = kw['relabel'] # for `parent` consistency, allow relabeling - assert not g.has_edge(*e) + #assert not g.has_edge(*e) gp.add_node(e[1], v=action.value) gp.add_edge(*e) elif action.action is GraphActionType.SetNodeAttr: - assert self.allow_node_attr - assert action.source in gp.nodes + #assert self.allow_node_attr + #assert action.source in gp.nodes # For some "optional" attributes like wildcard atoms, we indicate that they haven't been # chosen by the 'None' value. Here we make sure that either the attribute doesn't # exist, or that it's an optional attribute that hasn't yet been set. - assert action.attr not in gp.nodes[action.source] or gp.nodes[action.source][action.attr] is None + #assert action.attr not in gp.nodes[action.source] or gp.nodes[action.source][action.attr] is None gp.nodes[action.source][action.attr] = action.value elif action.action is GraphActionType.SetEdgeAttr: - assert self.allow_edge_attr - assert g.has_edge(action.source, action.target) - assert action.attr not in gp.edges[(action.source, action.target)] + #assert self.allow_edge_attr + #assert g.has_edge(action.source, action.target) + #assert action.attr not in gp.edges[(action.source, action.target)] gp.edges[(action.source, action.target)][action.attr] = action.value elif action.action is GraphActionType.RemoveNode: - assert g.has_node(action.source) + #assert g.has_node(action.source) gp = graph_without_node(gp, action.source) elif action.action is GraphActionType.RemoveNodeAttr: - assert g.has_node(action.source) + #assert g.has_node(action.source) gp = graph_without_node_attr(gp, action.source, action.attr) elif action.action is GraphActionType.RemoveEdge: - assert g.has_edge(action.source, action.target) + #assert g.has_edge(action.source, action.target) gp = graph_without_edge(gp, (action.source, action.target)) elif action.action is GraphActionType.RemoveEdgeAttr: - assert g.has_edge(action.source, action.target) + #assert g.has_edge(action.source, action.target) gp = graph_without_edge_attr(gp, (action.source, action.target), action.attr) else: raise ValueError(f"Unknown action type {action.action}", action.action) - gp.clear_cache() # Invalidate cached properties since we've modified the graph + # gp.clear_cache() # Invalidate cached properties since we've modified the graph return gp def parents(self, g: Graph): @@ -315,17 +318,22 @@ def count_backward_transitions(self, g: Graph, check_idempotent: bool = False): return len(self.parents(g)) c = 0 deg = [g.degree[i] for i in range(len(g.nodes))] + has_connected_edge_attr = [False] * len(g.nodes) + bridges = g.bridges() for a, b in g.edges: if deg[a] > 1 and deg[b] > 1 and len(g.edges[(a, b)]) == 0: # Can only remove edges connected to non-leaves and without # attributes (the agent has to remove the attrs, then remove # the edge). Removal cannot disconnect the graph. - new_g = graph_without_edge(g, (a, b)) - if nx.algorithms.is_connected(new_g): + if (a, b) not in bridges and (b, a) not in bridges: c += 1 - c += len(g.edges[(a, b)]) # One action per edge attr + num_attrs = len(g.edges[(a, b)]) + c += num_attrs # One action per edge attr + if num_attrs > 0: + has_connected_edge_attr[a] = True + has_connected_edge_attr[b] = True for i in g.nodes: - if deg[i] == 1 and len(g.nodes[i]) == 1 and len(g.edges[list(g.edges(i))[0]]) == 0: + if deg[i] == 1 and len(g.nodes[i]) == 1 and not has_connected_edge_attr[i]: c += 1 c += len(g.nodes[i]) - 1 # One action per node attr, except 'v' if len(g.nodes) == 1 and len(g.nodes[i]) == 1: diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index e24ff8f2..9ef30fd0 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -13,6 +13,14 @@ DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] +try: + from gflownet._C import mol_graph_to_Data, Graph as C_Graph, GraphDef, Data_collate + + C_Graph_available = True +except ImportError: + import warnings + warnings.warn("Could not import mol_graph_to_Data, Graph, GraphDef from _C, using pure python implementation") + C_Graph_available = False class MolBuildingEnvContext(GraphBuildingEnvContext): """A specification of what is being generated for a GraphBuildingEnv @@ -61,12 +69,12 @@ def __init__( """ # idx 0 has to coincide with the default value self.atom_attr_values = { - "v": atoms + ["*"], + "v": atoms, "chi": chiral_types, "charge": charges, "expl_H": expl_H_range, "no_impl": [False, True], - "fill_wildcard": [None] + atoms, # default is, there is nothing + # "fill_wildcard": [None] + atoms, # default is, there is nothing } self.num_rw_feat = num_rw_feat self.max_nodes = max_nodes @@ -157,6 +165,15 @@ def __init__( GraphActionType.RemoveEdge, GraphActionType.RemoveEdgeAttr, ] + if C_Graph_available: + self.graph_def = GraphDef(self.atom_attr_values, self.bond_attr_values) + self.graph_cls = self._make_C_graph + assert charges == [0, 1, -1], 'C impl quirk: charges must be [0, 1, -1]' + else: + self.graph_cls = Graph + + def _make_C_graph(self): + return C_Graph(self.graph_def) def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" @@ -164,6 +181,10 @@ def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = t = self.action_type_order[aidx.action_type] else: t = self.bck_action_type_order[aidx.action_type] + + if self.graph_cls is not Graph: + return g.mol_aidx_to_GraphAction((aidx.action_type, aidx.row_idx, aidx.col_idx), t) + if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: @@ -203,6 +224,9 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI else: raise ValueError(f"Unknown action type {action.action}") + if self.graph_cls is not Graph: + return (type_idx,) + g.mol_GraphAction_to_aidx(action) + if action.action is GraphActionType.Stop: row = col = 0 elif action.action is GraphActionType.AddNode: @@ -255,6 +279,10 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" + if self.graph_cls is not Graph: + cond_info = None # Todo: Implement this + return mol_graph_to_Data(g, self, torch, cond_info) + x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32) x[0, -1] = len(g.nodes) == 0 add_node_mask = np.ones((x.shape[0], self.num_new_node_values), dtype=np.float32) @@ -394,11 +422,13 @@ def graph_to_Data(self, g: Graph) -> gd.Data: def collate(self, graphs: List[gd.Data]): """Batch Data instances""" + if self.graph_cls is not Graph: + return Data_collate(graphs, ["edge_index", "non_edge_index"]) return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"]) def obj_to_graph(self, mol: Mol) -> Graph: """Convert an RDMol to a Graph""" - g = Graph() + g = self.graph_cls() mol = Mol(mol) # Make a copy if not self.allow_explicitly_aromatic: # If we disallow aromatic bonds, ask rdkit to Kekulize mol and remove aromatic bond flags diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index b84980dc..6807dc5b 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -240,6 +240,7 @@ def __init__( self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this self._logZ = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2) + self.logit_scaler = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2) def logZ(self, cond_info: Optional[torch.Tensor]): if cond_info is None: @@ -247,13 +248,16 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(cond_info) def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[GraphActionType]): - return GraphActionCategorical( + cat = GraphActionCategorical( g, raw_logits=[self.mlps[t.cname](emb[self._action_type_to_graph_part[t]]) for t in action_types], keys=[self._action_type_to_key[t] for t in action_types], action_masks=[action_type_to_mask(t, g) for t in action_types], types=action_types, ) + sc = self.logit_scaler(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) + cat.logits = [l * sc[b] for l, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them + return cat def forward(self, g: gd.Batch): node_embeddings, graph_embeddings = self.transf(g) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 4d90fa3a..73cb07a2 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -8,6 +8,7 @@ import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd @@ -72,7 +73,7 @@ def __init__(self, config: Config, print_config=True): ) # make sure the config is a Config object, and not the Config class itself self.cfg: Config = OmegaConf.merge(self.default_cfg, config) - self.device = torch.device(self.cfg.device) + self.device = torch.device(self.cfg.rank or self.cfg.device) set_main_process_device(self.device) # Print the loss every `self.print_every` iterations self.print_every = self.cfg.print_every @@ -82,6 +83,9 @@ def __init__(self, config: Config, print_config=True): # Will check if parameters are finite at every iteration (can be costly) self._validate_parameters = False + self.world_size = self.cfg.world_size + self.rank = self.cfg.rank + self.setup() def set_default_hps(self, base: Config): @@ -219,6 +223,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: tick = time.time() self.model.train() try: + self.model.lock.acquire() loss = info = None loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): @@ -227,6 +232,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: self.algo.step() # This also isn't used anywhere? if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") + self.model.lock.release() except ValueError as e: os.makedirs(self.cfg.log_dir, exist_ok=True) torch.save([self.model.state_dict(), batch, loss, info], open(self.cfg.log_dir + "/dump.pkl", "wb")) @@ -258,6 +264,22 @@ def _maybe_resolve_shared_buffer( batch = batch.to(self.device) return batch + def _send_models_to_device(self): + self.model.to(self.device) + self.sampling_model.to(self.device) + if self.world_size > 1: + self.model = DistributedDataParallel( + self.model.to(self.rank), + device_ids=[self.rank], + output_device=self.rank + ) + if self.sampling_model is not self.model: + self.sampling_model = DistributedDataParallel( + self.sampling_model.to(self.rank), + device_ids=[self.rank], + output_device=self.rank + ) + def run(self, logger=None): """Trains the GFN for `num_training_steps` minibatches, performing validation every `validate_every` minibatches. @@ -266,6 +288,8 @@ def run(self, logger=None): logger = create_logger(logfile=self.cfg.log_dir + "/train.log") self.model.to(self.device) self.sampling_model.to(self.device) + import threading + self.model.lock = threading.Lock() # This is created here because you can't pickle a lock, and model is deepcopied -> sampling_model epoch_length = max(len(self.training_data), 1) valid_freq = self.cfg.validate_every # If checkpoint_every is not specified, checkpoint at every validation epoch @@ -294,13 +318,23 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue + info = self.train_batch(batch, epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() - self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) + if self.world_size > 1: + all_info_vals = torch.zeros(len(info)).to(self.rank) + for i, k in enumerate(sorted(info.keys())): + all_info_vals[i] = info[k] + dist.all_reduce(all_info_vals, op=dist.ReduceOp.SUM) + for i, k in enumerate(sorted(info.keys())): + info[k] = all_info_vals[i].item() / self.world_size + if self.rank == 0: + self.log(info, it, 'train') + if valid_freq > 0 and it % valid_freq == 0: for batch in valid_dl: batch = self._maybe_resolve_shared_buffer(batch, valid_dl) @@ -346,7 +380,7 @@ def run(self, logger=None): del final_dl def terminate(self): - logger = logging.getLogger("logger") + logger = logging.getLogger("gflownet") for handler in logger.handlers: handler.close() diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index f9d32df3..b6543fe3 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -5,7 +5,7 @@ import torch -def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): +def create_logger(name="gflownet", loglevel=logging.INFO, logfile=None, streamHandle=True): logger = logging.getLogger(name) logger.setLevel(loglevel) while len([logger.removeHandler(i) for i in logger.handlers]): diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 7ed734bd..86081038 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -8,6 +8,8 @@ import torch import torch.multiprocessing as mp +DO_PIN_BUFFERS = False # So actually, we don't really need to pin buffers, but it's a good idea in some cases. + # The shared memory is already a good step. class SharedPinnedBuffer: def __init__(self, size): @@ -17,15 +19,15 @@ def __init__(self, size): self.lock = mp.Lock() self.do_unreg = False - if not self.buffer.is_pinned(): + if DO_PIN_BUFFERS and not self.buffer.is_pinned(): # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to # pin it again; doing so will raise a CUDA error cudart = torch.cuda.cudart() r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) assert r == 0 self.do_unreg = True # But then we need to unregister it later + assert self.buffer.is_pinned() assert self.buffer.is_shared() - assert self.buffer.is_pinned() def __del__(self): if torch.utils.data.get_worker_info() is None: @@ -265,6 +267,8 @@ def run(self): break timeouts = 0 attr, args, kwargs = r + if hasattr(self.obj, 'lock'): + f.lock.acquire() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -287,6 +291,8 @@ def run(self): else: msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) + if hasattr(self.obj, 'lock'): + f.lock.release() def terminate(self): self.stop.set() diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index ae544ec5..2fc6e113 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -3,7 +3,8 @@ from typing import Iterable import torch - +import torch.distributed +import torch.utils.data class SQLiteLogHook: def __init__(self, log_dir, ctx) -> None: @@ -16,6 +17,8 @@ def __call__(self, trajs, rewards, obj_props, cond_info): if self.log is None: worker_info = torch.utils.data.get_worker_info() self._wid = worker_info.id if worker_info is not None else 0 + if torch.distributed.is_initialized(): + self._wid = torch.distributed.get_rank() * (worker_info.num_workers if worker_info is not None else 1) + self._wid os.makedirs(self.log_dir, exist_ok=True) self.log_path = f"{self.log_dir}/generated_objs_{self._wid}.db" self.log = SQLiteLog() From 67f4b621ddb530ff3859da7dedc15f0b87eac313 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 28 Aug 2024 14:57:04 -0600 Subject: [PATCH 23/31] C mol valence fix, mask-backwards sample, MLE in TB, priority replay, Pad action, GPS architecture, some support for distributed DDP --- src/C/mol_graph_to_Data.c | 62 ++++++++-- src/gflownet/algo/config.py | 5 + src/gflownet/algo/graph_sampling.py | 7 +- src/gflownet/algo/trajectory_balance.py | 19 ++- src/gflownet/data/config.py | 2 + src/gflownet/data/data_source.py | 81 +++++++++--- src/gflownet/data/replay_buffer.py | 48 +++++++- src/gflownet/envs/graph_building_env.py | 20 ++- src/gflownet/envs/mol_building_env.py | 14 +++ src/gflownet/models/config.py | 3 +- src/gflownet/models/graph_transformer.py | 149 +++++++++++++++-------- src/gflownet/online_trainer.py | 26 +++- src/gflownet/trainer.py | 23 ++-- src/gflownet/utils/misc.py | 12 +- src/gflownet/utils/sqlite_log.py | 7 +- 15 files changed, 355 insertions(+), 123 deletions(-) diff --git a/src/C/mol_graph_to_Data.c b/src/C/mol_graph_to_Data.c index e6d5fc38..81e673f8 100644 --- a/src/C/mol_graph_to_Data.c +++ b/src/C/mol_graph_to_Data.c @@ -52,6 +52,7 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { int atom_valences[num_atom_types]; float used_valences[g->num_nodes]; int max_valence[g->num_nodes]; + int is_hypervalent[g->num_nodes]; PyObject *_max_atom_valence = PyObject_GetAttrString(ctx, "_max_atom_valence"); // new ref for (int i = 0; i < num_atom_types; i++) { atom_valences[i] = PyLong_AsLong(PyDict_GetItem(_max_atom_valence, PyList_GetItem(atom_types, i))); // borrowed @@ -71,9 +72,12 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { PyObject *N_str = PyUnicode_FromString("N"); Py_ssize_t nitro_attr_value = PySequence_Index(PyDict_GetItemString(gd->node_values, "v"), N_str); Py_DECREF(N_str); + PyObject *O_str = PyUnicode_FromString("O"); + Py_ssize_t oxy_attr_value = PySequence_Index(PyDict_GetItemString(gd->node_values, "v"), O_str); + Py_DECREF(O_str); for (int i = 0; i < g->num_nodes; i++) { used_valences[i] = max_valence[i] = v_val[i] = charge_val[i] = explH_val[i] = chi_val[i] = noImpl_val[i] = - num_set_attribs[i] = 0; + num_set_attribs[i] = is_hypervalent[i] = 0; } for (int i = 0; i < g->num_node_attrs; i++) { int node_pos = g->node_attrs[3 * i]; @@ -85,7 +89,6 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { max_valence[node_pos] += atom_valences[v_val[node_pos]]; } else if (attr_type == charge_idx) { charge_val[node_pos] = attr_value; - max_valence[node_pos] -= 1; // If we change the possible charge ranges from [0,1,-1] this won't work } else if (attr_type == explH_idx) { explH_val[node_pos] = attr_value; max_valence[node_pos] -= attr_value; @@ -118,11 +121,40 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { used_valences[u] += bond_valence[bond_val[i]]; used_valences[v] += bond_valence[bond_val[i]]; } - // Correct for the valence of charge Nitro atoms + + int is_any_hypervalent = 0; + // Correct for the valence of charge Nitro and Oxygen atoms for (int i = 0; i < g->num_nodes; i++) { - if (v_val[i] == nitro_attr_value && charge_val[i] == 1) { - max_valence[i] += 2; + if (charge_val[i] != 0){ + if ((v_val[i] == nitro_attr_value || v_val[i] == oxy_attr_value) + && charge_val[i] == 1) { + max_valence[i] += 1; + }else{ + max_valence[i] -= 1; + } } + + /* TODO: Figure out this charge thing... + wait or is the problem that I'm _removing_ valence for a + charged mol?? + else if (charge_val[i] == 1) { + // well... maybe this is true of the general case? + // if an atom is positively charged then we can just bond it to more stuff? + // why 2 for N? + max_valence[i] += 1; + }*/ + // Determine hypervalent atoms + // Hypervalent atoms are atoms that have more than the maximum valence for their atom type because of a positive + // charge, but don't have this extra charge filled with bonds, or don't have no_impl set to 1 + /*printf("%d used_valences: %f, max_valence: %d, charge_val: %d, noImpl_val: %d, atom_valences: %d -> %d\n", i, + used_valences[i], max_valence[i], charge_val[i], noImpl_val[i], atom_valences[v_val[i]], + (used_valences[i] >= atom_valences[v_val[i]] && charge_val[i] == 1 && + !(used_valences[i] == max_valence[i] || noImpl_val[i] == 1))); + fflush(stdout);*/ + /*if (used_valences[i] >= atom_valences[v_val[i]] && charge_val[i] == 1 && + !(used_valences[i] == max_valence[i] || noImpl_val[i] == 1)) { + is_hypervalent[i] = 1; + is_any_hypervalent = 1; + }*/ } // Compute creatable edges char can_create_edge[g->num_nodes][g->num_nodes]; @@ -224,7 +256,7 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { if (g->degrees[i] <= 1 && !has_connecting_edge_attr_set[i] && num_set_attribs[i] == 1) { remove_node_mask[i] = 1; } - int is_nitro = v_val[i] == nitro_attr_value; + int is_special_valence = v_val[i] == nitro_attr_value || v_val[i] == oxy_attr_value; for (int j = 0; j < 5; j++) { int one_hot_idx = gd->node_attr_offsets[j] + _node_attrs[j][i]; node_feat[i * gd->num_node_dim + one_hot_idx] = 1; @@ -239,24 +271,30 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { // Special case: if the node is a positively charged Nitrogen, and the valence is maxed out (i.e. 5) // the agent cannot remove the charge, it has to remove the bonds (or bond attrs) making this valence // maxed out first. - if (j == charge_idx && is_nitro && used_valences[i] >= max_valence[i] && charge_val[i] == 1) + if (j == charge_idx && is_special_valence && used_valences[i] >= max_valence[i] && charge_val[i] == 1) + //if (j == charge_idx && used_valences[i] >= max_valence[i] && charge_val[i] == 1) continue; remove_node_attr_mask[i * gd->num_settable_node_attrs + j] = 1; } else { // Special case for N: adding charge to N increases its valence by 2, so we don't want to prevent that // action, even if the max_valence is "full" (for v=3) - if (j == charge_idx && used_valences[i] >= (max_valence[i] + (is_nitro ? 2 : 0))) // charge - continue; + //if (j == charge_idx && is_nitro && used_valences[i] >= (max_valence[i] + (is_nitro ? 2 : 0))) // charge + // continue; if (j == explH_idx && used_valences[i] >= max_valence[i]) // expl_H continue; // Following on the special case, we made it here, now if the node is not yet charged but already // has a valence of 3, the only charge we can add is a +1, which is the 0th logit // (again this assumes charge ranges are [0,1,-1]) - if (j == charge_idx && is_nitro && used_valences[i] >= max_valence[i] && charge_val[i] == 0){ - // printf("Setting 1: %d\n", i * gd->num_node_attr_logits + logit_slice_start); + //if (j == charge_idx && is_nitro && used_valences[i] >= max_valence[i] && charge_val[i] == 0){ + //printf("valences: %f %d %d %d\n", used_valences[i], max_valence[i], charge_val[i], j); + if (j == charge_idx && is_special_valence && used_valences[i] >= max_valence[i]){ + //printf("Setting 1: %d\n", i * gd->num_node_attr_logits + logit_slice_start); set_node_attr_mask[i * gd->num_node_attr_logits + logit_slice_start] = 1; continue; } + // Next, if the node is not yet charged, we can only add a charge if the valence is not maxed out + if (j == charge_idx && used_valences[i] >= max_valence[i]) + continue; // printf("Setting: %d %d, %d %d\n", j, i * gd->num_node_attr_logits + logit_slice_start, logit_slice_end, logit_slice_start); memsetf(set_node_attr_mask + i * gd->num_node_attr_logits + logit_slice_start, 1, (logit_slice_end - logit_slice_start)); @@ -335,7 +373,7 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { raise(SIGINT); return NULL; } - *stop_mask = g->num_nodes > 0 ? 1 : 0; + *stop_mask = (g->num_nodes > 0 ? 1 : 0) && !is_any_hypervalent; if (RETURN_C_DATA) { // Data *data_obj = DataType.tp_new(&DataType, NULL, NULL); Data *data_obj = (Data *)PyObject_CallObject((PyObject *)&DataType, NULL); diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 4dd9cbfe..b27ecb58 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -102,15 +102,19 @@ class TBConfig(StrictDataClass): do_parameterize_p_b: bool = False do_predict_n: bool = False do_sample_p_b: bool = False + do_sample_using_masks: bool = False do_length_normalize: bool = False subtb_max_len: int = 128 Z_learning_rate: float = 1e-4 Z_lr_decay: float = 50_000 + regularize_logZ: Optional[float] = None cum_subtb: bool = True loss_fn: LossFN = LossFN.MSE loss_fn_par: float = 1.0 n_loss: NLoss = NLoss.none n_loss_multiplier: float = 1.0 + tb_loss_multiplier: float = 1.0 + mle_loss_multiplier: float = 0.0 backward_policy: Backward = Backward.Uniform @@ -196,6 +200,7 @@ class AlgoConfig(StrictDataClass): train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 + grad_acc_steps: int = 1 tb: TBConfig = field(default_factory=TBConfig) moql: MOQLConfig = field(default_factory=MOQLConfig) a2c: A2CConfig = field(default_factory=A2CConfig) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index b79fed0d..110deabd 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -230,6 +230,7 @@ def sample_backward_from_graphs( Probability of taking a random action (only used if model parameterizes P_B) """ + starting_graphs = list(graphs) dev = get_worker_device() n = len(graphs) done = [False] * n @@ -238,7 +239,7 @@ def sample_backward_from_graphs( "traj": [(graphs[i], GraphAction(GraphActionType.Stop))], "is_valid": True, "is_sink": [1], - "bck_a": [GraphAction(GraphActionType.Stop)], + "bck_a": [GraphAction(GraphActionType.Pad)], "bck_logprobs": [0.0], "result": graphs[i], } @@ -293,10 +294,10 @@ def not_done(lst): for i in range(n): # See comments in sample_from_model data[i]["traj"] = data[i]["traj"][::-1] - data[i]["bck_a"] = [GraphAction(GraphActionType.Stop)] + data[i]["bck_a"][::-1] + data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] data[i]["is_sink"] = data[i]["is_sink"][::-1] data[i]["bck_logprobs"] = torch.tensor(data[i]["bck_logprobs"][::-1], device=dev).reshape(-1) if self.pad_with_terminal_state: - data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop))) + data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) data[i]["is_sink"].append(1) return data diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 06d021e0..e8b7fe64 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -219,6 +219,10 @@ def create_training_data_from_graphs( return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob ) + elif self.cfg.do_sample_using_masks: + return self.graph_sampler.sample_backward_from_graphs( + graphs, None, cond_info, random_action_prob + ) trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: n_back = [ @@ -575,13 +579,20 @@ def compute_batch_losses( num_bootstrap = num_bootstrap or len(log_rewards) reward_losses = self._loss(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap], self.reward_loss) - reward_loss = reward_losses.mean() * self.cfg.reward_loss_multiplier + reward_loss = reward_losses.mean() else: reward_loss = 0 + + log_Z_reg_loss = (log_Z - self.cfg.regularize_logZ).pow(2).mean() if self.cfg.regularize_logZ is not None else 0 n_loss = n_loss.mean() tb_loss = traj_losses.mean() - loss = tb_loss + reward_loss + self.cfg.n_loss_multiplier * n_loss + mle_loss = -traj_log_p_F.mean() + loss = (tb_loss * self.cfg.tb_loss_multiplier + + reward_loss * self.cfg.reward_loss_multiplier + + n_loss * self.cfg.n_loss_multiplier + + log_Z_reg_loss + + mle_loss * self.cfg.mle_loss_multiplier) info = { "reward_loss": reward_loss, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, @@ -592,7 +603,7 @@ def compute_batch_losses( "loss": loss.item(), "n_loss": n_loss, "tb_loss": tb_loss.item(), - "batch_entropy": -traj_log_p_F.mean(), + "batch_entropy": fwd_cat.entropy().mean(), "traj_lens": batch.traj_lens.float().mean(), 'avg_log_reward': clip_log_R.mean(), } @@ -611,6 +622,8 @@ def compute_batch_losses( d = d * d d[final_graph_idx] = 0 info["n_loss_maxent"] = scatter(d, batch_idx, dim=0, dim_size=num_trajs, reduce="sum").mean() + if self.cfg.mle_loss_multiplier != 0: + info["mle_loss"] = mle_loss.item() if not torch.isfinite(loss): import pdb; pdb.set_trace() diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index 1e1b7f98..d4a51bc4 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -34,3 +34,5 @@ class ReplayConfig(StrictDataClass): hindsight_ratio: float = 0 num_from_replay: Optional[int] = None num_new_samples: Optional[int] = None + keep_highest_rewards: bool = False + keep_only_uniques: bool = False diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 536e9ca1..94c54534 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -9,8 +9,9 @@ from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu -from gflownet.envs.graph_building_env import GraphBuildingEnvContext +from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical, action_type_to_mask from gflownet.envs.seq_building_env import SeqBatch +from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.utils.misc import get_worker_rng from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer @@ -64,24 +65,63 @@ def __iter__(self): self.rng = get_worker_rng() its = [i() for i in self.iterators] self.algo.set_is_eval(self.is_algo_eval) + print("New iterator", its) + err_tol = 10 while True: - with self.global_step_count_lock: - self.current_iter = self.global_step_count.item() - self.global_step_count += 1 - iterator_outputs = [next(i, None) for i in its] - if any(i is None for i in iterator_outputs): - if not all(i is None for i in iterator_outputs): - warnings.warn("Some iterators are done, but not all. You may be mixing incompatible iterators.") - iterator_outputs = [i for i in iterator_outputs if i is not None] - else: - break - traj_lists, batch_infos = zip(*iterator_outputs) - trajs = sum(traj_lists, []) - # Merge all the dicts into one - batch_info = {} - for d in batch_infos: - batch_info.update(d) - yield self.create_batch(trajs, batch_info) + try: + with self.global_step_count_lock: + self.current_iter = self.global_step_count.item() + self.global_step_count += 1 + iterator_outputs = [next(i, None) for i in its] + if any(i is None for i in iterator_outputs): + if not all(i is None for i in iterator_outputs): + warnings.warn("Some iterators are done, but not all. You may be mixing incompatible iterators.") + iterator_outputs = [i for i in iterator_outputs if i is not None] + else: + break + traj_lists, batch_infos = zip(*iterator_outputs) + trajs = sum(traj_lists, []) + # Merge all the dicts into one + batch_info = {} + for d in batch_infos: + batch_info.update(d) + yield self.create_batch(trajs, batch_info) + except Exception as e: + if 1: + raise e + err_tol -= 1 + if err_tol == 0: + raise e + print(f"Error in DataSource: {e} [tol={err_tol}]") + # print full traceback + import traceback, sys + traceback.print_exc() + traceback.print_exc(file=sys.stderr) + continue + err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break + + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [action_type_to_mask(t, batch) for t in atypes], + [GraphTransformerGFN.action_type_to_key(t) for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits, pad_value=1.0) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) def do_sample_model(self, model, num_from_policy, num_new_replay_samples=None): if num_new_replay_samples is not None: @@ -244,6 +284,7 @@ def create_batch(self, trajs, batch_info): batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) batch.obj_props = torch.stack([t["obj_props"] for t in trajs]) + self.validate_batch(batch, trajs) return self._maybe_put_in_mp_buffer(batch) def compute_properties(self, trajs, mark_as_online=False): @@ -280,7 +321,9 @@ def compute_log_rewards(self, trajs): def send_to_replay(self, trajs): if self.replay_buffer is not None: for t in trajs: - self.replay_buffer.push(t, t["log_reward"], t["obj_props"], t["cond_info"], t["is_valid"]) + self.replay_buffer.push(t, t["log_reward"], t["obj_props"], t["cond_info"], t["is_valid"], + unique_obj=self.ctx.get_unique_obj(t["result"]), + priority=t.get("priority", t['log_reward'].item())) def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 7fc95024..62b791d2 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -1,3 +1,5 @@ +import heapq +from threading import Lock from typing import List import numpy as np @@ -20,20 +22,54 @@ def __init__(self, cfg: Config): self.buffer: List[tuple] = [] self.position = 0 - def push(self, *args): + self.treat_as_heap = cfg.replay.keep_highest_rewards + self.filter_uniques = cfg.replay.keep_only_uniques + self._uniques = set() + + self._lock = Lock() + + def push(self, *args, unique_obj=None, priority=None): + """unique_obj must be hashable and comparable""" if len(self.buffer) == 0: self._input_size = len(args) else: assert self._input_size == len(args), "ReplayBuffer input size must be constant" - if len(self.buffer) < self.capacity: - self.buffer.append(None) + if self.filter_uniques and unique_obj in self._uniques: + return args = detach_and_cpu(args) - self.buffer[self.position] = args - self.position = (self.position + 1) % self.capacity + self._lock.acquire() + if self.treat_as_heap: + if len(self.buffer) >= self.capacity: + if priority is None or priority > self.buffer[0][0]: + # We will use self.position for tie-breaking + *_, pop_unique = heapq.heappushpop(self.buffer, (priority, self.position, args, unique_obj)) + self.position += 1 + if self.filter_uniques: + self._uniques.remove(pop_unique) + self._uniques.add(unique_obj) + else: + pass # If the priority is lower than the lowest in the heap, we don't add it + else: + heapq.heappush(self.buffer, (priority, self.position, args, unique_obj)) + self.position += 1 + if self.filter_uniques: + self._uniques.add(unique_obj) + else: + if len(self.buffer) < self.capacity: + self.buffer.append(None) + if self.filter_uniques: + if self.position == 0 and len(self.buffer) == self.capacity: + # We're about to wrap around, so remove the oldest element + self._uniques.remove(self.buffer[0][2]) + self.buffer[self.position] = (priority, args, unique_obj) + if self.filter_uniques: + self._uniques.add(unique_obj) + self.position = (self.position + 1) % self.capacity + self._lock.release() def sample(self, batch_size): idxs = get_worker_rng().choice(len(self.buffer), batch_size) - out = list(zip(*[self.buffer[idx] for idx in idxs])) + out = list(zip(*[self.buffer[idx][2] for idx in idxs])) for i in range(len(out)): # stack if all elements are numpy arrays or torch tensors # (this is much more efficient to send arrays through multiprocessing queues) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 63681154..2817f056 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -67,6 +67,8 @@ class GraphActionType(enum.Enum): RemoveEdge = enum.auto() RemoveNodeAttr = enum.auto() RemoveEdgeAttr = enum.auto() + # The Pad action is always unmasked, and always has logprob 0 + Pad = -1 @cached_property def cname(self): @@ -823,7 +825,7 @@ def argmax( # if it wants to convert these indices to env-compatible actions return argmaxes - def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, batch: torch.Tensor = None): + def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, batch: torch.Tensor = None, pad_value: float = 0.0): """The log-probability of a list of action tuples, effectively indexes `logprobs` using internal slice indices. @@ -848,12 +850,14 @@ def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, ba logprobs = self.logsoftmax() if batch is None: batch = torch.arange(N, device=self.dev) + pad = torch.tensor(pad_value, device=self.dev) # We want to do the equivalent of this: # [logprobs[t][row + self.slice[t][i], col] for i, (t, row, col) in zip(batch, actions)] # but faster. # each action is a 3-tuple ActionIndex (type, row, column), where type is the index of the action type group. - actions = torch.as_tensor(actions, device=self.dev, dtype=torch.long) + unclamped_actions = torch.as_tensor(actions, device=self.dev, dtype=torch.long) + actions = unclamped_actions.clamp(0) # Clamp to 0 to avoid the -1 Pad action assert actions.shape[0] == batch.shape[0] # Check there are as many actions as batch indices # To index the log probabilities efficiently, we will ravel the array, and compute the # indices of the raveled actions. @@ -878,7 +882,10 @@ def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, ba # This is the last index in the raveled tensor, therefore the offset is just the column value col_offsets = actions[:, 2] # Index the flattened array - return all_logprobs[t_offsets + row_offsets + col_offsets] + raw_logprobs = all_logprobs[t_offsets + row_offsets + col_offsets] + # Now we replaced the Pad actions with 0s + logprobs = raw_logprobs.where(unclamped_actions[:, 0] != GraphActionType.Pad.value, pad) + return logprobs def entropy(self, logprobs=None): """The entropy for each graph categorical in the batch @@ -900,7 +907,7 @@ def entropy(self, logprobs=None): [ scatter(im, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) for i, b, m in zip(logprobs, self.batch, masks) - for im in [i.masked_fill(m == 0.0, 0) if m is not None else i] + for im in [i.exp() * i.masked_fill(m == 0.0, 0) if m is not None else i.exp() * i] ] ) return entropy @@ -1030,8 +1037,13 @@ def log_n(self, g) -> float: def traj_log_n(self, traj): return [self.log_n(g) for g, _ in traj] + def get_unique_obj(self, g: Graph): + return None + def action_type_to_mask(t: GraphActionType, gbatch: gd.Batch, assert_mask_exists: bool = False): + if t == GraphActionType.Pad: + return torch.ones((1, 1), device=gbatch.x.device) if assert_mask_exists: assert hasattr(gbatch, t.mask_name), f"Mask {t.mask_name} not found in graph data" return getattr(gbatch, t.mask_name) if hasattr(gbatch, t.mask_name) else torch.ones((1, 1), device=gbatch.x.device) diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 9ef30fd0..8b54e2c3 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -177,6 +177,9 @@ def _make_C_graph(self): def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" + if aidx.action_type == -1: + # Pad action + return GraphAction(GraphActionType.Pad) if fwd: t = self.action_type_order[aidx.action_type] else: @@ -217,6 +220,8 @@ def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionIndex: """Translate a GraphAction to an index tuple""" + if action.action is GraphActionType.Pad: + return ActionIndex(action_type=-1, row_idx=0, col_idx=0) for u in [self.action_type_order, self.bck_action_type_order]: if action.action in u: type_idx = u.index(action.action) @@ -500,3 +505,12 @@ def object_to_log_repr(self, g: Graph): return Chem.MolToSmiles(mol) except Exception: return "" + + def get_unique_obj(self, g: Graph): + """Convert a Graph to a canonical SMILES representation""" + try: + mol = self.graph_to_obj(g) + assert mol is not None + return Chem.CanonSmiles(Chem.MolToSmiles(mol)) + except Exception: + return "" \ No newline at end of file diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index 912ff9f8..c9f52c01 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -8,8 +8,9 @@ class GraphTransformerConfig(StrictDataClass): num_heads: int = 2 ln_type: str = "pre" - num_mlp_layers: int = 0 + num_mlp_layers: int = 1 concat_heads: bool = True + conv_type: str = 'Transformer' class SeqPosEnc(int, Enum): diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 6807dc5b..9f518504 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -20,6 +20,60 @@ def mlp(n_in, n_hid, n_out, n_layer, act=nn.LeakyReLU): return nn.Sequential(*sum([[nn.Linear(n[i], n[i + 1]), act()] for i in range(n_layer + 1)], [])[:-1]) +class GTGENLayer(nn.Module): + def __init__(self, num_emb, num_heads, concat, num_mlp_layers, ln_type): + super().__init__() + assert ln_type in ["pre", "post"] + self.ln_type = ln_type + n_att = num_emb * num_heads if concat else num_emb + self.gen = gnn.GENConv(num_emb, num_emb, num_layers=1, aggr="add", norm=None) + self.conv = gnn.TransformerConv(num_emb * 2, n_att // num_heads, edge_dim=num_emb, heads=num_heads) + self.linear = nn.Linear(n_att, num_emb) + self.norm1 = gnn.LayerNorm(num_emb, affine=False) + self.ff = mlp(num_emb, num_emb * 4, num_emb, num_mlp_layers) + self.norm2 = gnn.LayerNorm(num_emb, affine=False) + self.cscale = nn.Linear(num_emb, num_emb * 2) + + def forward(self, x, edge_index, edge_attr, batch, c): + cs = self.cscale(c) + if self.ln_type == "post": + agg = self.gen(x, edge_index, edge_attr) + l_h = self.linear(self.conv(torch.cat([x, agg], 1), edge_index, edge_attr)) + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + x = self.norm1(x + l_h * scale + shift, batch) + x = self.norm2(x + self.ff(x), batch) + else: + x_norm = self.norm1(x, batch) + agg = self.gen(x_norm, edge_index, edge_attr) + l_h = self.linear(self.conv(torch.cat([x, agg], 1), edge_index, edge_attr)) + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + x = x + l_h * scale + shift + x = x + self.ff(self.norm2(x, batch)) + return x + + +class GPSLayer(nn.Module): + def __init__(self, num_emb, num_heads, num_mlp_layers, residual=False): + super().__init__() + self.conv = gnn.GPSConv(num_emb, + gnn.GINEConv(mlp(num_emb, num_emb, num_emb, num_mlp_layers), edge_dim=num_emb), + num_heads, + norm='layer_norm') + self.cscale = nn.Linear(num_emb, num_emb * 2) + self.residual = residual + + def forward(self, x, edge_index, edge_attr, batch, c): + cs = self.cscale(c) + l_h = self.conv(x, edge_index, batch, edge_attr=edge_attr) + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + if self.residual: + x = x + l_h * scale + shift + else: + x = l_h * scale + shift + return x + + + class GraphTransformer(nn.Module): """An agnostic GraphTransformer class, and the main model used by other model classes @@ -34,7 +88,8 @@ class GraphTransformer(nn.Module): """ def __init__( - self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre", concat=True + self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre", concat=True, + num_mlp_layers=1, conv_type='Transformer', ): """ Parameters @@ -65,30 +120,20 @@ def __init__( super().__init__() self.num_layers = num_layers self.num_noise = num_noise - assert ln_type in ["pre", "post"] self.ln_type = ln_type + self.conv_type = conv_type self.x2h = mlp(x_dim + num_noise, num_emb, num_emb, 2) self.e2h = mlp(e_dim, num_emb, num_emb, 2) self.c2h = mlp(max(1, g_dim), num_emb, num_emb, 2) - n_att = num_emb * num_heads if concat else num_emb - self.graph2emb = nn.ModuleList( - sum( - [ - [ - gnn.GENConv(num_emb, num_emb, num_layers=1, aggr="add", norm=None), - gnn.TransformerConv(num_emb * 2, n_att // num_heads, edge_dim=num_emb, heads=num_heads), - nn.Linear(n_att, num_emb), - gnn.LayerNorm(num_emb, affine=False), - mlp(num_emb, num_emb * 4, num_emb, 1), - gnn.LayerNorm(num_emb, affine=False), - nn.Linear(num_emb, num_emb * 2), - ] - for i in range(self.num_layers) - ], - [], - ) - ) + if conv_type == 'Transformer': + self.gnn = nn.ModuleList([ + GTGENLayer(num_emb, num_heads, concat, num_mlp_layers, ln_type) for _ in range(num_layers) + ]) + elif conv_type == 'GPS': + self.gnn = nn.ModuleList([ + GPSLayer(num_emb, num_heads, num_mlp_layers) for _ in range(num_layers) + ]) def forward(self, g: gd.Batch): """Forward pass @@ -114,39 +159,35 @@ def forward(self, g: gd.Batch): e = self.e2h(g.edge_attr) c = self.c2h(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] - # Augment the edges with a new edge to the conditioning - # information node. This new node is connected to every node - # within its graph. - u, v = torch.arange(num_total_nodes, device=o.device), g.batch + num_total_nodes - aug_edge_index = torch.cat([g.edge_index, torch.stack([u, v]), torch.stack([v, u])], 1) - e_p = torch.zeros((num_total_nodes * 2, e.shape[1]), device=g.x.device) - e_p[:, 0] = 1 # Manually create a bias term - aug_e = torch.cat([e, e_p], 0) - aug_edge_index, aug_e = add_self_loops(aug_edge_index, aug_e, "mean") - aug_batch = torch.cat([g.batch, torch.arange(c.shape[0], device=o.device)], 0) - - # Append the conditioning information node embedding to o - o = torch.cat([o, c], 0) + if self.conv_type == 'Transformer': + # Augment the edges with a new edge to the conditioning + # information node. This new node is connected to every node + # within its graph. + u, v = torch.arange(num_total_nodes, device=o.device), g.batch + num_total_nodes + aug_edge_index = torch.cat([g.edge_index, torch.stack([u, v]), torch.stack([v, u])], 1) + e_p = torch.zeros((num_total_nodes * 2, e.shape[1]), device=g.x.device) + e_p[:, 0] = 1 # Manually create a bias term + aug_e = torch.cat([e, e_p], 0) + aug_edge_index, aug_e = add_self_loops(aug_edge_index, aug_e, "mean") + aug_batch = torch.cat([g.batch, torch.arange(c.shape[0], device=o.device)], 0) + # Append the conditioning information node embedding to o + o = torch.cat([o, c], 0) + else: + # GPS doesn't really need a virtual node since it's doing global pairwise attention + o = o + c[g.batch] + aug_batch = g.batch + aug_edge_index, aug_e = g.edge_index, e + for i in range(self.num_layers): - # Run the graph transformer forward - gen, trans, linear, norm1, ff, norm2, cscale = self.graph2emb[i * 7 : (i + 1) * 7] - cs = cscale(c[aug_batch]) - if self.ln_type == "post": - agg = gen(o, aug_edge_index, aug_e) - l_h = linear(trans(torch.cat([o, agg], 1), aug_edge_index, aug_e)) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] - o = norm1(o + l_h * scale + shift, aug_batch) - o = norm2(o + ff(o), aug_batch) - else: - o_norm = norm1(o, aug_batch) - agg = gen(o_norm, aug_edge_index, aug_e) - l_h = linear(trans(torch.cat([o_norm, agg], 1), aug_edge_index, aug_e)) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] - o = o + l_h * scale + shift - o = o + ff(norm2(o, aug_batch)) - - o_final = o[: -c.shape[0]] - glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) + o = self.gnn[i](o, aug_edge_index, aug_e, aug_batch, c[aug_batch]) + + if self.conv_type == 'Transformer': + # Remove the conditioning information node embedding + o_final = o[: -c.shape[0]] + glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) + else: + o_final = o + glob = gnn.global_mean_pool(o_final, g.batch) return o_final, glob @@ -199,11 +240,13 @@ def __init__( num_heads=cfg.model.graph_transformer.num_heads, ln_type=cfg.model.graph_transformer.ln_type, concat=cfg.model.graph_transformer.concat_heads, + num_mlp_layers=cfg.model.graph_transformer.num_mlp_layers, + conv_type=cfg.model.graph_transformer.conv_type, ) self.env_ctx = env_ctx num_emb = cfg.model.num_emb num_final = num_emb - num_glob_final = num_emb * 2 + num_glob_final = num_emb * 2 if cfg.model.graph_transformer.conv_type == 'Transformer' else num_emb num_edge_feat = num_emb if env_ctx.edges_are_unordered else num_emb * 2 self.edges_are_duplicated = env_ctx.edges_are_duplicated self.edges_are_unordered = env_ctx.edges_are_unordered diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 13dd0e48..17091d84 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -65,6 +65,14 @@ def _opt(self, params, lr=None, momentum=None): weight_decay=self.cfg.opt.weight_decay, eps=self.cfg.opt.adam_eps, ) + elif self.cfg.opt.opt == "adamW": + return torch.optim.AdamW( + params, + lr, + (momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) raise NotImplementedError(f"{self.cfg.opt.opt} is not implemented") @@ -121,12 +129,18 @@ def setup(self): with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: f.write(yaml_cfg) - def step(self, loss: Tensor): + def step(self, loss: Tensor, train_it: int): loss.backward() - with torch.no_grad(): - g0 = model_grad_norm(self.model) - self.clip_grad_callback(self.model.parameters()) - g1 = model_grad_norm(self.model) + info = {} + if train_it % self.cfg.algo.grad_acc_steps != 0: + return info + if self.cfg.opt.clip_grad_type is not None: + with torch.no_grad(): + g0 = model_grad_norm(self.model) + self.clip_grad_callback(self.model.parameters()) + g1 = model_grad_norm(self.model) + info["grad_norm"] = g0.item() + info["grad_norm_clip"] = g1.item() self.opt.step() self.opt.zero_grad() self.lr_sched.step() @@ -137,7 +151,7 @@ def step(self, loss: Tensor): if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) - return {"grad_norm": g0, "grad_norm_clip": g1} + return info class AvgRewardHook: diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 73cb07a2..12c413c6 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -106,18 +106,19 @@ def setup_algo(self): def setup_data(self): pass - def step(self, loss: Tensor): + def step(self, loss: Tensor, train_it: int): raise NotImplementedError() def setup(self): - if os.path.exists(self.cfg.log_dir): - if self.cfg.overwrite_existing_exp: - shutil.rmtree(self.cfg.log_dir) - else: - raise ValueError( - f"Log dir {self.cfg.log_dir} already exists. Set overwrite_existing_exp=True to delete it." - ) - os.makedirs(self.cfg.log_dir) + if self.rank == 0: + if os.path.exists(self.cfg.log_dir): + if self.cfg.overwrite_existing_exp: + shutil.rmtree(self.cfg.log_dir) + else: + raise ValueError( + f"Log dir {self.cfg.log_dir} already exists. Set overwrite_existing_exp=True to delete it." + ) + os.makedirs(self.cfg.log_dir) RDLogger.DisableLog("rdApp.*") set_worker_rng_seed(self.cfg.seed) @@ -228,7 +229,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") - step_info = self.step(loss) + step_info = self.step(loss, train_it) self.algo.step() # This also isn't used anywhere? if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") @@ -392,6 +393,8 @@ def terminate(self): terminate() def _save_state(self, it): + if self.rank != 0: + return state = { "models_state_dict": [self.model.state_dict()], "cfg": self.cfg, diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index b6543fe3..37b7d022 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -3,6 +3,8 @@ import numpy as np import torch +import torch.distributed +import torch.utils.data def create_logger(name="gflownet", loglevel=logging.INFO, logfile=None, streamHandle=True): @@ -32,12 +34,18 @@ def create_logger(name="gflownet", loglevel=logging.INFO, logfile=None, streamHa _worker_rng_seed = [142857] _main_process_device = [torch.device("cpu")] - -def get_worker_rng(): +def get_this_wid(): worker_info = torch.utils.data.get_worker_info() wid = worker_info.id if worker_info is not None else 0 + if torch.distributed.is_initialized(): + wid = torch.distributed.get_rank() * (worker_info.num_workers if worker_info is not None else 1) + wid + return wid + +def get_worker_rng(): + wid = get_this_wid() if wid not in _worker_rngs: _worker_rngs[wid] = np.random.RandomState(_worker_rng_seed[0] + wid) + torch.manual_seed(_worker_rng_seed[0] + wid) return _worker_rngs[wid] diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index 2fc6e113..64ebdf5e 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -6,6 +6,8 @@ import torch.distributed import torch.utils.data +from gflownet.utils.misc import get_this_wid + class SQLiteLogHook: def __init__(self, log_dir, ctx) -> None: self.log = None # Only initialized in __call__, which will occur inside the worker @@ -15,10 +17,7 @@ def __init__(self, log_dir, ctx) -> None: def __call__(self, trajs, rewards, obj_props, cond_info): if self.log is None: - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - if torch.distributed.is_initialized(): - self._wid = torch.distributed.get_rank() * (worker_info.num_workers if worker_info is not None else 1) + self._wid + self._wid = get_this_wid() os.makedirs(self.log_dir, exist_ok=True) self.log_path = f"{self.log_dir}/generated_objs_{self._wid}.db" self.log = SQLiteLog() From a1534be8f62683f6f5fa418587ee358656e83409 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Aug 2024 08:44:25 -0600 Subject: [PATCH 24/31] first (bad) attempt --- src/gflownet/algo/graph_sampling.py | 319 +++++++++++++++++----------- 1 file changed, 191 insertions(+), 128 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 110deabd..ec5cdc4a 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -87,124 +87,65 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor data: List[Dict] A list of trajectories. Each trajectory is a dict with keys - trajs: List[Tuple[Graph, GraphAction]], the list of states and actions + - bck_a: List[GraphAction], the reverse actions + - is_sink: List[int], sink states have P_B = 1 - fwd_logprob: sum logprobs P_F - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx + - interm_rewards: List[float], intermediate rewards """ dev = get_worker_device() + rng = get_worker_rng() # This will be returned - data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)] - # Let's also keep track of trajectory statistics according to the model - fwd_logprob: List[List[Tensor]] = [[] for _ in range(n)] - bck_logprob: List[List[Tensor]] = [[] for _ in range(n)] - + data = [ + { + "traj": [], + "bck_a": [], # The reverse actions + "is_valid": True, + "is_sink": [], + "fwd_logprobs": [], + "bck_logprobs": [], + "interm_rewards": [], + } + for _ in range(n) + ] graphs = [self.env.new() for _ in range(n)] done = [False for _ in range(n)] - # TODO: instead of padding with Stop, we could have a virtual action whose probability - # always evaluates to 1. Presently, Stop should convert to a (0,0,0) ActionIndex, which should - # always be at least a valid index, and will be masked out anyways -- but this isn't ideal. - # Here we have to pad the backward actions with something, since the backward actions are - # evaluated at s_{t+1} not s_t. - bck_a = [[GraphAction(GraphActionType.Stop)] for _ in range(n)] - - rng = get_worker_rng() - - def not_done(lst): - return [e for i, e in enumerate(lst) if not done[i]] for t in range(self.max_len): - # Construct graphs for the trajectories that aren't yet done - torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] - not_done_mask = torch.tensor(done, device=dev).logical_not() - # Forward pass to get GraphActionCategorical - # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. - # TODO: compute bck_cat.log_prob(bck_a) when relevant - batch = self.ctx.collate(torch_graphs) - batch.cond_info = cond_info[not_done_mask] if cond_info is not None else None - fwd_cat, *_, log_reward_preds = model(batch.to(dev)) - if random_action_prob > 0: - # Device which graphs in the minibatch will get their action randomized - is_random_action = torch.tensor( - rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev - ).float() - # Set the logits to some large value to have a uniform distribution - fwd_cat.logits = [ - is_random_action[b][:, None] * torch.ones_like(i) * 100 + i * (1 - is_random_action[b][:, None]) - for i, b in zip(fwd_cat.logits, fwd_cat.batch) - ] - if self.sample_temp != 1: - sample_cat = copy.copy(fwd_cat) - sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] - actions = sample_cat.sample() - else: - actions = fwd_cat.sample() - graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)] - log_probs = fwd_cat.log_prob(actions) - # Step each trajectory, and accumulate statistics - for i, j in zip(not_done(range(n)), range(n)): - fwd_logprob[i].append(log_probs[j].unsqueeze(0)) - data[i]["traj"].append((graphs[i], graph_actions[j])) - bck_a[i].append(self.env.reverse(graphs[i], graph_actions[j])) - # Check if we're done - if graph_actions[j].action is GraphActionType.Stop: - done[i] = True - bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) - data[i]["is_sink"].append(1) - else: # If not done, try to step the self.environment - gp = graphs[i] - try: - # self.env.step can raise AssertionError if the action is illegal - gp = self.env.step(graphs[i], graph_actions[j]) - assert len(gp.nodes) <= self.max_nodes - except AssertionError: - done[i] = True - data[i]["is_valid"] = False - bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) - data[i]["is_sink"].append(1) - continue - if t == self.max_len - 1: - done[i] = True - # If no error, add to the trajectory - # P_B = uniform backward - n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) - bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log()) - data[i]["is_sink"].append(0) - graphs[i] = gp - if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): - # check if the graph is sane (e.g. RDKit can construct a molecule from it) otherwise - # treat the done action as illegal - data[i]["is_valid"] = False + self._forward_step(model, data, graphs, cond_info, t, done, rng, dev, random_action_prob) if all(done): break # is_sink indicates to a GFN algorithm that P_B(s) must be 1 - + # # There are 3 types of possible trajectories # A - ends with a stop action. traj = [..., (g, a), (gp, Stop)], P_B = [..., bck(gp), 1] # B - ends with an invalid action. = [..., (g, a)], = [..., 1] # C - ends at max_len. = [..., (g, a)], = [..., bck(gp)] - + # # Let's say we pad terminal states, then: # A - ends with a stop action. traj = [..., (g, a), (gp, Stop), (gp, None)], P_B = [..., bck(gp), 1, 1] # B - ends with an invalid action. = [..., (g, a), (g, None)], = [..., 1, 1] # C - ends at max_len. = [..., (g, a), (gp, None)], = [..., bck(gp), 1] # and then P_F(terminal) "must" be 1 + + self._wrap_up_fwd_trajs(data, graphs) + return data - for i in range(n): + def _wrap_up_fwd_trajs(self, data, graphs): + for i in range(len(data)): # If we're not bootstrapping, we could query the reward # model here, but this is expensive/impractical. Instead # just report forward and backward logprobs - data[i]["fwd_logprob"] = sum(fwd_logprob[i]) - data[i]["bck_logprob"] = sum(bck_logprob[i]) - data[i]["bck_logprobs"] = torch.stack(bck_logprob[i]).reshape(-1) + data[i]["fwd_logprobs"] = torch.stack(data[i]["fwd_logprobs"]).reshape(-1) + data[i]["bck_logprobs"] = torch.stack(data[i]["bck_logprobs"]).reshape(-1) + data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() + data[i]["bck_logprob"] = data[i]["bck_logprobs"].sum() data[i]["result"] = graphs[i] - data[i]["bck_a"] = bck_a[i] if self.pad_with_terminal_state: - # TODO: instead of padding with Stop, we could have a virtual action whose - # probability always evaluates to 1. - data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop))) + data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) data[i]["is_sink"].append(1) - return data def sample_backward_from_graphs( self, @@ -246,51 +187,12 @@ def sample_backward_from_graphs( for i in range(n) ] - def not_done(lst): - return [e for i, e in enumerate(lst) if not done[i]] - # TODO: This should be doable. if random_action_prob > 0: warnings.warn("Random action not implemented for backward sampling") while sum(done) < n: - torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] - not_done_mask = torch.tensor(done, device=dev).logical_not() - if model is not None: - gbatch = self.ctx.collate(torch_graphs).to(dev) - gbatch.cond_info = cond_info[not_done_mask] if cond_info is not None else None - _, bck_cat, *_ = model(gbatch) - else: - gbatch = self.ctx.collate(torch_graphs) - action_types = self.ctx.bck_action_type_order - action_masks = [action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] - bck_cat = GraphActionCategorical( - gbatch, - raw_logits=[torch.ones_like(m) for m in action_masks], - keys=[GraphTransformerGFN.action_type_to_key(t) for t in action_types], - action_masks=action_masks, - types=action_types, - ) - bck_actions = bck_cat.sample() - graph_bck_actions = [ - self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) - ] - bck_logprobs = bck_cat.log_prob(bck_actions) - - for i, j in zip(not_done(range(n)), range(n)): - if not done[i]: - g = graphs[i] - b_a = graph_bck_actions[j] - gp = self.env.step(g, b_a) - f_a = self.env.reverse(g, b_a) - graphs[i], f_a = relabel(gp, f_a) - data[i]["traj"].append((graphs[i], f_a)) - data[i]["bck_a"].append(b_a) - data[i]["is_sink"].append(0) - data[i]["bck_logprobs"].append(bck_logprobs[j].item()) - if len(graphs[i]) == 0: - done[i] = True - + self._backward_step(model, data, graphs, cond_info, done, dev) for i in range(n): # See comments in sample_from_model data[i]["traj"] = data[i]["traj"][::-1] @@ -301,3 +203,164 @@ def not_done(lst): data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) data[i]["is_sink"].append(1) return data + + def local_search_sample_from_model( + self, + model: nn.Module, + n: int, + cond_info: Optional[Tensor], + random_action_prob: float = 0.0, + num_ls_steps: int = 1, + num_bck_steps: int = 1, + ): + dev = get_worker_device() + rng = get_worker_rng() + # First get n trajectories + current_trajs = self.sample_from_model(model, n, cond_info, random_action_prob) + # Then we're going to perform num_ls_steps of local search, each with num_bck_steps backward steps. + # Each local search step is a kind of Metropolis-Hastings step, where we propose a new trajectory, which may + # be accepted or rejected based on the forward and backward probabilities and reward. + # Finally, we return all the trajectories that were sampled. + returned_trajs = current_trajs + + for mcmc_steps in range(num_ls_steps): + # Create new trajectories from the initial ones + bck_trajs = self.sample_backward_from_graphs( + [i["result"] for i in current_trajs], model, cond_info, random_action_prob + ) + # Now we truncate the trajectories, we want to remove at most the last num_bck_steps steps, and also + # remove the last step(s) if it is a stop (and/or pad) action. + stop = GraphActionType.Stop + num_pad = [ + (1 if i["traj"][-1][0].action == stop else 0) + int(self.pad_with_terminal_state) for i in current_trajs + ] + trunc_lens = [max(0, len(i["traj"]) - num_bck_steps - pad) for i, pad in zip(current_trajs, num_pad)] + new_trajs = [ + { + key: list(t[key][:k]) + for key in ["traj", "bck_a", "is_sink", "fwd_logprobs", "bck_logprobs", "interm_rewards"] + } + for t, k in zip(bck_trajs, trunc_lens) + ] + # Next we sample new endings for the truncated trajectories + graphs = [i["traj"][-1][0] for i in new_trajs] + done = [False] * n + + while not all(done): + self._forward_step(model, new_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob) + done = [d or len(t["traj"]) >= self.max_len for d, t in zip(done, new_trajs)] + self._wrap_up_fwd_trajs(new_trajs, graphs) + # We add those new trajectories to the list of returned trajectories + returned_trajs += new_trajs + # Finally, we replace the current trajectories with the new ones if they are accepted by MH + for i in range(n): + if new_trajs[i]["fwd_logprob"] + new_trajs[i]["bck_logprob"] > current_trajs[i]["fwd_logprob"] + current_trajs[i]["bck_logprob"]: + current_trajs[i] = new_trajs[i] + def _forward_step(self, model, data, graphs, cond_info, t, done, rng, dev, random_action_prob) -> None: + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] + + n = len(data) + # Construct graphs for the trajectories that aren't yet done + torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] + not_done_mask = torch.tensor(done, device=dev).logical_not() + # Forward pass to get GraphActionCategorical + # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. + # TODO: compute bck_cat.log_prob(bck_a) when relevant + batch = self.ctx.collate(torch_graphs) + batch.cond_info = cond_info[not_done_mask] if cond_info is not None else None + fwd_cat, *_ = model(batch.to(dev)) + if random_action_prob > 0: + # Device which graphs in the minibatch will get their action randomized + is_random_action = torch.tensor( + rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev + ).float() + # Set the logits to some large value to have a uniform distribution + fwd_cat.logits = [ + is_random_action[b][:, None] * torch.ones_like(i) * 100 + i * (1 - is_random_action[b][:, None]) + for i, b in zip(fwd_cat.logits, fwd_cat.batch) + ] + if self.sample_temp != 1: + sample_cat = copy.copy(fwd_cat) + sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] + actions = sample_cat.sample() + else: + actions = fwd_cat.sample() + graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)] + log_probs = fwd_cat.log_prob(actions) + # Step each trajectory, and accumulate statistics + for i, j in zip(not_done(range(n)), range(n)): + data[i]["fwd_logprob"].append(log_probs[j].unsqueeze(0)) + data[i]["traj"].append((graphs[i], graph_actions[j])) + data[i]["bck_a"].append(self.env.reverse(graphs[i], graph_actions[j])) + # Check if we're done + if graph_actions[j].action is GraphActionType.Stop: + done[i] = True + data[i]["bck_logprob"].append(torch.tensor([1.0], device=dev).log()) + data[i]["is_sink"].append(1) + else: # If not done, try to step the self.environment + gp = graphs[i] + try: + # self.env.step can raise AssertionError if the action is illegal + gp = self.env.step(graphs[i], graph_actions[j]) + assert len(gp.nodes) <= self.max_nodes + except AssertionError: + done[i] = True + data[i]["is_valid"] = False + data[i]["bck_logprob"].append(torch.tensor([1.0], device=dev).log()) + data[i]["is_sink"].append(1) + continue + if t == self.max_len - 1: + done[i] = True + # If no error, add to the trajectory + # P_B = uniform backward + n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) + data[i]["bck_logprob"].append(torch.tensor([1 / n_back], device=dev).log()) + data[i]["is_sink"].append(0) + graphs[i] = gp + if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): + # check if the graph is sane (e.g. RDKit can construct a molecule from it) otherwise + # treat the done action as illegal + data[i]["is_valid"] = False + # Nothing is returned, data is modified in place + + def _backward_step(self, model, data, graphs, cond_info, done, dev) -> None: + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] + + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] + not_done_mask = torch.tensor(done, device=dev).logical_not() + if model is not None: + gbatch = self.ctx.collate(torch_graphs).to(dev) + gbatch.cond_info = cond_info[not_done_mask] if cond_info is not None else None + _, bck_cat, *_ = model(gbatch) + else: + gbatch = self.ctx.collate(torch_graphs) + action_types = self.ctx.bck_action_type_order + action_masks = [action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] + bck_cat = GraphActionCategorical( + gbatch, + raw_logits=[torch.ones_like(m) for m in action_masks], + keys=[GraphTransformerGFN.action_type_to_key(t) for t in action_types], + action_masks=action_masks, + types=action_types, + ) + bck_actions = bck_cat.sample() + graph_bck_actions = [ + self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) + ] + bck_logprobs = bck_cat.log_prob(bck_actions) + + for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + if not done[i]: + g = graphs[i] + b_a = graph_bck_actions[j] + gp = self.env.step(g, b_a) + f_a = self.env.reverse(g, b_a) + graphs[i], f_a = relabel(gp, f_a) + data[i]["traj"].append((graphs[i], f_a)) + data[i]["bck_a"].append(b_a) + data[i]["is_sink"].append(0) + data[i]["bck_logprobs"].append(bck_logprobs[j].item()) + if len(graphs[i]) == 0: + done[i] = True From 4491f6b820628f3f1eeffce1568d5200dae186e2 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 30 Aug 2024 07:23:10 -0600 Subject: [PATCH 25/31] working local serach, cond_info dict pass, allow no log dir, fix Pad and U PB --- src/gflownet/__init__.py | 3 + src/gflownet/algo/config.py | 7 + src/gflownet/algo/graph_sampling.py | 243 +++++++++++++++++------- src/gflownet/algo/trajectory_balance.py | 31 ++- src/gflownet/config.py | 2 +- src/gflownet/data/data_source.py | 21 +- src/gflownet/envs/frag_mol_env.py | 5 + src/gflownet/envs/graph_building_env.py | 2 +- src/gflownet/online_trainer.py | 12 +- src/gflownet/trainer.py | 6 +- 10 files changed, 233 insertions(+), 99 deletions(-) diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 5415ecd8..6c790b4b 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -19,6 +19,7 @@ class GFNAlgorithm: updates: int = 0 global_cfg: Config is_eval: bool = False + requires_task: bool = False def step(self): self.updates += 1 # This isn't used anywhere? @@ -77,6 +78,8 @@ def get_random_action_prob(self, it: int): return self.global_cfg.algo.train_random_action_prob return 0 + def set_task(self, task): + raise NotImplementedError() class GFNTask: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index b27ecb58..cda6f53b 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -148,6 +148,12 @@ class SQLConfig(StrictDataClass): gamma: float = 1 penalty: float = -10 +@dataclass +class LSTBConfig(StrictDataClass): + num_ls_steps: int = 1 + num_bck_steps: int = 1 + accept_criteria: str = "deterministic" + @dataclass class AlgoConfig(StrictDataClass): @@ -206,3 +212,4 @@ class AlgoConfig(StrictDataClass): a2c: A2CConfig = field(default_factory=A2CConfig) fm: FMConfig = field(default_factory=FMConfig) sql: SQLConfig = field(default_factory=SQLConfig) + ls: LSTBConfig = field(default_factory=LSTBConfig) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index ec5cdc4a..7e71b120 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import List, Optional +from typing import Callable, List, Optional import torch import torch.nn as nn @@ -100,11 +100,11 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor data = [ { "traj": [], - "bck_a": [], # The reverse actions + "bck_a": [GraphAction(GraphActionType.Pad)], # The reverse actions "is_valid": True, "is_sink": [], "fwd_logprobs": [], - "bck_logprobs": [], + "U_bck_logprobs": [], "interm_rewards": [], } for _ in range(n) @@ -113,10 +113,11 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor done = [False for _ in range(n)] for t in range(self.max_len): + # This modifies `data` and `graphs` in place self._forward_step(model, data, graphs, cond_info, t, done, rng, dev, random_action_prob) if all(done): break - + # Note on is_sink and padding: # is_sink indicates to a GFN algorithm that P_B(s) must be 1 # # There are 3 types of possible trajectories @@ -129,23 +130,22 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor # B - ends with an invalid action. = [..., (g, a), (g, None)], = [..., 1, 1] # C - ends at max_len. = [..., (g, a), (gp, None)], = [..., bck(gp), 1] # and then P_F(terminal) "must" be 1 - - self._wrap_up_fwd_trajs(data, graphs) - return data - def _wrap_up_fwd_trajs(self, data, graphs): for i in range(len(data)): # If we're not bootstrapping, we could query the reward # model here, but this is expensive/impractical. Instead # just report forward and backward logprobs data[i]["fwd_logprobs"] = torch.stack(data[i]["fwd_logprobs"]).reshape(-1) - data[i]["bck_logprobs"] = torch.stack(data[i]["bck_logprobs"]).reshape(-1) + data[i]["U_bck_logprobs"] = torch.stack(data[i]["U_bck_logprobs"]).reshape(-1) data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() - data[i]["bck_logprob"] = data[i]["bck_logprobs"].sum() + data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() data[i]["result"] = graphs[i] if self.pad_with_terminal_state: data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) + data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) data[i]["is_sink"].append(1) + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) + return data def sample_backward_from_graphs( self, @@ -182,6 +182,7 @@ def sample_backward_from_graphs( "is_sink": [1], "bck_a": [GraphAction(GraphActionType.Pad)], "bck_logprobs": [0.0], + "U_bck_logprobs": [0.0], "result": graphs[i], } for i in range(n) @@ -193,15 +194,19 @@ def sample_backward_from_graphs( while sum(done) < n: self._backward_step(model, data, graphs, cond_info, done, dev) + for i in range(n): # See comments in sample_from_model data[i]["traj"] = data[i]["traj"][::-1] + # I think this pad is only necessary if we're padding terminal states??? data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] data[i]["is_sink"] = data[i]["is_sink"][::-1] - data[i]["bck_logprobs"] = torch.tensor(data[i]["bck_logprobs"][::-1], device=dev).reshape(-1) + data[i]["U_bck_logprobs"] = torch.tensor([0] + data[i]["U_bck_logprobs"][::-1], device=dev).reshape(-1) if self.pad_with_terminal_state: data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) + data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) data[i]["is_sink"].append(1) + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) return data def local_search_sample_from_model( @@ -212,64 +217,116 @@ def local_search_sample_from_model( random_action_prob: float = 0.0, num_ls_steps: int = 1, num_bck_steps: int = 1, + compute_reward: Optional[Callable] = None, + criteria: str = "deterministic", ): dev = get_worker_device() rng = get_worker_rng() # First get n trajectories current_trajs = self.sample_from_model(model, n, cond_info, random_action_prob) + compute_reward(current_trajs, cond_info) # in-place # Then we're going to perform num_ls_steps of local search, each with num_bck_steps backward steps. # Each local search step is a kind of Metropolis-Hastings step, where we propose a new trajectory, which may # be accepted or rejected based on the forward and backward probabilities and reward. # Finally, we return all the trajectories that were sampled. - returned_trajs = current_trajs + + # We keep the initial trajectories to return them at the end. We need to copy 'traj' to avoid modifying it + initial_trajs = [{k: v if k != "traj" else list(v) for k, v in t.items()} for t in current_trajs] + sampled_terminals = [] + if self.pad_with_terminal_state: + for t in current_trajs: + t["traj"] = t["traj"][:-1] # Remove the padding state for mcmc_steps in range(num_ls_steps): - # Create new trajectories from the initial ones - bck_trajs = self.sample_backward_from_graphs( - [i["result"] for i in current_trajs], model, cond_info, random_action_prob - ) - # Now we truncate the trajectories, we want to remove at most the last num_bck_steps steps, and also - # remove the last step(s) if it is a stop (and/or pad) action. + # First we must do a bit of accounting so that we can later prevent trajectories longer than max_len stop = GraphActionType.Stop - num_pad = [ - (1 if i["traj"][-1][0].action == stop else 0) + int(self.pad_with_terminal_state) for i in current_trajs - ] + num_pad = [(1 if t["traj"][-1][1].action == stop else 0) for t in current_trajs] trunc_lens = [max(0, len(i["traj"]) - num_bck_steps - pad) for i, pad in zip(current_trajs, num_pad)] - new_trajs = [ - { - key: list(t[key][:k]) - for key in ["traj", "bck_a", "is_sink", "fwd_logprobs", "bck_logprobs", "interm_rewards"] - } - for t, k in zip(bck_trajs, trunc_lens) - ] - # Next we sample new endings for the truncated trajectories - graphs = [i["traj"][-1][0] for i in new_trajs] + + # Go backwards num_bck_steps steps + bck_trajs = [ + {"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs + ] # type: ignore + graphs = [i["traj"][-1][0] for i in current_trajs] done = [False] * n + fwd_a = [] + for i in range(num_bck_steps): + # This modifies `bck_trajs` & `graphs` in place, passing fwd_a computes P_F(s|s') for the previous step + self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a) + fwd_a = [t["traj"][-1][1] for t in bck_trajs] + # Add forward logprobs for the last step + self._add_fwd_logprobs(bck_trajs, graphs, model, cond_info, done, dev, fwd_a) + log_P_B_tau_back = [sum(t["bck_logprobs"]) for t in bck_trajs] + log_P_F_tau_back = [sum(t["fwd_logprobs"]) for t in bck_trajs] + # Go forward to get full trajectories + fwd_trajs = [ + {"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs + ] # type: ignore + done = [False] * n + bck_a = [] while not all(done): - self._forward_step(model, new_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob) - done = [d or len(t["traj"]) >= self.max_len for d, t in zip(done, new_trajs)] - self._wrap_up_fwd_trajs(new_trajs, graphs) - # We add those new trajectories to the list of returned trajectories - returned_trajs += new_trajs - # Finally, we replace the current trajectories with the new ones if they are accepted by MH + self._forward_step(model, fwd_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob, bck_a) + done = [d or (len(t["traj"]) + T) >= self.max_len for d, t, T in zip(done, fwd_trajs, trunc_lens)] + bck_a = [t["bck_a"][-1] for t in fwd_trajs] + # Add backward logprobs for the last step; this is only necessary if the last action is not a stop + done = [t["traj"][-1][1].action == stop for t in fwd_trajs] + if not all(done): + self._add_bck_logprobs(fwd_trajs, graphs, model, cond_info, done, dev, bck_a) + log_P_F_tau_recon = [sum(t["fwd_logprobs"]) for t in fwd_trajs] + log_P_B_tau_recon = [sum(t["bck_logprobs"]) for t in fwd_trajs] + + # We add those new terminal states to the list of terminal states + terminals = [t["traj"][-1][0] for t in fwd_trajs] + sampled_terminals.extend(terminals) + for traj, term in zip(fwd_trajs, terminals): + traj["result"] = term + # Compute rewards for the acceptance + if compute_reward is not None: + compute_reward(fwd_trajs, cond_info) + + # To end the iteration, we replace the current trajectories with the new ones if they are accepted by MH for i in range(n): - if new_trajs[i]["fwd_logprob"] + new_trajs[i]["bck_logprob"] > current_trajs[i]["fwd_logprob"] + current_trajs[i]["bck_logprob"]: - current_trajs[i] = new_trajs[i] - def _forward_step(self, model, data, graphs, cond_info, t, done, rng, dev, random_action_prob) -> None: + if criteria == "deterministic": + # Keep the highest reward + if fwd_trajs[i]["log_reward"] > current_trajs[i]["log_reward"]: + current_trajs[i] = fwd_trajs[i] + elif criteria == "stochastic": + # Accept with probability max(1, R(x')/R(x)q(x'|x)/q(x|x')) + log_q_xprime_given_x = log_P_B_tau_back[i] + log_P_F_tau_recon[i] + log_q_x_given_xprime = log_P_B_tau_recon[i] + log_P_F_tau_back[i] + log_R_ratio = fwd_trajs[i]["log_reward"] - current_trajs[i]["log_reward"] + log_acceptance_ratio = log_R_ratio + log_q_xprime_given_x - log_q_x_given_xprime + if log_acceptance_ratio > 0 or rng.uniform() < torch.exp(log_acceptance_ratio): + current_trajs[i] = fwd_trajs[i] + elif criteria == "always": + current_trajs[i] = fwd_trajs[i] + # Finally, we resample new "P_B-on-policy" trajectories from the terminal states + stacked_ci = ( + {k: cond_info[k].repeat(num_ls_steps, *((1,) * (cond_info[k].ndim - 1))) for k in cond_info} + if cond_info is not None + else None + ) + returned_trajs = self.sample_backward_from_graphs(sampled_terminals, model, stacked_ci, random_action_prob) + return returned_trajs + initial_trajs + + def _forward_step(self, model, data, graphs, cond_info, t, done, rng, dev, random_action_prob, bck_a=[]) -> None: def not_done(lst): - return [e for i, e in enumerate(lst) if not done[i]] + return [e for i, e in enumerate(lst) if not done[i]] if not bck_a else lst n = len(data) # Construct graphs for the trajectories that aren't yet done torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] - not_done_mask = torch.tensor(done, device=dev).logical_not() + if not bck_a: + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + else: + not_done_mask = torch.tensor([True] * n, device=cond_info["encoding"].device) # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant batch = self.ctx.collate(torch_graphs) - batch.cond_info = cond_info[not_done_mask] if cond_info is not None else None - fwd_cat, *_ = model(batch.to(dev)) + batch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, bck_cat, *_ = model(batch.to(dev)) if random_action_prob > 0: # Device which graphs in the minibatch will get their action randomized is_random_action = torch.tensor( @@ -286,17 +343,26 @@ def not_done(lst): actions = sample_cat.sample() else: actions = fwd_cat.sample() - graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)] + graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a, fwd=True) for g, a in zip(torch_graphs, actions)] log_probs = fwd_cat.log_prob(actions) + if bck_a: + aidx_bck_a = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, bck_a)] + bck_logprobs = bck_cat.log_prob(aidx_bck_a) # Step each trajectory, and accumulate statistics for i, j in zip(not_done(range(n)), range(n)): - data[i]["fwd_logprob"].append(log_probs[j].unsqueeze(0)) + if bck_a and len(data[i]["bck_logprobs"]) < len(data[i]["traj"]): + data[i]["bck_logprobs"].append(bck_logprobs[j].unsqueeze(0)) + if done[i]: + continue + data[i]["fwd_logprobs"].append(log_probs[j].unsqueeze(0)) data[i]["traj"].append((graphs[i], graph_actions[j])) data[i]["bck_a"].append(self.env.reverse(graphs[i], graph_actions[j])) + if "U_bck_logprobs" not in data[i]: + data[i]["U_bck_logprobs"] = [] # Check if we're done if graph_actions[j].action is GraphActionType.Stop: done[i] = True - data[i]["bck_logprob"].append(torch.tensor([1.0], device=dev).log()) + data[i]["U_bck_logprobs"].append(torch.tensor([1.0], device=dev).log()) data[i]["is_sink"].append(1) else: # If not done, try to step the self.environment gp = graphs[i] @@ -307,7 +373,7 @@ def not_done(lst): except AssertionError: done[i] = True data[i]["is_valid"] = False - data[i]["bck_logprob"].append(torch.tensor([1.0], device=dev).log()) + data[i]["U_bck_logprobs"].append(torch.tensor([1.0], device=dev).log()) data[i]["is_sink"].append(1) continue if t == self.max_len - 1: @@ -315,7 +381,7 @@ def not_done(lst): # If no error, add to the trajectory # P_B = uniform backward n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) - data[i]["bck_logprob"].append(torch.tensor([1 / n_back], device=dev).log()) + data[i]["U_bck_logprobs"].append(torch.tensor([1 / n_back], device=dev).log()) data[i]["is_sink"].append(0) graphs[i] = gp if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): @@ -324,16 +390,21 @@ def not_done(lst): data[i]["is_valid"] = False # Nothing is returned, data is modified in place - def _backward_step(self, model, data, graphs, cond_info, done, dev) -> None: + def _backward_step(self, model, data, graphs, cond_info, done, dev, fwd_a=[]): + # fwd_a is a list of GraphActions that are the reverse of the last backwards actions we took. + # Passing them allows us to compute the forward logprobs of the actions we took. def not_done(lst): - return [e for i, e in enumerate(lst) if not done[i]] + return [e for i, e in enumerate(lst) if not done[i]] if not fwd_a else lst torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] - not_done_mask = torch.tensor(done, device=dev).logical_not() + if not fwd_a: + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + else: + not_done_mask = torch.tensor([True] * len(graphs), device=cond_info["encoding"].device) if model is not None: - gbatch = self.ctx.collate(torch_graphs).to(dev) - gbatch.cond_info = cond_info[not_done_mask] if cond_info is not None else None - _, bck_cat, *_ = model(gbatch) + gbatch = self.ctx.collate(torch_graphs) + gbatch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, bck_cat, *_ = model(gbatch.to(dev)) else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order @@ -350,17 +421,59 @@ def not_done(lst): self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) ] bck_logprobs = bck_cat.log_prob(bck_actions) + if fwd_a and model is not None: + aidx_fwd_a = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, fwd_a)] + fwd_logprobs = fwd_cat.log_prob(aidx_fwd_a) for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + if fwd_a and model is not None and len(data[i]["fwd_logprobs"]) < len(data[i]["traj"]): + data[i]["fwd_logprobs"].append(fwd_logprobs[j].item()) + if done[i]: + # This can happen when fwd_a is passed, we should optimize this though. The reason is that even + # if a graph is done, we may still want to compute its forward logprobs. + continue + g = graphs[i] + b_a = graph_bck_actions[j] + gp = self.env.step(g, b_a) + f_a = self.env.reverse(g, b_a) + graphs[i], f_a = relabel(gp, f_a) + data[i]["traj"].append((graphs[i], f_a)) + data[i]["bck_a"].append(b_a) + data[i]["is_sink"].append(0) + data[i]["bck_logprobs"].append(bck_logprobs[j].item()) + + if len(graphs[i]) == 0: + done[i] = True + if "U_bck_logprobs" not in data[i]: + data[i]["U_bck_logprobs"] = [] if not done[i]: - g = graphs[i] - b_a = graph_bck_actions[j] - gp = self.env.step(g, b_a) - f_a = self.env.reverse(g, b_a) - graphs[i], f_a = relabel(gp, f_a) - data[i]["traj"].append((graphs[i], f_a)) - data[i]["bck_a"].append(b_a) - data[i]["is_sink"].append(0) - data[i]["bck_logprobs"].append(bck_logprobs[j].item()) - if len(graphs[i]) == 0: - done[i] = True + n_back = self.env.count_backward_transitions(graphs[i], check_idempotent=self.correct_idempotent) + data[i]["U_bck_logprobs"].append(torch.tensor([1.0 / n_back], device=dev).log()) + + def _add_fwd_logprobs(self, data, graphs, model, cond_info, done, dev, fwd_a): + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] + + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + gbatch = self.ctx.collate(torch_graphs) + gbatch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, *_ = model(gbatch.to(dev)) + fwd_actions = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, fwd_a)] + log_probs = fwd_cat.log_prob(fwd_actions) + for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + data[i]["fwd_logprobs"].append(log_probs[j].item()) + + def _add_bck_logprobs(self, data, graphs, model, cond_info, done, dev, bck_a): + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] + + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + gbatch = self.ctx.collate(torch_graphs) + gbatch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, bck_cat, *_ = model(gbatch.to(dev)) + bck_actions = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, bck_a)] + log_probs = bck_cat.log_prob(bck_actions) + for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + data[i]["bck_logprobs"].append(log_probs[j].item()) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index e8b7fe64..899532ad 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -174,17 +174,10 @@ def create_training_data_from_own_samples( - reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise - fwd_logprob: log Z + sum logprobs P_F - bck_logprob: sum logprobs P_B - - logZ: predicted log Z - loss: predicted loss (if bootstrapping) - is_valid: is the generated graph valid according to the env & ctx """ - dev = get_worker_device() - cond_info = cond_info.to(dev) if cond_info is not None else None data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) - if cond_info is not None: - logZ_pred = model.logZ(cond_info) - for i in range(n): - data[i]["logZ"] = logZ_pred[i].item() return data def create_training_data_from_graphs( @@ -214,8 +207,6 @@ def create_training_data_from_graphs( """ if self.cfg.do_sample_p_b: assert model is not None and cond_info is not None and random_action_prob is not None - dev = get_worker_device() - cond_info = cond_info.to(dev) return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob ) @@ -229,10 +220,10 @@ def create_training_data_from_graphs( self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) for gp, _ in traj["traj"][1:] ] + [1] - traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(get_worker_device()) + traj["U_bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(get_worker_device()) traj["result"] = traj["traj"][-1][0] if self.cfg.do_parameterize_p_b: - traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]] + traj["bck_a"] = [GraphAction(GraphActionType.Pad)] + [self.env.reverse(g, a) for g, a in traj["traj"]] # There needs to be an additonal node when we're parameterizing P_B, # See sampling with parametrized P_B traj["traj"].append(deepcopy(traj["traj"][-1])) @@ -316,7 +307,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): ] batch = self.ctx.collate(torch_graphs) batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) - batch.log_p_B = torch.cat([i["bck_logprobs"] for i in trajs], 0) + batch.U_log_p_B = torch.cat([i["U_bck_logprobs"] for i in trajs], 0) batch.actions = torch.tensor(actions) if self.cfg.do_parameterize_p_b: batch.bck_actions = torch.tensor( @@ -353,7 +344,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.bck_ip_lens = torch.tensor([len(i) for i in bck_ipa]) # compute_batch_losses expects these two optional values, if someone else doesn't fill them in, default to 0 - batch.num_offline = 0 + batch.num_offline = 0 # TODO: this has been half-deprecated, finish the job batch.num_online = 0 return batch @@ -475,7 +466,7 @@ def compute_batch_losses( # occasion masks out the first P_B of the "next" trajectory that we've shifted. log_p_B = torch.roll(log_p_B, -1, 0) * (1 - batch.is_sink) else: - log_p_B = batch.log_p_B + log_p_B = batch.U_log_p_B assert log_p_F.shape == log_p_B.shape if self.cfg.n_loss == NLoss.TB: @@ -506,13 +497,15 @@ def compute_batch_losses( if self.cfg.do_parameterize_p_b: # Life is pain, log_p_B is one unit too short for all trajs + log_p_B_unif = batch.U_log_p_B + assert log_p_B_unif.shape[0] == log_p_B.shape[0] - log_p_B_unif = torch.zeros_like(log_p_B) - for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)): - log_p_B_unif[s : e - 1] = batch.log_p_B[s - i : e - 1 - i] + # log_p_B_unif = torch.zeros_like(log_p_B) + # for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)): + # log_p_B_unif[s : e - 1] = batch.U_log_p_B[s - i : e - 1 - i] - if self.cfg.backward_policy == Backward.Uniform: - log_p_B = log_p_B_unif + # if self.cfg.backward_policy == Backward.Uniform: + # log_p_B = log_p_B_unif else: log_p_B_unif = log_p_B diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 65733f64..106d4d87 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -90,7 +90,7 @@ class Config(StrictDataClass): """ desc: str = "noDesc" - log_dir: str = MISSING + log_dir: Optional[str] = MISSING device: str = "cuda" seed: int = 0 validate_every: int = 1000 diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 94c54534..9226c88c 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -137,8 +137,7 @@ def iterator(): t = self.current_iter p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) - # TODO: in the cond info refactor, pass the whole thing instead of just the encoding - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) self.mark_all(trajs, source='sample') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) @@ -168,7 +167,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(n_this_time, t) # TODO: in the cond info refactor, pass the whole thing instead of just the encoding - trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info, p) self.mark_all(trajs, source='sample') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) @@ -199,7 +198,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) self.mark_all(trajs, source='dataset') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) @@ -218,7 +217,7 @@ def iterator(): # I'm also not a fan of encode_conditional_information, it assumes lots of things about what's passed to # it and the state of the program (e.g. validation mode) cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) - trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info, p) self.mark_all(trajs, source='dataset') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) @@ -241,7 +240,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) self.mark_all(trajs, source='dataset') self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) @@ -289,6 +288,8 @@ def create_batch(self, trajs, batch_info): def compute_properties(self, trajs, mark_as_online=False): """Sets trajs' obj_props and is_valid keys by querying the task.""" + if all('obj_props' in t for t in trajs): + return # TODO: refactor obj_props into properties valid_idcs = torch.tensor([i for i in range(len(trajs)) if trajs[i].get("is_valid", True)]).long() # fetch the valid trajectories endpoints @@ -304,13 +305,15 @@ def compute_properties(self, trajs, mark_as_online=False): all_fr[valid_idcs] = obj_props for i in range(len(trajs)): trajs[i]["obj_props"] = all_fr[i] - trajs[i]["is_online"] = mark_as_online + trajs[i]["is_online"] = mark_as_online # TODO: this is deprecated in favor of 'source'? # Override the is_valid key in case the task made some objs invalid for i in valid_idcs: trajs[i]["is_valid"] = True def compute_log_rewards(self, trajs): """Sets trajs' log_reward key by querying the task.""" + if all('log_reward' in t for t in trajs): + return obj_props = torch.stack([t["obj_props"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = self.task.cond_info_to_logreward(cond_info, obj_props) @@ -327,10 +330,14 @@ def send_to_replay(self, trajs): def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): + if 'cond_info' in trajs[i]: + continue trajs[i]["cond_info"] = {k: cond_info[k][i] for k in cond_info} def set_traj_props(self, trajs, props): for i in range(len(trajs)): + if 'obj_props' in trajs[i]: + continue trajs[i]["obj_props"] = props[i] # TODO: refactor def relabel_in_hindsight(self, trajs): diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index bac09959..eb4ee9cf 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -105,6 +105,9 @@ def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = action: GraphAction A graph action whose type is one of Stop, AddNode, or SetEdgeAttr. """ + if aidx.action_type == -1: + # Pad action + return GraphAction(GraphActionType.Pad) if fwd: t = self.action_type_order[aidx.action_type] else: @@ -147,6 +150,8 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI A triple describing the type of action, and the corresponding row and column index for the corresponding Categorical matrix. """ + if action.action is GraphActionType.Pad: + return ActionIndex(action_type=-1, row_idx=0, col_idx=0) # Find the index of the action type, privileging the forward actions for u in [self.action_type_order, self.bck_action_type_order]: if action.action in u: diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 2817f056..b3dbc034 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -345,7 +345,7 @@ def count_backward_transitions(self, g: Graph, check_idempotent: bool = False): def reverse(self, g: Graph, ga: GraphAction): if ga.action == GraphActionType.Stop: - return ga + return GraphAction(GraphActionType.Pad) # because we can't reverse a Stop action elif ga.action == GraphActionType.AddNode: return GraphAction(GraphActionType.RemoveNode, source=len(g.nodes)) elif ga.action == GraphActionType.AddEdge: diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 17091d84..78519c81 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -9,6 +9,7 @@ from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.local_search_tb import LocalSearchTB from gflownet.algo.soft_q_learning import SoftQLearning from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.data.replay_buffer import ReplayBuffer @@ -38,6 +39,8 @@ def setup_algo(self): algo = self.cfg.algo.method if algo == "TB": algo = TrajectoryBalance + elif algo == 'LSTB': + algo = LocalSearchTB elif algo == "FM": algo = FlowMatching elif algo == "A2C": @@ -48,6 +51,9 @@ def setup_algo(self): raise ValueError(algo) self.algo = algo(self.env, self.ctx, self.cfg) + if self.algo.requires_task: + self.algo.set_task(self.task) + def setup_data(self): self.training_data = [] self.test_data = [] @@ -125,9 +131,9 @@ def setup(self): if self.print_config: print("\n\nHyperparameters:\n") print(yaml_cfg) - os.makedirs(self.cfg.log_dir, exist_ok=True) - with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: - f.write(yaml_cfg) + if self.cfg.log_dir is not None: + with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: + f.write(yaml_cfg) def step(self, loss: Tensor, train_it: int): loss.backward() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 12c413c6..fa3c2706 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -110,7 +110,7 @@ def step(self, loss: Tensor, train_it: int): raise NotImplementedError() def setup(self): - if self.rank == 0: + if self.cfg.log_dir and self.rank == 0: if os.path.exists(self.cfg.log_dir): if self.cfg.overwrite_existing_exp: shutil.rmtree(self.cfg.log_dir) @@ -286,7 +286,7 @@ def run(self, logger=None): validation every `validate_every` minibatches. """ if logger is None: - logger = create_logger(logfile=self.cfg.log_dir + "/train.log") + logger = create_logger(logfile=self.cfg.log_dir + "/train.log" if self.cfg.log_dir else None) self.model.to(self.device) self.sampling_model.to(self.device) import threading @@ -393,7 +393,7 @@ def terminate(self): terminate() def _save_state(self, it): - if self.rank != 0: + if self.rank != 0 or self.cfg.log_dir is None: return state = { "models_state_dict": [self.model.state_dict()], From c5373cf3f8a4a52f8c47c073e35b70c0b141b775 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 30 Aug 2024 07:54:37 -0600 Subject: [PATCH 26/31] lstb file --- src/gflownet/algo/local_search_tb.py | 50 ++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/gflownet/algo/local_search_tb.py diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py new file mode 100644 index 00000000..5751beb5 --- /dev/null +++ b/src/gflownet/algo/local_search_tb.py @@ -0,0 +1,50 @@ +import torch + +from gflownet import GFNTask +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.data.data_source import DataSource +from gflownet.utils.misc import get_worker_device + + +class LocalSearchTB(TrajectoryBalance): + requires_task: bool = True + task: GFNTask + + def __init__(self, env, ctx, cfg): + super().__init__(env, ctx, cfg) + self.task = None + + def set_task(self, task): + self.task = task + assert self.cfg.do_parameterize_p_b, "LocalSearchTB requires do_parameterize_p_b to be True" + + def create_training_data_from_own_samples(self, model, n, cond_info=None, random_action_prob=0.0): + assert self.task is not None, "LocalSearchTB requires a task to be set" + assert n % (1 + self.global_cfg.algo.ls.num_ls_steps) == 0, "n must be divisible by 1 + num_ls_steps" + n_per_step = n // (1 + self.global_cfg.algo.ls.num_ls_steps) + cond_info = {k: v[:n_per_step] for k, v in cond_info.items()} if cond_info is not None else None + random_action_prob = random_action_prob or 0.0 + data = self.graph_sampler.local_search_sample_from_model( + model, + n_per_step, + cond_info, + random_action_prob, + self.global_cfg.algo.ls.num_ls_steps, + self.global_cfg.algo.ls.num_bck_steps, + self._compute_log_rewards, + self.global_cfg.algo.ls.accept_criteria, + ) + return data + + def _compute_log_rewards(self, trajs, cond_info): + """Sets trajs' log_reward key by querying the task.""" + cfg = self.global_cfg + self.cfg = cfg # TODO: fix TB so we don't need to do this + # Doing a bit of hijacking here to compute the log rewards, DataSource implements this for us. + # May be worth refactoring this to be more general eventually, this depends on `self` having ctx, task, and cfg + # attributes. + DataSource.set_traj_cond_info(self, trajs, cond_info) # type: ignore + DataSource.compute_properties(self, trajs) # type: ignore + DataSource.compute_log_rewards(self, trajs) # type: ignore + self.cfg = cfg.algo.tb + # trajs is modified in place, so no need to return anything \ No newline at end of file From 7e623bd9a45094be546da448a9e8f2ae015ea260 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 5 Sep 2024 16:08:16 -0600 Subject: [PATCH 27/31] yield_only_accepted in LS + load_model_state flag --- src/gflownet/algo/config.py | 1 + src/gflownet/algo/graph_sampling.py | 36 +++++++++++++++++----------- src/gflownet/algo/local_search_tb.py | 9 +++---- src/gflownet/config.py | 1 + src/gflownet/data/data_source.py | 13 ++++++---- src/gflownet/trainer.py | 6 +++++ 6 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index cda6f53b..4ae715ae 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -153,6 +153,7 @@ class LSTBConfig(StrictDataClass): num_ls_steps: int = 1 num_bck_steps: int = 1 accept_criteria: str = "deterministic" + yield_only_accepted: bool = False @dataclass diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 7e71b120..3cfe601d 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -13,6 +13,7 @@ GraphActionType, action_type_to_mask, ) +from gflownet.algo.config import LSTBConfig from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.utils.misc import get_worker_device, get_worker_rng @@ -215,10 +216,8 @@ def local_search_sample_from_model( n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0, - num_ls_steps: int = 1, - num_bck_steps: int = 1, + cfg: LSTBConfig = LSTBConfig(), compute_reward: Optional[Callable] = None, - criteria: str = "deterministic", ): dev = get_worker_device() rng = get_worker_rng() @@ -237,11 +236,11 @@ def local_search_sample_from_model( for t in current_trajs: t["traj"] = t["traj"][:-1] # Remove the padding state - for mcmc_steps in range(num_ls_steps): + for mcmc_steps in range(cfg.num_ls_steps): # First we must do a bit of accounting so that we can later prevent trajectories longer than max_len stop = GraphActionType.Stop num_pad = [(1 if t["traj"][-1][1].action == stop else 0) for t in current_trajs] - trunc_lens = [max(0, len(i["traj"]) - num_bck_steps - pad) for i, pad in zip(current_trajs, num_pad)] + trunc_lens = [max(0, len(i["traj"]) - cfg.num_bck_steps - pad) for i, pad in zip(current_trajs, num_pad)] # Go backwards num_bck_steps steps bck_trajs = [ @@ -250,7 +249,7 @@ def local_search_sample_from_model( graphs = [i["traj"][-1][0] for i in current_trajs] done = [False] * n fwd_a = [] - for i in range(num_bck_steps): + for i in range(cfg.num_bck_steps): # This modifies `bck_trajs` & `graphs` in place, passing fwd_a computes P_F(s|s') for the previous step self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a) fwd_a = [t["traj"][-1][1] for t in bck_trajs] @@ -278,6 +277,7 @@ def local_search_sample_from_model( # We add those new terminal states to the list of terminal states terminals = [t["traj"][-1][0] for t in fwd_trajs] + import pdb; pdb.set_trace() sampled_terminals.extend(terminals) for traj, term in zip(fwd_trajs, terminals): traj["result"] = term @@ -287,11 +287,11 @@ def local_search_sample_from_model( # To end the iteration, we replace the current trajectories with the new ones if they are accepted by MH for i in range(n): - if criteria == "deterministic": + if cfg.accept_criteria == "deterministic": # Keep the highest reward if fwd_trajs[i]["log_reward"] > current_trajs[i]["log_reward"]: current_trajs[i] = fwd_trajs[i] - elif criteria == "stochastic": + elif cfg.accept_criteria == "stochastic": # Accept with probability max(1, R(x')/R(x)q(x'|x)/q(x|x')) log_q_xprime_given_x = log_P_B_tau_back[i] + log_P_F_tau_recon[i] log_q_x_given_xprime = log_P_B_tau_recon[i] + log_P_F_tau_back[i] @@ -299,14 +299,22 @@ def local_search_sample_from_model( log_acceptance_ratio = log_R_ratio + log_q_xprime_given_x - log_q_x_given_xprime if log_acceptance_ratio > 0 or rng.uniform() < torch.exp(log_acceptance_ratio): current_trajs[i] = fwd_trajs[i] - elif criteria == "always": + elif cfg.accept_criteria == "always": current_trajs[i] = fwd_trajs[i] + # Finally, we resample new "P_B-on-policy" trajectories from the terminal states - stacked_ci = ( - {k: cond_info[k].repeat(num_ls_steps, *((1,) * (cond_info[k].ndim - 1))) for k in cond_info} - if cond_info is not None - else None - ) + # If we're only interested in the accepted trajectories, we use them as starting points instead + if cfg.yield_only_accepted: + sampled_terminals = [i['traj'][-1][0] for i in current_trajs] + stacked_ci = cond_info + + if not cfg.yield_only_accepted: + # In this scenario, the batch is n // num_ls_steps, so we do some stacking + stacked_ci = ( + {k: cond_info[k].repeat(cfg.num_ls_steps, *((1,) * (cond_info[k].ndim - 1))) for k in cond_info} + if cond_info is not None + else None + ) returned_trajs = self.sample_backward_from_graphs(sampled_terminals, model, stacked_ci, random_action_prob) return returned_trajs + initial_trajs diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py index 5751beb5..a1632987 100644 --- a/src/gflownet/algo/local_search_tb.py +++ b/src/gflownet/algo/local_search_tb.py @@ -21,7 +21,10 @@ def set_task(self, task): def create_training_data_from_own_samples(self, model, n, cond_info=None, random_action_prob=0.0): assert self.task is not None, "LocalSearchTB requires a task to be set" assert n % (1 + self.global_cfg.algo.ls.num_ls_steps) == 0, "n must be divisible by 1 + num_ls_steps" - n_per_step = n // (1 + self.global_cfg.algo.ls.num_ls_steps) + if self.global_cfg.algo.ls.yield_only_accepted: + n_per_step = n + else: + n_per_step = n // (1 + self.global_cfg.algo.ls.num_ls_steps) cond_info = {k: v[:n_per_step] for k, v in cond_info.items()} if cond_info is not None else None random_action_prob = random_action_prob or 0.0 data = self.graph_sampler.local_search_sample_from_model( @@ -29,10 +32,8 @@ def create_training_data_from_own_samples(self, model, n, cond_info=None, random n_per_step, cond_info, random_action_prob, - self.global_cfg.algo.ls.num_ls_steps, - self.global_cfg.algo.ls.num_bck_steps, + self.global_cfg.algo.ls, self._compute_log_rewards, - self.global_cfg.algo.ls.accept_criteria, ) return data diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 106d4d87..21e47aaa 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -96,6 +96,7 @@ class Config(StrictDataClass): validate_every: int = 1000 checkpoint_every: Optional[int] = None store_all_checkpoints: bool = False + load_model_state: Optional[str] = None print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 9226c88c..d3221a7b 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -86,9 +86,8 @@ def __iter__(self): for d in batch_infos: batch_info.update(d) yield self.create_batch(trajs, batch_info) - except Exception as e: - if 1: - raise e + err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break + except (Exception, RuntimeError) as e: err_tol -= 1 if err_tol == 0: raise e @@ -98,7 +97,13 @@ def __iter__(self): traceback.print_exc() traceback.print_exc(file=sys.stderr) continue - err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break + except: + print("Unknown error in DataSource") + import traceback, sys + traceback.print_exc() + traceback.print_exc(file=sys.stderr) + err_tol -= 1 + continue def validate_batch(self, batch, trajs): for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index fa3c2706..a1eca458 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -128,6 +128,8 @@ def setup(self): self.setup_env_context() self.setup_algo() self.setup_model() + if self.cfg.load_model_state is not None: + self.load_state(self.cfg.load_model_state) def _wrap_for_mp(self, obj): """Wraps an object in a placeholder whose reference can be sent to a @@ -411,6 +413,10 @@ def _save_state(self, it): if self.cfg.store_all_checkpoints: shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") + def load_state(self, path): + state = torch.load(path) + self.model.load_state_dict(state["models_state_dict"][0]) + def log(self, info, index, key): if not hasattr(self, "_summary_writer"): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) From 32d4cafe92f3886dae01e377a65f4d7e06688dfc Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Oct 2024 08:54:51 -0600 Subject: [PATCH 28/31] many fixes, frag env options --- src/gflownet/algo/graph_sampling.py | 11 +++++-- src/gflownet/algo/local_search_tb.py | 9 ++++-- src/gflownet/data/data_source.py | 28 +++++++++--------- src/gflownet/envs/frag_mol_env.py | 36 +++++++++++++++++++---- src/gflownet/online_trainer.py | 2 +- src/gflownet/tasks/config.py | 1 + src/gflownet/trainer.py | 43 +++++++++++++++------------- 7 files changed, 83 insertions(+), 47 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 3cfe601d..72ebb56b 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -235,6 +235,7 @@ def local_search_sample_from_model( if self.pad_with_terminal_state: for t in current_trajs: t["traj"] = t["traj"][:-1] # Remove the padding state + num_accepts = 0 for mcmc_steps in range(cfg.num_ls_steps): # First we must do a bit of accounting so that we can later prevent trajectories longer than max_len @@ -254,7 +255,7 @@ def local_search_sample_from_model( self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a) fwd_a = [t["traj"][-1][1] for t in bck_trajs] # Add forward logprobs for the last step - self._add_fwd_logprobs(bck_trajs, graphs, model, cond_info, done, dev, fwd_a) + self._add_fwd_logprobs(bck_trajs, graphs, model, cond_info, [False] * n, dev, fwd_a) log_P_B_tau_back = [sum(t["bck_logprobs"]) for t in bck_trajs] log_P_F_tau_back = [sum(t["fwd_logprobs"]) for t in bck_trajs] @@ -277,10 +278,10 @@ def local_search_sample_from_model( # We add those new terminal states to the list of terminal states terminals = [t["traj"][-1][0] for t in fwd_trajs] - import pdb; pdb.set_trace() sampled_terminals.extend(terminals) for traj, term in zip(fwd_trajs, terminals): traj["result"] = term + traj["is_accept"] = False # Compute rewards for the acceptance if compute_reward is not None: compute_reward(fwd_trajs, cond_info) @@ -291,6 +292,7 @@ def local_search_sample_from_model( # Keep the highest reward if fwd_trajs[i]["log_reward"] > current_trajs[i]["log_reward"]: current_trajs[i] = fwd_trajs[i] + num_accepts += 1 elif cfg.accept_criteria == "stochastic": # Accept with probability max(1, R(x')/R(x)q(x'|x)/q(x|x')) log_q_xprime_given_x = log_P_B_tau_back[i] + log_P_F_tau_recon[i] @@ -299,8 +301,10 @@ def local_search_sample_from_model( log_acceptance_ratio = log_R_ratio + log_q_xprime_given_x - log_q_x_given_xprime if log_acceptance_ratio > 0 or rng.uniform() < torch.exp(log_acceptance_ratio): current_trajs[i] = fwd_trajs[i] + num_accepts += 1 elif cfg.accept_criteria == "always": current_trajs[i] = fwd_trajs[i] + num_accepts += 1 # Finally, we resample new "P_B-on-policy" trajectories from the terminal states # If we're only interested in the accepted trajectories, we use them as starting points instead @@ -316,7 +320,8 @@ def local_search_sample_from_model( else None ) returned_trajs = self.sample_backward_from_graphs(sampled_terminals, model, stacked_ci, random_action_prob) - return returned_trajs + initial_trajs + # TODO: modify the trajs' cond_info!!! + return initial_trajs + returned_trajs, num_accepts / (cfg.num_ls_steps * n) def _forward_step(self, model, data, graphs, cond_info, t, done, rng, dev, random_action_prob, bck_a=[]) -> None: def not_done(lst): diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py index a1632987..405a7db0 100644 --- a/src/gflownet/algo/local_search_tb.py +++ b/src/gflownet/algo/local_search_tb.py @@ -20,14 +20,15 @@ def set_task(self, task): def create_training_data_from_own_samples(self, model, n, cond_info=None, random_action_prob=0.0): assert self.task is not None, "LocalSearchTB requires a task to be set" - assert n % (1 + self.global_cfg.algo.ls.num_ls_steps) == 0, "n must be divisible by 1 + num_ls_steps" if self.global_cfg.algo.ls.yield_only_accepted: - n_per_step = n + n_per_step = n // 2 + assert n % 2 == 0, "n must be divisible by 2" else: + assert n % (1 + self.global_cfg.algo.ls.num_ls_steps) == 0, "n must be divisible by 1 + num_ls_steps" n_per_step = n // (1 + self.global_cfg.algo.ls.num_ls_steps) cond_info = {k: v[:n_per_step] for k, v in cond_info.items()} if cond_info is not None else None random_action_prob = random_action_prob or 0.0 - data = self.graph_sampler.local_search_sample_from_model( + data, accept_rate = self.graph_sampler.local_search_sample_from_model( model, n_per_step, cond_info, @@ -35,6 +36,8 @@ def create_training_data_from_own_samples(self, model, n, cond_info=None, random self.global_cfg.algo.ls, self._compute_log_rewards, ) + for t in data: + t['accept_rate'] = accept_rate return data def _compute_log_rewards(self, trajs, cond_info): diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index d3221a7b..a6734959 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -12,7 +12,7 @@ from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical, action_type_to_mask from gflownet.envs.seq_building_env import SeqBatch from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.utils.misc import get_worker_rng +from gflownet.utils.misc import get_worker_rng, get_this_wid from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer @@ -48,6 +48,7 @@ def __init__( self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step + self._err_tol = 10 self.setup_mp_buffers() def add_sampling_hook(self, hook: Callable): @@ -65,8 +66,6 @@ def __iter__(self): self.rng = get_worker_rng() its = [i() for i in self.iterators] self.algo.set_is_eval(self.is_algo_eval) - print("New iterator", its) - err_tol = 10 while True: try: with self.global_step_count_lock: @@ -86,23 +85,21 @@ def __iter__(self): for d in batch_infos: batch_info.update(d) yield self.create_batch(trajs, batch_info) - err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break + self._err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break except (Exception, RuntimeError) as e: - err_tol -= 1 - if err_tol == 0: + self._err_tol -= 1 + if self._err_tol == 0: raise e - print(f"Error in DataSource: {e} [tol={err_tol}]") + print(f"Error in DataSource: {e} [tol={self._err_tol}]") # print full traceback import traceback, sys traceback.print_exc() - traceback.print_exc(file=sys.stderr) continue except: print("Unknown error in DataSource") import traceback, sys traceback.print_exc() - traceback.print_exc(file=sys.stderr) - err_tol -= 1 + self._err_tol -= 1 continue def validate_batch(self, batch, trajs): @@ -288,7 +285,7 @@ def create_batch(self, trajs, batch_info): batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) batch.obj_props = torch.stack([t["obj_props"] for t in trajs]) - self.validate_batch(batch, trajs) + # self.validate_batch(batch, trajs) return self._maybe_put_in_mp_buffer(batch) def compute_properties(self, trajs, mark_as_online=False): @@ -368,16 +365,19 @@ def sample_idcs(self, n, num_samples): def iterate_indices(self, n, num_samples): worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + if torch.distributed.is_initialized(): + num_workers *= torch.distributed.get_world_size() + wid = get_this_wid() if n == 0: # Should we be raising an error here? warning? yield np.arange(0, 0) return - if worker_info is None: # no multi-processing + if num_workers == 1: # no multi-processing, no distributed start, end, wid = 0, n, -1 else: # split the data into chunks (per-worker) - nw = worker_info.num_workers - wid = worker_info.id + nw = num_workers start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) if end - start <= num_samples: diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index eb4ee9cf..8c42de12 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -80,6 +80,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu self.edges_are_duplicated = True self.edges_are_unordered = False self.fail_on_missing_attr = True + self.use_strict_edge_filling = False # Order in which models have to output logits self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.SetEdgeAttr] @@ -227,6 +228,8 @@ def graph_to_Data(self, g: Graph) -> gd.Data: # If there are unspecified attachment points, we will disallow the agent from using the stop # action. has_unfilled_attach = False + first_unfilled_attach = True + has_fully_unfilled_attach = False for i, e in enumerate(g.edges): ed = g.edges[e] if len(ed): @@ -234,22 +237,41 @@ def graph_to_Data(self, g: Graph) -> gd.Data: node_is_connected_to_edge_with_attr[e[1]] = 1 a = ed.get("src_attach", -1) b = ed.get("dst_attach", -1) + if a < 0 or b < 0: + has_unfilled_attach = True + if self.use_strict_edge_filling: + remove_edge_attr_mask *= 0 + if a < 0 and b < 0: + has_fully_unfilled_attach = True if a >= 0: attached[e[0]].append(a) - remove_edge_attr_mask[i, 0] = 1 - else: - has_unfilled_attach = True + if self.use_strict_edge_filling: + # Since the agent has to fill in the attachment points before adding a new node, or stopping, the + # reverse is that if there is a half-filled edge attachment, the only valid move is to unset the + # other half (i.e. whichever is the first edge we find like that) + # We also will only end up being able to do that to leaf nodes (degree 1) + if degrees[e[0]] == 1 or degrees[e[1]] == 1: + remove_edge_attr_mask[i, 0] = 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + else: + remove_edge_attr_mask[i, 0] = 1 if b >= 0: attached[e[1]].append(b) - remove_edge_attr_mask[i, 1] = 1 - else: - has_unfilled_attach = True + if self.use_strict_edge_filling: + if degrees[e[0]] == 1 or degrees[e[1]] == 1: + remove_edge_attr_mask[i, 1] = 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + else: + remove_edge_attr_mask[i, 1] = 1 + + if has_unfilled_attach: + first_unfilled_attach = False # The node must be connected to at most 1 other node and in the case where it is # connected to exactly one other node, the edge connecting them must not have any # attributes. if len(g): remove_node_mask = (node_is_connected * (1 - node_is_connected_to_edge_with_attr)).reshape((-1, 1)) + if self.use_strict_edge_filling: + remove_node_mask = remove_node_mask * (1 - (has_unfilled_attach and not has_fully_unfilled_attach)) # Here we encode the attached atoms in the edge features, as well as mask out attached # atoms. @@ -276,6 +298,8 @@ def graph_to_Data(self, g: Graph) -> gd.Data: else np.ones((1, 1), np.float32) ) add_node_mask = add_node_mask * np.ones((x.shape[0], self.num_new_node_values), np.float32) + if self.use_strict_edge_filling: + add_node_mask = add_node_mask * (1 - has_unfilled_attach) stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else ones((1, 1)) data = gd.Data( diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 78519c81..605ddb17 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -131,7 +131,7 @@ def setup(self): if self.print_config: print("\n\nHyperparameters:\n") print(yaml_cfg) - if self.cfg.log_dir is not None: + if self.cfg.log_dir is not None and self.rank == 0: with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: f.write(yaml_cfg) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 8f2b30c9..520cde3e 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -7,6 +7,7 @@ @dataclass class SEHTaskConfig(StrictDataClass): reduced_frag: bool = False + qed_mul: bool = False @dataclass diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index a1eca458..777e1a0f 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -201,7 +201,7 @@ def build_validation_data_loader(self) -> DataLoader: # TODO: might be better to change total steps to total trajectories drawn src.do_sample_model_n_times(model, n_drawn, num_total=self.cfg.num_validation_gen_steps * n_drawn) - if self.cfg.log_dir: + if self.cfg.log_dir and n_drawn > 0: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) for hook in self.valid_sampling_hooks: src.add_sampling_hook(hook) @@ -251,7 +251,8 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0) -> Dict[str, Any]: tick = time.time() self.model.eval() - loss, info = self.algo.compute_batch_losses(self.model, batch) + with torch.no_grad(): + loss, info = self.algo.compute_batch_losses(self.model, batch) if hasattr(batch, "extra_info"): info.update(batch.extra_info) info["eval_time"] = time.time() - tick @@ -272,15 +273,11 @@ def _send_models_to_device(self): self.sampling_model.to(self.device) if self.world_size > 1: self.model = DistributedDataParallel( - self.model.to(self.rank), - device_ids=[self.rank], - output_device=self.rank + self.model.to(self.rank), device_ids=[self.rank], output_device=self.rank ) if self.sampling_model is not self.model: self.sampling_model = DistributedDataParallel( - self.sampling_model.to(self.rank), - device_ids=[self.rank], - output_device=self.rank + self.sampling_model.to(self.rank), device_ids=[self.rank], output_device=self.rank ) def run(self, logger=None): @@ -292,7 +289,10 @@ def run(self, logger=None): self.model.to(self.device) self.sampling_model.to(self.device) import threading - self.model.lock = threading.Lock() # This is created here because you can't pickle a lock, and model is deepcopied -> sampling_model + + self.model.lock = ( + threading.Lock() + ) # This is created here because you can't pickle a lock, and model is deepcopied -> sampling_model epoch_length = max(len(self.training_data), 1) valid_freq = self.cfg.validate_every # If checkpoint_every is not specified, checkpoint at every validation epoch @@ -327,23 +327,15 @@ def run(self, logger=None): start_time = time.time() if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) - - if self.world_size > 1: - all_info_vals = torch.zeros(len(info)).to(self.rank) - for i, k in enumerate(sorted(info.keys())): - all_info_vals[i] = info[k] - dist.all_reduce(all_info_vals, op=dist.ReduceOp.SUM) - for i, k in enumerate(sorted(info.keys())): - info[k] = all_info_vals[i].item() / self.world_size - if self.rank == 0: - self.log(info, it, 'train') + self.log(info, it, "train") if valid_freq > 0 and it % valid_freq == 0: + logger.info("Starting validation epoch") for batch in valid_dl: batch = self._maybe_resolve_shared_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) - self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) + self.log(info, it, "valid") end_metrics = {} for c in callbacks.values(): if hasattr(c, "on_validation_end"): @@ -418,6 +410,17 @@ def load_state(self, path): self.model.load_state_dict(state["models_state_dict"][0]) def log(self, info, index, key): + # First check if we need to reduce the info across processes + if self.world_size > 1: + all_info_vals = torch.zeros(len(info)).to(self.rank) + for i, k in enumerate(sorted(info.keys())): + all_info_vals[i] = info[k] + dist.all_reduce(all_info_vals, op=dist.ReduceOp.SUM) + for i, k in enumerate(sorted(info.keys())): + info[k] = all_info_vals[i].item() / self.world_size + if self.rank != 0: # Only the master process logs + return + if not hasattr(self, "_summary_writer"): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): From 30bd2e32b9ed3fafc5ef7985b0221fd9593b370a Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Oct 2024 09:01:01 -0600 Subject: [PATCH 29/31] tox --- setup.py | 1 + src/C/degree_view.c | 2 +- src/C/edge_view.c | 2 +- src/C/graph_def.c | 2 +- src/C/main.c | 10 ++-- src/C/main.h | 2 +- src/C/mol_graph_to_Data.c | 8 +-- src/C/node_view.c | 2 +- src/gflownet/__init__.py | 1 + src/gflownet/algo/config.py | 1 + src/gflownet/algo/graph_sampling.py | 4 +- src/gflownet/algo/local_search_tb.py | 4 +- src/gflownet/algo/trajectory_balance.py | 30 ++++++----- src/gflownet/data/data_source.py | 46 ++++++++++------- src/gflownet/envs/frag_mol_env.py | 12 +++-- src/gflownet/envs/graph_building_env.py | 42 ++++++++------- src/gflownet/envs/mol_building_env.py | 10 ++-- src/gflownet/models/config.py | 2 +- src/gflownet/models/graph_transformer.py | 57 ++++++++++++--------- src/gflownet/online_trainer.py | 2 +- src/gflownet/utils/misc.py | 2 + src/gflownet/utils/multiprocessing_proxy.py | 7 +-- src/gflownet/utils/sqlite_log.py | 1 + 23 files changed, 148 insertions(+), 102 deletions(-) diff --git a/setup.py b/setup.py index e5cccec7..1418c57c 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ def _get_next_version(): latest_patch = -1 return f"{major}.{minor}.{latest_patch+1}" + ext = [ Extension( name="gflownet._C", diff --git a/src/C/degree_view.c b/src/C/degree_view.c index 7e752946..dbc65bec 100644 --- a/src/C/degree_view.c +++ b/src/C/degree_view.c @@ -67,4 +67,4 @@ PyTypeObject DegreeViewType = { .tp_members = DegreeView_members, .tp_methods = DegreeView_methods, .tp_as_mapping = &DegreeView_mapmeth, -}; \ No newline at end of file +}; diff --git a/src/C/edge_view.c b/src/C/edge_view.c index 6cc40f67..584426e4 100644 --- a/src/C/edge_view.c +++ b/src/C/edge_view.c @@ -271,4 +271,4 @@ PyTypeObject EdgeViewType = { .tp_iter = EdgeView_iter, .tp_iternext = EdgeView_iternext, .tp_richcompare = EdgeView_richcompare, -}; \ No newline at end of file +}; diff --git a/src/C/graph_def.c b/src/C/graph_def.c index bea85c35..9de334ae 100644 --- a/src/C/graph_def.c +++ b/src/C/graph_def.c @@ -149,4 +149,4 @@ PyTypeObject GraphDefType = { .tp_dealloc = (destructor)GraphDef_dealloc, .tp_members = GraphDef_members, .tp_methods = GraphDef_methods, -}; \ No newline at end of file +}; diff --git a/src/C/main.c b/src/C/main.c index d5bdcaa4..0cec6a47 100644 --- a/src/C/main.c +++ b/src/C/main.c @@ -632,9 +632,9 @@ PyObject *Graph_getstate(PyObject *_self, PyObject *args) { memcpy(ptr + 4, self->nodes, self->num_nodes * sizeof(int)); memcpy(ptr + 4 + self->num_nodes, self->edges, self->num_edges * 2 * sizeof(int)); memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2, self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); - memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); - memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, self->degrees, self->num_nodes * sizeof(int)); // Return (bytes, GraphDef) PyObject *res = PyTuple_Pack(2, state, self->graph_def); @@ -670,7 +670,7 @@ PyObject *Graph_setstate(PyObject *_self, PyObject *args) { memcpy(self->nodes, ptr + 4, self->num_nodes * sizeof(int)); memcpy(self->edges, ptr + 4 + self->num_nodes, self->num_edges * 2 * sizeof(int)); memcpy(self->node_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2, self->num_node_attrs * 3 * sizeof(int)); - memcpy(self->edge_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, + memcpy(self->edge_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, self->num_edge_attrs * 3 * sizeof(int)); memcpy(self->degrees, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, self->num_nodes * sizeof(int)); @@ -699,7 +699,7 @@ def search(graph, subgraph, assignments, depth): if u < depth and v < depth: x, y = assignments[u], assignments[v] if not ( - graph.has_edge(x, y) and + graph.has_edge(x, y) and graph.nodes[x] == subgraph.nodes[u] and graph.nodes[y] == subgraph.nodes[v] and graph.edges[x, y] == subgraph.edges[u, v]): return False @@ -1140,4 +1140,4 @@ PyMODINIT_FUNC PyInit__C(void) { } return m; -} \ No newline at end of file +} diff --git a/src/C/main.h b/src/C/main.h index 25162eff..5e45d737 100644 --- a/src/C/main.h +++ b/src/C/main.h @@ -112,4 +112,4 @@ PyObject *Data_collate(PyObject *self, PyObject *args); void Data_init_C(Data *self, PyObject *bytes, PyObject *graph_def, int shapes[][2], int *is_float, int num_matrices, const char **names); -PyObject *mol_graph_to_Data(PyObject *self, PyObject *args); \ No newline at end of file +PyObject *mol_graph_to_Data(PyObject *self, PyObject *args); diff --git a/src/C/mol_graph_to_Data.c b/src/C/mol_graph_to_Data.c index 81e673f8..574e8aa3 100644 --- a/src/C/mol_graph_to_Data.c +++ b/src/C/mol_graph_to_Data.c @@ -133,10 +133,10 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { max_valence[i] -= 1; } } - + /* TODO: Figure out this charge thing... wait or is the problem that I'm _removing_ valence for a + charged mol?? - else if (charge_val[i] == 1) { + else if (charge_val[i] == 1) { // well... maybe this is true of the general case? // if an atom is positively charged then we can just bond it to more stuff? // why 2 for N? @@ -145,8 +145,8 @@ PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { // Determine hypervalent atoms // Hypervalent atoms are atoms that have more than the maximum valence for their atom type because of a positive // charge, but don't have this extra charge filled with bonds, or don't have no_impl set to 1 - /*printf("%d used_valences: %f, max_valence: %d, charge_val: %d, noImpl_val: %d, atom_valences: %d -> %d\n", i, - used_valences[i], max_valence[i], charge_val[i], noImpl_val[i], atom_valences[v_val[i]], + /*printf("%d used_valences: %f, max_valence: %d, charge_val: %d, noImpl_val: %d, atom_valences: %d -> %d\n", i, + used_valences[i], max_valence[i], charge_val[i], noImpl_val[i], atom_valences[v_val[i]], (used_valences[i] >= atom_valences[v_val[i]] && charge_val[i] == 1 && !(used_valences[i] == max_valence[i] || noImpl_val[i] == 1))); fflush(stdout);*/ diff --git a/src/C/node_view.c b/src/C/node_view.c index 208b4bea..d105ae2f 100644 --- a/src/C/node_view.c +++ b/src/C/node_view.c @@ -241,4 +241,4 @@ PyTypeObject NodeViewType = { .tp_iter = NodeView_iter, .tp_iternext = NodeView_iternext, .tp_richcompare = NodeView_richcompare, -}; \ No newline at end of file +}; diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 6c790b4b..d1b8c3e6 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -81,6 +81,7 @@ def get_random_action_prob(self, it: int): def set_task(self, task): raise NotImplementedError() + class GFNTask: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 4ae715ae..76095f2a 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -148,6 +148,7 @@ class SQLConfig(StrictDataClass): gamma: float = 1 penalty: float = -10 + @dataclass class LSTBConfig(StrictDataClass): num_ls_steps: int = 1 diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 72ebb56b..62b592f0 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -6,6 +6,7 @@ import torch.nn as nn from torch import Tensor +from gflownet.algo.config import LSTBConfig from gflownet.envs.graph_building_env import ( Graph, GraphAction, @@ -13,7 +14,6 @@ GraphActionType, action_type_to_mask, ) -from gflownet.algo.config import LSTBConfig from gflownet.models.graph_transformer import GraphTransformerGFN from gflownet.utils.misc import get_worker_device, get_worker_rng @@ -309,7 +309,7 @@ def local_search_sample_from_model( # Finally, we resample new "P_B-on-policy" trajectories from the terminal states # If we're only interested in the accepted trajectories, we use them as starting points instead if cfg.yield_only_accepted: - sampled_terminals = [i['traj'][-1][0] for i in current_trajs] + sampled_terminals = [i["traj"][-1][0] for i in current_trajs] stacked_ci = cond_info if not cfg.yield_only_accepted: diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py index 405a7db0..4c8f38c9 100644 --- a/src/gflownet/algo/local_search_tb.py +++ b/src/gflownet/algo/local_search_tb.py @@ -37,7 +37,7 @@ def create_training_data_from_own_samples(self, model, n, cond_info=None, random self._compute_log_rewards, ) for t in data: - t['accept_rate'] = accept_rate + t["accept_rate"] = accept_rate return data def _compute_log_rewards(self, trajs, cond_info): @@ -51,4 +51,4 @@ def _compute_log_rewards(self, trajs, cond_info): DataSource.compute_properties(self, trajs) # type: ignore DataSource.compute_log_rewards(self, trajs) # type: ignore self.cfg = cfg.algo.tb - # trajs is modified in place, so no need to return anything \ No newline at end of file + # trajs is modified in place, so no need to return anything diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 899532ad..8e9d3679 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -211,9 +211,7 @@ def create_training_data_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob ) elif self.cfg.do_sample_using_masks: - return self.graph_sampler.sample_backward_from_graphs( - graphs, None, cond_info, random_action_prob - ) + return self.graph_sampler.sample_backward_from_graphs(graphs, None, cond_info, random_action_prob) trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: n_back = [ @@ -498,7 +496,7 @@ def compute_batch_losses( if self.cfg.do_parameterize_p_b: # Life is pain, log_p_B is one unit too short for all trajs log_p_B_unif = batch.U_log_p_B - assert log_p_B_unif.shape[0] == log_p_B.shape[0] + assert log_p_B_unif.shape[0] == log_p_B.shape[0] # log_p_B_unif = torch.zeros_like(log_p_B) # for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)): @@ -575,17 +573,19 @@ def compute_batch_losses( reward_loss = reward_losses.mean() else: reward_loss = 0 - + log_Z_reg_loss = (log_Z - self.cfg.regularize_logZ).pow(2).mean() if self.cfg.regularize_logZ is not None else 0 n_loss = n_loss.mean() tb_loss = traj_losses.mean() mle_loss = -traj_log_p_F.mean() - loss = (tb_loss * self.cfg.tb_loss_multiplier - + reward_loss * self.cfg.reward_loss_multiplier - + n_loss * self.cfg.n_loss_multiplier - + log_Z_reg_loss - + mle_loss * self.cfg.mle_loss_multiplier) + loss = ( + tb_loss * self.cfg.tb_loss_multiplier + + reward_loss * self.cfg.reward_loss_multiplier + + n_loss * self.cfg.n_loss_multiplier + + log_Z_reg_loss + + mle_loss * self.cfg.mle_loss_multiplier + ) info = { "reward_loss": reward_loss, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, @@ -598,12 +598,14 @@ def compute_batch_losses( "tb_loss": tb_loss.item(), "batch_entropy": fwd_cat.entropy().mean(), "traj_lens": batch.traj_lens.float().mean(), - 'avg_log_reward': clip_log_R.mean(), + "avg_log_reward": clip_log_R.mean(), } sources = set(batch.sources) if len(sources) > 1: for source in sources: - info[f'{source}_loss'] = traj_losses[torch.as_tensor([i == source for i in batch.sources])].mean().item() + info[f"{source}_loss"] = ( + traj_losses[torch.as_tensor([i == source for i in batch.sources])].mean().item() + ) if self.ctx.has_n() and self.cfg.do_predict_n: info["n_loss_pred"] = scatter( (log_n_preds - batch.log_ns) ** 2, batch_idx, dim=0, dim_size=num_trajs, reduce="sum" @@ -619,7 +621,9 @@ def compute_batch_losses( info["mle_loss"] = mle_loss.item() if not torch.isfinite(loss): - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() return loss, info def analytical_maxent_backward(self, batch, first_graph_idx): diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index a6734959..bad06268 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -9,10 +9,10 @@ from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu -from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical, action_type_to_mask +from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext, action_type_to_mask from gflownet.envs.seq_building_env import SeqBatch from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.utils.misc import get_worker_rng, get_this_wid +from gflownet.utils.misc import get_this_wid, get_worker_rng from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer @@ -85,19 +85,23 @@ def __iter__(self): for d in batch_infos: batch_info.update(d) yield self.create_batch(trajs, batch_info) - self._err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break + self._err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break except (Exception, RuntimeError) as e: self._err_tol -= 1 if self._err_tol == 0: raise e print(f"Error in DataSource: {e} [tol={self._err_tol}]") # print full traceback - import traceback, sys + import sys + import traceback + traceback.print_exc() continue except: print("Unknown error in DataSource") - import traceback, sys + import sys + import traceback + traceback.print_exc() self._err_tol -= 1 continue @@ -140,7 +144,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) - self.mark_all(trajs, source='sample') + self.mark_all(trajs, source="sample") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -170,7 +174,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(n_this_time, t) # TODO: in the cond info refactor, pass the whole thing instead of just the encoding trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info, p) - self.mark_all(trajs, source='sample') + self.mark_all(trajs, source="sample") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -186,7 +190,7 @@ def do_sample_replay(self, num_samples): def iterator(): while self.active: trajs, *_ = self.replay_buffer.sample(num_samples) - self.mark_all(trajs, source='replay') + self.mark_all(trajs, source="replay") self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} @@ -201,7 +205,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) - self.mark_all(trajs, source='dataset') + self.mark_all(trajs, source="dataset") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -220,7 +224,7 @@ def iterator(): # it and the state of the program (e.g. validation mode) cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info, p) - self.mark_all(trajs, source='dataset') + self.mark_all(trajs, source="dataset") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -243,7 +247,7 @@ def iterator(): cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) - self.mark_all(trajs, source='dataset') + self.mark_all(trajs, source="dataset") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -290,7 +294,7 @@ def create_batch(self, trajs, batch_info): def compute_properties(self, trajs, mark_as_online=False): """Sets trajs' obj_props and is_valid keys by querying the task.""" - if all('obj_props' in t for t in trajs): + if all("obj_props" in t for t in trajs): return # TODO: refactor obj_props into properties valid_idcs = torch.tensor([i for i in range(len(trajs)) if trajs[i].get("is_valid", True)]).long() @@ -314,7 +318,7 @@ def compute_properties(self, trajs, mark_as_online=False): def compute_log_rewards(self, trajs): """Sets trajs' log_reward key by querying the task.""" - if all('log_reward' in t for t in trajs): + if all("log_reward" in t for t in trajs): return obj_props = torch.stack([t["obj_props"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} @@ -326,19 +330,25 @@ def compute_log_rewards(self, trajs): def send_to_replay(self, trajs): if self.replay_buffer is not None: for t in trajs: - self.replay_buffer.push(t, t["log_reward"], t["obj_props"], t["cond_info"], t["is_valid"], - unique_obj=self.ctx.get_unique_obj(t["result"]), - priority=t.get("priority", t['log_reward'].item())) + self.replay_buffer.push( + t, + t["log_reward"], + t["obj_props"], + t["cond_info"], + t["is_valid"], + unique_obj=self.ctx.get_unique_obj(t["result"]), + priority=t.get("priority", t["log_reward"].item()), + ) def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): - if 'cond_info' in trajs[i]: + if "cond_info" in trajs[i]: continue trajs[i]["cond_info"] = {k: cond_info[k][i] for k in cond_info} def set_traj_props(self, trajs, props): for i in range(len(trajs)): - if 'obj_props' in trajs[i]: + if "obj_props" in trajs[i]: continue trajs[i]["obj_props"] = props[i] # TODO: refactor diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 8c42de12..238f08c6 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -246,19 +246,23 @@ def graph_to_Data(self, g: Graph) -> gd.Data: if a >= 0: attached[e[0]].append(a) if self.use_strict_edge_filling: - # Since the agent has to fill in the attachment points before adding a new node, or stopping, the - # reverse is that if there is a half-filled edge attachment, the only valid move is to unset the + # Since the agent has to fill in the attachment points before adding a new node, or stopping, the + # reverse is that if there is a half-filled edge attachment, the only valid move is to unset the # other half (i.e. whichever is the first edge we find like that) # We also will only end up being able to do that to leaf nodes (degree 1) if degrees[e[0]] == 1 or degrees[e[1]] == 1: - remove_edge_attr_mask[i, 0] = 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + remove_edge_attr_mask[i, 0] = ( + 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + ) else: remove_edge_attr_mask[i, 0] = 1 if b >= 0: attached[e[1]].append(b) if self.use_strict_edge_filling: if degrees[e[0]] == 1 or degrees[e[1]] == 1: - remove_edge_attr_mask[i, 1] = 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + remove_edge_attr_mask[i, 1] = ( + 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + ) else: remove_edge_attr_mask[i, 1] = 1 diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index b3dbc034..fcd89cd9 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -183,12 +183,12 @@ def step(self, g: Graph, action: GraphAction) -> Graph: gp = g.copy() if action.action is GraphActionType.AddEdge: a, b = action.source, action.target - #assert self.allow_add_edge - #assert a in g and b in g + # assert self.allow_add_edge + # assert a in g and b in g if a > b: a, b = b, a - #assert a != b - #assert not g.has_edge(a, b) + # assert a != b + # assert not g.has_edge(a, b) # Ideally the FA underlying this must only be able to send # create_edge actions which respect this a Graph: elif action.action is GraphActionType.AddNode: if len(g) == 0: - #assert action.source == 0 # TODO: this may not be useful + # assert action.source == 0 # TODO: this may not be useful gp.add_node(0, v=action.value) else: - #assert action.source in g.nodes + # assert action.source in g.nodes e = [action.source, max(g.nodes) + 1] # if kw and 'relabel' in kw: # e[1] = kw['relabel'] # for `parent` consistency, allow relabeling - #assert not g.has_edge(*e) + # assert not g.has_edge(*e) gp.add_node(e[1], v=action.value) gp.add_edge(*e) elif action.action is GraphActionType.SetNodeAttr: - #assert self.allow_node_attr - #assert action.source in gp.nodes + # assert self.allow_node_attr + # assert action.source in gp.nodes # For some "optional" attributes like wildcard atoms, we indicate that they haven't been # chosen by the 'None' value. Here we make sure that either the attribute doesn't # exist, or that it's an optional attribute that hasn't yet been set. - #assert action.attr not in gp.nodes[action.source] or gp.nodes[action.source][action.attr] is None + # assert action.attr not in gp.nodes[action.source] or gp.nodes[action.source][action.attr] is None gp.nodes[action.source][action.attr] = action.value elif action.action is GraphActionType.SetEdgeAttr: - #assert self.allow_edge_attr - #assert g.has_edge(action.source, action.target) - #assert action.attr not in gp.edges[(action.source, action.target)] + # assert self.allow_edge_attr + # assert g.has_edge(action.source, action.target) + # assert action.attr not in gp.edges[(action.source, action.target)] gp.edges[(action.source, action.target)][action.attr] = action.value elif action.action is GraphActionType.RemoveNode: - #assert g.has_node(action.source) + # assert g.has_node(action.source) gp = graph_without_node(gp, action.source) elif action.action is GraphActionType.RemoveNodeAttr: - #assert g.has_node(action.source) + # assert g.has_node(action.source) gp = graph_without_node_attr(gp, action.source, action.attr) elif action.action is GraphActionType.RemoveEdge: - #assert g.has_edge(action.source, action.target) + # assert g.has_edge(action.source, action.target) gp = graph_without_edge(gp, (action.source, action.target)) elif action.action is GraphActionType.RemoveEdgeAttr: - #assert g.has_edge(action.source, action.target) + # assert g.has_edge(action.source, action.target) gp = graph_without_edge_attr(gp, (action.source, action.target), action.attr) else: raise ValueError(f"Unknown action type {action.action}", action.action) @@ -825,7 +825,13 @@ def argmax( # if it wants to convert these indices to env-compatible actions return argmaxes - def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, batch: torch.Tensor = None, pad_value: float = 0.0): + def log_prob( + self, + actions: List[ActionIndex], + logprobs: torch.Tensor = None, + batch: torch.Tensor = None, + pad_value: float = 0.0, + ): """The log-probability of a list of action tuples, effectively indexes `logprobs` using internal slice indices. diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 8b54e2c3..10134a11 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -14,14 +14,18 @@ DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] try: - from gflownet._C import mol_graph_to_Data, Graph as C_Graph, GraphDef, Data_collate + from gflownet._C import Data_collate + from gflownet._C import Graph as C_Graph + from gflownet._C import GraphDef, mol_graph_to_Data C_Graph_available = True except ImportError: import warnings + warnings.warn("Could not import mol_graph_to_Data, Graph, GraphDef from _C, using pure python implementation") C_Graph_available = False + class MolBuildingEnvContext(GraphBuildingEnvContext): """A specification of what is being generated for a GraphBuildingEnv @@ -168,7 +172,7 @@ def __init__( if C_Graph_available: self.graph_def = GraphDef(self.atom_attr_values, self.bond_attr_values) self.graph_cls = self._make_C_graph - assert charges == [0, 1, -1], 'C impl quirk: charges must be [0, 1, -1]' + assert charges == [0, 1, -1], "C impl quirk: charges must be [0, 1, -1]" else: self.graph_cls = Graph @@ -513,4 +517,4 @@ def get_unique_obj(self, g: Graph): assert mol is not None return Chem.CanonSmiles(Chem.MolToSmiles(mol)) except Exception: - return "" \ No newline at end of file + return "" diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index c9f52c01..bf95c40d 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -10,7 +10,7 @@ class GraphTransformerConfig(StrictDataClass): ln_type: str = "pre" num_mlp_layers: int = 1 concat_heads: bool = True - conv_type: str = 'Transformer' + conv_type: str = "Transformer" class SeqPosEnc(int, Enum): diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 9f518504..e5a35996 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -39,14 +39,14 @@ def forward(self, x, edge_index, edge_attr, batch, c): if self.ln_type == "post": agg = self.gen(x, edge_index, edge_attr) l_h = self.linear(self.conv(torch.cat([x, agg], 1), edge_index, edge_attr)) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] x = self.norm1(x + l_h * scale + shift, batch) x = self.norm2(x + self.ff(x), batch) else: x_norm = self.norm1(x, batch) agg = self.gen(x_norm, edge_index, edge_attr) l_h = self.linear(self.conv(torch.cat([x, agg], 1), edge_index, edge_attr)) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] x = x + l_h * scale + shift x = x + self.ff(self.norm2(x, batch)) return x @@ -55,17 +55,19 @@ def forward(self, x, edge_index, edge_attr, batch, c): class GPSLayer(nn.Module): def __init__(self, num_emb, num_heads, num_mlp_layers, residual=False): super().__init__() - self.conv = gnn.GPSConv(num_emb, - gnn.GINEConv(mlp(num_emb, num_emb, num_emb, num_mlp_layers), edge_dim=num_emb), - num_heads, - norm='layer_norm') + self.conv = gnn.GPSConv( + num_emb, + gnn.GINEConv(mlp(num_emb, num_emb, num_emb, num_mlp_layers), edge_dim=num_emb), + num_heads, + norm="layer_norm", + ) self.cscale = nn.Linear(num_emb, num_emb * 2) self.residual = residual - + def forward(self, x, edge_index, edge_attr, batch, c): cs = self.cscale(c) l_h = self.conv(x, edge_index, batch, edge_attr=edge_attr) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] if self.residual: x = x + l_h * scale + shift else: @@ -73,7 +75,6 @@ def forward(self, x, edge_index, edge_attr, batch, c): return x - class GraphTransformer(nn.Module): """An agnostic GraphTransformer class, and the main model used by other model classes @@ -88,8 +89,18 @@ class GraphTransformer(nn.Module): """ def __init__( - self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre", concat=True, - num_mlp_layers=1, conv_type='Transformer', + self, + x_dim, + e_dim, + g_dim, + num_emb=64, + num_layers=3, + num_heads=2, + num_noise=0, + ln_type="pre", + concat=True, + num_mlp_layers=1, + conv_type="Transformer", ): """ Parameters @@ -126,14 +137,12 @@ def __init__( self.x2h = mlp(x_dim + num_noise, num_emb, num_emb, 2) self.e2h = mlp(e_dim, num_emb, num_emb, 2) self.c2h = mlp(max(1, g_dim), num_emb, num_emb, 2) - if conv_type == 'Transformer': - self.gnn = nn.ModuleList([ - GTGENLayer(num_emb, num_heads, concat, num_mlp_layers, ln_type) for _ in range(num_layers) - ]) - elif conv_type == 'GPS': - self.gnn = nn.ModuleList([ - GPSLayer(num_emb, num_heads, num_mlp_layers) for _ in range(num_layers) - ]) + if conv_type == "Transformer": + self.gnn = nn.ModuleList( + [GTGENLayer(num_emb, num_heads, concat, num_mlp_layers, ln_type) for _ in range(num_layers)] + ) + elif conv_type == "GPS": + self.gnn = nn.ModuleList([GPSLayer(num_emb, num_heads, num_mlp_layers) for _ in range(num_layers)]) def forward(self, g: gd.Batch): """Forward pass @@ -159,7 +168,7 @@ def forward(self, g: gd.Batch): e = self.e2h(g.edge_attr) c = self.c2h(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] - if self.conv_type == 'Transformer': + if self.conv_type == "Transformer": # Augment the edges with a new edge to the conditioning # information node. This new node is connected to every node # within its graph. @@ -181,7 +190,7 @@ def forward(self, g: gd.Batch): for i in range(self.num_layers): o = self.gnn[i](o, aug_edge_index, aug_e, aug_batch, c[aug_batch]) - if self.conv_type == 'Transformer': + if self.conv_type == "Transformer": # Remove the conditioning information node embedding o_final = o[: -c.shape[0]] glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) @@ -246,7 +255,7 @@ def __init__( self.env_ctx = env_ctx num_emb = cfg.model.num_emb num_final = num_emb - num_glob_final = num_emb * 2 if cfg.model.graph_transformer.conv_type == 'Transformer' else num_emb + num_glob_final = num_emb * 2 if cfg.model.graph_transformer.conv_type == "Transformer" else num_emb num_edge_feat = num_emb if env_ctx.edges_are_unordered else num_emb * 2 self.edges_are_duplicated = env_ctx.edges_are_duplicated self.edges_are_unordered = env_ctx.edges_are_unordered @@ -298,7 +307,9 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap action_masks=[action_type_to_mask(t, g) for t in action_types], types=action_types, ) - sc = self.logit_scaler(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) + sc = self.logit_scaler( + g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device) + ) cat.logits = [l * sc[b] for l, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them return cat diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 605ddb17..52636d9a 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -39,7 +39,7 @@ def setup_algo(self): algo = self.cfg.algo.method if algo == "TB": algo = TrajectoryBalance - elif algo == 'LSTB': + elif algo == "LSTB": algo = LocalSearchTB elif algo == "FM": algo = FlowMatching diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index 37b7d022..8108c090 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -34,6 +34,7 @@ def create_logger(name="gflownet", loglevel=logging.INFO, logfile=None, streamHa _worker_rng_seed = [142857] _main_process_device = [torch.device("cpu")] + def get_this_wid(): worker_info = torch.utils.data.get_worker_info() wid = worker_info.id if worker_info is not None else 0 @@ -41,6 +42,7 @@ def get_this_wid(): wid = torch.distributed.get_rank() * (worker_info.num_workers if worker_info is not None else 1) + wid return wid + def get_worker_rng(): wid = get_this_wid() if wid not in _worker_rngs: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 86081038..d011b64a 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -9,7 +9,8 @@ import torch.multiprocessing as mp DO_PIN_BUFFERS = False # So actually, we don't really need to pin buffers, but it's a good idea in some cases. - # The shared memory is already a good step. +# The shared memory is already a good step. + class SharedPinnedBuffer: def __init__(self, size): @@ -267,7 +268,7 @@ def run(self): break timeouts = 0 attr, args, kwargs = r - if hasattr(self.obj, 'lock'): + if hasattr(self.obj, "lock"): f.lock.acquire() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] @@ -291,7 +292,7 @@ def run(self): else: msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) - if hasattr(self.obj, 'lock'): + if hasattr(self.obj, "lock"): f.lock.release() def terminate(self): diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index 64ebdf5e..7bf99cb3 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -8,6 +8,7 @@ from gflownet.utils.misc import get_this_wid + class SQLiteLogHook: def __init__(self, log_dir, ctx) -> None: self.log = None # Only initialized in __call__, which will occur inside the worker From ccefd863f38c474f045eb1c70ea37e5728cc6aeb Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Oct 2024 09:21:32 -0600 Subject: [PATCH 30/31] ruff & mypy --- src/gflownet/algo/graph_sampling.py | 34 ++++++++++++--------- src/gflownet/algo/local_search_tb.py | 3 -- src/gflownet/algo/trajectory_balance.py | 4 --- src/gflownet/data/data_source.py | 11 +------ src/gflownet/data/replay_buffer.py | 10 +++--- src/gflownet/models/graph_transformer.py | 2 +- src/gflownet/online_trainer.py | 4 +-- src/gflownet/trainer.py | 4 +-- src/gflownet/utils/multiprocessing_proxy.py | 6 ++-- 9 files changed, 33 insertions(+), 45 deletions(-) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 62b592f0..60b412a2 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -136,16 +136,17 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor # If we're not bootstrapping, we could query the reward # model here, but this is expensive/impractical. Instead # just report forward and backward logprobs + # TODO: stop using dicts and used typed objects data[i]["fwd_logprobs"] = torch.stack(data[i]["fwd_logprobs"]).reshape(-1) data[i]["U_bck_logprobs"] = torch.stack(data[i]["U_bck_logprobs"]).reshape(-1) - data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() - data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() + data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() # type: ignore + data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() # type: ignore data[i]["result"] = graphs[i] if self.pad_with_terminal_state: - data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) + data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) - data[i]["is_sink"].append(1) - assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) + data[i]["is_sink"].append(1) # type: ignore + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore return data def sample_backward_from_graphs( @@ -198,16 +199,19 @@ def sample_backward_from_graphs( for i in range(n): # See comments in sample_from_model - data[i]["traj"] = data[i]["traj"][::-1] + # TODO: stop using dicts and used typed objects + data[i]["traj"] = data[i]["traj"][::-1] # type: ignore # I think this pad is only necessary if we're padding terminal states??? - data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] - data[i]["is_sink"] = data[i]["is_sink"][::-1] - data[i]["U_bck_logprobs"] = torch.tensor([0] + data[i]["U_bck_logprobs"][::-1], device=dev).reshape(-1) + data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] # type: ignore + data[i]["is_sink"] = data[i]["is_sink"][::-1] # type: ignore + data[i]["U_bck_logprobs"] = torch.tensor( + [0] + data[i]["U_bck_logprobs"][::-1], device=dev # type: ignore + ).reshape(-1) if self.pad_with_terminal_state: - data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) + data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) - data[i]["is_sink"].append(1) - assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) + data[i]["is_sink"].append(1) # type: ignore + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore return data def local_search_sample_from_model( @@ -249,7 +253,7 @@ def local_search_sample_from_model( ] # type: ignore graphs = [i["traj"][-1][0] for i in current_trajs] done = [False] * n - fwd_a = [] + fwd_a: List[GraphAction] = [] for i in range(cfg.num_bck_steps): # This modifies `bck_trajs` & `graphs` in place, passing fwd_a computes P_F(s|s') for the previous step self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a) @@ -264,7 +268,7 @@ def local_search_sample_from_model( {"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs ] # type: ignore done = [False] * n - bck_a = [] + bck_a: List[GraphAction] = [] while not all(done): self._forward_step(model, fwd_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob, bck_a) done = [d or (len(t["traj"]) + T) >= self.max_len for d, t, T in zip(done, fwd_trajs, trunc_lens)] @@ -281,7 +285,7 @@ def local_search_sample_from_model( sampled_terminals.extend(terminals) for traj, term in zip(fwd_trajs, terminals): traj["result"] = term - traj["is_accept"] = False + traj["is_accept"] = False # type: ignore # Compute rewards for the acceptance if compute_reward is not None: compute_reward(fwd_trajs, cond_info) diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py index 4c8f38c9..6bd65fcd 100644 --- a/src/gflownet/algo/local_search_tb.py +++ b/src/gflownet/algo/local_search_tb.py @@ -1,9 +1,6 @@ -import torch - from gflownet import GFNTask from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.data.data_source import DataSource -from gflownet.utils.misc import get_worker_device class LocalSearchTB(TrajectoryBalance): diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 8e9d3679..fd23db73 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -620,10 +620,6 @@ def compute_batch_losses( if self.cfg.mle_loss_multiplier != 0: info["mle_loss"] = mle_loss.item() - if not torch.isfinite(loss): - import pdb - - pdb.set_trace() return loss, info def analytical_maxent_backward(self, batch, first_graph_idx): diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index bad06268..8b945d05 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -1,3 +1,4 @@ +import traceback import warnings from typing import Callable, Generator, List, Optional @@ -92,19 +93,9 @@ def __iter__(self): raise e print(f"Error in DataSource: {e} [tol={self._err_tol}]") # print full traceback - import sys - import traceback traceback.print_exc() continue - except: - print("Unknown error in DataSource") - import sys - import traceback - - traceback.print_exc() - self._err_tol -= 1 - continue def validate_batch(self, batch, trajs): for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 62b791d2..197258b2 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -1,6 +1,6 @@ import heapq from threading import Lock -from typing import List +from typing import Any, List import numpy as np import torch @@ -15,8 +15,8 @@ def __init__(self, cfg: Config): Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories) In self.push(), the buffer detaches any torch tensor and sends it to the CPU. """ - self.capacity = cfg.replay.capacity - self.warmup = cfg.replay.warmup + self.capacity = cfg.replay.capacity or int(1e6) + self.warmup = cfg.replay.warmup or 0 assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" self.buffer: List[tuple] = [] @@ -24,7 +24,7 @@ def __init__(self, cfg: Config): self.treat_as_heap = cfg.replay.keep_highest_rewards self.filter_uniques = cfg.replay.keep_only_uniques - self._uniques = set() + self._uniques: set[Any] = set() self._lock = Lock() @@ -56,7 +56,7 @@ def push(self, *args, unique_obj=None, priority=None): self._uniques.add(unique_obj) else: if len(self.buffer) < self.capacity: - self.buffer.append(None) + self.buffer.append(()) if self.filter_uniques: if self.position == 0 and len(self.buffer) == self.capacity: # We're about to wrap around, so remove the oldest element diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index e5a35996..3d5f91b7 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -310,7 +310,7 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap sc = self.logit_scaler( g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device) ) - cat.logits = [l * sc[b] for l, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them + cat.logits = [lg * sc[b] for lg, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them return cat def forward(self, g: gd.Batch): diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 52636d9a..c8bebba0 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -1,6 +1,6 @@ import copy -import os import pathlib +from typing import Any import git import torch @@ -137,7 +137,7 @@ def setup(self): def step(self, loss: Tensor, train_it: int): loss.backward() - info = {} + info: dict[str, Any] = {} if train_it % self.cfg.algo.grad_acc_steps != 0: return info if self.cfg.opt.clip_grad_type is not None: diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 777e1a0f..3321c32e 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -272,11 +272,11 @@ def _send_models_to_device(self): self.model.to(self.device) self.sampling_model.to(self.device) if self.world_size > 1: - self.model = DistributedDataParallel( + self.model = nn.parallel.DistributedDataParallel( self.model.to(self.rank), device_ids=[self.rank], output_device=self.rank ) if self.sampling_model is not self.model: - self.sampling_model = DistributedDataParallel( + self.sampling_model = nn.parallel.DistributedDataParallel( self.sampling_model.to(self.rank), device_ids=[self.rank], output_device=self.rank ) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index d011b64a..13559514 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -268,8 +268,8 @@ def run(self): break timeouts = 0 attr, args, kwargs = r - if hasattr(self.obj, "lock"): - f.lock.acquire() + if hasattr(self.obj, "lock"): # TODO: this is not used anywhere? + self.obj.lock.acquire() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -293,7 +293,7 @@ def run(self): msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) if hasattr(self.obj, "lock"): - f.lock.release() + self.obj.lock.release() def terminate(self): self.stop.set() From 1669c28b1e4e0f98281f3bb0639f5783b03e272d Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Oct 2024 09:45:52 -0600 Subject: [PATCH 31/31] bandit --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6cfd7224..10a45299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,8 @@ universal = "true" [tool.bandit] # B101 tests the use of assert # B301 and B403 test the use of pickle -skips = ["B101", "B301", "B403"] +# B614 tests the use of torch.load/save +skips = ["B101", "B301", "B403", "B614"] exclude_dirs = ["tests", ".tox", ".venv"] [tool.pytest.ini_options]