Skip to content

Guard scatter_axis against 64-bit outputs on the GPU#3695

Open
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/scatter-axis-64bit-guard
Open

Guard scatter_axis against 64-bit outputs on the GPU#3695
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/scatter-axis-64bit-guard

Conversation

@obchain

@obchain obchain commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3690.

mx.put_along_axis / scatter_add_axis with a 64-bit element dtype (int64/uint64) fail the Metal library JIT build instead of raising a clean "unsupported dtype" error. In mlx/backend/metal/kernels/atomic.h, packing_size<T> = sizeof(uint)/sizeof(T) is 0 for 8-byte T, so uint_or_packed<T> declares a zero-length array (hard C++ error) and offset / packing_size<T> divides by zero. The whole mlx-metallib build then fails.

The plain Scatter path already guards 8-byte outputs on the GPU — in scatter() (ops.cpp) and Scatter::eval_gpu (indexing.cpp) — but ScatterAxis was missing the equivalent guard. This adds it in both places, mirroring Scatter:

  • scatter_axis() in mlx/ops.cpp — raises on GPU for 8-byte dtypes, matching the existing scatter() guard. CPU is unaffected.
  • ScatterAxis::eval_gpu in mlx/backend/metal/indexing.cpp — same guard as Scatter::eval_gpu.

Repro before the fix (Metal device):

import mlx.core as mx
mx.set_default_device(mx.gpu)
x   = mx.zeros((4, 8), dtype=mx.int64)
idx = mx.array([[0],[1],[2],[3]])
upd = mx.ones((4, 1), dtype=mx.int64)
mx.eval(mx.put_along_axis(x, idx, upd, axis=1))  # was: Metal JIT build failure

After the fix this raises a clean ValueError on the GPU, and continues to work on the CPU.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

put_along_axis / scatter_add_axis with int64/uint64 values failed the
Metal library JIT build instead of raising a clean error: packing_size<T>
in atomic.h is sizeof(uint)/sizeof(T) == 0 for 8-byte T, producing a
zero-length array and a divide-by-zero in the fallback atomic union.

The plain Scatter path already guards this in scatter() and
Scatter::eval_gpu; ScatterAxis had no equivalent guard. Mirror it in
scatter_axis() (GPU only, matching scatter()) and ScatterAxis::eval_gpu,
and add a test.
Comment thread mlx/ops.cpp
}

// TODO, remove when scatter_axis supports 64-bit outputs
if (to_stream(s).device == Device::gpu && size_of(a.dtype()) == 8) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check in metal/indexing.cpp along is enough, and it actually works in the cuda backend, the test should also be updated for metal only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] int64/uint64 put_along_axis / scatter_add_axis crashes the Metal JIT build

2 participants