[ROCm][DSv4] Share AITER decode dequant + fp8-cast buffers across layers (rebased, stacked on #902)#903
Conversation
Parity check on MI355X (cherry-picked branch)Re-ran cosine-vs-PyTorch-reference parity on the cherry-picked branch
(Threshold: cosine_mean > 0.999 AND cosine_min > 0.99) TP2 NaN — not from the cherry-picked codeA drill-down (
Cherry-pick verification
|
4564d3b to
b790bdb
Compare
Parity check rerun on MI355X — all six configs PASSAfter landing the head-split workaround for the AITER
Threshold: Runtime kernel trace confirms only the working kernel is loaded: The buggy Test setup: MI355X gfx950, GPU 0 inside |
b790bdb to
eb39ee6
Compare
Rebased onto latest
|
| PR | New head | Old head |
|---|---|---|
| #901 | 779a9f5c |
f51f25f |
| #902 | 2833da39 |
ea14ac0 |
| #903 | eb39ee63 |
b790bdb |
Re-ran _parity_test_aiter_decode.py on top of the rebased stack (HEAD
eb39ee63 on base 7b2a8671) — all six configs PASS, numbers match the
previous run within float-roundoff (purely a sanity-check confirming the
rebase did not break anything):
| Config | h_q | cosine min / mean / max | result |
|---|---|---|---|
| TP2 swa_only | 64 | 0.998151 / 0.999135 / 0.999492 | PASS |
| TP2 swa+extra | 64 | 0.998568 / 0.999278 / 0.999538 | PASS |
| TP4 swa_only | 32 | 0.998532 / 0.999131 / 0.999474 | PASS |
| TP4 swa+extra | 32 | 0.998528 / 0.999280 / 0.999494 | PASS |
| TP8 swa_only | 16 | 0.998682 / 0.999152 / 0.999455 | PASS |
| TP8 swa+extra | 16 | 0.998927 / 0.999292 / 0.999505 | PASS |
Threshold: cosine_mean > 0.999 AND cosine_min > 0.99.
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the
buggy qh64_qseqlen4 kernel path is bypassed by the head-split workaround.
End-to-end gsm8k on top of the rebased stack is queued — will follow up.
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.
Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
- AiterSparseScratch: lazy-init persistent-mode metadata buffers,
keyed by (batch, nhead, topk, dtype) so 61 layers share one
allocation per decode step
- aiter_sparse_attn_decode: drop-in replacement handling dual-scope
attention (SWA + extra), LSE-based merging, and attn_sink correction
- Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
requires FP8)
- Fixed-stride kv_indices layout with -1 sentinels (required by
AITER persistent-mode kernels)
- deepseek_v4_attention.py:
- Add _aiter_scratch / _aiter_extra_scratch fields to __init__
- Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
routes to _forward_decode_aiter, otherwise falls back to the
existing PyTorch reference
- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py
Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
eb39ee6 to
304f95f
Compare
The AITER persistent ASM kernel mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps (selected when h_q==64, e.g. TP=2 on a 128-head model or TP=4 on a 256-head model) returns lse=+inf for some (batch, head) rows when called with return_lse=True on gfx950, which propagates NaN through the dual-scope LSE merge. The qh16_qseqlen1 kernel selected for h_q<=32 is numerically clean (cosine > 0.999 vs PyTorch ref on TP4/TP8). Recurse at the outer level in aiter_sparse_attn_decode — split q + attn_sink in half, run a complete sparse_attn_decode per half, concatenate the final (b, h_q, d_v) outputs. Doing the split at the outer level (instead of inside _aiter_decode_one_scope) keeps each recursive call self-contained: LSE merging and sink correction never cross the h_q boundary, so the result does not depend on the kernel's lse shape convention (which has differed across aiter versions). Earlier inner-level split surfaced as a shape mismatch on aiter v0.1.10.post3 in the upstream nightly docker. 2x kernel launches per decode call in the split case (TP=2 with 128 heads or TP=4 with 256 heads), but those configs are rarely used in production; correctness first. TODO: drop once AITER fixes qh64_qseqlen4_gqaratio16_lse_ps. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
304f95f to
29ef9be
Compare
|
@tjtanaavllm thanks for trying it on Root causeOur previous head-split lived inside Fix (commit
|
| Config | h_q | cosine min / mean / max | result |
|---|---|---|---|
| TP2 swa_only | 64 | 0.998151 / 0.999135 / 0.999492 | PASS |
| TP2 swa+extra | 64 | 0.998568 / 0.999278 / 0.999538 | PASS |
| TP4 swa_only | 32 | 0.998532 / 0.999131 / 0.999474 | PASS |
| TP4 swa+extra | 32 | 0.998528 / 0.999280 / 0.999494 | PASS |
| TP8 swa_only | 16 | 0.998682 / 0.999152 / 0.999455 | PASS |
| TP8 swa+extra | 16 | 0.998927 / 0.999292 / 0.999505 | PASS |
Only mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps is loaded at runtime; the buggy qh64_qseqlen4 path is bypassed.
Note on PR base
We also re-pointed all three PRs (#901 / #902 / #903) at hexwang/dsv4_adapt_upstream per the team's decision that this is the accuracy baseline going forward. Re-rebase had 0 manual conflicts since neither hexwang/dsv4_adapt_upstream nor tj/dsv4prrebase touches vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py or our _forward_decode_aiter additions.
Could you re-pull chuali/aiter-mla-dsv4-decode-rebased (or whichever you're cherry-picking) and try the V4-Flash + TP=4 + --tokenizer-mode "deepseek_v4" recipe again? If it still trips, please grab the q.shape and lse.shape at the failure point — that pins down whether your aiter build returns lse in a fundamentally different layout, and we can guard for that explicitly.
select_mxfp4_moe_backend in mxfp4.py references RoutingMethodType.DeepseekV4 but the symbol is not imported on the hexwang/dsv4_adapt_upstream base, raising NameError: name 'RoutingMethodType' is not defined during model init. Hexiang's recipe happens to skip this code path via --moe-backend triton_unfused CLI flag, but the LLM offline API takes the same path and trips on it. Minimal one-line fix: pull RoutingMethodType into the same multi-import block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc / mxfp4_*_moe_quant_config from fused_moe.config. This fix was originally in vllm-project#40889 but got auto-merged-out during the cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing when rebasing onto hexwang/dsv4_adapt_upstream which does not. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.
Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.
* `AiterSparseScratch.rebuild()` now only allocates buffers and stores
the static gqa/topk/dtype parameters; it no longer requires a
`kv_indptr_seed` and no longer runs the metadata builder itself.
* New `AiterSparseScratch.refresh_metadata()` reruns
`get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
* `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
`kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.
Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
- Call 1 (lens=[3,2]): success, scratch key set.
- Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
bit-identical to call 1.
- Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
expected (max abs diff = 2.39 vs identical-input call), no fault.
- Parity check vs the original non-cudagraph implementation:
max abs diff = 0.000000.
Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.
Concretely:
* `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
* `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
bf16->fp8 cast into them in place.
* `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
`current_workspace_manager().get_simultaneous(...)` so they share a
single workspace allocation, mirroring how prefill already does it.
Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.
Validated on MI355X with a standalone microbench:
* bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
* `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
* pointer stable across varying per-step `lens` patterns
* shape / dtype mismatch raises as expected
Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
29ef9be to
c9a8bd0
Compare
Summary
Rebased version of vllm-project#40909 for review against the
tj/dsv4prrebaseintegration branch. Stacks on top of the rebased vllm-project#40892 (which itself stacks on the rebased vllm-project#40889) — see the parallelchuali/aiter-mla-dsv4-decode-cudagraph-rebasedPR.Per Tun Jian's note that vLLM upstream is now planning to merge DSv4 via the rebased PR vllm-project#40860, the AITER decode stack needs to land on top of that newer base. This PR cherry-picks the 1 commit from vllm-project#40909 on top of the rebased vllm-project#40889 + vllm-project#40892 commits, all on
ROCm/vllm:tj/dsv4prrebase.This does not duplicate vllm-project#40909 — that PR will stay open as a draft against
vllm-project/vllm:mainuntil vllm-project#40860 lands; this PR exists so reviewers can look at the workspace sharing on top of the new base today.Stacking
This PR includes 5 commits in its diff: 2 from rebased vllm-project#40889 + 2 from rebased vllm-project#40892 + 1 from rebased vllm-project#40909. Once the prior PRs land, this one will reduce to just the 1 workspace-sharing commit. Happy to rebase on request.
What changed during rebase
_forward_decode_aiterextension cleanly on top of the new base._dequantize_blocked_k_cache(out=)signature still consistent with both the AITER path (passesout=) and the original_forward_decode_fallback(uses defaultNone).Original PR description
Stacks on top of vllm-project#40892 ("Make AITER sparse decode cudagraph-clean").
The cudagraph-clean refactor in vllm-project#40892 covers all the per-step control tensors (qo_indptr / kv_indptr / kv_indices / kv_last_page_lens / q_scale / kv_scale / q_fp8 / out_buf / work_meta_data) via
AiterSparseScratch. What it does not cover, by design, is the two big per-layer-sized data-bearing intermediates inside the decode hot path:_dequantize_blocked_k_cache(one per scope: SWA + extra).kv_fp8 = blocked_k.to(fp8_dtype)inside_aiter_decode_one_scope.On a 61-layer DSv4 model that means 122–244 fresh ~kv-cache-sized allocations per step that all land in the cudagraph memory pool and never get reused, multiplying that pool by ~60x worth of redundant slots. This PR fixes that.
What changed
_dequantize_blocked_k_cacheaccepts an optionalout=bf16 buffer and writes into it in place when provided.aiter_sparse_attn_decodeand_aiter_decode_one_scopeaccept optionalkv_fp8_buf/extra_kv_fp8_buffp8 buffers and copy the bf16->fp8 cast into them in place._forward_decode_aiter(ROCm path) pulls the 2-or-4 buffers fromcurrent_workspace_manager().get_simultaneous(...)so a single workspace allocation is shared across all 61 layers, mirroring the prefill path that already does this.AiterSparseScratchdocstring updated to make the workspace-vs-scratch ownership boundary explicit.The buffer sizes depend only on static kv-cache shape (
num_blocks,block_size,head_dim), so the workspace reaches its max during the first warmup call and stays stable through cudagraph capture andlock_workspace(). No new env var, no behavior change on the bf16 / un-buffered fallback path (passingkv_fp8_buf=Nonepreserves the old code path bit-for-bit).Test plan
bench_remote/_unit_test_workspace.py):buf=Nonevsbuf=workspacepaths agree to max abs diff = 0.0;kv_fp8_buf.data_ptr()stable across 61 simulated layer calls; pointer stable under varyinglens; shape/dtype guards fire as expected.tj/dsv4prrebase— running now alongside the rebased [ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X vllm-project/vllm#40889 / [ROCm][DSv4] Make AITER sparse MLA decode cudagraph-clean (follow-up to #40889) vllm-project/vllm#40892AI assistance disclosure
AI assistance (Claude / Cursor agent) was used for this rebase. The submitter has reviewed every cherry-picked diff line. Microbench above was run on MI355X against the original base; will re-run on
tj/dsv4prrebaseonce the parity sweep completes.Related