diff --git a/include/subexpr.h b/include/subexpr.h index b4b427c..4324991 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -28,12 +28,11 @@ struct int_double_pair; /* Type-specific expression structures that "inherit" from expr */ -/* Linear operator: y = A * x + b */ +/* Linear operator: y = A * x + b + * The matrix A is stored as node->jacobian (CSR). */ typedef struct linear_op_expr { expr base; - CSC_Matrix *A_csc; - CSR_Matrix *A_csr; double *b; /* constant offset vector (NULL if no offset) */ } linear_op_expr; @@ -98,8 +97,12 @@ typedef struct hstack_expr typedef struct elementwise_mult_expr { expr base; - CSR_Matrix *CSR_work1; - CSR_Matrix *CSR_work2; + CSR_Matrix *CSR_work1; /* C = Jg2^T diag(w) Jg1 */ + CSR_Matrix *CSR_work2; /* CT = C^T */ + int *idx_map_C; /* C[j] -> wsum_hess pos */ + int *idx_map_CT; /* CT[j] -> wsum_hess pos */ + int *idx_map_Hx; /* x->wsum_hess[j] -> pos */ + int *idx_map_Hy; /* y->wsum_hess[j] -> pos */ } elementwise_mult_expr; /* Left matrix multiplication: y = A * f(x) where f(x) is an expression. Note that diff --git a/include/utils/CSR_sum.h b/include/utils/CSR_sum.h index fcd8038..c04c7a8 100644 --- a/include/utils/CSR_sum.h +++ b/include/utils/CSR_sum.h @@ -86,4 +86,15 @@ void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, int spacing, int *iwork, int *idx_map); +/* 4-way sorted merge of CSR matrices A, B, C, D (same dimensions). + * Allocates and returns the output CSR with the union sparsity pattern. + * Allocates and fills idx_maps[0..3] (one per input, size input->nnz + * each) mapping each input entry to its position in the output. + * Caller owns the returned CSR and all 4 idx_map arrays. */ +CSR_Matrix *sum_4_csr_fill_sparsity_and_idx_maps(const CSR_Matrix *A, + const CSR_Matrix *B, + const CSR_Matrix *C, + const CSR_Matrix *D, + int *idx_maps[4]); + #endif /* CSR_SUM_H */ diff --git a/src/affine/linear_op.c b/src/affine/linear_op.c index 1dee37a..f2bc846 100644 --- a/src/affine/linear_op.c +++ b/src/affine/linear_op.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "affine.h" +#include "utils/CSR_Matrix.h" #include #include #include @@ -28,8 +29,8 @@ static void forward(expr *node, const double *u) /* child's forward pass */ node->left->forward(node->left, u); - /* y = A * x */ - csr_matvec(lin_node->A_csr, x->value, node->value, x->var_id); + /* y = A * x (A is stored as node->jacobian) */ + csr_matvec(node->jacobian, x->value, node->value, x->var_id); /* y += b (if offset exists) */ if (lin_node->b != NULL) @@ -49,29 +50,17 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { linear_op_expr *lin_node = (linear_op_expr *) node; - /* memory pointing to by A_csr will be freed when the jacobian is freed, - so if the jacobian is not null we must not free A_csr. */ - - if (!node->jacobian) - { - free_csr_matrix(lin_node->A_csr); - } - - free_csc_matrix(lin_node->A_csc); - if (lin_node->b != NULL) { free(lin_node->b); lin_node->b = NULL; } - - lin_node->A_csr = NULL; - lin_node->A_csc = NULL; } static void jacobian_init_impl(expr *node) { - node->jacobian = ((linear_op_expr *) node)->A_csr; + /* jacobian is set at construction time — nothing to do */ + (void) node; } static void eval_jacobian(expr *node) @@ -80,6 +69,19 @@ static void eval_jacobian(expr *node) (void) node; } +static void wsum_hess_init_impl(expr *node) +{ + /* Linear operator Hessian is always zero */ + node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + /* Linear operator Hessian is always zero - nothing to evaluate */ + (void) node; + (void) w; +} + expr *new_linear(expr *u, const CSR_Matrix *A, const double *b) { assert(u->d2 == 1); @@ -87,14 +89,13 @@ expr *new_linear(expr *u, const CSR_Matrix *A, const double *b) linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr)); expr *node = &lin_node->base; init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init_impl, eval_jacobian, - is_affine, NULL, NULL, free_type_data); + is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data); node->left = u; expr_retain(u); - /* Initialize type-specific fields */ - lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz); - copy_csr_matrix(A, lin_node->A_csr); - lin_node->A_csc = csr_to_csc(A); + /* Store A directly as the jacobian (linear op jacobian is constant) */ + node->jacobian = new_csr_matrix(A->m, A->n, A->nnz); + copy_csr_matrix(A, node->jacobian); /* Initialize offset (copy b if provided, otherwise NULL) */ if (b != NULL) diff --git a/src/bivariate_full_dom/multiply.c b/src/bivariate_full_dom/multiply.c index 4e38f87..bd13bfb 100644 --- a/src/bivariate_full_dom/multiply.c +++ b/src/bivariate_full_dom/multiply.c @@ -19,11 +19,20 @@ #include "subexpr.h" #include "utils/CSR_sum.h" #include -#include #include #include #include +/* Scatter-add src->x into dest using precomputed index map */ +static void accumulate_mapped(double *dest, const CSR_Matrix *src, + const int *idx_map) +{ + for (int j = 0; j < src->nnz; j++) + { + dest[idx_map[j]] += src->x[j]; + } +} + // ------------------------------------------------------------------------------ // Implementation of elementwise multiplication when both arguments are vectors. // If one argument is a scalar variable, the broadcasting should be represented @@ -49,7 +58,6 @@ static void jacobian_init_impl(expr *node) { jacobian_init(node->left); jacobian_init(node->right); - node->work->dwork = (double *) malloc(2 * node->size * sizeof(double)); int nnz_max = node->left->jacobian->nnz + node->right->jacobian->nnz; node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max); @@ -66,7 +74,8 @@ static void eval_jacobian(expr *node) x->eval_jacobian(x); y->eval_jacobian(y); - /* chain rule */ + /* chain rule: the jacobian of h(x) = f(g1(x), g2(x))) is Jh = J_{f, 1} J_{g1} + + * J_{f, 2} J_{g2} */ sum_scaled_csr_matrices_fill_values(x->jacobian, y->jacobian, node->jacobian, y->value, x->value); } @@ -76,8 +85,9 @@ static void wsum_hess_init_impl(expr *node) expr *x = node->left; expr *y = node->right; - /* both x and y are variables*/ - if (x->var_id != NOT_A_VARIABLE) + /* both x and y are variables, and not the same */ + if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && + x->var_id != y->var_id) { assert(y->var_id != NOT_A_VARIABLE); node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 2 * node->size); @@ -126,27 +136,54 @@ static void wsum_hess_init_impl(expr *node) } else { - /* both are linear operators */ - CSC_Matrix *A = ((linear_op_expr *) x)->A_csc; - CSC_Matrix *B = ((linear_op_expr *) y)->A_csc; + /* chain rule: the Hessian is in this case given by + wsum_hess = C + C^T + term2 + term3 where - /* Allocate workspace for Hessian computation */ - elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; - CSR_Matrix *C; /* C = B^T diag(w) A */ - C = BTA_alloc(A, B); - node->work->iwork = (int *) malloc(C->m * sizeof(int)); + * C = J_{g2}^T diag(w) J_{g1} + * term2 = sum_k w_k g2_k H_{g1_k} + * term3 = sum_k w_k g1_k H_{g2_k} + + The two last terms are nonzero only if g1 and g2 are nonlinear. + Here, we view multiply as the composition h(x) = f(g1(x), g2(x)) where f + is the elementwise multiplication operator, and g1 and g2 are the left and + right child nodes. + */ + /* used for computing weights to wsum_hess of children */ + if (!x->is_affine(x) || !y->is_affine(y)) + { + node->work->dwork = (double *) malloc(node->size * sizeof(double)); + } + + /* prepare sparsity pattern of csc conversion */ + jacobian_csc_init(x); + jacobian_csc_init(y); + CSC_Matrix *Jg1 = x->work->jacobian_csc; + CSC_Matrix *Jg2 = y->work->jacobian_csc; + + /* compute sparsity of C and prepare CT */ + CSR_Matrix *C = BTA_alloc(Jg1, Jg2); + node->work->iwork = (int *) malloc(C->m * sizeof(int)); CSR_Matrix *CT = AT_alloc(C, node->work->iwork); + + /* initialize wsum_hessians of children */ + wsum_hess_init(x); + wsum_hess_init(y); + + elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; mul_node->CSR_work1 = C; mul_node->CSR_work2 = CT; - /* Hessian is H = C + C^T where both are B->n x A->n, and can't be more than - * 2 * nnz(C) */ - assert(C->m == node->n_vars && C->n == node->n_vars); - node->wsum_hess = new_csr_matrix(C->m, C->n, 2 * C->nnz); - - /* fill sparsity pattern Hessian = C + C^T */ - sum_csr_matrices_fill_sparsity(C, CT, node->wsum_hess); + /* compute sparsity pattern of H = C + C^T + term2 + term3 (we also + fill index maps telling us where to accumulate each element of each + matrix in the sum) */ + int *maps[4]; + node->wsum_hess = sum_4_csr_fill_sparsity_and_idx_maps(C, CT, x->wsum_hess, + y->wsum_hess, maps); + mul_node->idx_map_C = maps[0]; + mul_node->idx_map_CT = maps[1]; + mul_node->idx_map_Hx = maps[2]; + mul_node->idx_map_Hy = maps[3]; } } @@ -155,35 +192,98 @@ static void eval_wsum_hess(expr *node, const double *w) expr *x = node->left; expr *y = node->right; - /* both x and y are variables*/ - if (x->var_id != NOT_A_VARIABLE) + /* both x and y are variables, and not the same */ + if (x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE && + x->var_id != y->var_id) { memcpy(node->wsum_hess->x, w, node->size * sizeof(double)); memcpy(node->wsum_hess->x + node->size, w, node->size * sizeof(double)); } else { - /* both are linear operators */ - CSC_Matrix *A = ((linear_op_expr *) x)->A_csc; - CSC_Matrix *B = ((linear_op_expr *) y)->A_csc; - CSR_Matrix *C = ((elementwise_mult_expr *) node)->CSR_work1; - CSR_Matrix *CT = ((elementwise_mult_expr *) node)->CSR_work2; + bool is_x_affine = x->is_affine(x); + bool is_y_affine = y->is_affine(y); + // ---------------------------------------------------------------------- + // convert Jacobians of children to CSC format + // (we only need to do this once if the child is affine) + // TODO: what if we have parameters? Should we set jacobian_csc_filled + // to false whenever parameters change value? + // ---------------------------------------------------------------------- + if (!x->work->jacobian_csc_filled) + { + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); - /* Compute C = B^T diag(w) A */ - BTDA_fill_values(A, B, w, C); + if (is_x_affine) + { + x->work->jacobian_csc_filled = true; + } + } - /* Compute CT = C^T = A^T diag(w) B */ + if (!y->work->jacobian_csc_filled) + { + csr_to_csc_fill_values(y->jacobian, y->work->jacobian_csc, + y->work->csc_work); + + if (is_y_affine) + { + y->work->jacobian_csc_filled = true; + } + } + + CSC_Matrix *Jg1 = x->work->jacobian_csc; + CSC_Matrix *Jg2 = y->work->jacobian_csc; + + // --------------------------------------------------------------- + // compute C and CT + // --------------------------------------------------------------- + elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; + CSR_Matrix *C = mul_node->CSR_work1; + CSR_Matrix *CT = mul_node->CSR_work2; + BTDA_fill_values(Jg1, Jg2, w, C); AT_fill_values(C, CT, node->work->iwork); - /* Hessian = C + CT = B^T diag(w) A + A^T diag(w) B */ - sum_csr_matrices_fill_values(C, CT, node->wsum_hess); + // --------------------------------------------------------------- + // compute term2 and term 3 + // --------------------------------------------------------------- + if (!is_x_affine) + { + for (int i = 0; i < node->size; i++) + { + node->work->dwork[i] = w[i] * y->value[i]; + } + x->eval_wsum_hess(x, node->work->dwork); + } + + if (!is_y_affine) + { + for (int i = 0; i < node->size; i++) + { + node->work->dwork[i] = w[i] * x->value[i]; + } + y->eval_wsum_hess(y, node->work->dwork); + } + + // --------------------------------------------------------------- + // compute H = C + C^T + term2 + term3 + // --------------------------------------------------------------- + memset(node->wsum_hess->x, 0, node->wsum_hess->nnz * sizeof(double)); + accumulate_mapped(node->wsum_hess->x, C, mul_node->idx_map_C); + accumulate_mapped(node->wsum_hess->x, CT, mul_node->idx_map_CT); + accumulate_mapped(node->wsum_hess->x, x->wsum_hess, mul_node->idx_map_Hx); + accumulate_mapped(node->wsum_hess->x, y->wsum_hess, mul_node->idx_map_Hy); } } static void free_type_data(expr *node) { - free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work1); - free_csr_matrix(((elementwise_mult_expr *) node)->CSR_work2); + elementwise_mult_expr *mul_node = (elementwise_mult_expr *) node; + free_csr_matrix(mul_node->CSR_work1); + free_csr_matrix(mul_node->CSR_work2); + free(mul_node->idx_map_C); + free(mul_node->idx_map_CT); + free(mul_node->idx_map_Hx); + free(mul_node->idx_map_Hy); } static bool is_affine(const expr *node) @@ -194,25 +294,6 @@ static bool is_affine(const expr *node) expr *new_elementwise_mult(expr *left, expr *right) { - - /* for correctness x and y must be (1) different variables, or (2) both must be - * linear operators */ - if (left->var_id != NOT_A_VARIABLE && right->var_id != NOT_A_VARIABLE && - left->var_id == right->var_id) - { - fprintf(stderr, "Error: elementwise multiplication of a variable by itself " - "not supported.\n"); - exit(EXIT_FAILURE); - } - else if ((left->var_id != NOT_A_VARIABLE && right->var_id == NOT_A_VARIABLE) || - (left->var_id == NOT_A_VARIABLE && right->var_id != NOT_A_VARIABLE)) - { - fprintf(stderr, "Error: elementwise multiplication of a variable by a " - "non-variable is not supported. (Both must be inserted " - "as linear operators)\n"); - exit(EXIT_FAILURE); - } - elementwise_mult_expr *mul_node = (elementwise_mult_expr *) calloc(1, sizeof(elementwise_mult_expr)); expr *node = &mul_node->base; diff --git a/src/bivariate_restricted_dom/quad_over_lin.c b/src/bivariate_restricted_dom/quad_over_lin.c index 28b6e9a..ac4f85c 100644 --- a/src/bivariate_restricted_dom/quad_over_lin.c +++ b/src/bivariate_restricted_dom/quad_over_lin.c @@ -81,13 +81,12 @@ static void jacobian_init_impl(expr *node) } else /* left node is not a variable (guaranteed to be a linear operator) */ { - linear_op_expr *lin_x = (linear_op_expr *) x; node->work->dwork = (double *) malloc(x->size * sizeof(double)); /* compute required allocation and allocate jacobian */ bool *col_nz = (bool *) calloc( node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/ - int nonzero_cols = count_nonzero_cols(lin_x->A_csr, col_nz); + int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz); node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1); /* precompute column indices */ @@ -120,6 +119,13 @@ static void jacobian_init_impl(expr *node) break; } } + + /* prepare CSC form of child jacobian for chain rule. + * For a linear operator the values are constant, so fill + * them once here. */ + jacobian_csc_init(x); + csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc, + x->work->csc_work); } } @@ -151,8 +157,6 @@ static void eval_jacobian(expr *node) } else /* x is not a variable */ { - CSC_Matrix *A_csc = ((linear_op_expr *) x)->A_csc; - /* local jacobian */ for (int j = 0; j < x->size; j++) { @@ -160,7 +164,8 @@ static void eval_jacobian(expr *node) } /* chain rule (no derivative wrt y) using CSC format */ - csc_matvec_fill_values(A_csc, node->work->dwork, node->jacobian); + csc_matvec_fill_values(x->work->jacobian_csc, node->work->dwork, + node->jacobian); /* insert derivative wrt y at right place (for correctness this assumes that y does not appear in the numerator, but this will always be diff --git a/src/expr.c b/src/expr.c index fccffeb..7754d1c 100644 --- a/src/expr.c +++ b/src/expr.c @@ -45,6 +45,10 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward, void jacobian_csc_init(expr *node) { + if (node->work->jacobian_csc != NULL) + { + return; + } node->work->csc_work = (int *) malloc(node->n_vars * sizeof(int)); node->work->jacobian_csc = csr_to_csc_fill_sparsity(node->jacobian, node->work->csc_work); diff --git a/src/utils/CSR_sum.c b/src/utils/CSR_sum.c index fd0a810..e024a0d 100644 --- a/src/utils/CSR_sum.c +++ b/src/utils/CSR_sum.c @@ -20,6 +20,7 @@ #include "utils/int_double_pair.h" #include "utils/utils.h" #include +#include #include void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C) @@ -710,6 +711,78 @@ void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C, * iwork: workspace of size at least max(A->n, A->nnz) * idx_map: output index map, size at least A->nnz */ +CSR_Matrix *sum_4_csr_fill_sparsity_and_idx_maps(const CSR_Matrix *A, + const CSR_Matrix *B, + const CSR_Matrix *C, + const CSR_Matrix *D, + int *idx_maps[4]) +{ + const CSR_Matrix *inputs[4] = {A, B, C, D}; + int m = A->m; + int n = A->n; + int nnz_ub = A->nnz + B->nnz + C->nnz + D->nnz; + + /* allocate output and index maps */ + CSR_Matrix *out = new_csr_matrix(m, n, nnz_ub); + for (int k = 0; k < 4; k++) + { + idx_maps[k] = (int *) malloc(inputs[k]->nnz * sizeof(int)); + } + + /* 4-way sorted merge per row */ + int ptrs[4], ends[4]; + int nnz = 0; + + for (int row = 0; row < m; row++) + { + out->p[row] = nnz; + for (int k = 0; k < 4; k++) + { + ptrs[k] = inputs[k]->p[row]; + ends[k] = inputs[k]->p[row + 1]; + } + + for (;;) + { + + /* find minimum column in the current row among the 4 inputs */ + int min_col = -1; + for (int k = 0; k < 4; k++) + { + if (ptrs[k] < ends[k]) + { + int col = inputs[k]->i[ptrs[k]]; + if (min_col == -1 || col < min_col) + { + min_col = col; + } + } + } + + /* if no elements in the current row, the output row will be empty */ + if (min_col == -1) + { + break; + } + + /* insert min_col into output and update idx_maps */ + out->i[nnz] = min_col; + for (int k = 0; k < 4; k++) + { + if (ptrs[k] < ends[k] && inputs[k]->i[ptrs[k]] == min_col) + { + idx_maps[k][ptrs[k]++] = nnz; + } + } + nnz++; + } + } + + out->p[m] = nnz; + out->nnz = nnz; + return out; +} + void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix *C, int spacing, int *iwork, diff --git a/tests/all_tests.c b/tests/all_tests.c index d8ba5b4..37312c0 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -270,6 +270,10 @@ int main(void) mu_run_test(test_wsum_hess_sin_sum_axis0_matmul, tests_run); mu_run_test(test_wsum_hess_logistic_sum_axis0_matmul, tests_run); mu_run_test(test_wsum_hess_sin_cos, tests_run); + mu_run_test(test_wsum_hess_Ax_Bx_multiply, tests_run); + mu_run_test(test_wsum_hess_x_x_multiply, tests_run); + mu_run_test(test_wsum_hess_AX_BX_multiply, tests_run); + mu_run_test(test_wsum_hess_multiply_deep_composite, tests_run); printf("\n--- Utility Tests ---\n"); mu_run_test(test_cblas_ddot, tests_run); diff --git a/tests/wsum_hess/affine/test_const_scalar_mult.h b/tests/wsum_hess/affine/test_const_scalar_mult.h index 31fb6a6..1756f3f 100644 --- a/tests/wsum_hess/affine/test_const_scalar_mult.h +++ b/tests/wsum_hess/affine/test_const_scalar_mult.h @@ -27,6 +27,7 @@ const char *test_wsum_hess_const_scalar_mult_log_vector(void) y->forward(y, u_vals); /* Initialize and evaluate weighted Hessian with w = [1.0, 0.5, 0.25] */ + jacobian_init(y); wsum_hess_init(y); double w[3] = {1.0, 0.5, 0.25}; y->eval_wsum_hess(y, w); @@ -72,6 +73,7 @@ const char *test_wsum_hess_const_scalar_mult_log_matrix(void) y->forward(y, u_vals); /* Initialize and evaluate weighted Hessian with w = [1.0, 1.0, 1.0, 1.0] */ + jacobian_init(y); wsum_hess_init(y); double w[4] = {1.0, 1.0, 1.0, 1.0}; y->eval_wsum_hess(y, w); diff --git a/tests/wsum_hess/affine/test_const_vector_mult.h b/tests/wsum_hess/affine/test_const_vector_mult.h index 0e02a4f..febb3ba 100644 --- a/tests/wsum_hess/affine/test_const_vector_mult.h +++ b/tests/wsum_hess/affine/test_const_vector_mult.h @@ -27,6 +27,7 @@ const char *test_wsum_hess_const_vector_mult_log_vector(void) y->forward(y, u_vals); /* Initialize and evaluate weighted Hessian with w = [1.0, 0.5, 0.25] */ + jacobian_init(y); wsum_hess_init(y); double w[3] = {1.0, 0.5, 0.25}; y->eval_wsum_hess(y, w); @@ -70,6 +71,7 @@ const char *test_wsum_hess_const_vector_mult_log_matrix(void) y->forward(y, u_vals); /* Initialize and evaluate weighted Hessian with w = [1.0, 1.0, 1.0, 1.0] */ + jacobian_init(y); wsum_hess_init(y); double w[4] = {1.0, 1.0, 1.0, 1.0}; y->eval_wsum_hess(y, w); diff --git a/tests/wsum_hess/affine/test_hstack.h b/tests/wsum_hess/affine/test_hstack.h index 1ff662d..3b53474 100644 --- a/tests/wsum_hess/affine/test_hstack.h +++ b/tests/wsum_hess/affine/test_hstack.h @@ -135,6 +135,7 @@ const char *test_wsum_hess_hstack_matrix(void) expr *hstack_node = new_hstack(args, 4, 18); hstack_node->forward(hstack_node, u_vals); + jacobian_init(hstack_node); wsum_hess_init(hstack_node); hstack_node->eval_wsum_hess(hstack_node, w); diff --git a/tests/wsum_hess/affine/test_transpose.h b/tests/wsum_hess/affine/test_transpose.h index 8f518bf..f41b5f9 100644 --- a/tests/wsum_hess/affine/test_transpose.h +++ b/tests/wsum_hess/affine/test_transpose.h @@ -18,6 +18,7 @@ const char *test_wsum_hess_transpose(void) double u[8] = {1, 3, 2, 4, 5, 7, 6, 8}; XYT->forward(XYT, u); + jacobian_init(XYT); wsum_hess_init(XYT); double w[4] = {1, 2, 3, 4}; XYT->eval_wsum_hess(XYT, w); diff --git a/tests/wsum_hess/bivariate_full_dom/test_matmul.h b/tests/wsum_hess/bivariate_full_dom/test_matmul.h index c18396b..7ca03e4 100644 --- a/tests/wsum_hess/bivariate_full_dom/test_matmul.h +++ b/tests/wsum_hess/bivariate_full_dom/test_matmul.h @@ -44,6 +44,7 @@ const char *test_wsum_hess_matmul(void) /* Forward pass and Hessian initialization */ Z->forward(Z, u_vals); + jacobian_init(Z); wsum_hess_init(Z); Z->eval_wsum_hess(Z, w); @@ -144,6 +145,7 @@ const char *test_wsum_hess_matmul_yx(void) /* Forward pass and Hessian initialization */ Z->forward(Z, u_vals); + jacobian_init(Z); wsum_hess_init(Z); Z->eval_wsum_hess(Z, w); diff --git a/tests/wsum_hess/bivariate_full_dom/test_multiply.h b/tests/wsum_hess/bivariate_full_dom/test_multiply.h index 4089502..8b447f1 100644 --- a/tests/wsum_hess/bivariate_full_dom/test_multiply.h +++ b/tests/wsum_hess/bivariate_full_dom/test_multiply.h @@ -22,6 +22,7 @@ const char *test_wsum_hess_multiply_1(void) expr *node = new_elementwise_mult(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, w); @@ -79,6 +80,7 @@ const char *test_wsum_hess_multiply_sparse_random(void) mult_node->forward(mult_node, u_vals); /* Initialize and evaluate Hessian */ + jacobian_init(mult_node); wsum_hess_init(mult_node); double w[5] = {0.50646339, 0.44756224, 0.67295241, 0.16424956, 0.03031469}; mult_node->eval_wsum_hess(mult_node, w); @@ -160,6 +162,7 @@ const char *test_wsum_hess_multiply_linear_ops(void) mult_node->forward(mult_node, u_vals); /* Initialize Hessian structure */ + jacobian_init(mult_node); wsum_hess_init(mult_node); /* Evaluate Hessian with weights */ @@ -207,8 +210,9 @@ const char *test_wsum_hess_multiply_2(void) expr *y = new_variable(3, 1, 3, 12); expr *node = new_elementwise_mult(x, y); - node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); + node->forward(node, u_vals); node->eval_wsum_hess(node, w); int expected_p[13] = {0, 0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 6, 6}; diff --git a/tests/wsum_hess/bivariate_restricted_dom/test_quad_over_lin.h b/tests/wsum_hess/bivariate_restricted_dom/test_quad_over_lin.h index 04e39c9..c428432 100644 --- a/tests/wsum_hess/bivariate_restricted_dom/test_quad_over_lin.h +++ b/tests/wsum_hess/bivariate_restricted_dom/test_quad_over_lin.h @@ -17,6 +17,7 @@ const char *test_wsum_hess_quad_over_lin_xy(void) expr *node = new_quad_over_lin(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, &w); @@ -46,6 +47,7 @@ const char *test_wsum_hess_quad_over_lin_yx(void) expr *node = new_quad_over_lin(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, &w); diff --git a/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr.h b/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr.h index 6a8a31f..a825a59 100644 --- a/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr.h +++ b/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr.h @@ -21,6 +21,7 @@ const char *test_wsum_hess_rel_entr_1(void) expr *node = new_rel_entr_vector_args(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, w); @@ -52,6 +53,7 @@ const char *test_wsum_hess_rel_entr_2(void) expr *node = new_rel_entr_vector_args(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, w); @@ -83,6 +85,7 @@ const char *test_wsum_hess_rel_entr_matrix(void) expr *node = new_rel_entr_vector_args(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, w); diff --git a/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_scalar_vector.h b/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_scalar_vector.h index 2d0d59d..d305b97 100644 --- a/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_scalar_vector.h +++ b/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_scalar_vector.h @@ -16,6 +16,7 @@ const char *test_wsum_hess_rel_entr_scalar_vector(void) expr *node = new_rel_entr_first_arg_scalar(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, w); diff --git a/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_vector_scalar.h b/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_vector_scalar.h index 968e5d5..700e163 100644 --- a/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_vector_scalar.h +++ b/tests/wsum_hess/bivariate_restricted_dom/test_rel_entr_vector_scalar.h @@ -16,6 +16,7 @@ const char *test_wsum_hess_rel_entr_vector_scalar(void) expr *node = new_rel_entr_second_arg_scalar(x, y); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, w); diff --git a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h index 475db80..6187f1d 100644 --- a/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h +++ b/tests/wsum_hess/composite/test_chain_rule_wsum_hess.h @@ -3,6 +3,7 @@ #include "elementwise_full_dom.h" #include "minunit.h" #include "numerical_diff.h" +#include "test_helpers.h" const char *test_wsum_hess_exp_sum(void) { @@ -114,3 +115,86 @@ const char *test_wsum_hess_sin_cos(void) free_expr(sin_cos_x); return 0; } + +const char *test_wsum_hess_Ax_Bx_multiply(void) +{ + /* the first and last values are not used, but good to include them in test */ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[2] = {1.33, 2.1}; + + CSR_Matrix *A = new_csr_random(2, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 2, 1.0); + expr *x = new_variable(2, 1, 1, 4); + expr *Ax = new_left_matmul(x, A); + expr *Bx = new_left_matmul(x, B); + expr *multiply = new_elementwise_mult(Ax, Bx); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(multiply, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(multiply); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_wsum_hess_x_x_multiply(void) +{ + /* the first and last values are not used, but good to include them in test */ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[2] = {1.33, 2.1}; + expr *x = new_variable(2, 1, 1, 4); + expr *multiply = new_elementwise_mult(x, x); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(multiply, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(multiply); + return 0; +} + +const char *test_wsum_hess_AX_BX_multiply(void) +{ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[4] = {1.1, 2.2, 3.3, 4.4}; + + CSR_Matrix *A = new_csr_random(2, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 2, 1.0); + expr *X = new_variable(2, 2, 0, 4); + expr *AX = new_left_matmul(X, A); + expr *BX = new_left_matmul(X, B); + expr *multiply = new_elementwise_mult(new_sin(AX), new_cos(BX)); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(multiply, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(multiply); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} + +const char *test_wsum_hess_multiply_deep_composite(void) +{ + double u_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[4] = {1.1, 2.2, 3.3, 4.4}; + + CSR_Matrix *A = new_csr_random(2, 2, 1.0); + CSR_Matrix *B = new_csr_random(2, 2, 1.0); + expr *X = new_variable(2, 2, 0, 8); + expr *Y = new_variable(2, 2, 0, 8); + expr *AX = new_left_matmul(X, A); + expr *BY = new_left_matmul(Y, B); + expr *sin_AX = new_sin(AX); + expr *cos_BY = new_cos(BY); + expr *sin_AX_mult_sin_AX = new_elementwise_mult(sin_AX, sin_AX); + expr *multiply = new_elementwise_mult(sin_AX_mult_sin_AX, cos_BY); + + mu_assert("check_wsum_hess failed", + check_wsum_hess(multiply, u_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(multiply); + free_csr_matrix(A); + free_csr_matrix(B); + return 0; +} diff --git a/tests/wsum_hess/elementwise_full_dom/test_exp.h b/tests/wsum_hess/elementwise_full_dom/test_exp.h index 848d7c0..a7eb37f 100644 --- a/tests/wsum_hess/elementwise_full_dom/test_exp.h +++ b/tests/wsum_hess/elementwise_full_dom/test_exp.h @@ -18,6 +18,7 @@ const char *test_wsum_hess_exp(void) expr *x = new_variable(3, 1, 0, 3); expr *exp_node = new_exp(x); exp_node->forward(exp_node, u_vals); + jacobian_init(exp_node); wsum_hess_init(exp_node); exp_node->eval_wsum_hess(exp_node, w); diff --git a/tests/wsum_hess/elementwise_full_dom/test_hyperbolic.h b/tests/wsum_hess/elementwise_full_dom/test_hyperbolic.h index 8a26ffc..6acab84 100644 --- a/tests/wsum_hess/elementwise_full_dom/test_hyperbolic.h +++ b/tests/wsum_hess/elementwise_full_dom/test_hyperbolic.h @@ -25,6 +25,7 @@ const char *test_wsum_hess_sinh(void) expr *x = new_variable(3, 1, 0, 3); expr *sinh_node = new_sinh(x); sinh_node->forward(sinh_node, u_vals); + jacobian_init(sinh_node); wsum_hess_init(sinh_node); sinh_node->eval_wsum_hess(sinh_node, w); @@ -60,6 +61,7 @@ const char *test_wsum_hess_tanh(void) expr *x = new_variable(3, 1, 0, 3); expr *tanh_node = new_tanh(x); tanh_node->forward(tanh_node, u_vals); + jacobian_init(tanh_node); wsum_hess_init(tanh_node); tanh_node->eval_wsum_hess(tanh_node, w); @@ -97,6 +99,7 @@ const char *test_wsum_hess_asinh(void) expr *x = new_variable(3, 1, 0, 3); expr *asinh_node = new_asinh(x); asinh_node->forward(asinh_node, u_vals); + jacobian_init(asinh_node); wsum_hess_init(asinh_node); asinh_node->eval_wsum_hess(asinh_node, w); @@ -135,6 +138,7 @@ const char *test_wsum_hess_atanh(void) expr *x = new_variable(3, 1, 0, 3); expr *atanh_node = new_atanh(x); atanh_node->forward(atanh_node, u_vals); + jacobian_init(atanh_node); wsum_hess_init(atanh_node); atanh_node->eval_wsum_hess(atanh_node, w); diff --git a/tests/wsum_hess/elementwise_full_dom/test_logistic.h b/tests/wsum_hess/elementwise_full_dom/test_logistic.h index 67d2c8f..3dd22a1 100644 --- a/tests/wsum_hess/elementwise_full_dom/test_logistic.h +++ b/tests/wsum_hess/elementwise_full_dom/test_logistic.h @@ -28,6 +28,7 @@ const char *test_wsum_hess_logistic(void) logistic_node->forward(logistic_node, u_vals); jacobian_init(logistic_node); logistic_node->eval_jacobian(logistic_node); + jacobian_init(logistic_node); wsum_hess_init(logistic_node); logistic_node->eval_wsum_hess(logistic_node, w); diff --git a/tests/wsum_hess/elementwise_full_dom/test_power.h b/tests/wsum_hess/elementwise_full_dom/test_power.h index 915c224..f104a94 100644 --- a/tests/wsum_hess/elementwise_full_dom/test_power.h +++ b/tests/wsum_hess/elementwise_full_dom/test_power.h @@ -18,6 +18,7 @@ const char *test_wsum_hess_power(void) expr *x = new_variable(3, 1, 0, 3); expr *power_node = new_power(x, 3.0); power_node->forward(power_node, u_vals); + jacobian_init(power_node); wsum_hess_init(power_node); power_node->eval_wsum_hess(power_node, w); diff --git a/tests/wsum_hess/elementwise_full_dom/test_trig.h b/tests/wsum_hess/elementwise_full_dom/test_trig.h index 3add144..19460b6 100644 --- a/tests/wsum_hess/elementwise_full_dom/test_trig.h +++ b/tests/wsum_hess/elementwise_full_dom/test_trig.h @@ -18,6 +18,7 @@ const char *test_wsum_hess_sin(void) expr *x = new_variable(3, 1, 0, 3); expr *sin_node = new_sin(x); sin_node->forward(sin_node, u_vals); + jacobian_init(sin_node); wsum_hess_init(sin_node); sin_node->eval_wsum_hess(sin_node, w); @@ -46,6 +47,7 @@ const char *test_wsum_hess_cos(void) expr *x = new_variable(3, 1, 0, 3); expr *cos_node = new_cos(x); cos_node->forward(cos_node, u_vals); + jacobian_init(cos_node); wsum_hess_init(cos_node); cos_node->eval_wsum_hess(cos_node, w); @@ -74,6 +76,7 @@ const char *test_wsum_hess_tan(void) expr *x = new_variable(3, 1, 0, 3); expr *tan_node = new_tan(x); tan_node->forward(tan_node, u_vals); + jacobian_init(tan_node); wsum_hess_init(tan_node); tan_node->eval_wsum_hess(tan_node, w); diff --git a/tests/wsum_hess/elementwise_full_dom/test_xexp.h b/tests/wsum_hess/elementwise_full_dom/test_xexp.h index 4b6ad57..5630b65 100644 --- a/tests/wsum_hess/elementwise_full_dom/test_xexp.h +++ b/tests/wsum_hess/elementwise_full_dom/test_xexp.h @@ -18,6 +18,7 @@ const char *test_wsum_hess_xexp(void) expr *x = new_variable(3, 1, 0, 3); expr *xexp_node = new_xexp(x); xexp_node->forward(xexp_node, u_vals); + jacobian_init(xexp_node); wsum_hess_init(xexp_node); xexp_node->eval_wsum_hess(xexp_node, w); diff --git a/tests/wsum_hess/elementwise_restricted_dom/test_entr.h b/tests/wsum_hess/elementwise_restricted_dom/test_entr.h index 75b92b9..1c3b433 100644 --- a/tests/wsum_hess/elementwise_restricted_dom/test_entr.h +++ b/tests/wsum_hess/elementwise_restricted_dom/test_entr.h @@ -18,6 +18,7 @@ const char *test_wsum_hess_entr(void) expr *x = new_variable(3, 1, 0, 3); expr *entr_node = new_entr(x); entr_node->forward(entr_node, u_vals); + jacobian_init(entr_node); wsum_hess_init(entr_node); entr_node->eval_wsum_hess(entr_node, w); diff --git a/tests/wsum_hess/elementwise_restricted_dom/test_log.h b/tests/wsum_hess/elementwise_restricted_dom/test_log.h index ef9c1e1..3632ff1 100644 --- a/tests/wsum_hess/elementwise_restricted_dom/test_log.h +++ b/tests/wsum_hess/elementwise_restricted_dom/test_log.h @@ -30,6 +30,7 @@ const char *test_wsum_hess_log(void) expr *x = new_variable(3, 1, 2, 7); expr *log_node = new_log(x); log_node->forward(log_node, u_vals); + jacobian_init(log_node); wsum_hess_init(log_node); log_node->eval_wsum_hess(log_node, w); diff --git a/tests/wsum_hess/other/test_prod.h b/tests/wsum_hess/other/test_prod.h index 6e8ff22..f932366 100644 --- a/tests/wsum_hess/other/test_prod.h +++ b/tests/wsum_hess/other/test_prod.h @@ -17,6 +17,7 @@ const char *test_wsum_hess_prod_no_zero(void) expr *p = new_prod(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, &w); @@ -44,6 +45,7 @@ const char *test_wsum_hess_prod_one_zero(void) expr *p = new_prod(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, &w); @@ -77,6 +79,7 @@ const char *test_wsum_hess_prod_two_zeros(void) expr *p = new_prod(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, &w); @@ -105,6 +108,7 @@ const char *test_wsum_hess_prod_many_zeros(void) expr *p = new_prod(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, &w); diff --git a/tests/wsum_hess/other/test_prod_axis_one.h b/tests/wsum_hess/other/test_prod_axis_one.h index feb2134..b13bbae 100644 --- a/tests/wsum_hess/other/test_prod_axis_one.h +++ b/tests/wsum_hess/other/test_prod_axis_one.h @@ -30,6 +30,7 @@ const char *test_wsum_hess_prod_axis_one_no_zeros(void) expr *p = new_prod_axis_one(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); @@ -97,6 +98,7 @@ const char *test_wsum_hess_prod_axis_one_one_zero(void) expr *p = new_prod_axis_one(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); @@ -200,6 +202,7 @@ const char *test_wsum_hess_prod_axis_one_mixed_zeros(void) expr *p = new_prod_axis_one(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); @@ -350,6 +353,7 @@ const char *test_wsum_hess_prod_axis_one_2x2(void) expr *p = new_prod_axis_one(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); diff --git a/tests/wsum_hess/other/test_prod_axis_zero.h b/tests/wsum_hess/other/test_prod_axis_zero.h index 3e6ccc5..9e7fe02 100644 --- a/tests/wsum_hess/other/test_prod_axis_zero.h +++ b/tests/wsum_hess/other/test_prod_axis_zero.h @@ -32,6 +32,7 @@ const char *test_wsum_hess_prod_axis_zero_no_zeros(void) expr *p = new_prod_axis_zero(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); @@ -107,6 +108,7 @@ const char *test_wsum_hess_prod_axis_zero_mixed_zeros(void) expr *p = new_prod_axis_zero(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); @@ -209,6 +211,7 @@ const char *test_wsum_hess_prod_axis_zero_one_zero(void) expr *p = new_prod_axis_zero(x); p->forward(p, u_vals); + jacobian_init(p); wsum_hess_init(p); p->eval_wsum_hess(p, w_vals); diff --git a/tests/wsum_hess/other/test_quad_form.h b/tests/wsum_hess/other/test_quad_form.h index 0f70792..6387e45 100644 --- a/tests/wsum_hess/other/test_quad_form.h +++ b/tests/wsum_hess/other/test_quad_form.h @@ -30,6 +30,7 @@ const char *test_wsum_hess_quad_form(void) jacobian_init(node); node->forward(node, u_vals); + jacobian_init(node); wsum_hess_init(node); node->eval_wsum_hess(node, &w);