Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ cc_library(
deps = [
":allocator",
":basics",
":gilbert",
":mat",
":threading",
":threading_context",
Expand All @@ -359,6 +360,7 @@ cc_library(
deps = [
":allocator",
":basics",
":gilbert",
":mat",
":matmul_env",
":threading",
Expand Down Expand Up @@ -521,6 +523,12 @@ cc_test(
],
)

cc_library(
name = "gilbert",
srcs = ["ops/gilbert.cc"],
hdrs = ["ops/gilbert.h"],
)

cc_test(
name = "bench_matmul",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ set(SOURCES
io/io.cc
io/io.h
ops/dot-inl.h
ops/gilbert.cc
ops/gilbert.h
ops/matmul_static_bf16.cc
ops/matmul_static_f32.cc
ops/matmul_static_nuq.cc
Expand Down
117 changes: 117 additions & 0 deletions ops/gilbert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2026 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Adapted from code by abetusk (BSD-2-Clause) in
// https://github.com/jakubcerveny/gilbert.

#include "ops/gilbert.h"

#include <math.h>
#include <stdlib.h>

namespace gcpp {

namespace {

int sgn(int x) {
if (x < 0) {
return -1;
}
if (x > 0) {
return 1;
}
return 0;
}

int gilbert_d2xy_r(int dst_idx, int cur_idx, int* xres, int* yres, int ax,
int ay, int bx, int by) {
int nxt_idx;
int w, h, x, y, dax, day, dbx, dby, di;
int ax2, ay2, bx2, by2, w2, h2;

w = abs(ax + ay);
h = abs(bx + by);
x = *xres;
y = *yres;
dax = sgn(ax);
day = sgn(ay);
dbx = sgn(bx);
dby = sgn(by);
di = dst_idx - cur_idx;

if (h == 1) {
*xres = x + dax * di;
*yres = y + day * di;
return 0;
}
if (w == 1) {
*xres = x + dbx * di;
*yres = y + dby * di;
return 0;
}

ax2 = ax >> 1;
ay2 = ay >> 1;
bx2 = bx >> 1;
by2 = by >> 1;
w2 = abs(ax2 + ay2);
h2 = abs(bx2 + by2);

if ((2 * w) > (3 * h)) {
if ((w2 & 1) && (w > 2)) {
ax2 += dax;
ay2 += day;
}
nxt_idx = cur_idx + abs((ax2 + ay2) * (bx + by));
if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) {
*xres = x;
*yres = y;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax2, ay2, bx, by);
}
cur_idx = nxt_idx;
*xres = x + ax2;
*yres = y + ay2;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax - ax2, ay - ay2, bx,
by);
}

if ((h2 & 1) && (h > 2)) {
bx2 += dbx;
by2 += dby;
}

nxt_idx = cur_idx + abs((bx2 + by2) * (ax2 + ay2));
if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) {
*xres = x;
*yres = y;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, bx2, by2, ax2, ay2);
}
cur_idx = nxt_idx;

nxt_idx = cur_idx + abs((ax + ay) * ((bx - bx2) + (by - by2)));
if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) {
*xres = x + bx2;
*yres = y + by2;
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax, ay, bx - bx2,
by - by2);
}
cur_idx = nxt_idx;

*xres = x + (ax - dax) + (bx2 - dbx);
*yres = y + (ay - day) + (by2 - dby);
return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, -bx2, -by2, -(ax - ax2),
-(ay - ay2));
}

} // namespace

int gilbert_d2xy(int* x, int* y, int idx, int w, int h) {
*x = 0;
*y = 0;
if (w >= h) {
return gilbert_d2xy_r(idx, 0, x, y, w, 0, 0, h);
}
return gilbert_d2xy_r(idx, 0, x, y, 0, h, w, 0);
}

} // namespace gcpp
17 changes: 17 additions & 0 deletions ops/gilbert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2026 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Adapted from code by abetusk (BSD-2-Clause) in
// https://github.com/jakubcerveny/gilbert.

#ifndef THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_

namespace gcpp {

// Maps a 1D Hilbert curve index to 2D coordinates (x, y).
int gilbert_d2xy(int* x, int* y, int idx, int w, int h);

} // namespace gcpp

#endif // THIRD_PARTY_GEMMA_CPP_OPS_GILBERT_H_
65 changes: 65 additions & 0 deletions ops/matmul-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,71 @@ class MMLoops {
}
});
}

// Parallel loops over mc/nc blocks of M/range_n via SFC, single K.
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderSFC, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMSFC);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_kc = args.ranges_kc.Range(0); // whole K

parallel.ForRangesSFC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::B3A2C0(
A, B, range_mc, range_kc, range_nc, args, MMSetC(),
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));

const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);

if (B2 != nullptr) {
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
MMSetC(), C2);
}
if constexpr (IsBF16<TC>()) {
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}

// Parallel loops over mc/nc blocks of M/range_n via SFC, sequential K.
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderSFC_K, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
const auto zone = args.env.ctx.profiler_zones.Get(Zones::kMMSFC_K);

parallel.ForRangesSFC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::ForeachKC(
A, B, range_mc, args.ranges_kc, range_nc, args,
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));

const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);

if (B2 != nullptr) {
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc,
args, C2);
}

if constexpr (IsBF16<TC>()) {
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}
}; // MMLoops

// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
Expand Down
Loading
Loading