Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions sparsediffpy/_bindings/atoms/diag_mat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef ATOM_DIAG_MAT_H
#define ATOM_DIAG_MAT_H

#include "common.h"

static PyObject *py_make_diag_mat(PyObject *self, PyObject *args)
{
PyObject *child_capsule;

if (!PyArg_ParseTuple(args, "O", &child_capsule))
{
return NULL;
}

expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
if (!child)
{
return NULL;
}

expr *node = new_diag_mat(child);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create diag_mat node");
return NULL;
}

expr_retain(node);
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_DIAG_MAT_H */
72 changes: 72 additions & 0 deletions sparsediffpy/_bindings/atoms/kron_left.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#ifndef ATOM_KRON_LEFT_H
#define ATOM_KRON_LEFT_H

#include "bivariate.h"
#include "common.h"

/*
* Python signature:
* make_kron_left(child_capsule, C_data, C_indices, C_indptr, m, n, p, q)
*
* Creates kron(C, X) where C is (m x n) constant CSR matrix and X is the
* child expression of shape (p x q).
*/
static PyObject *py_make_kron_left(PyObject *self, PyObject *args)
{
(void) self;
PyObject *child_capsule;
PyObject *data_obj, *indices_obj, *indptr_obj;
int m, n, p, q;

if (!PyArg_ParseTuple(args, "OOOOiiii", &child_capsule, &data_obj,
&indices_obj, &indptr_obj, &m, &n, &p, &q))
{
return NULL;
}

expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
if (!child)
{
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
return NULL;
}

PyArrayObject *data_array =
(PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
PyArrayObject *indices_array =
(PyArrayObject *) PyArray_FROM_OTF(indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);
PyArrayObject *indptr_array =
(PyArrayObject *) PyArray_FROM_OTF(indptr_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);

if (!data_array || !indices_array || !indptr_array)
{
Py_XDECREF(data_array);
Py_XDECREF(indices_array);
Py_XDECREF(indptr_array);
return NULL;
}

int nnz = (int) PyArray_SIZE(data_array);
CSR_Matrix *C = new_csr_matrix(m, n, nnz);
memcpy(C->x, PyArray_DATA(data_array), (size_t) nnz * sizeof(double));
memcpy(C->i, PyArray_DATA(indices_array), (size_t) nnz * sizeof(int));
memcpy(C->p, PyArray_DATA(indptr_array), (size_t)(m + 1) * sizeof(int));

Py_DECREF(data_array);
Py_DECREF(indices_array);
Py_DECREF(indptr_array);

expr *node = new_kron_left(child, C, p, q);
free_csr_matrix(C);

if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create kron_left node");
return NULL;
}

expr_retain(node);
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_KRON_LEFT_H */
32 changes: 32 additions & 0 deletions sparsediffpy/_bindings/atoms/upper_tri.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef ATOM_UPPER_TRI_H
#define ATOM_UPPER_TRI_H

#include "common.h"

static PyObject *py_make_upper_tri(PyObject *self, PyObject *args)
{
PyObject *child_capsule;

if (!PyArg_ParseTuple(args, "O", &child_capsule))
{
return NULL;
}

expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
if (!child)
{
return NULL;
}

expr *node = new_upper_tri(child);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create upper_tri node");
return NULL;
}

expr_retain(node);
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif /* ATOM_UPPER_TRI_H */
51 changes: 51 additions & 0 deletions sparsediffpy/_bindings/atoms/vstack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef ATOM_VSTACK_H
#define ATOM_VSTACK_H

#include "common.h"

static PyObject *py_make_vstack(PyObject *self, PyObject *args)
{
(void) self;
PyObject *list_obj;
if (!PyArg_ParseTuple(args, "O", &list_obj))
{
return NULL;
}
if (!PyList_Check(list_obj))
{
PyErr_SetString(PyExc_TypeError,
"First argument must be a list of expr capsules");
return NULL;
}
Py_ssize_t n_args = PyList_Size(list_obj);
if (n_args == 0)
{
PyErr_SetString(PyExc_ValueError, "List of expr capsules cannot be empty");
return NULL;
}
expr **expr_args = (expr **) calloc(n_args, sizeof(expr *));
for (Py_ssize_t i = 0; i < n_args; ++i)
{
PyObject *item = PyList_GetItem(list_obj, i);
expr *e = (expr *) PyCapsule_GetPointer(item, EXPR_CAPSULE_NAME);
if (!e)
{
free(expr_args);
PyErr_SetString(PyExc_ValueError, "Invalid expr capsule in list");
return NULL;
}
expr_args[i] = e;
}
int n_vars = expr_args[0]->n_vars;
expr *node = new_vstack(expr_args, (int) n_args, n_vars);
free(expr_args);
if (!node)
{
PyErr_SetString(PyExc_RuntimeError, "failed to create vstack node");
return NULL;
}
expr_retain(node);
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
}

#endif // ATOM_VSTACK_H
10 changes: 10 additions & 0 deletions sparsediffpy/_bindings/bindings.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
#include "atoms/constant.h"
#include "atoms/cos.h"
#include "atoms/dense_matmul.h"
#include "atoms/diag_mat.h"
#include "atoms/diag_vec.h"
#include "atoms/entr.h"
#include "atoms/exp.h"
#include "atoms/getters.h"
#include "atoms/hstack.h"
#include "atoms/index.h"
#include "atoms/kron_left.h"
#include "atoms/left_matmul.h"
#include "atoms/linear.h"
#include "atoms/log.h"
Expand Down Expand Up @@ -45,7 +47,9 @@
#include "atoms/tanh.h"
#include "atoms/trace.h"
#include "atoms/transpose.h"
#include "atoms/upper_tri.h"
#include "atoms/variable.h"
#include "atoms/vstack.h"
#include "atoms/xexp.h"

/* Include problem bindings */
Expand Down Expand Up @@ -82,6 +86,8 @@ static PyMethodDef DNLPMethods[] = {
{"make_hstack", py_make_hstack, METH_VARARGS,
"Create hstack node from list of expr capsules and n_vars (make_hstack([e1, "
"e2, ...], n_vars))"},
{"make_vstack", py_make_vstack, METH_VARARGS,
"Create vstack node from list of expr capsules (make_vstack([e1, e2, ...]))"},
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
{"make_neg", py_make_neg, METH_VARARGS, "Create neg node"},
{"make_normal_cdf", py_make_normal_cdf, METH_VARARGS, "Create normal_cdf node"},
Expand All @@ -102,16 +108,20 @@ static PyMethodDef DNLPMethods[] = {
"Create prod_axis_one node"},
{"make_sin", py_make_sin, METH_VARARGS, "Create sin node"},
{"make_cos", py_make_cos, METH_VARARGS, "Create cos node"},
{"make_diag_mat", py_make_diag_mat, METH_VARARGS, "Create diag_mat node"},
{"make_diag_vec", py_make_diag_vec, METH_VARARGS, "Create diag_vec node"},
{"make_tan", py_make_tan, METH_VARARGS, "Create tan node"},
{"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"},
{"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"},
{"make_asinh", py_make_asinh, METH_VARARGS, "Create asinh node"},
{"make_atanh", py_make_atanh, METH_VARARGS, "Create atanh node"},
{"make_upper_tri", py_make_upper_tri, METH_VARARGS, "Create upper_tri node"},
{"make_broadcast", py_make_broadcast, METH_VARARGS, "Create broadcast node"},
{"make_entr", py_make_entr, METH_VARARGS, "Create entr node"},
{"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"},
{"make_xexp", py_make_xexp, METH_VARARGS, "Create xexp node"},
{"make_kron_left", py_make_kron_left, METH_VARARGS,
"Create kron(C, X) node where C is constant sparse matrix"},
{"make_sparse_left_matmul", py_make_sparse_left_matmul, METH_VARARGS,
"Create sparse left matmul node (A @ f(x))"},
{"make_dense_left_matmul", py_make_dense_left_matmul, METH_VARARGS,
Expand Down
Loading