Conversation
| if is_hip_extension(): | ||
| # only GFX9.4+ and MI200 machines support bf16 | ||
| return 100 > gpu_arch >= 94 or is_mi200() | ||
| # only GFX9.4+, GFX12+ and MI200 machines support bf16, that excludes GFX10 and GFX11 |
There was a problem hiding this comment.
gfx10* and gfx11* are Navi cards? Is is that the HW does not support bf16, or our TE ROCm does not support those cards?
There was a problem hiding this comment.
No BF16 support in HW
| } | ||
|
|
||
| constexpr int THREADS_PER_WAVEFRONT = 64; | ||
| constexpr int THREADS_PER_WAVEFRONT = 32; |
There was a problem hiding this comment.
@wangye805 Did you check if there is any perf decrease on gfx942 or gfx950 if we use THREADS_PER_WAVEFRONT = 32?
There was a problem hiding this comment.
Well, it will need testing on real HW. But will you need this change in gfx942/gfx950?
There was a problem hiding this comment.
For real Gfx1250 HW it should be 32. It also verified to work on Gfx9*. I changed it unconditionally because otherwise it will require runtime checking on host side
| if get_device_compute_capability(0) == 95: | ||
| """Return 64 MiB for gfx50x+, 32 MiB for all other architectures.""" | ||
| """TODO: gfx1250 WS size requirements""" | ||
| if get_device_compute_capability(0) in (95, 125): |
There was a problem hiding this comment.
Shall we check with hipblasLt? What's the observation on gfx1250?
| if IS_HIP_EXTENSION: | ||
| # only MI200 and newer machines support bf16 | ||
| if get_device_compute_capability() in [(9, 4), (9, 5)] or is_mi200(): | ||
| if get_device_compute_capability() in ((9, 4), (9, 5), (12, 5)) or is_mi200(): |
There was a problem hiding this comment.
Shall we rename is_mi200 to is_cdna2?
There was a problem hiding this comment.
It is historically is_mi200 and it checks specifically for MI2.0, so current name better reflect what the method does
| message(WARNING "Python interpreter not found; skipping AITER API validation.") | ||
| endif() | ||
|
|
||
| # so far, there are only gfx942 and gfx950 v3 kernels |
There was a problem hiding this comment.
What's the change in this file for?
There was a problem hiding this comment.
AITER/CK building bug that limits them to V3 archs only is fixed a while ago. Right now there is no need to perform this filtering on TE side. Since AITER might have no V3 kernels for GFX1250, this filtering is removed so it can be built despite V3 support in AITER
There was a problem hiding this comment.
Please update this file as TE dev now switch to QoLA
| } | ||
|
|
||
| constexpr int THREADS_PER_WAVEFRONT = 64; | ||
| constexpr int THREADS_PER_WAVEFRONT = 32; |
There was a problem hiding this comment.
Well, it will need testing on real HW. But will you need this change in gfx942/gfx950?
| else: | ||
| fa_utils.use_aiter_triton = True | ||
| # Setup Flash attention utils | ||
| fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 |
There was a problem hiding this comment.
Two questions:
1). Do we plan to switch to aiter flash-attn for gfx942/gfx950/gfx1250 all together?
2). If we do, is this fa_utils.version still as 2.7.1?
There was a problem hiding this comment.
This code is run on emu and real HW, so answer is yes. For extra safety it might make sense to disable this by default (make env default 0)
| if(NOT EXISTS "${__AITER_SOURCE_DIR}/hsa") | ||
| message(FATAL_ERROR "Cannot find AITER v3 kernels location at ${__AITER_SOURCE_DIR}/hsa.") | ||
| continue() | ||
| endif() |
There was a problem hiding this comment.
The message(FATAL_ERROR ...) terminates the CMake configuration so the continue() is unreachable. Moreover, continue() is only valid inside a foreach / while loop and here it is at top level which would make it a CMake error if it would be reached.
| endif() | ||
|
|
||
| # so far, there are only gfx942 and gfx950 v3 kernels | ||
| SET(V3_ASM_ARCHS_SUPPORTED "gfx942;gfx950") |
There was a problem hiding this comment.
Do we know AITER actually generates v3 kernels for gfx1250 without having the V3 support specified directly?
There was a problem hiding this comment.
It is not to instruct AITER to make V3 kernels, it is to reflect what does AITER support. There used to be a bug in AITER that it could only compile for V3 supported archs,
| return False, "MXFP8 support is not enabled." | ||
| gpu_arch = get_device_compute_capability() | ||
| if gpu_arch == (9, 5): | ||
| if gpu_arch in ((9, 5), (12, 5)): |
There was a problem hiding this comment.
There is inconsistent gfx1250 detection across files:
gpu_arch in ((9, 5), (12, 5))is for exact match- examples/jax/encoder/common.py
gpu_arch >= 120matches any gfx12x and tests/cpp/operator/test_cublaslt_gemm.cuprop.major >= 12has the same broad match
If only gfx1250 supports MXFP8/FP8 here, the JAX example and C++ test will incorrectly enable tests on gfx1200 for example. Maybe we could pick one canonical predicate and use it everywhere to be consistent in what the C++ and Python runtime accept.
| except ImportError as e: | ||
| pass |
There was a problem hiding this comment.
Maybe it would be better to log the error instead of silently fall through with no diagnostics? Something like attn_log.fa_logger.debug(...)
| return True, "" | ||
| else: | ||
| return False, "Device arch gfx94x or gfx95x required for FP8 execution." | ||
| return False, "Device arch gfx94x or newer required for FP8 execution." |
There was a problem hiding this comment.
Nit: Technically, gfx10xx and gfx11xx are newer but they are explicitly excluded in the if statement. Maybe we could just say Device arch gfx94x, gfx95x or gfx125x required for FP8 execution.
Description
Port gfx1250 support changes from NPI branch.
It only includes changes in TE itself, no attention backends update.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: