Fix Triton WNA16 MoE fallback for CompressedTensorsWNA16MoEMethod#875
Draft
mgehre-amd wants to merge 2 commits intogfx11from
Draft
Fix Triton WNA16 MoE fallback for CompressedTensorsWNA16MoEMethod#875mgehre-amd wants to merge 2 commits intogfx11from
mgehre-amd wants to merge 2 commits intogfx11from
Conversation
The Triton WNA16 MoE path (triggered by VLLM_MOE_GPTQ_EXLLAMA=false) crashed with AttributeError because moe_mk and moe_quant_config were never initialized in __init__. These attributes are only set in the exllama and AWQ GEMV code paths, but the apply() method checks self.moe_mk and falls through to the Triton fused_experts() path which requires self.moe_quant_config. On Qwen3-Omni-30B-A3B-Instruct-AWQ-4bit (128 experts, top-8, w4a16 group_size=32), switching from ExllamaExperts to Triton WNA16 MoE reduces TTFT from 1345ms to 929ms (31% improvement) because: - Triton fused MoE handles gather/scatter + GEMM in a single kernel - Eliminates atomicAdd K-tiling overhead from the exllama contiguous kernel - Router gate GEMM drops from 8.2ms to 56us per layer Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
075bf94 to
0e3cf4c
Compare
These attributes are initialized to None in __init__ but later assigned FusedMoEKernel / FusedMoEQuantConfig values. Without explicit type annotations, mypy inferred the type as None and flagged the assignments as incompatible. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
mgehre-amd
commented
Apr 15, 2026
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _is_rdna_for_moe_default() -> bool: |
Author
There was a problem hiding this comment.
Remove this and the venv variable and do proper is_compatible checks in compressed_tensor_moe.py
mgehre-amd
commented
Apr 15, 2026
| @pytest.mark.parametrize("e,topk", [(8, 2), (16, 4)]) | ||
| def test_triton_wna16_moe(m: int, n: int, k: int, e: int, topk: int): | ||
| """Test the Triton WNA16 MoE fallback path (VLLM_MOE_GPTQ_EXLLAMA=false). | ||
|
|
Author
There was a problem hiding this comment.
Also test CompressedTensorsWNA16MoEMethod.apply()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The Triton WNA16 MoE path (triggered by VLLM_MOE_GPTQ_EXLLAMA=false) crashed with AttributeError because moe_mk and moe_quant_config were never initialized in init. These attributes are only set in the exllama and AWQ GEMV code paths, but the apply() method checks self.moe_mk and falls through to the Triton fused_experts() path which requires self.moe_quant_config.
I probably broke this when integrating the new exllama kernel, so fix this back.