diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 8e6c2474..ba63e708 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -51,4 +51,18 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo `DataSource` also covers validation sets, including cases such as: - Generating new trajectories (w.r.t a fixed dataset of conditioning goals) -- Evaluating the model's likelihood on trajectories from a fixed, offline dataset \ No newline at end of file +- Evaluating the model's likelihood on trajectories from a fixed, offline dataset + +## Multiprocessing + +We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks. + +Because training models involves sampling them, the worker processes need to be able to call the models. This is done by passing a wrapped model (and possibly wrapped replay buffer) to the workers, using `gflownet.utils.multiprocessing_proxy`. These wrappers ensure that model calls are routed to the main worker process, where the model lives (e.g. in CUDA), and that the returned values are properly serialized and sent back to the worker process. These wrappers are also designed to be API-compatible with models, e.g. `model(input)` or `model.method(input)` will work as expected, regardless of whether `model` is a torch module or a wrapper. Note that it is only possible to call methods on these wrappers, direct attribute access is not supported. + +Note that the workers do not use CUDA, therefore have to work entirely on CPU, but the code is designed to be somewhat agnostic to this fact. By using `get_worker_device`, code can be written without assuming too much; again, calls such as `model(input)` will work as expected. + +On message serialization, naively sending batches of data and results (`Batch` and `GraphActionCategorical`) through multiprocessing queues is fairly inefficient. Torch tries to be smart and will use shared memory for tensors that are sent through queues, which unfortunately is very slow because creating these shared memory files is slow, and because `Data` `Batch`es tend to contain lots of small tensors, which is not a good fit for shared memory. + +We implement two solutions to this problem (in order of preference): +- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages. +- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle. diff --git a/pyproject.toml b/pyproject.toml index 6cfd7224..10a45299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,8 @@ universal = "true" [tool.bandit] # B101 tests the use of assert # B301 and B403 test the use of pickle -skips = ["B101", "B301", "B403"] +# B614 tests the use of torch.load/save +skips = ["B101", "B301", "B403", "B614"] exclude_dirs = ["tests", ".tox", ".venv"] [tool.pytest.ini_options] diff --git a/setup.py b/setup.py index 26bcab4d..1418c57c 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from ast import literal_eval from subprocess import check_output # nosec - command is hard-coded, no possibility of injection -from setuptools import setup +from setuptools import Extension, setup def _get_next_version(): @@ -25,4 +25,18 @@ def _get_next_version(): return f"{major}.{minor}.{latest_patch+1}" -setup(name="gflownet", version=_get_next_version()) +ext = [ + Extension( + name="gflownet._C", + sources=[ + "src/C/main.c", + "src/C/data.c", + "src/C/graph_def.c", + "src/C/node_view.c", + "src/C/edge_view.c", + "src/C/degree_view.c", + "src/C/mol_graph_to_Data.c", + ], + ) +] +setup(name="gflownet", version=_get_next_version(), ext_modules=ext) diff --git a/src/C/data.c b/src/C/data.c new file mode 100644 index 00000000..1b8e3c3c --- /dev/null +++ b/src/C/data.c @@ -0,0 +1,612 @@ +#define PY_SSIZE_T_CLEAN +#include "main.h" +#include "structmember.h" +#include + +static void Data_dealloc(Data *self) { + Py_XDECREF(self->bytes); + Py_XDECREF(self->graph_def); + free(self->shapes); + free(self->is_float); + free(self->offsets); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *torch_module = NULL; +static PyObject *torch_gd_module = NULL; +static void _check_torch() { + if (torch_module == NULL) { + torch_module = PyImport_ImportModule("torch"); + torch_gd_module = PyImport_ImportModule("torch_geometric.data"); + } +} + +static PyObject *gbe_module = NULL; +static PyObject *Stop, *AddNode, *SetNodeAttr, *AddEdge, *SetEdgeAttr, *RemoveNode, *RemoveNodeAttr, *RemoveEdge, + *RemoveEdgeAttr, *GraphAction, *GraphActionType; +void _check_gbe() { + if (gbe_module == NULL) { + gbe_module = PyImport_ImportModule("gflownet.envs.graph_building_env"); + if (gbe_module == NULL) { + PyErr_SetString(PyExc_ImportError, "Could not import gflownet.envs.graph_building_env"); + return; + } + GraphActionType = PyObject_GetAttrString(gbe_module, "GraphActionType"); + GraphAction = PyObject_GetAttrString(gbe_module, "GraphAction"); + Stop = PyObject_GetAttrString(GraphActionType, "Stop"); + AddNode = PyObject_GetAttrString(GraphActionType, "AddNode"); + SetNodeAttr = PyObject_GetAttrString(GraphActionType, "SetNodeAttr"); + AddEdge = PyObject_GetAttrString(GraphActionType, "AddEdge"); + SetEdgeAttr = PyObject_GetAttrString(GraphActionType, "SetEdgeAttr"); + RemoveNode = PyObject_GetAttrString(GraphActionType, "RemoveNode"); + RemoveNodeAttr = PyObject_GetAttrString(GraphActionType, "RemoveNodeAttr"); + RemoveEdge = PyObject_GetAttrString(GraphActionType, "RemoveEdge"); + RemoveEdgeAttr = PyObject_GetAttrString(GraphActionType, "RemoveEdgeAttr"); + } +} + +static PyObject *Data_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + Data *self; + self = (Data *)type->tp_alloc(type, 0); + if (self != NULL) { + self->bytes = NULL; + self->graph_def = NULL; + self->shapes = NULL; + self->is_float = NULL; + self->num_matrices = 0; + self->names = NULL; + self->offsets = NULL; + } + return (PyObject *)self; +} + +static int Data_init(Data *self, PyObject *args, PyObject *kwds) { + if (args != NULL && 0) { + PyErr_SetString(PyExc_KeyError, "Trying to initialize a Data object from Python. Use *_graph_to_Data instead."); + return -1; + } + return 0; +} + +void *Data_ptr_and_shape(Data *self, char *name, int *n, int *m) { + for (int i = 0; i < self->num_matrices; i++) { + if (strcmp(self->names[i], name) == 0) { + *n = self->shapes[i * 2 + 0]; + *m = self->shapes[i * 2 + 1]; + return PyByteArray_AsString(self->bytes) + self->offsets[i]; + } + } + return (void *)0xdeadbeef; +} + +void Data_init_C(Data *self, PyObject *bytes, PyObject *graph_def, int shapes[][2], int *is_float, int num_matrices, + const char **names) { + PyObject *tmp; + tmp = (PyObject *)self->bytes; + Py_INCREF(bytes); + self->bytes = bytes; + Py_XDECREF(tmp); + + tmp = (PyObject *)self->graph_def; + Py_INCREF(graph_def); + self->graph_def = (GraphDef *)graph_def; + Py_XDECREF(tmp); + + self->shapes = malloc(sizeof(int) * num_matrices * 2); + self->offsets = malloc(sizeof(int) * num_matrices); + int offset_bytes = 0; + for (int i = 0; i < num_matrices; i++) { + self->shapes[i * 2 + 0] = shapes[i][0]; + self->shapes[i * 2 + 1] = shapes[i][1]; + self->offsets[i] = offset_bytes; + offset_bytes += shapes[i][0] * shapes[i][1] * (is_float[i] ? 4 : 8); + } + self->is_float = malloc(sizeof(int) * num_matrices); + memcpy(self->is_float, is_float, sizeof(int) * num_matrices); + self->num_matrices = num_matrices; + self->names = names; +} + +PyObject *Data_mol_aidx_to_GraphAction(PyObject *_self, PyObject *args) { + _check_gbe(); + int _t, row, col; // _t is unused, we have gt from python + PyObject *gt = NULL; + if (!PyArg_ParseTuple(args, "(iii)O", &_t, &row, &col, >)) { + return NULL; + } + Data *self = (Data *)_self; + // These are singletons, so we can just compare pointers! + if (gt == Stop) { + PyObject *args = PyTuple_Pack(1, Stop); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(args); + return res; + } + if (gt == AddNode) { + // return GraphAction(t, source=act_row, value=self.atom_attr_values["v"][act_col]) + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *vals = PyDict_GetItemString(self->graph_def->node_values, "v"); // ref is borrowed + PyObject *value = PyList_GetItem(vals, col); // ref is borrowed + PyObject *args = PyTuple_Pack(5, gt, source, Py_None, value, Py_None); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == SetNodeAttr) { + // attr, val = self.atom_attr_logit_map[act_col] + // return GraphAction(t, source=act_row, attr=attr, value=val) + PyObject *attr, *val; + for (int i = 0; i < self->graph_def->num_settable_node_attrs; i++) { + int start = self->graph_def->node_attr_offsets[i] - i; // - i because these are logits + int end = self->graph_def->node_attr_offsets[i + 1] - (i + 1); + if (start <= col && col < end) { + attr = PyList_GetItem(self->graph_def->node_poskey, i); + PyObject *vals = PyDict_GetItem(self->graph_def->node_values, attr); + val = PyList_GetItem(vals, col - start + 1); // + 1 because the default is 0 + break; + } + } + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *args = PyTuple_Pack(5, gt, source, Py_None, val, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == AddEdge) { + int n, m; + long *non_edge_index = Data_ptr_and_shape(self, "non_edge_index", &n, &m); + PyObject *u = PyLong_FromLong(non_edge_index[row]); + PyObject *v = PyLong_FromLong(non_edge_index[row + m]); + PyObject *args = PyTuple_Pack(3, AddEdge, u, v); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(args); + return res; + } + if (gt == SetEdgeAttr) { + // attr, val = self.bond_attr_logit_map[act_col] + // return GraphAction(t, source=act_row, attr=attr, value=val) + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + // edge_index should be (2, m), * 2 because edges are duplicated + PyObject *u = PyLong_FromLong(edge_index[row * 2]); + PyObject *v = PyLong_FromLong(edge_index[row * 2 + m]); + PyObject *attr, *val = NULL; + for (int i = 0; i < self->graph_def->num_settable_edge_attrs; i++) { + int start = self->graph_def->edge_attr_offsets[i] - i; // - i because these are logits + int end = self->graph_def->edge_attr_offsets[i + 1] - (i + 1); + if (start <= col && col < end) { + attr = PyList_GetItem(self->graph_def->edge_poskey, i); + PyObject *vals = PyDict_GetItem(self->graph_def->edge_values, attr); + val = PyList_GetItem(vals, col - start + 1); // + 1 because the default is 0 + break; + } + } + if (val == NULL) { + PyErr_SetString(PyExc_ValueError, "failed to find edge attr"); + return NULL; + } + PyObject *args = PyTuple_Pack(5, gt, u, v, val, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(args); + return res; + } + if (gt == RemoveNode) { + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *args = PyTuple_Pack(2, RemoveNode, source); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == RemoveNodeAttr) { + PyObject *source = PyLong_FromLong(row); // ref is new + PyObject *attr = PyList_GetItem(self->graph_def->node_poskey, col); // this works because 'v' is last + PyObject *args = PyTuple_Pack(5, RemoveNodeAttr, source, Py_None, Py_None, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(source); + Py_DECREF(args); + return res; + } + if (gt == RemoveEdge) { + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + // edge_index should be (2, m), * 2 because edges are duplicated + PyObject *u = PyLong_FromLong(edge_index[row * 2 + 0]); + PyObject *v = PyLong_FromLong(edge_index[row * 2 + m]); + PyObject *vargs = PyTuple_Pack(3, RemoveEdge, u, v); + PyObject *res = PyObject_CallObject(GraphAction, vargs); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(vargs); + return res; + } + if (gt == RemoveEdgeAttr) { + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + // edge_index should be (2, m), * 2 because edges are duplicated + PyObject *u = PyLong_FromLong(edge_index[row * 2 + 0]); + PyObject *v = PyLong_FromLong(edge_index[row * 2 + m]); + PyObject *attr = PyList_GetItem(self->graph_def->edge_poskey, col); + PyObject *args = PyTuple_Pack(5, RemoveEdgeAttr, u, v, Py_None, attr); + PyObject *res = PyObject_CallObject(GraphAction, args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(args); + return res; + } + PyErr_SetString(PyExc_ValueError, "Unknown action type"); + return NULL; +} + +PyObject *Data_mol_GraphAction_to_aidx(PyObject *_self, PyObject *args) { + _check_gbe(); + + PyObject *action = NULL; + if (!PyArg_ParseTuple(args, "O", &action)) { + return NULL; + } + Data *self = (Data *)_self; + PyObject *action_type = PyObject_GetAttrString(action, "action"); // new ref + // These are singletons, so we can just compare pointers! + long row = 0, col = 0; + if (action_type == Stop) { + // return (0, 0) + } else if (action_type == AddNode) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + PyObject *val = PyObject_GetAttrString(action, "value"); + PyObject *vals = PyDict_GetItemString(self->graph_def->node_values, "v"); + col = PySequence_Index(vals, val); + Py_DECREF(val); + } else if (action_type == SetNodeAttr) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + PyObject *attr = PyObject_GetAttrString(action, "attr"); + PyObject *val = PyObject_GetAttrString(action, "value"); + PyObject *vals = PyDict_GetItem(self->graph_def->node_values, attr); + col = PySequence_Index(vals, val) - 1; // -1 because the default value is at index 0 + int attr_pos = PyLong_AsLong(PyDict_GetItem(self->graph_def->node_keypos, attr)); + col += self->graph_def->node_attr_offsets[attr_pos] - attr_pos; + Py_DECREF(attr); + Py_DECREF(val); + } else if (action_type == AddEdge) { + col = 0; + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *non_edge_index = Data_ptr_and_shape(self, "non_edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((non_edge_index[i] == u && non_edge_index[i + m] == v) || + (non_edge_index[i] == v && non_edge_index[i + m] == u)) { + row = i; + break; + } + } + } else if (action_type == SetEdgeAttr) { + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((edge_index[i] == u && edge_index[i + m] == v) || (edge_index[i] == v && edge_index[i + m] == u)) { + row = i / 2; // edges are duplicated + break; + } + } + PyObject *attr = PyObject_GetAttrString(action, "attr"); + PyObject *val = PyObject_GetAttrString(action, "value"); + PyObject *vals = PyDict_GetItem(self->graph_def->edge_values, attr); + col = PySequence_Index(vals, val) - 1; // -1 because the default value is at index 0 + int attr_pos = PyLong_AsLong(PyDict_GetItem(self->graph_def->edge_keypos, attr)); + col += self->graph_def->edge_attr_offsets[attr_pos] - attr_pos; + Py_DECREF(attr); + Py_DECREF(val); + } else if (action_type == RemoveNode) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + col = 0; + } else if (action_type == RemoveNodeAttr) { + row = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + PyObject *attr = PyObject_GetAttrString(action, "attr"); + col = PyLong_AsLong(PyDict_GetItem(self->graph_def->node_keypos, attr)); + } else if (action_type == RemoveEdge) { + col = 0; + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((edge_index[i] == u && edge_index[i + m] == v) || (edge_index[i] == v && edge_index[i + m] == u)) { + row = i / 2; // edges are duplicated + break; + } + } + } else if (action_type == RemoveEdgeAttr) { + int u = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "source")); + int v = borrow_new_and_call(PyLong_AsLong, PyObject_GetAttrString(action, "target")); + int n, m; + long *edge_index = Data_ptr_and_shape(self, "edge_index", &n, &m); + for (int i = 0; i < m; i++) { + if ((edge_index[i] == u && edge_index[i + m] == v) || (edge_index[i] == v && edge_index[i + m] == u)) { + row = i / 2; // edges are duplicated + break; + } + } + PyObject *attr = PyObject_GetAttrString(action, "attr"); + col = PyLong_AsLong(PyDict_GetItem(self->graph_def->edge_keypos, attr)); + } else { + PyErr_SetString(PyExc_ValueError, "Unknown action type"); + return NULL; + } + PyObject *py_row = PyLong_FromLong(row); + PyObject *py_col = PyLong_FromLong(col); + PyObject *res = PyTuple_Pack(2, py_row, py_col); + Py_DECREF(py_row); + Py_DECREF(py_col); + Py_DECREF(action_type); + return res; +} + +PyObject *Data_as_torch(PyObject *_self, PyObject *unused_args) { + _check_torch(); + Data *self = (Data *)_self; + PyObject *res = PyDict_New(); + PyObject *frombuffer = PyObject_GetAttrString(torch_module, "frombuffer"); + PyObject *empty = PyObject_GetAttrString(torch_module, "empty"); + PyObject *dtype_f32 = PyObject_GetAttrString(torch_module, "float32"); + PyObject *dtype_i64 = PyObject_GetAttrString(torch_module, "int64"); + PyObject *fb_args = PyTuple_Pack(1, self->bytes); + PyObject *fb_kwargs = PyDict_New(); + int do_del_kw = 0; + int offset = 0; + for (int i = 0; i < self->num_matrices; i++) { + int i_num_items = self->shapes[i * 2 + 0] * self->shapes[i * 2 + 1]; + PyObject *tensor; + PyDict_SetItemString(fb_kwargs, "dtype", self->is_float[i] ? dtype_f32 : dtype_i64); + if (i_num_items == 0) { + if (do_del_kw) { + PyDict_DelItemString(fb_kwargs, "offset"); + PyDict_DelItemString(fb_kwargs, "count"); + do_del_kw = 0; + } + PyObject *zero = PyLong_FromLong(0); + PyObject *args = PyTuple_Pack(1, zero); + tensor = PyObject_Call(empty, args, fb_kwargs); + Py_DECREF(args); + Py_DECREF(zero); + } else { + PyObject *py_offset = PyLong_FromLong(offset); + PyObject *py_numi = PyLong_FromLong(i_num_items); + PyDict_SetItemString(fb_kwargs, "offset", py_offset); + PyDict_SetItemString(fb_kwargs, "count", py_numi); + Py_DECREF(py_offset); + Py_DECREF(py_numi); + do_del_kw = 1; + tensor = PyObject_Call(frombuffer, fb_args, fb_kwargs); + } + PyObject *reshaped_tensor = + PyObject_CallMethod(tensor, "view", "ii", self->shapes[i * 2 + 0], self->shapes[i * 2 + 1]); + PyDict_SetItemString(res, self->names[i], reshaped_tensor); + Py_DECREF(tensor); + Py_DECREF(reshaped_tensor); + offset += i_num_items * (self->is_float[i] ? 4 : 8); + } + Py_DECREF(frombuffer); + Py_DECREF(dtype_f32); + Py_DECREF(dtype_i64); + Py_DECREF(fb_args); + Py_DECREF(fb_kwargs); + + PyObject *Data_cls = PyObject_GetAttrString(torch_gd_module, "Data"); // new ref + PyObject *args = PyTuple_New(0); + PyObject *Data_res = PyObject_Call(Data_cls, args, res); + Py_DECREF(args); + Py_DECREF(Data_cls); + Py_DECREF(res); + return Data_res; +} + +static PyMemberDef Data_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef Data_methods[] = { + {"mol_aidx_to_GraphAction", (PyCFunction)Data_mol_aidx_to_GraphAction, METH_VARARGS, "mol_aidx_to_GraphAction"}, + {"mol_GraphAction_to_aidx", (PyCFunction)Data_mol_GraphAction_to_aidx, METH_VARARGS, "mol_GraphAction_to_aidx"}, + {"as_torch", (PyCFunction)Data_as_torch, METH_NOARGS, "to pyg data"}, + {NULL} /* Sentinel */ +}; + +PyTypeObject DataType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.Data", + .tp_doc = PyDoc_STR("Constrained Data object"), + .tp_basicsize = sizeof(Data), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = Data_new, + .tp_init = (initproc)Data_init, + .tp_dealloc = (destructor)Data_dealloc, + .tp_members = Data_members, + .tp_methods = Data_methods, +}; + +#include + +PyObject *Data_collate(PyObject *self, PyObject *args) { + PyObject *graphs = NULL, *follow_batch = NULL; + if (!PyArg_ParseTuple(args, "O|O", &graphs, &follow_batch)) { + return NULL; + } + if (!PyList_Check(graphs) || PyList_Size(graphs) < 1 || !PyObject_TypeCheck(PyList_GetItem(graphs, 0), &DataType)) { + PyErr_SetString(PyExc_TypeError, "Data_collate expects a non-empty list of Data objects"); + return NULL; + } + _check_torch(); + PyObject *empty_cb = PyObject_GetAttrString(torch_module, "empty"); + PyObject *empty_cb_int64_arg = PyDict_New(); + PyDict_SetItemString(empty_cb_int64_arg, "dtype", PyObject_GetAttrString(torch_module, "int64")); + int num_graphs = PyList_Size(graphs); + // PyObject *batch = PyObject_CallMethod(torch_gd_module, "Batch", NULL); + // Actually we want batch = Batch(_base_cls=graphs[0].__class__) + PyObject *base_cls = PyObject_GetAttrString(PyList_GetItem(graphs, 0), "__class__"); + PyObject *batch = PyObject_CallMethod(torch_gd_module, "Batch", "O", base_cls); + PyObject *slice_dict = PyDict_New(); + Data *first = (Data *)PyList_GetItem(graphs, 0); + int index_of_x = 0; + for (int i = 0; i < first->num_matrices; i++) { + if (strcmp(first->names[i], "x") == 0) { + index_of_x = i; + break; + } + } + for (int i = 0; i < first->num_matrices; i++) { + PyObject *tensor = NULL; + const char *name = first->names[i]; + int do_compute_batch = 0; + if (strcmp(name, "x") == 0) { + do_compute_batch = 1; + } else if (follow_batch != NULL) { + PyObject *py_name = PyUnicode_FromString(name); + do_compute_batch = PySequence_Contains(follow_batch, py_name); + Py_DECREF(py_name); + } + int item_size = (first->is_float[i] ? sizeof(float) : sizeof(long)); + // First let's count the total number of items in the batch + int num_rows = 0; + int cat_dim = 0, val_dim = 1; + if (strstr(name, "index") != NULL) { + cat_dim = 1; + val_dim = 0; + } + int val_dim_size = first->shapes[i * 2 + val_dim]; + for (int j = 0; j < num_graphs; j++) { + Data *data = (Data *)PyList_GetItem(graphs, j); + // We're concatenating along the first dimension, so we should check that the second dimension is the same + if (data->shapes[i * 2 + val_dim] != val_dim_size) { + PyErr_Format(PyExc_TypeError, + "mol_Data_collate concatenates %s along dimension %d, but tensor has shape %d along " + "dimension %d", + name, cat_dim, data->shapes[i * 2 + val_dim], val_dim); + return NULL; + } + num_rows += data->shapes[i * 2 + cat_dim]; + } + // Now we allocate the tensor itself, its batch tensor & slices + PyObject *py_num_rows = PyLong_FromLong(num_rows); + PyObject *py_val_dim_size = PyLong_FromLong(val_dim_size); + PyObject *py_shape = cat_dim == 0 ? PyTuple_Pack(2, py_num_rows, py_val_dim_size) + : PyTuple_Pack(2, py_val_dim_size, py_num_rows); + int tensor_shape_x = cat_dim == 0 ? num_rows : val_dim_size; + int tensor_shape_y = cat_dim == 0 ? val_dim_size : num_rows; + if (first->is_float[i]) { + tensor = PyObject_Call(empty_cb, py_shape, NULL); + } else { + tensor = PyObject_Call(empty_cb, py_shape, empty_cb_int64_arg); + } + int _total_ni = val_dim_size * num_rows; + Py_DECREF(py_shape); + PyObject *batch_tensor = NULL; + if (do_compute_batch) { + py_shape = PyTuple_Pack(1, py_num_rows); + batch_tensor = PyObject_Call(empty_cb, py_shape, empty_cb_int64_arg); + Py_DECREF(py_shape); + } + PyObject *py_num_graphsp1 = PyLong_FromLong(num_graphs + 1); + py_shape = PyTuple_Pack(1, py_num_graphsp1); + PyObject *slice_tensor = PyObject_Call(empty_cb, py_shape, empty_cb_int64_arg); + Py_DECREF(py_shape); + Py_DECREF(py_num_rows); + Py_DECREF(py_val_dim_size); + Py_DECREF(py_num_graphsp1); + PyObject *ptr = PyObject_CallMethod(tensor, "data_ptr", ""); + void *tensor_ptr = PyLong_AsVoidPtr(ptr); + Py_DECREF(ptr); + long *batch_tensor_ptr; + if (do_compute_batch) { + ptr = PyObject_CallMethod(batch_tensor, "data_ptr", ""); + batch_tensor_ptr = PyLong_AsVoidPtr(ptr); + Py_DECREF(ptr); + } + ptr = PyObject_CallMethod(slice_tensor, "data_ptr", ""); + long *slice_ptr = PyLong_AsVoidPtr(ptr); + Py_DECREF(ptr); + // Now we copy the data from the individual Data objects to the batch + int offset_bytes = 0; + int offset_items = 0; + int offset_rows = 0; + slice_ptr[0] = 0; + int value_increment = 0; // we need to increment edge indices across graphs + int do_increment = strstr(name, "index") != NULL; + for (int j = 0; j < num_graphs; j++) { + Data *data = (Data *)PyList_GetItem(graphs, j); // borrowed ref + int num_items_j = data->shapes[i * 2 + 0] * data->shapes[i * 2 + 1]; + if (cat_dim == 0) { + // we're given a 2d matrix m, n and we want to fit it into a bigger matrix M, n + void *dst = tensor_ptr; + void *src = PyByteArray_AsString(data->bytes) + data->offsets[i]; + for (int u = 0; u < data->shapes[i * 2 + 0]; u++) { + for (int v = 0; v < data->shapes[i * 2 + 1]; v++) { + // dst[u + offset_rows, v] = src[u, v] + if (first->is_float[i]) { + ((float *)dst)[(u + offset_rows) * tensor_shape_y + v] = + ((float *)src)[u * data->shapes[i * 2 + 1] + v]; + } else { + ((long *)dst)[(u + offset_rows) * tensor_shape_y + v] = + ((long *)src)[u * data->shapes[i * 2 + 1] + v] + value_increment; + } + } + } + } else { + // we're given a 2d matrix m, n and we want to fit it into a bigger matrix m, N + void *dst = tensor_ptr; + void *src = PyByteArray_AsString(data->bytes) + data->offsets[i]; + for (int u = 0; u < data->shapes[i * 2 + 0]; u++) { + for (int v = 0; v < data->shapes[i * 2 + 1]; v++) { + // dst[u, v + offset_rows] = src[u, v] + if (first->is_float[i]) { + ((float *)dst)[u * tensor_shape_y + v + offset_rows] = + ((float *)src)[u * data->shapes[i * 2 + 1] + v]; + } else { + ((long *)dst)[u * tensor_shape_y + v + offset_rows] = + ((long *)src)[u * data->shapes[i * 2 + 1] + v] + value_increment; + } + } + } + } + if (do_compute_batch) { + for (int k = 0; k < data->shapes[i * 2 + cat_dim]; k++) { + batch_tensor_ptr[k + offset_rows] = j; + } + } + offset_rows += data->shapes[i * 2 + cat_dim]; + offset_items += num_items_j; + offset_bytes += num_items_j * item_size; + slice_ptr[j + 1] = offset_rows; + if (do_increment) { + value_increment += data->shapes[index_of_x * 2]; // increment by num_nodes + } + } + PyObject_SetAttrString(batch, name, tensor); + // for x, the batch is just 'batch', otherwise '%s_batch' + if (strcmp(name, "x") == 0) { + PyObject_SetAttrString(batch, "batch", batch_tensor); + } else if (do_compute_batch) { + char buf[100]; + sprintf(buf, "%s_batch", name); + PyObject_SetAttrString(batch, buf, batch_tensor); + } + PyDict_SetItemString(slice_dict, name, slice_tensor); + Py_DECREF(tensor); + Py_XDECREF(batch_tensor); + Py_DECREF(slice_tensor); + } + PyObject_SetAttrString(batch, "_slice_dict", slice_dict); + Py_DECREF(empty_cb); + Py_DECREF(empty_cb_int64_arg); + Py_DECREF(slice_dict); + + return batch; +} diff --git a/src/C/degree_view.c b/src/C/degree_view.c new file mode 100644 index 00000000..dbc65bec --- /dev/null +++ b/src/C/degree_view.c @@ -0,0 +1,70 @@ +#define PY_SSIZE_T_CLEAN +#include "./main.h" +#include "structmember.h" +#include + +static void DegreeView_dealloc(DegreeView *self) { + Py_XDECREF(self->graph); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *DegreeView_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + DegreeView *self; + self = (DegreeView *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph = NULL; + } + return (PyObject *)self; +} + +static int DegreeView_init(DegreeView *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph", NULL}; + PyObject *graph = NULL, *tmp; + int node_id = -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &graph, &node_id)) + return -1; + + if (graph) { + tmp = (PyObject *)self->graph; + Py_INCREF(graph); + self->graph = (Graph *)graph; + Py_XDECREF(tmp); + } + return 0; +} + +static PyMemberDef DegreeView_members[] = { + {NULL} /* Sentinel */ +}; + +static PyObject *DegreeView_getitem(PyObject *_self, PyObject *k) { + DegreeView *self = (DegreeView *)_self; + int index = PyLong_AsLong(k); + if (PyErr_Occurred()) { + return NULL; + } + return PyLong_FromLong(self->graph->degrees[index]); +} + +static PyMappingMethods DegreeView_mapmeth = { + .mp_subscript = DegreeView_getitem, +}; + +static PyMethodDef DegreeView_methods[] = { + {NULL} /* Sentinel */ +}; + +PyTypeObject DegreeViewType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.DegreeView", + .tp_doc = PyDoc_STR("DegreeView object"), + .tp_basicsize = sizeof(DegreeView), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = DegreeView_new, + .tp_init = (initproc)DegreeView_init, + .tp_dealloc = (destructor)DegreeView_dealloc, + .tp_members = DegreeView_members, + .tp_methods = DegreeView_methods, + .tp_as_mapping = &DegreeView_mapmeth, +}; diff --git a/src/C/edge_view.c b/src/C/edge_view.c new file mode 100644 index 00000000..584426e4 --- /dev/null +++ b/src/C/edge_view.c @@ -0,0 +1,274 @@ +#include "./main.h" +#include "structmember.h" + +static void EdgeView_dealloc(EdgeView *self) { + Py_XDECREF(self->graph); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *EdgeView_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + EdgeView *self; + self = (EdgeView *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph = NULL; + } + return (PyObject *)self; +} + +static int EdgeView_init(EdgeView *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph", "index", NULL}; + PyObject *graph = NULL, *tmp; + int index = -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i", kwlist, &graph, &index)) + return -1; + + if (graph) { + tmp = (PyObject *)self->graph; + Py_INCREF(graph); + self->graph = (Graph *)graph; + Py_XDECREF(tmp); + self->index = index; + } + return 0; +} + +static PyMemberDef EdgeView_members[] = { + {NULL} /* Sentinel */ +}; + +static int EdgeView_setitem(PyObject *_self, PyObject *k, PyObject *v) { + EdgeView *self = (EdgeView *)_self; + if (self->index < 0) { + PyErr_SetString(PyExc_KeyError, "Cannot assign to a node"); + return -1; + } + PyObject *r = Graph_setedgeattr(self->graph, self->index, k, v); + return r == NULL ? -1 : 0; +} + +int get_edge_index_from_pos(Graph *g, int u_pos, int v_pos) { + for (int i = 0; i < g->num_edges; i++) { + if ((g->edges[2 * i] == u_pos && g->edges[2 * i + 1] == v_pos) || + (g->edges[2 * i] == v_pos && g->edges[2 * i + 1] == u_pos)) { + return i; + } + } + return -2; +} + +int get_edge_index(Graph *g, int u, int v) { + if (u > v) { + int tmp = u; + u = v; + v = tmp; + } + int u_pos = -1; + int v_pos = -1; + for (int i = 0; i < g->num_nodes; i++) { + if (g->nodes[i] == u) { + u_pos = i; + } else if (g->nodes[i] == v) { + v_pos = i; + } + if (u_pos >= 0 && v_pos >= 0) { + break; + } + } + return get_edge_index_from_pos(g, u_pos, v_pos); +} + +int get_edge_index_py(Graph *g, PyObject *k) { + if (!PyTuple_Check(k) || PyTuple_Size(k) != 2) { + PyErr_SetString(PyExc_KeyError, "EdgeView key must be a tuple of length 2"); + return -1; + } + int u = PyLong_AsLong(PyTuple_GetItem(k, 0)); + int v = PyLong_AsLong(PyTuple_GetItem(k, 1)); + return get_edge_index(g, u, v); +} + +static PyObject *EdgeView_getitem(PyObject *_self, PyObject *k) { + EdgeView *self = (EdgeView *)_self; + if (self->index < 0) { + int edge_idx = get_edge_index_py(self->graph, k); + if (edge_idx < 0) { + if (edge_idx == -2) + PyErr_SetString(PyExc_KeyError, "Edge not found"); + return NULL; + } + PyObject *idx = PyLong_FromLong(edge_idx); + PyObject *args = PyTuple_Pack(2, self->graph, idx); + PyObject *res = PyObject_CallObject((PyObject *)&EdgeViewType, args); + Py_DECREF(args); + Py_DECREF(idx); + return res; + } + return Graph_getedgeattr(self->graph, self->index, k); +} + +static int EdgeView_contains(PyObject *_self, PyObject *v) { + EdgeView *self = (EdgeView *)_self; + if (self->index < 0) { + int index = get_edge_index_py(self->graph, + v); // Returns -2 if not found, -1 on error + if (index == -1) { + return -1; // There was an error + } + return index >= 0; + } + PyObject *attr = Graph_getedgeattr(self->graph, self->index, v); + if (attr == NULL) { + PyErr_Clear(); + return 0; + } + Py_DECREF(attr); + return 1; +} +PyObject *EdgeView_iter(PyObject *_self) { + EdgeView *self = (EdgeView *)_self; + if (self->index != -1) { + PyErr_SetString(PyExc_TypeError, "A bound EdgeView is not iterable"); + return NULL; + } + Py_INCREF(_self); // We have to return a new reference, not a borrowed one + return _self; +} + +PyObject *EdgeView_iternext(PyObject *_self) { + EdgeView *self = (EdgeView *)_self; + self->index++; + if (self->index >= self->graph->num_edges) { + return NULL; + } + int u = self->graph->nodes[self->graph->edges[2 * self->index]]; + int v = self->graph->nodes[self->graph->edges[2 * self->index + 1]]; + PyObject *pu = PyLong_FromLong(u); + PyObject *pv = PyLong_FromLong(v); + PyObject *res = PyTuple_Pack(2, pu, pv); + Py_DECREF(pu); + Py_DECREF(pv); + return res; +} + +Py_ssize_t EdgeView_len(PyObject *_self) { + EdgeView *self = (EdgeView *)_self; + if (self->index != -1) { + // This is inefficient, much like a lot of this codebase... but it's in C. For our graph sizes + // it's not a big deal (and unsurprisingly, it's fast) + Py_ssize_t num_attrs = 0; + for (int i = 0; i < self->graph->num_edge_attrs; i++) { + if (self->graph->edge_attrs[3 * i] == self->index) { + num_attrs++; + } + } + return num_attrs; + } + return self->graph->num_edges; +} + +PyObject *EdgeView_get(PyObject *_self, PyObject *args) { + PyObject *key; + PyObject *default_value = Py_None; + if (!PyArg_ParseTuple(args, "O|O", &key, &default_value)) { + return NULL; + } + PyObject *res = EdgeView_getitem(_self, key); + if (res == NULL) { + PyErr_Clear(); + Py_INCREF(default_value); + return default_value; + } + return res; +} +int _EdgeView_eq(EdgeView *self, EdgeView *other) { + if (self->index == -1 || other->index == -1) { + return -1; + } + + int num_settable = ((GraphDef*)self->graph->graph_def)->num_settable_edge_attrs; + int self_values[num_settable]; + int other_values[num_settable]; + for (int i = 0; i < num_settable; i++) { + self_values[i] = -1; + other_values[i] = -1; + } + for (int i = 0; i < self->graph->num_edge_attrs; i++) { + if (self->graph->edge_attrs[3 * i] == self->index) { + self_values[self->graph->edge_attrs[3 * i + 1]] = self->graph->edge_attrs[3 * i + 2]; + } + } + for (int i = 0; i < other->graph->num_edge_attrs; i++) { + if (other->graph->edge_attrs[3 * i] == other->index) { + other_values[other->graph->edge_attrs[3 * i + 1]] = other->graph->edge_attrs[3 * i + 2]; + } + } + for (int i = 0; i < num_settable; i++) { + if (self_values[i] != other_values[i]) { + return 0; + } + } + return 1; +} + +static PyObject* +EdgeView_richcompare(PyObject *_self, PyObject *other, int op) { + PyObject *result = NULL; + EdgeView *self = (EdgeView *)_self; + + if (other == NULL || other->ob_type != &EdgeViewType) { + result = Py_NotImplemented; + } + else { + if (op == Py_EQ){ + if (self->index == -1 || ((EdgeView *)other)->index == -1){ + PyErr_SetString(PyExc_TypeError, "EdgeView.__eq__ only supports equality comparison for bound EdgeViews"); + return NULL; + } + EdgeView *other_node = (EdgeView *)other; + if (_EdgeView_eq(self, other_node) == 1){ + result = Py_True; + } else { + result = Py_False; + } + }else{ + PyErr_SetString(PyExc_TypeError, "EdgeView only supports equality comparison"); + return NULL; + } + } + + Py_XINCREF(result); + return result; +} + +static PyMappingMethods EdgeView_mapmeth = { + .mp_subscript = EdgeView_getitem, + .mp_ass_subscript = EdgeView_setitem, +}; + +static PySequenceMethods EdgeView_seqmeth = { + .sq_contains = EdgeView_contains, + .sq_length = EdgeView_len, +}; + +static PyMethodDef EdgeView_methods[] = { + {"get", (PyCFunction)EdgeView_get, METH_VARARGS, "dict-like get"}, {NULL} /* Sentinel */ +}; + +PyTypeObject EdgeViewType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.EdgeView", + .tp_doc = PyDoc_STR("Constrained EdgeView object"), + .tp_basicsize = sizeof(EdgeView), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = EdgeView_new, + .tp_init = (initproc)EdgeView_init, + .tp_dealloc = (destructor)EdgeView_dealloc, + .tp_members = EdgeView_members, + .tp_methods = EdgeView_methods, + .tp_as_mapping = &EdgeView_mapmeth, + .tp_as_sequence = &EdgeView_seqmeth, + .tp_iter = EdgeView_iter, + .tp_iternext = EdgeView_iternext, + .tp_richcompare = EdgeView_richcompare, +}; diff --git a/src/C/graph_def.c b/src/C/graph_def.c new file mode 100644 index 00000000..9de334ae --- /dev/null +++ b/src/C/graph_def.c @@ -0,0 +1,152 @@ +#define PY_SSIZE_T_CLEAN +#include "./main.h" +#include "structmember.h" +#include + +static void GraphDef_dealloc(GraphDef *self) { + Py_XDECREF(self->node_values); + Py_XDECREF(self->node_keypos); + Py_XDECREF(self->node_poskey); + Py_XDECREF(self->edge_values); + Py_XDECREF(self->edge_keypos); + Py_XDECREF(self->edge_poskey); + free(self->node_attr_offsets); + free(self->edge_attr_offsets); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *GraphDef_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + GraphDef *self; + self = (GraphDef *)type->tp_alloc(type, 0); + if (self != NULL) { + self->node_values = NULL; + self->node_keypos = NULL; + self->node_poskey = NULL; + self->edge_values = NULL; + self->edge_keypos = NULL; + self->edge_poskey = NULL; + self->node_attr_offsets = NULL; + self->edge_attr_offsets = NULL; + } + return (PyObject *)self; +} + +static int GraphDef_init(GraphDef *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"node_values", "edge_values", NULL}; + PyObject *node_values = NULL, *edge_values = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &node_values, &edge_values)) + return -1; + + if (!(node_values && edge_values)) { + return -1; + } + if (!(PyDict_Check(node_values) && PyDict_Check(edge_values))) { + PyErr_SetString(PyExc_TypeError, "GraphDef: node_values and edge_values must be dicts"); + return -1; + } + PyObject *v_list = PyDict_GetItemString(node_values, "v"); + if (!v_list) { + PyErr_SetString(PyExc_TypeError, "GraphDef: node_values must contain 'v'"); + return -1; + } + SELF_SET(node_values); + SELF_SET(edge_values); + self->node_keypos = PyDict_New(); + // We want to create a dict, node_keypos = {k: i for i, k in enumerate(sorted(node_values.keys()))} + // First get the keys + PyObject *keys = PyDict_Keys(node_values); + // Then sort them + PyObject *sorted_keys = PySequence_List(keys); + Py_DECREF(keys); + PyList_Sort(sorted_keys); + // Then create the dict + Py_ssize_t len = PyList_Size(sorted_keys); + self->node_attr_offsets = malloc(sizeof(int) * (len + 1)); + self->node_attr_offsets[0] = 0; + // TODO: here this works because 'v' is last in the list, if it's not + // node_logit_offsets[i] != node_attr_offsets[i] - i, which we're using a lot as an assumption + // throughout this and mol_graph_to_Data + Py_ssize_t offset = 0; + for (Py_ssize_t i = 0; i < len; i++) { + PyObject *k = PyList_GetItem(sorted_keys, i); + PyDict_SetItem(self->node_keypos, k, PyLong_FromSsize_t(i)); + PyObject *vals = PyDict_GetItem(node_values, k); + if (!PyList_Check(vals)) { + PyErr_SetString(PyExc_TypeError, "GraphDef: node_values must be a Dict[str, List[Any]]"); + return -1; + } + offset += PyList_Size(vals); + self->node_attr_offsets[i + 1] = offset; + if (i == len - 1 && strcmp(PyUnicode_AsUTF8(k), "v") != 0) { + PyErr_SetString(PyExc_TypeError, "GraphDef: 'v' must be the last sorted key in current implementation"); + return -1; + } + } + self->num_node_dim = offset + 1; // + 1 for the empty graph + self->num_settable_node_attrs = len - 1; // 'v' is not settable by setnodeattr but only by addnode + self->num_new_node_values = PyList_Size(v_list); + self->num_node_attr_logits = offset - (len - 1) - self->num_new_node_values; // 'v' is not settable + self->node_poskey = sorted_keys; + + // Repeat for edge_keypos + self->edge_keypos = PyDict_New(); + keys = PyDict_Keys(edge_values); + sorted_keys = PySequence_List(keys); + Py_DECREF(keys); + PyList_Sort(sorted_keys); + len = PyList_Size(sorted_keys); + self->edge_attr_offsets = malloc(sizeof(int) * (len + 1)); + self->edge_attr_offsets[0] = 0; + offset = 0; + for (Py_ssize_t i = 0; i < len; i++) { + PyObject *k = PyList_GetItem(sorted_keys, i); + PyDict_SetItem(self->edge_keypos, k, PyLong_FromSsize_t(i)); + PyObject *vals = PyDict_GetItem(edge_values, k); + if (!PyList_Check(vals)) { + PyErr_SetString(PyExc_TypeError, "GraphDef: edge_values must be a Dict[str, List[Any]]"); + return -1; + } + offset += PyList_Size(vals); + self->edge_attr_offsets[i + 1] = offset; + } + self->num_edge_dim = offset; + self->num_settable_edge_attrs = len; + self->num_edge_attr_logits = offset - len; + self->edge_poskey = sorted_keys; + + return 0; +} + +PyObject *GraphDef___getstate__(GraphDef *self, PyObject *args) { + return PyTuple_Pack(1, Py_BuildValue("OO", self->node_values, self->edge_values)); +} + +PyObject *GraphDef___setstate__(GraphDef *self, PyObject *state) { + // new() was just called on self + GraphDef_init(self, PyTuple_GetItem(state, 0), NULL); + Py_RETURN_NONE; +} + +static PyMemberDef GraphDef_members[] = { + {NULL} /* Sentinel */ +}; + +static PyMethodDef GraphDef_methods[] = { + {"__getstate__", (PyCFunction)GraphDef___getstate__, METH_NOARGS, "Pickle the Custom object"}, + {"__setstate__", (PyCFunction)GraphDef___setstate__, METH_O, "Un-pickle the Custom object"}, + {NULL} /* Sentinel */ +}; + +PyTypeObject GraphDefType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "gflownet._C.GraphDef", + .tp_doc = PyDoc_STR("GraphDef object"), + .tp_basicsize = sizeof(GraphDef), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = GraphDef_new, + .tp_init = (initproc)GraphDef_init, + .tp_dealloc = (destructor)GraphDef_dealloc, + .tp_members = GraphDef_members, + .tp_methods = GraphDef_methods, +}; diff --git a/src/C/main.c b/src/C/main.c new file mode 100644 index 00000000..0cec6a47 --- /dev/null +++ b/src/C/main.c @@ -0,0 +1,1143 @@ +#include "main.h" +#include "structmember.h" +#include + +// #define GRAPH_DEBUG +#ifdef GRAPH_DEBUG +void *__real_malloc(size_t size); +void __real_free(void *ptr); +void *__real_realloc(void *ptr, size_t size); + +long total_mem_alloc = 0; +long cnt = 0; +long things_alloc = 0; +/* build with LDFLAGS="-Wl,--wrap,malloc -Wl,--wrap,free -Wl,--wrap,realloc" python setup.py build_ext --inplace + * __wrap_malloc - malloc wrapper function + */ +void *__wrap_malloc(size_t size) { + void *ptr = __real_malloc(size + sizeof(long) * 2); + void *ptr2 = ptr + size; + *(long *)ptr2 = size; + *(long *)(ptr2 + 8) = 0x14285700deadbeef; + // printf("malloc(%d) = %p\n", size, ptr); + total_mem_alloc += size; + things_alloc++; + if (cnt++ % 10000 == 0) { + printf("total_mem_alloc: %ld [%ld]\n", total_mem_alloc, things_alloc); + } + return ptr; +} + +/* + * __wrap_free - free wrapper function + */ +void __wrap_free(void *ptr) { + if (ptr == NULL) { + return; + } + int size = -8; + void *ptr2 = ptr; + while (*(long *)ptr2 != 0x14285700deadbeef) { + ptr2++; + size++; + } + if (size != *(long *)(ptr2 - 8)) { + printf("size mismatch?: %d != %ld\n", size, *(long *)(ptr2 - 8)); + abort(); + } + __real_free(ptr); + // printf("free(%p)\n", ptr); + total_mem_alloc -= size; + things_alloc--; +} + +void *__wrap_realloc(void *ptr, size_t new_size) { + if (ptr == NULL) { + return __wrap_malloc(new_size); + } + int size = -8; + void *ptr2 = ptr; + while (*(long *)ptr2 != 0x14285700deadbeef) { + ptr2++; + size++; + } + *(long *)ptr2 = 0; // erase marker + if (size != *(long *)(ptr2 - 8)) { + printf("size mismatch?: %d != %ld\n", size, *(long *)(ptr2 - 8)); + abort(); + } + void *newptr = __real_realloc(ptr, new_size + sizeof(long) * 2); + void *ptr3 = newptr + new_size; + *(long *)ptr3 = new_size; + *(long *)(ptr3 + 8) = 0x14285700deadbeef; + total_mem_alloc -= size; + total_mem_alloc += new_size; + return newptr; +} + +#endif + +static void Graph_dealloc(Graph *self) { + Py_XDECREF(self->graph_def); + free(self->nodes); + free(self->node_attrs); + free(self->edges); + free(self->edge_attrs); + free(self->degrees); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *Graph_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + Graph *self; + self = (Graph *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph_def = Py_None; + Py_INCREF(Py_None); + self->num_nodes = 0; + self->num_edges = 0; + self->nodes = self->edges = self->node_attrs = self->edge_attrs = self->degrees = NULL; + self->num_node_attrs = 0; + self->num_edge_attrs = 0; + } + return (PyObject *)self; +} + +static int Graph_init(Graph *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph_def", NULL}; + PyObject *graph_def = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &graph_def)) + return -1; + if (PyObject_TypeCheck(graph_def, &GraphDefType) == 0) { + PyErr_SetString(PyExc_TypeError, "Graph: graph_def must be a GraphDef"); + return -1; + } + if (graph_def) { + SELF_SET(graph_def); + } + return 0; +} + +static PyMemberDef Graph_members[] = { + {"graph_def", T_OBJECT_EX, offsetof(Graph, graph_def), 0, "node values"}, {NULL} /* Sentinel */ +}; + +int graph_get_node_pos(Graph *g, int node_id) { + for (int i = 0; i < g->num_nodes; i++) { + if (g->nodes[i] == node_id) { + return i; + } + } + return -1; +} + +static PyObject *Graph_add_node(Graph *self, PyObject *args, PyObject *kwds) { + PyObject *node = NULL; + if (!PyArg_ParseTuple(args, "O", &node)) + return NULL; + if (node) { + if (!PyLong_Check(node)) { + PyErr_SetString(PyExc_TypeError, "node must be an int"); + return NULL; + } + int node_id = PyLong_AsLong(node); + if (graph_get_node_pos(self, node_id) >= 0) { + PyErr_SetString(PyExc_KeyError, "node already exists"); + return NULL; + } + int node_pos = self->num_nodes; + self->num_nodes++; + self->nodes = realloc(self->nodes, self->num_nodes * sizeof(int)); + self->nodes[node_pos] = node_id; + self->degrees = realloc(self->degrees, self->num_nodes * sizeof(int)); + self->degrees[node_pos] = 0; + Py_ssize_t num_attrs; + if (kwds == NULL || (num_attrs = PyDict_Size(kwds)) == 0) { + Py_RETURN_NONE; + } + // Now we need to add the node attributes + // First check if the attributes are found in the GraphDef + GraphDef *gt = (GraphDef *)self->graph_def; + // for k in kwds: + // assert k in gt.node_values and kwds[k] in gt.node_values[k] + PyObject *key, *value; + Py_ssize_t pos = 0; + int node_attr_pos = self->num_node_attrs; + self->num_node_attrs += num_attrs; + self->node_attrs = realloc(self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + + while (PyDict_Next(kwds, &pos, &key, &value)) { + PyObject *node_values = PyDict_GetItem(gt->node_values, key); + if (node_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(node_values, value); + if (value_idx == -1) { + PyObject *repr = PyObject_Repr(value); + PyObject *key_repr = PyObject_Repr(key); + PyErr_Format(PyExc_KeyError, "value %s not found in GraphDef for key %s", PyUnicode_AsUTF8(repr), + PyUnicode_AsUTF8(key_repr)); + Py_DECREF(repr); + // PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, key)); + self->node_attrs[node_attr_pos * 3] = node_pos; + self->node_attrs[node_attr_pos * 3 + 1] = attr_index; + self->node_attrs[node_attr_pos * 3 + 2] = value_idx; + node_attr_pos++; + } + } + Py_RETURN_NONE; +} + +static PyObject *Graph_add_edge(Graph *self, PyObject *args, PyObject *kwds) { + PyObject *u = NULL, *v = NULL; + if (!PyArg_ParseTuple(args, "OO", &u, &v)) + return NULL; + if (u && v) { + if (!PyLong_Check(u) || !PyLong_Check(v)) { + PyErr_SetString(PyExc_TypeError, "u, v must be ints"); + return NULL; + } + int u_id = -PyLong_AsLong(u); + int v_id = -PyLong_AsLong(v); + int u_pos = -1; + int v_pos = -1; + if (u_id < v_id) { + int tmp = u_id; + u_id = v_id; + v_id = tmp; + } + for (int i = 0; i < self->num_nodes; i++) { + if (self->nodes[i] == -u_id) { + u_id = -u_id; + u_pos = i; + } else if (self->nodes[i] == -v_id) { + v_id = -v_id; + v_pos = i; + } + if (u_pos >= 0 && v_pos >= 0) { + break; + } + } + if (u_id < 0 || v_id < 0) { + PyErr_SetString(PyExc_KeyError, "u, v must refer to existing nodes"); + return NULL; + } + if (u_id == v_id){ + PyErr_SetString(PyExc_KeyError, "u, v must be different nodes"); + return NULL; + } + for (int i = 0; i < self->num_edges; i++) { + if (self->edges[i * 2] == u_pos && self->edges[i * 2 + 1] == v_pos) { + PyErr_SetString(PyExc_KeyError, "edge already exists"); + return NULL; + } + } + self->num_edges++; + self->edges = realloc(self->edges, self->num_edges * sizeof(int) * 2); + self->edges[self->num_edges * 2 - 2] = u_pos; + self->edges[self->num_edges * 2 - 1] = v_pos; + self->degrees[u_pos]++; + self->degrees[v_pos]++; + + Py_ssize_t num_attrs; + if (kwds == NULL || (num_attrs = PyDict_Size(kwds)) == 0) { + Py_RETURN_NONE; + } + int edge_id = self->num_edges - 1; + // First check if the attributes are found in the GraphDef + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *key, *value; + Py_ssize_t pos = 0; + int edge_attr_pos = self->num_edge_attrs; + self->num_edge_attrs += num_attrs; + self->edge_attrs = realloc(self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + + while (PyDict_Next(kwds, &pos, &key, &value)) { + PyObject *edge_values = PyDict_GetItem(gt->edge_values, key); + if (edge_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(edge_values, value); + if (value_idx == -1) { + PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, key)); + self->edge_attrs[edge_attr_pos * 3] = edge_id; + self->edge_attrs[edge_attr_pos * 3 + 1] = attr_index; + self->edge_attrs[edge_attr_pos * 3 + 2] = value_idx; + edge_attr_pos++; + } + } + Py_RETURN_NONE; +} + +PyObject *Graph_isdirected(Graph *self, PyObject *Py_UNUSED(ignored)) { Py_RETURN_FALSE; } + +void bridge_dfs(int v, int parent, int *visited, int *tin, int *low, int *timer, int **adj, int *degrees, int *result, + Graph *g) { + visited[v] = 1; + tin[v] = low[v] = (*timer)++; + for (int i = 0; i < degrees[v]; i++) { + int to = adj[v][i]; + if (to == parent) + continue; + if (visited[to]) { + low[v] = mini(low[v], tin[to]); + } else { + bridge_dfs(to, v, visited, tin, low, timer, adj, degrees, result, g); + low[v] = mini(low[v], low[to]); + if (low[to] > tin[v]) { + result[get_edge_index_from_pos(g, v, to)] = 1; + } + } + } +} +PyObject *Graph_bridges(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + // from: + // https://cp-algorithms.com/graph/bridge-searching.html + const int n = self->num_nodes; // number of nodes + int *adj[n]; + for (int i = 0; i < n; i++) { + adj[i] = malloc(self->degrees[i] * sizeof(int)); + } + int is_bridge[self->num_edges]; + for (int i = 0; i < self->num_edges; i++) { + is_bridge[i] = 0; + *(adj[self->edges[i * 2]]) = self->edges[i * 2 + 1]; + *(adj[self->edges[i * 2 + 1]]) = self->edges[i * 2]; + // increase the pointer + adj[self->edges[i * 2]]++; + adj[self->edges[i * 2 + 1]]++; + } + // reset the pointers using the degrees + for (int i = 0; i < n; i++) { + adj[i] -= self->degrees[i]; + } + int visited[n]; + int tin[n]; + int low[n]; + for (int i = 0; i < n; i++) { + visited[i] = 0; + tin[i] = -1; + low[i] = -1; + } + + int timer = 0; + for (int i = 0; i < n; ++i) { + if (!visited[i]) + bridge_dfs(i, -1, visited, tin, low, &timer, adj, self->degrees, is_bridge, self); + } + if (args == NULL) { + // This function is being called from python, so we must return a list + PyObject *result = PyList_New(0); + for (int i = 0; i < self->num_edges; i++) { + if (is_bridge[i]) { + PyObject *u = PyLong_FromLong(self->nodes[self->edges[i * 2]]); + PyObject *v = PyLong_FromLong(self->nodes[self->edges[i * 2 + 1]]); + PyObject *t = PyTuple_Pack(2, u, v); + PyList_Append(result, t); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(t); + } + } + for (int i = 0; i < n; i++) { + free(adj[i]); + } + return result; + } + // This function is being called from C, so args is actually a int* to store the result + int *result = (int *)args; + for (int i = 0; i < self->num_edges; i++) { + result[i] = is_bridge[i]; + } + for (int i = 0; i < n; i++) { + free(adj[i]); + } + return 0; // success +} + +PyObject *Graph_copy(PyObject *_self, PyObject *unused_args) { + Graph *self = (Graph *)_self; + PyObject *args = PyTuple_Pack(1, self->graph_def); + Graph *obj = (Graph *)PyObject_CallObject((PyObject *)&GraphType, args); + Py_DECREF(args); + obj->num_nodes = self->num_nodes; + obj->num_edges = self->num_edges; + obj->nodes = malloc(self->num_nodes * sizeof(int)); + memcpy(obj->nodes, self->nodes, self->num_nodes * sizeof(int)); + obj->edges = malloc(self->num_edges * 2 * sizeof(int)); + memcpy(obj->edges, self->edges, self->num_edges * 2 * sizeof(int)); + obj->num_node_attrs = self->num_node_attrs; + obj->num_edge_attrs = self->num_edge_attrs; + obj->node_attrs = malloc(self->num_node_attrs * 3 * sizeof(int)); + memcpy(obj->node_attrs, self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + obj->edge_attrs = malloc(self->num_edge_attrs * 3 * sizeof(int)); + memcpy(obj->edge_attrs, self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + obj->degrees = malloc(self->num_nodes * sizeof(int)); + memcpy(obj->degrees, self->degrees, self->num_nodes * sizeof(int)); + return (PyObject *)obj; +} + +PyObject *Graph_has_edge(PyObject *_self, PyObject *args) { + int u, v; + if (!PyArg_ParseTuple(args, "ii", &u, &v)) + return NULL; + if (get_edge_index((Graph *)_self, u, v) >= 0) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +void _Graph_check(Graph *g) { +#ifdef GRAPH_DEBUG + for (int i = 0; i < g->num_node_attrs; i++) { + if (g->node_attrs[i * 3] >= g->num_nodes) { + printf("Invalid node attr pointer %d %d\n", g->node_attrs[i * 3], g->num_nodes); + abort(); + } + } + for (int i = 0; i < g->num_edge_attrs; i++) { + if (g->edge_attrs[i * 3] >= g->num_edges) { + printf("Invalid edge attr pointer %d %d\n", g->edge_attrs[i * 3], g->num_edges); + abort(); + } + } +#endif +} + +PyObject *Graph_remove_edge(PyObject *, PyObject *); // fwd decl + +PyObject *Graph_remove_node(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + int u; + if (!PyArg_ParseTuple(args, "i", &u)) + return NULL; + int pos = graph_get_node_pos(self, u); + if (pos < 0) { + PyErr_SetString(PyExc_KeyError, "node not found"); + return NULL; + } + // Check if any edge has this node + for (int i = 0; i < self->num_edges; i++) { + if (self->edges[i * 2] == pos || self->edges[i * 2 + 1] == pos) { + PyObject *u = PyLong_FromLong(self->nodes[self->edges[i * 2]]); + PyObject *v = PyLong_FromLong(self->nodes[self->edges[i * 2 + 1]]); + PyObject *rm_args = PyTuple_Pack(2, u, v); + // if so remove it + PyObject *rm_res = Graph_remove_edge(_self, rm_args); + Py_DECREF(u); + Py_DECREF(v); + Py_DECREF(rm_args); + if (rm_res == NULL) { + return NULL; + } + Py_DECREF(rm_res); + i--; + } + } + // Remove the node + // self->nodes contains node "names", so we just pop one element from the array + int *old_nodes = self->nodes; + self->nodes = malloc((self->num_nodes - 1) * sizeof(int)); + memcpy(self->nodes, old_nodes, pos * sizeof(int)); + memcpy(self->nodes + pos, old_nodes + pos + 1, (self->num_nodes - pos - 1) * sizeof(int)); + free(old_nodes); + int *old_degrees = self->degrees; + self->degrees = malloc((self->num_nodes - 1) * sizeof(int)); + memcpy(self->degrees, old_degrees, pos * sizeof(int)); + memcpy(self->degrees + pos, old_degrees + pos + 1, (self->num_nodes - pos - 1) * sizeof(int)); + free(old_degrees); + self->num_nodes--; + // Remove the node attributes + int *old_node_attrs = self->node_attrs; + int num_rem = 0; + // first find out how many attributes we need to remove + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == pos) { + num_rem++; + } + } + if (num_rem) { + // now remove them + self->node_attrs = malloc((self->num_node_attrs - num_rem) * 3 * sizeof(int)); + int i, j = 0; + for (i = 0; i < self->num_node_attrs; i++) { + if (old_node_attrs[i * 3] == pos) { + continue; + } + int old_node_index = old_node_attrs[i * 3]; + self->node_attrs[j * 3] = old_node_index; + self->node_attrs[j * 3 + 1] = old_node_attrs[i * 3 + 1]; + self->node_attrs[j * 3 + 2] = old_node_attrs[i * 3 + 2]; + j++; + } + self->num_node_attrs -= num_rem; + free(old_node_attrs); + } + // since we removed the node at pos, all other node indices past that must be decremented + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] > pos) { + self->node_attrs[i * 3]--; + } + } + // We must also relabel the edges + for (int i = 0; i < self->num_edges; i++) { + if (self->edges[i * 2] > pos) { + self->edges[i * 2]--; + } + if (self->edges[i * 2 + 1] > pos) { + self->edges[i * 2 + 1]--; + } + } + _Graph_check(self); + Py_RETURN_NONE; +} + +PyObject *Graph_remove_edge(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + int u, v; + if (!PyArg_ParseTuple(args, "ii", &u, &v)) + return NULL; + int edge_index = get_edge_index((Graph *)_self, u, v); + if (edge_index < 0) { + PyErr_SetString(PyExc_KeyError, "edge not found"); + return NULL; + } + // Decrease degree + self->degrees[self->edges[edge_index * 2]]--; + self->degrees[self->edges[edge_index * 2 + 1]]--; + // Remove the edge + int *old_edges = self->edges; + self->edges = malloc((self->num_edges - 1) * 2 * sizeof(int)); + memcpy(self->edges, old_edges, edge_index * 2 * sizeof(int)); + memcpy(self->edges + edge_index * 2, old_edges + edge_index * 2 + 2, + (self->num_edges - edge_index - 1) * 2 * sizeof(int)); + free(old_edges); + self->num_edges--; + // Remove the edge attributes + int num_rem = 0; + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == edge_index) { + num_rem++; + } + } + if (num_rem) { + int *old_edge_attrs = self->edge_attrs; + self->edge_attrs = malloc((self->num_edge_attrs - num_rem) * 3 * sizeof(int)); + int i, j = 0; + for (i = 0; i < self->num_edge_attrs; i++) { + int old_edge_index = old_edge_attrs[i * 3]; + if (old_edge_index == edge_index) { + continue; + } + self->edge_attrs[j * 3] = old_edge_index; + self->edge_attrs[j * 3 + 1] = old_edge_attrs[i * 3 + 1]; + self->edge_attrs[j * 3 + 2] = old_edge_attrs[i * 3 + 2]; + j++; + } + + self->num_edge_attrs -= num_rem; + free(old_edge_attrs); + } + // since we removed the edge at edge_index, all other edge indices past that must be decremented + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] > edge_index) { + self->edge_attrs[i * 3]--; + } + } + _Graph_check(self); + Py_RETURN_NONE; +} + +PyObject *Graph_relabel(PyObject *_self, PyObject *args) { + PyObject *mapping = NULL; + if (!PyArg_ParseTuple(args, "O", &mapping)) + return NULL; + if (!PyDict_Check(mapping)) { + PyErr_SetString(PyExc_TypeError, "mapping must be a dict"); + return NULL; + } + Graph *new = (Graph *)Graph_copy(_self, NULL); // new ref + for (int i = 0; i < new->num_nodes; i++) { + PyObject *node = PyLong_FromLong(new->nodes[i]); + PyObject *new_node = PyDict_GetItem(mapping, node); // ref is borrowed + Py_DECREF(node); + if (new_node == NULL) { + PyErr_Format(PyExc_KeyError, "node %d not found in mapping", new->nodes[i]); + return NULL; + } + if (!PyLong_Check(new_node)) { + PyErr_SetString(PyExc_TypeError, "mapping must be a dict of ints"); + return NULL; + } + new->nodes[i] = PyLong_AsLong(new_node); + } + _Graph_check((Graph *)_self); + return (PyObject *)new; +} + +PyObject *Graph_inspect(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + char buffer[4096]; + memset(buffer, 0, 4096); + int cap = 4096; + int ptr = 0; + + ptr += snprintf(buffer, cap - ptr, "Node labels:\n "); + for (int i = 0; i < self->num_nodes; i++) { + printf("%d ", self->nodes[i]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "\n"); + ptr += snprintf(buffer + ptr, cap - ptr, "Edges:\n"); + for (int i = 0; i < self->num_edges; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, " %d %d (%d %d)\n", self->nodes[self->edges[i * 2]], self->nodes[self->edges[i * 2 + 1]], + self->edges[i * 2], self->edges[i * 2 + 1]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "Node attributes:\n"); + for (int i = 0; i < self->num_node_attrs; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, " %d %d %d\n", self->node_attrs[i * 3], self->node_attrs[i * 3 + 1], self->node_attrs[i * 3 + 2]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "Edge attributes:\n"); + for (int i = 0; i < self->num_edge_attrs; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, " %d %d %d\n", self->edge_attrs[i * 3], self->edge_attrs[i * 3 + 1], self->edge_attrs[i * 3 + 2]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "Degrees:\n "); + for (int i = 0; i < self->num_nodes; i++) { + ptr += snprintf(buffer + ptr, cap - ptr, "%d ", self->degrees[i]); + } + ptr += snprintf(buffer + ptr, cap - ptr, "\n\n"); + PyObject *res = PyUnicode_FromString(buffer); + return res; +} + +PyObject *Graph_getstate(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + int num_bytes = (4 + self->num_nodes * 2 + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3) * sizeof(int); + // Create a new bytes object + PyObject *state = PyBytes_FromStringAndSize(NULL, num_bytes); + char *buf = PyBytes_AS_STRING(state); + int* ptr = (int*)buf; + ptr[0] = self->num_nodes; + ptr[1] = self->num_edges; + ptr[2] = self->num_node_attrs; + ptr[3] = self->num_edge_attrs; + memcpy(ptr + 4, self->nodes, self->num_nodes * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes, self->edges, self->num_edges * 2 * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2, self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, + self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + memcpy(ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, + self->degrees, self->num_nodes * sizeof(int)); + // Return (bytes, GraphDef) + PyObject *res = PyTuple_Pack(2, state, self->graph_def); + Py_DECREF(state); + return res; +} +PyObject *Graph_setstate(PyObject *_self, PyObject *args) { + PyObject* state = PyTuple_GetItem(args, 0); + PyObject* graph_def = PyTuple_GetItem(args, 1); + if (!PyBytes_Check(state)) { + PyErr_SetString(PyExc_TypeError, "state must be a bytes object"); + return NULL; + } + if (PyObject_TypeCheck(graph_def, &GraphDefType) == 0) { + PyErr_SetString(PyExc_TypeError, "graph_def must be a GraphDef"); + return NULL; + } + Graph *self = (Graph *)_self; + Py_XDECREF(self->graph_def); + self->graph_def = graph_def; + Py_XINCREF(graph_def); + char *buf = PyBytes_AS_STRING(state); + int* ptr = (int*)buf; + self->num_nodes = ptr[0]; + self->num_edges = ptr[1]; + self->num_node_attrs = ptr[2]; + self->num_edge_attrs = ptr[3]; + self->nodes = malloc(self->num_nodes * sizeof(int)); + self->edges = malloc(self->num_edges * 2 * sizeof(int)); + self->node_attrs = malloc(self->num_node_attrs * 3 * sizeof(int)); + self->edge_attrs = malloc(self->num_edge_attrs * 3 * sizeof(int)); + self->degrees = malloc(self->num_nodes * sizeof(int)); + memcpy(self->nodes, ptr + 4, self->num_nodes * sizeof(int)); + memcpy(self->edges, ptr + 4 + self->num_nodes, self->num_edges * 2 * sizeof(int)); + memcpy(self->node_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2, self->num_node_attrs * 3 * sizeof(int)); + memcpy(self->edge_attrs, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3, + self->num_edge_attrs * 3 * sizeof(int)); + memcpy(self->degrees, ptr + 4 + self->num_nodes + self->num_edges * 2 + self->num_node_attrs * 3 + self->num_edge_attrs * 3, + self->num_nodes * sizeof(int)); + Py_RETURN_NONE; +} + + +int _NodeView_eq(NodeView *self, NodeView *other); +int _EdgeView_eq(EdgeView *self, EdgeView *other); + +inline int _fast_Node_eq(int i, int j, int* nodeval_cache, int N, int M){ + for (int k = 0; k < M; k++){ + //printf("nodeeq i=%d j=%d k=%d M=%d: %d == %d\n", i, j, k, M, nodeval_cache[i * M + k], nodeval_cache[(j + N) * M + k]); + if (nodeval_cache[i * M + k] != nodeval_cache[(j + N) * M + k]) + return 0; + } + return 1; +} + +int _Graph_iso_search(Graph* a, Graph* b, int* assignment, int depth, int* nodeval_cache, int* a_edgeidx_cache){ + /* Python version: +def search(graph, subgraph, assignments, depth): + # Make sure that every edge between assigned vertices in the subgraph is also an + # edge in the graph. + for u, v in subgraph.edges: + if u < depth and v < depth: + x, y = assignments[u], assignments[v] + if not ( + graph.has_edge(x, y) and + graph.nodes[x] == subgraph.nodes[u] and graph.nodes[y] == subgraph.nodes[v] and + graph.edges[x, y] == subgraph.edges[u, v]): + return False + + # If all the vertices in the subgraph are assigned, then we are done. + if depth == len(subgraph): + return True + + for j in range(len(subgraph)): #possible_assignments[i]: + if assignments[depth] == -1: + assignments[depth] = j + if search(graph, subgraph, assignments, depth+1): + return True + + assignments[depth] = -1 */ + /* + for (int i=0;inum_nodes*2;i++){ + if (i == a->num_nodes){ + printf("|| "); + } + printf("%d ", assignment[i]); + } + puts("");*/ + for (int i = 0; i < b->num_edges; i++){ + int u = b->edges[i * 2]; + int v = b->edges[i * 2 + 1]; + if (u < depth && v < depth){ + int x = assignment[u]; + int y = assignment[v]; + int x_y = a_edgeidx_cache[x * a->num_nodes + y]; // get_edge_index(a, a->nodes[x], a->nodes[y]); + if (x_y < 0) + return 0; + //NodeView nv_x = {.graph = a, .index = x}; + //NodeView nv_y = {.graph = a, .index = y}; + //NodeView nv_u = {.graph = b, .index = u}; + //NodeView nv_v = {.graph = b, .index = v}; + //if (!_NodeView_eq(&nv_x, &nv_u) || !_NodeView_eq(&nv_y, &nv_v)) + if (!_fast_Node_eq(x, u, nodeval_cache, a->num_nodes, 1 + ((GraphDef*)a->graph_def)->num_settable_node_attrs) || + !_fast_Node_eq(y, v, nodeval_cache, a->num_nodes, 1 + ((GraphDef*)a->graph_def)->num_settable_node_attrs)) + return 0; + EdgeView ev_x_y = {.graph = a, .index = x_y}; + EdgeView ev_u_v = {.graph = b, .index = i}; + if (!_EdgeView_eq(&ev_x_y, &ev_u_v)) + return 0; + } + } + + if (depth == b->num_nodes){ + // Found a valid assignment + return 1; + } + for (int j = 0; j < b->num_nodes; j++){ + //printf("Trying %d = %d\n", depth, j); + if (assignment[b->num_nodes + j] != -1) + continue; + //NodeView nv_a = {.graph = a, .index = j}; + //NodeView nv_b = {.graph = b, .index = depth}; + if (a->degrees[j] == b->degrees[depth]){ + if (!_fast_Node_eq(j, depth, nodeval_cache, a->num_nodes, 1 + ((GraphDef*)a->graph_def)->num_settable_node_attrs)){ + continue; + } + /*if (!_NodeView_eq(&nv_a, &nv_b)){ + continue; + }*/ + assignment[depth] = j; + assignment[j + b->num_nodes] = depth; + if (_Graph_iso_search(a, b, assignment, depth + 1, nodeval_cache, a_edgeidx_cache)){ + return 1; + } + assignment[depth] = -1; + assignment[j + b->num_nodes] = -1; + } + } + return 0; +} + +PyObject *Graph_is_isomorphic(PyObject *_self, PyObject *args) { + Graph *self = (Graph *)_self; + PyObject *other = NULL; + if (!PyArg_ParseTuple(args, "O", &other)) + return NULL; + if (PyObject_TypeCheck(other, &GraphType) == 0) { + PyErr_SetString(PyExc_TypeError, "other must be a Graph"); + return NULL; + } + Graph *other_graph = (Graph *)other; + if (self->graph_def != other_graph->graph_def) { + Py_RETURN_FALSE; + } + if (self->num_nodes != other_graph->num_nodes || self->num_edges != other_graph->num_edges) { + Py_RETURN_FALSE; + } + int assignments[self->num_nodes * 2]; + for (int i = 0; i < self->num_nodes * 2; i++) { + assignments[i] = -1; + } + int num_attrs = ((GraphDef *)self->graph_def)->num_settable_node_attrs + 1; + int nodeval_cache[self->num_nodes * num_attrs * 2]; + memset(nodeval_cache, 0, self->num_nodes * num_attrs * 2 * sizeof(int)); + for (int j = 0; j < self->num_node_attrs; j++){ + int node_attr_index = self->node_attrs[j * 3]; + int attr_index = self->node_attrs[j * 3 + 1]; + int value_index = self->node_attrs[j * 3 + 2]; + nodeval_cache[node_attr_index * num_attrs + attr_index] = value_index; + } + for (int j = 0; j < other_graph->num_node_attrs; j++){ + int node_attr_index = other_graph->node_attrs[j * 3]; + int attr_index = other_graph->node_attrs[j * 3 + 1]; + int value_index = other_graph->node_attrs[j * 3 + 2]; + nodeval_cache[(node_attr_index + self->num_nodes) * num_attrs + attr_index] = value_index; + } + // Q: Is something wrong with the above? Are we setting the values correctly? + // A: Yes, we are setting the values correctly. The nodeval_cache is a 3D array, where the first dimension is the node index, the second dimension is the attribute index, and the third dimension is the value index. + // The first num_nodes * num_attrs entries are for the first graph, and the next num_nodes * num_attrs entries are for the second graph. + + int a_edgeidx_cache[self->num_nodes * self->num_nodes]; + memset(a_edgeidx_cache, -1, self->num_nodes * self->num_nodes * sizeof(int)); + for (int i = 0; i < self->num_edges; i++){ + a_edgeidx_cache[self->edges[i * 2] * self->num_nodes + self->edges[i * 2 + 1]] = i; + } + if (_Graph_iso_search(self, other_graph, assignments, 0, nodeval_cache, a_edgeidx_cache)){ + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + + +static PyMethodDef Graph_methods[] = { + {"add_node", (PyCFunction)Graph_add_node, METH_VARARGS | METH_KEYWORDS, "Add a node"}, + {"add_edge", (PyCFunction)Graph_add_edge, METH_VARARGS | METH_KEYWORDS, "Add an edge"}, + {"has_edge", (PyCFunction)Graph_has_edge, METH_VARARGS, "Check if an edge is present"}, + {"remove_node", (PyCFunction)Graph_remove_node, METH_VARARGS, "Remove a node"}, + {"remove_edge", (PyCFunction)Graph_remove_edge, METH_VARARGS, "Remove an edge"}, + {"is_directed", (PyCFunction)Graph_isdirected, METH_NOARGS, "Is the graph directed?"}, + {"is_multigraph", (PyCFunction)Graph_isdirected, METH_NOARGS, "Is the graph a multigraph?"}, + {"bridges", (PyCFunction)Graph_bridges, METH_NOARGS, "Find the bridges of the graph"}, + {"copy", (PyCFunction)Graph_copy, METH_VARARGS, "Copy the graph"}, + {"relabel_nodes", (PyCFunction)Graph_relabel, METH_VARARGS, "relabel the graph"}, + {"is_isomorphic", (PyCFunction)Graph_is_isomorphic, METH_VARARGS, "isomorphism test"}, + {"__deepcopy__", (PyCFunction)Graph_copy, METH_VARARGS, "Copy the graph"}, + {"__getstate__", (PyCFunction)Graph_getstate, METH_NOARGS, "pickling"}, + {"__setstate__", (PyCFunction)Graph_setstate, METH_O, "unpickling"}, + {"_inspect", (PyCFunction)Graph_inspect, METH_VARARGS, "Print the graph's information"}, + {NULL} /* Sentinel */ +}; + +static PyObject *Graph_getnodes(Graph *self, void *closure) { + // Return a new NodeView + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&NodeViewType, args); // new ref + Py_DECREF(args); + return obj; +} + +static PyObject *Graph_getedges(Graph *self, void *closure) { + // Return a new EdgeView + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&EdgeViewType, args); + Py_DECREF(args); + return obj; +} + +static PyObject *Graph_getdegree(Graph *self, void *closure) { + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&DegreeViewType, args); + Py_DECREF(args); + return obj; +} + +static PyGetSetDef Graph_getsetters[] = { + {"nodes", (getter)Graph_getnodes, NULL, "nodes", NULL}, + {"edges", (getter)Graph_getedges, NULL, "edges", NULL}, + {"degree", (getter)Graph_getdegree, NULL, "degree", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject *Graph_contains(Graph *self, PyObject *v) { + if (!PyLong_Check(v)) { + PyErr_SetString(PyExc_TypeError, "Graph.__contains__ only accepts integers"); + return NULL; + } + int node_id = PyLong_AsLong(v); + int pos = graph_get_node_pos(self, node_id); + if (pos >= 0) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +static Py_ssize_t Graph_len(Graph *self) { return self->num_nodes; } + +static PySequenceMethods Graph_seqmeth = { + .sq_contains = (objobjproc)Graph_contains, + .sq_length = (lenfunc)Graph_len, +}; + +static PyObject *Graph_iter(PyObject *self) { + PyObject *args = PyTuple_Pack(1, self); + PyObject *obj = PyObject_CallObject((PyObject *)&NodeViewType, args); // new ref + Py_DECREF(args); + return obj; +} + +static PyObject *Graph_getitem(PyObject *self, PyObject *key) { + PyObject *args = PyTuple_Pack(2, self, key); + PyObject *obj = PyObject_CallObject((PyObject *)&NodeViewType, args); // new ref + Py_DECREF(args); + return obj; +} +static PyMappingMethods Graph_mapmeth = { + .mp_subscript = Graph_getitem, +}; + +PyTypeObject GraphType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "gflownet._C.Graph", + .tp_doc = PyDoc_STR("Constrained Graph object"), + .tp_basicsize = sizeof(Graph), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = Graph_new, + .tp_init = (initproc)Graph_init, + .tp_dealloc = (destructor)Graph_dealloc, + .tp_members = Graph_members, + .tp_methods = Graph_methods, + .tp_iter = Graph_iter, + .tp_getset = Graph_getsetters, + .tp_as_sequence = &Graph_seqmeth, + .tp_as_mapping = &Graph_mapmeth, +}; + +PyObject *Graph_getnodeattr(Graph *self, int index, PyObject *k) { + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *value_list = PyDict_GetItem(gt->node_values, k); // borrowed ref + if (value_list == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found"); + return NULL; + } + long attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, k)); + int true_node_index = -1; + for (int i = 0; i < self->num_nodes; i++) { + if (self->nodes[i] == index) { + true_node_index = i; + break; + } + } + if (true_node_index == -1) { + PyErr_SetString(PyExc_KeyError, "node not found"); + return NULL; + } + + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == true_node_index && self->node_attrs[i * 3 + 1] == attr_index) { + // borrowed ref so we have to increase its refcnt because we are returning it + PyObject *r = PyList_GetItem(value_list, self->node_attrs[i * 3 + 2]); + Py_INCREF(r); + return r; + } + } + PyErr_SetString(PyExc_KeyError, "attribute not set for this node"); + return NULL; +} + +PyObject *Graph_setnodeattr(Graph *self, int index, PyObject *k, PyObject *v) { + GraphDef *gt = (GraphDef *)self->graph_def; + int true_node_index = -1; + for (int i = 0; i < self->num_nodes; i++) { + if (self->nodes[i] == index) { + true_node_index = i; + break; + } + } + if (true_node_index == -1) { + PyErr_SetString(PyExc_KeyError, "node not found"); + return NULL; + } + PyObject *node_values = PyDict_GetItem(gt->node_values, k); + if (node_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + if (v == NULL) { + // this means we have to delete g.nodes[index][k] + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, k)); + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == true_node_index && self->node_attrs[i * 3 + 1] == attr_index) { + // found the attribute, remove it + int *old_node_attrs = self->node_attrs; + self->node_attrs = malloc((self->num_node_attrs - 1) * 3 * sizeof(int)); + memcpy(self->node_attrs, old_node_attrs, i * 3 * sizeof(int)); + memcpy(self->node_attrs + i * 3, old_node_attrs + (i + 1) * 3, + (self->num_node_attrs - i - 1) * 3 * sizeof(int)); + self->num_node_attrs--; + free(old_node_attrs); + Py_RETURN_NONE; + } + } + PyErr_SetString(PyExc_KeyError, "trying to delete a key that's not set"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(node_values, v); + if (value_idx == -1) { + PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->node_keypos, k)); + for (int i = 0; i < self->num_node_attrs; i++) { + if (self->node_attrs[i * 3] == true_node_index && self->node_attrs[i * 3 + 1] == attr_index) { + self->node_attrs[i * 3 + 2] = value_idx; + Py_RETURN_NONE; + } + } + // Could not find the attribute, add it + int new_idx = self->num_node_attrs; + self->num_node_attrs++; + self->node_attrs = realloc(self->node_attrs, self->num_node_attrs * 3 * sizeof(int)); + self->node_attrs[new_idx * 3] = true_node_index; + self->node_attrs[new_idx * 3 + 1] = attr_index; + self->node_attrs[new_idx * 3 + 2] = value_idx; + Py_RETURN_NONE; +} + +PyObject *Graph_getedgeattr(Graph *self, int index, PyObject *k) { + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *value_list = PyDict_GetItem(gt->edge_values, k); // borrowed ref + if (value_list == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found"); + return NULL; + } + long attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, k)); + if (index > self->num_edges) { + // Should never happen, index is computed by us in EdgeView_getitem, not + // by the user! + PyErr_SetString(PyExc_KeyError, "edge not found [but this should never happen!]"); + return NULL; + } + + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == index && self->edge_attrs[i * 3 + 1] == attr_index) { + // borrowed ref so we have to increase its refcnt because we are returning it + PyObject *r = PyList_GetItem(value_list, self->edge_attrs[i * 3 + 2]); + Py_INCREF(r); + return r; + } + } + PyErr_SetString(PyExc_KeyError, "attribute not set for this node"); + return NULL; +} +PyObject *Graph_setedgeattr(Graph *self, int index, PyObject *k, PyObject *v) { + GraphDef *gt = (GraphDef *)self->graph_def; + PyObject *edge_values = PyDict_GetItem(gt->edge_values, k); + if (edge_values == NULL) { + PyErr_SetString(PyExc_KeyError, "key not found in GraphDef"); + return NULL; + } + if (v == 0) { + // this means we have to delete g.edges[index][k] + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, k)); + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == index && self->edge_attrs[i * 3 + 1] == attr_index) { + // found the attribute, remove it + int *old_edge_attrs = self->edge_attrs; + self->edge_attrs = malloc((self->num_edge_attrs - 1) * 3 * sizeof(int)); + memcpy(self->edge_attrs, old_edge_attrs, i * 3 * sizeof(int)); + memcpy(self->edge_attrs + i * 3, old_edge_attrs + (i + 1) * 3, + (self->num_edge_attrs - i - 1) * 3 * sizeof(int)); + self->num_edge_attrs--; + free(old_edge_attrs); + Py_RETURN_NONE; + } + } + PyErr_SetString(PyExc_KeyError, "trying to delete a key that's not set"); + return NULL; + } + Py_ssize_t value_idx = PySequence_Index(edge_values, v); + if (value_idx == -1) { + PyErr_SetString(PyExc_KeyError, "value not found in GraphDef"); + return NULL; + } + int attr_index = PyLong_AsLong(PyDict_GetItem(gt->edge_keypos, k)); + for (int i = 0; i < self->num_edge_attrs; i++) { + if (self->edge_attrs[i * 3] == index && self->edge_attrs[i * 3 + 1] == attr_index) { + self->edge_attrs[i * 3 + 2] = value_idx; + Py_RETURN_NONE; + } + } + // Could not find the attribute, add it + int new_idx = self->num_edge_attrs; + self->num_edge_attrs++; + self->edge_attrs = realloc(self->edge_attrs, self->num_edge_attrs * 3 * sizeof(int)); + self->edge_attrs[new_idx * 3] = index; + self->edge_attrs[new_idx * 3 + 1] = attr_index; + self->edge_attrs[new_idx * 3 + 2] = value_idx; + Py_RETURN_NONE; +} + +static PyMethodDef SpamMethods[] = { + //{"print_count", print_count, METH_VARARGS, "Execute a shell command."}, + {"mol_graph_to_Data", mol_graph_to_Data, METH_VARARGS, "Convert a mol_graph to a Data object."}, + {"Data_collate", Data_collate, METH_VARARGS, "collate Data instances"}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static struct PyModuleDef spammodule = {PyModuleDef_HEAD_INIT, "gflownet._C", /* name of module */ + "doc", /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + SpamMethods}; +PyObject *SpamError; + +PyMODINIT_FUNC PyInit__C(void) { + PyObject *m; + + m = PyModule_Create(&spammodule); + if (m == NULL) + return NULL; + + SpamError = PyErr_NewException("gflownet._C.error", NULL, NULL); + Py_XINCREF(SpamError); + if (PyModule_AddObject(m, "error", SpamError) < 0) { + Py_XDECREF(SpamError); + Py_CLEAR(SpamError); + Py_DECREF(m); + return NULL; + } + PyTypeObject *types[] = {&GraphType, &GraphDefType, &NodeViewType, &EdgeViewType, &DegreeViewType, &DataType}; + char *names[] = {"Graph", "GraphDef", "NodeView", "EdgeView", "DegreeView", "Data"}; + for (int i = 0; i < (int)(sizeof(types) / sizeof(PyTypeObject *)); i++) { + if (PyType_Ready(types[i]) < 0) { + Py_DECREF(m); + return NULL; + } + Py_XINCREF(types[i]); + if (PyModule_AddObject(m, names[i], (PyObject *)types[i]) < 0) { + Py_XDECREF(types[i]); + Py_DECREF(m); + return NULL; + } + } + + return m; +} diff --git a/src/C/main.h b/src/C/main.h new file mode 100644 index 00000000..5e45d737 --- /dev/null +++ b/src/C/main.h @@ -0,0 +1,115 @@ +#define PY_SSIZE_T_CLEAN +#include + +extern PyObject *SpamError; + +typedef struct { + PyObject_HEAD; + PyObject *node_values; /* Dict[str, List[Any]] */ + PyObject *edge_values; /* Dict[str, List[Any]] */ + PyObject *node_keypos; /* Dict[str, int] */ + PyObject *edge_keypos; /* Dict[str, int] */ + PyObject *node_poskey; /* List[str] */ + PyObject *edge_poskey; /* List[str] */ + int *node_attr_offsets; /* List[int] */ + int *edge_attr_offsets; /* List[int] */ + int num_node_dim; + int num_settable_node_attrs; + int num_node_attr_logits; + int num_new_node_values; + int num_edge_dim; + int num_settable_edge_attrs; + int num_edge_attr_logits; +} GraphDef; + +extern PyTypeObject GraphDefType; + +typedef struct { + PyObject_HEAD; + PyObject *graph_def; + int num_nodes; + int num_edges; + int num_node_attrs; + int num_edge_attrs; + int *nodes; /* List[int] */ + int *edges; /* List[Tuple[int, int]] */ + int *node_attrs; /* List[Tuple[nodeid, attrid, attrvalueidx]] */ + int *edge_attrs; /* List[Tuple[edgeid, attrid, attrvalueidx]] */ + int *degrees; /* List[int] */ +} Graph; + +extern PyTypeObject GraphType; + +typedef struct NodeView { + PyObject_HEAD; + Graph *graph; + int index; +} NodeView; + +extern PyTypeObject NodeViewType; + +typedef struct EdgeView { + PyObject_HEAD; + Graph *graph; + int index; +} EdgeView; + +extern PyTypeObject EdgeViewType; + +typedef struct DegreeView { + PyObject_HEAD; + Graph *graph; +} DegreeView; + +extern PyTypeObject DegreeViewType; + +/* This is a specialized Data instance that only knows how to hold onto 2d tensors + that are either float32 or int64 (for reasons of compatibility with torch_geometric) */ +typedef struct Data { + PyObject_HEAD; + PyObject *bytes; + GraphDef *graph_def; + int *shapes; + int *is_float; + int *offsets; // in bytes + int num_matrices; + const char **names; +} Data; + +extern PyTypeObject DataType; + +#define SELF_SET(v) \ + { \ + PyObject *tmp = self->v; \ + Py_INCREF(v); \ + self->v = v; \ + Py_XDECREF(tmp); \ + } + +// when we want to do f(x) where x is a new reference that we want to decref after the call +// e.g. PyLong_AsLong(PyObject_GetAttrString(x, "id")), GetAttrString returns a new reference +#define borrow_new_and_call(f, x) \ + ({ \ + PyObject *tmp = x; \ + __auto_type rval = f(tmp); \ + Py_DECREF(tmp); \ + rval; \ + }) + +static inline int maxi(int a, int b) { return a > b ? a : b; } +static inline int mini(int a, int b) { return a < b ? a : b; } + +PyObject *Graph_getnodeattr(Graph *self, int index, PyObject *k); +PyObject *Graph_setnodeattr(Graph *self, int index, PyObject *k, PyObject *v); +PyObject *Graph_getedgeattr(Graph *self, int index, PyObject *k); +PyObject *Graph_setedgeattr(Graph *self, int index, PyObject *k, PyObject *v); +PyObject *Graph_bridges(PyObject *self, PyObject *args); + +int get_edge_index(Graph *g, int u, int v); +int get_edge_index_from_pos(Graph *g, int u_pos, int v_pos); + +PyObject *Data_collate(PyObject *self, PyObject *args); +void Data_init_C(Data *self, PyObject *bytes, PyObject *graph_def, int shapes[][2], int *is_float, int num_matrices, + const char **names); + +PyObject *mol_graph_to_Data(PyObject *self, PyObject *args); diff --git a/src/C/mol_graph_to_Data.c b/src/C/mol_graph_to_Data.c new file mode 100644 index 00000000..574e8aa3 --- /dev/null +++ b/src/C/mol_graph_to_Data.c @@ -0,0 +1,428 @@ +#include "main.h" +#include + +void memsetf(float *ptr, float value, size_t num) { + for (size_t i = 0; i < num; i++) { + ptr[i] = value; + } +} + +#define RETURN_C_DATA 1 +static const char *mol_Data_names[] = {"x", + "edge_index", + "edge_attr", + "non_edge_index", + "stop_mask", + "add_node_mask", + "set_node_attr_mask", + "add_edge_mask", + "set_edge_attr_mask", + "remove_node_mask", + "remove_node_attr_mask", + "remove_edge_mask", + "remove_edge_attr_mask", + "cond_info"}; + +static PyObject *torch_module = NULL; +static PyObject *torch_gd_module = NULL; +void _check_torch() { + if (torch_module == NULL) { + torch_module = PyImport_ImportModule("torch"); + torch_gd_module = PyImport_ImportModule("torch_geometric.data"); + } +} + +PyObject *mol_graph_to_Data(PyObject *self, PyObject *args) { + PyObject *mol_graph = NULL, *ctx = NULL, *_torch_module = NULL, *cond_info = NULL; + if (PyArg_ParseTuple(args, "OOO|O", &mol_graph, &ctx, &_torch_module, &cond_info) == 0) { + return NULL; + } + _check_torch(); + if (PyObject_TypeCheck(mol_graph, &GraphType) == 0) { + PyErr_SetString(PyExc_TypeError, "mol_graph must be a Graph"); + return NULL; + } + if (cond_info == Py_None) { + cond_info = NULL; + } + Graph *g = (Graph *)mol_graph; + GraphDef *gd = (GraphDef *)g->graph_def; + PyObject *atom_types = PyDict_GetItemString(gd->node_values, "v"); // borrowed ref + int num_atom_types = PySequence_Size(atom_types); + int atom_valences[num_atom_types]; + float used_valences[g->num_nodes]; + int max_valence[g->num_nodes]; + int is_hypervalent[g->num_nodes]; + PyObject *_max_atom_valence = PyObject_GetAttrString(ctx, "_max_atom_valence"); // new ref + for (int i = 0; i < num_atom_types; i++) { + atom_valences[i] = PyLong_AsLong(PyDict_GetItem(_max_atom_valence, PyList_GetItem(atom_types, i))); // borrowed + } + Py_DECREF(_max_atom_valence); + int v_val[g->num_nodes]; + int v_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "v")); + int charge_val[g->num_nodes]; + int charge_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "charge")); + int explH_val[g->num_nodes]; + int explH_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "expl_H")); + int chi_val[g->num_nodes]; + int chi_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "chi")); + int noImpl_val[g->num_nodes]; + int noImpl_idx = PyLong_AsLong(PyDict_GetItemString(gd->node_keypos, "no_impl")); + int num_set_attribs[g->num_nodes]; + PyObject *N_str = PyUnicode_FromString("N"); + Py_ssize_t nitro_attr_value = PySequence_Index(PyDict_GetItemString(gd->node_values, "v"), N_str); + Py_DECREF(N_str); + PyObject *O_str = PyUnicode_FromString("O"); + Py_ssize_t oxy_attr_value = PySequence_Index(PyDict_GetItemString(gd->node_values, "v"), O_str); + Py_DECREF(O_str); + for (int i = 0; i < g->num_nodes; i++) { + used_valences[i] = max_valence[i] = v_val[i] = charge_val[i] = explH_val[i] = chi_val[i] = noImpl_val[i] = + num_set_attribs[i] = is_hypervalent[i] = 0; + } + for (int i = 0; i < g->num_node_attrs; i++) { + int node_pos = g->node_attrs[3 * i]; + int attr_type = g->node_attrs[3 * i + 1]; + int attr_value = g->node_attrs[3 * i + 2]; + num_set_attribs[node_pos] += 1; + if (attr_type == v_idx) { + v_val[node_pos] = attr_value; + max_valence[node_pos] += atom_valences[v_val[node_pos]]; + } else if (attr_type == charge_idx) { + charge_val[node_pos] = attr_value; + } else if (attr_type == explH_idx) { + explH_val[node_pos] = attr_value; + max_valence[node_pos] -= attr_value; + } else if (attr_type == chi_idx) { + chi_val[node_pos] = attr_value; + } else if (attr_type == noImpl_idx) { + noImpl_val[node_pos] = attr_value; + } + } + + // bonds are the only edge attributes + char has_connecting_edge_attr_set[g->num_nodes]; + memset(has_connecting_edge_attr_set, 0, g->num_nodes); + float bond_valence[] = {1, 2, 3, 1.5}; // single, double, triple, aromatic + int bond_val[g->num_edges]; + memset(bond_val, 0, g->num_edges * sizeof(int)); + for (int i = 0; i < g->num_edge_attrs; i++) { + int edge_pos = g->edge_attrs[3 * i]; + int attr_type = g->edge_attrs[3 * i + 1]; + int attr_value = g->edge_attrs[3 * i + 2]; + if (attr_type == 0) { // this should always be true, but whatever + bond_val[edge_pos] = attr_value; + } + has_connecting_edge_attr_set[g->edges[2 * edge_pos]] = 1; + has_connecting_edge_attr_set[g->edges[2 * edge_pos + 1]] = 1; + } + for (int i = 0; i < g->num_edges; i++) { + int u = g->edges[2 * i]; + int v = g->edges[2 * i + 1]; + used_valences[u] += bond_valence[bond_val[i]]; + used_valences[v] += bond_valence[bond_val[i]]; + } + + int is_any_hypervalent = 0; + // Correct for the valence of charge Nitro and Oxygen atoms + for (int i = 0; i < g->num_nodes; i++) { + if (charge_val[i] != 0){ + if ((v_val[i] == nitro_attr_value || v_val[i] == oxy_attr_value) + && charge_val[i] == 1) { + max_valence[i] += 1; + }else{ + max_valence[i] -= 1; + } + } + + /* TODO: Figure out this charge thing... + wait or is the problem that I'm _removing_ valence for a + charged mol?? + else if (charge_val[i] == 1) { + // well... maybe this is true of the general case? + // if an atom is positively charged then we can just bond it to more stuff? + // why 2 for N? + max_valence[i] += 1; + }*/ + // Determine hypervalent atoms + // Hypervalent atoms are atoms that have more than the maximum valence for their atom type because of a positive + // charge, but don't have this extra charge filled with bonds, or don't have no_impl set to 1 + /*printf("%d used_valences: %f, max_valence: %d, charge_val: %d, noImpl_val: %d, atom_valences: %d -> %d\n", i, + used_valences[i], max_valence[i], charge_val[i], noImpl_val[i], atom_valences[v_val[i]], + (used_valences[i] >= atom_valences[v_val[i]] && charge_val[i] == 1 && + !(used_valences[i] == max_valence[i] || noImpl_val[i] == 1))); + fflush(stdout);*/ + /*if (used_valences[i] >= atom_valences[v_val[i]] && charge_val[i] == 1 && + !(used_valences[i] == max_valence[i] || noImpl_val[i] == 1)) { + is_hypervalent[i] = 1; + is_any_hypervalent = 1; + }*/ + } + // Compute creatable edges + char can_create_edge[g->num_nodes][g->num_nodes]; + for (int i = 0; i < g->num_nodes; i++) { + for (int j = 0; j < g->num_nodes; j++) { + can_create_edge[i][j] = + (used_valences[i] + 1 <= max_valence[i]) && (used_valences[j] + 1 <= max_valence[j]); + } + } + for (int i = 0; i < g->num_edges; i++) { + int u = g->edges[2 * i]; + int v = g->edges[2 * i + 1]; + can_create_edge[u][v] = can_create_edge[v][u] = 0; + } + int num_creatable_edges = 0; + for (int i = 0; i < g->num_nodes; i++) { + for (int j = i + 1; j < g->num_nodes; j++) { + num_creatable_edges += can_create_edge[i][j]; + } + } + PyObject *max_nodes_py = PyObject_GetAttrString(ctx, "max_nodes"); // new ref + int max_nodes = max_nodes_py == Py_None ? 1000 : PyLong_AsLong(max_nodes_py); + Py_DECREF(max_nodes_py); + + PyObject *cond_info_shape = cond_info == NULL ? NULL : PyObject_CallMethod(cond_info, "size", NULL); + int num_cond_info_dims = cond_info == NULL ? 0 : PyLong_AsLong(PyTuple_GetItem(cond_info_shape, 1)); + Py_XDECREF(cond_info_shape); + int node_feat_shape = maxi(1, g->num_nodes); + int is_float[] = {1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + int shapes[14][2] = { + {node_feat_shape, gd->num_node_dim}, // node_feat + {2, g->num_edges * 2}, // edge_index + {2 * g->num_edges, gd->num_edge_dim}, // edge_feat + {2, num_creatable_edges}, // non_edge_index + {1, 1}, // stop_mask + {node_feat_shape, gd->num_new_node_values}, // add_node_mask + {node_feat_shape, gd->num_node_attr_logits}, // set_node_attr_mask + {num_creatable_edges, 1}, // add_edge_mask + {g->num_edges, gd->num_edge_attr_logits}, // set_edge_attr_mask + {node_feat_shape, 1}, // remove_node_mask + {node_feat_shape, gd->num_settable_node_attrs}, // remove_node_attr_mask + {g->num_edges, 1}, // remove_edge_mask + {g->num_edges, gd->num_settable_edge_attrs}, // remove_edge_attr_mask + {1, num_cond_info_dims}, // cond_info + }; + int offsets[14]; + Py_ssize_t num_items = 0; + for (int i = 0; i < 14; i++) { + offsets[i] = num_items; + // we need twice the space for longs + num_items += shapes[i][0] * shapes[i][1] * (2 - is_float[i]); + } + // Allocate the memory for the Data object in a way we can return it to Python + // (and let Python free it when it's done) + PyObject *data = PyByteArray_FromStringAndSize(NULL, num_items * sizeof(float)); + int *dataptr = (int *)PyByteArray_AsString(data); + memset(dataptr, 0, num_items * sizeof(float)); + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + + float *node_feat = (float *)(dataptr + offsets[0]); + long *edge_index = (long *)(dataptr + offsets[1]); + float *edge_feat = (float *)(dataptr + offsets[2]); + long *non_edge_index = (long *)(dataptr + offsets[3]); + float *stop_mask = (float *)(dataptr + offsets[4]); + float *add_node_mask = (float *)(dataptr + offsets[5]); + float *set_node_attr_mask = (float *)(dataptr + offsets[6]); + float *add_edge_mask = (float *)(dataptr + offsets[7]); + float *set_edge_attr_mask = (float *)(dataptr + offsets[8]); + float *remove_node_mask = (float *)(dataptr + offsets[9]); + float *remove_node_attr_mask = (float *)(dataptr + offsets[10]); + float *remove_edge_mask = (float *)(dataptr + offsets[11]); + float *remove_edge_attr_mask = (float *)(dataptr + offsets[12]); + float *cond_info_data = (float *)(dataptr + offsets[13]); + + int bridges[g->num_edges]; + // Graph_brigdes' second argument is expected to be NULL when called by Python + // we're using it to pass the bridges array instead + Graph_bridges((PyObject *)g, (PyObject *)bridges); + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + + // sorted attrs is 'charge', 'chi', 'expl_H', 'no_impl', 'v' + int *_node_attrs[5] = {charge_val, chi_val, explH_val, noImpl_val, v_val}; + int *_edge_attrs[1] = {bond_val}; + if (g->num_nodes == 0) { + node_feat[gd->num_node_dim - 1] = 1; + memsetf(add_node_mask, 1, gd->num_new_node_values); + remove_node_mask[0] = 1; + } + /*for (int i=0;i<5;i++){ + printf("%d %d\n", gd->node_attr_offsets[i], gd->node_attr_offsets[i]-i); + }*/ + for (int i = 0; i < g->num_nodes; i++) { + if (g->degrees[i] <= 1 && !has_connecting_edge_attr_set[i] && num_set_attribs[i] == 1) { + remove_node_mask[i] = 1; + } + int is_special_valence = v_val[i] == nitro_attr_value || v_val[i] == oxy_attr_value; + for (int j = 0; j < 5; j++) { + int one_hot_idx = gd->node_attr_offsets[j] + _node_attrs[j][i]; + node_feat[i * gd->num_node_dim + one_hot_idx] = 1; + int logit_slice_start = gd->node_attr_offsets[j] - j; + int logit_slice_end = gd->node_attr_offsets[j + 1] - j - 1; + if (logit_slice_end == logit_slice_start) + continue; // we cannot actually set this attribute + if (j == v_idx) + continue; // we cannot remove nor set 'v' + // printf("Node attr %d %d = %d\n", j, i, _node_attrs[j][i]); + if (_node_attrs[j][i] > 0) { + // Special case: if the node is a positively charged Nitrogen, and the valence is maxed out (i.e. 5) + // the agent cannot remove the charge, it has to remove the bonds (or bond attrs) making this valence + // maxed out first. + if (j == charge_idx && is_special_valence && used_valences[i] >= max_valence[i] && charge_val[i] == 1) + //if (j == charge_idx && used_valences[i] >= max_valence[i] && charge_val[i] == 1) + continue; + remove_node_attr_mask[i * gd->num_settable_node_attrs + j] = 1; + } else { + // Special case for N: adding charge to N increases its valence by 2, so we don't want to prevent that + // action, even if the max_valence is "full" (for v=3) + //if (j == charge_idx && is_nitro && used_valences[i] >= (max_valence[i] + (is_nitro ? 2 : 0))) // charge + // continue; + if (j == explH_idx && used_valences[i] >= max_valence[i]) // expl_H + continue; + // Following on the special case, we made it here, now if the node is not yet charged but already + // has a valence of 3, the only charge we can add is a +1, which is the 0th logit + // (again this assumes charge ranges are [0,1,-1]) + //if (j == charge_idx && is_nitro && used_valences[i] >= max_valence[i] && charge_val[i] == 0){ + //printf("valences: %f %d %d %d\n", used_valences[i], max_valence[i], charge_val[i], j); + if (j == charge_idx && is_special_valence && used_valences[i] >= max_valence[i]){ + //printf("Setting 1: %d\n", i * gd->num_node_attr_logits + logit_slice_start); + set_node_attr_mask[i * gd->num_node_attr_logits + logit_slice_start] = 1; + continue; + } + // Next, if the node is not yet charged, we can only add a charge if the valence is not maxed out + if (j == charge_idx && used_valences[i] >= max_valence[i]) + continue; + // printf("Setting: %d %d, %d %d\n", j, i * gd->num_node_attr_logits + logit_slice_start, logit_slice_end, logit_slice_start); + memsetf(set_node_attr_mask + i * gd->num_node_attr_logits + logit_slice_start, 1, + (logit_slice_end - logit_slice_start)); + } + } + if (used_valences[i] < max_valence[i] && g->num_nodes < max_nodes) { + memsetf(add_node_mask + i * gd->num_new_node_values, 1, gd->num_new_node_values); + } + } + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + for (int i = 0; i < g->num_edges; i++) { + if (bridges[i] == 0) { + remove_edge_mask[i] = 1; + } + int j = 0; // well there's only the bond attr + int one_hot_idx = gd->edge_attr_offsets[j] + _edge_attrs[j][i]; + edge_feat[2 * i * gd->num_edge_dim + one_hot_idx] = 1; + edge_feat[(2 * i + 1) * gd->num_edge_dim + one_hot_idx] = 1; + int logit_slice_start = gd->edge_attr_offsets[j] - j; + if (_edge_attrs[j][i] > 0) { + remove_edge_attr_mask[i * gd->num_settable_edge_attrs + j] = 1; + } else { + // TODO: we're not using aromatics here, instead single-double-triple is hardcoded + // k starts at 1 because the default value (0) is the single bond + for (int k = 1; k < 3; k++) { + // bond_valence - 1 because the 0th value is the single bond which has valence 1, and we'd be "removing" + // it to replace it with another bond + if ((used_valences[g->edges[2 * i]] + bond_valence[k] - 1 > max_valence[g->edges[2 * i]]) || + (used_valences[g->edges[2 * i + 1]] + bond_valence[k] - 1 > max_valence[g->edges[2 * i + 1]])) + continue; + + // use k - 1 because the default value (0) doesn't have an associated logit + set_edge_attr_mask[i * gd->num_edge_attr_logits + logit_slice_start + k - 1] = 1; + } + } + edge_index[2 * i] = g->edges[2 * i]; + edge_index[2 * i + 1] = g->edges[2 * i + 1]; + edge_index[2 * i + g->num_edges * 2] = g->edges[2 * i + 1]; + edge_index[2 * i + g->num_edges * 2 + 1] = g->edges[2 * i]; + } + // already filtered out for valence and such + memsetf(add_edge_mask, 1, num_creatable_edges); + int non_edge_idx_idx = 0; + for (int i = 0; i < g->num_nodes; i++) { + for (int j = i + 1; j < g->num_nodes; j++) { + if (!can_create_edge[i][j]) + continue; + non_edge_index[non_edge_idx_idx] = i; + non_edge_index[num_creatable_edges + non_edge_idx_idx] = j; + non_edge_idx_idx++; + } + } + if (PyErr_Occurred()) { + raise(SIGINT); + return NULL; + } + if (cond_info != NULL) { + PyObject *cpu_cond_info = PyObject_CallMethod(cond_info, "cpu", ""); + PyObject *cont_cond_info = PyObject_CallMethod(cpu_cond_info, "contiguous", ""); + PyObject *ptr = PyObject_CallMethod(cont_cond_info, "data_ptr", ""); + float *tensor_ptr = PyLong_AsVoidPtr(ptr); + + for (int i = 0; i < num_cond_info_dims; i++) { + cond_info_data[i] = tensor_ptr[i]; + } + Py_DECREF(ptr); + Py_DECREF(cpu_cond_info); + Py_DECREF(cont_cond_info); + } + + if (PyErr_Occurred()) { + PyErr_Print(); + raise(SIGINT); + return NULL; + } + *stop_mask = (g->num_nodes > 0 ? 1 : 0) && !is_any_hypervalent; + if (RETURN_C_DATA) { + // Data *data_obj = DataType.tp_new(&DataType, NULL, NULL); + Data *data_obj = (Data *)PyObject_CallObject((PyObject *)&DataType, NULL); + Data_init_C(data_obj, data, g->graph_def, shapes, is_float, 13 + (cond_info != NULL), mol_Data_names); + Py_DECREF(data); + return (PyObject *)data_obj; + } + // The following lines take about 80% of the runtime of the function on ~50 node graphs :"( + PyObject *res = PyDict_New(); + PyObject *frombuffer = PyObject_GetAttrString(torch_module, "frombuffer"); + PyObject *empty = PyObject_GetAttrString(torch_module, "empty"); + PyObject *dtype_f32 = PyObject_GetAttrString(torch_module, "float32"); + PyObject *dtype_i64 = PyObject_GetAttrString(torch_module, "int64"); + PyObject *fb_args = PyTuple_Pack(1, data); + PyObject *fb_kwargs = PyDict_New(); + int do_del_kw = 0; + for (int i = 0; i < 13; i++) { + int i_num_items = shapes[i][0] * shapes[i][1]; + PyObject *tensor; + PyDict_SetItemString(fb_kwargs, "dtype", is_float[i] ? dtype_f32 : dtype_i64); + if (i_num_items == 0) { + if (do_del_kw) { + PyDict_DelItemString(fb_kwargs, "offset"); + PyDict_DelItemString(fb_kwargs, "count"); + do_del_kw = 0; + } + PyObject *zero = PyLong_FromLong(0); + tensor = PyObject_Call(empty, PyTuple_Pack(1, zero), fb_kwargs); + Py_DECREF(zero); + } else { + PyObject *py_offset = PyLong_FromLong(offsets[i] * sizeof(float)); + PyObject *py_count = PyLong_FromLong(i_num_items); + PyDict_SetItemString(fb_kwargs, "offset", py_offset); + PyDict_SetItemString(fb_kwargs, "count", py_count); + do_del_kw = 1; + tensor = PyObject_Call(frombuffer, fb_args, fb_kwargs); + Py_DECREF(py_offset); + Py_DECREF(py_count); + } + PyObject *reshaped_tensor = PyObject_CallMethod(tensor, "view", "ii", shapes[i][0], shapes[i][1]); + PyDict_SetItemString(res, mol_Data_names[i], reshaped_tensor); + Py_DECREF(tensor); + Py_DECREF(reshaped_tensor); + } + Py_DECREF(frombuffer); + Py_DECREF(dtype_f32); + Py_DECREF(dtype_i64); + Py_DECREF(fb_args); + Py_DECREF(fb_kwargs); + Py_DECREF(data); + return res; +} diff --git a/src/C/node_view.c b/src/C/node_view.c new file mode 100644 index 00000000..d105ae2f --- /dev/null +++ b/src/C/node_view.c @@ -0,0 +1,244 @@ +#define PY_SSIZE_T_CLEAN +#include "./main.h" +#include "structmember.h" +#include + +static void NodeView_dealloc(NodeView *self) { + Py_XDECREF(self->graph); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *NodeView_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + NodeView *self; + self = (NodeView *)type->tp_alloc(type, 0); + if (self != NULL) { + self->graph = NULL; + } + return (PyObject *)self; +} + +static int NodeView_init(NodeView *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"graph", "index", NULL}; + PyObject *graph = NULL, *tmp; + int node_id = -1; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i", kwlist, &graph, &node_id)) + return -1; + + if (graph) { + tmp = (PyObject *)self->graph; + Py_INCREF(graph); + self->graph = (Graph *)graph; + Py_XDECREF(tmp); + self->index = node_id; + if (node_id >= 0) { + int node_exists = 0; + for (int i = 0; i < self->graph->num_nodes; i++) { + if (self->graph->nodes[i] == node_id) { + node_exists = 1; + } + } + if (!node_exists) { + PyErr_SetString(PyExc_KeyError, "Trying to create a view with a node that does not exist"); + return -1; + } + } + } + return 0; +} + +static PyMemberDef NodeView_members[] = { + {NULL} /* Sentinel */ +}; + +static int NodeView_setitem(PyObject *_self, PyObject *k, PyObject *v) { + NodeView *self = (NodeView *)_self; + if (self->index < 0) { + PyErr_SetString(PyExc_KeyError, "Cannot assign to a node"); + return -1; + } + PyObject *r = Graph_setnodeattr(self->graph, self->index, k, v); + Py_DECREF(r); + return r == NULL ? -1 : 0; +} + +static PyObject *NodeView_getitem(PyObject *_self, PyObject *k) { + NodeView *self = (NodeView *)_self; + if (self->index < 0) { + PyObject *args = PyTuple_Pack(2, self->graph, k); + PyObject *res = PyObject_CallObject((PyObject *)&NodeViewType, args); + Py_DECREF(args); + return res; + } + return Graph_getnodeattr(self->graph, self->index, k); +} + +static int NodeView_contains(PyObject *_self, PyObject *v) { + NodeView *self = (NodeView *)_self; + if (self->index < 0) { + // Graph_containsnode + if (!PyLong_Check(v)) { + PyErr_SetString(PyExc_TypeError, "NodeView.__contains__ only accepts integers"); + return -1; + } + int index = PyLong_AsLong(v); + for (int i = 0; i < self->graph->num_nodes; i++) { + if (self->graph->nodes[i] == index) { + return 1; + } + } + return 0; + } + // Graph_nodehasattr + PyObject *attr = Graph_getnodeattr(self->graph, self->index, v); + if (attr == NULL) { + PyErr_Clear(); + return 0; + } + Py_DECREF(attr); + return 1; +} + +static Py_ssize_t NodeView_len(PyObject *_self) { + NodeView *self = (NodeView *)_self; + if (self->index != -1) { + // This is inefficient, much like a lot of this codebase... but it's in C. For our graph sizes + // it's not a big deal (and unsurprisingly, it's fast) + Py_ssize_t num_attrs = 0; + for (int i = 0; i < self->graph->num_node_attrs; i++) { + if (self->graph->node_attrs[3 * i] == self->index) { + num_attrs++; + } + } + return num_attrs; + } + return self->graph->num_nodes; +} + +PyObject *NodeView_iter(PyObject *_self) { + NodeView *self = (NodeView *)_self; + if (self->index != -1) { + PyErr_SetString(PyExc_TypeError, "A bound NodeView is not iterable"); + return NULL; + } + Py_INCREF(_self); // We have to return a new reference, not a borrowed one + return _self; +} + +PyObject *NodeView_iternext(PyObject *_self) { + NodeView *self = (NodeView *)_self; + self->index++; + if (self->index >= self->graph->num_nodes) { + return NULL; + } + return PyLong_FromLong(self->graph->nodes[self->index]); +} + +PyObject *NodeView_get(PyObject *_self, PyObject *args) { + PyObject *key; + PyObject *default_value = Py_None; + if (!PyArg_ParseTuple(args, "O|O", &key, &default_value)) { + return NULL; + } + PyObject *res = NodeView_getitem(_self, key); + if (res == NULL) { + PyErr_Clear(); + Py_INCREF(default_value); + return default_value; + } + return res; +} + +int _NodeView_eq(NodeView *self, NodeView *other) { + if (self->index == -1 || other->index == -1) { + return -1; + } + + int num_settable = ((GraphDef*)self->graph->graph_def)->num_settable_node_attrs + 1; + int self_values[num_settable]; + int other_values[num_settable]; + for (int i = 0; i < num_settable; i++) { + self_values[i] = -1; + other_values[i] = -1; + } + for (int i = 0; i < self->graph->num_node_attrs; i++) { + if (self->graph->node_attrs[3 * i] == self->graph->nodes[self->index]) { + self_values[self->graph->node_attrs[3 * i + 1]] = self->graph->node_attrs[3 * i + 2]; + } + } + for (int i = 0; i < other->graph->num_node_attrs; i++) { + if (other->graph->node_attrs[3 * i] == other->graph->nodes[other->index]) { + other_values[other->graph->node_attrs[3 * i + 1]] = other->graph->node_attrs[3 * i + 2]; + } + } + for (int i = 0; i < num_settable; i++) { + // printf("%d %d %d\n", i, self_values[i], other_values[i]); + if (self_values[i] != other_values[i]) { + return 0; + } + } + return 1; +} + +static PyObject * +NodeView_richcompare(PyObject *_self, PyObject *other, int op) +{ + PyObject *result = NULL; + NodeView *self = (NodeView *)_self; + + if (other == NULL || other->ob_type != &NodeViewType) { + result = Py_NotImplemented; + } + else { + if (op == Py_EQ){ + if (self->index == -1 || ((NodeView *)other)->index == -1){ + PyErr_SetString(PyExc_TypeError, "NodeView.__eq__ only supports equality comparison for bound NodeViews"); + return NULL; + } + NodeView *other_node = (NodeView *)other; + if (_NodeView_eq(self, other_node) == 1){ + result = Py_True; + } else { + result = Py_False; + } + }else{ + PyErr_SetString(PyExc_TypeError, "NodeView only supports equality comparison"); + return NULL; + } + } + + Py_XINCREF(result); + return result; +} + +static PyMappingMethods NodeView_mapmeth = { + .mp_subscript = NodeView_getitem, + .mp_ass_subscript = NodeView_setitem, +}; + +static PySequenceMethods NodeView_seqmeth = { + .sq_contains = NodeView_contains, + .sq_length = NodeView_len, +}; + +static PyMethodDef NodeView_methods[] = { + {"get", (PyCFunction)NodeView_get, METH_VARARGS, "dict-like get"}, {NULL} /* Sentinel */ +}; + +PyTypeObject NodeViewType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_C.NodeView", + .tp_doc = PyDoc_STR("Constrained NodeView object"), + .tp_basicsize = sizeof(NodeView), + .tp_itemsize = 0, + //.tp_flags = Py_TPFLAGS_HAVE_RICHCOMPARE, //Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = NodeView_new, + .tp_init = (initproc)NodeView_init, + .tp_dealloc = (destructor)NodeView_dealloc, + .tp_members = NodeView_members, + .tp_methods = NodeView_methods, + .tp_as_mapping = &NodeView_mapmeth, + .tp_as_sequence = &NodeView_seqmeth, + .tp_iter = NodeView_iter, + .tp_iternext = NodeView_iternext, + .tp_richcompare = NodeView_richcompare, +}; diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 6cb8f979..d1b8c3e6 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -19,10 +19,14 @@ class GFNAlgorithm: updates: int = 0 global_cfg: Config is_eval: bool = False + requires_task: bool = False def step(self): self.updates += 1 # This isn't used anywhere? + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -74,6 +78,9 @@ def get_random_action_prob(self, it: int): return self.global_cfg.algo.train_random_action_prob return 0 + def set_task(self, task): + raise NotImplementedError() + class GFNTask: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 7077a9d1..7ce05a81 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -3,6 +3,7 @@ import torch_geometric.data as gd from torch import Tensor +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device @@ -10,7 +11,7 @@ from .graph_sampling import GraphSampler -class A2C: +class A2C(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -36,6 +37,7 @@ def __init__( The experiment configuration """ + self.global_cfg = cfg # TODO: this belongs in the base class self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -149,7 +151,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values - policy, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + policy, per_state_preds = model(batch) V = per_state_preds[:, 0] G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1 G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 4dd9cbfe..76095f2a 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -102,15 +102,19 @@ class TBConfig(StrictDataClass): do_parameterize_p_b: bool = False do_predict_n: bool = False do_sample_p_b: bool = False + do_sample_using_masks: bool = False do_length_normalize: bool = False subtb_max_len: int = 128 Z_learning_rate: float = 1e-4 Z_lr_decay: float = 50_000 + regularize_logZ: Optional[float] = None cum_subtb: bool = True loss_fn: LossFN = LossFN.MSE loss_fn_par: float = 1.0 n_loss: NLoss = NLoss.none n_loss_multiplier: float = 1.0 + tb_loss_multiplier: float = 1.0 + mle_loss_multiplier: float = 0.0 backward_policy: Backward = Backward.Uniform @@ -145,6 +149,14 @@ class SQLConfig(StrictDataClass): penalty: float = -10 +@dataclass +class LSTBConfig(StrictDataClass): + num_ls_steps: int = 1 + num_bck_steps: int = 1 + accept_criteria: str = "deterministic" + yield_only_accepted: bool = False + + @dataclass class AlgoConfig(StrictDataClass): """Generic configuration for algorithms @@ -196,8 +208,10 @@ class AlgoConfig(StrictDataClass): train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 + grad_acc_steps: int = 1 tb: TBConfig = field(default_factory=TBConfig) moql: MOQLConfig = field(default_factory=MOQLConfig) a2c: A2CConfig = field(default_factory=A2CConfig) fm: FMConfig = field(default_factory=FMConfig) sql: SQLConfig = field(default_factory=SQLConfig) + ls: LSTBConfig = field(default_factory=LSTBConfig) diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 9bfc3345..4798600b 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -5,6 +5,7 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import ( GraphActionCategorical, @@ -39,24 +40,24 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective num_layers=num_layers, num_heads=num_heads, ) - num_final = num_emb * 2 + num_final = num_emb num_mlp_layers = 0 self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers) # Edge attr logits are "sided", so we will compute both sides independently self.emb2set_edge_attr = mlp( num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers ) - self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers) - self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers) + self.emb2stop = mlp(num_emb * 2, num_emb, num_objectives, num_mlp_layers) + self.emb2reward = mlp(num_emb * 2, num_emb, 1, num_mlp_layers) self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers) self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) self.action_type_order = env_ctx.action_type_order self.mask_value = -10 self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): + def forward(self, g: gd.Batch, output_Qs=False): """See `GraphTransformer` for argument values""" - node_embeddings, graph_embeddings = self.transf(g, cond) + node_embeddings, graph_embeddings = self.transf(g) # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col]) @@ -86,7 +87,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): # Compute the greedy policy # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes - w = cond[:, -self.num_objectives :] + w = g.cond_info[:, -self.num_objectives :] w_dot_Q = [ (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) for qi, b in zip(cat.logits, cat.batch) @@ -122,8 +123,9 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective self.action_type_order = env_ctx.action_type_order self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch, output_Qs=False): + cond = g.cond_info + node_embeddings, graph_embeddings = self.transf(g) ne_row, ne_col = g.non_edge_index # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] @@ -156,7 +158,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): return cat, r_pred -class EnvelopeQLearning: +class EnvelopeQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -182,6 +184,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.task = task @@ -314,7 +317,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values # Q(s,a,omega) - fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True) + batch.cond_info = cond_info[batch_idx] + fwd_cat, per_state_preds = model(batch, output_Qs=True) Q_omega = fwd_cat.logits # reshape to List[shape: (num in all graphs, num actions on T, num_objectives) | for all types T] Q_omega = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega] @@ -323,7 +327,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: batchp = batch.batch_prime batchp_num_trajs = int(batchp.traj_lens.shape[0]) batchp_batch_idx = torch.arange(batchp_num_trajs, device=dev).repeat_interleave(batchp.traj_lens) - fwd_cat_prime, per_state_preds = model(batchp, batchp.cond_info[batchp_batch_idx], output_Qs=True) + batchp.cond_info = batchp.cond_info[batchp_batch_idx] + fwd_cat_prime, per_state_preds = model(batchp, output_Qs=True) Q_omega_prime = fwd_cat_prime.logits # We've repeated everything N_omega times, so we can reshape the same way as above but with # an extra N_omega first dimension diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index 33c436bf..c75c1ce4 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -46,7 +46,7 @@ def __init__( # in a number of settings the regular loss is more stable. self.fm_balanced_loss = cfg.algo.fm.balanced_loss self.fm_leaf_coef = cfg.algo.fm.leaf_coef - self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent + self.correct_idempotent: bool = cfg.algo.fm.correct_idempotent def construct_batch(self, trajs, cond_info, log_rewards): """Construct a batch from a list of trajectories and their information @@ -149,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Query the model for Fsa. The model will output a GraphActionCategorical, but we will # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of # log F(s,a) so we don't have to change anything in the sampling routines. - cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) + batch.cond_info = batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)] + cat, graph_out = model(batch) # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 9cf9faeb..60b412a2 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -1,11 +1,12 @@ import copy import warnings -from typing import List, Optional +from typing import Callable, List, Optional import torch import torch.nn as nn from torch import Tensor +from gflownet.algo.config import LSTBConfig from gflownet.envs.graph_building_env import ( Graph, GraphAction, @@ -87,122 +88,65 @@ def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor data: List[Dict] A list of trajectories. Each trajectory is a dict with keys - trajs: List[Tuple[Graph, GraphAction]], the list of states and actions + - bck_a: List[GraphAction], the reverse actions + - is_sink: List[int], sink states have P_B = 1 - fwd_logprob: sum logprobs P_F - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx + - interm_rewards: List[float], intermediate rewards """ dev = get_worker_device() + rng = get_worker_rng() # This will be returned - data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)] - # Let's also keep track of trajectory statistics according to the model - fwd_logprob: List[List[Tensor]] = [[] for _ in range(n)] - bck_logprob: List[List[Tensor]] = [[] for _ in range(n)] - + data = [ + { + "traj": [], + "bck_a": [GraphAction(GraphActionType.Pad)], # The reverse actions + "is_valid": True, + "is_sink": [], + "fwd_logprobs": [], + "U_bck_logprobs": [], + "interm_rewards": [], + } + for _ in range(n) + ] graphs = [self.env.new() for _ in range(n)] done = [False for _ in range(n)] - # TODO: instead of padding with Stop, we could have a virtual action whose probability - # always evaluates to 1. Presently, Stop should convert to a (0,0,0) ActionIndex, which should - # always be at least a valid index, and will be masked out anyways -- but this isn't ideal. - # Here we have to pad the backward actions with something, since the backward actions are - # evaluated at s_{t+1} not s_t. - bck_a = [[GraphAction(GraphActionType.Stop)] for _ in range(n)] - - rng = get_worker_rng() - - def not_done(lst): - return [e for i, e in enumerate(lst) if not done[i]] for t in range(self.max_len): - # Construct graphs for the trajectories that aren't yet done - torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] - not_done_mask = torch.tensor(done, device=dev).logical_not() - # Forward pass to get GraphActionCategorical - # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. - # TODO: compute bck_cat.log_prob(bck_a) when relevant - ci = cond_info[not_done_mask] if cond_info is not None else None - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) - if random_action_prob > 0: - # Device which graphs in the minibatch will get their action randomized - is_random_action = torch.tensor( - rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev - ).float() - # Set the logits to some large value to have a uniform distribution - fwd_cat.logits = [ - is_random_action[b][:, None] * torch.ones_like(i) * 100 + i * (1 - is_random_action[b][:, None]) - for i, b in zip(fwd_cat.logits, fwd_cat.batch) - ] - if self.sample_temp != 1: - sample_cat = copy.copy(fwd_cat) - sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] - actions = sample_cat.sample() - else: - actions = fwd_cat.sample() - graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a) for g, a in zip(torch_graphs, actions)] - log_probs = fwd_cat.log_prob(actions) - # Step each trajectory, and accumulate statistics - for i, j in zip(not_done(range(n)), range(n)): - fwd_logprob[i].append(log_probs[j].unsqueeze(0)) - data[i]["traj"].append((graphs[i], graph_actions[j])) - bck_a[i].append(self.env.reverse(graphs[i], graph_actions[j])) - # Check if we're done - if graph_actions[j].action is GraphActionType.Stop: - done[i] = True - bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) - data[i]["is_sink"].append(1) - else: # If not done, try to step the self.environment - gp = graphs[i] - try: - # self.env.step can raise AssertionError if the action is illegal - gp = self.env.step(graphs[i], graph_actions[j]) - assert len(gp.nodes) <= self.max_nodes - except AssertionError: - done[i] = True - data[i]["is_valid"] = False - bck_logprob[i].append(torch.tensor([1.0], device=dev).log()) - data[i]["is_sink"].append(1) - continue - if t == self.max_len - 1: - done[i] = True - # If no error, add to the trajectory - # P_B = uniform backward - n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) - bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log()) - data[i]["is_sink"].append(0) - graphs[i] = gp - if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): - # check if the graph is sane (e.g. RDKit can construct a molecule from it) otherwise - # treat the done action as illegal - data[i]["is_valid"] = False + # This modifies `data` and `graphs` in place + self._forward_step(model, data, graphs, cond_info, t, done, rng, dev, random_action_prob) if all(done): break - + # Note on is_sink and padding: # is_sink indicates to a GFN algorithm that P_B(s) must be 1 - + # # There are 3 types of possible trajectories # A - ends with a stop action. traj = [..., (g, a), (gp, Stop)], P_B = [..., bck(gp), 1] # B - ends with an invalid action. = [..., (g, a)], = [..., 1] # C - ends at max_len. = [..., (g, a)], = [..., bck(gp)] - + # # Let's say we pad terminal states, then: # A - ends with a stop action. traj = [..., (g, a), (gp, Stop), (gp, None)], P_B = [..., bck(gp), 1, 1] # B - ends with an invalid action. = [..., (g, a), (g, None)], = [..., 1, 1] # C - ends at max_len. = [..., (g, a), (gp, None)], = [..., bck(gp), 1] # and then P_F(terminal) "must" be 1 - for i in range(n): + for i in range(len(data)): # If we're not bootstrapping, we could query the reward # model here, but this is expensive/impractical. Instead # just report forward and backward logprobs - data[i]["fwd_logprob"] = sum(fwd_logprob[i]) - data[i]["bck_logprob"] = sum(bck_logprob[i]) - data[i]["bck_logprobs"] = torch.stack(bck_logprob[i]).reshape(-1) + # TODO: stop using dicts and used typed objects + data[i]["fwd_logprobs"] = torch.stack(data[i]["fwd_logprobs"]).reshape(-1) + data[i]["U_bck_logprobs"] = torch.stack(data[i]["U_bck_logprobs"]).reshape(-1) + data[i]["fwd_logprob"] = data[i]["fwd_logprobs"].sum() # type: ignore + data[i]["U_bck_logprob"] = data[i]["U_bck_logprobs"].sum() # type: ignore data[i]["result"] = graphs[i] - data[i]["bck_a"] = bck_a[i] if self.pad_with_terminal_state: - # TODO: instead of padding with Stop, we could have a virtual action whose - # probability always evaluates to 1. - data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop))) - data[i]["is_sink"].append(1) + data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore + data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) + data[i]["is_sink"].append(1) # type: ignore + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore return data def sample_backward_from_graphs( @@ -229,6 +173,7 @@ def sample_backward_from_graphs( Probability of taking a random action (only used if model parameterizes P_B) """ + starting_graphs = list(graphs) dev = get_worker_device() n = len(graphs) done = [False] * n @@ -237,64 +182,315 @@ def sample_backward_from_graphs( "traj": [(graphs[i], GraphAction(GraphActionType.Stop))], "is_valid": True, "is_sink": [1], - "bck_a": [GraphAction(GraphActionType.Stop)], + "bck_a": [GraphAction(GraphActionType.Pad)], "bck_logprobs": [0.0], + "U_bck_logprobs": [0.0], "result": graphs[i], } for i in range(n) ] - def not_done(lst): - return [e for i, e in enumerate(lst) if not done[i]] - # TODO: This should be doable. if random_action_prob > 0: warnings.warn("Random action not implemented for backward sampling") while sum(done) < n: - torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] - not_done_mask = torch.tensor(done, device=dev).logical_not() - if model is not None: - ci = cond_info[not_done_mask] if cond_info is not None else None - _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), ci) - else: - gbatch = self.ctx.collate(torch_graphs) - action_types = self.ctx.bck_action_type_order - action_masks = [action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] - bck_cat = GraphActionCategorical( - gbatch, - raw_logits=[torch.ones_like(m) for m in action_masks], - keys=[GraphTransformerGFN.action_type_to_key(t) for t in action_types], - action_masks=action_masks, - types=action_types, - ) - bck_actions = bck_cat.sample() - graph_bck_actions = [ - self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) - ] - bck_logprobs = bck_cat.log_prob(bck_actions) - - for i, j in zip(not_done(range(n)), range(n)): - if not done[i]: - g = graphs[i] - b_a = graph_bck_actions[j] - gp = self.env.step(g, b_a) - f_a = self.env.reverse(g, b_a) - graphs[i], f_a = relabel(gp, f_a) - data[i]["traj"].append((graphs[i], f_a)) - data[i]["bck_a"].append(b_a) - data[i]["is_sink"].append(0) - data[i]["bck_logprobs"].append(bck_logprobs[j].item()) - if len(graphs[i]) == 0: - done[i] = True + self._backward_step(model, data, graphs, cond_info, done, dev) for i in range(n): # See comments in sample_from_model - data[i]["traj"] = data[i]["traj"][::-1] - data[i]["bck_a"] = [GraphAction(GraphActionType.Stop)] + data[i]["bck_a"][::-1] - data[i]["is_sink"] = data[i]["is_sink"][::-1] - data[i]["bck_logprobs"] = torch.tensor(data[i]["bck_logprobs"][::-1], device=dev).reshape(-1) + # TODO: stop using dicts and used typed objects + data[i]["traj"] = data[i]["traj"][::-1] # type: ignore + # I think this pad is only necessary if we're padding terminal states??? + data[i]["bck_a"] = [GraphAction(GraphActionType.Pad)] + data[i]["bck_a"][::-1] # type: ignore + data[i]["is_sink"] = data[i]["is_sink"][::-1] # type: ignore + data[i]["U_bck_logprobs"] = torch.tensor( + [0] + data[i]["U_bck_logprobs"][::-1], device=dev # type: ignore + ).reshape(-1) if self.pad_with_terminal_state: - data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop))) - data[i]["is_sink"].append(1) + data[i]["traj"].append((starting_graphs[i], GraphAction(GraphActionType.Pad))) # type: ignore + data[i]["U_bck_logprobs"] = torch.cat([data[i]["U_bck_logprobs"], torch.tensor([0.0], device=dev)]) + data[i]["is_sink"].append(1) # type: ignore + assert len(data[i]["U_bck_logprobs"]) == len(data[i]["bck_a"]) # type: ignore return data + + def local_search_sample_from_model( + self, + model: nn.Module, + n: int, + cond_info: Optional[Tensor], + random_action_prob: float = 0.0, + cfg: LSTBConfig = LSTBConfig(), + compute_reward: Optional[Callable] = None, + ): + dev = get_worker_device() + rng = get_worker_rng() + # First get n trajectories + current_trajs = self.sample_from_model(model, n, cond_info, random_action_prob) + compute_reward(current_trajs, cond_info) # in-place + # Then we're going to perform num_ls_steps of local search, each with num_bck_steps backward steps. + # Each local search step is a kind of Metropolis-Hastings step, where we propose a new trajectory, which may + # be accepted or rejected based on the forward and backward probabilities and reward. + # Finally, we return all the trajectories that were sampled. + + # We keep the initial trajectories to return them at the end. We need to copy 'traj' to avoid modifying it + initial_trajs = [{k: v if k != "traj" else list(v) for k, v in t.items()} for t in current_trajs] + sampled_terminals = [] + if self.pad_with_terminal_state: + for t in current_trajs: + t["traj"] = t["traj"][:-1] # Remove the padding state + num_accepts = 0 + + for mcmc_steps in range(cfg.num_ls_steps): + # First we must do a bit of accounting so that we can later prevent trajectories longer than max_len + stop = GraphActionType.Stop + num_pad = [(1 if t["traj"][-1][1].action == stop else 0) for t in current_trajs] + trunc_lens = [max(0, len(i["traj"]) - cfg.num_bck_steps - pad) for i, pad in zip(current_trajs, num_pad)] + + # Go backwards num_bck_steps steps + bck_trajs = [ + {"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs + ] # type: ignore + graphs = [i["traj"][-1][0] for i in current_trajs] + done = [False] * n + fwd_a: List[GraphAction] = [] + for i in range(cfg.num_bck_steps): + # This modifies `bck_trajs` & `graphs` in place, passing fwd_a computes P_F(s|s') for the previous step + self._backward_step(model, bck_trajs, graphs, cond_info, done, dev, fwd_a) + fwd_a = [t["traj"][-1][1] for t in bck_trajs] + # Add forward logprobs for the last step + self._add_fwd_logprobs(bck_trajs, graphs, model, cond_info, [False] * n, dev, fwd_a) + log_P_B_tau_back = [sum(t["bck_logprobs"]) for t in bck_trajs] + log_P_F_tau_back = [sum(t["fwd_logprobs"]) for t in bck_trajs] + + # Go forward to get full trajectories + fwd_trajs = [ + {"traj": [], "bck_a": [], "is_sink": [], "bck_logprobs": [], "fwd_logprobs": []} for _ in current_trajs + ] # type: ignore + done = [False] * n + bck_a: List[GraphAction] = [] + while not all(done): + self._forward_step(model, fwd_trajs, graphs, cond_info, 0, done, rng, dev, random_action_prob, bck_a) + done = [d or (len(t["traj"]) + T) >= self.max_len for d, t, T in zip(done, fwd_trajs, trunc_lens)] + bck_a = [t["bck_a"][-1] for t in fwd_trajs] + # Add backward logprobs for the last step; this is only necessary if the last action is not a stop + done = [t["traj"][-1][1].action == stop for t in fwd_trajs] + if not all(done): + self._add_bck_logprobs(fwd_trajs, graphs, model, cond_info, done, dev, bck_a) + log_P_F_tau_recon = [sum(t["fwd_logprobs"]) for t in fwd_trajs] + log_P_B_tau_recon = [sum(t["bck_logprobs"]) for t in fwd_trajs] + + # We add those new terminal states to the list of terminal states + terminals = [t["traj"][-1][0] for t in fwd_trajs] + sampled_terminals.extend(terminals) + for traj, term in zip(fwd_trajs, terminals): + traj["result"] = term + traj["is_accept"] = False # type: ignore + # Compute rewards for the acceptance + if compute_reward is not None: + compute_reward(fwd_trajs, cond_info) + + # To end the iteration, we replace the current trajectories with the new ones if they are accepted by MH + for i in range(n): + if cfg.accept_criteria == "deterministic": + # Keep the highest reward + if fwd_trajs[i]["log_reward"] > current_trajs[i]["log_reward"]: + current_trajs[i] = fwd_trajs[i] + num_accepts += 1 + elif cfg.accept_criteria == "stochastic": + # Accept with probability max(1, R(x')/R(x)q(x'|x)/q(x|x')) + log_q_xprime_given_x = log_P_B_tau_back[i] + log_P_F_tau_recon[i] + log_q_x_given_xprime = log_P_B_tau_recon[i] + log_P_F_tau_back[i] + log_R_ratio = fwd_trajs[i]["log_reward"] - current_trajs[i]["log_reward"] + log_acceptance_ratio = log_R_ratio + log_q_xprime_given_x - log_q_x_given_xprime + if log_acceptance_ratio > 0 or rng.uniform() < torch.exp(log_acceptance_ratio): + current_trajs[i] = fwd_trajs[i] + num_accepts += 1 + elif cfg.accept_criteria == "always": + current_trajs[i] = fwd_trajs[i] + num_accepts += 1 + + # Finally, we resample new "P_B-on-policy" trajectories from the terminal states + # If we're only interested in the accepted trajectories, we use them as starting points instead + if cfg.yield_only_accepted: + sampled_terminals = [i["traj"][-1][0] for i in current_trajs] + stacked_ci = cond_info + + if not cfg.yield_only_accepted: + # In this scenario, the batch is n // num_ls_steps, so we do some stacking + stacked_ci = ( + {k: cond_info[k].repeat(cfg.num_ls_steps, *((1,) * (cond_info[k].ndim - 1))) for k in cond_info} + if cond_info is not None + else None + ) + returned_trajs = self.sample_backward_from_graphs(sampled_terminals, model, stacked_ci, random_action_prob) + # TODO: modify the trajs' cond_info!!! + return initial_trajs + returned_trajs, num_accepts / (cfg.num_ls_steps * n) + + def _forward_step(self, model, data, graphs, cond_info, t, done, rng, dev, random_action_prob, bck_a=[]) -> None: + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] if not bck_a else lst + + n = len(data) + # Construct graphs for the trajectories that aren't yet done + torch_graphs = [self.ctx.graph_to_Data(i) for i in not_done(graphs)] + if not bck_a: + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + else: + not_done_mask = torch.tensor([True] * n, device=cond_info["encoding"].device) + # Forward pass to get GraphActionCategorical + # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. + # TODO: compute bck_cat.log_prob(bck_a) when relevant + batch = self.ctx.collate(torch_graphs) + batch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, bck_cat, *_ = model(batch.to(dev)) + if random_action_prob > 0: + # Device which graphs in the minibatch will get their action randomized + is_random_action = torch.tensor( + rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev + ).float() + # Set the logits to some large value to have a uniform distribution + fwd_cat.logits = [ + is_random_action[b][:, None] * torch.ones_like(i) * 100 + i * (1 - is_random_action[b][:, None]) + for i, b in zip(fwd_cat.logits, fwd_cat.batch) + ] + if self.sample_temp != 1: + sample_cat = copy.copy(fwd_cat) + sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] + actions = sample_cat.sample() + else: + actions = fwd_cat.sample() + graph_actions = [self.ctx.ActionIndex_to_GraphAction(g, a, fwd=True) for g, a in zip(torch_graphs, actions)] + log_probs = fwd_cat.log_prob(actions) + if bck_a: + aidx_bck_a = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, bck_a)] + bck_logprobs = bck_cat.log_prob(aidx_bck_a) + # Step each trajectory, and accumulate statistics + for i, j in zip(not_done(range(n)), range(n)): + if bck_a and len(data[i]["bck_logprobs"]) < len(data[i]["traj"]): + data[i]["bck_logprobs"].append(bck_logprobs[j].unsqueeze(0)) + if done[i]: + continue + data[i]["fwd_logprobs"].append(log_probs[j].unsqueeze(0)) + data[i]["traj"].append((graphs[i], graph_actions[j])) + data[i]["bck_a"].append(self.env.reverse(graphs[i], graph_actions[j])) + if "U_bck_logprobs" not in data[i]: + data[i]["U_bck_logprobs"] = [] + # Check if we're done + if graph_actions[j].action is GraphActionType.Stop: + done[i] = True + data[i]["U_bck_logprobs"].append(torch.tensor([1.0], device=dev).log()) + data[i]["is_sink"].append(1) + else: # If not done, try to step the self.environment + gp = graphs[i] + try: + # self.env.step can raise AssertionError if the action is illegal + gp = self.env.step(graphs[i], graph_actions[j]) + assert len(gp.nodes) <= self.max_nodes + except AssertionError: + done[i] = True + data[i]["is_valid"] = False + data[i]["U_bck_logprobs"].append(torch.tensor([1.0], device=dev).log()) + data[i]["is_sink"].append(1) + continue + if t == self.max_len - 1: + done[i] = True + # If no error, add to the trajectory + # P_B = uniform backward + n_back = self.env.count_backward_transitions(gp, check_idempotent=self.correct_idempotent) + data[i]["U_bck_logprobs"].append(torch.tensor([1 / n_back], device=dev).log()) + data[i]["is_sink"].append(0) + graphs[i] = gp + if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): + # check if the graph is sane (e.g. RDKit can construct a molecule from it) otherwise + # treat the done action as illegal + data[i]["is_valid"] = False + # Nothing is returned, data is modified in place + + def _backward_step(self, model, data, graphs, cond_info, done, dev, fwd_a=[]): + # fwd_a is a list of GraphActions that are the reverse of the last backwards actions we took. + # Passing them allows us to compute the forward logprobs of the actions we took. + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] if not fwd_a else lst + + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] + if not fwd_a: + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + else: + not_done_mask = torch.tensor([True] * len(graphs), device=cond_info["encoding"].device) + if model is not None: + gbatch = self.ctx.collate(torch_graphs) + gbatch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, bck_cat, *_ = model(gbatch.to(dev)) + else: + gbatch = self.ctx.collate(torch_graphs) + action_types = self.ctx.bck_action_type_order + action_masks = [action_type_to_mask(t, gbatch, assert_mask_exists=True) for t in action_types] + bck_cat = GraphActionCategorical( + gbatch, + raw_logits=[torch.ones_like(m) for m in action_masks], + keys=[GraphTransformerGFN.action_type_to_key(t) for t in action_types], + action_masks=action_masks, + types=action_types, + ) + bck_actions = bck_cat.sample() + graph_bck_actions = [ + self.ctx.ActionIndex_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) + ] + bck_logprobs = bck_cat.log_prob(bck_actions) + if fwd_a and model is not None: + aidx_fwd_a = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, fwd_a)] + fwd_logprobs = fwd_cat.log_prob(aidx_fwd_a) + + for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + if fwd_a and model is not None and len(data[i]["fwd_logprobs"]) < len(data[i]["traj"]): + data[i]["fwd_logprobs"].append(fwd_logprobs[j].item()) + if done[i]: + # This can happen when fwd_a is passed, we should optimize this though. The reason is that even + # if a graph is done, we may still want to compute its forward logprobs. + continue + g = graphs[i] + b_a = graph_bck_actions[j] + gp = self.env.step(g, b_a) + f_a = self.env.reverse(g, b_a) + graphs[i], f_a = relabel(gp, f_a) + data[i]["traj"].append((graphs[i], f_a)) + data[i]["bck_a"].append(b_a) + data[i]["is_sink"].append(0) + data[i]["bck_logprobs"].append(bck_logprobs[j].item()) + + if len(graphs[i]) == 0: + done[i] = True + if "U_bck_logprobs" not in data[i]: + data[i]["U_bck_logprobs"] = [] + if not done[i]: + n_back = self.env.count_backward_transitions(graphs[i], check_idempotent=self.correct_idempotent) + data[i]["U_bck_logprobs"].append(torch.tensor([1.0 / n_back], device=dev).log()) + + def _add_fwd_logprobs(self, data, graphs, model, cond_info, done, dev, fwd_a): + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] + + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + gbatch = self.ctx.collate(torch_graphs) + gbatch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, *_ = model(gbatch.to(dev)) + fwd_actions = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, fwd_a)] + log_probs = fwd_cat.log_prob(fwd_actions) + for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + data[i]["fwd_logprobs"].append(log_probs[j].item()) + + def _add_bck_logprobs(self, data, graphs, model, cond_info, done, dev, bck_a): + def not_done(lst): + return [e for i, e in enumerate(lst) if not done[i]] + + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(len(graphs)))] + not_done_mask = torch.tensor(done, device=cond_info["encoding"].device).logical_not() + gbatch = self.ctx.collate(torch_graphs) + gbatch.cond_info = cond_info["encoding"][not_done_mask] if cond_info is not None else None + fwd_cat, bck_cat, *_ = model(gbatch.to(dev)) + bck_actions = [self.ctx.GraphAction_to_ActionIndex(g, a) for g, a in zip(torch_graphs, bck_a)] + log_probs = bck_cat.log_prob(bck_actions) + for i, j in zip(not_done(range(len(graphs))), range(len(graphs))): + data[i]["bck_logprobs"].append(log_probs[j].item()) diff --git a/src/gflownet/algo/local_search_tb.py b/src/gflownet/algo/local_search_tb.py new file mode 100644 index 00000000..6bd65fcd --- /dev/null +++ b/src/gflownet/algo/local_search_tb.py @@ -0,0 +1,51 @@ +from gflownet import GFNTask +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.data.data_source import DataSource + + +class LocalSearchTB(TrajectoryBalance): + requires_task: bool = True + task: GFNTask + + def __init__(self, env, ctx, cfg): + super().__init__(env, ctx, cfg) + self.task = None + + def set_task(self, task): + self.task = task + assert self.cfg.do_parameterize_p_b, "LocalSearchTB requires do_parameterize_p_b to be True" + + def create_training_data_from_own_samples(self, model, n, cond_info=None, random_action_prob=0.0): + assert self.task is not None, "LocalSearchTB requires a task to be set" + if self.global_cfg.algo.ls.yield_only_accepted: + n_per_step = n // 2 + assert n % 2 == 0, "n must be divisible by 2" + else: + assert n % (1 + self.global_cfg.algo.ls.num_ls_steps) == 0, "n must be divisible by 1 + num_ls_steps" + n_per_step = n // (1 + self.global_cfg.algo.ls.num_ls_steps) + cond_info = {k: v[:n_per_step] for k, v in cond_info.items()} if cond_info is not None else None + random_action_prob = random_action_prob or 0.0 + data, accept_rate = self.graph_sampler.local_search_sample_from_model( + model, + n_per_step, + cond_info, + random_action_prob, + self.global_cfg.algo.ls, + self._compute_log_rewards, + ) + for t in data: + t["accept_rate"] = accept_rate + return data + + def _compute_log_rewards(self, trajs, cond_info): + """Sets trajs' log_reward key by querying the task.""" + cfg = self.global_cfg + self.cfg = cfg # TODO: fix TB so we don't need to do this + # Doing a bit of hijacking here to compute the log rewards, DataSource implements this for us. + # May be worth refactoring this to be more general eventually, this depends on `self` having ctx, task, and cfg + # attributes. + DataSource.set_traj_cond_info(self, trajs, cond_info) # type: ignore + DataSource.compute_properties(self, trajs) # type: ignore + DataSource.compute_log_rewards(self, trajs) # type: ignore + self.cfg = cfg.algo.tb + # trajs is modified in place, so no need to return anything diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index b1a636de..52314d03 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -34,7 +34,8 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) # Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions - fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + fwd_cat, log_reward_preds = model(batch) # This is the log prob of each action in the trajectory log_prob = fwd_cat.log_prob(batch.actions) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index a9d61aaa..99730279 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -4,13 +4,14 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device -class SoftQLearning: +class SoftQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -33,6 +34,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -147,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per object predictions # Here we will interpret the logits of the fwd_cat as Q values - Q, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + Q, per_state_preds = model(batch) if self.do_q_prime_correction: # First we need to estimate V_soft. We will use q_a' = \pi diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index eac57cc6..fd23db73 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -109,7 +109,7 @@ def __init__( """ self.ctx = ctx self.env = env - self.global_cfg = cfg + self.global_cfg = cfg # TODO: this belongs in the base class self.cfg = cfg.algo.tb self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes @@ -147,9 +147,6 @@ def __init__( self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(get_worker_device()) - def set_is_eval(self, is_eval: bool): - self.is_eval = is_eval - def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, @@ -177,17 +174,10 @@ def create_training_data_from_own_samples( - reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise - fwd_logprob: log Z + sum logprobs P_F - bck_logprob: sum logprobs P_B - - logZ: predicted log Z - loss: predicted loss (if bootstrapping) - is_valid: is the generated graph valid according to the env & ctx """ - dev = get_worker_device() - cond_info = cond_info.to(dev) if cond_info is not None else None data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) - if cond_info is not None: - logZ_pred = model.logZ(cond_info) - for i in range(n): - data[i]["logZ"] = logZ_pred[i].item() return data def create_training_data_from_graphs( @@ -217,21 +207,21 @@ def create_training_data_from_graphs( """ if self.cfg.do_sample_p_b: assert model is not None and cond_info is not None and random_action_prob is not None - dev = get_worker_device() - cond_info = cond_info.to(dev) return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob ) + elif self.cfg.do_sample_using_masks: + return self.graph_sampler.sample_backward_from_graphs(graphs, None, cond_info, random_action_prob) trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: n_back = [ self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) for gp, _ in traj["traj"][1:] ] + [1] - traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(get_worker_device()) + traj["U_bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(get_worker_device()) traj["result"] = traj["traj"][-1][0] if self.cfg.do_parameterize_p_b: - traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]] + traj["bck_a"] = [GraphAction(GraphActionType.Pad)] + [self.env.reverse(g, a) for g, a in traj["traj"]] # There needs to be an additonal node when we're parameterizing P_B, # See sampling with parametrized P_B traj["traj"].append(deepcopy(traj["traj"][-1])) @@ -315,7 +305,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): ] batch = self.ctx.collate(torch_graphs) batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs]) - batch.log_p_B = torch.cat([i["bck_logprobs"] for i in trajs], 0) + batch.U_log_p_B = torch.cat([i["U_bck_logprobs"] for i in trajs], 0) batch.actions = torch.tensor(actions) if self.cfg.do_parameterize_p_b: batch.bck_actions = torch.tensor( @@ -352,7 +342,7 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.bck_ip_lens = torch.tensor([len(i) for i in bck_ipa]) # compute_batch_losses expects these two optional values, if someone else doesn't fill them in, default to 0 - batch.num_offline = 0 + batch.num_offline = 0 # TODO: this has been half-deprecated, finish the job batch.num_online = 0 return batch @@ -402,12 +392,14 @@ def compute_batch_losses( # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: - fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, bck_cat, per_graph_out = model(batch) else: if self.model_is_autoregressive: - fwd_cat, per_graph_out = model(batch, cond_info, batched=True) + fwd_cat, per_graph_out = model(batch, batched=True) else: - fwd_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, per_graph_out = model(batch) # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] @@ -472,7 +464,7 @@ def compute_batch_losses( # occasion masks out the first P_B of the "next" trajectory that we've shifted. log_p_B = torch.roll(log_p_B, -1, 0) * (1 - batch.is_sink) else: - log_p_B = batch.log_p_B + log_p_B = batch.U_log_p_B assert log_p_F.shape == log_p_B.shape if self.cfg.n_loss == NLoss.TB: @@ -503,13 +495,15 @@ def compute_batch_losses( if self.cfg.do_parameterize_p_b: # Life is pain, log_p_B is one unit too short for all trajs + log_p_B_unif = batch.U_log_p_B + assert log_p_B_unif.shape[0] == log_p_B.shape[0] - log_p_B_unif = torch.zeros_like(log_p_B) - for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)): - log_p_B_unif[s : e - 1] = batch.log_p_B[s - i : e - 1 - i] + # log_p_B_unif = torch.zeros_like(log_p_B) + # for i, (s, e) in enumerate(zip(first_graph_idx, traj_cumlen)): + # log_p_B_unif[s : e - 1] = batch.U_log_p_B[s - i : e - 1 - i] - if self.cfg.backward_policy == Backward.Uniform: - log_p_B = log_p_B_unif + # if self.cfg.backward_policy == Backward.Uniform: + # log_p_B = log_p_B_unif else: log_p_B_unif = log_p_B @@ -576,16 +570,23 @@ def compute_batch_losses( num_bootstrap = num_bootstrap or len(log_rewards) reward_losses = self._loss(log_rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap], self.reward_loss) - reward_loss = reward_losses.mean() * self.cfg.reward_loss_multiplier + reward_loss = reward_losses.mean() else: reward_loss = 0 + log_Z_reg_loss = (log_Z - self.cfg.regularize_logZ).pow(2).mean() if self.cfg.regularize_logZ is not None else 0 + n_loss = n_loss.mean() tb_loss = traj_losses.mean() - loss = tb_loss + reward_loss + self.cfg.n_loss_multiplier * n_loss + mle_loss = -traj_log_p_F.mean() + loss = ( + tb_loss * self.cfg.tb_loss_multiplier + + reward_loss * self.cfg.reward_loss_multiplier + + n_loss * self.cfg.n_loss_multiplier + + log_Z_reg_loss + + mle_loss * self.cfg.mle_loss_multiplier + ) info = { - "offline_loss": traj_losses[: batch.num_offline].mean() if batch.num_offline > 0 else 0, - "online_loss": traj_losses[batch.num_offline :].mean() if batch.num_online > 0 else 0, "reward_loss": reward_loss, "invalid_trajectories": invalid_mask.sum() / batch.num_online if batch.num_online > 0 else 0, "invalid_logprob": (invalid_mask * traj_log_p_F).sum() / (invalid_mask.sum() + 1e-4), @@ -595,9 +596,16 @@ def compute_batch_losses( "loss": loss.item(), "n_loss": n_loss, "tb_loss": tb_loss.item(), - "batch_entropy": -traj_log_p_F.mean(), + "batch_entropy": fwd_cat.entropy().mean(), "traj_lens": batch.traj_lens.float().mean(), + "avg_log_reward": clip_log_R.mean(), } + sources = set(batch.sources) + if len(sources) > 1: + for source in sources: + info[f"{source}_loss"] = ( + traj_losses[torch.as_tensor([i == source for i in batch.sources])].mean().item() + ) if self.ctx.has_n() and self.cfg.do_predict_n: info["n_loss_pred"] = scatter( (log_n_preds - batch.log_ns) ** 2, batch_idx, dim=0, dim_size=num_trajs, reduce="sum" @@ -609,6 +617,8 @@ def compute_batch_losses( d = d * d d[final_graph_idx] = 0 info["n_loss_maxent"] = scatter(d, batch_idx, dim=0, dim_size=num_trajs, reduce="sum").mean() + if self.cfg.mle_loss_multiplier != 0: + info["mle_loss"] = mle_loss.item() return loss, info diff --git a/src/gflownet/config.py b/src/gflownet/config.py index b66f238d..21e47aaa 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -79,6 +79,10 @@ class Config(StrictDataClass): The hostname of the machine on which the experiment is run pickle_mp_messages : bool Whether to pickle messages sent between processes (only relevant if num_workers > 0) + mp_buffer_size : Optional[int] + If specified, use a buffer of this size in bytes for passing tensors between processes. + Note that this is only relevant if num_workers > 0. + Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers. git_hash : Optional[str] The git hash of the current commit overwrite_existing_exp : bool @@ -86,12 +90,13 @@ class Config(StrictDataClass): """ desc: str = "noDesc" - log_dir: str = MISSING + log_dir: Optional[str] = MISSING device: str = "cuda" seed: int = 0 validate_every: int = 1000 checkpoint_every: Optional[int] = None store_all_checkpoints: bool = False + load_model_state: Optional[str] = None print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None @@ -102,6 +107,9 @@ class Config(StrictDataClass): pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False + mp_buffer_size: Optional[int] = None + world_size: int = 1 + rank: int = 0 algo: AlgoConfig = field(default_factory=AlgoConfig) model: ModelConfig = field(default_factory=ModelConfig) opt: OptimizerConfig = field(default_factory=OptimizerConfig) diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index 1e1b7f98..d4a51bc4 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -34,3 +34,5 @@ class ReplayConfig(StrictDataClass): hindsight_ratio: float = 0 num_from_replay: Optional[int] = None num_new_samples: Optional[int] = None + keep_highest_rewards: bool = False + keep_only_uniques: bool = False diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 85ede753..8b945d05 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -1,15 +1,20 @@ +import traceback import warnings from typing import Callable, Generator, List, Optional import numpy as np import torch from torch.utils.data import IterableDataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer, detach_and_cpu -from gflownet.envs.graph_building_env import GraphBuildingEnvContext -from gflownet.utils.misc import get_worker_rng +from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnvContext, action_type_to_mask +from gflownet.envs.seq_building_env import SeqBatch +from gflownet.models.graph_transformer import GraphTransformerGFN +from gflownet.utils.misc import get_this_wid, get_worker_rng +from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer def cycle_call(it): @@ -44,6 +49,8 @@ def __init__( self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step + self._err_tol = 10 + self.setup_mp_buffers() def add_sampling_hook(self, hook: Callable): """Add a hook that is called when sampling new trajectories. @@ -61,23 +68,57 @@ def __iter__(self): its = [i() for i in self.iterators] self.algo.set_is_eval(self.is_algo_eval) while True: - with self.global_step_count_lock: - self.current_iter = self.global_step_count.item() - self.global_step_count += 1 - iterator_outputs = [next(i, None) for i in its] - if any(i is None for i in iterator_outputs): - if not all(i is None for i in iterator_outputs): - warnings.warn("Some iterators are done, but not all. You may be mixing incompatible iterators.") - iterator_outputs = [i for i in iterator_outputs if i is not None] - else: - break - traj_lists, batch_infos = zip(*iterator_outputs) - trajs = sum(traj_lists, []) - # Merge all the dicts into one - batch_info = {} - for d in batch_infos: - batch_info.update(d) - yield self.create_batch(trajs, batch_info) + try: + with self.global_step_count_lock: + self.current_iter = self.global_step_count.item() + self.global_step_count += 1 + iterator_outputs = [next(i, None) for i in its] + if any(i is None for i in iterator_outputs): + if not all(i is None for i in iterator_outputs): + warnings.warn("Some iterators are done, but not all. You may be mixing incompatible iterators.") + iterator_outputs = [i for i in iterator_outputs if i is not None] + else: + break + traj_lists, batch_infos = zip(*iterator_outputs) + trajs = sum(traj_lists, []) + # Merge all the dicts into one + batch_info = {} + for d in batch_infos: + batch_info.update(d) + yield self.create_batch(trajs, batch_info) + self._err_tol = 10 # Reset the error tolerance, if we run into 10 consecutive errors, we'll break + except (Exception, RuntimeError) as e: + self._err_tol -= 1 + if self._err_tol == 0: + raise e + print(f"Error in DataSource: {e} [tol={self._err_tol}]") + # print full traceback + + traceback.print_exc() + continue + + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [action_type_to_mask(t, batch) for t in atypes], + [GraphTransformerGFN.action_type_to_key(t) for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits, pad_value=1.0) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) def do_sample_model(self, model, num_from_policy, num_new_replay_samples=None): if num_new_replay_samples is not None: @@ -93,8 +134,8 @@ def iterator(): t = self.current_iter p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) - # TODO: in the cond info refactor, pass the whole thing instead of just the encoding - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + self.mark_all(trajs, source="sample") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -123,7 +164,8 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(n_this_time, t) # TODO: in the cond info refactor, pass the whole thing instead of just the encoding - trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info, p) + self.mark_all(trajs, source="sample") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -139,6 +181,7 @@ def do_sample_replay(self, num_samples): def iterator(): while self.active: trajs, *_ = self.replay_buffer.sample(num_samples) + self.mark_all(trajs, source="replay") self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} @@ -152,7 +195,8 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + self.mark_all(trajs, source="dataset") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -170,7 +214,8 @@ def iterator(): # I'm also not a fan of encode_conditional_information, it assumes lots of things about what's passed to # it and the state of the program (e.g. validation mode) cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) - trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info, p) + self.mark_all(trajs, source="dataset") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -192,7 +237,8 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + self.mark_all(trajs, source="dataset") self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -201,6 +247,10 @@ def iterator(): self.iterators.append(iterator) return self + def mark_all(self, trajs, **kw): + for t in trajs: + t.update(kw) + def call_sampling_hooks(self, trajs): batch_info = {} # TODO: just pass trajs to the hooks and deprecate passing all those arguments @@ -218,8 +268,7 @@ def create_batch(self, trajs, batch_info): ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) log_rewards = torch.stack([t["log_reward"] for t in trajs]) batch = self.algo.construct_batch(trajs, ci, log_rewards) - batch.num_online = sum(t.get("is_online", 0) for t in trajs) - batch.num_offline = len(trajs) - batch.num_online + batch.sources = [t.get("source", "unknown") for t in trajs] batch.extra_info = batch_info if "preferences" in trajs[0]["cond_info"].keys(): batch.preferences = torch.stack([t["cond_info"]["preferences"] for t in trajs]) @@ -231,10 +280,13 @@ def create_batch(self, trajs, batch_info): batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) batch.obj_props = torch.stack([t["obj_props"] for t in trajs]) - return batch + # self.validate_batch(batch, trajs) + return self._maybe_put_in_mp_buffer(batch) def compute_properties(self, trajs, mark_as_online=False): """Sets trajs' obj_props and is_valid keys by querying the task.""" + if all("obj_props" in t for t in trajs): + return # TODO: refactor obj_props into properties valid_idcs = torch.tensor([i for i in range(len(trajs)) if trajs[i].get("is_valid", True)]).long() # fetch the valid trajectories endpoints @@ -250,13 +302,15 @@ def compute_properties(self, trajs, mark_as_online=False): all_fr[valid_idcs] = obj_props for i in range(len(trajs)): trajs[i]["obj_props"] = all_fr[i] - trajs[i]["is_online"] = mark_as_online + trajs[i]["is_online"] = mark_as_online # TODO: this is deprecated in favor of 'source'? # Override the is_valid key in case the task made some objs invalid for i in valid_idcs: trajs[i]["is_valid"] = True def compute_log_rewards(self, trajs): """Sets trajs' log_reward key by querying the task.""" + if all("log_reward" in t for t in trajs): + return obj_props = torch.stack([t["obj_props"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = self.task.cond_info_to_logreward(cond_info, obj_props) @@ -267,14 +321,26 @@ def compute_log_rewards(self, trajs): def send_to_replay(self, trajs): if self.replay_buffer is not None: for t in trajs: - self.replay_buffer.push(t, t["log_reward"], t["obj_props"], t["cond_info"], t["is_valid"]) + self.replay_buffer.push( + t, + t["log_reward"], + t["obj_props"], + t["cond_info"], + t["is_valid"], + unique_obj=self.ctx.get_unique_obj(t["result"]), + priority=t.get("priority", t["log_reward"].item()), + ) def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): + if "cond_info" in trajs[i]: + continue trajs[i]["cond_info"] = {k: cond_info[k][i] for k in cond_info} def set_traj_props(self, trajs, props): for i in range(len(trajs)): + if "obj_props" in trajs[i]: + continue trajs[i]["obj_props"] = props[i] # TODO: refactor def relabel_in_hindsight(self, trajs): @@ -300,16 +366,19 @@ def sample_idcs(self, n, num_samples): def iterate_indices(self, n, num_samples): worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + if torch.distributed.is_initialized(): + num_workers *= torch.distributed.get_world_size() + wid = get_this_wid() if n == 0: # Should we be raising an error here? warning? yield np.arange(0, 0) return - if worker_info is None: # no multi-processing + if num_workers == 1: # no multi-processing, no distributed start, end, wid = 0, n, -1 else: # split the data into chunks (per-worker) - nw = worker_info.num_workers - wid = worker_info.id + nw = num_workers start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) if end - start <= num_samples: @@ -319,3 +388,20 @@ def iterate_indices(self, n, num_samples): yield np.arange(i, i + num_samples) if i + num_samples < end: yield np.arange(i + num_samples, end) + + def setup_mp_buffers(self): + if self.cfg.num_workers > 0: + self.mp_buffer_size = self.cfg.mp_buffer_size + if self.mp_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.mp_buffer_size) for _ in range(self.cfg.num_workers)] + else: + self.mp_buffer_size = None + + def _maybe_put_in_mp_buffer(self, batch): + if self.mp_buffer_size: + if not (isinstance(batch, (Batch, SeqBatch))): + warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.") + return batch + return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid) + else: + return batch diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index 7fc95024..197258b2 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -1,4 +1,6 @@ -from typing import List +import heapq +from threading import Lock +from typing import Any, List import numpy as np import torch @@ -13,27 +15,61 @@ def __init__(self, cfg: Config): Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories) In self.push(), the buffer detaches any torch tensor and sends it to the CPU. """ - self.capacity = cfg.replay.capacity - self.warmup = cfg.replay.warmup + self.capacity = cfg.replay.capacity or int(1e6) + self.warmup = cfg.replay.warmup or 0 assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity" self.buffer: List[tuple] = [] self.position = 0 - def push(self, *args): + self.treat_as_heap = cfg.replay.keep_highest_rewards + self.filter_uniques = cfg.replay.keep_only_uniques + self._uniques: set[Any] = set() + + self._lock = Lock() + + def push(self, *args, unique_obj=None, priority=None): + """unique_obj must be hashable and comparable""" if len(self.buffer) == 0: self._input_size = len(args) else: assert self._input_size == len(args), "ReplayBuffer input size must be constant" - if len(self.buffer) < self.capacity: - self.buffer.append(None) + if self.filter_uniques and unique_obj in self._uniques: + return args = detach_and_cpu(args) - self.buffer[self.position] = args - self.position = (self.position + 1) % self.capacity + self._lock.acquire() + if self.treat_as_heap: + if len(self.buffer) >= self.capacity: + if priority is None or priority > self.buffer[0][0]: + # We will use self.position for tie-breaking + *_, pop_unique = heapq.heappushpop(self.buffer, (priority, self.position, args, unique_obj)) + self.position += 1 + if self.filter_uniques: + self._uniques.remove(pop_unique) + self._uniques.add(unique_obj) + else: + pass # If the priority is lower than the lowest in the heap, we don't add it + else: + heapq.heappush(self.buffer, (priority, self.position, args, unique_obj)) + self.position += 1 + if self.filter_uniques: + self._uniques.add(unique_obj) + else: + if len(self.buffer) < self.capacity: + self.buffer.append(()) + if self.filter_uniques: + if self.position == 0 and len(self.buffer) == self.capacity: + # We're about to wrap around, so remove the oldest element + self._uniques.remove(self.buffer[0][2]) + self.buffer[self.position] = (priority, args, unique_obj) + if self.filter_uniques: + self._uniques.add(unique_obj) + self.position = (self.position + 1) % self.capacity + self._lock.release() def sample(self, batch_size): idxs = get_worker_rng().choice(len(self.buffer), batch_size) - out = list(zip(*[self.buffer[idx] for idx in idxs])) + out = list(zip(*[self.buffer[idx][2] for idx in idxs])) for i in range(len(out)): # stack if all elements are numpy arrays or torch tensors # (this is much more efficient to send arrays through multiprocessing queues) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index bac09959..238f08c6 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -80,6 +80,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu self.edges_are_duplicated = True self.edges_are_unordered = False self.fail_on_missing_attr = True + self.use_strict_edge_filling = False # Order in which models have to output logits self.action_type_order = [GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.SetEdgeAttr] @@ -105,6 +106,9 @@ def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = action: GraphAction A graph action whose type is one of Stop, AddNode, or SetEdgeAttr. """ + if aidx.action_type == -1: + # Pad action + return GraphAction(GraphActionType.Pad) if fwd: t = self.action_type_order[aidx.action_type] else: @@ -147,6 +151,8 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI A triple describing the type of action, and the corresponding row and column index for the corresponding Categorical matrix. """ + if action.action is GraphActionType.Pad: + return ActionIndex(action_type=-1, row_idx=0, col_idx=0) # Find the index of the action type, privileging the forward actions for u in [self.action_type_order, self.bck_action_type_order]: if action.action in u: @@ -222,6 +228,8 @@ def graph_to_Data(self, g: Graph) -> gd.Data: # If there are unspecified attachment points, we will disallow the agent from using the stop # action. has_unfilled_attach = False + first_unfilled_attach = True + has_fully_unfilled_attach = False for i, e in enumerate(g.edges): ed = g.edges[e] if len(ed): @@ -229,22 +237,45 @@ def graph_to_Data(self, g: Graph) -> gd.Data: node_is_connected_to_edge_with_attr[e[1]] = 1 a = ed.get("src_attach", -1) b = ed.get("dst_attach", -1) + if a < 0 or b < 0: + has_unfilled_attach = True + if self.use_strict_edge_filling: + remove_edge_attr_mask *= 0 + if a < 0 and b < 0: + has_fully_unfilled_attach = True if a >= 0: attached[e[0]].append(a) - remove_edge_attr_mask[i, 0] = 1 - else: - has_unfilled_attach = True + if self.use_strict_edge_filling: + # Since the agent has to fill in the attachment points before adding a new node, or stopping, the + # reverse is that if there is a half-filled edge attachment, the only valid move is to unset the + # other half (i.e. whichever is the first edge we find like that) + # We also will only end up being able to do that to leaf nodes (degree 1) + if degrees[e[0]] == 1 or degrees[e[1]] == 1: + remove_edge_attr_mask[i, 0] = ( + 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + ) + else: + remove_edge_attr_mask[i, 0] = 1 if b >= 0: attached[e[1]].append(b) - remove_edge_attr_mask[i, 1] = 1 - else: - has_unfilled_attach = True + if self.use_strict_edge_filling: + if degrees[e[0]] == 1 or degrees[e[1]] == 1: + remove_edge_attr_mask[i, 1] = ( + 1 if not has_unfilled_attach or has_unfilled_attach and first_unfilled_attach else 0 + ) + else: + remove_edge_attr_mask[i, 1] = 1 + + if has_unfilled_attach: + first_unfilled_attach = False # The node must be connected to at most 1 other node and in the case where it is # connected to exactly one other node, the edge connecting them must not have any # attributes. if len(g): remove_node_mask = (node_is_connected * (1 - node_is_connected_to_edge_with_attr)).reshape((-1, 1)) + if self.use_strict_edge_filling: + remove_node_mask = remove_node_mask * (1 - (has_unfilled_attach and not has_fully_unfilled_attach)) # Here we encode the attached atoms in the edge features, as well as mask out attached # atoms. @@ -271,6 +302,8 @@ def graph_to_Data(self, g: Graph) -> gd.Data: else np.ones((1, 1), np.float32) ) add_node_mask = add_node_mask * np.ones((x.shape[0], self.num_new_node_values), np.float32) + if self.use_strict_edge_filling: + add_node_mask = add_node_mask * (1 - has_unfilled_attach) stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else ones((1, 1)) data = gd.Data( diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index b10b228d..fcd89cd9 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -67,6 +67,8 @@ class GraphActionType(enum.Enum): RemoveEdge = enum.auto() RemoveNodeAttr = enum.auto() RemoveEdgeAttr = enum.auto() + # The Pad action is always unmasked, and always has logprob 0 + Pad = -1 @cached_property def cname(self): @@ -140,7 +142,7 @@ class GraphBuildingEnv: - we can generate a legal action for any attribute that isn't a default one. """ - def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True): + def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True, graph_cls=None): """A graph building environment instance Parameters @@ -152,13 +154,16 @@ def __init__(self, allow_add_edge=True, allow_node_attr=True, allow_edge_attr=Tr if True, allows this action and computes SetNodeAttr parents allow_edge_attr: bool if True, allows this action and computes SetEdgeAttr parents + graph_cls: class + if None, defaults to Graph. Called in new(). """ self.allow_add_edge = allow_add_edge self.allow_node_attr = allow_node_attr self.allow_edge_attr = allow_edge_attr + self.graph_cls = Graph if graph_cls is None else graph_cls def new(self): - return Graph() + return self.graph_cls() def step(self, g: Graph, action: GraphAction) -> Graph: """Step forward the given graph state with an action @@ -178,12 +183,12 @@ def step(self, g: Graph, action: GraphAction) -> Graph: gp = g.copy() if action.action is GraphActionType.AddEdge: a, b = action.source, action.target - assert self.allow_add_edge - assert a in g and b in g + # assert self.allow_add_edge + # assert a in g and b in g if a > b: a, b = b, a - assert a != b - assert not g.has_edge(a, b) + # assert a != b + # assert not g.has_edge(a, b) # Ideally the FA underlying this must only be able to send # create_edge actions which respect this a Graph: elif action.action is GraphActionType.AddNode: if len(g) == 0: - assert action.source == 0 # TODO: this may not be useful + # assert action.source == 0 # TODO: this may not be useful gp.add_node(0, v=action.value) else: - assert action.source in g.nodes + # assert action.source in g.nodes e = [action.source, max(g.nodes) + 1] # if kw and 'relabel' in kw: # e[1] = kw['relabel'] # for `parent` consistency, allow relabeling - assert not g.has_edge(*e) + # assert not g.has_edge(*e) gp.add_node(e[1], v=action.value) gp.add_edge(*e) elif action.action is GraphActionType.SetNodeAttr: - assert self.allow_node_attr - assert action.source in gp.nodes + # assert self.allow_node_attr + # assert action.source in gp.nodes # For some "optional" attributes like wildcard atoms, we indicate that they haven't been # chosen by the 'None' value. Here we make sure that either the attribute doesn't # exist, or that it's an optional attribute that hasn't yet been set. - assert action.attr not in gp.nodes[action.source] or gp.nodes[action.source][action.attr] is None + # assert action.attr not in gp.nodes[action.source] or gp.nodes[action.source][action.attr] is None gp.nodes[action.source][action.attr] = action.value elif action.action is GraphActionType.SetEdgeAttr: - assert self.allow_edge_attr - assert g.has_edge(action.source, action.target) - assert action.attr not in gp.edges[(action.source, action.target)] + # assert self.allow_edge_attr + # assert g.has_edge(action.source, action.target) + # assert action.attr not in gp.edges[(action.source, action.target)] gp.edges[(action.source, action.target)][action.attr] = action.value elif action.action is GraphActionType.RemoveNode: - assert g.has_node(action.source) + # assert g.has_node(action.source) gp = graph_without_node(gp, action.source) elif action.action is GraphActionType.RemoveNodeAttr: - assert g.has_node(action.source) + # assert g.has_node(action.source) gp = graph_without_node_attr(gp, action.source, action.attr) elif action.action is GraphActionType.RemoveEdge: - assert g.has_edge(action.source, action.target) + # assert g.has_edge(action.source, action.target) gp = graph_without_edge(gp, (action.source, action.target)) elif action.action is GraphActionType.RemoveEdgeAttr: - assert g.has_edge(action.source, action.target) + # assert g.has_edge(action.source, action.target) gp = graph_without_edge_attr(gp, (action.source, action.target), action.attr) else: raise ValueError(f"Unknown action type {action.action}", action.action) - gp.clear_cache() # Invalidate cached properties since we've modified the graph + # gp.clear_cache() # Invalidate cached properties since we've modified the graph return gp def parents(self, g: Graph): @@ -315,17 +320,22 @@ def count_backward_transitions(self, g: Graph, check_idempotent: bool = False): return len(self.parents(g)) c = 0 deg = [g.degree[i] for i in range(len(g.nodes))] + has_connected_edge_attr = [False] * len(g.nodes) + bridges = g.bridges() for a, b in g.edges: if deg[a] > 1 and deg[b] > 1 and len(g.edges[(a, b)]) == 0: # Can only remove edges connected to non-leaves and without # attributes (the agent has to remove the attrs, then remove # the edge). Removal cannot disconnect the graph. - new_g = graph_without_edge(g, (a, b)) - if nx.algorithms.is_connected(new_g): + if (a, b) not in bridges and (b, a) not in bridges: c += 1 - c += len(g.edges[(a, b)]) # One action per edge attr + num_attrs = len(g.edges[(a, b)]) + c += num_attrs # One action per edge attr + if num_attrs > 0: + has_connected_edge_attr[a] = True + has_connected_edge_attr[b] = True for i in g.nodes: - if deg[i] == 1 and len(g.nodes[i]) == 1 and len(g.edges[list(g.edges(i))[0]]) == 0: + if deg[i] == 1 and len(g.nodes[i]) == 1 and not has_connected_edge_attr[i]: c += 1 c += len(g.nodes[i]) - 1 # One action per node attr, except 'v' if len(g.nodes) == 1 and len(g.nodes[i]) == 1: @@ -335,7 +345,7 @@ def count_backward_transitions(self, g: Graph, check_idempotent: bool = False): def reverse(self, g: Graph, ga: GraphAction): if ga.action == GraphActionType.Stop: - return ga + return GraphAction(GraphActionType.Pad) # because we can't reverse a Stop action elif ga.action == GraphActionType.AddNode: return GraphAction(GraphActionType.RemoveNode, source=len(g.nodes)) elif ga.action == GraphActionType.AddEdge: @@ -815,7 +825,13 @@ def argmax( # if it wants to convert these indices to env-compatible actions return argmaxes - def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, batch: torch.Tensor = None): + def log_prob( + self, + actions: List[ActionIndex], + logprobs: torch.Tensor = None, + batch: torch.Tensor = None, + pad_value: float = 0.0, + ): """The log-probability of a list of action tuples, effectively indexes `logprobs` using internal slice indices. @@ -840,12 +856,14 @@ def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, ba logprobs = self.logsoftmax() if batch is None: batch = torch.arange(N, device=self.dev) + pad = torch.tensor(pad_value, device=self.dev) # We want to do the equivalent of this: # [logprobs[t][row + self.slice[t][i], col] for i, (t, row, col) in zip(batch, actions)] # but faster. # each action is a 3-tuple ActionIndex (type, row, column), where type is the index of the action type group. - actions = torch.as_tensor(actions, device=self.dev, dtype=torch.long) + unclamped_actions = torch.as_tensor(actions, device=self.dev, dtype=torch.long) + actions = unclamped_actions.clamp(0) # Clamp to 0 to avoid the -1 Pad action assert actions.shape[0] == batch.shape[0] # Check there are as many actions as batch indices # To index the log probabilities efficiently, we will ravel the array, and compute the # indices of the raveled actions. @@ -870,7 +888,10 @@ def log_prob(self, actions: List[ActionIndex], logprobs: torch.Tensor = None, ba # This is the last index in the raveled tensor, therefore the offset is just the column value col_offsets = actions[:, 2] # Index the flattened array - return all_logprobs[t_offsets + row_offsets + col_offsets] + raw_logprobs = all_logprobs[t_offsets + row_offsets + col_offsets] + # Now we replaced the Pad actions with 0s + logprobs = raw_logprobs.where(unclamped_actions[:, 0] != GraphActionType.Pad.value, pad) + return logprobs def entropy(self, logprobs=None): """The entropy for each graph categorical in the batch @@ -887,10 +908,12 @@ def entropy(self, logprobs=None): """ if logprobs is None: logprobs = self.logsoftmax() + masks = self.action_masks if self.action_masks is not None else [None] * len(logprobs) entropy = -sum( [ - scatter(i * i.exp(), b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) - for i, b in zip(logprobs, self.batch) + scatter(im, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) + for i, b, m in zip(logprobs, self.batch, masks) + for im in [i.exp() * i.masked_fill(m == 0.0, 0) if m is not None else i.exp() * i] ] ) return entropy @@ -1020,8 +1043,13 @@ def log_n(self, g) -> float: def traj_log_n(self, traj): return [self.log_n(g) for g, _ in traj] + def get_unique_obj(self, g: Graph): + return None + def action_type_to_mask(t: GraphActionType, gbatch: gd.Batch, assert_mask_exists: bool = False): + if t == GraphActionType.Pad: + return torch.ones((1, 1), device=gbatch.x.device) if assert_mask_exists: assert hasattr(gbatch, t.mask_name), f"Mask {t.mask_name} not found in graph data" return getattr(gbatch, t.mask_name) if hasattr(gbatch, t.mask_name) else torch.ones((1, 1), device=gbatch.x.device) diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index e24ff8f2..10134a11 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -13,6 +13,18 @@ DEFAULT_CHIRAL_TYPES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW] +try: + from gflownet._C import Data_collate + from gflownet._C import Graph as C_Graph + from gflownet._C import GraphDef, mol_graph_to_Data + + C_Graph_available = True +except ImportError: + import warnings + + warnings.warn("Could not import mol_graph_to_Data, Graph, GraphDef from _C, using pure python implementation") + C_Graph_available = False + class MolBuildingEnvContext(GraphBuildingEnvContext): """A specification of what is being generated for a GraphBuildingEnv @@ -61,12 +73,12 @@ def __init__( """ # idx 0 has to coincide with the default value self.atom_attr_values = { - "v": atoms + ["*"], + "v": atoms, "chi": chiral_types, "charge": charges, "expl_H": expl_H_range, "no_impl": [False, True], - "fill_wildcard": [None] + atoms, # default is, there is nothing + # "fill_wildcard": [None] + atoms, # default is, there is nothing } self.num_rw_feat = num_rw_feat self.max_nodes = max_nodes @@ -157,13 +169,29 @@ def __init__( GraphActionType.RemoveEdge, GraphActionType.RemoveEdgeAttr, ] + if C_Graph_available: + self.graph_def = GraphDef(self.atom_attr_values, self.bond_attr_values) + self.graph_cls = self._make_C_graph + assert charges == [0, 1, -1], "C impl quirk: charges must be [0, 1, -1]" + else: + self.graph_cls = Graph + + def _make_C_graph(self): + return C_Graph(self.graph_def) def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" + if aidx.action_type == -1: + # Pad action + return GraphAction(GraphActionType.Pad) if fwd: t = self.action_type_order[aidx.action_type] else: t = self.bck_action_type_order[aidx.action_type] + + if self.graph_cls is not Graph: + return g.mol_aidx_to_GraphAction((aidx.action_type, aidx.row_idx, aidx.col_idx), t) + if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: @@ -196,6 +224,8 @@ def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionIndex: """Translate a GraphAction to an index tuple""" + if action.action is GraphActionType.Pad: + return ActionIndex(action_type=-1, row_idx=0, col_idx=0) for u in [self.action_type_order, self.bck_action_type_order]: if action.action in u: type_idx = u.index(action.action) @@ -203,6 +233,9 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI else: raise ValueError(f"Unknown action type {action.action}") + if self.graph_cls is not Graph: + return (type_idx,) + g.mol_GraphAction_to_aidx(action) + if action.action is GraphActionType.Stop: row = col = 0 elif action.action is GraphActionType.AddNode: @@ -255,6 +288,10 @@ def GraphAction_to_ActionIndex(self, g: gd.Data, action: GraphAction) -> ActionI def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" + if self.graph_cls is not Graph: + cond_info = None # Todo: Implement this + return mol_graph_to_Data(g, self, torch, cond_info) + x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32) x[0, -1] = len(g.nodes) == 0 add_node_mask = np.ones((x.shape[0], self.num_new_node_values), dtype=np.float32) @@ -394,11 +431,13 @@ def graph_to_Data(self, g: Graph) -> gd.Data: def collate(self, graphs: List[gd.Data]): """Batch Data instances""" + if self.graph_cls is not Graph: + return Data_collate(graphs, ["edge_index", "non_edge_index"]) return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"]) def obj_to_graph(self, mol: Mol) -> Graph: """Convert an RDMol to a Graph""" - g = Graph() + g = self.graph_cls() mol = Mol(mol) # Make a copy if not self.allow_explicitly_aromatic: # If we disallow aromatic bonds, ask rdkit to Kekulize mol and remove aromatic bond flags @@ -470,3 +509,12 @@ def object_to_log_repr(self, g: Graph): return Chem.MolToSmiles(mol) except Exception: return "" + + def get_unique_obj(self, g: Graph): + """Convert a Graph to a canonical SMILES representation""" + try: + mol = self.graph_to_obj(g) + assert mol is not None + return Chem.CanonSmiles(Chem.MolToSmiles(mol)) + except Exception: + return "" diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index 0e0281a9..f47fa6ca 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -70,6 +70,7 @@ def __init__(self, seqs: List[torch.Tensor], pad: int): # Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this # is the total number of timesteps. self.num_graphs = self.lens.sum().item() + self.cond_info: torch.Tensor # May be set later def to(self, device): for name in dir(self): diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index 912ff9f8..bf95c40d 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -8,8 +8,9 @@ class GraphTransformerConfig(StrictDataClass): num_heads: int = 2 ln_type: str = "pre" - num_mlp_layers: int = 0 + num_mlp_layers: int = 1 concat_heads: bool = True + conv_type: str = "Transformer" class SeqPosEnc(int, Enum): diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 366e4390..3d5f91b7 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -20,6 +20,61 @@ def mlp(n_in, n_hid, n_out, n_layer, act=nn.LeakyReLU): return nn.Sequential(*sum([[nn.Linear(n[i], n[i + 1]), act()] for i in range(n_layer + 1)], [])[:-1]) +class GTGENLayer(nn.Module): + def __init__(self, num_emb, num_heads, concat, num_mlp_layers, ln_type): + super().__init__() + assert ln_type in ["pre", "post"] + self.ln_type = ln_type + n_att = num_emb * num_heads if concat else num_emb + self.gen = gnn.GENConv(num_emb, num_emb, num_layers=1, aggr="add", norm=None) + self.conv = gnn.TransformerConv(num_emb * 2, n_att // num_heads, edge_dim=num_emb, heads=num_heads) + self.linear = nn.Linear(n_att, num_emb) + self.norm1 = gnn.LayerNorm(num_emb, affine=False) + self.ff = mlp(num_emb, num_emb * 4, num_emb, num_mlp_layers) + self.norm2 = gnn.LayerNorm(num_emb, affine=False) + self.cscale = nn.Linear(num_emb, num_emb * 2) + + def forward(self, x, edge_index, edge_attr, batch, c): + cs = self.cscale(c) + if self.ln_type == "post": + agg = self.gen(x, edge_index, edge_attr) + l_h = self.linear(self.conv(torch.cat([x, agg], 1), edge_index, edge_attr)) + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + x = self.norm1(x + l_h * scale + shift, batch) + x = self.norm2(x + self.ff(x), batch) + else: + x_norm = self.norm1(x, batch) + agg = self.gen(x_norm, edge_index, edge_attr) + l_h = self.linear(self.conv(torch.cat([x, agg], 1), edge_index, edge_attr)) + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + x = x + l_h * scale + shift + x = x + self.ff(self.norm2(x, batch)) + return x + + +class GPSLayer(nn.Module): + def __init__(self, num_emb, num_heads, num_mlp_layers, residual=False): + super().__init__() + self.conv = gnn.GPSConv( + num_emb, + gnn.GINEConv(mlp(num_emb, num_emb, num_emb, num_mlp_layers), edge_dim=num_emb), + num_heads, + norm="layer_norm", + ) + self.cscale = nn.Linear(num_emb, num_emb * 2) + self.residual = residual + + def forward(self, x, edge_index, edge_attr, batch, c): + cs = self.cscale(c) + l_h = self.conv(x, edge_index, batch, edge_attr=edge_attr) + scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] + if self.residual: + x = x + l_h * scale + shift + else: + x = l_h * scale + shift + return x + + class GraphTransformer(nn.Module): """An agnostic GraphTransformer class, and the main model used by other model classes @@ -34,7 +89,18 @@ class GraphTransformer(nn.Module): """ def __init__( - self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre", concat=True + self, + x_dim, + e_dim, + g_dim, + num_emb=64, + num_layers=3, + num_heads=2, + num_noise=0, + ln_type="pre", + concat=True, + num_mlp_layers=1, + conv_type="Transformer", ): """ Parameters @@ -65,32 +131,20 @@ def __init__( super().__init__() self.num_layers = num_layers self.num_noise = num_noise - assert ln_type in ["pre", "post"] self.ln_type = ln_type + self.conv_type = conv_type self.x2h = mlp(x_dim + num_noise, num_emb, num_emb, 2) self.e2h = mlp(e_dim, num_emb, num_emb, 2) self.c2h = mlp(max(1, g_dim), num_emb, num_emb, 2) - n_att = num_emb * num_heads if concat else num_emb - self.graph2emb = nn.ModuleList( - sum( - [ - [ - gnn.GENConv(num_emb, num_emb, num_layers=1, aggr="add", norm=None), - gnn.TransformerConv(num_emb * 2, n_att // num_heads, edge_dim=num_emb, heads=num_heads), - nn.Linear(n_att, num_emb), - gnn.LayerNorm(num_emb, affine=False), - mlp(num_emb, num_emb * 4, num_emb, 1), - gnn.LayerNorm(num_emb, affine=False), - nn.Linear(num_emb, num_emb * 2), - ] - for i in range(self.num_layers) - ], - [], + if conv_type == "Transformer": + self.gnn = nn.ModuleList( + [GTGENLayer(num_emb, num_heads, concat, num_mlp_layers, ln_type) for _ in range(num_layers)] ) - ) + elif conv_type == "GPS": + self.gnn = nn.ModuleList([GPSLayer(num_emb, num_heads, num_mlp_layers) for _ in range(num_layers)]) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): + def forward(self, g: gd.Batch): """Forward pass Parameters @@ -112,41 +166,37 @@ def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): x = g.x o = self.x2h(x) e = self.e2h(g.edge_attr) - c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) + c = self.c2h(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] - # Augment the edges with a new edge to the conditioning - # information node. This new node is connected to every node - # within its graph. - u, v = torch.arange(num_total_nodes, device=o.device), g.batch + num_total_nodes - aug_edge_index = torch.cat([g.edge_index, torch.stack([u, v]), torch.stack([v, u])], 1) - e_p = torch.zeros((num_total_nodes * 2, e.shape[1]), device=g.x.device) - e_p[:, 0] = 1 # Manually create a bias term - aug_e = torch.cat([e, e_p], 0) - aug_edge_index, aug_e = add_self_loops(aug_edge_index, aug_e, "mean") - aug_batch = torch.cat([g.batch, torch.arange(c.shape[0], device=o.device)], 0) - - # Append the conditioning information node embedding to o - o = torch.cat([o, c], 0) + if self.conv_type == "Transformer": + # Augment the edges with a new edge to the conditioning + # information node. This new node is connected to every node + # within its graph. + u, v = torch.arange(num_total_nodes, device=o.device), g.batch + num_total_nodes + aug_edge_index = torch.cat([g.edge_index, torch.stack([u, v]), torch.stack([v, u])], 1) + e_p = torch.zeros((num_total_nodes * 2, e.shape[1]), device=g.x.device) + e_p[:, 0] = 1 # Manually create a bias term + aug_e = torch.cat([e, e_p], 0) + aug_edge_index, aug_e = add_self_loops(aug_edge_index, aug_e, "mean") + aug_batch = torch.cat([g.batch, torch.arange(c.shape[0], device=o.device)], 0) + # Append the conditioning information node embedding to o + o = torch.cat([o, c], 0) + else: + # GPS doesn't really need a virtual node since it's doing global pairwise attention + o = o + c[g.batch] + aug_batch = g.batch + aug_edge_index, aug_e = g.edge_index, e + for i in range(self.num_layers): - # Run the graph transformer forward - gen, trans, linear, norm1, ff, norm2, cscale = self.graph2emb[i * 7 : (i + 1) * 7] - cs = cscale(c[aug_batch]) - if self.ln_type == "post": - agg = gen(o, aug_edge_index, aug_e) - l_h = linear(trans(torch.cat([o, agg], 1), aug_edge_index, aug_e)) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] - o = norm1(o + l_h * scale + shift, aug_batch) - o = norm2(o + ff(o), aug_batch) - else: - o_norm = norm1(o, aug_batch) - agg = gen(o_norm, aug_edge_index, aug_e) - l_h = linear(trans(torch.cat([o_norm, agg], 1), aug_edge_index, aug_e)) - scale, shift = cs[:, : l_h.shape[1]], cs[:, l_h.shape[1] :] - o = o + l_h * scale + shift - o = o + ff(norm2(o, aug_batch)) - - o_final = o[: -c.shape[0]] - glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) + o = self.gnn[i](o, aug_edge_index, aug_e, aug_batch, c[aug_batch]) + + if self.conv_type == "Transformer": + # Remove the conditioning information node embedding + o_final = o[: -c.shape[0]] + glob = torch.cat([gnn.global_mean_pool(o_final, g.batch), o[-c.shape[0] :]], 1) + else: + o_final = o + glob = gnn.global_mean_pool(o_final, g.batch) return o_final, glob @@ -199,11 +249,13 @@ def __init__( num_heads=cfg.model.graph_transformer.num_heads, ln_type=cfg.model.graph_transformer.ln_type, concat=cfg.model.graph_transformer.concat_heads, + num_mlp_layers=cfg.model.graph_transformer.num_mlp_layers, + conv_type=cfg.model.graph_transformer.conv_type, ) self.env_ctx = env_ctx num_emb = cfg.model.num_emb num_final = num_emb - num_glob_final = num_emb * 2 + num_glob_final = num_emb * 2 if cfg.model.graph_transformer.conv_type == "Transformer" else num_emb num_edge_feat = num_emb if env_ctx.edges_are_unordered else num_emb * 2 self.edges_are_duplicated = env_ctx.edges_are_duplicated self.edges_are_unordered = env_ctx.edges_are_unordered @@ -240,6 +292,7 @@ def __init__( self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this self._logZ = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2) + self.logit_scaler = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2) def logZ(self, cond_info: Optional[torch.Tensor]): if cond_info is None: @@ -247,16 +300,21 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(cond_info) def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[GraphActionType]): - return GraphActionCategorical( + cat = GraphActionCategorical( g, raw_logits=[self.mlps[t.cname](emb[self._action_type_to_graph_part[t]]) for t in action_types], keys=[self._action_type_to_key[t] for t in action_types], action_masks=[action_type_to_mask(t, g) for t in action_types], types=action_types, ) + sc = self.logit_scaler( + g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device) + ) + cat.logits = [lg * sc[b] for lg, b in zip(cat.raw_logits, cat.batch)] # Setting .logits masks them + return cat - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch): + node_embeddings, graph_embeddings = self.transf(g) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): ne_row, ne_col = g.non_edge_index diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 54557922..8ecb8919 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -65,7 +65,7 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) return self._logZ(cond_info) - def forward(self, xs: SeqBatch, cond, batched=False): + def forward(self, xs: SeqBatch, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. Parameters @@ -83,6 +83,7 @@ def forward(self, xs: SeqBatch, cond, batched=False): x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) + cond = xs.cond_info if self.use_cond: cond_var = self.cond_embed(cond) # (batch, nemb) cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 103acc95..c8bebba0 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -1,6 +1,6 @@ import copy -import os import pathlib +from typing import Any import git import torch @@ -9,6 +9,7 @@ from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.local_search_tb import LocalSearchTB from gflownet.algo.soft_q_learning import SoftQLearning from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.data.replay_buffer import ReplayBuffer @@ -38,6 +39,8 @@ def setup_algo(self): algo = self.cfg.algo.method if algo == "TB": algo = TrajectoryBalance + elif algo == "LSTB": + algo = LocalSearchTB elif algo == "FM": algo = FlowMatching elif algo == "A2C": @@ -48,6 +51,9 @@ def setup_algo(self): raise ValueError(algo) self.algo = algo(self.env, self.ctx, self.cfg) + if self.algo.requires_task: + self.algo.set_task(self.task) + def setup_data(self): self.training_data = [] self.test_data = [] @@ -65,6 +71,14 @@ def _opt(self, params, lr=None, momentum=None): weight_decay=self.cfg.opt.weight_decay, eps=self.cfg.opt.adam_eps, ) + elif self.cfg.opt.opt == "adamW": + return torch.optim.AdamW( + params, + lr, + (momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) raise NotImplementedError(f"{self.cfg.opt.opt} is not implemented") @@ -82,12 +96,17 @@ def setup(self): else: Z_params = [] non_Z_params = list(self.model.parameters()) + self.opt = self._opt(non_Z_params) - self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) - ) + + if Z_params: + self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + else: + self.opt_Z = None self.sampling_tau = self.cfg.algo.sampling_tau if self.sampling_tau > 0: @@ -112,26 +131,33 @@ def setup(self): if self.print_config: print("\n\nHyperparameters:\n") print(yaml_cfg) - os.makedirs(self.cfg.log_dir, exist_ok=True) - with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: - f.write(yaml_cfg) + if self.cfg.log_dir is not None and self.rank == 0: + with open(pathlib.Path(self.cfg.log_dir) / "config.yaml", "w", encoding="utf8") as f: + f.write(yaml_cfg) - def step(self, loss: Tensor): + def step(self, loss: Tensor, train_it: int): loss.backward() - with torch.no_grad(): - g0 = model_grad_norm(self.model) - self.clip_grad_callback(self.model.parameters()) - g1 = model_grad_norm(self.model) + info: dict[str, Any] = {} + if train_it % self.cfg.algo.grad_acc_steps != 0: + return info + if self.cfg.opt.clip_grad_type is not None: + with torch.no_grad(): + g0 = model_grad_norm(self.model) + self.clip_grad_callback(self.model.parameters()) + g1 = model_grad_norm(self.model) + info["grad_norm"] = g0.item() + info["grad_norm_clip"] = g1.item() self.opt.step() self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() self.lr_sched.step() - self.lr_sched_Z.step() + if self.opt_Z is not None: + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched_Z.step() if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) - return {"grad_norm": g0, "grad_norm_clip": g1} + return info class AvgRewardHook: diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 8f2b30c9..520cde3e 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -7,6 +7,7 @@ @dataclass class SEHTaskConfig(StrictDataClass): reduced_frag: bool = False + qed_mul: bool = False @dataclass diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 2adca4f4..de4b72f0 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -129,6 +129,7 @@ class SEHFragTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False + cfg.mp_buffer_size = 32 * 1024**2 # 32 MB cfg.num_workers = 8 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 386c0494..3321c32e 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -4,10 +4,11 @@ import pathlib import shutil import time -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, List, Optional, Protocol, Union import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.utils.tensorboard import torch_geometric.data as gd @@ -16,6 +17,7 @@ from rdkit import RDLogger from torch import Tensor from torch.utils.data import DataLoader, Dataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource @@ -23,7 +25,7 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device, set_worker_rng_seed -from gflownet.utils.multiprocessing_proxy import mp_object_wrapper +from gflownet.utils.multiprocessing_proxy import BufferUnpickler, mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook from .config import Config @@ -71,7 +73,7 @@ def __init__(self, config: Config, print_config=True): ) # make sure the config is a Config object, and not the Config class itself self.cfg: Config = OmegaConf.merge(self.default_cfg, config) - self.device = torch.device(self.cfg.device) + self.device = torch.device(self.cfg.rank or self.cfg.device) set_main_process_device(self.device) # Print the loss every `self.print_every` iterations self.print_every = self.cfg.print_every @@ -81,6 +83,9 @@ def __init__(self, config: Config, print_config=True): # Will check if parameters are finite at every iteration (can be costly) self._validate_parameters = False + self.world_size = self.cfg.world_size + self.rank = self.cfg.rank + self.setup() def set_default_hps(self, base: Config): @@ -101,18 +106,19 @@ def setup_algo(self): def setup_data(self): pass - def step(self, loss: Tensor): + def step(self, loss: Tensor, train_it: int): raise NotImplementedError() def setup(self): - if os.path.exists(self.cfg.log_dir): - if self.cfg.overwrite_existing_exp: - shutil.rmtree(self.cfg.log_dir) - else: - raise ValueError( - f"Log dir {self.cfg.log_dir} already exists. Set overwrite_existing_exp=True to delete it." - ) - os.makedirs(self.cfg.log_dir) + if self.cfg.log_dir and self.rank == 0: + if os.path.exists(self.cfg.log_dir): + if self.cfg.overwrite_existing_exp: + shutil.rmtree(self.cfg.log_dir) + else: + raise ValueError( + f"Log dir {self.cfg.log_dir} already exists. Set overwrite_existing_exp=True to delete it." + ) + os.makedirs(self.cfg.log_dir) RDLogger.DisableLog("rdApp.*") set_worker_rng_seed(self.cfg.seed) @@ -122,6 +128,8 @@ def setup(self): self.setup_env_context() self.setup_algo() self.setup_model() + if self.cfg.load_model_state is not None: + self.load_state(self.cfg.load_model_state) def _wrap_for_mp(self, obj): """Wraps an object in a placeholder whose reference can be sent to a @@ -132,6 +140,7 @@ def _wrap_for_mp(self, obj): self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, + sb_size=self.cfg.mp_buffer_size, ) self.to_terminate.append(wrapper.terminate) return wrapper.placeholder @@ -181,8 +190,6 @@ def build_training_data_loader(self) -> DataLoader: def build_validation_data_loader(self) -> DataLoader: model = self._wrap_for_mp(self.model) - # TODO: we're changing the default, make sure anything that is using test data is adjusted - src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) n_drawn = self.cfg.algo.valid_num_from_policy n_from_dataset = self.cfg.algo.valid_num_from_dataset @@ -194,7 +201,7 @@ def build_validation_data_loader(self) -> DataLoader: # TODO: might be better to change total steps to total trajectories drawn src.do_sample_model_n_times(model, n_drawn, num_total=self.cfg.num_validation_gen_steps * n_drawn) - if self.cfg.log_dir: + if self.cfg.log_dir and n_drawn > 0: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) for hook in self.valid_sampling_hooks: src.add_sampling_hook(hook) @@ -219,13 +226,16 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: tick = time.time() self.model.train() try: + self.model.lock.acquire() + loss = info = None loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") - step_info = self.step(loss) + step_info = self.step(loss, train_it) self.algo.step() # This also isn't used anywhere? if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") + self.model.lock.release() except ValueError as e: os.makedirs(self.cfg.log_dir, exist_ok=True) torch.save([self.model.state_dict(), batch, loss, info], open(self.cfg.log_dir + "/dump.pkl", "wb")) @@ -241,20 +251,48 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0) -> Dict[str, Any]: tick = time.time() self.model.eval() - loss, info = self.algo.compute_batch_losses(self.model, batch) + with torch.no_grad(): + loss, info = self.algo.compute_batch_losses(self.model, batch) if hasattr(batch, "extra_info"): info.update(batch.extra_info) info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} + def _maybe_resolve_shared_buffer( + self, batch: Union[Batch, SeqBatch, tuple, list], dl: DataLoader + ) -> Union[Batch, SeqBatch]: + if dl.dataset.mp_buffer_size and isinstance(batch, (tuple, list)): + batch, wid = batch + batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load() + elif isinstance(batch, (Batch, SeqBatch)): + batch = batch.to(self.device) + return batch + + def _send_models_to_device(self): + self.model.to(self.device) + self.sampling_model.to(self.device) + if self.world_size > 1: + self.model = nn.parallel.DistributedDataParallel( + self.model.to(self.rank), device_ids=[self.rank], output_device=self.rank + ) + if self.sampling_model is not self.model: + self.sampling_model = nn.parallel.DistributedDataParallel( + self.sampling_model.to(self.rank), device_ids=[self.rank], output_device=self.rank + ) + def run(self, logger=None): """Trains the GFN for `num_training_steps` minibatches, performing validation every `validate_every` minibatches. """ if logger is None: - logger = create_logger(logfile=self.cfg.log_dir + "/train.log") + logger = create_logger(logfile=self.cfg.log_dir + "/train.log" if self.cfg.log_dir else None) self.model.to(self.device) self.sampling_model.to(self.device) + import threading + + self.model.lock = ( + threading.Lock() + ) # This is created here because you can't pickle a lock, and model is deepcopied -> sampling_model epoch_length = max(len(self.training_data), 1) valid_freq = self.cfg.validate_every # If checkpoint_every is not specified, checkpoint at every validation epoch @@ -275,6 +313,7 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() + batch = self._maybe_resolve_shared_buffer(batch, train_dl) epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -282,18 +321,21 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue - info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) + + info = self.train_batch(batch, epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() - self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) + self.log(info, it, "train") if valid_freq > 0 and it % valid_freq == 0: + logger.info("Starting validation epoch") for batch in valid_dl: + batch = self._maybe_resolve_shared_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) - self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) + self.log(info, it, "valid") end_metrics = {} for c in callbacks.values(): if hasattr(c, "on_validation_end"): @@ -311,6 +353,7 @@ def run(self, logger=None): range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): + batch = self._maybe_resolve_shared_buffer(batch, final_dl) if hasattr(batch, "extra_info"): for k, v in batch.extra_info.items(): if k not in final_info: @@ -332,7 +375,7 @@ def run(self, logger=None): del final_dl def terminate(self): - logger = logging.getLogger("logger") + logger = logging.getLogger("gflownet") for handler in logger.handlers: handler.close() @@ -344,6 +387,8 @@ def terminate(self): terminate() def _save_state(self, it): + if self.rank != 0 or self.cfg.log_dir is None: + return state = { "models_state_dict": [self.model.state_dict()], "cfg": self.cfg, @@ -360,7 +405,22 @@ def _save_state(self, it): if self.cfg.store_all_checkpoints: shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") + def load_state(self, path): + state = torch.load(path) + self.model.load_state_dict(state["models_state_dict"][0]) + def log(self, info, index, key): + # First check if we need to reduce the info across processes + if self.world_size > 1: + all_info_vals = torch.zeros(len(info)).to(self.rank) + for i, k in enumerate(sorted(info.keys())): + all_info_vals[i] = info[k] + dist.all_reduce(all_info_vals, op=dist.ReduceOp.SUM) + for i, k in enumerate(sorted(info.keys())): + info[k] = all_info_vals[i].item() / self.world_size + if self.rank != 0: # Only the master process logs + return + if not hasattr(self, "_summary_writer"): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index f9d32df3..8108c090 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -3,9 +3,11 @@ import numpy as np import torch +import torch.distributed +import torch.utils.data -def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): +def create_logger(name="gflownet", loglevel=logging.INFO, logfile=None, streamHandle=True): logger = logging.getLogger(name) logger.setLevel(loglevel) while len([logger.removeHandler(i) for i in logger.handlers]): @@ -33,11 +35,19 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHand _main_process_device = [torch.device("cpu")] -def get_worker_rng(): +def get_this_wid(): worker_info = torch.utils.data.get_worker_info() wid = worker_info.id if worker_info is not None else 0 + if torch.distributed.is_initialized(): + wid = torch.distributed.get_rank() * (worker_info.num_workers if worker_info is not None else 1) + wid + return wid + + +def get_worker_rng(): + wid = get_this_wid() if wid not in _worker_rngs: _worker_rngs[wid] = np.random.RandomState(_worker_rng_seed[0] + wid) + torch.manual_seed(_worker_rng_seed[0] + wid) return _worker_rngs[wid] diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index df13b565..13559514 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -1,22 +1,146 @@ +import io import pickle import queue import threading import traceback +from pickle import Pickler, Unpickler, UnpicklingError import torch import torch.multiprocessing as mp +DO_PIN_BUFFERS = False # So actually, we don't really need to pin buffers, but it's a good idea in some cases. +# The shared memory is already a good step. + + +class SharedPinnedBuffer: + def __init__(self, size): + self.size = size + self.buffer = torch.empty(size, dtype=torch.uint8) + self.buffer.share_memory_() + self.lock = mp.Lock() + self.do_unreg = False + + if DO_PIN_BUFFERS and not self.buffer.is_pinned(): + # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to + # pin it again; doing so will raise a CUDA error + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) + assert r == 0 + self.do_unreg = True # But then we need to unregister it later + assert self.buffer.is_pinned() + assert self.buffer.is_shared() + + def __del__(self): + if torch.utils.data.get_worker_info() is None: + if self.do_unreg: + cudart = torch.cuda.cudart() + r = cudart.cudaHostUnregister(self.buffer.data_ptr()) + assert r == 0 + + +class _BufferPicklerSentinel: + pass + + +class BufferPickler(Pickler): + def __init__(self, buf: SharedPinnedBuffer): + self._f = io.BytesIO() + super().__init__(self._f) + self.buf = buf + # The lock will be released by the consumer (BufferUnpickler) of this buffer once + # the memory has been transferred to the device and copied + self.buf.lock.acquire() + self.buf_offset = 0 + + def persistent_id(self, v): + if not isinstance(v, torch.Tensor): + return None + numel = v.numel() * v.element_size() + if self.buf_offset + numel > self.buf.size: + raise RuntimeError( + f"Tried to allocate {self.buf_offset + numel} bytes in a buffer of size {self.buf.size}. " + "Consider increasing cfg.mp_buffer_size" + ) + start = self.buf_offset + shape = tuple(v.shape) + if v.ndim > 0 and v.stride(-1) != 1 or not v.is_contiguous(): + v = v.contiguous().reshape(-1) + if v.ndim > 0 and v.stride(-1) != 1: + # We're still not contiguous, this unfortunately happens occasionally, e.g.: + # x = torch.arange(10).reshape((10, 1)) + # y = x.T[::2].T + # y.stride(), y.is_contiguous(), y.contiguous().stride() + # -> (1, 2), True, (1, 2) + v = v.flatten() + 0 + # I don't know if this comes from my misunderstanding of strides or if it's a bug in torch + # but either way torch will refuse to view this tensor as a uint8 tensor, so we have to + 0 + # to force torch to materialize it into a new tensor (it may otherwise be lazy and not materialize) + if numel > 0: + self.buf.buffer[start : start + numel] = v.flatten().view(torch.uint8) + self.buf_offset += numel + self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes + return (_BufferPicklerSentinel, (start, shape, v.dtype)) + + def dumps(self, obj): + self.dump(obj) + return (self._f.getvalue(), self.buf_offset) + + +class BufferUnpickler(Unpickler): + def __init__(self, buf: SharedPinnedBuffer, data, device): + self._f, total_size = io.BytesIO(data[0]), data[1] + super().__init__(self._f) + self.buf = buf + self.target_buf = buf.buffer[:total_size].to(device) + 0 + # Why the `+ 0`? Unfortunately, we have no way to know exactly when the consumer of the object we're + # unpickling will be done using the buffer underlying the tensor, so we have to create a copy. + # If we don't and another consumer starts using the buffer, and this consumer transfers this pinned + # buffer to the GPU, the first consumer's tensors will be corrupted, because (depending on the CUDA + # memory manager) the pinned buffer will transfer to the same GPU location. + # Hopefully, especially if the target device is the GPU, the copy will be fast and/or async. + # Note that this could be fixed by using one buffer for each worker, but that would be significantly + # more memory usage. + + def load_tensor(self, offset, shape, dtype): + numel = prod(shape) * dtype.itemsize + tensor: torch.Tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) + return tensor + + def persistent_load(self, pid): + if isinstance(pid, tuple): + sentinel, (offset, shape, dtype) = pid + if sentinel is _BufferPicklerSentinel: + return self.load_tensor(offset, shape, dtype) + return UnpicklingError("Invalid persistent id") + + def load(self): + r = super().load() + # We're done with this buffer, release it for the next consumer + self.buf.lock.release() + return r + + +def prod(ns): + p = 1 + for i in ns: + p *= i + return p + class MPObjectPlaceholder: """This class can be used for example as a model or dataset placeholder in a worker process, and translates calls to the object-placeholder into queries for the main process to execute on the real object.""" - def __init__(self, in_queues, out_queues, pickle_messages=False): + def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_size=None): self.qs = in_queues, out_queues self.device = torch.device("cpu") self.pickle_messages = pickle_messages self._is_init = False + self.shared_buffer_size = shared_buffer_size + if shared_buffer_size: + self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size) + self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size) def _check_init(self): if self._is_init: @@ -31,11 +155,15 @@ def _check_init(self): self._is_init = True def encode(self, m): + if self.shared_buffer_size: + return BufferPickler(self._buffer_to_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) return m def decode(self, m): + if self.shared_buffer_size: + m = BufferUnpickler(self._buffer_from_main, m, self.device).load() if self.pickle_messages: m = pickle.loads(m) if isinstance(m, Exception): @@ -75,7 +203,7 @@ class MPObjectProxy: Always passes CPU tensors between processes. """ - def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy. Parameters @@ -91,11 +219,14 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo If True, pickle messages sent between processes. This reduces load on shared memory, but increases load on CPU. It is recommended to activate this flag if encountering "Too many open files"-type errors. + sb_size: Optional[int] + shared buffer size """ self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) + self.use_shared_buffer = bool(sb_size) + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size) self.obj = obj if hasattr(obj, "parameters"): self.device = next(obj.parameters()).device @@ -107,11 +238,16 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.thread.start() def encode(self, m): + if self.use_shared_buffer: + return BufferPickler(self.placeholder._buffer_from_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) return m def decode(self, m): + if self.use_shared_buffer: + return BufferUnpickler(self.placeholder._buffer_to_main, m, self.device).load() + if self.pickle_messages: return pickle.loads(m) return m @@ -121,8 +257,7 @@ def to_cpu(self, i): def run(self): timeouts = 0 - - while not self.stop.is_set() or timeouts < 500: + while not self.stop.is_set() and timeouts < 5 / 1e-5: for qi, q in enumerate(self.in_queues): try: r = self.decode(q.get(True, 1e-5)) @@ -133,6 +268,8 @@ def run(self): break timeouts = 0 attr, args, kwargs = r + if hasattr(self.obj, "lock"): # TODO: this is not used anywhere? + self.obj.lock.acquire() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -143,6 +280,7 @@ def run(self): except Exception as e: result = e exc_str = traceback.format_exc() + print(exc_str) try: pickle.dumps(e) except Exception: @@ -154,39 +292,37 @@ def run(self): else: msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) + if hasattr(self.obj, "lock"): + self.obj.lock.release() def terminate(self): self.stop.set() -def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, this can be used to wrap a model such that only the main process makes cuda calls by forwarding data through the model, or a replay buffer such that the new data is pushed in from the worker processes but only the main process has to hold the full buffer in memory. - self.out_queues[qi].put(self.encode(msg)) - elif isinstance(result, dict): - msg = {k: self.to_cpu(i) for k, i in result.items()} - self.out_queues[qi].put(self.encode(msg)) - else: - msg = self.to_cpu(result) - self.out_queues[qi].put(self.encode(msg)) Parameters ---------- obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) - Lives in the main process to which method calls are passed + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple Types that will be cast to cuda when received as arguments of method calls. torch.Tensor is cast by default. pickle_messages: bool - If True, pickle messages sent between processes. This reduces load on shared - memory, but increases load on CPU. It is recommended to activate this flag if - encountering "Too many open files"-type errors. + If True, pickle messages sent between processes. This reduces load on shared + memory, but increases load on CPU. It is recommended to activate this flag if + encountering "Too many open files"-type errors. + sb_size: Optional[int] + If not None, creates a shared buffer of this size for sending tensors between processes. + Note, this will allocate two buffers of this size (one for sending, the other for receiving). Returns ------- @@ -194,4 +330,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, sb_size=sb_size) diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index ae544ec5..7bf99cb3 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -3,6 +3,10 @@ from typing import Iterable import torch +import torch.distributed +import torch.utils.data + +from gflownet.utils.misc import get_this_wid class SQLiteLogHook: @@ -14,8 +18,7 @@ def __init__(self, log_dir, ctx) -> None: def __call__(self, trajs, rewards, obj_props, cond_info): if self.log is None: - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 + self._wid = get_this_wid() os.makedirs(self.log_dir, exist_ok=True) self.log_path = f"{self.log_dir}/generated_objs_{self._wid}.db" self.log = SQLiteLog() diff --git a/tests/test_graph_building_env.py b/tests/test_graph_building_env.py index e9184cbd..adf41120 100644 --- a/tests/test_graph_building_env.py +++ b/tests/test_graph_building_env.py @@ -123,4 +123,24 @@ def test_log_prob(): def test_entropy(): cat = make_test_cat() - cat.entropy() + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + cat.action_masks = [ + torch.tensor([[0], [1], [1.0]]), + ((torch.arange(cat.logits[1].numel()) % 2) == 0).float().reshape(cat.logits[1].shape), + torch.tensor([[1, 0, 1], [0, 1, 1.0]]), + ] + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + +def test_entropy_grad(): + # Purposefully large values to test extremal behaviors + logits = torch.tensor([[100, 101, -102, 95, 10, 20, 72]]).float() + logits.requires_grad_(True) + batch = Batch.from_data_list([Data(x=torch.ones((1, 10)), y=torch.ones((2, 6)))], follow_batch=["y"]) + cat = GraphActionCategorical(batch, [logits[:, :3], logits[:, 3:].reshape(2, 2)], [None, "y"], [None, None]) + cat._epsilon = 0 + (grad_gac,) = torch.autograd.grad(cat.entropy(), logits, retain_graph=True) + assert torch.isfinite(grad_gac).all()