Skip to content

[ROCm][DSv4] Make AITER sparse decode cudagraph-clean (rebased, stacked on #901)#902

Draft
ChuanLi1101 wants to merge 8 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
ChuanLi1101:chuali/aiter-mla-dsv4-decode-cudagraph-rebased
Draft

[ROCm][DSv4] Make AITER sparse decode cudagraph-clean (rebased, stacked on #901)#902
ChuanLi1101 wants to merge 8 commits intoROCm:hexwang/dsv4_adapt_upstreamfrom
ChuanLi1101:chuali/aiter-mla-dsv4-decode-cudagraph-rebased

Conversation

@ChuanLi1101
Copy link
Copy Markdown

@ChuanLi1101 ChuanLi1101 commented Apr 27, 2026

Summary

Rebased version of vllm-project#40892 for review against the tj/dsv4prrebase integration branch. Cudagraph-clean follow-up to vllm-project#40889 (which is rebased in the parallel chuali/aiter-mla-dsv4-decode-rebased PR against this same branch).

Per Tun Jian's note that vLLM upstream is now planning to merge DSv4 via the rebased PR vllm-project#40860, the AITER decode stack needs to land on top of that newer base. This PR cherry-picks the 2 commits from vllm-project#40892 on top of the rebased vllm-project#40889 commits, all on ROCm/vllm:tj/dsv4prrebase.

This does not duplicate vllm-project#40892 — that PR will stay open as a draft against vllm-project/vllm:main until vllm-project#40860 lands; this PR exists so reviewers can look at the cudagraph-clean refactor on top of the new base today.

Stacking

This PR includes 4 commits in its diff: 2 from rebased vllm-project#40889 + 2 from rebased vllm-project#40892. Once the prior PR lands, this one will reduce to just the 2 cudagraph-clean commits. Happy to rebase on request.

What changed during rebase

Original PR description

The persistent-mode AITER sparse MLA decode kernel (mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps) is the only LSE-returning sparse attention path on MI355X / gfx950, but the original integration in vllm-project#40889 allocates fresh per-step indexing tensors (qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, q_scale, kv_scale) and a fresh FP8 query / output buffer on every layer call. That is incompatible with HIP/CUDA-graph capture, so DSv4 currently has to run in eager mode on ROCm.

This PR makes the per-step decode path completely allocation-free and pointer-stable so the decode loop can be wrapped in a HIP graph once the model wires it up.

Changes

  • AiterSparseScratch now owns every per-step buffer in addition to the AITER work-plan / reduce buffers:
    • qo_indptr, kv_indptr, kv_indices_2d, kv_last_page_lens
    • valid_mask, valid_lens, col_arange
    • q_fp8 (FP8 query buffer), out_buf (BF16 output)
    • q_scale, kv_scale (constant 1.0 tensors)
  • _aiter_decode_one_scope rewrites all of those in-place every step (torch.lt(out=), copy_, masked_fill_, torch.cumsum(out=)).
  • AiterSparseScratch.refresh_metadata() re-runs aiter.get_mla_metadata_v1 against the current kv_indptr every step and writes the new work plan into the same work_* / reduce_* slots. The persistent ASM kernel encodes per-batch lengths into that plan, so leaving it stale across steps with different kv lengths causes a GPU memory-access fault.
  • rebuild() now only allocates buffers and stores static gqa/topk/dtype parameters; it no longer re-runs the metadata builder itself.

Net effect: across decode steps with the same (total_q, nhead, topk, d_qk, d_v, dtype, kvtype) key, every data_ptr() is stable, so a graph captured on step N can be replayed for any step N+k.

Test plan

AI assistance disclosure

AI assistance (Cursor agent) was used for this rebase. The submitter has reviewed every cherry-picked diff line. The microbench above was run on MI355X against the original base; will re-run on tj/dsv4prrebase once the parity sweep completes.

Related

@ChuanLi1101
Copy link
Copy Markdown
Author

Parity check on MI355X (cherry-picked branch)

Re-ran cosine-vs-PyTorch-reference parity on the cherry-picked branch
(tj/dsv4prrebase + 5 commits, HEAD 4564d3b) using
bench_remote/_parity_test_aiter_decode.py on a single MI355X / gfx950
GPU inside chuali_glm51. Synthetic inputs: b=8, topk=2048,
d_qk=576, d_v=512.

Config h_q cosine min / mean / max max abs diff result
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 1.66e-02 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 1.56e-02 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 1.46e-02 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 1.01e-02 PASS
TP2 swa_only 64 0.994775 / 0.998978 / 0.999511 2.55e-02 borderline
TP2 swa+extra 64 NaN NaN FAIL

(Threshold: cosine_mean > 0.999 AND cosine_min > 0.99)

TP2 NaN — not from the cherry-picked code

A drill-down (_parity_tp2_diag.py) shows the NaN comes from the AITER
persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps itself,
not from the integration code:

TP2 swa+extra b=8:
  out_swa  : nans=0     infs=0   ok
  out_ext  : nans=8192  infs=0   <-- NaN appears in scope #2's output
  lse_ext  : nans=0     infs=16  <-- kernel returns +inf for 16 rows

(lse_ext - lse_total).exp() with both operands +inf yields NaN, which
then propagates through the LSE merge. Reproducible across batch sizes
(1/2/4/8/16) and topk (256/512/1024/2048) at TP2; completely clean
on TP4 and TP8 (which use the smaller qh16_qseqlen1 kernel). Worth a
ping to the AITER team but doesn't block landing the decode path on
TP4/TP8 (the production deployment shapes on MI355X).

Cherry-pick verification

  • 5 commits applied with 0 manual conflict resolution.
  • One small drop during cherry-pick: the duplicate RoutingMethodType
    import in mxfp4.py already exists on tj/dsv4prrebase.
  • All cross-file references verified: _forward_decode_aiter
    aiter_sparse_attn_decode signatures, _dequantize_blocked_k_cache(out=),
    current_workspace_manager, AiterSparseScratch.

@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-cudagraph-rebased branch from bb248d3 to ea14ac0 Compare April 27, 2026 05:21
@ChuanLi1101
Copy link
Copy Markdown
Author

Parity check rerun on MI355X — all six configs PASS

After landing the head-split workaround for the AITER qh64_qseqlen4_gqaratio16_lse_ps
kernel lse=+inf issue (commit f51f25f on this stack's bottom branch), re-ran
bench_remote/_parity_test_aiter_decode.py against the top of the stack
(b790bdb):

Config h_q cosine min / mean / max max abs diff result
TP2 swa_only 64 0.998151 / 0.999135 / 0.999492 2.15e-02 PASS
TP2 swa+extra 64 0.998568 / 0.999278 / 0.999538 1.46e-02 PASS
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 1.66e-02 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 1.56e-02 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 1.46e-02 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 1.01e-02 PASS

Threshold: cosine_mean > 0.999 AND cosine_min > 0.99. All six PASS.
OVERALL: PASS (all configs cosine > 0.999 mean and > 0.99 min).

Runtime kernel trace confirms only the working kernel is loaded:

[aiter] hipModuleLoad: ... mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps.co

The buggy qh64_qseqlen4_gqaratio16_lse_ps kernel is no longer invoked.

Test setup: MI355X gfx950, GPU 0 inside chuali_glm51,
b=8 s_q=1 d_qk=576 d_v=512 n_blk=4096 blk_sz=1, topk=2048 per scope,
valid lengths sampled in [topk//2, topk). Reference is
_ref_sparse_attn_decode ported verbatim from DeepseekV4MLAAttention
(bf16). Full log + summary in bench_remote/logs/parity_run_v2.log and
bench_remote/logs/PARITY_SUMMARY.md.

@tjtanaavllm
Copy link
Copy Markdown

@ChuanLi1101 I cherry picked this PR into tj/dsv4prrebase and get this error

odule.py", line 1787, in _call_impl
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]   File "/app/dsv4cache/dsv4prrebasesync3/vllm/model_executor/layers/deepseek_v4_attention.py", line 856, in forward
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]     self._forward_decode(
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]   File "/app/dsv4cache/dsv4prrebasesync3/vllm/model_executor/layers/deepseek_v4_attention.py", line 906, in _forward_decode
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]     self._forward_decode_aiter(
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]   File "/app/dsv4cache/dsv4prrebasesync3/vllm/model_executor/layers/deepseek_v4_attention.py", line 1218, in _forward_decode_aiter
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]     attn_out = aiter_sparse_attn_decode(
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]                ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]   File "/app/dsv4cache/dsv4prrebasesync3/vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py", line 330, in aiter_sparse_attn_decode
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]     correction = 1.0 / (1.0 + (sink - lse).exp())
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962]                                ~~~~~^~~~~
(Worker_TP1 pid=222260) ERROR 04-27 14:58:08 [multiproc_executor.py:962] RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 2

I am using aiter v0.1.10.post3 from upstream docker vllm/vllm-openai-rocm:nightly

My server command is

max_num_seqs=64
max_num_batched_tokens=131072
tensor_parallel_size=4
export VLLM_TORCH_PROFILER_DIR="/app/vllm_profile"
export HF_HOME=/data/huggingface-cache
export VLLM_ROCM_USE_AITER=1
export VLLM_ROCM_USE_AITER_LINEAR=1
export VLLM_DSV4_WO_A_FP8=0

unset FLATMM_HIP_CLANG_PATH
#MODEL=/data/deepseek-ai/DeepSeek-R1-0528
MODEL=deepseek-ai/DeepSeek-V4-Flash
vllm serve ${MODEL} \
    --host localhost \
    --port 8000 \
    --dtype auto \
    --tensor-parallel-size ${tensor_parallel_size} \
    --max-num-seqs ${max_num_seqs} \
    --distributed-executor-backend mp \
    --trust-remote-code \
    --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' \
    --gpu-memory-utilization 0.40 \
    --moe-backend "triton_unfused" \
    --enforce-eager \
    --tokenizer-mode "deepseek_v4" \
    --async-scheduling

Command is:

MODEL=deepseek-ai/DeepSeek-V4-Flash
lm_eval --model local-completions --model_args model=$MODEL,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=8,max_retries=10,max_gen_toks=2048 --batch_size auto --tasks gsm8k --num_fewshot 5 --limit 16  --output_path . 2>&1 | tee -a eval-oldforward.log

Signed-off-by: ganyi <ygan@amd.com>
Made-with: Cursor
@ChuanLi1101
Copy link
Copy Markdown
Author

Rebased onto latest tj/dsv4prrebase (HEAD 7b2a8671)

tj/dsv4prrebase was force-updated since the original cherry-pick (vllm-project#40860 base

  • subsequent fixes including _old_use_hadamard, silu clamp shared expert,
    topk_softplus, and an upstream main merge). Rebased the stack on top of the
    new HEAD — 0 manual conflicts (the new base does not touch
    vllm/model_executor/layers/deepseek_v4_attention.py or our new
    rocm_aiter_dsv4_decode.py).

New commit hashes:

PR New head Old head
#901 779a9f5c f51f25f
#902 2833da39 ea14ac0
#903 eb39ee63 b790bdb

Re-ran _parity_test_aiter_decode.py on top of the rebased stack (HEAD
eb39ee63 on base 7b2a8671) — all six configs PASS, numbers match the
previous run within float-roundoff (purely a sanity-check confirming the
rebase did not break anything):

Config h_q cosine min / mean / max result
TP2 swa_only 64 0.998151 / 0.999135 / 0.999492 PASS
TP2 swa+extra 64 0.998568 / 0.999278 / 0.999538 PASS
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 PASS

Threshold: cosine_mean > 0.999 AND cosine_min > 0.99.
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the
buggy qh64_qseqlen4 kernel path is bypassed by the head-split workaround.

End-to-end gsm8k on top of the rebased stack is queued — will follow up.

whx-sjtu and others added 3 commits April 28, 2026 07:12
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.

Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
  - AiterSparseScratch: lazy-init persistent-mode metadata buffers,
    keyed by (batch, nhead, topk, dtype) so 61 layers share one
    allocation per decode step
  - aiter_sparse_attn_decode: drop-in replacement handling dual-scope
    attention (SWA + extra), LSE-based merging, and attn_sink correction
  - Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
    requires FP8)
  - Fixed-stride kv_indices layout with -1 sentinels (required by
    AITER persistent-mode kernels)

- deepseek_v4_attention.py:
  - Add _aiter_scratch / _aiter_extra_scratch fields to __init__
  - Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
    routes to _forward_decode_aiter, otherwise falls back to the
    existing PyTorch reference

- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py

Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-cudagraph-rebased branch from 2833da3 to 143c695 Compare April 28, 2026 08:49
@ChuanLi1101 ChuanLi1101 changed the base branch from tj/dsv4prrebase to hexwang/dsv4_adapt_upstream April 28, 2026 08:50
The AITER persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps
(selected when h_q==64, e.g. TP=2 on a 128-head model or TP=4 on a
256-head model) returns lse=+inf for some (batch, head) rows when called
with return_lse=True on gfx950, which propagates NaN through the
dual-scope LSE merge.

The qh16_qseqlen1 kernel selected for h_q<=32 is numerically clean
(cosine > 0.999 vs PyTorch ref on TP4/TP8). Recurse at the outer level
in aiter_sparse_attn_decode — split q + attn_sink in half, run a
complete sparse_attn_decode per half, concatenate the final (b, h_q, d_v)
outputs.

Doing the split at the outer level (instead of inside
_aiter_decode_one_scope) keeps each recursive call self-contained:
LSE merging and sink correction never cross the h_q boundary, so the
result does not depend on the kernel's lse shape convention (which has
differed across aiter versions). Earlier inner-level split surfaced as
a shape mismatch on aiter v0.1.10.post3 in the upstream nightly docker.

2x kernel launches per decode call in the split case (TP=2 with 128
heads or TP=4 with 256 heads), but those configs are rarely used in
production; correctness first.

TODO: drop once AITER fixes qh64_qseqlen4_gqaratio16_lse_ps.

Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-cudagraph-rebased branch from 143c695 to 450b838 Compare April 28, 2026 09:06
@ChuanLi1101
Copy link
Copy Markdown
Author

@tjtanaavllm thanks for trying it on tj/dsv4prrebase — that error is a real bug in our head-split workaround, fixed with the latest force-push.

Root cause

Our previous head-split lived inside _aiter_decode_one_scope: when h_q > 32 it recursed with q split in half, ran the kernel twice, and concatenated the two lse tensors back at dim=-1. That assumes lse from aiter.mla.mla_decode_fwd is shape (total_q, h_q) — which is what we get on our gfx950 / aiter build, so the parity sweep on our box passed. On your aiter v0.1.10.post3 (upstream nightly docker) the kernel returns lse with a different shape, so the post-recursion (sink - lse).exp() failed to broadcast. Hence the 64 vs 32 mismatch you saw.

Fix (commit e9605007ea on chuali/aiter-mla-dsv4-decode-rebased)

Moved the head-split up to the outer aiter_sparse_attn_decode level. When h_q > 32 we now split q and attn_sink in half, call aiter_sparse_attn_decode recursively (a complete sparse attention per half — its own LSE merge, its own sink correction), and concatenate the final (b, h_q, d_v) outputs at dim=-2.

Each recursive call is fully self-contained: LSE merging and sink correction never cross the h_q boundary, so the result no longer depends on the kernel's lse shape convention.

Verified

Re-ran bench_remote/_parity_test_aiter_decode.py on top of the rebased stack (HEAD 29ef9be512 on hexwang/dsv4_adapt_upstream 66ed64f5):

Config h_q cosine min / mean / max result
TP2 swa_only 64 0.998151 / 0.999135 / 0.999492 PASS
TP2 swa+extra 64 0.998568 / 0.999278 / 0.999538 PASS
TP4 swa_only 32 0.998532 / 0.999131 / 0.999474 PASS
TP4 swa+extra 32 0.998528 / 0.999280 / 0.999494 PASS
TP8 swa_only 16 0.998682 / 0.999152 / 0.999455 PASS
TP8 swa+extra 16 0.998927 / 0.999292 / 0.999505 PASS

Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the buggy qh64_qseqlen4 path is bypassed.

Note on PR base

We also re-pointed all three PRs (#901 / #902 / #903) at hexwang/dsv4_adapt_upstream per the team's decision that this is the accuracy baseline going forward. Re-rebase had 0 manual conflicts since neither hexwang/dsv4_adapt_upstream nor tj/dsv4prrebase touches vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py or our _forward_decode_aiter additions.

Could you re-pull chuali/aiter-mla-dsv4-decode-rebased (or whichever you're cherry-picking) and try the V4-Flash + TP=4 + --tokenizer-mode "deepseek_v4" recipe again? If it still trips, please grab the q.shape and lse.shape at the failure point — that pins down whether your aiter build returns lse in a fundamentally different layout, and we can guard for that explicitly.

ChuanLi1101 and others added 3 commits April 28, 2026 02:34
select_mxfp4_moe_backend in mxfp4.py references
RoutingMethodType.DeepseekV4 but the symbol is not imported on the
hexwang/dsv4_adapt_upstream base, raising NameError: name
'RoutingMethodType' is not defined during model init. Hexiang's recipe
happens to skip this code path via --moe-backend triton_unfused CLI
flag, but the LLM offline API takes the same path and trips on it.

Minimal one-line fix: pull RoutingMethodType into the same multi-import
block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc /
mxfp4_*_moe_quant_config from fused_moe.config.

This fix was originally in vllm-project#40889 but got auto-merged-out during the
cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing
when rebasing onto hexwang/dsv4_adapt_upstream which does not.

Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.

Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.

  * `AiterSparseScratch.rebuild()` now only allocates buffers and stores
    the static gqa/topk/dtype parameters; it no longer requires a
    `kv_indptr_seed` and no longer runs the metadata builder itself.
  * New `AiterSparseScratch.refresh_metadata()` reruns
    `get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
  * `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
    `kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
    step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.

Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
  - Call 1 (lens=[3,2]): success, scratch key set.
  - Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
    bit-identical to call 1.
  - Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
    expected (max abs diff = 2.39 vs identical-input call), no fault.
  - Parity check vs the original non-cudagraph implementation:
    max abs diff = 0.000000.

Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
@ChuanLi1101 ChuanLi1101 force-pushed the chuali/aiter-mla-dsv4-decode-cudagraph-rebased branch from 450b838 to f4778f8 Compare April 28, 2026 09:36
@whx-sjtu whx-sjtu force-pushed the hexwang/dsv4_adapt_upstream branch from 9a8d252 to 14a3f64 Compare April 30, 2026 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants