Skip to content

feat: add AMD ROCm support via AITER CK flash attention backend#67

Open
ZJLi2013 wants to merge 1 commit intoSkyworkAI:mainfrom
ZJLi2013:feat/amd-rocm-support
Open

feat: add AMD ROCm support via AITER CK flash attention backend#67
ZJLi2013 wants to merge 1 commit intoSkyworkAI:mainfrom
ZJLi2013:feat/amd-rocm-support

Conversation

@ZJLi2013
Copy link
Copy Markdown

Summary

Add AMD Instinct MI300X GPU support to Matrix-Game-3.0 using AITER (AI Tensor Engine for ROCm) Composable Kernel (CK) flash attention backend.

  • attention.py: Add AITER CK dispatch path between FA3 and FA2. Dispatch priority: FA3 → AITER CK → FA2 → RuntimeError. No changes to existing NVIDIA code paths.
  • README.md: Add AMD ROCm installation section with ROCm 7.x + AITER setup instructions.

Motivation

The original code hard-asserts flash_attn availability, which fails on AMD GPUs. This PR adds native AMD support via AITER's CK backend, which compiles flash attention kernels to native AMD ISA (hipcc → LLVM → GCN).

Test Results (MI300X, gfx942, ROCm 7.2)

Backend Iter 2 (steady-state) Step time vs FA2 Triton
FA2 Triton 14s ~5.0s baseline
AITER CK 11s ~4.0s -25%
  • Tested with: 1× MI300X, 2 iterations, 3 inference steps, 704×1280, seed=42
  • Output: identical 1.8MB MP4 video across all backends
  • Environment: rocm/pytorch:rocm7.2.1_ubuntu22.04_py3.10_pytorch_release_2.9.1 + amd-aiter 0.1.13

Changes

File Lines Description
Matrix-Game-3/wan/modules/attention.py +32/-4 AITER CK import, dispatch path, error handling
Matrix-Game-3/README.md +19/-2 AMD MI300X support note, ROCm install section

Compatibility

  • NVIDIA: Zero impact — existing FA3/FA2 paths unchanged
  • AMD ROCm 7.x: AITER CK backend (recommended, best performance)
  • AMD ROCm 6.x: FA2 Triton backend still works via flash-attn with FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
  • No AITER or flash-attn: Clear RuntimeError with install instructions (replaces silent AssertionError)

Add AMD Instinct MI300X GPU support using AITER Composable Kernel (CK) flash attention backend, which compiles to native AMD ISA and delivers ~25% better steady-state performance than Triton kernels.

Dispatch priority: FA3 > AITER CK > FA2 > RuntimeError.

Tested on MI300X (gfx942) with ROCm 7.2 + AITER 0.1.13, producing identical outputs to the NVIDIA FA3 path.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant