diff --git a/include/affine.h b/include/affine.h index 379c1cc..77ac14b 100644 --- a/include/affine.h +++ b/include/affine.h @@ -33,7 +33,7 @@ expr *new_vstack(expr **args, int n_args, int n_vars); expr *new_promote(expr *child, int d1, int d2); expr *new_trace(expr *child); -expr *new_constant(int d1, int d2, int n_vars, const double *values); +expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values); expr *new_variable(int d1, int d2, int var_id, int n_vars); expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs); @@ -42,26 +42,27 @@ expr *new_broadcast(expr *child, int target_d1, int target_d2); expr *new_diag_vec(expr *child); expr *new_transpose(expr *child); -/* Left matrix multiplication: A @ f(x) where A is a constant sparse - * matrix */ -expr *new_left_matmul(expr *u, const CSR_Matrix *A); +/* Left matrix multiplication: A @ f(x) where A is a constant or parameter + * sparse matrix. param_node is NULL for fixed constants. */ +expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A); -/* Left matrix multiplication: A @ f(x) where A is a constant dense - * matrix (row-major, m x n). Uses CBLAS for efficient computation. */ -expr *new_left_matmul_dense(expr *u, int m, int n, const double *data); +/* Left matrix multiplication: A @ f(x) where A is a constant or parameter + * dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */ +expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n, + const double *data); -/* Right matrix multiplication: f(x) @ A where A is a constant - * matrix */ -expr *new_right_matmul(expr *u, const CSR_Matrix *A); +/* Right matrix multiplication: f(x) @ A where A is a constant or parameter + * matrix. */ +expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A); -expr *new_right_matmul_dense(expr *u, int m, int n, const double *data); +expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n, + const double *data); -/* Constant scalar multiplication: a * f(x) where a is a constant - * double */ -expr *new_const_scalar_mult(double a, expr *child); +/* Scalar multiplication: a * f(x) where a comes from param_node */ +expr *new_scalar_mult(expr *param_node, expr *child); -/* Constant vector elementwise multiplication: a . f(x) where a is - * constant */ -expr *new_const_vector_mult(const double *a, expr *child); +/* Vector elementwise multiplication: a . f(x) where a comes from + * param_node */ +expr *new_vector_mult(expr *param_node, expr *child); #endif /* AFFINE_H */ diff --git a/include/problem.h b/include/problem.h index 9d7a446..3dc448e 100644 --- a/include/problem.h +++ b/include/problem.h @@ -46,6 +46,11 @@ typedef struct problem int n_vars; int total_constraint_size; + /* parameter support */ + expr **param_nodes; + int n_param_nodes; + int total_parameter_size; + /* allocated by new_problem */ double *constraint_values; double *gradient_values; @@ -76,6 +81,9 @@ void problem_init_jacobian_coo(problem *prob); void problem_init_hessian_coo_lower_triangular(problem *prob); void free_problem(problem *prob); +void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes); +void problem_update_params(problem *prob, const double *theta); + double problem_objective_forward(problem *prob, const double *u); void problem_constraint_forward(problem *prob, const double *u); void problem_gradient(problem *prob); diff --git a/include/subexpr.h b/include/subexpr.h index b4b427c..7e5504b 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -26,8 +26,20 @@ /* Forward declaration */ struct int_double_pair; +/* Parameter ID for fixed constants (not updatable) */ +#define PARAM_FIXED -1 + /* Type-specific expression structures that "inherit" from expr */ +/* Unified constant/parameter node. Constants use param_id == PARAM_FIXED. + * Updatable parameters use param_id >= 0 (offset into global theta). */ +typedef struct parameter_expr +{ + expr base; + int param_id; + bool has_been_refreshed; +} parameter_expr; + /* Linear operator: y = A * x + b */ typedef struct linear_op_expr { @@ -114,32 +126,24 @@ typedef struct left_matmul_expr CSC_Matrix *Jchild_CSC; CSC_Matrix *J_CSC; int *csc_to_csr_work; + expr *param_source; + void (*refresh_param_values)(struct left_matmul_expr *); } left_matmul_expr; -/* Right matrix multiplication: y = f(x) * A where f(x) is an expression. - * f(x) has shape p x n, A has shape n x q, output y has shape p x q. - * Uses vec(y) = B * vec(f(x)) where B = A^T kron I_p. */ -typedef struct right_matmul_expr -{ - expr base; - CSR_Matrix *B; /* B = A^T kron I_p */ - CSR_Matrix *BT; /* B^T for backpropagating Hessian weights */ - CSC_Matrix *CSC_work; -} right_matmul_expr; - -/* Constant scalar multiplication: y = a * child where a is a constant double */ -typedef struct const_scalar_mult_expr +/* Scalar multiplication: y = a * child where a comes from param_source */ +typedef struct scalar_mult_expr { expr base; - double a; -} const_scalar_mult_expr; + expr *param_source; +} scalar_mult_expr; -/* Constant vector elementwise multiplication: y = a \circ child for constant a */ -typedef struct const_vector_mult_expr +/* Vector elementwise multiplication: y = a \circ child where a comes from + * param_source */ +typedef struct vector_mult_expr { expr base; - double *a; /* length equals node->size */ -} const_vector_mult_expr; + expr *param_source; +} vector_mult_expr; /* Index/slicing: y = child[indices] where indices is a list of flat positions */ typedef struct index_expr diff --git a/include/utils/matrix.h b/include/utils/matrix.h index fe7db5f..49c55a5 100644 --- a/include/utils/matrix.h +++ b/include/utils/matrix.h @@ -31,6 +31,7 @@ typedef struct Matrix const CSC_Matrix *J, int p); void (*block_left_mult_values)(const struct Matrix *self, const CSC_Matrix *J, CSC_Matrix *C); + void (*update_values)(struct Matrix *self, const double *new_values); void (*free_fn)(struct Matrix *self); } Matrix; diff --git a/src/affine/left_matmul.c b/src/affine/left_matmul.c index be62f28..be87553 100644 --- a/src/affine/left_matmul.c +++ b/src/affine/left_matmul.c @@ -48,16 +48,34 @@ #include "utils/utils.h" +static void refresh_param_values(left_matmul_expr *lnode) +{ + if (lnode->param_source == NULL) + { + return; + } + parameter_expr *param = (parameter_expr *) lnode->param_source; + if (param->has_been_refreshed) + { + return; + } + param->has_been_refreshed = true; + lnode->refresh_param_values(lnode); +} + static void forward(expr *node, const double *u) { + left_matmul_expr *lnode = (left_matmul_expr *) node; + refresh_param_values(lnode); + expr *x = node->left; /* child's forward pass */ node->left->forward(node->left, u); /* y = A_kron @ vec(f(x)) */ - Matrix *A = ((left_matmul_expr *) node)->A; - int n_blocks = ((left_matmul_expr *) node)->n_blocks; + Matrix *A = lnode->A; + int n_blocks = lnode->n_blocks; A->block_left_mult_vec(A, x->value, node->value, n_blocks); } @@ -74,11 +92,16 @@ static void free_type_data(expr *node) free_csc_matrix(lnode->Jchild_CSC); free_csc_matrix(lnode->J_CSC); free(lnode->csc_to_csr_work); + if (lnode->param_source != NULL) + { + free_expr(lnode->param_source); + } lnode->A = NULL; lnode->AT = NULL; lnode->Jchild_CSC = NULL; lnode->J_CSC = NULL; lnode->csc_to_csr_work = NULL; + lnode->param_source = NULL; } static void jacobian_init_impl(expr *node) @@ -98,8 +121,8 @@ static void jacobian_init_impl(expr *node) static void eval_jacobian(expr *node) { - expr *x = node->left; left_matmul_expr *lnode = (left_matmul_expr *) node; + expr *x = node->left; CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC; CSC_Matrix *J_CSC = lnode->J_CSC; @@ -130,9 +153,11 @@ static void wsum_hess_init_impl(expr *node) static void eval_wsum_hess(expr *node, const double *w) { + left_matmul_expr *lnode = (left_matmul_expr *) node; + /* compute A^T w*/ - Matrix *AT = ((left_matmul_expr *) node)->AT; - int n_blocks = ((left_matmul_expr *) node)->n_blocks; + Matrix *AT = lnode->AT; + int n_blocks = lnode->n_blocks; AT->block_left_mult_vec(AT, w, node->work->dwork, n_blocks); node->left->eval_wsum_hess(node->left, node->work->dwork); @@ -140,7 +165,33 @@ static void eval_wsum_hess(expr *node, const double *w) node->wsum_hess->nnz * sizeof(double)); } -expr *new_left_matmul(expr *u, const CSR_Matrix *A) +static void refresh_sparse_left(left_matmul_expr *lnode) +{ + Sparse_Matrix *sm_A = (Sparse_Matrix *) lnode->A; + Sparse_Matrix *sm_AT = (Sparse_Matrix *) lnode->AT; + lnode->A->update_values(lnode->A, lnode->param_source->value); + /* Recompute AT values from A */ + AT_fill_values(sm_A->csr, sm_AT->csr, lnode->base.work->iwork); +} + +static void refresh_dense_left(left_matmul_expr *lnode) +{ + Dense_Matrix *dm_A = (Dense_Matrix *) lnode->A; + int m = dm_A->base.m; + int n = dm_A->base.n; + lnode->A->update_values(lnode->A, lnode->param_source->value); + /* Recompute AT data (transpose of row-major A) */ + Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT; + for (int i = 0; i < m; i++) + { + for (int j = 0; j < n; j++) + { + dm_AT->x[j * m + i] = dm_A->x[i * n + j]; + } + } +} + +expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A) { /* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users to do A @ u where u is (n, ) which in C is actually (1, n). In that case @@ -187,10 +238,19 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A) lnode->AT = sparse_matrix_trans((const Sparse_Matrix *) lnode->A, node->work->iwork); + /* parameter support */ + lnode->param_source = param_node; + if (param_node != NULL) + { + expr_retain(param_node); + lnode->refresh_param_values = refresh_sparse_left; + } + return node; } -expr *new_left_matmul_dense(expr *u, int m, int n, const double *data) +expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n, + const double *data) { int d1, d2, n_blocks; if (u->d1 == n) @@ -226,5 +286,13 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data) lnode->A = new_dense_matrix(m, n, data); lnode->AT = dense_matrix_trans((const Dense_Matrix *) lnode->A); + /* parameter support */ + lnode->param_source = param_node; + if (param_node != NULL) + { + expr_retain(param_node); + lnode->refresh_param_values = refresh_dense_left; + } + return node; } diff --git a/src/affine/constant.c b/src/affine/parameter.c similarity index 71% rename from src/affine/constant.c rename to src/affine/parameter.c index b2764f3..7b7f8d9 100644 --- a/src/affine/constant.c +++ b/src/affine/parameter.c @@ -16,38 +16,36 @@ * limitations under the License. */ #include "affine.h" +#include "subexpr.h" #include #include static void forward(expr *node, const double *u) { - /* Constants don't depend on u; values are already set */ + /* Parameters/constants don't depend on u; values are already set */ (void) node; (void) u; } static void jacobian_init_impl(expr *node) { - /* Constant jacobian is all zeros: size x n_vars with 0 nonzeros. - * new_csr_matrix uses calloc for row pointers, so they're already 0. */ + /* Zero jacobian: size x n_vars with 0 nonzeros. */ node->jacobian = new_csr_matrix(node->size, node->n_vars, 0); } static void eval_jacobian(expr *node) { - /* Constant jacobian never changes - nothing to evaluate */ (void) node; } static void wsum_hess_init_impl(expr *node) { - /* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */ + /* Zero Hessian: n_vars x n_vars with 0 nonzeros. */ node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0); } static void eval_wsum_hess(expr *node, const double *w) { - /* Constant Hessian is always zero - nothing to compute */ (void) node; (void) w; } @@ -58,12 +56,20 @@ static bool is_affine(const expr *node) return true; } -expr *new_constant(int d1, int d2, int n_vars, const double *values) +expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values) { - expr *node = (expr *) calloc(1, sizeof(expr)); + parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr)); + expr *node = &pnode->base; init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL); - memcpy(node->value, values, node->size * sizeof(double)); + + pnode->param_id = param_id; + pnode->has_been_refreshed = false; + + if (values != NULL) + { + memcpy(node->value, values, node->size * sizeof(double)); + } return node; } diff --git a/src/affine/right_matmul.c b/src/affine/right_matmul.c index f3740a9..73c050b 100644 --- a/src/affine/right_matmul.c +++ b/src/affine/right_matmul.c @@ -16,15 +16,53 @@ * limitations under the License. */ #include "affine.h" - +#include "subexpr.h" #include "utils/CSR_Matrix.h" #include /* This file implements the atom 'right_matmul' corresponding to the operation y = f(x) @ A, where A is a given matrix and f(x) is an arbitrary expression. We implement this by expressing right matmul in terms of left matmul and - transpose: f(x) @ A = (A^T @ f(x)^T)^T. */ -expr *new_right_matmul(expr *u, const CSR_Matrix *A) + transpose: f(x) @ A = (A^T @ f(x)^T)^T. + + For the parameter case: + - param_source stores A values in CSR data order + - inner left_matmul stores AT as its A-matrix and A as its AT-matrix + - on refresh: update AT (inner's AT, the original A) from param_source, + then recompute A^T (inner's A) from the updated A. */ + +/* Refresh for sparse right_matmul: param stores A in CSR data order. + Inner left_matmul: lnode->A = AT (transposed), lnode->AT = A (original). + So: update lnode->AT from param values, then recompute lnode->A. */ +static void refresh_sparse_right(left_matmul_expr *lnode) +{ + Sparse_Matrix *sm_AT_inner = (Sparse_Matrix *) lnode->A; + Sparse_Matrix *sm_A_inner = (Sparse_Matrix *) lnode->AT; + /* lnode->AT holds the original A; update its values from param */ + lnode->AT->update_values(lnode->AT, lnode->param_source->value); + /* Recompute A^T (lnode->A) from A (lnode->AT) */ + AT_fill_values(sm_A_inner->csr, sm_AT_inner->csr, lnode->base.work->iwork); +} + +static void refresh_dense_right(left_matmul_expr *lnode) +{ + Dense_Matrix *dm_AT_inner = (Dense_Matrix *) lnode->A; + Dense_Matrix *dm_A_inner = (Dense_Matrix *) lnode->AT; + int m_orig = dm_A_inner->base.m; /* original A is m x n */ + int n_orig = dm_A_inner->base.n; + /* Update original A (inner's AT) from param values */ + lnode->AT->update_values(lnode->AT, lnode->param_source->value); + /* Recompute A^T (inner's A) from A */ + for (int i = 0; i < m_orig; i++) + { + for (int j = 0; j < n_orig; j++) + { + dm_AT_inner->x[j * m_orig + i] = dm_A_inner->x[i * n_orig + j]; + } + } +} + +expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A) { /* We can express right matmul using left matmul and transpose: u @ A = (A^T @ u^T)^T. */ @@ -32,7 +70,18 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A) CSR_Matrix *AT = transpose(A, work_transpose); expr *u_transpose = new_transpose(u); - expr *left_matmul = new_left_matmul(u_transpose, AT); + expr *left_matmul = new_left_matmul(NULL, u_transpose, AT); + + /* If parameterized, attach param_source and custom refresh to inner + left_matmul */ + if (param_node != NULL) + { + left_matmul_expr *lnode = (left_matmul_expr *) left_matmul; + lnode->param_source = param_node; + expr_retain(param_node); + lnode->refresh_param_values = refresh_sparse_right; + } + expr *node = new_transpose(left_matmul); free_csr_matrix(AT); @@ -40,7 +89,8 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A) return node; } -expr *new_right_matmul_dense(expr *u, int m, int n, const double *data) +expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n, + const double *data) { /* We express: u @ A = (A^T @ u^T)^T A is m x n, so A^T is n x m. */ @@ -54,7 +104,17 @@ expr *new_right_matmul_dense(expr *u, int m, int n, const double *data) } expr *u_transpose = new_transpose(u); - expr *left_matmul_node = new_left_matmul_dense(u_transpose, n, m, AT); + expr *left_matmul_node = new_left_matmul_dense(NULL, u_transpose, n, m, AT); + + /* If parameterized, attach param_source and custom refresh */ + if (param_node != NULL) + { + left_matmul_expr *lnode = (left_matmul_expr *) left_matmul_node; + lnode->param_source = param_node; + expr_retain(param_node); + lnode->refresh_param_values = refresh_dense_right; + } + expr *node = new_transpose(left_matmul_node); free(AT); diff --git a/src/affine/const_scalar_mult.c b/src/affine/scalar_mult.c similarity index 75% rename from src/affine/const_scalar_mult.c rename to src/affine/scalar_mult.c index 0c81aff..8b9c93e 100644 --- a/src/affine/const_scalar_mult.c +++ b/src/affine/scalar_mult.c @@ -15,24 +15,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "affine.h" +#include "bivariate.h" #include "subexpr.h" #include #include #include #include -/* Constant scalar multiplication: y = a * child where a is a constant double */ +/* Scalar multiplication: y = a * child where a comes from param_source */ static void forward(expr *node, const double *u) { expr *child = node->left; + double a = ((scalar_mult_expr *) node)->param_source->value[0]; /* child's forward pass */ child->forward(child, u); /* local forward pass: multiply each element by scalar a */ - double a = ((const_scalar_mult_expr *) node)->a; for (int i = 0; i < node->size; i++) { node->value[i] = a * child->value[i]; @@ -53,7 +53,7 @@ static void jacobian_init_impl(expr *node) static void eval_jacobian(expr *node) { expr *child = node->left; - double a = ((const_scalar_mult_expr *) node)->a; + double a = ((scalar_mult_expr *) node)->param_source->value[0]; /* evaluate child */ child->eval_jacobian(child); @@ -81,7 +81,7 @@ static void eval_wsum_hess(expr *node, const double *w) expr *x = node->left; x->eval_wsum_hess(x, w); - double a = ((const_scalar_mult_expr *) node)->a; + double a = ((scalar_mult_expr *) node)->param_source->value[0]; for (int j = 0; j < x->wsum_hess->nnz; j++) { node->wsum_hess->x[j] = a * x->wsum_hess->x[j]; @@ -90,21 +90,33 @@ static void eval_wsum_hess(expr *node, const double *w) static bool is_affine(const expr *node) { - /* Affine iff the child is affine */ return node->left->is_affine(node->left); } -expr *new_const_scalar_mult(double a, expr *child) +static void free_type_data(expr *node) { - const_scalar_mult_expr *mult_node = - (const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr)); + scalar_mult_expr *snode = (scalar_mult_expr *) node; + if (snode->param_source != NULL) + { + free_expr(snode->param_source); + snode->param_source = NULL; + } +} + +expr *new_scalar_mult(expr *param_node, expr *child) +{ + scalar_mult_expr *mult_node = + (scalar_mult_expr *) calloc(1, sizeof(scalar_mult_expr)); expr *node = &mult_node->base; init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init_impl, - eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL); + eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, + free_type_data); node->left = child; - mult_node->a = a; expr_retain(child); + mult_node->param_source = param_node; + expr_retain(param_node); + return node; } diff --git a/src/affine/const_vector_mult.c b/src/affine/vector_mult.c similarity index 79% rename from src/affine/const_vector_mult.c rename to src/affine/vector_mult.c index ad7c81e..4b1ec08 100644 --- a/src/affine/const_vector_mult.c +++ b/src/affine/vector_mult.c @@ -15,18 +15,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "affine.h" +#include "bivariate.h" #include "subexpr.h" #include #include #include -/* Constant vector elementwise multiplication: y = a \circ child */ +/* Vector elementwise multiplication: y = a \circ child + * where a comes from param_source */ static void forward(expr *node, const double *u) { expr *child = node->left; - const double *a = ((const_vector_mult_expr *) node)->a; + const double *a = ((vector_mult_expr *) node)->param_source->value; /* child's forward pass */ child->forward(child, u); @@ -52,7 +53,7 @@ static void jacobian_init_impl(expr *node) static void eval_jacobian(expr *node) { expr *x = node->left; - const double *a = ((const_vector_mult_expr *) node)->a; + const double *a = ((vector_mult_expr *) node)->param_source->value; /* evaluate x */ x->eval_jacobian(x); @@ -83,7 +84,7 @@ static void wsum_hess_init_impl(expr *node) static void eval_wsum_hess(expr *node, const double *w) { expr *x = node->left; - const double *a = ((const_vector_mult_expr *) node)->a; + const double *a = ((vector_mult_expr *) node)->param_source->value; /* scale weights w by a */ for (int i = 0; i < node->size; i++) @@ -99,20 +100,23 @@ static void eval_wsum_hess(expr *node, const double *w) static void free_type_data(expr *node) { - const_vector_mult_expr *vnode = (const_vector_mult_expr *) node; - free(vnode->a); + vector_mult_expr *vnode = (vector_mult_expr *) node; + if (vnode->param_source != NULL) + { + free_expr(vnode->param_source); + vnode->param_source = NULL; + } } static bool is_affine(const expr *node) { - /* Affine iff the child is affine */ return node->left->is_affine(node->left); } -expr *new_const_vector_mult(const double *a, expr *child) +expr *new_vector_mult(expr *param_node, expr *child) { - const_vector_mult_expr *vnode = - (const_vector_mult_expr *) calloc(1, sizeof(const_vector_mult_expr)); + vector_mult_expr *vnode = + (vector_mult_expr *) calloc(1, sizeof(vector_mult_expr)); expr *node = &vnode->base; init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init_impl, @@ -121,9 +125,8 @@ expr *new_const_vector_mult(const double *a, expr *child) node->left = child; expr_retain(child); - /* copy a vector */ - vnode->a = (double *) malloc(child->size * sizeof(double)); - memcpy(vnode->a, a, child->size * sizeof(double)); + vnode->param_source = param_node; + expr_retain(param_node); return node; } diff --git a/src/problem.c b/src/problem.c index 406f734..f663193 100644 --- a/src/problem.c +++ b/src/problem.c @@ -16,6 +16,7 @@ * limitations under the License. */ #include "problem.h" +#include "subexpr.h" #include "utils/CSR_sum.h" #include "utils/utils.h" #include @@ -302,6 +303,9 @@ void free_problem(problem *prob) { if (prob == NULL) return; + /* Free param_nodes array (weak refs, don't free the nodes) */ + free(prob->param_nodes); + /* Free allocated arrays */ free(prob->constraint_values); free(prob->gradient_values); @@ -327,6 +331,35 @@ void free_problem(problem *prob) free(prob); } +void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes) +{ + prob->n_param_nodes = n_param_nodes; + prob->param_nodes = (expr **) malloc(n_param_nodes * sizeof(expr *)); + memcpy(prob->param_nodes, param_nodes, n_param_nodes * sizeof(expr *)); + + prob->total_parameter_size = 0; + for (int i = 0; i < n_param_nodes; i++) + { + prob->total_parameter_size += param_nodes[i]->size; + } +} + +void problem_update_params(problem *prob, const double *theta) +{ + for (int i = 0; i < prob->n_param_nodes; i++) + { + expr *pnode = prob->param_nodes[i]; + parameter_expr *param = (parameter_expr *) pnode; + if (param->param_id == PARAM_FIXED) continue; + int offset = param->param_id; + memcpy(pnode->value, theta + offset, pnode->size * sizeof(double)); + param->has_been_refreshed = false; + } + + /* Force re-evaluation of affine Jacobians on next call */ + prob->jacobian_called = false; +} + double problem_objective_forward(problem *prob, const double *u) { Timer timer; diff --git a/src/utils/dense_matrix.c b/src/utils/dense_matrix.c index 63f3442..05ba1a8 100644 --- a/src/utils/dense_matrix.c +++ b/src/utils/dense_matrix.c @@ -171,6 +171,12 @@ static void dense_block_left_mult_values(const Matrix *A, const CSC_Matrix *J, } } +static void dense_update_values(Matrix *self, const double *new_values) +{ + Dense_Matrix *dm = (Dense_Matrix *) self; + memcpy(dm->x, new_values, dm->base.m * dm->base.n * sizeof(double)); +} + static void dense_free(Matrix *A) { Dense_Matrix *dm = (Dense_Matrix *) A; @@ -187,6 +193,7 @@ Matrix *new_dense_matrix(int m, int n, const double *data) dm->base.block_left_mult_vec = dense_block_left_mult_vec; dm->base.block_left_mult_sparsity = dense_block_left_mult_sparsity; dm->base.block_left_mult_values = dense_block_left_mult_values; + dm->base.update_values = dense_update_values; dm->base.free_fn = dense_free; dm->x = (double *) malloc(m * n * sizeof(double)); memcpy(dm->x, data, m * n * sizeof(double)); diff --git a/src/utils/sparse_matrix.c b/src/utils/sparse_matrix.c index 24ed539..076890a 100644 --- a/src/utils/sparse_matrix.c +++ b/src/utils/sparse_matrix.c @@ -18,6 +18,7 @@ #include "utils/linalg_sparse_matmuls.h" #include "utils/matrix.h" #include +#include static void sparse_block_left_mult_vec(const Matrix *self, const double *x, double *y, int p) @@ -40,6 +41,12 @@ static void sparse_block_left_mult_values(const Matrix *self, const CSC_Matrix * block_left_multiply_fill_values(sm->csr, J, C); } +static void sparse_update_values(Matrix *self, const double *new_values) +{ + Sparse_Matrix *sm = (Sparse_Matrix *) self; + memcpy(sm->csr->x, new_values, sm->csr->nnz * sizeof(double)); +} + static void sparse_free(Matrix *self) { Sparse_Matrix *sm = (Sparse_Matrix *) self; @@ -55,6 +62,7 @@ Matrix *new_sparse_matrix(const CSR_Matrix *A) sm->base.block_left_mult_vec = sparse_block_left_mult_vec; sm->base.block_left_mult_sparsity = sparse_block_left_mult_sparsity; sm->base.block_left_mult_values = sparse_block_left_mult_values; + sm->base.update_values = sparse_update_values; sm->base.free_fn = sparse_free; sm->csr = new_csr(A); return &sm->base; @@ -69,6 +77,7 @@ Matrix *sparse_matrix_trans(const Sparse_Matrix *self, int *iwork) sm->base.block_left_mult_vec = sparse_block_left_mult_vec; sm->base.block_left_mult_sparsity = sparse_block_left_mult_sparsity; sm->base.block_left_mult_values = sparse_block_left_mult_values; + sm->base.update_values = sparse_update_values; sm->base.free_fn = sparse_free; sm->csr = AT; return &sm->base; diff --git a/tests/all_tests.c b/tests/all_tests.c index d8ba5b4..eed08ac 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -12,7 +12,7 @@ #include "forward_pass/affine/test_neg.h" #include "forward_pass/affine/test_promote.h" #include "forward_pass/affine/test_sum.h" -#include "forward_pass/affine/test_variable_constant.h" +#include "forward_pass/affine/test_variable_parameter.h" #include "forward_pass/affine/test_vstack.h" #include "forward_pass/bivariate_full_dom/test_matmul.h" #include "forward_pass/composite/test_composite.h" @@ -22,17 +22,17 @@ #include "forward_pass/other/test_prod_axis_one.h" #include "forward_pass/other/test_prod_axis_zero.h" #include "jacobian_tests/affine/test_broadcast.h" -#include "jacobian_tests/affine/test_const_scalar_mult.h" -#include "jacobian_tests/affine/test_const_vector_mult.h" #include "jacobian_tests/affine/test_hstack.h" #include "jacobian_tests/affine/test_index.h" #include "jacobian_tests/affine/test_left_matmul.h" #include "jacobian_tests/affine/test_neg.h" #include "jacobian_tests/affine/test_promote.h" #include "jacobian_tests/affine/test_right_matmul.h" +#include "jacobian_tests/affine/test_scalar_mult.h" #include "jacobian_tests/affine/test_sum.h" #include "jacobian_tests/affine/test_trace.h" #include "jacobian_tests/affine/test_transpose.h" +#include "jacobian_tests/affine/test_vector_mult.h" #include "jacobian_tests/affine/test_vstack.h" #include "jacobian_tests/bivariate_full_dom/test_elementwise_mult.h" #include "jacobian_tests/bivariate_full_dom/test_matmul.h" @@ -48,6 +48,7 @@ #include "jacobian_tests/other/test_prod_axis_zero.h" #include "jacobian_tests/other/test_quad_form.h" #include "numerical_diff/test_numerical_diff.h" +#include "problem/test_param_prob.h" #include "problem/test_problem.h" #include "utils/test_cblas.h" #include "utils/test_coo_matrix.h" @@ -57,15 +58,15 @@ #include "utils/test_linalg_sparse_matmuls.h" #include "utils/test_matrix.h" #include "wsum_hess/affine/test_broadcast.h" -#include "wsum_hess/affine/test_const_scalar_mult.h" -#include "wsum_hess/affine/test_const_vector_mult.h" #include "wsum_hess/affine/test_hstack.h" #include "wsum_hess/affine/test_index.h" #include "wsum_hess/affine/test_left_matmul.h" #include "wsum_hess/affine/test_right_matmul.h" +#include "wsum_hess/affine/test_scalar_mult.h" #include "wsum_hess/affine/test_sum.h" #include "wsum_hess/affine/test_trace.h" #include "wsum_hess/affine/test_transpose.h" +#include "wsum_hess/affine/test_vector_mult.h" #include "wsum_hess/affine/test_vstack.h" #include "wsum_hess/bivariate_full_dom/test_matmul.h" #include "wsum_hess/bivariate_full_dom/test_multiply.h" @@ -138,10 +139,10 @@ int main(void) mu_run_test(test_jacobian_Ax_Bx_multiply, tests_run); mu_run_test(test_jacobian_AX_BX_multiply, tests_run); mu_run_test(test_jacobian_composite_exp_add, tests_run); - mu_run_test(test_jacobian_const_scalar_mult_log_vector, tests_run); - mu_run_test(test_jacobian_const_scalar_mult_log_matrix, tests_run); - mu_run_test(test_jacobian_const_vector_mult_log_vector, tests_run); - mu_run_test(test_jacobian_const_vector_mult_log_matrix, tests_run); + mu_run_test(test_jacobian_scalar_mult_log_vector, tests_run); + mu_run_test(test_jacobian_scalar_mult_log_matrix, tests_run); + mu_run_test(test_jacobian_vector_mult_log_vector, tests_run); + mu_run_test(test_jacobian_vector_mult_log_matrix, tests_run); mu_run_test(test_jacobian_rel_entr_vector_args_1, tests_run); mu_run_test(test_jacobian_rel_entr_vector_args_2, tests_run); mu_run_test(test_jacobian_rel_entr_matrix_args, tests_run); @@ -242,10 +243,10 @@ int main(void) mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_yx, tests_run); mu_run_test(test_wsum_hess_quad_form, tests_run); - mu_run_test(test_wsum_hess_const_scalar_mult_log_vector, tests_run); - mu_run_test(test_wsum_hess_const_scalar_mult_log_matrix, tests_run); - mu_run_test(test_wsum_hess_const_vector_mult_log_vector, tests_run); - mu_run_test(test_wsum_hess_const_vector_mult_log_matrix, tests_run); + mu_run_test(test_wsum_hess_scalar_mult_log_vector, tests_run); + mu_run_test(test_wsum_hess_scalar_mult_log_matrix, tests_run); + mu_run_test(test_wsum_hess_vector_mult_log_vector, tests_run); + mu_run_test(test_wsum_hess_vector_mult_log_matrix, tests_run); mu_run_test(test_wsum_hess_multiply_linear_ops, tests_run); mu_run_test(test_wsum_hess_multiply_sparse_random, tests_run); mu_run_test(test_wsum_hess_multiply_1, tests_run); @@ -324,6 +325,13 @@ int main(void) mu_run_test(test_problem_jacobian_multi, tests_run); mu_run_test(test_problem_constraint_forward, tests_run); mu_run_test(test_problem_hessian, tests_run); + + printf("\n--- Parameter Tests ---\n"); + mu_run_test(test_param_scalar_mult_problem, tests_run); + mu_run_test(test_param_vector_mult_problem, tests_run); + mu_run_test(test_param_left_matmul_problem, tests_run); + mu_run_test(test_param_right_matmul_problem, tests_run); + mu_run_test(test_param_fixed_skip_in_update, tests_run); #endif /* PROFILE_ONLY */ #ifdef PROFILE_ONLY diff --git a/tests/forward_pass/affine/test_add.h b/tests/forward_pass/affine/test_add.h index 11fb35c..12f8cb9 100644 --- a/tests/forward_pass/affine/test_add.h +++ b/tests/forward_pass/affine/test_add.h @@ -12,7 +12,7 @@ const char *test_addition(void) double u[2] = {3.0, 4.0}; double c[2] = {1.0, 2.0}; expr *var = new_variable(2, 1, 0, 2); - expr *const_node = new_constant(2, 1, 0, c); + expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); expr *sum = new_add(var, const_node); sum->forward(sum, u); double expected[2] = {4.0, 6.0}; diff --git a/tests/forward_pass/affine/test_left_matmul_dense.h b/tests/forward_pass/affine/test_left_matmul_dense.h index 5cd9c75..5b1f91d 100644 --- a/tests/forward_pass/affine/test_left_matmul_dense.h +++ b/tests/forward_pass/affine/test_left_matmul_dense.h @@ -24,7 +24,7 @@ const char *test_left_matmul_dense(void) double A_data[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; /* Build expression Z = A @ X */ - expr *Z = new_left_matmul_dense(X, 3, 3, A_data); + expr *Z = new_left_matmul_dense(NULL, X, 3, 3, A_data); /* Variable values in column-major order */ double u[9] = {1.0, 2.0, 3.0, /* first column */ diff --git a/tests/forward_pass/affine/test_sum.h b/tests/forward_pass/affine/test_sum.h index bf95c84..a4a0780 100644 --- a/tests/forward_pass/affine/test_sum.h +++ b/tests/forward_pass/affine/test_sum.h @@ -18,7 +18,7 @@ const char *test_sum_axis_neg1(void) Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(3, 2, 0, values); + expr *const_node = new_parameter(3, 2, PARAM_FIXED, 0, values); expr *log_node = new_log(const_node); expr *sum_node = new_sum(log_node, -1); sum_node->forward(sum_node, NULL); @@ -43,7 +43,7 @@ const char *test_sum_axis_0(void) Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(3, 2, 0, values); + expr *const_node = new_parameter(3, 2, PARAM_FIXED, 0, values); expr *log_node = new_log(const_node); expr *sum_node = new_sum(log_node, 0); sum_node->forward(sum_node, NULL); @@ -70,7 +70,7 @@ const char *test_sum_axis_1(void) Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(3, 2, 0, values); + expr *const_node = new_parameter(3, 2, PARAM_FIXED, 0, values); expr *log_node = new_log(const_node); expr *sum_node = new_sum(log_node, 1); sum_node->forward(sum_node, NULL); diff --git a/tests/forward_pass/affine/test_variable_constant.h b/tests/forward_pass/affine/test_variable_parameter.h similarity index 91% rename from tests/forward_pass/affine/test_variable_constant.h rename to tests/forward_pass/affine/test_variable_parameter.h index 9df0f1a..a2ed535 100644 --- a/tests/forward_pass/affine/test_variable_constant.h +++ b/tests/forward_pass/affine/test_variable_parameter.h @@ -21,7 +21,7 @@ const char *test_constant(void) { double c[2] = {5.0, 10.0}; double u[2] = {0.0, 0.0}; - expr *const_node = new_constant(2, 1, 0, c); + expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); const_node->forward(const_node, u); mu_assert("Constant test failed", cmp_double_array(const_node->value, c, 2)); free_expr(const_node); diff --git a/tests/forward_pass/composite/test_composite.h b/tests/forward_pass/composite/test_composite.h index f45a190..87ba10f 100644 --- a/tests/forward_pass/composite/test_composite.h +++ b/tests/forward_pass/composite/test_composite.h @@ -17,7 +17,7 @@ const char *test_composite(void) /* Build tree: log(exp(x) + c) */ expr *var = new_variable(2, 1, 0, 2); expr *exp_node = new_exp(var); - expr *const_node = new_constant(2, 1, 0, c); + expr *const_node = new_parameter(2, 1, PARAM_FIXED, 0, c); expr *sum = new_add(exp_node, const_node); expr *log_node = new_log(sum); diff --git a/tests/forward_pass/other/test_prod_axis_one.h b/tests/forward_pass/other/test_prod_axis_one.h index 7cf74e0..2d74b0a 100644 --- a/tests/forward_pass/other/test_prod_axis_one.h +++ b/tests/forward_pass/other/test_prod_axis_one.h @@ -16,7 +16,7 @@ const char *test_forward_prod_axis_one(void) Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(2, 3, 0, values); + expr *const_node = new_parameter(2, 3, PARAM_FIXED, 0, values); expr *prod_node = new_prod_axis_one(const_node); prod_node->forward(prod_node, NULL); diff --git a/tests/forward_pass/other/test_prod_axis_zero.h b/tests/forward_pass/other/test_prod_axis_zero.h index f87782a..aec2cb7 100644 --- a/tests/forward_pass/other/test_prod_axis_zero.h +++ b/tests/forward_pass/other/test_prod_axis_zero.h @@ -16,7 +16,7 @@ const char *test_forward_prod_axis_zero(void) Stored as: [1, 2, 3, 4, 5, 6] */ double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - expr *const_node = new_constant(2, 3, 0, values); + expr *const_node = new_parameter(2, 3, PARAM_FIXED, 0, values); expr *prod_node = new_prod_axis_zero(const_node); prod_node->forward(prod_node, NULL); diff --git a/tests/jacobian_tests/affine/test_broadcast.h b/tests/jacobian_tests/affine/test_broadcast.h index a66e761..c850fb3 100644 --- a/tests/jacobian_tests/affine/test_broadcast.h +++ b/tests/jacobian_tests/affine/test_broadcast.h @@ -141,7 +141,7 @@ const char *test_double_broadcast(void) /* form the expression x + b */ expr *x = new_variable(5, 1, 0, 5); - expr *b = new_constant(1, 5, 5, b_vals); + expr *b = new_parameter(1, 5, PARAM_FIXED, 5, b_vals); expr *bcast_x = new_broadcast(x, 5, 5); expr *bcast_b = new_broadcast(b, 5, 5); expr *sum = new_add(bcast_x, bcast_b); diff --git a/tests/jacobian_tests/affine/test_left_matmul.h b/tests/jacobian_tests/affine/test_left_matmul.h index 270a2be..9e14dfb 100644 --- a/tests/jacobian_tests/affine/test_left_matmul.h +++ b/tests/jacobian_tests/affine/test_left_matmul.h @@ -42,7 +42,7 @@ const char *test_jacobian_left_matmul_log(void) memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(NULL, log_x, A); A_log_x->forward(A_log_x, x_vals); jacobian_init(A_log_x); @@ -86,7 +86,7 @@ const char *test_jacobian_left_matmul_log_matrix(void) memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(NULL, log_x, A); A_log_x->forward(A_log_x, x_vals); jacobian_init(A_log_x); @@ -135,7 +135,7 @@ const char *test_jacobian_left_matmul_exp_composite(void) expr *Bx = new_linear(x, B, NULL); expr *exp_Bx = new_exp(Bx); - expr *A_exp_Bx = new_left_matmul(exp_Bx, A); + expr *A_exp_Bx = new_left_matmul(NULL, exp_Bx, A); mu_assert("check_jacobian failed", check_jacobian(A_exp_Bx, x_vals, NUMERICAL_DIFF_DEFAULT_H)); diff --git a/tests/jacobian_tests/affine/test_right_matmul.h b/tests/jacobian_tests/affine/test_right_matmul.h index e3af665..625e494 100644 --- a/tests/jacobian_tests/affine/test_right_matmul.h +++ b/tests/jacobian_tests/affine/test_right_matmul.h @@ -27,7 +27,7 @@ const char *test_jacobian_right_matmul_log(void) memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); jacobian_init(log_x_A); @@ -76,7 +76,7 @@ const char *test_jacobian_right_matmul_log_vector(void) memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); jacobian_init(log_x_A); diff --git a/tests/jacobian_tests/affine/test_const_scalar_mult.h b/tests/jacobian_tests/affine/test_scalar_mult.h similarity index 81% rename from tests/jacobian_tests/affine/test_const_scalar_mult.h rename to tests/jacobian_tests/affine/test_scalar_mult.h index b459e77..db4670d 100644 --- a/tests/jacobian_tests/affine/test_const_scalar_mult.h +++ b/tests/jacobian_tests/affine/test_scalar_mult.h @@ -1,15 +1,17 @@ #include #include "affine.h" +#include "bivariate.h" #include "elementwise_full_dom.h" #include "elementwise_restricted_dom.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" -/* Test: y = a * log(x) where a is a scalar constant */ +/* Test: y = a * log(x) where a is a scalar parameter */ -const char *test_jacobian_const_scalar_mult_log_vector(void) +const char *test_jacobian_scalar_mult_log_vector(void) { /* Create variable x: [1.0, 2.0, 4.0] with 3 elements */ double u_vals[3] = {1.0, 2.0, 4.0}; @@ -19,8 +21,9 @@ const char *test_jacobian_const_scalar_mult_log_vector(void) expr *log_node = new_log(x); /* Create scalar mult node: y = 2.5 * log(x) */ - double a = 2.5; - expr *y = new_const_scalar_mult(a, log_node); + double a_val = 2.5; + expr *a_param = new_parameter(1, 1, PARAM_FIXED, 3, &a_val); + expr *y = new_scalar_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -45,7 +48,7 @@ const char *test_jacobian_const_scalar_mult_log_vector(void) return 0; } -const char *test_jacobian_const_scalar_mult_log_matrix(void) +const char *test_jacobian_scalar_mult_log_matrix(void) { /* Create variable x as 2x2 matrix: [[1.0, 2.0], [4.0, 8.0]] */ double u_vals[4] = {1.0, 2.0, 4.0, 8.0}; @@ -55,8 +58,9 @@ const char *test_jacobian_const_scalar_mult_log_matrix(void) expr *log_node = new_log(x); /* Create scalar mult node: y = 3.0 * log(x) */ - double a = 3.0; - expr *y = new_const_scalar_mult(a, log_node); + double a_val = 3.0; + expr *a_param = new_parameter(1, 1, PARAM_FIXED, 4, &a_val); + expr *y = new_scalar_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); diff --git a/tests/jacobian_tests/affine/test_transpose.h b/tests/jacobian_tests/affine/test_transpose.h index 4cf6e7e..7d268b8 100644 --- a/tests/jacobian_tests/affine/test_transpose.h +++ b/tests/jacobian_tests/affine/test_transpose.h @@ -21,7 +21,7 @@ const char *test_jacobian_transpose(void) // X = [1 2; 3 4] (columnwise: x = [1 3 2 4]) expr *X = new_variable(2, 2, 0, 4); - expr *AX = new_left_matmul(X, A); + expr *AX = new_left_matmul(NULL, X, A); expr *transpose_AX = new_transpose(AX); double u[4] = {1, 3, 2, 4}; transpose_AX->forward(transpose_AX, u); diff --git a/tests/jacobian_tests/affine/test_const_vector_mult.h b/tests/jacobian_tests/affine/test_vector_mult.h similarity index 77% rename from tests/jacobian_tests/affine/test_const_vector_mult.h rename to tests/jacobian_tests/affine/test_vector_mult.h index 45ca3fd..87b9a77 100644 --- a/tests/jacobian_tests/affine/test_const_vector_mult.h +++ b/tests/jacobian_tests/affine/test_vector_mult.h @@ -1,15 +1,17 @@ #include #include "affine.h" +#include "bivariate.h" #include "elementwise_full_dom.h" #include "elementwise_restricted_dom.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" -/* Test: y = a ∘ log(x) where a is a constant vector */ +/* Test: y = a ∘ log(x) where a is a parameter vector */ -const char *test_jacobian_const_vector_mult_log_vector(void) +const char *test_jacobian_vector_mult_log_vector(void) { /* Create variable x: [1.0, 2.0, 4.0] */ double u_vals[3] = {1.0, 2.0, 4.0}; @@ -20,7 +22,8 @@ const char *test_jacobian_const_vector_mult_log_vector(void) /* Create vector mult node: y = [2.0, 3.0, 4.0] ∘ log(x) */ double a[3] = {2.0, 3.0, 4.0}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_param = new_parameter(3, 1, PARAM_FIXED, 3, a); + expr *y = new_vector_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -29,11 +32,6 @@ const char *test_jacobian_const_vector_mult_log_vector(void) jacobian_init(y); y->eval_jacobian(y); - /* Expected jacobian (row-wise scaling): - * Row 0: 2.0 * [1/1] = [2.0] - * Row 1: 3.0 * [1/2] = [1.5] - * Row 2: 4.0 * [1/4] = [1.0] - */ double expected_x[3] = {2.0, 1.5, 1.0}; int expected_p[4] = {0, 1, 2, 3}; int expected_i[3] = {0, 1, 2}; @@ -49,7 +47,7 @@ const char *test_jacobian_const_vector_mult_log_vector(void) return 0; } -const char *test_jacobian_const_vector_mult_log_matrix(void) +const char *test_jacobian_vector_mult_log_matrix(void) { /* Create variable x as 2x2 matrix: [[1.0, 2.0], [4.0, 8.0]] */ double u_vals[4] = {1.0, 2.0, 4.0, 8.0}; @@ -60,7 +58,8 @@ const char *test_jacobian_const_vector_mult_log_matrix(void) /* Create vector mult node: y = [1.5, 2.5, 3.5, 4.5] ∘ log(x) */ double a[4] = {1.5, 2.5, 3.5, 4.5}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_param = new_parameter(4, 1, PARAM_FIXED, 4, a); + expr *y = new_vector_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -69,12 +68,6 @@ const char *test_jacobian_const_vector_mult_log_matrix(void) jacobian_init(y); y->eval_jacobian(y); - /* Expected jacobian (row-wise scaling): - * Row 0: 1.5 * [1/1] = [1.5] - * Row 1: 2.5 * [1/2] = [1.25] - * Row 2: 3.5 * [1/4] = [0.875] - * Row 3: 4.5 * [1/8] = [0.5625] - */ double expected_x[4] = {1.5, 1.25, 0.875, 0.5625}; int expected_p[5] = {0, 1, 2, 3, 4}; int expected_i[4] = {0, 1, 2, 3}; diff --git a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h index 4a6dde9..012cca3 100644 --- a/tests/jacobian_tests/composite/test_chain_rule_jacobian.h +++ b/tests/jacobian_tests/composite/test_chain_rule_jacobian.h @@ -81,8 +81,8 @@ const char *test_jacobian_Ax_Bx_multiply(void) CSR_Matrix *A = new_csr_random(2, 2, 1.0); CSR_Matrix *B = new_csr_random(2, 2, 1.0); expr *x = new_variable(2, 1, 1, 4); - expr *Ax = new_left_matmul(x, A); - expr *Bx = new_left_matmul(x, B); + expr *Ax = new_left_matmul(NULL, x, A); + expr *Bx = new_left_matmul(NULL, x, B); expr *multiply = new_elementwise_mult(Ax, Bx); mu_assert("check_jacobian failed", @@ -101,8 +101,8 @@ const char *test_jacobian_AX_BX_multiply(void) CSR_Matrix *A = new_csr_random(2, 2, 1.0); CSR_Matrix *B = new_csr_random(2, 2, 1.0); expr *X = new_variable(2, 2, 0, 4); - expr *AX = new_left_matmul(X, A); - expr *BX = new_left_matmul(X, B); + expr *AX = new_left_matmul(NULL, X, A); + expr *BX = new_left_matmul(NULL, X, B); expr *multiply = new_elementwise_mult(new_sin(AX), new_cos(BX)); mu_assert("check_jacobian failed", diff --git a/tests/problem/test_param_prob.h b/tests/problem/test_param_prob.h new file mode 100644 index 0000000..d76c84e --- /dev/null +++ b/tests/problem/test_param_prob.h @@ -0,0 +1,429 @@ +#ifndef TEST_PARAM_PROB_H +#define TEST_PARAM_PROB_H + +#include +#include +#include + +#include "affine.h" +#include "bivariate.h" +#include "elementwise_restricted_dom.h" +#include "expr.h" +#include "minunit.h" +#include "problem.h" +#include "subexpr.h" +#include "test_helpers.h" + +/* + * Test 1: param_scalar_mult in objective + * + * Problem: minimize a * sum(log(x)), no constraints, x size 2 + * a is a scalar parameter (param_id=0) + * + * At x=[1,2], a=3: + * obj = 3*(log(1)+log(2)) = 3*log(2) + * gradient = [3/1, 3/2] = [3.0, 1.5] + * + * After update a=5: + * obj = 5*log(2) + * gradient = [5.0, 2.5] + */ +const char *test_param_scalar_mult_problem(void) +{ + int n_vars = 2; + + /* Build tree: sum(a * log(x)) */ + expr *x = new_variable(2, 1, 0, n_vars); + expr *log_x = new_log(x); + expr *a_param = new_parameter(1, 1, 0, n_vars, NULL); + expr *scaled = new_scalar_mult(a_param, log_x); + expr *objective = new_sum(scaled, -1); + + /* Create problem (no constraints) */ + problem *prob = new_problem(objective, NULL, 0, false); + + /* Register parameter */ + expr *param_nodes[1] = {a_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set a=3 and evaluate at x=[1,2] */ + double theta[1] = {3.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + double obj_val = problem_objective_forward(prob, u); + problem_gradient(prob); + + double expected_obj = 3.0 * log(2.0); + mu_assert("obj wrong (a=3)", fabs(obj_val - expected_obj) < 1e-10); + + double expected_grad[2] = {3.0, 1.5}; + mu_assert("gradient wrong (a=3)", + cmp_double_array(prob->gradient_values, expected_grad, 2)); + + /* Update a=5 and re-evaluate */ + theta[0] = 5.0; + problem_update_params(prob, theta); + + obj_val = problem_objective_forward(prob, u); + problem_gradient(prob); + + expected_obj = 5.0 * log(2.0); + mu_assert("obj wrong (a=5)", fabs(obj_val - expected_obj) < 1e-10); + + double expected_grad2[2] = {5.0, 2.5}; + mu_assert("gradient wrong (a=5)", + cmp_double_array(prob->gradient_values, expected_grad2, 2)); + + free_problem(prob); + + return 0; +} + +/* + * Test 2: param_vector_mult in constraint + * + * Problem: minimize sum(x), subject to p ∘ x, x size 2 + * p is a vector parameter of size 2 (param_id=0) + * + * At x=[1,2], p=[3,4]: + * constraint_values = [3, 8] + * jacobian = diag([3, 4]) + * + * After update p=[5,6]: + * constraint_values = [5, 12] + * jacobian = diag([5, 6]) + */ +const char *test_param_vector_mult_problem(void) +{ + int n_vars = 2; + + /* Objective: sum(x) */ + expr *x_obj = new_variable(2, 1, 0, n_vars); + expr *objective = new_sum(x_obj, -1); + + /* Constraint: p ∘ x */ + expr *x_con = new_variable(2, 1, 0, n_vars); + expr *p_param = new_parameter(2, 1, 0, n_vars, NULL); + expr *constraint = new_vector_mult(p_param, x_con); + + expr *constraints[1] = {constraint}; + + /* Create problem */ + problem *prob = new_problem(objective, constraints, 1, false); + + expr *param_nodes[1] = {p_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set p=[3,4] and evaluate at x=[1,2] */ + double theta[2] = {3.0, 4.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv[2] = {3.0, 8.0}; + mu_assert("constraint values wrong (p=[3,4])", + cmp_double_array(prob->constraint_values, expected_cv, 2)); + + CSR_Matrix *jac = prob->jacobian; + mu_assert("jac rows wrong", jac->m == 2); + mu_assert("jac cols wrong", jac->n == 2); + + int expected_p[3] = {0, 1, 2}; + mu_assert("jac->p wrong (p=[3,4])", cmp_int_array(jac->p, expected_p, 3)); + + int expected_i[2] = {0, 1}; + mu_assert("jac->i wrong (p=[3,4])", cmp_int_array(jac->i, expected_i, 2)); + + double expected_x[2] = {3.0, 4.0}; + mu_assert("jac->x wrong (p=[3,4])", cmp_double_array(jac->x, expected_x, 2)); + + /* Update p=[5,6] and re-evaluate */ + double theta2[2] = {5.0, 6.0}; + problem_update_params(prob, theta2); + + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv2[2] = {5.0, 12.0}; + mu_assert("constraint values wrong (p=[5,6])", + cmp_double_array(prob->constraint_values, expected_cv2, 2)); + + double expected_x2[2] = {5.0, 6.0}; + mu_assert("jac->x wrong (p=[5,6])", cmp_double_array(jac->x, expected_x2, 2)); + + free_problem(prob); + + return 0; +} + +/* + * Test 3: left_param_matmul in constraint + * + * Problem: minimize sum(x), subject to A @ x, x size 2, A is 2x2 + * A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order) + * A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4] + * + * At x=[1,2]: + * constraint_values = [1*1+2*2, 3*1+4*2] = [5, 11] + * jacobian = [[1,2],[3,4]] + * + * After update A = [[5,6],[7,8]] → theta = [5,6,7,8]: + * constraint_values = [5*1+6*2, 7*1+8*2] = [17, 23] + * jacobian = [[5,6],[7,8]] + */ +const char *test_param_left_matmul_problem(void) +{ + int n_vars = 2; + + /* Objective: sum(x) */ + expr *x_obj = new_variable(2, 1, 0, n_vars); + expr *objective = new_sum(x_obj, -1); + + /* Constraint: A @ x */ + expr *x_con = new_variable(2, 1, 0, n_vars); + expr *A_param = new_parameter(2, 2, 0, n_vars, NULL); + + /* Dense 2x2 CSR with placeholder zeros */ + CSR_Matrix *A = new_csr_matrix(2, 2, 4); + int Ap[3] = {0, 2, 4}; + int Ai[4] = {0, 1, 0, 1}; + double Ax[4] = {0.0, 0.0, 0.0, 0.0}; + memcpy(A->p, Ap, 3 * sizeof(int)); + memcpy(A->i, Ai, 4 * sizeof(int)); + memcpy(A->x, Ax, 4 * sizeof(double)); + + expr *constraint = new_left_matmul(A_param, x_con, A); + free_csr_matrix(A); + + expr *constraints[1] = {constraint}; + + /* Create problem */ + problem *prob = new_problem(objective, constraints, 1, false); + + expr *param_nodes[1] = {A_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */ + double theta[4] = {1.0, 2.0, 3.0, 4.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv[2] = {5.0, 11.0}; + mu_assert("constraint values wrong (A1)", + cmp_double_array(prob->constraint_values, expected_cv, 2)); + + CSR_Matrix *jac = prob->jacobian; + mu_assert("jac rows wrong", jac->m == 2); + mu_assert("jac cols wrong", jac->n == 2); + + int expected_p[3] = {0, 2, 4}; + mu_assert("jac->p wrong (A1)", cmp_int_array(jac->p, expected_p, 3)); + + int expected_i[4] = {0, 1, 0, 1}; + mu_assert("jac->i wrong (A1)", cmp_int_array(jac->i, expected_i, 4)); + + double expected_x[4] = {1.0, 2.0, 3.0, 4.0}; + mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4)); + + /* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */ + double theta2[4] = {5.0, 6.0, 7.0, 8.0}; + problem_update_params(prob, theta2); + + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv2[2] = {17.0, 23.0}; + mu_assert("constraint values wrong (A2)", + cmp_double_array(prob->constraint_values, expected_cv2, 2)); + + double expected_x2[4] = {5.0, 6.0, 7.0, 8.0}; + mu_assert("jac->x wrong (A2)", cmp_double_array(jac->x, expected_x2, 4)); + + free_problem(prob); + + return 0; +} + +/* + * Test 4: right_param_matmul in constraint + * + * Problem: minimize sum(x), subject to x @ A, x size 1x2, A is 2x2 + * A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order) + * A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4] + * + * At x=[1,2]: + * constraint_values = [1*1+2*3, 1*2+2*4] = [7, 10] + * jacobian = [[1,3],[2,4]] = A^T + * + * After update A = [[5,6],[7,8]] → theta = [5,6,7,8]: + * constraint_values = [1*5+2*7, 1*6+2*8] = [19, 22] + * jacobian = [[5,7],[6,8]] = A^T + */ +const char *test_param_right_matmul_problem(void) +{ + int n_vars = 2; + + /* Objective: sum(x) */ + expr *x_obj = new_variable(1, 2, 0, n_vars); + expr *objective = new_sum(x_obj, -1); + + /* Constraint: x @ A */ + expr *x_con = new_variable(1, 2, 0, n_vars); + expr *A_param = new_parameter(2, 2, 0, n_vars, NULL); + + /* Dense 2x2 CSR with placeholder zeros */ + CSR_Matrix *A = new_csr_matrix(2, 2, 4); + int Ap[3] = {0, 2, 4}; + int Ai[4] = {0, 1, 0, 1}; + double Ax[4] = {0.0, 0.0, 0.0, 0.0}; + memcpy(A->p, Ap, 3 * sizeof(int)); + memcpy(A->i, Ai, 4 * sizeof(int)); + memcpy(A->x, Ax, 4 * sizeof(double)); + + expr *constraint = new_right_matmul(A_param, x_con, A); + free_csr_matrix(A); + + expr *constraints[1] = {constraint}; + + /* Create problem */ + problem *prob = new_problem(objective, constraints, 1, false); + + expr *param_nodes[1] = {A_param}; + problem_register_params(prob, param_nodes, 1); + problem_init_derivatives(prob); + + /* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */ + double theta[4] = {1.0, 2.0, 3.0, 4.0}; + problem_update_params(prob, theta); + + double u[2] = {1.0, 2.0}; + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv[2] = {7.0, 10.0}; + mu_assert("constraint values wrong (A1)", + cmp_double_array(prob->constraint_values, expected_cv, 2)); + + CSR_Matrix *jac = prob->jacobian; + mu_assert("jac rows wrong", jac->m == 2); + mu_assert("jac cols wrong", jac->n == 2); + + int expected_p[3] = {0, 2, 4}; + mu_assert("jac->p wrong (A1)", cmp_int_array(jac->p, expected_p, 3)); + + int expected_i[4] = {0, 1, 0, 1}; + mu_assert("jac->i wrong (A1)", cmp_int_array(jac->i, expected_i, 4)); + + double expected_x[4] = {1.0, 3.0, 2.0, 4.0}; + mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4)); + + /* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */ + double theta2[4] = {5.0, 6.0, 7.0, 8.0}; + problem_update_params(prob, theta2); + + problem_constraint_forward(prob, u); + problem_jacobian(prob); + + double expected_cv2[2] = {19.0, 22.0}; + mu_assert("constraint values wrong (A2)", + cmp_double_array(prob->constraint_values, expected_cv2, 2)); + + double expected_x2[4] = {5.0, 7.0, 6.0, 8.0}; + mu_assert("jac->x wrong (A2)", cmp_double_array(jac->x, expected_x2, 4)); + + free_problem(prob); + + return 0; +} + +/* + * Test 5: PARAM_FIXED params are skipped by problem_update_params + * + * Problem: minimize a * sum(log(x)) + b * sum(x), no constraints, x size 2 + * a is a FIXED scalar parameter (param_id=PARAM_FIXED, value=2.0) + * b is an updatable scalar parameter (param_id=0) + * + * At x=[1,2], a=2, b=3: + * obj = 2*(log(1)+log(2)) + 3*(1+2) = 2*log(2) + 9 + * gradient = [2/1 + 3, 2/2 + 3] = [5.0, 4.0] + * + * After update theta={5.0} (only b changes to 5, a stays 2): + * obj = 2*log(2) + 5*3 = 2*log(2) + 15 + * gradient = [2/1 + 5, 2/2 + 5] = [7.0, 6.0] + */ +const char *test_param_fixed_skip_in_update(void) +{ + int n_vars = 2; + + /* Build tree: a * sum(log(x)) + b * sum(x) */ + expr *x1 = new_variable(2, 1, 0, n_vars); + expr *log_x = new_log(x1); + double a_val = 2.0; + expr *a_param = new_parameter(1, 1, PARAM_FIXED, n_vars, &a_val); + expr *a_log = new_scalar_mult(a_param, log_x); + expr *sum_a_log = new_sum(a_log, -1); + + expr *x2 = new_variable(2, 1, 0, n_vars); + expr *b_param = new_parameter(1, 1, 0, n_vars, NULL); + expr *b_x = new_scalar_mult(b_param, x2); + expr *sum_b_x = new_sum(b_x, -1); + + expr *objective = new_add(sum_a_log, sum_b_x); + + /* Create problem and register BOTH params */ + problem *prob = new_problem(objective, NULL, 0, false); + + expr *param_nodes[2] = {a_param, b_param}; + problem_register_params(prob, param_nodes, 2); + problem_init_derivatives(prob); + + /* Set b=3 and evaluate at x=[1,2] */ + double theta[1] = {3.0}; + problem_update_params(prob, theta); + + /* Verify a is still 2.0 (not overwritten) */ + mu_assert("a_param changed after update", fabs(a_param->value[0] - 2.0) < 1e-10); + + double u[2] = {1.0, 2.0}; + double obj_val = problem_objective_forward(prob, u); + problem_gradient(prob); + + double expected_obj = 2.0 * log(2.0) + 9.0; + mu_assert("obj wrong (b=3)", fabs(obj_val - expected_obj) < 1e-10); + + double expected_grad[2] = {5.0, 4.0}; + mu_assert("gradient wrong (b=3)", + cmp_double_array(prob->gradient_values, expected_grad, 2)); + + /* Update b=5, a should stay 2 */ + theta[0] = 5.0; + problem_update_params(prob, theta); + + mu_assert("a_param changed after second update", + fabs(a_param->value[0] - 2.0) < 1e-10); + + obj_val = problem_objective_forward(prob, u); + problem_gradient(prob); + + double expected_obj2 = 2.0 * log(2.0) + 15.0; + mu_assert("obj wrong (b=5)", fabs(obj_val - expected_obj2) < 1e-10); + + double expected_grad2[2] = {7.0, 6.0}; + mu_assert("gradient wrong (b=5)", + cmp_double_array(prob->gradient_values, expected_grad2, 2)); + + free_problem(prob); + + return 0; +} + +#endif /* TEST_PARAM_PROB_H */ diff --git a/tests/profiling/profile_left_matmul.h b/tests/profiling/profile_left_matmul.h index 5ff98be..aabaece 100644 --- a/tests/profiling/profile_left_matmul.h +++ b/tests/profiling/profile_left_matmul.h @@ -31,7 +31,7 @@ const char *profile_left_matmul(void) } A->p[n] = n * n; - expr *AX = new_left_matmul(X, A); + expr *AX = new_left_matmul(NULL, X, A); double *x_vals = (double *) malloc(n * n * sizeof(double)); for (int i = 0; i < n * n; i++) diff --git a/tests/wsum_hess/affine/test_left_matmul.h b/tests/wsum_hess/affine/test_left_matmul.h index 3ccaff1..84ef5e1 100644 --- a/tests/wsum_hess/affine/test_left_matmul.h +++ b/tests/wsum_hess/affine/test_left_matmul.h @@ -63,7 +63,7 @@ const char *test_wsum_hess_left_matmul(void) memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(NULL, log_x, A); A_log_x->forward(A_log_x, x_vals); jacobian_init(A_log_x); @@ -118,7 +118,7 @@ const char *test_wsum_hess_left_matmul_exp_composite(void) expr *Bx = new_linear(x, B, NULL); expr *exp_Bx = new_exp(Bx); - expr *A_exp_Bx = new_left_matmul(exp_Bx, A); + expr *A_exp_Bx = new_left_matmul(NULL, exp_Bx, A); mu_assert("check_wsum_hess failed", check_wsum_hess(A_exp_Bx, x_vals, w, NUMERICAL_DIFF_DEFAULT_H)); @@ -170,7 +170,7 @@ const char *test_wsum_hess_left_matmul_matrix(void) memcpy(A->x, A_x, 7 * sizeof(double)); expr *log_x = new_log(x); - expr *A_log_x = new_left_matmul(log_x, A); + expr *A_log_x = new_left_matmul(NULL, log_x, A); A_log_x->forward(A_log_x, x_vals); jacobian_init(A_log_x); diff --git a/tests/wsum_hess/affine/test_right_matmul.h b/tests/wsum_hess/affine/test_right_matmul.h index e697b58..1757a9a 100644 --- a/tests/wsum_hess/affine/test_right_matmul.h +++ b/tests/wsum_hess/affine/test_right_matmul.h @@ -33,7 +33,7 @@ const char *test_wsum_hess_right_matmul(void) memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); jacobian_init(log_x_A); @@ -83,7 +83,7 @@ const char *test_wsum_hess_right_matmul_vector(void) memcpy(A->x, A_x, 4 * sizeof(double)); expr *log_x = new_log(x); - expr *log_x_A = new_right_matmul(log_x, A); + expr *log_x_A = new_right_matmul(NULL, log_x, A); log_x_A->forward(log_x_A, x_vals); jacobian_init(log_x_A); diff --git a/tests/wsum_hess/affine/test_const_scalar_mult.h b/tests/wsum_hess/affine/test_scalar_mult.h similarity index 72% rename from tests/wsum_hess/affine/test_const_scalar_mult.h rename to tests/wsum_hess/affine/test_scalar_mult.h index 31fb6a6..7bdd20a 100644 --- a/tests/wsum_hess/affine/test_const_scalar_mult.h +++ b/tests/wsum_hess/affine/test_scalar_mult.h @@ -2,15 +2,17 @@ #include #include "affine.h" +#include "bivariate.h" #include "elementwise_full_dom.h" #include "elementwise_restricted_dom.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" -/* Test: y = a * log(x) where a is a scalar constant */ +/* Test: y = a * log(x) where a is a scalar parameter */ -const char *test_wsum_hess_const_scalar_mult_log_vector(void) +const char *test_wsum_hess_scalar_mult_log_vector(void) { /* Create variable x: [1.0, 2.0, 4.0] */ double u_vals[3] = {1.0, 2.0, 4.0}; @@ -20,8 +22,9 @@ const char *test_wsum_hess_const_scalar_mult_log_vector(void) expr *log_node = new_log(x); /* Create scalar mult node: y = 2.5 * log(x) */ - double a = 2.5; - expr *y = new_const_scalar_mult(a, log_node); + double a_val = 2.5; + expr *a_param = new_parameter(1, 1, PARAM_FIXED, 3, &a_val); + expr *y = new_scalar_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -31,15 +34,6 @@ const char *test_wsum_hess_const_scalar_mult_log_vector(void) double w[3] = {1.0, 0.5, 0.25}; y->eval_wsum_hess(y, w); - /* For y = a * log(x), the Hessian is: - * H = a * H_log = a * diag([-1/x_i^2]) - * With weights w and scalar a: - * H_weighted = a * diag([-w_i / x_i^2]) - * - * Expected diagonal: 2.5 * [-1/1^2, -0.5/2^2, -0.25/4^2] - * = 2.5 * [-1, -0.125, -0.015625] - * = [-2.5, -0.3125, -0.0390625] - */ double expected_x[3] = {-2.5, -0.3125, -0.0390625}; int expected_p[4] = {0, 1, 2, 3}; int expected_i[3] = {0, 1, 2}; @@ -55,7 +49,7 @@ const char *test_wsum_hess_const_scalar_mult_log_vector(void) return 0; } -const char *test_wsum_hess_const_scalar_mult_log_matrix(void) +const char *test_wsum_hess_scalar_mult_log_matrix(void) { /* Create variable x as 2x2 matrix: [[1.0, 2.0], [4.0, 8.0]] */ double u_vals[4] = {1.0, 2.0, 4.0, 8.0}; @@ -65,8 +59,9 @@ const char *test_wsum_hess_const_scalar_mult_log_matrix(void) expr *log_node = new_log(x); /* Create scalar mult node: y = 3.0 * log(x) */ - double a = 3.0; - expr *y = new_const_scalar_mult(a, log_node); + double a_val = 3.0; + expr *a_param = new_parameter(1, 1, PARAM_FIXED, 4, &a_val); + expr *y = new_scalar_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -76,10 +71,6 @@ const char *test_wsum_hess_const_scalar_mult_log_matrix(void) double w[4] = {1.0, 1.0, 1.0, 1.0}; y->eval_wsum_hess(y, w); - /* Expected diagonal: 3.0 * [-1/1^2, -1/2^2, -1/4^2, -1/8^2] - * = 3.0 * [-1, -0.25, -0.0625, -0.015625] - * = [-3.0, -0.75, -0.1875, -0.046875] - */ double expected_x[4] = {-3.0, -0.75, -0.1875, -0.046875}; int expected_p[5] = {0, 1, 2, 3, 4}; int expected_i[4] = {0, 1, 2, 3}; diff --git a/tests/wsum_hess/affine/test_const_vector_mult.h b/tests/wsum_hess/affine/test_vector_mult.h similarity index 72% rename from tests/wsum_hess/affine/test_const_vector_mult.h rename to tests/wsum_hess/affine/test_vector_mult.h index 0e02a4f..d527199 100644 --- a/tests/wsum_hess/affine/test_const_vector_mult.h +++ b/tests/wsum_hess/affine/test_vector_mult.h @@ -2,15 +2,17 @@ #include #include "affine.h" +#include "bivariate.h" #include "elementwise_full_dom.h" #include "elementwise_restricted_dom.h" #include "expr.h" #include "minunit.h" +#include "subexpr.h" #include "test_helpers.h" -/* Test: y = a ∘ log(x) where a is a constant vector */ +/* Test: y = a ∘ log(x) where a is a parameter vector */ -const char *test_wsum_hess_const_vector_mult_log_vector(void) +const char *test_wsum_hess_vector_mult_log_vector(void) { /* Create variable x: [1.0, 2.0, 4.0] */ double u_vals[3] = {1.0, 2.0, 4.0}; @@ -21,7 +23,8 @@ const char *test_wsum_hess_const_vector_mult_log_vector(void) /* Create vector mult node: y = [2.0, 3.0, 4.0] ∘ log(x) */ double a[3] = {2.0, 3.0, 4.0}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_param = new_parameter(3, 1, PARAM_FIXED, 3, a); + expr *y = new_vector_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -31,13 +34,6 @@ const char *test_wsum_hess_const_vector_mult_log_vector(void) double w[3] = {1.0, 0.5, 0.25}; y->eval_wsum_hess(y, w); - /* For y = a ∘ log(x), the weighted Hessian is: - * H = diag([a_i * (-w_i / x_i^2)]) - * - * Expected diagonal: [2.0 * (-1 / 1^2), 3.0 * (-0.5 / 2^2), 4.0 * (-0.25 / 4^2)] - * = [2.0 * (-1), 3.0 * (-0.125), 4.0 * (-0.015625)] - * = [-2.0, -0.375, -0.0625] - */ double expected_x[3] = {-2.0, -0.375, -0.0625}; int expected_p[4] = {0, 1, 2, 3}; int expected_i[3] = {0, 1, 2}; @@ -53,7 +49,7 @@ const char *test_wsum_hess_const_vector_mult_log_vector(void) return 0; } -const char *test_wsum_hess_const_vector_mult_log_matrix(void) +const char *test_wsum_hess_vector_mult_log_matrix(void) { /* Create variable x as 2x2 matrix: [[1.0, 2.0], [4.0, 8.0]] */ double u_vals[4] = {1.0, 2.0, 4.0, 8.0}; @@ -64,7 +60,8 @@ const char *test_wsum_hess_const_vector_mult_log_matrix(void) /* Create vector mult node: y = [1.5, 2.5, 3.5, 4.5] ∘ log(x) */ double a[4] = {1.5, 2.5, 3.5, 4.5}; - expr *y = new_const_vector_mult(a, log_node); + expr *a_param = new_parameter(4, 1, PARAM_FIXED, 4, a); + expr *y = new_vector_mult(a_param, log_node); /* Forward pass */ y->forward(y, u_vals); @@ -74,11 +71,6 @@ const char *test_wsum_hess_const_vector_mult_log_matrix(void) double w[4] = {1.0, 1.0, 1.0, 1.0}; y->eval_wsum_hess(y, w); - /* Expected diagonal: [1.5 * (-1 / 1^2), 2.5 * (-1 / 2^2), - * 3.5 * (-1 / 4^2), 4.5 * (-1 / 8^2)] - * = [1.5 * (-1), 2.5 * (-0.25), 3.5 * (-0.0625), 4.5 * - * (-0.015625)] = [-1.5, -0.625, -0.21875, -0.0703125] - */ double expected_x[4] = {-1.5, -0.625, -0.21875, -0.0703125}; int expected_p[5] = {0, 1, 2, 3, 4}; int expected_i[4] = {0, 1, 2, 3};