Skip to content

gfx11 MoE experiments#908

Draft
mgehre-amd wants to merge 23 commits intogfx11from
matthias.moe-perf-opt-experiments
Draft

gfx11 MoE experiments#908
mgehre-amd wants to merge 23 commits intogfx11from
matthias.moe-perf-opt-experiments

Conversation

@mgehre-amd
Copy link
Copy Markdown

@mgehre-amd mgehre-amd commented Apr 28, 2026

vllm-bench.py \
    --rocm-dir /scratch/$USER \
    --model cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit \
    --input-len 128 \
    --output-len 128 \
    --max-model-len 4096 \
    --target-gpu-memory-gb 55 \
    --ready-check-timeout-sec 2700 \
    --trust-remote-code \
    --dtype float16

gives 12.78 ms. (And 13.50 ms without --dtype float16).

Adds fused_moe_wvSplitK_int4_gemm that dispatches expert blocks via
blockIdx.y on-device, eliminating host-side loops and GPU-CPU sync.
Weights are in skinny layout [E, N, K//8] int32 (ExLlama shuffle).

Key optimizations for RDNA 3.5 decode (batch=1):
- Use all CUs per expert block for maximum bandwidth
- YTILE=2 for N=1 decode (better occupancy than YTILE=1 or 4)
- Reduced LDS allocation (16KB vs 64KB) for higher occupancy
- Non-temporal weight loads to avoid L1 pollution
- Scattered mode with sorted_token_ids for decode without
  pre-permutation

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Dispatch MoE INT4 GEMM based on batch size: Triton for prefill (M>5),
HIP wvSplitK for decode (M<=5). Both read from the same shuffle-packed
[E, N, K//8] int32 weights — no duplication.

The Triton path adds use_shuffle_w4a16 to fused_moe_kernel_gptq_awq
which unpacks ExLlama-shuffled int32 via tl.interleave, then extracts
nibbles with shift+mask. Scales are [E, N, K//G], symmetric only.

Weight processing converts GPTQ [E, K/8, N] to skinny [E, N, K//8]
with ExLlama shuffle packing at load time. Enabled by default on ROCm
via VLLM_MOE_HYBRID_W4A16=true.

Qwen3-Omni-30B-A3B AWQ on Strix Halo (vs exllama baseline):
  TPOT: 14.51ms → 13.73ms (-5.4%)
  TTFT: 996ms → 841ms (-15.6%)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Build self.moe_kernel directly in process_weights_after_loading via
maybe_make_prepare_finalize(allow_new_interface=True) so HybridW4A16MoEExperts
runs on single-GPU deployments (no DP/EP), where the legacy select_gemm_impl
path is bypassed by the upstream MoE refactor.

Route apply() through self.moe_kernel.apply when the hybrid path is active;
the legacy fused_experts call is preserved as the non-hybrid fallback. The
dead select_gemm_impl branch for hybrid is removed.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The MoERunnerBase.forward_dispatch already calls self.gate when it owns the
gate (gate=self.gate is passed to SharedFusedMoE.__init__). The model-level
explicit self.gate(hidden_states) call was redundant after the MoE-runner
refactor (commits 93bada4, 809d83c, 19ec9a0) and produced an extra
GEMV per layer per decode step.

Replaces the unconditional gate call with the canonical is_internal_router
branch already used by qwen3_next.py and deepseek_v2.py. When the runner
holds the gate, pass router_logits=hidden_states as a placeholder (the
runner's forward_dispatch ignores it and recomputes via self.gate).

Bench (Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit, Strix Halo): TPOT 15.20 -> 15.11 ms
(-0.6%, decode 65.8 -> 66.2 tok/s). Eliminates 48 router GEMVs per decode step.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The bf16 branch of wvSplitK_int4_compute_sml_/wvSplitK_int4_compute_ used a
scalar fallback that clang lowered to ~227 v_bfe_u32 + cvt instructions per
inner loop iteration. The fp16 branch already used the magic-number trick
(0x64006400 + __hsub2/__hfma2); bf16 was the laggard.

Port marlin's dequant<nv_bfloat162, kU4B8> pattern: single mask 0x000F000F,
shift between 4 extractions, 4x __hsub2. Critical: bf16 has only 7 mantissa
bits (vs fp16's 10), so the original two-mask 0x000F000F + 0x00F000F0 pattern
leaks into the bf16 exponent and corrupts output. The shift+single-mask
pattern is correct.

ASM: 227 -> 67 v_bfe_u32 (-71%), 32 v_and_or_b32 introduced.

Bench (Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit, Strix Halo): TPOT 15.11 -> 14.74 ms
(-2.4%, decode 66.2 -> 67.8 tok/s, vs O28-fix baseline). Reproducible across
2 consecutive runs.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Adds a tunable launch macro for the gfx1x N=1 fused-MoE w4a16 path that
exposes WvPrGrp and A_CHUNK as template parameters (in addition to the
existing YTILE/UNRL knobs), driven by VLLM_MOE_INT4_TUNE_IDX. The default
case is hard-coded to C2 (THRDS=32, YTILE=2, WvPrGrp=4, A_CHUNK=16,
UNRL=4); the env var override is retained for further sweeps.

The baseline kernel launched __launch_bounds__(WvPrGrp_template * THRDS)
with WvPrGrp_template hard-coded to 16, but the runtime mindiv_int4()
clamp resolved __wvPrGrp to ~2 for the active Qwen3-Omni MoE shapes
(M=1536, CU=20, YTILE=2). 14 of 16 launched warps therefore returned at
the early-exit (compute body line 173) without doing any math, wasting
launch-bounds budget and occupancy.

12-config kernel-bench sweep (gemm1 1536x2048 + gemm2 2048x768, 200
iters each, two runs):
  baseline  (WvPrGrp=16)               : gemm1 79.7us  gemm2 47.6us  total 127.5us
  C1        (WvPrGrp=8)                : 79.0  43.0  122.0
  C2        (WvPrGrp=4)                : 76.0  43.0  119.0  <- winner
  C3        (WvPrGrp=8, UNRL=8)        : 76.0  44.5  120.5
  C5        (YTILE=4, WvPrGrp=8)       : 76.1  45.3  121.4
  others (C4/C6/C7/C8/C10/C11)         : 124..138us

End-to-end on cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit
(input/output=128, max-model-len=4096): TPOT 14.61 ms vs prior
14.74 ms baseline (-0.13 ms / -0.88%). Decode 68.4 tok/s. Sanity
check passed.

Changes:
- get_moe_int4_tune_idx(): cached env-var read; default 2 (C2).
- MOE_WVSPLITK_INT4G_LAUNCH_TUNED: 6-axis launch macro mirroring the
  fixed MOE_WVSPLITK_INT4G_LAUNCH but with WvPrGrp+A_CHUNK exposed.
- MOE_WVSPLITK_INT4G_TUNED_DISPATCH_GFX1X / MOE_WVSPLIT_INT4G_TUNED_N1:
  switch on the env idx; gfx9 keeps the original baseline path.
- The N=1 (else) branch of MOE_WVSPLIT_INT4G_TILE now routes through
  the tuned dispatcher; all N>=2 branches are unchanged.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
In Qwen3-Omni (and other Qwen-VL/MoE models), all 48 attention layers share
a single MRotaryEmbedding instance via the _ROPE_DICT cache and pass the
same `positions` tensor through the model. The previous code re-evaluated
`cos_sin_cache[positions]` (an `aten::index` gather) plus a chunk + two
contiguous() clones for every layer in every step.

Cache the (cos, sin) pair on the instance, keyed by positions data_ptr,
shape, and cos_sin_cache identity. The cache is bypassed during
torch.compile tracing so the captured graph stays consistent; in eager
and HIP-graph capture/replay, the Python-level guard resolves at capture
time and the gather collapses to one call per forward.

Also force cos/sin contiguous at fill so triton_mrope's internal
.contiguous() is a no-op for all 48 layers, killing 96 redundant clones
per decode step.

Changes:
- forward_native and forward_cuda both go through the cached lookup, so
  the eager fallback path benefits identically.
- Cache key intentionally avoids any tensor read so it stays GPU-sync free.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The int4 dequant inner loop uses `__hsub2(bf16x2, bf16x2)` to subtract a
constant bias from values constructed via the marlin-style magic-number
trick. On gfx1151 (RDNA 3.5) there is no native v_pk_sub_bf16, so LLVM
lowers each call to a fp32 round-trip with full IEEE-754 RTNE-with-NaN-
quieting packing. That packing emits `v_cmp_u_f32 / v_or 0x400000 /
v_add3 0x7fff / v_cndmask` per lane.

The dequant inputs are constructed in [128.0, 143.0] and the bias is
128.0 or 136.0, so the result is provably in [-8, 7] and cannot be NaN
or Inf. The NaN-quieting branch is dead code.

Replace the bf16 `__hsub2` calls with a small helper that:
  1) expands each bf16 to fp32 via shift,
  2) does the subtract in fp32,
  3) packs back to bf16 with RTNE only (no NaN canonicalisation).

Bit-exact for all finite inputs; saves ~7 vector-ALU ops per call out
of ~12. The fp16 path uses native `v_pk_add_f16` and is unchanged.

ASM (moe_wvSplitK_int4_hf_sml<bf16,32,2,4,16,4,1,128,false>):
  v_cmp_u_f32:    249 -> 2
  v_or 0x400000:  130 -> 2
  v_cndmask_b32:  many -> 7

Kernel-bench harness (median, gfx1151):
  moe/gemm1   75.0us -> 72.0us  (-4%)
  moe/gemm2   47.0us -> 41.4us  (-12%)
  dense/qkv   38.0us -> 33.6us  (-12%)
  dense/o_pj  37.0us -> 33.3us  (-10%)

Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit decode (input=128/output=128,
2 runs, identical, vs clean rebuilt baseline at HEAD~1):
  Baseline (O75 only):  TPOT 14.63 ms,  decode 68.4 tok/s
  A7:                   TPOT 14.21 ms,  decode 70.4 tok/s
  Delta:                -0.42 ms (-2.9%), +2.0 tok/s
  Sanity: 1+1=2, 2+3=5

Marlin's bf16 dequant in csrc/quantization/marlin/dequant.h uses the
same `__hsub2` pattern and would benefit from the same change on gfx11.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Adds a tunable launch macro for the gfx1x N=1 dense bf16/fp16 wvSplitK_hf_sml
path that exposes WvPrGrp, A_CHUNK and UNRL as template parameters (in
addition to the existing YTILE knob), driven by VLLM_BF16_SML_TUNE_IDX. The
default case is hard-coded to the sweep winner (THRDS=32, YTILE=2, WvPrGrp=4,
A_CHUNK=16, UNRL=4); the env var override is retained for further sweeps.

Mirrors the C2 pattern committed in 24ad556717 for the int4 MoE path. The
same WvPrGrp=16 launch_bounds-budget waste analysis applies: the dense N=1
kernel hits both the 152064x2048 lm_head GEMV (top-of-step) and 96x per-step
128x2048 router GEMVs in Qwen3-Omni decode, and at WvPrGrp=16 most warps
return at the early-exit without doing math.

12-config kernel-bench sweep (lm_head 152064x2048 + router 128x2048, 200
iters each):
  baseline   (16, 8,2)   : lm_head 2829us  router 9.58us
  idx=10     ( 4,16,4)   : lm_head 2568us  router 8.70us  <- winner (-9.2%)
  idx= 7     ( 4,16,2)   : lm_head 2622us  router 9.06us
  idx= 9     ( 8,16,4)   : lm_head 2652us  router 8.70us
  idx= 4     ( 8, 8,8)   : lm_head 2658us  router 8.78us

End-to-end on cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit
(input/output=128, max-model-len=4096): TPOT 13.88 ms vs prior 14.20 ms
baseline (-0.32 ms / -2.25%). Decode 72.1 tok/s (+1.7). Sanity check passed.

Changes:
- get_bf16_sml_tune_idx(): cached env-var read; default 10 (winner).
- WVSPLITK_BF16_TUNED_LAUNCH: 5-axis launch macro mirroring WVSPLITK_CFG but
  with WvPrGrp+A_CHUNK exposed; preserves the sml/_/big trio routing.
- WVSPLITK_BF16_TUNED_DISPATCH: switches on the env idx; case 0 falls back
  to the original WVSPLITK_CFG(32,16,2,2,1) line.
- The N=1 + use_wave32 + sYT > 1 branch of wvSplitK now routes through the
  tuned dispatcher; sYT == 1 (YTILE=1 branch) and N >= 2 are unchanged.
- mindiv(): cap iteration count at min(div2, 13) to avoid div-by-zero when
  div2 < 13 (mirrors mindiv_int4's pre-existing fix). Required for any
  WvPrGrp < 13 dispatch path; was a latent bug exposed by this sweep.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The bf16 branch of DOT2C(V0, V2, V3) was lowering to a scalar f32
chain (4 ops per dot pair: mul.x + mul.y + add + acc). gfx11 supports
dot12-insts which provides v_dot2_f32_bf16 as a single packed dot
instruction; use the clang builtin __builtin_amdgcn_fdot2_f32_bf16 to
emit it directly. Bit-equivalent to the prior code for finite operands.

Local typedef bf16x2_t (clang ext_vector_type) is needed because the
intrinsic expects an ext-vector argument, not HIP's short2 struct.

ASM delta on void_moe_wvSplitK_int4_hf_sml_<bf16,32,2,16,16,2,1,128,
false> (production K1 instantiation, UNRL=2, N=1, GS=128, no zp):

  v_mul_f32        16  ->   0  (eliminated)
  v_dot2_f32_bf16   0  ->  32  (one per dot pair)
  v_add_f32        64  ->  40  (-24, accumulators folded into dot)

Kernel-bench harness (median, 200 iters):

  moe/gemm1   75.0 us -> 66.4 us  (-11.5%)
  moe/gemm2   47.0 us -> 40.4 us  (-14.0%)

End-to-end on cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit
(in/out=128, max_model_len=4096, target_gpu_mem=55GB), 3 runs:

  TPOT median 14.10 ms -> 13.76 ms  (-0.34 ms / -2.4%)
  Decode      70.9 t/s -> 72.7 t/s  (+1.8)

Sanity check passed (1+1=2, 2+3=5).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The shared-experts/router overlap path (`SharedExpertsOrder.MULTI_STREAM_OVERLAPPED`)
in `_determine_shared_experts_order` was gated on `current_platform.is_cuda()`,
which returns False on ROCm. As a result, the existing aux-stream overlap
infrastructure - already initialized via `aux_stream()` (which itself uses
`is_cuda_alike()`) - was never actually engaged on ROCm.

Switch the gate to `is_cuda_alike()` so ROCm can use the same aux-stream overlap
between the router/gate and the shared-experts MLP that CUDA already enjoys.

Note: this is a no-op for Qwen3-Omni-30B (shared_expert_intermediate_size=0,
no shared experts to overlap), but benefits other Qwen3-MoE / DeepSeek-style
models that have shared experts.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The hybrid W4A16 MoE prefill path (`HybridW4A16MoEExperts`) used a single
static Triton config (`BLOCK_M=64, BLOCK_N=64, BLOCK_K=min(gs,32)=32`,
default `num_warps=4 num_stages=2`) for every prefill GEMM, regardless of
shape.  On Qwen3-Omni-30B-A3B (E=128, top_k=8, group_size=128) this left
the GEMM1/gate_up shape (M=128, N=1536, K=2048) running at only 24% of
the memory roofline.

A 144-config microbench sweep over (BLOCK_M, BLOCK_N, BLOCK_K, num_warps,
num_stages) at the two production prefill shapes shows the optimal config
differs between gate_up (narrow N=1536, prefers BLOCK_N=32 nw=2 ns=2) and
down (wider N=2048, prefers BLOCK_N=128 nw=4 ns=1).  Bumping BLOCK_K from
32 to 64 (legal since group_size=128) halves K-loop iterations.  Both
GEMMs share `BLOCK_M=64` because the same `moe_align_block_size` output
feeds both kernels; changing kernel-BM independently of align-BM would
mis-index `expert_ids`.

Changes:
- `_triton_config` now takes `(K, N)` and dispatches on N at threshold
  1792 (gate_up N=1536 vs down N=2048 for Qwen3-Omni; will also work for
  other MoE topologies where down_proj N > gate_up N//2).
- BLOCK_K capped at min(group_size, 64) instead of min(group_size, 32).
- The two GEMM call sites compute their config once each and pass the
  appropriate one.

Measured on Strix Halo (gfx1151), in/out=128/1, 3 prompts:
  Median TTFT 278 ms -> 251 ms (-9.7%, -27 ms) across 3 runs.
  Decode TPOT 13.76 ms -> 13.78 ms (+0.02 ms, within noise).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Add (N,K)-exact carve-outs at the top of the M<=128 gfx1x branch in
the dense `_triton_w4a16_skinny_fmt_kernel` heuristic. The existing
generic branches were tuned on Qwen3-4B (gs=128) and used suboptimal
configs for Qwen3-Omni-30B's qkv_proj (5120x2048) and o_proj (2048x4096)
shapes at gs=32:

  qkv (5120,2048):  generic wide-N (BM=64,BN=64,BK=64,nw=4) -> 0.439 ms
                    new                (BM=64,BN=32,BK=32,nw=4) -> 0.342 ms (1.29x)
  o_proj (2048,4096): generic tall-K (BM=64,BN=16,BK=64,nw=1) -> 0.461 ms
                      new              (BM=32,BN=128,BK=32,nw=2) -> 0.376 ms (1.23x)

End-to-end on Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit (in=128, out=1):
5-run median TTFT 250 -> 246 ms (-4 ms / -1.6%). Decode TPOT (in=128,
out=128) unchanged at 13.78 ms.

Carve-outs are exact (N,K) matches and fall through to the existing
generic branches for any other shape, so Qwen3-4B and Qwen3-7B M=128
configs are bit-equivalent (verified: q/o, qkv, gate_up, down all hit
the same heuristic line as before).

Note: the kernel clamps `BLOCK_K = min(BLOCK_K, group_size)`, so for
gs=32 the practical max BLOCK_K is 32. The carve-outs intentionally
specify BK=32 to make this explicit.

Sweep harness: tune_triton_w4a16_skinny_fmt.py (not committed).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Adds a Triton-based fused MoE megakernel for the M=1 decode case of
AWQ-int4 quantized models, gated by the env var VLLM_MOE_MEGAKERNEL=1
(off by default).  The kernel chains router GEMV + topk_softmax +
per-expert (W1 + silu_and_mul + W2) + topk-weighted reduce in three
Triton launches that share a small HBM scratch (TOPK*N2 fp32) instead
of round-tripping through workspace tensors.

Numerical correctness verified (tests/kernels/moe/test_moe_megakernel.py)
against a pure-fp32 PyTorch reference: max relative error ~0.27% (bf16
storage rounding).

Empirical perf on Strix Halo (gfx1151) with
cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit, M=1, group_size=32:
  - baseline (HIP wvSplitK_int4 path):    TPOT 13.74 ms, 72.8 tok/s
  - megakernel enabled:                   TPOT 37.43 ms, 26.7 tok/s

The megakernel regresses because the default decode path uses a heavily
hand-tuned HIP wvSplitK_int4 GEMV, while the Triton W4A16 dequant inside
the megakernel is slower for M=1 GEMV shapes.  Kept in-tree gated as a
reference implementation and a base for future fusion work.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Adds a HIP fused MoE kernel for the M=1 decode case of AWQ-int4
quantized models, gated by env var VLLM_MOE_HIP_MEGAKERNEL=1 (off
by default).  Replaces the (GEMM2 launch + moe_unpermute launch)
pair with a single HIP kernel that does:

  for slot in [0, top_k):
    expert_id = topk_ids[slot]
    for m in [0, K_hidden) striped across CuCount workgroups:
      acc = sum_k act[slot, k] * dequant(W2[expert_id, m, k])
      partial[slot, m] = topk_weights[slot] * acc
  # grid-wide barrier (atomic counter + threadfence)
  reduce partial[:, m] -> out[m]

Reuses the existing wvSplitK_int4 inner loop (DOT2C macro,
marlin-style bf16 dequant via bf16x2_dequant_sub_finite).  The
slot loop runs inside one persistent kernel; the barrier waits
on a HBM uint32 atomic counter and is safe because we launch
exactly CuCount blocks (all guaranteed resident).

Numerical correctness: tests/kernels/moe/test_moe_megakernel_rocm.py
verifies bit-exact (max_abs=0.0) match against a pure-fp32 reference
on random AWQ weights at production-like shapes (K_hidden=2048,
INTERMEDIATE=768, top_k=8) for both group_size=32 and 128.  Three
test cases pass; barrier reset across calls is also verified.

Empirical perf on Strix Halo (gfx1151) with
cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit, M=1, group_size=32,
3 runs:
  - baseline (default HybridW4A16MoEExperts.apply path):  TPOT 13.70 ms
  - HIP megakernel enabled (3 runs): TPOT 17.26, 17.28, 17.25 ms

Regression: +3.55 ms (~26% slower).  Root cause analysis:
  1. The default path's GEMM2 + moe_unpermute uses the heavily
     hand-tuned wvSplitK_int4 kernel for GEMM2 (production-tuned
     C2 config: WvPrGrp=4, YTILE=2, A_CHUNK=16, UNRL=4) with grid
     dim (CuCount, top_k=8) = 160 blocks at THRDS*WvPrGrp=128
     threads each, then a tiny scatter-reduce.
  2. The HIP megakernel collapses the two launches to one but
     reduces the available parallelism: only CuCount=20 blocks
     run (one per CU), each iterating top_k=8 slots serially.
     This linearizes work that the original chain pipelines
     across slots, costing ~3.5 ms / decode.
  3. Going back to (CuCount, top_k) grid would risk barrier
     deadlock because Strix Halo's resident-WG limit at this
     register pressure is < 160 blocks.  A true cooperative
     launch API (hipLaunchCooperativeKernel) would be needed
     for a safe 160-block grid; not yet wired up here.

Kept in-tree gated as a reference for the persistent-grid
approach with HBM-mediated atomic-counter barrier on RDNA3.
The barrier mechanism itself works (verified by the bit-exact
correctness across 3 test cases) and could be reused for a
larger fusion scope (e.g. router + GEMM1 + silu + GEMM2 in one
kernel) once cooperative grid launch is plumbed through.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The original GEMM2+reduce megakernel staged each slot's topk-weighted
partial through HBM scratch (`partial[top_k, M]`), then used an
atomic-counter grid-wide barrier so that the last block could reload
all slots and emit `out[:, M]`.

Each workgroup actually owns a disjoint m-stripe (the per-iter step is
`CuCount * WvPrGrp * YTILE`, blocks differ only in `blockIdx.x`), so
no cross-block reduction was ever required. The HBM round-trip plus
the cross-block barrier were pure overhead.

This change keeps the slot partials in LDS and reduces them in-place at
the end of the slot loop, writing `out[m_global]` directly. The HBM
scratch and barrier args remain in the signature (still passed by the
host wrapper) but are unused; removing them would require an ABI break
across the Python op registration.

Changes:
- `MEGA_PARTIAL_LDS_FLOATS = 2048`: per-block partial slab capacity
  (top_k x per-block m-elements). Fits comfortably alongside the
  existing 8192-scalar activation LDS within the 64 KB gfx11 budget.
- Each block tracks `local_iter` and writes
  `lds_partials[slot * per_block_m + local_m]` from lane THRDS-1 (or
  lane 63 on gfx9 wave64) instead of `Cp[m + i + n*M]`.
- Post-slot loop: a single `__syncthreads()` then `THRDS * WvPrGrp`
  threads sum across `top_k` from LDS into `out`.
- Host wrapper TORCH_CHECKs the LDS budget against `M_in x top_k`.

End-to-end (Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit, M=1 decode,
VLLM_MOE_HIP_MEGAKERNEL=1, 3-run median):
  TPOT 15.13 ms (was 15.65 ms; default non-mega path is 13.70 ms)

Unit tests `tests/kernels/moe/test_moe_megakernel_rocm.py` (group_size
32 / 128 + repeat-call) all pass.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Doubles the workgroup width so each block now runs 8 wave32s instead of 4,
giving the hardware more per-block arithmetic to hide the per-slot
LDS-staging latency (the megakernel re-stages activations once per top_k
slot, so per-slot startup overhead matters).

Changes:
- WvPrGrp constexpr 4 -> 8 in both the host launcher and the LDS-budget
  pre-check. The kernel template was already parameterized on WvPrGrp;
  no kernel-side changes needed.
- LDS footprint unchanged (24 KB activations + partials, well under the
  64 KB CU budget). per_block_m budget recomputed: with CuCount=80,
  WvPrGrp=8, YTILE=2, M=2048, top_k=8, needed=256 floats, still fits in
  MEGA_PARTIAL_LDS_FLOATS=2048.
- launch_bounds(WvPrGrp*THRDS) tightens the VGPR budget: clang now
  spills nothing and bf16 kernel uses 115 VGPR (was 163 with 4 waves
  / UNRL=8 attempt; HEAD with 4 waves / UNRL=4 was lower still but at
  half the per-block parallelism). 256 threads x 115 VGPR fits one wave
  per SIMD with room for multiple resident blocks.

Bench (Strix Halo, Qwen3-Omni-30B-AWQ-4bit, --max-num-seqs 1, 3-run median):
- A only (HEAD)         15.13 ms TPOT
- A+B (this commit)     14.33 ms TPOT  (-0.80 ms, -5.3%)
- A+C (UNRL=8 attempt)  15.14 ms TPOT  (no change; rejected)

Both A+B and A+C also keep the test_moe_megakernel_rocm.py suite passing
(3 tests, group_size 32 and 128, plus the repeat-call test).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Wider WG hides more per-slot LDS staging latency. Compiler keeps
115 VGPR/wave (under 128 limit at 12 waves). Median TPOT 14.33 ms
-> 14.15 ms on Qwen3-Omni-30B-A3B-AWQ-4bit decode (Strix Halo).

Changes:
- WvPrGrp 8 -> 12 in both the host wrapper LDS budget calc and the
  kernel launch config; LDS partial budget recomputed and still fits
  the 2048-float reservation.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Wider WG exposes more thread-level parallelism per block; despite
115 VGPR/wave exceeding the 96-VGPR full-occupancy limit at 16
waves, throughput improves because the driver still launches one
block per CU and the larger TLP wins. Median TPOT 14.15 ms ->
13.79 ms on Qwen3-Omni-30B-A3B-AWQ-4bit decode (Strix Halo).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
…roll)

Each thread now consumes 32 K-elements per UNRL tick, halving the
loop overhead and clustering 16-byte VMEM loads back-to-back. Median
TPOT 13.79 ms -> 13.61 ms on Qwen3-Omni-30B-A3B-AWQ-4bit decode
(Strix Halo). Now beats the default chain (13.70 ms).

Changes:
- A_CHUNK 16 -> 32; bigTypeA/bigTypeW unions auto-resize, sum/bigA/
  bigB register footprint unchanged because YTILE/UNRL/N stayed put.
- s_clause hint experiment (Y.a) was tried first and regressed by
  +0.04 ms, so was reverted before this commit.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Adds a FUSE_SILU template parameter that folds silu(gate)*up into the
megakernel's per-slot LDS staging loop, eliminating one kernel launch
and one [top_k, INTER] HBM round-trip per MoE layer.

A naive translation of the standalone silu kernel regressed +1.03% TPOT
(13.74 vs 13.59 ms). Two fixes applied:

1. Redistribute silu work across all 512 block threads. The original
   A_CHUNK-strided layout left 488/512 threads idle on the
   if (k_in >= K) break and serialized expf+IEEE-divide on a single
   wave's critical path (~6 us per slot). The fused branch now uses a
   flat for (k_in = tid; k_in < K; k_in += blockDim) loop so all 512
   threads contribute to silu evaluation.

2. Replace g / (1.0f + __expf(-g)) with g * __builtin_amdgcn_rcpf(...).
   HIP compiles the IEEE divide as a 9 vector-ALU sequence (rcp + 4x
   Newton-Raphson + fixup); the rcpf intrinsic emits a single v_rcp_f32.

Result on Strix Halo Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit
(VLLM_MOE_HIP_MEGAKERNEL=1, --max-num-seqs 1, 3-run median):
  baseline (silu unfused)              13.59 ms
  v8 (naive fuse)                      13.74 ms (+1.03%)
  v8b (parallel staging + rcpf, this) 13.50 ms (-0.66%)

Bit-exact unit tests added for both code paths (fuse_silu={false, true}).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
In the bf16 inner loop, replace per-pair bf16x2_dequant_sub_finite (4
vector-ALU ops per 32-bit pair: bf16->fp32 widen, fp32 subtract, manual
round-half-even back to bf16) with the magic-encoded value directly
((qa & 0xF000F) | 0x43004300 = bf16(128 + nibble) packed pair).

The pre-existing scale step then folds in a single -136*sum_act
correction so the math reduces to:
  sum += scale * sum_pairs((nibble - 8) * activation)
       = scale * (partial_dot - 136 * sum_act)
where 136 = 128 (magic offset) + 8 (zero-point bias) and sum_act is the
horizontal sum of the chunk's bf16 activations, computed once per
(k2, n) via fdot2_f32_bf16 against a (1.0, 1.0) bf16 pair and reused
across all y of the YTILE loop.

Per chunk this trades ~16 dequant vector-ALU ops for ~16 fdot2 ops to
build sum_act + 1 fma per (n, y) at the scale step. Net reduction in
inner-loop ALU work, no change to fp16 path.

Result on Strix Halo Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit
(VLLM_MOE_HIP_MEGAKERNEL=1, --max-num-seqs 1, 3-run median):
  v8b baseline 13.48 ms
  v8e (this)   13.36 ms  (-0.89%)

Recovers ~17% of the bf16-vs-fp16 gap (full fp16 path: 12.78 ms).
Bit-exact unit tests still pass.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
…uivalent)

bf16x2_dequant_sub_finite drops the per-pair round-half-to-even chain
(bits + 0x7FFF + ((bits>>16)&1)) >> 16 and just truncates the fp32
result via (bits >> 16) | (other_bits & 0xFFFF0000u).

Math: every possible result of this function for the AWQ-int4 magic-
encoded dequant inputs is an integer in [-8, 7]. Such integers round-
trip through bf16 with zero low 16 bits, so dropping those bits is
bit-equivalent to RTNE rounding here.

Saves ~4 vector-ALU ops per pair (2 add + 2 mask+shift). gfx11 has no
v_cvt_pk_bf16_f32 instruction (gfx12+ only), so per-lane shuffle is
already the floor; this is the cheapest possible per-lane sequence.

Applied to both csrc/rocm/skinny_gemms_int4.cu (called from the dense
and MoE wvSplitK paths -- the residual ~0.6 ms bf16/fp16 gap lives
here) and csrc/rocm/moe_megakernel.cu (the helper exists in both for
header-locality reasons; the megakernel itself bypasses the call via
the v8e sum_act correction).

Validated end-to-end on cyankiwi/Qwen3-Omni-30B-A3B-Thinking-AWQ-4bit
with VLLM_MOE_HIP_MEGAKERNEL=1:

  gsm8k --limit 1000 --batch_size 1 (52 min):
    flexible-extract: 85.7% +/- 1.11%
    strict-match    : 84.6% +/- 1.14%

Within expected range for AWQ-int4 of this model; no measurable
accuracy loss vs the prior RTNE rounding.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant