Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions include/subexpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ struct int_double_pair;

/* Type-specific expression structures that "inherit" from expr */

/* Linear operator: y = A * x + b */
/* Linear operator: y = A * x + b
* The matrix A is stored as node->jacobian (CSR). */
typedef struct linear_op_expr
{
expr base;
CSC_Matrix *A_csc;
CSR_Matrix *A_csr;
double *b; /* constant offset vector (NULL if no offset) */
} linear_op_expr;

Expand Down Expand Up @@ -98,8 +97,12 @@ typedef struct hstack_expr
typedef struct elementwise_mult_expr
{
expr base;
CSR_Matrix *CSR_work1;
CSR_Matrix *CSR_work2;
CSR_Matrix *CSR_work1; /* C = Jg2^T diag(w) Jg1 */
CSR_Matrix *CSR_work2; /* CT = C^T */
int *idx_map_C; /* C[j] -> wsum_hess pos */
int *idx_map_CT; /* CT[j] -> wsum_hess pos */
int *idx_map_Hx; /* x->wsum_hess[j] -> pos */
int *idx_map_Hy; /* y->wsum_hess[j] -> pos */
} elementwise_mult_expr;

/* Left matrix multiplication: y = A * f(x) where f(x) is an expression. Note that
Expand Down
11 changes: 11 additions & 0 deletions include/utils/CSR_sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,15 @@ void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
int spacing, int *iwork,
int *idx_map);

/* 4-way sorted merge of CSR matrices A, B, C, D (same dimensions).
* Allocates and returns the output CSR with the union sparsity pattern.
* Allocates and fills idx_maps[0..3] (one per input, size input->nnz
* each) mapping each input entry to its position in the output.
* Caller owns the returned CSR and all 4 idx_map arrays. */
CSR_Matrix *sum_4_csr_fill_sparsity_and_idx_maps(const CSR_Matrix *A,
const CSR_Matrix *B,
const CSR_Matrix *C,
const CSR_Matrix *D,
int *idx_maps[4]);

#endif /* CSR_SUM_H */
43 changes: 22 additions & 21 deletions src/affine/linear_op.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* limitations under the License.
*/
#include "affine.h"
#include "utils/CSR_Matrix.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
Expand All @@ -28,8 +29,8 @@ static void forward(expr *node, const double *u)
/* child's forward pass */
node->left->forward(node->left, u);

/* y = A * x */
csr_matvec(lin_node->A_csr, x->value, node->value, x->var_id);
/* y = A * x (A is stored as node->jacobian) */
csr_matvec(node->jacobian, x->value, node->value, x->var_id);

/* y += b (if offset exists) */
if (lin_node->b != NULL)
Expand All @@ -49,29 +50,17 @@ static bool is_affine(const expr *node)
static void free_type_data(expr *node)
{
linear_op_expr *lin_node = (linear_op_expr *) node;
/* memory pointing to by A_csr will be freed when the jacobian is freed,
so if the jacobian is not null we must not free A_csr. */

if (!node->jacobian)
{
free_csr_matrix(lin_node->A_csr);
}

free_csc_matrix(lin_node->A_csc);

if (lin_node->b != NULL)
{
free(lin_node->b);
lin_node->b = NULL;
}

lin_node->A_csr = NULL;
lin_node->A_csc = NULL;
}

static void jacobian_init_impl(expr *node)
{
node->jacobian = ((linear_op_expr *) node)->A_csr;
/* jacobian is set at construction time — nothing to do */
(void) node;
}

static void eval_jacobian(expr *node)
Expand All @@ -80,21 +69,33 @@ static void eval_jacobian(expr *node)
(void) node;
}

static void wsum_hess_init_impl(expr *node)
{
/* Linear operator Hessian is always zero */
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
}

static void eval_wsum_hess(expr *node, const double *w)
{
/* Linear operator Hessian is always zero - nothing to evaluate */
(void) node;
(void) w;
}

expr *new_linear(expr *u, const CSR_Matrix *A, const double *b)
{
assert(u->d2 == 1);
/* Allocate the type-specific struct */
linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr));
expr *node = &lin_node->base;
init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init_impl, eval_jacobian,
is_affine, NULL, NULL, free_type_data);
is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data);
node->left = u;
expr_retain(u);

/* Initialize type-specific fields */
lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz);
copy_csr_matrix(A, lin_node->A_csr);
lin_node->A_csc = csr_to_csc(A);
/* Store A directly as the jacobian (linear op jacobian is constant) */
node->jacobian = new_csr_matrix(A->m, A->n, A->nnz);
copy_csr_matrix(A, node->jacobian);

/* Initialize offset (copy b if provided, otherwise NULL) */
if (b != NULL)
Expand Down
187 changes: 134 additions & 53 deletions src/bivariate_full_dom/multiply.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@
#include "subexpr.h"
#include "utils/CSR_sum.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Scatter-add src->x into dest using precomputed index map */
static void accumulate_mapped(double *dest, const CSR_Matrix *src,
const int *idx_map)
{
for (int j = 0; j < src->nnz; j++)
{
dest[idx_map[j]] += src->x[j];
}
}

// ------------------------------------------------------------------------------
// Implementation of elementwise multiplication when both arguments are vectors.
// If one argument is a scalar variable, the broadcasting should be represented
Expand All @@ -49,7 +58,6 @@ static void jacobian_init_impl(expr *node)
{
jacobian_init(node->left);
jacobian_init(node->right);
node->work->dwork = (double *) malloc(2 * node->size * sizeof(double));
int nnz_max = node->left->jacobian->nnz + node->right->jacobian->nnz;
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);

Expand All @@ -66,7 +74,8 @@ static void eval_jacobian(expr *node)
x->eval_jacobian(x);
y->eval_jacobian(y);

/* chain rule */
/* chain rule: the jacobian of h(x) = f(g1(x), g2(x))) is Jh = J_{f, 1} J_{g1} +
* J_{f, 2} J_{g2} */
sum_scaled_csr_matrices_fill_values(x->jacobian, y->jacobian, node->jacobian,
y->value, x->value);
}
Expand All @@ -76,8 +85,9 @@ static void wsum_hess_init_impl(expr *node)
expr *x = node->left;
expr *y = node->right;

/* both x and y are variables*/
if (x->var_id != NOT_A_VARIABLE)
/* both x and y are variables, and not the same */
if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE &&
x->var_id != y->var_id)
{
assert(y->var_id != NOT_A_VARIABLE);
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 2 * node->size);
Expand Down Expand Up @@ -126,27 +136,54 @@ static void wsum_hess_init_impl(expr *node)
}
else
{
/* both are linear operators */
CSC_Matrix *A = ((linear_op_expr *) x)->A_csc;
CSC_Matrix *B = ((linear_op_expr *) y)->A_csc;
/* chain rule: the Hessian is in this case given by
wsum_hess = C + C^T + term2 + term3 where
/* Allocate workspace for Hessian computation */
elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node;
CSR_Matrix *C; /* C = B^T diag(w) A */
C = BTA_alloc(A, B);
node->work->iwork = (int *) malloc(C->m * sizeof(int));
* C = J_{g2}^T diag(w) J_{g1}
* term2 = sum_k w_k g2_k H_{g1_k}
* term3 = sum_k w_k g1_k H_{g2_k}
The two last terms are nonzero only if g1 and g2 are nonlinear.
Here, we view multiply as the composition h(x) = f(g1(x), g2(x)) where f
is the elementwise multiplication operator, and g1 and g2 are the left and
right child nodes.
*/

/* used for computing weights to wsum_hess of children */
if (!x->is_affine(x) || !y->is_affine(y))
{
node->work->dwork = (double *) malloc(node->size * sizeof(double));
}

/* prepare sparsity pattern of csc conversion */
jacobian_csc_init(x);
jacobian_csc_init(y);
CSC_Matrix *Jg1 = x->work->jacobian_csc;
CSC_Matrix *Jg2 = y->work->jacobian_csc;

/* compute sparsity of C and prepare CT */
CSR_Matrix *C = BTA_alloc(Jg1, Jg2);
node->work->iwork = (int *) malloc(C->m * sizeof(int));
CSR_Matrix *CT = AT_alloc(C, node->work->iwork);

/* initialize wsum_hessians of children */
wsum_hess_init(x);
wsum_hess_init(y);

elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node;
mul_node->CSR_work1 = C;
mul_node->CSR_work2 = CT;

/* Hessian is H = C + C^T where both are B->n x A->n, and can't be more than
* 2 * nnz(C) */
assert(C->m == node->n_vars && C->n == node->n_vars);
node->wsum_hess = new_csr_matrix(C->m, C->n, 2 * C->nnz);

/* fill sparsity pattern Hessian = C + C^T */
sum_csr_matrices_fill_sparsity(C, CT, node->wsum_hess);
/* compute sparsity pattern of H = C + C^T + term2 + term3 (we also
fill index maps telling us where to accumulate each element of each
matrix in the sum) */
int *maps[4];
node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps(C, CT, x->wsum_hess,
y->wsum_hess, maps);
mul_node->idx_map_C = maps[0];
mul_node->idx_map_CT = maps[1];
mul_node->idx_map_Hx = maps[2];
mul_node->idx_map_Hy = maps[3];
}
}

Expand All @@ -155,35 +192,98 @@ static void eval_wsum_hess(expr *node, const double *w)
expr *x = node->left;
expr *y = node->right;

/* both x and y are variables*/
if (x->var_id != NOT_A_VARIABLE)
/* both x and y are variables, and not the same */
if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE &&
x->var_id != y->var_id)
{
memcpy(node->wsum_hess->x, w, node->size * sizeof(double));
memcpy(node->wsum_hess->x + node->size, w, node->size * sizeof(double));
}
else
{
/* both are linear operators */
CSC_Matrix *A = ((linear_op_expr *) x)->A_csc;
CSC_Matrix *B = ((linear_op_expr *) y)->A_csc;
CSR_Matrix *C = ((elementwise_mult_expr *) node)->CSR_work1;
CSR_Matrix *CT = ((elementwise_mult_expr *) node)->CSR_work2;
bool is_x_affine = x->is_affine(x);
bool is_y_affine = y->is_affine(y);
// ----------------------------------------------------------------------
// convert Jacobians of children to CSC format
// (we only need to do this once if the child is affine)
// TODO: what if we have parameters? Should we set jacobian_csc_filled
// to false whenever parameters change value?
// ----------------------------------------------------------------------
if (!x->work->jacobian_csc_filled)
{
csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc,
x->work->csc_work);

/* Compute C = B^T diag(w) A */
BTDA_fill_values(A, B, w, C);
if (is_x_affine)
{
x->work->jacobian_csc_filled = true;
}
}

/* Compute CT = C^T = A^T diag(w) B */
if (!y->work->jacobian_csc_filled)
{
csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc,
y->work->csc_work);

if (is_y_affine)
{
y->work->jacobian_csc_filled = true;
}
}

CSC_Matrix *Jg1 = x->work->jacobian_csc;
CSC_Matrix *Jg2 = y->work->jacobian_csc;

// ---------------------------------------------------------------
// compute C and CT
// ---------------------------------------------------------------
elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node;
CSR_Matrix *C = mul_node->CSR_work1;
CSR_Matrix *CT = mul_node->CSR_work2;
BTDA_fill_values(Jg1, Jg2, w, C);
AT_fill_values(C, CT, node->work->iwork);

/* Hessian = C + CT = B^T diag(w) A + A^T diag(w) B */
sum_csr_matrices_fill_values(C, CT, node->wsum_hess);
// ---------------------------------------------------------------
// compute term2 and term 3
// ---------------------------------------------------------------
if (!is_x_affine)
{
for (int i = 0; i < node->size; i++)
{
node->work->dwork[i] = w[i] * y->value[i];
}
x->eval_wsum_hess(x, node->work->dwork);
}

if (!is_y_affine)
{
for (int i = 0; i < node->size; i++)
{
node->work->dwork[i] = w[i] * x->value[i];
}
y->eval_wsum_hess(y, node->work->dwork);
}

// ---------------------------------------------------------------
// compute H = C + C^T + term2 + term3
// ---------------------------------------------------------------
memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double));
accumulate_mapped(node->wsum_hess->x, C, mul_node->idx_map_C);
accumulate_mapped(node->wsum_hess->x, CT, mul_node->idx_map_CT);
accumulate_mapped(node->wsum_hess->x, x->wsum_hess, mul_node->idx_map_Hx);
accumulate_mapped(node->wsum_hess->x, y->wsum_hess, mul_node->idx_map_Hy);
}
}

static void free_type_data(expr *node)
{
free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work1);
free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work2);
elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node;
free_csr_matrix(mul_node->CSR_work1);
free_csr_matrix(mul_node->CSR_work2);
free(mul_node->idx_map_C);
free(mul_node->idx_map_CT);
free(mul_node->idx_map_Hx);
free(mul_node->idx_map_Hy);
}

static bool is_affine(const expr *node)
Expand All @@ -194,25 +294,6 @@ static bool is_affine(const expr *node)

expr *new_elementwise_mult(expr *left, expr *right)
{

/* for correctness x and y must be (1) different variables, or (2) both must be
* linear operators */
if (left->var_id != NOT_A_VARIABLE && right->var_id != NOT_A_VARIABLE &&
left->var_id == right->var_id)
{
fprintf(stderr, "Error: elementwise multiplication of a variable by itself "
"not supported.\n");
exit(EXIT_FAILURE);
}
else if ((left->var_id != NOT_A_VARIABLE && right->var_id == NOT_A_VARIABLE) ||
(left->var_id == NOT_A_VARIABLE && right->var_id != NOT_A_VARIABLE))
{
fprintf(stderr, "Error: elementwise multiplication of a variable by a "
"non-variable is not supported. (Both must be inserted "
"as linear operators)\n");
exit(EXIT_FAILURE);
}

elementwise_mult_expr *mul_node =
(elementwise_mult_expr *) calloc(1, sizeof(elementwise_mult_expr));
expr *node = &mul_node->base;
Expand Down
Loading
Loading