Skip to content

Integrate AITER fused RoPE kernels with fallback to TE native#541

Open
suachong wants to merge 11 commits intodevfrom
feat/aiter-fused-rope
Open

Integrate AITER fused RoPE kernels with fallback to TE native#541
suachong wants to merge 11 commits intodevfrom
feat/aiter-fused-rope

Conversation

@suachong
Copy link
Copy Markdown
Contributor

Description

Integrate AITER's optimized HIP/ASM RoPE kernels into TE's FusedRoPEFunc on ROCm.

When aiter is installed and the input meets the supported subset (sbhd format,
non-interleaved, no context parallelism, no packed sequences, no start_positions),
the forward and backward passes dispatch to aiter.ops.rope.rope_fwd /
rope_bwd for improved performance on AMD GPUs.

Fallback to the existing tex.fused_rope_forward / tex.fused_rope_backward
is automatic for all other configurations, and when AITER is not available.

A new env var NVTE_USE_AITER_ROPE (default "1") allows explicit opt-out.
The AITER import is gated behind IS_HIP_EXTENSION to avoid unnecessary import
attempts on CUDA systems.

Tested in MLPerf GPT-OSS-20B MoE pretraining on MI355X (8×GPU).

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • Add module-level AITER rope kernel import with robust error handling (except Exception) and IS_HIP_EXTENSION gate
  • Add NVTE_USE_AITER_ROPE env var (default "1") to allow opt-out
  • Add FusedRoPEFunc._can_use_aiter() guard that restricts AITER dispatch to sbhd format, non-interleaved, no CP, no THD, no start_positions
  • Dispatch to aiter.ops.rope.rope_fwd / rope_bwd in forward/backward when guard passes; fall back to tex.fused_rope_* otherwise
  • Add test_aiter_rope_matches_te_fused: verifies AITER and TE fused produce identical output and gradients (parametrized over dtype, seq_length, hidden_size, rotary_percent, loss_func)
  • Add test_aiter_rope_can_use_guard: exhaustive unit test of guard logic (6 parametrized cases)
  • Add test_aiter_rope_env_var_disable: verifies _HAVE_AITER_ROPE=False disables dispatch
  • Add test_aiter_rope_fallback_unsupported: verifies unsupported configs fall back correctly

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Add optional AITER RoPE dispatch path in FusedRoPEFunc for improved
performance on ROCm/AMD GPUs. When aiter is installed and the input
meets the supported subset (sbhd format, non-interleaved, no context
parallelism, no packed sequences, no start_positions), the forward and
backward passes dispatch to aiter.ops.rope.rope_fwd / rope_bwd.

Fallback to the existing tex.fused_rope_forward / tex.fused_rope_backward
is automatic for all other configurations and when AITER is not available.

A new env var NVTE_USE_AITER_ROPE (default "1") allows explicit opt-out.
The AITER import is gated behind IS_HIP_EXTENSION to avoid unnecessary
import attempts on CUDA systems.

Add unit tests for AITER-vs-TE numerical parity, guard logic coverage,
env var disable behavior, and fallback on unsupported configurations.

Tested in MLPerf GPT-OSS-20B MoE pretraining on MI355X (8xGPU).

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
@suachong suachong force-pushed the feat/aiter-fused-rope branch from 085a9c1 to 8b05a6a Compare April 15, 2026 18:54
Comment thread transformer_engine/pytorch/attention/rope.py
Comment thread transformer_engine/pytorch/attention/rope.py Outdated
Comment thread transformer_engine/pytorch/attention/rope.py Outdated

if IS_HIP_EXTENSION:
try:
from aiter.ops.rope import ( # pylint: disable=import-error
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any AITER versioning that can be used to constrain using of the API?

Comment thread transformer_engine/pytorch/attention/rope.py Outdated
Comment thread tests/pytorch/test_fused_rope.py Outdated
Comment thread tests/pytorch/test_fused_rope.py Outdated
- Add AMD copyright header to rope.py
- Check IS_HIP_EXTENSION first, guard all AITER code behind it
- Use logger.warning for AITER import failures instead of logger.info
- Log AITER version (via aiter._version) on successful import
- Default NVTE_USE_AITER_ROPE to "0" (opt-in) since CI cannot test it
- Expose _HAVE_AITER_ROPE via FusedRoPEFunc.has_aiter_rope() method
- Use @pytest.mark.skipif decorator instead of inline pytest.skip()

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Provides a containerized way to test the AITER fused RoPE integration
on ROCm systems, since CI cannot test this feature.

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Comment thread tests/pytorch/aiter_rope_test/Dockerfile Outdated
Comment thread transformer_engine/pytorch/attention/rope.py
Local testing infrastructure, not intended for the repository.

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Follow existing convention: add AMD copyright above the NVIDIA header
with "modified for portability to AMDGPU" note, rather than replacing it.

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Comment thread transformer_engine/pytorch/attention/rope.py Outdated
Comment thread transformer_engine/pytorch/attention/rope.py Outdated
- Replace logger.warning with RuntimeError when NVTE_USE_AITER_ROPE=1
  but AITER import fails, making the failure explicit instead of
  silently falling back to TE native kernels
- Remove all diagnostic logging (version info, reason tracking) to
  reduce maintenance burden and stay synchronized with upstream

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Comment thread transformer_engine/pytorch/attention/rope.py Outdated
Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>
Comment thread tests/pytorch/test_fused_rope.py
Comment thread tests/pytorch/test_fused_rope.py Outdated
Comment thread tests/pytorch/test_fused_rope.py Outdated
Comment thread tests/pytorch/test_fused_rope.py
Comment thread tests/pytorch/test_fused_rope.py Outdated
Comment thread transformer_engine/pytorch/attention/rope.py Outdated
Comment thread tests/pytorch/test_fused_rope.py
- Guard `import os` and env var check under IS_HIP_EXTENSION in rope.py
  to minimize upstream diff
- Add AMD copyright header to test_fused_rope.py
- Guard `unittest.mock` and `FusedRoPEFunc` imports behind IS_HIP_EXTENSION
- Add IS_HIP_EXTENSION skipif guard to all AITER test functions
- Use torch.device("cuda") instead of hardcoding cuda:0 in AITER test

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Comment thread tests/pytorch/test_fused_rope.py Outdated
apply_fused_qkv_rotary_pos_emb,
)

try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer other modules. IS_HIP_EXTENSION importing does not require try/catch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Follow repo convention: import IS_HIP_EXTENSION directly from
torch.utils.cpp_extension without try/except guard, consistent
with all other test modules.

Signed-off-by: Su Ann Chong <suachong@amd.com>
Made-with: Cursor
Comment on lines +16 to +19
try:
from torch.utils.cpp_extension import IS_HIP_EXTENSION
except ImportError:
IS_HIP_EXTENSION = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to guard

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants