Skip to content

Gfx1250 changes#527

Open
ipanfilo wants to merge 3 commits intodevfrom
ipanfilo/gfx1250_pr
Open

Gfx1250 changes#527
ipanfilo wants to merge 3 commits intodevfrom
ipanfilo/gfx1250_pr

Conversation

@ipanfilo
Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo commented Apr 7, 2026

Description

Port gfx1250 support changes from NPI branch.
It only includes changes in TE itself, no attention backends update.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Modify wavefront size to 32 where 64 was used
  • Add gfx1250 to TE GPU specific logic
  • Add gfx1250 to GEMM tests
  • Update AITER build to build for both V3 and V2 GPU architectures simultaneously
  • Add mode to use AITER Triton kernels for Flash Attention (controlled by NVTE_FLASH_ATTN_AITER env)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ipanfilo ipanfilo added the ci-level 2 CI test level 2 label Apr 7, 2026
@ipanfilo ipanfilo requested a review from AllenFarcas April 10, 2026 16:19
@ipanfilo ipanfilo marked this pull request as ready for review April 10, 2026 16:19
if is_hip_extension():
# only GFX9.4+ and MI200 machines support bf16
return 100 > gpu_arch >= 94 or is_mi200()
# only GFX9.4+, GFX12+ and MI200 machines support bf16, that excludes GFX10 and GFX11
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gfx10* and gfx11* are Navi cards? Is is that the HW does not support bf16, or our TE ROCm does not support those cards?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No BF16 support in HW

}

constexpr int THREADS_PER_WAVEFRONT = 64;
constexpr int THREADS_PER_WAVEFRONT = 32;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangye805 Did you check if there is any perf decrease on gfx942 or gfx950 if we use THREADS_PER_WAVEFRONT = 32?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it will need testing on real HW. But will you need this change in gfx942/gfx950?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For real Gfx1250 HW it should be 32. It also verified to work on Gfx9*. I changed it unconditionally because otherwise it will require runtime checking on host side

if get_device_compute_capability(0) == 95:
"""Return 64 MiB for gfx50x+, 32 MiB for all other architectures."""
"""TODO: gfx1250 WS size requirements"""
if get_device_compute_capability(0) in (95, 125):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we check with hipblasLt? What's the observation on gfx1250?

if IS_HIP_EXTENSION:
# only MI200 and newer machines support bf16
if get_device_compute_capability() in [(9, 4), (9, 5)] or is_mi200():
if get_device_compute_capability() in ((9, 4), (9, 5), (12, 5)) or is_mi200():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we rename is_mi200 to is_cdna2?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is historically is_mi200 and it checks specifically for MI2.0, so current name better reflect what the method does

message(WARNING "Python interpreter not found; skipping AITER API validation.")
endif()

# so far, there are only gfx942 and gfx950 v3 kernels
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the change in this file for?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER/CK building bug that limits them to V3 archs only is fixed a while ago. Right now there is no need to perform this filtering on TE side. Since AITER might have no V3 kernels for GFX1250, this filtering is removed so it can be built despite V3 support in AITER

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this file as TE dev now switch to QoLA

}

constexpr int THREADS_PER_WAVEFRONT = 64;
constexpr int THREADS_PER_WAVEFRONT = 32;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it will need testing on real HW. But will you need this change in gfx942/gfx950?

else:
fa_utils.use_aiter_triton = True
# Setup Flash attention utils
fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two questions:
1). Do we plan to switch to aiter flash-attn for gfx942/gfx950/gfx1250 all together?
2). If we do, is this fa_utils.version still as 2.7.1?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is run on emu and real HW, so answer is yes. For extra safety it might make sense to disable this by default (make env default 0)

Comment on lines +117 to +120
if(NOT EXISTS "${__AITER_SOURCE_DIR}/hsa")
message(FATAL_ERROR "Cannot find AITER v3 kernels location at ${__AITER_SOURCE_DIR}/hsa.")
continue()
endif()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The message(FATAL_ERROR ...) terminates the CMake configuration so the continue() is unreachable. Moreover, continue() is only valid inside a foreach / while loop and here it is at top level which would make it a CMake error if it would be reached.

endif()

# so far, there are only gfx942 and gfx950 v3 kernels
SET(V3_ASM_ARCHS_SUPPORTED "gfx942;gfx950")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know AITER actually generates v3 kernels for gfx1250 without having the V3 support specified directly?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not to instruct AITER to make V3 kernels, it is to reflect what does AITER support. There used to be a bug in AITER that it could only compile for V3 supported archs,

return False, "MXFP8 support is not enabled."
gpu_arch = get_device_compute_capability()
if gpu_arch == (9, 5):
if gpu_arch in ((9, 5), (12, 5)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is inconsistent gfx1250 detection across files:

  • gpu_arch in ((9, 5), (12, 5)) is for exact match
  • examples/jax/encoder/common.py gpu_arch >= 120 matches any gfx12x and tests/cpp/operator/test_cublaslt_gemm.cu prop.major >= 12 has the same broad match

If only gfx1250 supports MXFP8/FP8 here, the JAX example and C++ test will incorrectly enable tests on gfx1200 for example. Maybe we could pick one canonical predicate and use it everywhere to be consistent in what the C++ and Python runtime accept.

Comment on lines +108 to +109
except ImportError as e:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be better to log the error instead of silently fall through with no diagnostics? Something like attn_log.fa_logger.debug(...)

return True, ""
else:
return False, "Device arch gfx94x or gfx95x required for FP8 execution."
return False, "Device arch gfx94x or newer required for FP8 execution."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Technically, gfx10xx and gfx11xx are newer but they are explicitly excluded in the if statement. Maybe we could just say Device arch gfx94x, gfx95x or gfx125x required for FP8 execution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 2 CI test level 2

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants