From 339a4cf9c4b7a1e6a57273a388f357391bd0f86e Mon Sep 17 00:00:00 2001 From: Transurgeon Date: Sun, 28 Jun 2026 18:55:34 -0400 Subject: [PATCH 1/2] Add native kron (Kronecker product) affine atom cvxpy's kron(A, B) always has one variable-free operand, so every output entry depends on a single child entry: Z[OUT] = coeff[OUT] * child[child_row[OUT]]. The output Jacobian is therefore the child Jacobian's rows gathered (with repetition) and scaled by the variable-free operand -- no coefficient matrix, no matmul, no CSC conversion; O(nnz(result)). child_row[] and coeff_idx[] depend only on the operand shapes and are precomputed once in new_kron. Handles kron(param/const, var) and kron(var, param/const), parametric or constant, with column-major (Fortran) flattening, and re-evaluates the variable-free operand each solve. forward, Jacobian and the affine Hessian backprop are all scaled gathers. Adds forward/Jacobian/wsum_hess unit tests (both forms, scalar operand, and numerical Jacobian/Hessian checks on a composite arg); all_tests now 405. Co-Authored-By: Claude Opus 4.8 (1M context) --- include/atoms/affine.h | 8 + include/subexpr.h | 16 ++ src/atoms/affine/kron.c | 217 ++++++++++++++++++++++++ tests/all_tests.c | 11 ++ tests/forward_pass/affine/test_kron.h | 78 +++++++++ tests/jacobian_tests/affine/test_kron.h | 91 ++++++++++ tests/wsum_hess/affine/test_kron.h | 53 ++++++ 7 files changed, 474 insertions(+) create mode 100644 src/atoms/affine/kron.c create mode 100644 tests/forward_pass/affine/test_kron.h create mode 100644 tests/jacobian_tests/affine/test_kron.h create mode 100644 tests/wsum_hess/affine/test_kron.h diff --git a/include/atoms/affine.h b/include/atoms/affine.h index 49b4637..831b7ca 100644 --- a/include/atoms/affine.h +++ b/include/atoms/affine.h @@ -80,4 +80,12 @@ expr *new_vector_mult(expr *param_node, expr *child); kernel and may either represent a constant or an updatable parameter */ expr *new_convolve(expr *param_node, expr *child); +/* Kronecker product Z = kron(A, B). Exactly one operand is variable-free and is + passed as param_node (constant or updatable parameter); the other carries the + variables and is passed as child. const_is_left selects which operand is the + parameter: 1 -> A=param_node, B=child; 0 -> A=child, B=param_node. (p, q) are + A's dims and (r, s) are B's dims. */ +expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, + int r, int s); + #endif /* AFFINE_H */ diff --git a/include/subexpr.h b/include/subexpr.h index 26aaaf6..15f627e 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -173,6 +173,22 @@ typedef struct convolve_expr CSC_matrix *Jchild_CSC; } convolve_expr; +/* Kronecker product Z = kron(A, B) where exactly one operand is variable-free + * (held by param_source) and the other (child = node->left) carries the + * variables. Every output entry depends on a single child entry, so the output + * Jacobian is the child Jacobian's rows gathered (with repetition) and scaled -- + * no coefficient matrix or matmul. child_row[OUT] and coeff_idx[OUT] depend only + * on the operand shapes and are precomputed once at construction. */ +typedef struct kron_expr +{ + expr base; + expr *param_source; /* the constant/parameter operand */ + int p, q, r, s; /* A is p x q, B is r x s */ + int const_is_left; /* 1: A=param, B=child; 0: A=child, B=param */ + int *child_row; /* size_out: child entry each output row gathers */ + int *coeff_idx; /* size_out: index into param_source->value (the scale) */ +} kron_expr; + /* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children * may be composite expressions. */ typedef struct matmul_expr diff --git a/src/atoms/affine/kron.c b/src/atoms/affine/kron.c new file mode 100644 index 0000000..ddc4562 --- /dev/null +++ b/src/atoms/affine/kron.c @@ -0,0 +1,217 @@ +/* + * Copyright 2026 Daniel Cederberg and William Zhang + * + * This file is part of the SparseDiffEngine project. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atoms/affine.h" +#include "subexpr.h" +#include "utils/CSR_matrix.h" +#include "utils/sparse_matrix.h" +#include "utils/tracked_alloc.h" +#include +#include +#include + +/* Kronecker product Z = kron(A, B), where exactly one operand is variable-free + * (param_source) and the other (child = node->left) carries the variables. + * + * With column-major (Fortran) flattening, an output index OUT = I + J*(p*r) + * decomposes as I = i*r + k and J = j*s + l (i in [0,p), k in [0,r), j in [0,q), + * l in [0,s)). The output block (i, j) inner (k, l) equals A[i,j] * B[k,l], so + * every output entry depends on a single child entry: + * + * Z[OUT] = coeff[OUT] * vec(child)[child_row[OUT]] + * J_kron[OUT,:] = coeff[OUT] * J_child[child_row[OUT], :] + * + * where coeff[OUT] = param_source->value[coeff_idx[OUT]]. child_row[] and + * coeff_idx[] depend only on the shapes and are filled once in new_kron, so + * forward, Jacobian and (affine) Hessian are all scaled gathers -- no + * size_out x size_child coefficient matrix and no sparse matmul. */ + +static void forward(expr *node, const double *u) +{ + expr *child = node->left; + kron_expr *knode = (kron_expr *) node; + + /* Pull current parameter values through any broadcast/promote wrappers. */ + if (knode->base.needs_parameter_refresh) + { + knode->param_source->forward(knode->param_source, NULL); + knode->base.needs_parameter_refresh = false; + } + + child->forward(child, u); + + const double *a = knode->param_source->value; + const double *x = child->value; + double *y = node->value; + for (int out = 0; out < node->size; out++) + { + y[out] = a[knode->coeff_idx[out]] * x[knode->child_row[out]]; + } +} + +static void jacobian_init_impl(expr *node) +{ + expr *child = node->left; + kron_expr *knode = (kron_expr *) node; + + jacobian_init(child); + + /* Output row OUT shares the column set of child row child_row[OUT]. Build + the result CSR sparsity by copying those child rows (with repetition). */ + CSR_matrix *Jc = child->jacobian->to_csr(child->jacobian); + + int total = 0; + for (int out = 0; out < node->size; out++) + { + int cc = knode->child_row[out]; + total += Jc->p[cc + 1] - Jc->p[cc]; + } + + CSR_matrix *Jk = new_CSR_matrix(node->size, node->n_vars, total); + int idx = 0; + Jk->p[0] = 0; + for (int out = 0; out < node->size; out++) + { + int cc = knode->child_row[out]; + for (int t = Jc->p[cc]; t < Jc->p[cc + 1]; t++) + { + Jk->i[idx++] = Jc->i[t]; + } + Jk->p[out + 1] = idx; + } + node->jacobian = new_sparse_matrix(Jk); +} + +static void eval_jacobian(expr *node) +{ + expr *child = node->left; + kron_expr *knode = (kron_expr *) node; + + child->eval_jacobian(child); + + /* Child sparsity is fixed after jacobian_init, so the result row offsets + still align; refill values as scale * child-row-values. */ + CSR_matrix *Jc = child->jacobian->to_csr(child->jacobian); + CSR_matrix *Jk = node->jacobian->to_csr(node->jacobian); + const double *a = knode->param_source->value; + + int idx = 0; + for (int out = 0; out < node->size; out++) + { + int cc = knode->child_row[out]; + double scale = a[knode->coeff_idx[out]]; + for (int t = Jc->p[cc]; t < Jc->p[cc + 1]; t++) + { + Jk->x[idx++] = scale * Jc->x[t]; + } + } +} + +static void wsum_hess_init_impl(expr *node) +{ + expr *child = node->left; + + wsum_hess_init(child); + node->wsum_hess = child->wsum_hess->copy_sparsity(child->wsum_hess); + /* backprop workspace: one weight per child entry */ + node->work->dwork = (double *) sp_malloc(child->size * sizeof(double)); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *child = node->left; + kron_expr *knode = (kron_expr *) node; + const double *a = knode->param_source->value; + double *w_prime = node->work->dwork; + + /* kron is affine in child, so the Hessian is the child's with weights pushed + back through the linear gather: w'[child_row] += coeff * w[OUT]. Many + output rows map to one child entry, hence the accumulation. */ + memset(w_prime, 0, child->size * sizeof(double)); + for (int out = 0; out < node->size; out++) + { + w_prime[knode->child_row[out]] += a[knode->coeff_idx[out]] * w[out]; + } + + child->eval_wsum_hess(child, w_prime); + memcpy(node->wsum_hess->x, child->wsum_hess->x, + node->wsum_hess->nnz * sizeof(double)); +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +static void free_type_data(expr *node) +{ + kron_expr *knode = (kron_expr *) node; + sp_free(knode->child_row); + sp_free(knode->coeff_idx); + free_expr(knode->param_source); +} + +expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, + int r, int s) +{ + int d1 = p * r; + int d2 = q * s; + int size_out = d1 * d2; + + kron_expr *knode = (kron_expr *) sp_calloc(1, sizeof(kron_expr)); + expr *node = &knode->base; + init_expr(node, d1, d2, child->n_vars, forward, jacobian_init_impl, + eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, + free_type_data); + node->left = child; + expr_retain(child); + + knode->param_source = param_node; + expr_retain(param_node); + knode->p = p; + knode->q = q; + knode->r = r; + knode->s = s; + knode->const_is_left = const_is_left; + + knode->child_row = (int *) sp_malloc(size_out * sizeof(int)); + knode->coeff_idx = (int *) sp_malloc(size_out * sizeof(int)); + + int n_rows = p * r; /* number of output rows */ + for (int out = 0; out < size_out; out++) + { + int I = out % n_rows; + int J = out / n_rows; + int i = I / r, k = I % r; + int j = J / s, l = J % s; + if (const_is_left) + { + /* A = param (p x q), B = child (r x s) */ + knode->child_row[out] = k + l * r; /* col-major into B */ + knode->coeff_idx[out] = i + j * p; /* col-major into A */ + } + else + { + /* A = child (p x q), B = param (r x s) */ + knode->child_row[out] = i + j * p; /* col-major into A */ + knode->coeff_idx[out] = k + l * r; /* col-major into B */ + } + } + + knode->base.needs_parameter_refresh = true; + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index b9dff88..9c1effc 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -8,6 +8,7 @@ #include "forward_pass/affine/test_add.h" #include "forward_pass/affine/test_broadcast.h" #include "forward_pass/affine/test_convolve.h" +#include "forward_pass/affine/test_kron.h" #include "forward_pass/affine/test_diag_mat.h" #include "forward_pass/affine/test_hstack.h" #include "forward_pass/affine/test_left_matmul_dense.h" @@ -27,6 +28,7 @@ #include "forward_pass/other/test_prod_axis_zero.h" #include "jacobian_tests/affine/test_broadcast.h" #include "jacobian_tests/affine/test_convolve.h" +#include "jacobian_tests/affine/test_kron.h" #include "jacobian_tests/affine/test_diag_mat.h" #include "jacobian_tests/affine/test_hstack.h" #include "jacobian_tests/affine/test_index.h" @@ -73,6 +75,7 @@ #include "utils/test_stacked_pd.h" #include "wsum_hess/affine/test_broadcast.h" #include "wsum_hess/affine/test_convolve.h" +#include "wsum_hess/affine/test_kron.h" #include "wsum_hess/affine/test_diag_mat.h" #include "wsum_hess/affine/test_hstack.h" #include "wsum_hess/affine/test_index.h" @@ -151,6 +154,9 @@ int main(void) mu_run_test(test_convolve_forward, tests_run); mu_run_test(test_convolve_forward_row, tests_run); mu_run_test(test_convolve_forward_param, tests_run); + mu_run_test(test_kron_forward_const_left, tests_run); + mu_run_test(test_kron_forward_const_right, tests_run); + mu_run_test(test_kron_forward_scalar, tests_run); mu_run_test(test_diag_mat_forward, tests_run); mu_run_test(test_upper_tri_forward_4x4, tests_run); @@ -244,6 +250,9 @@ int main(void) mu_run_test(test_jacobian_matmul, tests_run); mu_run_test(test_jacobian_convolve, tests_run); mu_run_test(test_jacobian_convolve_composite, tests_run); + mu_run_test(test_jacobian_kron_const_left, tests_run); + mu_run_test(test_jacobian_kron_const_right, tests_run); + mu_run_test(test_jacobian_kron_composite, tests_run); mu_run_test(test_jacobian_transpose, tests_run); mu_run_test(test_jacobian_transpose_pd_preserved, tests_run); mu_run_test(test_diag_mat_jacobian_variable, tests_run); @@ -317,6 +326,8 @@ int main(void) mu_run_test(test_wsum_hess_right_matmul_vector, tests_run); mu_run_test(test_wsum_hess_convolve, tests_run); mu_run_test(test_wsum_hess_convolve_composite, tests_run); + mu_run_test(test_wsum_hess_kron, tests_run); + mu_run_test(test_wsum_hess_kron_composite, tests_run); mu_run_test(test_wsum_hess_broadcast_row, tests_run); mu_run_test(test_wsum_hess_broadcast_col, tests_run); mu_run_test(test_wsum_hess_broadcast_scalar_to_matrix, tests_run); diff --git a/tests/forward_pass/affine/test_kron.h b/tests/forward_pass/affine/test_kron.h new file mode 100644 index 0000000..3a04df3 --- /dev/null +++ b/tests/forward_pass/affine/test_kron.h @@ -0,0 +1,78 @@ +#include +#include + +#include "atoms/affine.h" +#include "expr.h" +#include "minunit.h" +#include "subexpr.h" +#include "test_helpers.h" + +const char *test_kron_forward_const_left(void) +{ + /* Z = kron(A, B), A = [[1,2],[3,4]] constant, B = [[5,6],[7,8]] variable. + * np.kron gives a 4x4; check its column-major flatten. */ + double A[4] = {1.0, 3.0, 2.0, 4.0}; /* col-major [[1,2],[3,4]] */ + expr *A_param = new_parameter(2, 2, PARAM_FIXED, 4, A); + expr *B = new_variable(2, 2, 0, 4); + expr *Z = new_kron(A_param, B, 1, 2, 2, 2, 2); + + double u[4] = {5.0, 7.0, 6.0, 8.0}; /* col-major [[5,6],[7,8]] */ + Z->forward(Z, u); + + double expected[16] = {5, 7, 15, 21, 6, 8, 18, 24, + 10, 14, 20, 28, 12, 16, 24, 32}; + + mu_assert("kron const-left d1=4", Z->d1 == 4); + mu_assert("kron const-left d2=4", Z->d2 == 4); + mu_assert("kron const-left forward failed", + cmp_double_array(Z->value, expected, 16)); + + free_expr(Z); + return 0; +} + +const char *test_kron_forward_const_right(void) +{ + /* Z = kron(A, B), A = [[5,6],[7,8]] variable, B = [[1,2],[3,4]] constant. */ + double B[4] = {1.0, 3.0, 2.0, 4.0}; /* col-major [[1,2],[3,4]] */ + expr *B_param = new_parameter(2, 2, PARAM_FIXED, 4, B); + expr *A = new_variable(2, 2, 0, 4); + expr *Z = new_kron(B_param, A, 0, 2, 2, 2, 2); + + double u[4] = {5.0, 7.0, 6.0, 8.0}; /* col-major [[5,6],[7,8]] */ + Z->forward(Z, u); + + double expected[16] = {5, 15, 7, 21, 10, 20, 14, 28, + 6, 18, 8, 24, 12, 24, 16, 32}; + + mu_assert("kron const-right d1=4", Z->d1 == 4); + mu_assert("kron const-right d2=4", Z->d2 == 4); + mu_assert("kron const-right forward failed", + cmp_double_array(Z->value, expected, 16)); + + free_expr(Z); + return 0; +} + +const char *test_kron_forward_scalar(void) +{ + /* Z = kron(y, B), y a scalar variable, B = [[1,2],[3,4]] constant. + * Z = y * B; for y = 3 expect [[3,6],[9,12]] (col-major 3*B). */ + double B[4] = {1.0, 3.0, 2.0, 4.0}; + expr *B_param = new_parameter(2, 2, PARAM_FIXED, 1, B); + expr *y = new_variable(1, 1, 0, 1); + expr *Z = new_kron(B_param, y, 0, 1, 1, 2, 2); + + double u[1] = {3.0}; + Z->forward(Z, u); + + double expected[4] = {3.0, 9.0, 6.0, 12.0}; + + mu_assert("kron scalar d1=2", Z->d1 == 2); + mu_assert("kron scalar d2=2", Z->d2 == 2); + mu_assert("kron scalar forward failed", + cmp_double_array(Z->value, expected, 4)); + + free_expr(Z); + return 0; +} diff --git a/tests/jacobian_tests/affine/test_kron.h b/tests/jacobian_tests/affine/test_kron.h new file mode 100644 index 0000000..114b11f --- /dev/null +++ b/tests/jacobian_tests/affine/test_kron.h @@ -0,0 +1,91 @@ +#include +#include + +#include "atoms/affine.h" +#include "atoms/elementwise_restricted_dom.h" +#include "expr.h" +#include "minunit.h" +#include "numerical_diff.h" +#include "subexpr.h" +#include "test_helpers.h" + +const char *test_jacobian_kron_const_left(void) +{ + /* Z = kron([[1,2],[3,4]], B), B a 2x2 leaf variable. Each output row has a + * single nonzero: column = the B entry it gathers, value = the A entry. */ + double A[4] = {1.0, 3.0, 2.0, 4.0}; + expr *A_param = new_parameter(2, 2, PARAM_FIXED, 4, A); + expr *B = new_variable(2, 2, 0, 4); + expr *Z = new_kron(A_param, B, 1, 2, 2, 2, 2); + + double u[4] = {5.0, 7.0, 6.0, 8.0}; + Z->forward(Z, u); + jacobian_init(Z); + Z->eval_jacobian(Z); + + mu_assert("kron J rows", Z->jacobian->m == 16); + mu_assert("kron J cols", Z->jacobian->n == 4); + mu_assert("kron J nnz", Z->jacobian->nnz == 16); + + int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + int expected_i[16] = {0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3}; + double expected_x[16] = {1, 1, 3, 3, 1, 1, 3, 3, 2, 2, 4, 4, 2, 2, 4, 4}; + + mu_assert("kron const-left J sparsity", + cmp_sparsity(Z->jacobian, expected_p, expected_i, 16, 16)); + mu_assert("kron const-left J values", + cmp_values(Z->jacobian, expected_x, 16)); + + free_expr(Z); + return 0; +} + +const char *test_jacobian_kron_const_right(void) +{ + /* Z = kron(A, [[1,2],[3,4]]), A a 2x2 leaf variable (const_is_left = 0). */ + double B[4] = {1.0, 3.0, 2.0, 4.0}; + expr *B_param = new_parameter(2, 2, PARAM_FIXED, 4, B); + expr *A = new_variable(2, 2, 0, 4); + expr *Z = new_kron(B_param, A, 0, 2, 2, 2, 2); + + double u[4] = {5.0, 7.0, 6.0, 8.0}; + Z->forward(Z, u); + jacobian_init(Z); + Z->eval_jacobian(Z); + + mu_assert("kron J rows", Z->jacobian->m == 16); + mu_assert("kron J nnz", Z->jacobian->nnz == 16); + + int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + int expected_i[16] = {0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 3, 3, 2, 2, 3, 3}; + double expected_x[16] = {1, 3, 1, 3, 2, 4, 2, 4, 1, 3, 1, 3, 2, 4, 2, 4}; + + mu_assert("kron const-right J sparsity", + cmp_sparsity(Z->jacobian, expected_p, expected_i, 16, 16)); + mu_assert("kron const-right J values", + cmp_values(Z->jacobian, expected_x, 16)); + + free_expr(Z); + return 0; +} + +const char *test_jacobian_kron_composite(void) +{ + /* Z = kron([[1,2],[3,4]], log(X)) — composite variable operand; check the + * gathered/scaled Jacobian against finite differences. */ + double A[4] = {1.0, 3.0, 2.0, 4.0}; + double x_vals[4] = {1.0, 2.0, 3.0, 4.0}; + + expr *A_param = new_parameter(2, 2, PARAM_FIXED, 4, A); + expr *X = new_variable(2, 2, 0, 4); + expr *log_X = new_log(X); + expr *Z = new_kron(A_param, log_X, 1, 2, 2, 2, 2); + + mu_assert("kron composite Jacobian check failed", + check_jacobian_num(Z, x_vals, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} diff --git a/tests/wsum_hess/affine/test_kron.h b/tests/wsum_hess/affine/test_kron.h new file mode 100644 index 0000000..cf0fa2c --- /dev/null +++ b/tests/wsum_hess/affine/test_kron.h @@ -0,0 +1,53 @@ +#include + +#include "atoms/affine.h" +#include "atoms/elementwise_restricted_dom.h" +#include "expr.h" +#include "minunit.h" +#include "numerical_diff.h" +#include "subexpr.h" +#include "test_helpers.h" + +const char *test_wsum_hess_kron(void) +{ + /* kron(A, B) is linear in a leaf variable B, so its weighted Hessian is + * zero for any weights. */ + double A[4] = {1.0, 3.0, 2.0, 4.0}; + expr *A_param = new_parameter(2, 2, PARAM_FIXED, 4, A); + expr *B = new_variable(2, 2, 0, 4); + expr *Z = new_kron(A_param, B, 1, 2, 2, 2, 2); + + double u[4] = {5.0, 7.0, 6.0, 8.0}; + double w[16] = {1, -1, 2, 3, -2, 1, 0, 2, -1, 1, 3, -3, 2, 1, -2, 1}; + + Z->forward(Z, u); + jacobian_init(Z); + wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + mu_assert("kron wsum_hess square", Z->wsum_hess->m == 4 && Z->wsum_hess->n == 4); + mu_assert("kron wsum_hess zero for linear arg", Z->wsum_hess->nnz == 0); + + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_kron_composite(void) +{ + /* Z = kron([[1,2],[3,4]], log(X)) — nonlinear in X; backprop of weights + * through the linear gather must match a numerical second-derivative. */ + double A[4] = {1.0, 3.0, 2.0, 4.0}; + double x_vals[4] = {1.0, 2.0, 3.0, 4.0}; + double w[16] = {1, -1, 2, 3, -2, 1, 0, 2, -1, 1, 3, -3, 2, 1, -2, 1}; + + expr *A_param = new_parameter(2, 2, PARAM_FIXED, 4, A); + expr *X = new_variable(2, 2, 0, 4); + expr *log_X = new_log(X); + expr *Z = new_kron(A_param, log_X, 1, 2, 2, 2, 2); + + mu_assert("kron composite wsum_hess check failed", + check_wsum_hess(Z, x_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_expr(Z); + return 0; +} From 0168aee72b28ac49acef8570379f9b4b0714ad8a Mon Sep 17 00:00:00 2001 From: Transurgeon Date: Sun, 28 Jun 2026 19:18:56 -0400 Subject: [PATCH 2/2] Format kron atom and mirror left_matmul structure Apply clang-format and tidy the native kron atom: factor parameter refresh into a helper, null freed pointers, and trim doc duplicated by the kron_expr definition in subexpr.h. Co-Authored-By: Claude Opus 4.8 (1M context) --- include/atoms/affine.h | 4 ++-- src/atoms/affine/kron.c | 32 ++++++++++++++++--------- tests/all_tests.c | 6 ++--- tests/forward_pass/affine/test_kron.h | 7 +++--- tests/jacobian_tests/affine/test_kron.h | 12 ++++------ 5 files changed, 33 insertions(+), 28 deletions(-) diff --git a/include/atoms/affine.h b/include/atoms/affine.h index 831b7ca..7c49411 100644 --- a/include/atoms/affine.h +++ b/include/atoms/affine.h @@ -85,7 +85,7 @@ expr *new_convolve(expr *param_node, expr *child); variables and is passed as child. const_is_left selects which operand is the parameter: 1 -> A=param_node, B=child; 0 -> A=child, B=param_node. (p, q) are A's dims and (r, s) are B's dims. */ -expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, - int r, int s); +expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, int r, + int s); #endif /* AFFINE_H */ diff --git a/src/atoms/affine/kron.c b/src/atoms/affine/kron.c index ddc4562..56b5089 100644 --- a/src/atoms/affine/kron.c +++ b/src/atoms/affine/kron.c @@ -24,8 +24,8 @@ #include #include -/* Kronecker product Z = kron(A, B), where exactly one operand is variable-free - * (param_source) and the other (child = node->left) carries the variables. +/* Kronecker product Z = kron(A, B). See the kron_expr definition in subexpr.h + * for the operand layout; the index math behind the scaled gather is below. * * With column-major (Fortran) flattening, an output index OUT = I + J*(p*r) * decomposes as I = i*r + k and J = j*s + l (i in [0,p), k in [0,r), j in [0,q), @@ -40,18 +40,24 @@ * forward, Jacobian and (affine) Hessian are all scaled gathers -- no * size_out x size_child coefficient matrix and no sparse matmul. */ +/* Pull current parameter values through any broadcast/promote wrappers. */ +static void refresh_param_values(kron_expr *knode) +{ + if (!knode->base.needs_parameter_refresh) + { + return; + } + + knode->param_source->forward(knode->param_source, NULL); + knode->base.needs_parameter_refresh = false; +} + static void forward(expr *node, const double *u) { expr *child = node->left; kron_expr *knode = (kron_expr *) node; - /* Pull current parameter values through any broadcast/promote wrappers. */ - if (knode->base.needs_parameter_refresh) - { - knode->param_source->forward(knode->param_source, NULL); - knode->base.needs_parameter_refresh = false; - } - + refresh_param_values(knode); child->forward(child, u); const double *a = knode->param_source->value; @@ -163,10 +169,14 @@ static void free_type_data(expr *node) sp_free(knode->child_row); sp_free(knode->coeff_idx); free_expr(knode->param_source); + + knode->child_row = NULL; + knode->coeff_idx = NULL; + knode->param_source = NULL; } -expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, - int r, int s) +expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, int r, + int s) { int d1 = p * r; int d2 = q * s; diff --git a/tests/all_tests.c b/tests/all_tests.c index 9c1effc..ee3474b 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -8,9 +8,9 @@ #include "forward_pass/affine/test_add.h" #include "forward_pass/affine/test_broadcast.h" #include "forward_pass/affine/test_convolve.h" -#include "forward_pass/affine/test_kron.h" #include "forward_pass/affine/test_diag_mat.h" #include "forward_pass/affine/test_hstack.h" +#include "forward_pass/affine/test_kron.h" #include "forward_pass/affine/test_left_matmul_dense.h" #include "forward_pass/affine/test_linear_op.h" #include "forward_pass/affine/test_neg.h" @@ -28,10 +28,10 @@ #include "forward_pass/other/test_prod_axis_zero.h" #include "jacobian_tests/affine/test_broadcast.h" #include "jacobian_tests/affine/test_convolve.h" -#include "jacobian_tests/affine/test_kron.h" #include "jacobian_tests/affine/test_diag_mat.h" #include "jacobian_tests/affine/test_hstack.h" #include "jacobian_tests/affine/test_index.h" +#include "jacobian_tests/affine/test_kron.h" #include "jacobian_tests/affine/test_left_matmul.h" #include "jacobian_tests/affine/test_neg.h" #include "jacobian_tests/affine/test_promote.h" @@ -75,10 +75,10 @@ #include "utils/test_stacked_pd.h" #include "wsum_hess/affine/test_broadcast.h" #include "wsum_hess/affine/test_convolve.h" -#include "wsum_hess/affine/test_kron.h" #include "wsum_hess/affine/test_diag_mat.h" #include "wsum_hess/affine/test_hstack.h" #include "wsum_hess/affine/test_index.h" +#include "wsum_hess/affine/test_kron.h" #include "wsum_hess/affine/test_left_matmul.h" #include "wsum_hess/affine/test_right_matmul.h" #include "wsum_hess/affine/test_scalar_mult.h" diff --git a/tests/forward_pass/affine/test_kron.h b/tests/forward_pass/affine/test_kron.h index 3a04df3..e38c725 100644 --- a/tests/forward_pass/affine/test_kron.h +++ b/tests/forward_pass/affine/test_kron.h @@ -42,8 +42,8 @@ const char *test_kron_forward_const_right(void) double u[4] = {5.0, 7.0, 6.0, 8.0}; /* col-major [[5,6],[7,8]] */ Z->forward(Z, u); - double expected[16] = {5, 15, 7, 21, 10, 20, 14, 28, - 6, 18, 8, 24, 12, 24, 16, 32}; + double expected[16] = {5, 15, 7, 21, 10, 20, 14, 28, + 6, 18, 8, 24, 12, 24, 16, 32}; mu_assert("kron const-right d1=4", Z->d1 == 4); mu_assert("kron const-right d2=4", Z->d2 == 4); @@ -70,8 +70,7 @@ const char *test_kron_forward_scalar(void) mu_assert("kron scalar d1=2", Z->d1 == 2); mu_assert("kron scalar d2=2", Z->d2 == 2); - mu_assert("kron scalar forward failed", - cmp_double_array(Z->value, expected, 4)); + mu_assert("kron scalar forward failed", cmp_double_array(Z->value, expected, 4)); free_expr(Z); return 0; diff --git a/tests/jacobian_tests/affine/test_kron.h b/tests/jacobian_tests/affine/test_kron.h index 114b11f..1170ae3 100644 --- a/tests/jacobian_tests/affine/test_kron.h +++ b/tests/jacobian_tests/affine/test_kron.h @@ -27,15 +27,13 @@ const char *test_jacobian_kron_const_left(void) mu_assert("kron J cols", Z->jacobian->n == 4); mu_assert("kron J nnz", Z->jacobian->nnz == 16); - int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; + int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int expected_i[16] = {0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3}; double expected_x[16] = {1, 1, 3, 3, 1, 1, 3, 3, 2, 2, 4, 4, 2, 2, 4, 4}; mu_assert("kron const-left J sparsity", cmp_sparsity(Z->jacobian, expected_p, expected_i, 16, 16)); - mu_assert("kron const-left J values", - cmp_values(Z->jacobian, expected_x, 16)); + mu_assert("kron const-left J values", cmp_values(Z->jacobian, expected_x, 16)); free_expr(Z); return 0; @@ -57,15 +55,13 @@ const char *test_jacobian_kron_const_right(void) mu_assert("kron J rows", Z->jacobian->m == 16); mu_assert("kron J nnz", Z->jacobian->nnz == 16); - int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; + int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; int expected_i[16] = {0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 3, 3, 2, 2, 3, 3}; double expected_x[16] = {1, 3, 1, 3, 2, 4, 2, 4, 1, 3, 1, 3, 2, 4, 2, 4}; mu_assert("kron const-right J sparsity", cmp_sparsity(Z->jacobian, expected_p, expected_i, 16, 16)); - mu_assert("kron const-right J values", - cmp_values(Z->jacobian, expected_x, 16)); + mu_assert("kron const-right J values", cmp_values(Z->jacobian, expected_x, 16)); free_expr(Z); return 0;