Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions include/affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ expr *new_vstack(expr **args, int n_args, int n_vars);
expr *new_promote(expr *child, int d1, int d2);
expr *new_trace(expr *child);

expr *new_constant(int d1, int d2, int n_vars, const double *values);
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values);
expr *new_variable(int d1, int d2, int var_id, int n_vars);

expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
Expand All @@ -42,26 +42,27 @@ expr *new_broadcast(expr *child, int target_d1, int target_d2);
expr *new_diag_vec(expr *child);
expr *new_transpose(expr *child);

/* Left matrix multiplication: A @ f(x) where A is a constant sparse
* matrix */
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
* sparse matrix. param_node is NULL for fixed constants. */
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A);

/* Left matrix multiplication: A @ f(x) where A is a constant dense
* matrix (row-major, m x n). Uses CBLAS for efficient computation. */
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data);
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
* dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
const double *data);

/* Right matrix multiplication: f(x) @ A where A is a constant
* matrix */
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
* matrix. */
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);

expr *new_right_matmul_dense(expr *u, int m, int n, const double *data);
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
const double *data);

/* Constant scalar multiplication: a * f(x) where a is a constant
* double */
expr *new_const_scalar_mult(double a, expr *child);
/* Scalar multiplication: a * f(x) where a comes from param_node */
expr *new_scalar_mult(expr *param_node, expr *child);

/* Constant vector elementwise multiplication: a . f(x) where a is
* constant */
expr *new_const_vector_mult(const double *a, expr *child);
/* Vector elementwise multiplication: a . f(x) where a comes from
* param_node */
expr *new_vector_mult(expr *param_node, expr *child);

#endif /* AFFINE_H */
8 changes: 8 additions & 0 deletions include/problem.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ typedef struct problem
int n_vars;
int total_constraint_size;

/* parameter support */
expr **param_nodes;
int n_param_nodes;
int total_parameter_size;

/* allocated by new_problem */
double *constraint_values;
double *gradient_values;
Expand Down Expand Up @@ -76,6 +81,9 @@ void problem_init_jacobian_coo(problem *prob);
void problem_init_hessian_coo_lower_triangular(problem *prob);
void free_problem(problem *prob);

void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes);
void problem_update_params(problem *prob, const double *theta);

double problem_objective_forward(problem *prob, const double *u);
void problem_constraint_forward(problem *prob, const double *u);
void problem_gradient(problem *prob);
Expand Down
42 changes: 23 additions & 19 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,20 @@
/* Forward declaration */
struct int_double_pair;

/* Parameter ID for fixed constants (not updatable) */
#define PARAM_FIXED -1

/* Type-specific expression structures that "inherit" from expr */

/* Unified constant/parameter node. Constants use param_id == PARAM_FIXED.
* Updatable parameters use param_id >= 0 (offset into global theta). */
typedef struct parameter_expr
{
expr base;
int param_id;
bool has_been_refreshed;
} parameter_expr;

/* Linear operator: y = A * x + b */
typedef struct linear_op_expr
{
Expand Down Expand Up @@ -114,32 +126,24 @@ typedef struct left_matmul_expr
CSC_Matrix *Jchild_CSC;
CSC_Matrix *J_CSC;
int *csc_to_csr_work;
expr *param_source;
void (*refresh_param_values)(struct left_matmul_expr *);
} left_matmul_expr;

/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.
* f(x) has shape p x n, A has shape n x q, output y has shape p x q.
* Uses vec(y) = B * vec(f(x)) where B = A^T kron I_p. */
typedef struct right_matmul_expr
{
expr base;
CSR_Matrix *B; /* B = A^T kron I_p */
CSR_Matrix *BT; /* B^T for backpropagating Hessian weights */
CSC_Matrix *CSC_work;
} right_matmul_expr;

/* Constant scalar multiplication: y = a * child where a is a constant double */
typedef struct const_scalar_mult_expr
/* Scalar multiplication: y = a * child where a comes from param_source */
typedef struct scalar_mult_expr
{
expr base;
double a;
} const_scalar_mult_expr;
expr *param_source;
} scalar_mult_expr;

/* Constant vector elementwise multiplication: y = a \circ child for constant a */
typedef struct const_vector_mult_expr
/* Vector elementwise multiplication: y = a \circ child where a comes from
* param_source */
typedef struct vector_mult_expr
{
expr base;
double *a; /* length equals node->size */
} const_vector_mult_expr;
expr *param_source;
} vector_mult_expr;

/* Index/slicing: y = child[indices] where indices is a list of flat positions */
typedef struct index_expr
Expand Down
1 change: 1 addition & 0 deletions include/utils/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ typedef struct Matrix
const CSC_Matrix *J, int p);
void (*block_left_mult_values)(const struct Matrix *self, const CSC_Matrix *J,
CSC_Matrix *C);
void (*update_values)(struct Matrix *self, const double *new_values);
void (*free_fn)(struct Matrix *self);
} Matrix;

Expand Down
82 changes: 75 additions & 7 deletions src/affine/left_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,34 @@

#include "utils/utils.h"

static void refresh_param_values(left_matmul_expr *lnode)
{
if (lnode->param_source == NULL)
{
return;
}
parameter_expr *param = (parameter_expr *) lnode->param_source;
if (param->has_been_refreshed)
{
return;
}
param->has_been_refreshed = true;
lnode->refresh_param_values(lnode);
}

static void forward(expr *node, const double *u)
{
left_matmul_expr *lnode = (left_matmul_expr *) node;
refresh_param_values(lnode);

expr *x = node->left;

/* child's forward pass */
node->left->forward(node->left, u);

/* y = A_kron @ vec(f(x)) */
Matrix *A = ((left_matmul_expr *) node)->A;
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
Matrix *A = lnode->A;
int n_blocks = lnode->n_blocks;
A->block_left_mult_vec(A, x->value, node->value, n_blocks);
}

Expand All @@ -74,11 +92,16 @@ static void free_type_data(expr *node)
free_csc_matrix(lnode->Jchild_CSC);
free_csc_matrix(lnode->J_CSC);
free(lnode->csc_to_csr_work);
if (lnode->param_source != NULL)
{
free_expr(lnode->param_source);
}
lnode->A = NULL;
lnode->AT = NULL;
lnode->Jchild_CSC = NULL;
lnode->J_CSC = NULL;
lnode->csc_to_csr_work = NULL;
lnode->param_source = NULL;
}

static void jacobian_init_impl(expr *node)
Expand All @@ -98,8 +121,8 @@ static void jacobian_init_impl(expr *node)

static void eval_jacobian(expr *node)
{
expr *x = node->left;
left_matmul_expr *lnode = (left_matmul_expr *) node;
expr *x = node->left;

CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC;
CSC_Matrix *J_CSC = lnode->J_CSC;
Expand Down Expand Up @@ -130,17 +153,45 @@ static void wsum_hess_init_impl(expr *node)

static void eval_wsum_hess(expr *node, const double *w)
{
left_matmul_expr *lnode = (left_matmul_expr *) node;

/* compute A^T w*/
Matrix *AT = ((left_matmul_expr *) node)->AT;
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
Matrix *AT = lnode->AT;
int n_blocks = lnode->n_blocks;
AT->block_left_mult_vec(AT, w, node->work->dwork, n_blocks);

node->left->eval_wsum_hess(node->left, node->work->dwork);
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
node->wsum_hess->nnz * sizeof(double));
}

expr *new_left_matmul(expr *u, const CSR_Matrix *A)
static void refresh_sparse_left(left_matmul_expr *lnode)
{
Sparse_Matrix *sm_A = (Sparse_Matrix *) lnode->A;
Sparse_Matrix *sm_AT = (Sparse_Matrix *) lnode->AT;
lnode->A->update_values(lnode->A, lnode->param_source->value);
/* Recompute AT values from A */
AT_fill_values(sm_A->csr, sm_AT->csr, lnode->base.work->iwork);
}

static void refresh_dense_left(left_matmul_expr *lnode)
{
Dense_Matrix *dm_A = (Dense_Matrix *) lnode->A;
int m = dm_A->base.m;
int n = dm_A->base.n;
lnode->A->update_values(lnode->A, lnode->param_source->value);
/* Recompute AT data (transpose of row-major A) */
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
dm_AT->x[j * m + i] = dm_A->x[i * n + j];
}
}
}

expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
{
/* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
to do A @ u where u is (n, ) which in C is actually (1, n). In that case
Expand Down Expand Up @@ -187,10 +238,19 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
lnode->AT =
sparse_matrix_trans((const Sparse_Matrix *) lnode->A, node->work->iwork);

/* parameter support */
lnode->param_source = param_node;
if (param_node != NULL)
{
expr_retain(param_node);
lnode->refresh_param_values = refresh_sparse_left;
}

return node;
}

expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
const double *data)
{
int d1, d2, n_blocks;
if (u->d1 == n)
Expand Down Expand Up @@ -226,5 +286,13 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
lnode->A = new_dense_matrix(m, n, data);
lnode->AT = dense_matrix_trans((const Dense_Matrix *) lnode->A);

/* parameter support */
lnode->param_source = param_node;
if (param_node != NULL)
{
expr_retain(param_node);
lnode->refresh_param_values = refresh_dense_left;
}

return node;
}
24 changes: 15 additions & 9 deletions src/affine/constant.c → src/affine/parameter.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,36 @@
* limitations under the License.
*/
#include "affine.h"
#include "subexpr.h"
#include <stdlib.h>
#include <string.h>

static void forward(expr *node, const double *u)
{
/* Constants don't depend on u; values are already set */
/* Parameters/constants don't depend on u; values are already set */
(void) node;
(void) u;
}

static void jacobian_init_impl(expr *node)
{
/* Constant jacobian is all zeros: size x n_vars with 0 nonzeros.
* new_csr_matrix uses calloc for row pointers, so they're already 0. */
/* Zero jacobian: size x n_vars with 0 nonzeros. */
node->jacobian = new_csr_matrix(node->size, node->n_vars, 0);
}

static void eval_jacobian(expr *node)
{
/* Constant jacobian never changes - nothing to evaluate */
(void) node;
}

static void wsum_hess_init_impl(expr *node)
{
/* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */
/* Zero Hessian: n_vars x n_vars with 0 nonzeros. */
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
}

static void eval_wsum_hess(expr *node, const double *w)
{
/* Constant Hessian is always zero - nothing to compute */
(void) node;
(void) w;
}
Expand All @@ -58,12 +56,20 @@ static bool is_affine(const expr *node)
return true;
}

expr *new_constant(int d1, int d2, int n_vars, const double *values)
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values)
{
expr *node = (expr *) calloc(1, sizeof(expr));
parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr));
expr *node = &pnode->base;
init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian,
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
memcpy(node->value, values, node->size * sizeof(double));

pnode->param_id = param_id;
pnode->has_been_refreshed = false;

if (values != NULL)
{
memcpy(node->value, values, node->size * sizeof(double));
}

return node;
}
Loading
Loading