Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/Micky774/QoLA.git
[submodule "3rdparty/hipkittens"]
path = 3rdparty/hipkittens
url = https://github.com/HazyResearch/HipKittens.git
1 change: 1 addition & 0 deletions 3rdparty/hipkittens
Submodule hipkittens added at 997005
1 change: 1 addition & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ run_test_config(){
run 1 test_jit.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 test_multi_tensor.py
run 1 test_numerics.py
NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa_lbl "mxfp8" 1 test_numerics.py -k "recipe0 and 126m and not grouped"
run_default_fa 1 test_permutation.py
run_default_fa 1 test_recipe.py
run 1 test_sanity.py
Expand Down
123 changes: 84 additions & 39 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{32, 128, 16},
{256, 256, 256},
{768, 3072, 4096},
{4096, 16384, 4096},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -168,6 +170,21 @@ __global__ void compute_ref_kernel(
}


static size_t align256(size_t x) {
return (x + 255) & ~(size_t)255;
}

static size_t compute_mxfp8_workspace_size(size_t m, size_t k, size_t n, bool transa, bool transb, size_t base_size) {
size_t k_iters = k / 128;
size_t scale_k = k / 32;
size_t sa_pk = align256(k_iters * m * 4);
size_t sb_pk = k_iters * n * 4;
size_t needed = align256(sa_pk) + sb_pk;
if (!transa) needed += align256(m * k) + align256(m * scale_k) + align256(sa_pk);
if (transb) needed += align256(n * k) + align256(n * scale_k) + align256(sb_pk);
return std::max(base_size, needed);
}

struct TestParams {
size_t m;
size_t k;
Expand All @@ -177,6 +194,7 @@ struct TestParams {
bool transa;
bool transb;
NVTEScalingMode scaling_mode;
bool force_hipblaslt;
};


Expand Down Expand Up @@ -341,8 +359,10 @@ void performTest(const TestParams& params) {
const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype);
const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING;

if (use_mxfp8)
{
if (!use_mxfp8 && params.force_hipblaslt) {
GTEST_SKIP() << "force_hipblaslt only relevant for MXFP8";
}
if (use_mxfp8) {
if (!has_fp8) {
GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types";
}
Expand All @@ -352,6 +372,9 @@ void performTest(const TestParams& params) {
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
}
if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires (M%256, N%256, K>=256)";
}
}

cudaDeviceProp prop;
Expand Down Expand Up @@ -383,23 +406,16 @@ void performTest(const TestParams& params) {

if (has_fp8)
{
bool fp8_supported = (prop.major == 9 && prop.minor >= 4);
if (!fp8_supported) {
if (prop.major != 9 || prop.minor < 4) {
GTEST_SKIP() << "FP8 is not supported in current config";
}

if (use_mxfp8)
{
bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5);
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (params.use_bias) {
GTEST_SKIP() << "MXFP8 GEMM with bias is not supported";
}
if (use_mxfp8 && prop.minor < 5) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}

if (params.use_gelu && !fp8_gelu_fusion_config) {
if (use_mxfp8 && params.use_bias && params.force_hipblaslt) {
GTEST_SKIP() << "MXFP8 GEMM with bias is not supported by hipBLASLt";
}
if (params.use_gelu && !fp8_gelu_fusion_config && (params.force_hipblaslt || !use_mxfp8)) {
GTEST_SKIP() << "FP8 GEMM with GELU is not supported in current config";
}
if (params.use_bias && dtype == DType::kFloat16) {
Expand All @@ -409,29 +425,27 @@ void performTest(const TestParams& params) {

if (prop.major == 9 && prop.minor == 5) //gfx950 specific hipblasLt limitations
{
if (isFp8Type(dtype)){
if (isFp8Type(dtype)) {
GTEST_SKIP() << "GEMM with float8 output is not supported";
}
if (params.use_gelu && dtype == DType::kBFloat16) {
if (params.use_gelu && dtype == DType::kBFloat16 && (params.force_hipblaslt || !use_mxfp8)) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
if constexpr ((std::is_same<A_Type, bf8>::value || std::is_same<B_Type, bf8>::value) &&
std::is_same<D_Type, fp32>::value)
{
//GEMM with bias and fp32 output is not supported with bf8 A/B
if constexpr ((std::is_same_v<A_Type, bf8> || std::is_same_v<B_Type, bf8>) &&
std::is_same_v<D_Type, fp32>) {
if (params.use_bias) {
GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config";
}
}
}
if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
else if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
{
#if HIP_VERSION < 70100000
if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) {
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
}
#endif
if constexpr (std::is_same<D_Type, fp8>::value && std::is_same<Bias_Type, bf16>::value) {
if constexpr (std::is_same_v<D_Type, fp8> && std::is_same_v<Bias_Type, bf16>) {
if (params.use_bias && !fp8_gelu_fusion_config) {
GTEST_SKIP() << "GEMM with BF16 bias and FP8 output is not supported in current config";
}
Expand Down Expand Up @@ -490,6 +504,11 @@ void performTest(const TestParams& params) {
if (prop.major == 9 && prop.minor == 5) {
workspace_size = 67108864;
}
if (use_mxfp8) {
workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
params.transa, params.transb,
workspace_size);
}
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

Expand Down Expand Up @@ -544,13 +563,14 @@ void performTest(const TestParams& params) {
}

auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8);
size_t mismatch_limit = use_mxfp8 ? std::max((size_t)1, params.m * params.n / 1'000'000) : 0;
RefD.to_cpu();
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
compareResults("D", D, RefD.rowwise_cpu_dptr<D_Type>(), true, atol, rtol, true, mismatch_limit);

if(params.use_gelu){
auto [atol, rtol] = getTestTolerances(gelu_type, false, false);
auto [atol, rtol] = getTestTolerances(gelu_type, has_fp8, use_mxfp8);
RefPreGeluOut.to_cpu();
compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr<Gelu_Type>(), true, atol, rtol);
compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr<Gelu_Type>(), true, atol, rtol, true, mismatch_limit);
}
}

Expand Down Expand Up @@ -578,6 +598,15 @@ void performDqTest(const TestParams &params) {
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}
if (params.use_bias || params.use_gelu) {
if (params.force_hipblaslt) {
GTEST_SKIP() << "MXFP8 GEMM with bias/GELU is not supported by hipBLASLt";
}
GTEST_SKIP() << "DqGEMMTestSuite does not yet have reference for bias/gelu epilogues";
}
if (!params.force_hipblaslt && (params.m % 256 || params.n % 256 || params.k % 128 || params.k < 256)) {
GTEST_SKIP() << "HipKittens requires (M%256, N%256, K>=256)";
}

DType ref_type = dtype;
TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m};
Expand Down Expand Up @@ -605,7 +634,9 @@ void performDqTest(const TestParams &params) {
Tensor bias;
Tensor pre_gelu_out;

size_t workspace_size = 67108864;
size_t workspace_size = compute_mxfp8_workspace_size(params.m, params.k, params.n,
params.transa, params.transb,
67108864); // 64 MiB required for hipBLASlt
Tensor Workspace("Workspace", TShape{workspace_size}, DType::kByte);

//perform FP8 gemm and copy the output results from GPU memory to CPU memory
Expand Down Expand Up @@ -635,6 +666,12 @@ void performDqTest(const TestParams &params) {
#endif // __HIP_PLATFORM_AMD__

#define MAKE_TEST_PARAMS(P_) \
bool force_hipblaslt_ = std::get<5>(GetParam()); \
if (force_hipblaslt_) { \
setenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8", "1", 1); \
} else { \
unsetenv("NVTE_ROCM_USE_HIPBLASLT_MXFP8"); \
} \
TestParams P_ = {.m = std::get<0>(std::get<0>(GetParam())), \
.k = std::get<1>(std::get<0>(GetParam())), \
.n = std::get<2>(std::get<0>(GetParam())), \
Expand All @@ -643,13 +680,14 @@ void performDqTest(const TestParams &params) {
.transa = std::get<3>(GetParam()).first, \
.transb = std::get<3>(GetParam()).second, \
.scaling_mode = std::get<4>(GetParam()) \
? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \
: NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING}
? NVTEScalingMode::NVTE_MXFP8_1D_SCALING \
: NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING,\
.force_hipblaslt = force_hipblaslt_}

// <m, k, n>, use_bias, use_gelu, Layout, fp8_scalinig
// <m, k, n>, use_bias, use_gelu, Layout, fp8_scaling, force_hipblaslt
class GEMMTestSuite
: public ::testing::TestWithParam<
std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode>> {};
std::tuple<std::tuple<size_t, size_t, size_t>, bool, bool, Layout, NVTEScalingMode, bool>> {};

#define MAKE_GEMM_TEST(NAME_, A_, B_, BIAS_, GELU_, D_) \
TEST_P(GEMMTestSuite, NAME_) { \
Expand Down Expand Up @@ -715,13 +753,15 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, GEMMTestSuite,
::testing::Values(false, true), //use bias
::testing::Values(false, true), //use_gelu
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(false, true)), //use mxfp8
::testing::Values(false, true), //use mxfp8
::testing::Values(false, true)), //force hipblaslt
[](const testing::TestParamInfo<GEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
std::to_string(std::get<2>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<4>(info.param) ? "M" : "S");
(std::get<4>(info.param) ? "M" : "S") + "x" +
(std::get<5>(info.param) ? "HB" : "HK");
});

#ifdef __HIP_PLATFORM_AMD__
Expand All @@ -740,12 +780,17 @@ MAKE_DQ_GEMM_TEST(Testfp8xfp8xfp16, fp8, fp8, fp16)

INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
::testing::Combine(::testing::ValuesIn(test_case_sizes_mxfp8),
::testing::Values(false), // bias - unused
::testing::Values(false), // gelu - unused
::testing::ValuesIn(kLayouts), //transa,transb
::testing::Values(true)), //use mxfp8
::testing::Values(false), // use bias
::testing::Values(false), // use gelu
::testing::ValuesIn(kLayouts), // transa,transb
::testing::Values(true), // use mxfp8
::testing::Values(false, true)), // force hipblaslt
[](const testing::TestParamInfo<DqGEMMTestSuite::ParamType>& info) {
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
return MKN(std::get<0>(info.param)) + "x" +
std::to_string(std::get<1>(info.param)) + "x" +
std::to_string(std::get<2>(info.param)) + "x" +
TN(std::get<3>(info.param)) + "x" +
(std::get<5>(info.param) ? "HB" : "HK");
});

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
]
TEST_SHAPES = [(64, 32, 64)]
if is_hip_extension():
TEST_SHAPES += [(64, 64, 128), (128, 256, 256)]
TEST_SHAPES += [(64, 64, 128), (128, 256, 256), (256, 256, 256)]
jnp_float8_e4m3_type = get_jnp_float8_e4m3_type()
jnp_float8_e5m2_type = get_jnp_float8_e5m2_type()

Expand Down
5 changes: 3 additions & 2 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False):
pytest.skip(
f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM."
)
if use_bias:
pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.")
hipkittens_eligible = (m % 256 == 0) and (n % 256 == 0) and (k >= 256)
if use_bias and not hipkittens_eligible:
pytest.skip("hipblaslt GEMM does not support MXFP8 with bias.")
else:
jax_version = version.parse(jax.__version__)
if jax_version < version.parse("0.8.2"):
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,6 @@ def test_gpt_full_activation_recompute(
if (dtype == torch.bfloat16
and not fp8
and not use_reentrant
and recipe.float8_per_tensor_scaling()
):
pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.")
if fp8 and recipe.nvfp4():
Expand Down
22 changes: 22 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cmake_minimum_required(VERSION 3.21)
option(USE_ROCM "Use ROCm" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
option(USE_HIPKITTENS_GEMM "Use HipKittens MXFP8 GEMM kernels" ON)
set(USE_CUDA OFF)

if (USE_ROCM)
Expand Down Expand Up @@ -452,6 +453,23 @@ else()
add_subdirectory(ck_fused_attn ${CMAKE_CURRENT_BINARY_DIR}/ck_fused_attn)
endif()

if(USE_HIPKITTENS_GEMM)
list(FIND CMAKE_HIP_ARCHITECTURES "gfx950" _gfx950_index)
if(_gfx950_index EQUAL -1)
message(STATUS "HipKittens GEMM disabled (gfx950 not in CMAKE_HIP_ARCHITECTURES)")
set(USE_HIPKITTENS_GEMM OFF)
else()
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++20" HAS_CXX20)
if(HAS_CXX20)
add_subdirectory(gemm/kittens ${CMAKE_CURRENT_BINARY_DIR}/kittens)
else()
message(WARNING "HipKittens GEMMs require C++20")
set(USE_HIPKITTENS_GEMM OFF)
endif()
endif()
endif()

find_package(hip)
list(APPEND transformer_engine_LINKER_LIBS hip::host hip::device roctx64)
find_package(hiprtc)
Expand All @@ -466,6 +484,10 @@ else()
target_compile_definitions(transformer_engine PUBLIC USE_FUSED_ATTN_CK)
list(APPEND transformer_engine_LINKER_LIBS ck_fused_attn)
endif()
if(USE_HIPKITTENS_GEMM)
target_compile_definitions(transformer_engine PUBLIC USE_HIPKITTENS_GEMM)
list(APPEND transformer_engine_LINKER_LIBS kittens_gemm)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif()

Expand Down
35 changes: 35 additions & 0 deletions transformer_engine/common/gemm/kittens/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information

cmake_minimum_required(VERSION 3.21)
set(CMAKE_CXX_STANDARD 20)
project(kittens_gemm LANGUAGES HIP CXX)

set(HIPKITTENS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../../../3rdparty/hipkittens/include")
if(NOT EXISTS "${HIPKITTENS_INCLUDE_DIR}/kittens.cuh")
message(FATAL_ERROR
"Could not find HipKittens headers at ${HIPKITTENS_INCLUDE_DIR}. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()

set(kittens_gemm_SOURCES
mxfp8_gemm.hip)

add_library(kittens_gemm SHARED ${kittens_gemm_SOURCES})

set(KITTENS_GEMM_COMPILE_OPTIONS
--offload-arch=gfx950
-DKITTENS_CDNA4
-fno-gpu-rdc
-O3)

find_package(hip)
target_include_directories(kittens_gemm PRIVATE ${HIPKITTENS_INCLUDE_DIR})
target_include_directories(kittens_gemm PRIVATE ${HIP_INCLUDE_DIRS})
target_link_libraries(kittens_gemm PUBLIC hip::host hip::device)
target_compile_options(kittens_gemm PRIVATE ${KITTENS_GEMM_COMPILE_OPTIONS})

install(TARGETS kittens_gemm
DESTINATION ${CMAKE_INSTALL_PREFIX}/transformer_engine/lib)
Loading
Loading