diff --git a/include/bivariate.h b/include/bivariate.h index 7260f33..82561e3 100644 --- a/include/bivariate.h +++ b/include/bivariate.h @@ -42,6 +42,10 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A); expr *new_right_matmul_dense(expr *u, int m, int n, const double *data); +/* Kronecker product: kron(C, X) where C is a constant sparse matrix and X is + * an expression of shape (p x q). Output has shape (m*p, n*q). */ +expr *new_kron_left(expr *child, const CSR_Matrix *C, int p, int q); + /* Constant scalar multiplication: a * f(x) where a is a constant double */ expr *new_const_scalar_mult(double a, expr *child); diff --git a/include/subexpr.h b/include/subexpr.h index b4b427c..45184bc 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -116,6 +116,16 @@ typedef struct left_matmul_expr int *csc_to_csr_work; } left_matmul_expr; +/* Kronecker product: Z = kron(C, X) where C is a constant matrix */ +typedef struct kron_left_expr +{ + expr base; + int m, n, p, q; /* C is (m x n), child X is (p x q) */ + CSR_Matrix *C; /* constant matrix, stored as CSR */ + int *row_map; /* output row -> child Jacobian row */ + int *row_scale_idx; /* output row -> index into C->x for the scale factor */ +} kron_left_expr; + /* Right matrix multiplication: y = f(x) * A where f(x) is an expression. * f(x) has shape p x n, A has shape n x q, output y has shape p x q. * Uses vec(y) = B * vec(f(x)) where B = A^T kron I_p. */ diff --git a/src/bivariate/kron_left.c b/src/bivariate/kron_left.c new file mode 100644 index 0000000..7923241 --- /dev/null +++ b/src/bivariate/kron_left.c @@ -0,0 +1,338 @@ +/* + * Copyright 2026 Daniel Cederberg and William Zhang + * + * This file is part of the DNLP-differentiation-engine project. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "bivariate.h" +#include "subexpr.h" +#include +#include +#include +#include + +/* + * Kronecker product: Z = kron(C, X) where C is a constant (m x n) matrix + * and X is a variable expression of shape (p x q). + * + * Output Z has shape (m*p, n*q), stored column-major as vec(Z) of length m*p*n*q. + * + * Key identity: Z[i*p+k, j*q+l] = C[i,j] * X[k,l] + * In column-major: vec(Z)[r] where r = (j*q+l)*(m*p) + i*p + k + * depends on vec(X)[s] where s = l*p + k, with coefficient C[i,j]. + * + * Jacobian structure: each row r of J_Z is a scaled copy of row s of J_X. + * Only rows where C[i,j] != 0 are non-trivial. + */ + +/* ------------------------------------------------------------------ */ +/* Forward pass */ +/* ------------------------------------------------------------------ */ +static void forward(expr *node, const double *u) +{ + kron_left_expr *kn = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = kn->C; + int m = kn->m, p = kn->p, q = kn->q; + int mp = m * p; + + child->forward(child, u); + + /* Zero output first */ + memset(node->value, 0, (size_t) node->size * sizeof(double)); + + /* For each nonzero C[i,j], fill block: Z[i*p+k, j*q+l] = C[i,j] * X[k,l] */ + for (int i = 0; i < m; i++) + { + for (int idx = C->p[i]; idx < C->p[i + 1]; idx++) + { + int j = C->i[idx]; + double cij = C->x[idx]; + + for (int l = 0; l < q; l++) + { + int z_col_start = (j * q + l) * mp + i * p; + int x_col_start = l * p; + for (int k = 0; k < p; k++) + { + node->value[z_col_start + k] = + cij * child->value[x_col_start + k]; + } + } + } + } +} + +/* ------------------------------------------------------------------ */ +/* Affine check */ +/* ------------------------------------------------------------------ */ +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +/* ------------------------------------------------------------------ */ +/* Jacobian initialization */ +/* ------------------------------------------------------------------ */ +static void jacobian_init(expr *node) +{ + kron_left_expr *kn = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = kn->C; + int m = kn->m, p = kn->p, q = kn->q; + int mp = m * p; + int out_size = node->size; /* m * p * n * q */ + + /* Initialize child's Jacobian */ + child->jacobian_init(child); + CSR_Matrix *Jchild = child->jacobian; + + /* + * Count total nnz: for each output row r corresponding to nonzero C[i,j], + * the nnz equals the nnz of child's Jacobian row s = l*p + k. + */ + int total_nnz = 0; + for (int i = 0; i < m; i++) + { + int row_nnz_C = C->p[i + 1] - C->p[i]; + if (row_nnz_C == 0) continue; + + for (int l = 0; l < q; l++) + { + for (int k = 0; k < p; k++) + { + int s = l * p + k; /* child row */ + int child_row_nnz = Jchild->p[s + 1] - Jchild->p[s]; + total_nnz += row_nnz_C * child_row_nnz; + } + } + } + + /* Allocate Jacobian */ + node->jacobian = new_csr_matrix(out_size, node->n_vars, total_nnz); + + /* Allocate row_map and row_scale arrays for eval_jacobian */ + kn->row_map = (int *) malloc((size_t) out_size * sizeof(int)); + kn->row_scale_idx = (int *) malloc((size_t) out_size * sizeof(int)); + + /* + * Fill Jacobian sparsity pattern. + * Iterate in column-major order over Z: + * r = (j*q + l) * mp + i*p + k + * maps to child row s = l*p + k, with scale C[i,j] + */ + int nnz_idx = 0; + for (int r = 0; r < out_size; r++) + { + node->jacobian->p[r] = nnz_idx; + + /* Decode r: r = col_Z * mp + row_Z */ + int row_Z = r % mp; /* i*p + k */ + int col_Z = r / mp; /* j*q + l */ + int i = row_Z / p; + int k = row_Z % p; + int j = col_Z / q; + int l = col_Z % q; + int s = l * p + k; /* child Jacobian row */ + + kn->row_map[r] = s; + + /* Find C[i,j] in CSR */ + double cij = 0.0; + int cij_found = 0; + for (int idx = C->p[i]; idx < C->p[i + 1]; idx++) + { + if (C->i[idx] == j) + { + cij = C->x[idx]; + kn->row_scale_idx[r] = idx; + cij_found = 1; + break; + } + } + + if (!cij_found || cij == 0.0) + { + kn->row_scale_idx[r] = -1; /* mark as zero row */ + continue; + } + + /* Copy child's column indices for row s */ + int child_start = Jchild->p[s]; + int child_end = Jchild->p[s + 1]; + int child_row_nnz = child_end - child_start; + + memcpy(node->jacobian->i + nnz_idx, Jchild->i + child_start, + (size_t) child_row_nnz * sizeof(int)); + nnz_idx += child_row_nnz; + } + node->jacobian->p[out_size] = nnz_idx; + node->jacobian->nnz = nnz_idx; + assert(nnz_idx == total_nnz); +} + +/* ------------------------------------------------------------------ */ +/* Jacobian evaluation */ +/* ------------------------------------------------------------------ */ +static void eval_jacobian(expr *node) +{ + kron_left_expr *kn = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = kn->C; + CSR_Matrix *Jchild = child->jacobian; + CSR_Matrix *J = node->jacobian; + int out_size = node->size; + + /* Evaluate child's Jacobian */ + child->eval_jacobian(child); + + /* Fill values: each active row r copies child row s scaled by C[i,j] */ + for (int r = 0; r < out_size; r++) + { + int j_start = J->p[r]; + int j_end = J->p[r + 1]; + int row_nnz = j_end - j_start; + if (row_nnz == 0) continue; + + int s = kn->row_map[r]; + int c_idx = kn->row_scale_idx[r]; + double cij = C->x[c_idx]; + + int child_start = Jchild->p[s]; + + for (int t = 0; t < row_nnz; t++) + { + J->x[j_start + t] = cij * Jchild->x[child_start + t]; + } + } +} + +/* ------------------------------------------------------------------ */ +/* Weighted-sum Hessian initialization */ +/* ------------------------------------------------------------------ */ +static void wsum_hess_init(expr *node) +{ + expr *child = node->left; + + /* Initialize child's Hessian */ + child->wsum_hess_init(child); + + /* kron_left is linear in X, so Hessian has same sparsity as child */ + node->wsum_hess = + new_csr_matrix(node->n_vars, node->n_vars, child->wsum_hess->nnz); + memcpy(node->wsum_hess->p, child->wsum_hess->p, + (size_t) (node->n_vars + 1) * sizeof(int)); + memcpy(node->wsum_hess->i, child->wsum_hess->i, + (size_t) child->wsum_hess->nnz * sizeof(int)); + + /* Allocate workspace for reverse-mode weight accumulation */ + node->dwork = (double *) calloc((size_t) child->size, sizeof(double)); +} + +/* ------------------------------------------------------------------ */ +/* Weighted-sum Hessian evaluation */ +/* ------------------------------------------------------------------ */ +static void eval_wsum_hess(expr *node, const double *w) +{ + kron_left_expr *kn = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = kn->C; + int m = kn->m, p = kn->p, q = kn->q; + int mp = m * p; + int child_size = child->size; + + /* + * Reverse mode: w_child[s] = sum over active r mapping to s of (C[i,j] * w[r]) + * This is the adjoint of the forward pass. + */ + memset(node->dwork, 0, (size_t) child_size * sizeof(double)); + + for (int i = 0; i < m; i++) + { + for (int idx = C->p[i]; idx < C->p[i + 1]; idx++) + { + int j = C->i[idx]; + double cij = C->x[idx]; + + for (int l = 0; l < q; l++) + { + for (int k = 0; k < p; k++) + { + int r = (j * q + l) * mp + i * p + k; + int s = l * p + k; + node->dwork[s] += cij * w[r]; + } + } + } + } + + /* Delegate to child */ + child->eval_wsum_hess(child, node->dwork); + memcpy(node->wsum_hess->x, child->wsum_hess->x, + (size_t) node->wsum_hess->nnz * sizeof(double)); +} + +/* ------------------------------------------------------------------ */ +/* Cleanup */ +/* ------------------------------------------------------------------ */ +static void free_type_data(expr *node) +{ + kron_left_expr *kn = (kron_left_expr *) node; + free_csr_matrix(kn->C); + free(kn->row_map); + free(kn->row_scale_idx); + kn->C = NULL; + kn->row_map = NULL; + kn->row_scale_idx = NULL; +} + +/* ------------------------------------------------------------------ */ +/* Constructor */ +/* ------------------------------------------------------------------ */ +expr *new_kron_left(expr *child, const CSR_Matrix *C, int p, int q) +{ + int m = C->m; + int n = C->n; + + /* Verify child dimensions */ + if (child->size != p * q) + { + fprintf(stderr, + "Error in new_kron_left: child size %d != p*q = %d*%d = %d\n", + child->size, p, q, p * q); + exit(1); + } + + /* Output: kron(C, X) has shape (m*p, n*q) */ + int d1 = m * p; + int d2 = n * q; + + kron_left_expr *kn = (kron_left_expr *) calloc(1, sizeof(kron_left_expr)); + expr *node = &kn->base; + init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, wsum_hess_init, eval_wsum_hess, free_type_data); + + node->left = child; + expr_retain(child); + + kn->m = m; + kn->n = n; + kn->p = p; + kn->q = q; + kn->C = new_csr(C); + kn->row_map = NULL; + kn->row_scale_idx = NULL; + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index 8f839f4..8743b4c 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -16,6 +16,7 @@ #include "forward_pass/elementwise/test_exp.h" #include "forward_pass/elementwise/test_log.h" #include "forward_pass/elementwise/test_normal_cdf.h" +#include "forward_pass/test_kron_left.h" #include "forward_pass/test_left_matmul_dense.h" #include "forward_pass/test_matmul.h" #include "forward_pass/test_prod_axis_one.h" @@ -27,6 +28,7 @@ #include "jacobian_tests/test_elementwise_mult.h" #include "jacobian_tests/test_hstack.h" #include "jacobian_tests/test_index.h" +#include "jacobian_tests/test_kron_left.h" #include "jacobian_tests/test_left_matmul.h" #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_matmul.h" @@ -65,6 +67,7 @@ #include "wsum_hess/test_const_vector_mult.h" #include "wsum_hess/test_hstack.h" #include "wsum_hess/test_index.h" +#include "wsum_hess/test_kron_left.h" #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_matmul.h" #include "wsum_hess/test_multiply.h" @@ -116,6 +119,7 @@ int main(void) mu_run_test(test_forward_prod_axis_one, tests_run); mu_run_test(test_matmul, tests_run); mu_run_test(test_left_matmul_dense, tests_run); + mu_run_test(test_kron_left_forward, tests_run); printf("\n--- Jacobian Tests ---\n"); mu_run_test(test_neg_jacobian, tests_run); @@ -174,6 +178,8 @@ int main(void) mu_run_test(test_wsum_hess_multiply_2, tests_run); mu_run_test(test_jacobian_trace_variable, tests_run); mu_run_test(test_jacobian_trace_composite, tests_run); + mu_run_test(test_jacobian_kron_left_log, tests_run); + mu_run_test(test_jacobian_kron_left_log_matrix, tests_run); mu_run_test(test_jacobian_left_matmul_log, tests_run); mu_run_test(test_jacobian_left_matmul_log_matrix, tests_run); mu_run_test(test_jacobian_left_matmul_log_composite, tests_run); @@ -232,6 +238,8 @@ int main(void) mu_run_test(test_wsum_hess_multiply_sparse_random, tests_run); mu_run_test(test_wsum_hess_multiply_1, tests_run); mu_run_test(test_wsum_hess_multiply_2, tests_run); + mu_run_test(test_wsum_hess_kron_left, tests_run); + mu_run_test(test_wsum_hess_kron_left_composite, tests_run); mu_run_test(test_wsum_hess_left_matmul, tests_run); mu_run_test(test_wsum_hess_left_matmul_matrix, tests_run); mu_run_test(test_wsum_hess_left_matmul_composite, tests_run); diff --git a/tests/forward_pass/test_kron_left.h b/tests/forward_pass/test_kron_left.h new file mode 100644 index 0000000..ea9b5b7 --- /dev/null +++ b/tests/forward_pass/test_kron_left.h @@ -0,0 +1,72 @@ +#include +#include +#include + +#include "bivariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_kron_left_forward(void) +{ + /* Test: Z = kron(C, X) where + * C is 2x2 sparse: [[1, 2], [0, 3]] + * X is 2x2 variable (col-major): [[1, 3], [2, 4]] + * + * kron(C, X) = [[1*X, 2*X], [0*X, 3*X]] + * = [[1, 3, 2, 6], + * [2, 4, 4, 8], + * [0, 0, 3, 9], + * [0, 0, 6, 12]] + * + * Output is 4x4, stored column-major: + * col 0: [1, 2, 0, 0] + * col 1: [3, 4, 0, 0] + * col 2: [2, 4, 3, 6] + * col 3: [6, 8, 9, 12] + */ + + /* Create X variable (2 x 2) */ + expr *X = new_variable(2, 2, 0, 4); + + /* Create sparse matrix C in CSR format: + * row 0: C[0,0]=1, C[0,1]=2 + * row 1: C[1,1]=3 */ + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + /* Build expression Z = kron(C, X), X is 2x2 */ + expr *Z = new_kron_left(X, C, 2, 2); + + /* Variable values in column-major order: X = [[1,3],[2,4]] */ + double u[4] = {1.0, 2.0, 3.0, 4.0}; + + /* Evaluate forward pass */ + Z->forward(Z, u); + + /* Expected result (4 x 4) in column-major order */ + double expected[16] = { + 1.0, 2.0, 0.0, 0.0, /* col 0 */ + 3.0, 4.0, 0.0, 0.0, /* col 1 */ + 2.0, 4.0, 3.0, 6.0, /* col 2 */ + 6.0, 8.0, 9.0, 12.0 /* col 3 */ + }; + + /* Verify dimensions */ + mu_assert("kron_left result should have d1=4", Z->d1 == 4); + mu_assert("kron_left result should have d2=4", Z->d2 == 4); + mu_assert("kron_left result should have size=16", Z->size == 16); + + /* Verify values */ + mu_assert("kron_left forward pass test failed", + cmp_double_array(Z->value, expected, 16)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} diff --git a/tests/jacobian_tests/test_kron_left.h b/tests/jacobian_tests/test_kron_left.h new file mode 100644 index 0000000..6fb4321 --- /dev/null +++ b/tests/jacobian_tests/test_kron_left.h @@ -0,0 +1,148 @@ +#include +#include +#include + +#include "bivariate.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_jacobian_kron_left_log(void) +{ + /* Test Jacobian of kron(C, log(x)) where: + * x is 2x1 variable at x = [1, 2] + * C is 2x2 sparse: [[1, 2], [0, 3]] + * Output: kron(C, log(x)) is 4x2, vectorized column-major to 8x1 + * + * log(x) = [0, log(2)] + * d(log(x))/dx = diag([1, 1/2]) + * + * kron(C, log(x)) output rows (column-major, mp=4): + * r=0: (i=0,k=0,j=0,l=0) -> C[0,0]*log(x)[0] = 1*log(x[0]) + * r=1: (i=0,k=1,j=0,l=0) -> C[0,0]*log(x)[1] = 1*log(x[1]) + * r=2: (i=1,k=0,j=0,l=0) -> C[1,0]*log(x)[0] = 0 (C[1,0]=0) + * r=3: (i=1,k=1,j=0,l=0) -> C[1,0]*log(x)[1] = 0 (C[1,0]=0) + * r=4: (i=0,k=0,j=1,l=0) -> C[0,1]*log(x)[0] = 2*log(x[0]) + * r=5: (i=0,k=1,j=1,l=0) -> C[0,1]*log(x)[1] = 2*log(x[1]) + * r=6: (i=1,k=0,j=1,l=0) -> C[1,1]*log(x)[0] = 3*log(x[0]) + * r=7: (i=1,k=1,j=1,l=0) -> C[1,1]*log(x)[1] = 3*log(x[1]) + * + * Jacobian (8x2): J[r, var] = C[i,j] * d(log(x[k]))/d(x[var]) + * Since d(log(x[k]))/d(x[var]) = delta(k,var)/x[k]: + * r=0: [1/1, 0] = [1, 0] + * r=1: [0, 1/2] = [0, 0.5] + * r=2: [0, 0] (zero row) + * r=3: [0, 0] (zero row) + * r=4: [2/1, 0] = [2, 0] + * r=5: [0, 2/2] = [0, 1] + * r=6: [3/1, 0] = [3, 0] + * r=7: [0, 3/2] = [0, 1.5] + * + * CSR format (8x2): + * p = [0, 1, 2, 2, 2, 3, 4, 5, 6] + * i = [0, 1, 0, 1, 0, 1] + * x = [1.0, 0.5, 2.0, 1.0, 3.0, 1.5] + */ + double x_vals[2] = {1.0, 2.0}; + expr *x = new_variable(2, 1, 0, 2); + + /* Create sparse matrix C in CSR format */ + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + expr *log_x = new_log(x); + expr *Z = new_kron_left(log_x, C, 2, 1); + + Z->forward(Z, x_vals); + Z->jacobian_init(Z); + Z->eval_jacobian(Z); + + double expected_x[6] = {1.0, 0.5, 2.0, 1.0, 3.0, 1.5}; + int expected_i[6] = {0, 1, 0, 1, 0, 1}; + int expected_p[9] = {0, 1, 2, 2, 2, 3, 4, 5, 6}; + + mu_assert("kron_left jac vals fail", + cmp_double_array(Z->jacobian->x, expected_x, 6)); + mu_assert("kron_left jac cols fail", + cmp_int_array(Z->jacobian->i, expected_i, 6)); + mu_assert("kron_left jac rows fail", + cmp_int_array(Z->jacobian->p, expected_p, 9)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} + +const char *test_jacobian_kron_left_log_matrix(void) +{ + /* Test Jacobian of kron(C, log(x)) where: + * x is 2x2 variable at x = [[1, 3], [2, 4]] (col-major: [1,2,3,4]) + * C is 2x1 sparse: [[1], [2]] + * p=2, q=2, m=2, n=1 + * Output: kron(C, log(x)) is (2*2)x(1*2) = 4x2, vectorized to 8x1 + * + * Output rows (column-major, mp=4): + * r=0: col_Z=0,row_Z=0 -> i=0,k=0,j=0,l=0 -> C[0,0]*log(x)[s=0] + * r=1: col_Z=0,row_Z=1 -> i=0,k=1,j=0,l=0 -> C[0,0]*log(x)[s=1] + * r=2: col_Z=0,row_Z=2 -> i=1,k=0,j=0,l=0 -> C[1,0]*log(x)[s=0] + * r=3: col_Z=0,row_Z=3 -> i=1,k=1,j=0,l=0 -> C[1,0]*log(x)[s=1] + * r=4: col_Z=1,row_Z=0 -> i=0,k=0,j=0,l=1 -> C[0,0]*log(x)[s=2] + * r=5: col_Z=1,row_Z=1 -> i=0,k=1,j=0,l=1 -> C[0,0]*log(x)[s=3] + * r=6: col_Z=1,row_Z=2 -> i=1,k=0,j=0,l=1 -> C[1,0]*log(x)[s=2] + * r=7: col_Z=1,row_Z=3 -> i=1,k=1,j=0,l=1 -> C[1,0]*log(x)[s=3] + * + * Jacobian (8x4): J[r, var] = C[i,j] * d(log(x[s]))/d(x[var]) + * r=0: C[0,0]/x[0] at col 0 = 1/1 = 1.0 + * r=1: C[0,0]/x[1] at col 1 = 1/2 = 0.5 + * r=2: C[1,0]/x[0] at col 0 = 2/1 = 2.0 + * r=3: C[1,0]/x[1] at col 1 = 2/2 = 1.0 + * r=4: C[0,0]/x[2] at col 2 = 1/3 + * r=5: C[0,0]/x[3] at col 3 = 1/4 = 0.25 + * r=6: C[1,0]/x[2] at col 2 = 2/3 + * r=7: C[1,0]/x[3] at col 3 = 2/4 = 0.5 + * + * CSR (8x4): + * p = [0, 1, 2, 3, 4, 5, 6, 7, 8] + * i = [0, 1, 0, 1, 2, 3, 2, 3] + * x = [1.0, 0.5, 2.0, 1.0, 1.0/3.0, 0.25, 2.0/3.0, 0.5] + */ + double x_vals[4] = {1.0, 2.0, 3.0, 4.0}; + expr *x = new_variable(2, 2, 0, 4); + + /* Create sparse matrix C in CSR: 2x1, C=[[1],[2]] */ + CSR_Matrix *C = new_csr_matrix(2, 1, 2); + int C_p[3] = {0, 1, 2}; + int C_i[2] = {0, 0}; + double C_x[2] = {1.0, 2.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 2 * sizeof(int)); + memcpy(C->x, C_x, 2 * sizeof(double)); + + expr *log_x = new_log(x); + expr *Z = new_kron_left(log_x, C, 2, 2); + + Z->forward(Z, x_vals); + Z->jacobian_init(Z); + Z->eval_jacobian(Z); + + double expected_x[8] = {1.0, 0.5, 2.0, 1.0, 1.0 / 3.0, 0.25, 2.0 / 3.0, 0.5}; + int expected_i[8] = {0, 1, 0, 1, 2, 3, 2, 3}; + int expected_p[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + + mu_assert("kron_left matrix jac vals fail", + cmp_double_array(Z->jacobian->x, expected_x, 8)); + mu_assert("kron_left matrix jac cols fail", + cmp_int_array(Z->jacobian->i, expected_i, 8)); + mu_assert("kron_left matrix jac rows fail", + cmp_int_array(Z->jacobian->p, expected_p, 9)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} diff --git a/tests/wsum_hess/test_kron_left.h b/tests/wsum_hess/test_kron_left.h new file mode 100644 index 0000000..2b0a758 --- /dev/null +++ b/tests/wsum_hess/test_kron_left.h @@ -0,0 +1,172 @@ +#include +#include +#include +#include + +#include "affine.h" +#include "bivariate.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_wsum_hess_kron_left(void) +{ + /* Test weighted sum of Hessian of kron(C, log(x)) where: + * x is 2x1 variable at x = [1, 2] + * C is 2x2 sparse: [[1, 2], [0, 3]] + * Output: kron(C, log(x)) is 4x2, vectorized to 8x1 + * Weights w = [1, 2, 3, 4, 5, 6, 7, 8] + * + * From kron_left eval_wsum_hess, the child weights are: + * w_child[s] = sum over active (i,j,k,l) of C[i,j] * w[r] + * where r = (j*q+l)*mp + i*p + k, s = l*p + k + * + * With p=2, q=1, m=2, n=2, mp=4: + * For s=0 (l=0, k=0): + * C[0,0]=1: r=(0*1+0)*4+0*2+0=0, w[0]=1 -> 1*1=1 + * C[0,1]=2: r=(1*1+0)*4+0*2+0=4, w[4]=5 -> 2*5=10 + * C[1,1]=3: r=(1*1+0)*4+1*2+0=6, w[6]=7 -> 3*7=21 + * w_child[0] = 1 + 10 + 21 = 32 + * + * For s=1 (l=0, k=1): + * C[0,0]=1: r=(0*1+0)*4+0*2+1=1, w[1]=2 -> 1*2=2 + * C[0,1]=2: r=(1*1+0)*4+0*2+1=5, w[5]=6 -> 2*6=12 + * C[1,1]=3: r=(1*1+0)*4+1*2+1=7, w[7]=8 -> 3*8=24 + * w_child[1] = 2 + 12 + 24 = 38 + * + * For log(x), the Hessian is diagonal: H[k,k] = -1/x[k]^2 + * wsum_hess[k,k] = w_child[k] * (-1/x[k]^2) + * [0,0] = 32 * (-1/1) = -32 + * [1,1] = 38 * (-1/4) = -9.5 + */ + double x_vals[2] = {1.0, 2.0}; + double w[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + + expr *x = new_variable(2, 1, 0, 2); + + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + expr *log_x = new_log(x); + expr *Z = new_kron_left(log_x, C, 2, 1); + + Z->forward(Z, x_vals); + Z->jacobian_init(Z); + Z->wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + double expected_x[2] = {-32.0, -9.5}; + int expected_i[2] = {0, 1}; + int expected_p[3] = {0, 1, 2}; + + mu_assert("kron_left hess vals fail", + cmp_double_array(Z->wsum_hess->x, expected_x, 2)); + mu_assert("kron_left hess cols fail", + cmp_int_array(Z->wsum_hess->i, expected_i, 2)); + mu_assert("kron_left hess rows fail", + cmp_int_array(Z->wsum_hess->p, expected_p, 3)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_kron_left_composite(void) +{ + /* Test weighted sum of Hessian of kron(C, log(B @ x)) where: + * x is 2x1 variable at x = [1, 2] + * B is 2x2 dense matrix of all ones + * C is 2x1 sparse: [[1], [2]] + * p=2, q=1, m=2, n=1 + * Output: kron(C, log(B@x)) is 4x1 + * Weights w = [1, 2, 3, 4] + * + * B @ x = [3, 3] + * log(B @ x) = [log(3), log(3)] + * + * Child weights for log(Bx): + * w_child[s] = sum over active (i,j) of C[i,j] * w[r] + * With p=2, q=1, m=2, n=1, mp=4: + * For s=0 (l=0, k=0): + * C[0,0]=1: r=(0)*4+0*2+0=0, w[0]=1 -> 1*1=1 + * C[1,0]=2: r=(0)*4+1*2+0=2, w[2]=3 -> 2*3=6 + * w_child[0] = 7 + * + * For s=1 (l=0, k=1): + * C[0,0]=1: r=(0)*4+0*2+1=1, w[1]=2 -> 1*2=2 + * C[1,0]=2: r=(0)*4+1*2+1=3, w[3]=4 -> 2*4=8 + * w_child[1] = 10 + * + * Now for log(Bx) with child weights [7, 10]: + * d^2(log(y))/dy^2 = -1/y^2, and y = Bx = [3, 3] + * The wsum_hess of log(Bx) w.r.t. x is: + * B^T @ diag(w_child * (-1/y^2)) @ B + * = B^T @ diag([7*(-1/9), 10*(-1/9)]) @ B + * = B^T @ diag([-7/9, -10/9]) @ B + * + * B^T = [[1,1],[1,1]], so: + * B^T @ diag([-7/9, -10/9]) = [[-7/9, -10/9], [-7/9, -10/9]] + * ... @ B = [[-7/9 - 10/9, -7/9 - 10/9], + * [-7/9 - 10/9, -7/9 - 10/9]] + * = [[-17/9, -17/9], [-17/9, -17/9]] + * + * CSR (2x2 dense): + * nnz = 4 + * p = [0, 2, 4] + * i = [0, 1, 0, 1] + * x = [-17/9, -17/9, -17/9, -17/9] + */ + double x_vals[2] = {1.0, 2.0}; + double w[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *x = new_variable(2, 1, 0, 2); + + /* Create B matrix (2x2 all ones) */ + CSR_Matrix *B = new_csr_matrix(2, 2, 4); + int B_p[3] = {0, 2, 4}; + int B_i[4] = {0, 1, 0, 1}; + double B_x[4] = {1.0, 1.0, 1.0, 1.0}; + memcpy(B->p, B_p, 3 * sizeof(int)); + memcpy(B->i, B_i, 4 * sizeof(int)); + memcpy(B->x, B_x, 4 * sizeof(double)); + + /* Create C matrix: 2x1 sparse, [[1],[2]] */ + CSR_Matrix *C = new_csr_matrix(2, 1, 2); + int C_p[3] = {0, 1, 2}; + int C_i[2] = {0, 0}; + double C_x[2] = {1.0, 2.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 2 * sizeof(int)); + memcpy(C->x, C_x, 2 * sizeof(double)); + + expr *Bx = new_linear(x, B, NULL); + expr *log_Bx = new_log(Bx); + expr *Z = new_kron_left(log_Bx, C, 2, 1); + + Z->forward(Z, x_vals); + Z->jacobian_init(Z); + Z->wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + double expected_x[4] = {-17.0 / 9.0, -17.0 / 9.0, -17.0 / 9.0, -17.0 / 9.0}; + int expected_i[4] = {0, 1, 0, 1}; + int expected_p[3] = {0, 2, 4}; + + mu_assert("kron_left composite hess vals fail", + cmp_double_array(Z->wsum_hess->x, expected_x, 4)); + mu_assert("kron_left composite hess cols fail", + cmp_int_array(Z->wsum_hess->i, expected_i, 4)); + mu_assert("kron_left composite hess rows fail", + cmp_int_array(Z->wsum_hess->p, expected_p, 3)); + + free_csr_matrix(B); + free_csr_matrix(C); + free_expr(Z); + return 0; +}