Skip to content

fix _reduce_split_kernel for triton 3.5.1#4696

Open
irexyc wants to merge 1 commit into
InternLM:mainfrom
irexyc:fix-reduce
Open

fix _reduce_split_kernel for triton 3.5.1#4696
irexyc wants to merge 1 commit into
InternLM:mainfrom
irexyc:fix-reduce

Conversation

@irexyc

@irexyc irexyc commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Motivation

fix mask for acc_k

reproduce code with triton 3.5.1

#!/usr/bin/env python3
import torch
import triton
import triton.language as tl


SPLIT_K = 128
D = 128
STRIDE = D + 2
TOKENS = 1
HEADS = 32


@triton.jit
def bad_inline_mask(
    acc_ptr,
    out_ptr,
    stride_ak,
    stride_abs,
    stride_ah,
    stride_ad,
    stride_obs,
    stride_oh,
    stride_od,
    SPLIT_K: tl.constexpr,
    D: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    cur_batch = tl.program_id(1)
    cur_head = tl.program_id(0)
    offs_k = tl.arange(0, SPLIT_K)
    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < D

    offs_v = cur_batch * stride_abs + cur_head * stride_ah + offs_k[:, None] * stride_ak + offs_d[None, :] * stride_ad
    offs_m = cur_batch * stride_abs + cur_head * stride_ah + stride_ak * offs_k + D

    m = tl.load(acc_ptr + offs_m)
    l = tl.load(acc_ptr + offs_m + 1)
    v = tl.load(
        acc_ptr + offs_v,
        mask=mask_d[None, :] & (m[:, None] > -float('inf')),
        other=0.0,
    )

    m_max = tl.max(m, 0)
    alpha = tl.exp2(m - m_max)
    v = v * alpha[:, None]
    l = l * alpha
    out = tl.sum(v, 0)
    l_sum = tl.sum(l, 0)
    out = out / (l_sum + 1e-10)
    tl.store(out_ptr + cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od, out, mask=mask_d)


@triton.jit
def good_materialized_mask(
    acc_ptr,
    out_ptr,
    stride_ak,
    stride_abs,
    stride_ah,
    stride_ad,
    stride_obs,
    stride_oh,
    stride_od,
    SPLIT_K: tl.constexpr,
    D: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    cur_batch = tl.program_id(1)
    cur_head = tl.program_id(0)
    offs_k = tl.arange(0, SPLIT_K)
    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < D

    offs_v = cur_batch * stride_abs + cur_head * stride_ah + offs_k[:, None] * stride_ak + offs_d[None, :] * stride_ad
    offs_m = cur_batch * stride_abs + cur_head * stride_ah + stride_ak * offs_k + D

    m = tl.load(acc_ptr + offs_m)
    active = m > -float('inf')
    l = tl.load(acc_ptr + offs_m + 1)
    v = tl.load(
        acc_ptr + offs_v,
        mask=mask_d[None, :] & active[:, None],
        other=0.0,
    )

    m_max = tl.max(m, 0)
    alpha = tl.exp2(m - m_max)
    v = v * alpha[:, None]
    l = l * alpha
    out = tl.sum(v, 0)
    l_sum = tl.sum(l, 0)
    out = out / (l_sum + 1e-10)
    tl.store(out_ptr + cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od, out, mask=mask_d)


def reference(acc):
    m = acc[..., D]
    l = acc[..., D + 1]
    active = m > -float('inf')
    alpha = torch.exp2(m - m.max(dim=2).values[..., None])
    v = torch.where(active[..., None], acc[..., :D], torch.zeros_like(acc[..., :D]))
    alpha = torch.where(active, alpha, torch.zeros_like(alpha))
    return (v * alpha[..., None]).sum(2) / ((l * alpha).sum(2)[..., None] + 1e-10)


def main():
    print(f'torch={torch.__version__} triton={triton.__version__}')
    print(f'device={torch.cuda.get_device_name()}')

    torch.manual_seed(0)
    acc = torch.empty((TOKENS, HEADS, SPLIT_K, STRIDE), device='cuda', dtype=torch.float32)
    acc.fill_(float('nan'))
    acc[..., D] = -float('inf')
    acc[..., D + 1] = 0.0

    # Two valid splits followed by many empty splits.
    acc[:, :, :2, :D] = torch.randn((TOKENS, HEADS, 2, D), device='cuda') * 0.01
    acc[:, :, 0, D], acc[:, :, 1, D] = 10.5, 5.25
    acc[:, :, 0, D + 1], acc[:, :, 1, D + 1] = 1.04, 2.55

    bad = torch.empty((TOKENS, HEADS, D), device='cuda', dtype=torch.bfloat16)
    good = torch.empty_like(bad)

    bad_inline_mask[(HEADS, TOKENS)](
        acc,
        bad,
        stride_ak=acc.stride(2),
        stride_abs=acc.stride(0),
        stride_ah=acc.stride(1),
        stride_ad=acc.stride(3),
        stride_obs=bad.stride(0),
        stride_oh=bad.stride(1),
        stride_od=bad.stride(2),
        SPLIT_K=SPLIT_K,
        D=D,
        BLOCK_D=D,
        num_warps=2,
        num_stages=1,
    )
    good_materialized_mask[(HEADS, TOKENS)](
        acc,
        good,
        stride_ak=acc.stride(2),
        stride_abs=acc.stride(0),
        stride_ah=acc.stride(1),
        stride_ad=acc.stride(3),
        stride_obs=good.stride(0),
        stride_oh=good.stride(1),
        stride_od=good.stride(2),
        SPLIT_K=SPLIT_K,
        D=D,
        BLOCK_D=D,
        num_warps=2,
        num_stages=1,
    )
    torch.cuda.synchronize()

    ref = reference(acc).to(torch.bfloat16)
    print('bad finite:', torch.isfinite(bad).all().item())
    print('good finite:', torch.isfinite(good).all().item())
    print('bad max diff:', (bad.float() - ref.float()).abs().nan_to_num(float('inf')).max().item())
    print('good max diff:', (good.float() - ref.float()).abs().max().item())
    bad_idx = torch.nonzero(~torch.isfinite(bad) | ((bad.float() - good.float()).abs() > 1e-2))
    if bad_idx.numel() > 0:
        t, h, d = bad_idx[0].tolist()
    else:
        t, h, d = 0, 0, 64
    print('first bad/good/ref:', bad[t, h, d].item(), good[t, h, d].item(), ref[t, h, d].item())


if __name__ == '__main__':
    main()

Copilot AI review requested due to automatic review settings June 22, 2026 12:44

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes an incorrect Triton mask construction in the split-K reduction kernels used by paged attention, addressing incorrect/NaN outputs observed with Triton 3.5.1.

Changes:

  • Adjust mask expression in _reduce_split_kernel to avoid (m_k[:, None] > -inf) form that miscompiles/behaves incorrectly in Triton 3.5.1.
  • Apply the same mask construction change in _fused_reduce_hadamard_kernel to keep behavior consistent.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

m_k = tl.load(acc_ptr + offs_mi)
l_k = tl.load(acc_ptr + offs_mi + 1)
acc_k = tl.load(acc_ptr + offs_acc, mask=mask_dv[None, :] & (m_k[:, None] > -float('inf')), other=0.0)
# (m_k[:, None] > -float('inf')) produce invalid mask for triton 3.5.1
@lvhan028 lvhan028 requested a review from grimoire June 22, 2026 13:11

@grimoire grimoire left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants