Skip to content

Commit 0168aee

Browse files
Transurgeonclaude
andcommitted
Format kron atom and mirror left_matmul structure
Apply clang-format and tidy the native kron atom: factor parameter refresh into a helper, null freed pointers, and trim doc duplicated by the kron_expr definition in subexpr.h. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 339a4cf commit 0168aee

5 files changed

Lines changed: 33 additions & 28 deletions

File tree

include/atoms/affine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ expr *new_convolve(expr *param_node, expr *child);
8585
variables and is passed as child. const_is_left selects which operand is the
8686
parameter: 1 -> A=param_node, B=child; 0 -> A=child, B=param_node. (p, q) are
8787
A's dims and (r, s) are B's dims. */
88-
expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q,
89-
int r, int s);
88+
expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, int r,
89+
int s);
9090

9191
#endif /* AFFINE_H */

src/atoms/affine/kron.c

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
#include <stdlib.h>
2525
#include <string.h>
2626

27-
/* Kronecker product Z = kron(A, B), where exactly one operand is variable-free
28-
* (param_source) and the other (child = node->left) carries the variables.
27+
/* Kronecker product Z = kron(A, B). See the kron_expr definition in subexpr.h
28+
* for the operand layout; the index math behind the scaled gather is below.
2929
*
3030
* With column-major (Fortran) flattening, an output index OUT = I + J*(p*r)
3131
* decomposes as I = i*r + k and J = j*s + l (i in [0,p), k in [0,r), j in [0,q),
@@ -40,18 +40,24 @@
4040
* forward, Jacobian and (affine) Hessian are all scaled gathers -- no
4141
* size_out x size_child coefficient matrix and no sparse matmul. */
4242

43+
/* Pull current parameter values through any broadcast/promote wrappers. */
44+
static void refresh_param_values(kron_expr *knode)
45+
{
46+
if (!knode->base.needs_parameter_refresh)
47+
{
48+
return;
49+
}
50+
51+
knode->param_source->forward(knode->param_source, NULL);
52+
knode->base.needs_parameter_refresh = false;
53+
}
54+
4355
static void forward(expr *node, const double *u)
4456
{
4557
expr *child = node->left;
4658
kron_expr *knode = (kron_expr *) node;
4759

48-
/* Pull current parameter values through any broadcast/promote wrappers. */
49-
if (knode->base.needs_parameter_refresh)
50-
{
51-
knode->param_source->forward(knode->param_source, NULL);
52-
knode->base.needs_parameter_refresh = false;
53-
}
54-
60+
refresh_param_values(knode);
5561
child->forward(child, u);
5662

5763
const double *a = knode->param_source->value;
@@ -163,10 +169,14 @@ static void free_type_data(expr *node)
163169
sp_free(knode->child_row);
164170
sp_free(knode->coeff_idx);
165171
free_expr(knode->param_source);
172+
173+
knode->child_row = NULL;
174+
knode->coeff_idx = NULL;
175+
knode->param_source = NULL;
166176
}
167177

168-
expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q,
169-
int r, int s)
178+
expr *new_kron(expr *param_node, expr *child, int const_is_left, int p, int q, int r,
179+
int s)
170180
{
171181
int d1 = p * r;
172182
int d2 = q * s;

tests/all_tests.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
#include "forward_pass/affine/test_add.h"
99
#include "forward_pass/affine/test_broadcast.h"
1010
#include "forward_pass/affine/test_convolve.h"
11-
#include "forward_pass/affine/test_kron.h"
1211
#include "forward_pass/affine/test_diag_mat.h"
1312
#include "forward_pass/affine/test_hstack.h"
13+
#include "forward_pass/affine/test_kron.h"
1414
#include "forward_pass/affine/test_left_matmul_dense.h"
1515
#include "forward_pass/affine/test_linear_op.h"
1616
#include "forward_pass/affine/test_neg.h"
@@ -28,10 +28,10 @@
2828
#include "forward_pass/other/test_prod_axis_zero.h"
2929
#include "jacobian_tests/affine/test_broadcast.h"
3030
#include "jacobian_tests/affine/test_convolve.h"
31-
#include "jacobian_tests/affine/test_kron.h"
3231
#include "jacobian_tests/affine/test_diag_mat.h"
3332
#include "jacobian_tests/affine/test_hstack.h"
3433
#include "jacobian_tests/affine/test_index.h"
34+
#include "jacobian_tests/affine/test_kron.h"
3535
#include "jacobian_tests/affine/test_left_matmul.h"
3636
#include "jacobian_tests/affine/test_neg.h"
3737
#include "jacobian_tests/affine/test_promote.h"
@@ -75,10 +75,10 @@
7575
#include "utils/test_stacked_pd.h"
7676
#include "wsum_hess/affine/test_broadcast.h"
7777
#include "wsum_hess/affine/test_convolve.h"
78-
#include "wsum_hess/affine/test_kron.h"
7978
#include "wsum_hess/affine/test_diag_mat.h"
8079
#include "wsum_hess/affine/test_hstack.h"
8180
#include "wsum_hess/affine/test_index.h"
81+
#include "wsum_hess/affine/test_kron.h"
8282
#include "wsum_hess/affine/test_left_matmul.h"
8383
#include "wsum_hess/affine/test_right_matmul.h"
8484
#include "wsum_hess/affine/test_scalar_mult.h"

tests/forward_pass/affine/test_kron.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ const char *test_kron_forward_const_right(void)
4242
double u[4] = {5.0, 7.0, 6.0, 8.0}; /* col-major [[5,6],[7,8]] */
4343
Z->forward(Z, u);
4444

45-
double expected[16] = {5, 15, 7, 21, 10, 20, 14, 28,
46-
6, 18, 8, 24, 12, 24, 16, 32};
45+
double expected[16] = {5, 15, 7, 21, 10, 20, 14, 28,
46+
6, 18, 8, 24, 12, 24, 16, 32};
4747

4848
mu_assert("kron const-right d1=4", Z->d1 == 4);
4949
mu_assert("kron const-right d2=4", Z->d2 == 4);
@@ -70,8 +70,7 @@ const char *test_kron_forward_scalar(void)
7070

7171
mu_assert("kron scalar d1=2", Z->d1 == 2);
7272
mu_assert("kron scalar d2=2", Z->d2 == 2);
73-
mu_assert("kron scalar forward failed",
74-
cmp_double_array(Z->value, expected, 4));
73+
mu_assert("kron scalar forward failed", cmp_double_array(Z->value, expected, 4));
7574

7675
free_expr(Z);
7776
return 0;

tests/jacobian_tests/affine/test_kron.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@ const char *test_jacobian_kron_const_left(void)
2727
mu_assert("kron J cols", Z->jacobian->n == 4);
2828
mu_assert("kron J nnz", Z->jacobian->nnz == 16);
2929

30-
int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8,
31-
9, 10, 11, 12, 13, 14, 15, 16};
30+
int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
3231
int expected_i[16] = {0, 1, 0, 1, 2, 3, 2, 3, 0, 1, 0, 1, 2, 3, 2, 3};
3332
double expected_x[16] = {1, 1, 3, 3, 1, 1, 3, 3, 2, 2, 4, 4, 2, 2, 4, 4};
3433

3534
mu_assert("kron const-left J sparsity",
3635
cmp_sparsity(Z->jacobian, expected_p, expected_i, 16, 16));
37-
mu_assert("kron const-left J values",
38-
cmp_values(Z->jacobian, expected_x, 16));
36+
mu_assert("kron const-left J values", cmp_values(Z->jacobian, expected_x, 16));
3937

4038
free_expr(Z);
4139
return 0;
@@ -57,15 +55,13 @@ const char *test_jacobian_kron_const_right(void)
5755
mu_assert("kron J rows", Z->jacobian->m == 16);
5856
mu_assert("kron J nnz", Z->jacobian->nnz == 16);
5957

60-
int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8,
61-
9, 10, 11, 12, 13, 14, 15, 16};
58+
int expected_p[17] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
6259
int expected_i[16] = {0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 3, 3, 2, 2, 3, 3};
6360
double expected_x[16] = {1, 3, 1, 3, 2, 4, 2, 4, 1, 3, 1, 3, 2, 4, 2, 4};
6461

6562
mu_assert("kron const-right J sparsity",
6663
cmp_sparsity(Z->jacobian, expected_p, expected_i, 16, 16));
67-
mu_assert("kron const-right J values",
68-
cmp_values(Z->jacobian, expected_x, 16));
64+
mu_assert("kron const-right J values", cmp_values(Z->jacobian, expected_x, 16));
6965

7066
free_expr(Z);
7167
return 0;

0 commit comments

Comments
 (0)