Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fe9a66c
[ROCm] port tdm to npi_gfx1250
wangye805 Apr 22, 2026
40f3902
[ROCm] address first round reviewer comments
wangye805 Apr 23, 2026
cb421f8
[ROCm] address reviewer comments
wangye805 Apr 23, 2026
0e2d24d
[ROCm] address more reviewer comments
wangye805 Apr 23, 2026
0a13bbc
[ROCm] Address TDM review comments: remove extra params, add explanat…
wangye805 Apr 23, 2026
bfb7199
[ROCm] Address remaining review comments and enable TDM flow in CI gt…
wangye805 Apr 23, 2026
ba4bbb7
tdm: clamp tensorDim to avoid uint32_t underflow on OOB prefetch tiles
wangye805 Apr 24, 2026
ab77fbf
tdm: add HIPTensorMap descriptor struct; revert TDM from rocm_* kernels
wangye805 Apr 25, 2026
a0a60fe
tdm: fully revert rocm_*.cuh to branch-point state
wangye805 Apr 25, 2026
506d78c
tdm: extract ROCm flow into separate rocm_* launchers; TDM stays in m…
wangye805 Apr 25, 2026
acc7e4f
tdm: address review comments for cast_gated_kernels.cuh
wangye805 Apr 26, 2026
3c85101
tdm: revert swizzled_* lines to NV upstream position
wangye805 Apr 26, 2026
0007c88
tdm: fix cast_mxfp8_gated to match NV upstream structure
wangye805 Apr 26, 2026
bacd226
tdm: use switch(scaling_type) for AMD TDM mxfp8 gated dispatch
wangye805 Apr 26, 2026
4ba5883
tdm: hoist shared next-stage offset vars above #ifdef in cast_mxfp8_g…
wangye805 Apr 26, 2026
7dbf218
tdm: hoist shared shmem computation above #ifdef in cast_fp8_gated
wangye805 Apr 26, 2026
293d970
tdm: collapse duplicate switch(scaling_type) blocks in cast_mxfp8_gated
wangye805 Apr 26, 2026
53b5d22
tdm: remove tma_flow namespace; prefix ROCm-specific constants with R…
wangye805 Apr 26, 2026
a0a9ab6
util: apply ROCM_ prefix to ROCm-specific constants in cast/dequantiz…
wangye805 Apr 26, 2026
fdb4b1a
util: address PR review comments on cast_kernels.cuh and dequantize_k…
wangye805 Apr 26, 2026
a732d35
util: address 4 more PR review comments on cast/dequantize kernels
wangye805 Apr 26, 2026
1a004df
util: hoist shared next-iter offset vars above #ifndef in cast_mxfp8_…
wangye805 Apr 26, 2026
573f8d7
Revert " Remove padding from scales for hipBLASlt calls (#442)"
wangye805 Apr 26, 2026
fec2de5
fix(rocm): correct double-prefixed constants in rocm_cast_gated_kerne…
wangye805 Apr 26, 2026
004d59f
fix(rocm): add TMA_SHMEM_ALIGNMENT alias and sigmoidf for AMD compila…
wangye805 Apr 26, 2026
7c86c98
fix(rocm): fix fp8_quantize AMD flow — remove unavailable fp8_quantiz…
wangye805 Apr 26, 2026
0b40533
fix(rocm): route NVTE_MXFP8_1D_SCALING through fp8_quantize_rocm on AMD
wangye805 Apr 27, 2026
c89b5ff
fix(rocm): use padded scales_stride in rocm_mxfp8_dequantize
wangye805 Apr 27, 2026
9f55a8b
fix(rocm): guard TDM flow dispatch behind __gfx1250__ on AMD
wangye805 Apr 27, 2026
8338725
fix(rocm): wire up cast_mxfp8_2D_kernel launch on gfx1250 TDM path
wangye805 Apr 27, 2026
14329d5
refactor(rocm): consolidate mxfp8_quantize kernel launch for TDM and TMA
wangye805 Apr 27, 2026
d38c6bd
fix(amd): guard cudaFuncSetAttribute and add hip_bfloat16 overloads f…
wangye805 Apr 27, 2026
14a1dab
chore: remove debug print statements from MXFP8 cast/dequantize kernels
wangye805 Apr 29, 2026
198495a
chore: restore launcher debug prints, remove only in-kernel printf st…
wangye805 Apr 29, 2026
362ae53
test: add 16384x16384 matrix size to CastMXFP8_GatedAct benchmark run
wangye805 Apr 29, 2026
0456492
feat: migrate benchmarks/cpp/cast from dev branch
wangye805 Apr 29, 2026
573f6ea
build: add rocm_utils.cmake needed by benchmarks/cpp CMakeLists
wangye805 Apr 29, 2026
9f340a1
fix: suppress clang warnings in Google Benchmark for gfx1250 toolchain
wangye805 Apr 29, 2026
09ed78c
test: remove 16384x16384 from gated swiglu test (causes CPU ref hang)
wangye805 Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions benchmarks/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
cmake_minimum_required(VERSION 3.18)

if(NOT DEFINED CMAKE_CXX_COMPILER)
set(CMAKE_CXX_COMPILER hipcc)
endif()

include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/rocm_utils.cmake")

project(transformer_engine_benchmarks LANGUAGES CXX HIP)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(HIP REQUIRED)

include(FetchContent)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG v1.14.0
)
set(BUILD_GMOCK OFF CACHE BOOL "" FORCE)
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)

FetchContent_Declare(
benchmark
GIT_REPOSITORY https://github.com/google/benchmark.git
GIT_TAG v1.8.3
)
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE)
FetchContent_MakeAvailable(benchmark)
# Suppress clang warnings from benchmark headers that fire under -pedantic-errors
# on newer ROCm toolchains (gfx1250): __COUNTER__ classified as C2y extension,
# and kDefaultMinTimeStr triggers -Wunused-const-variable when benchmark.h is
# compiled as a standalone TU by hipcc.
set_target_properties(benchmark benchmark_main PROPERTIES LINKER_LANGUAGE CXX)
foreach(_bench_target benchmark benchmark_main)
target_compile_options(${_bench_target} PRIVATE
-Wno-c2y-extensions
-Wno-unused-const-variable
)
endforeach()

set(TESTS_CPP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../tests/cpp)

include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/hipify/hipify.cmake")
TE_Hipify(${TESTS_CPP_DIR})

include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include
${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common
${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine
${CMAKE_CURRENT_SOURCE_DIR}/utils
${TESTS_CPP_DIR}
)

set(COMMON_COMPILE_OPTIONS
-Wall
-Wextra
-O3
-DNDEBUG
-DUSE_ROCM
)

find_library(TRANSFORMER_ENGINE_LIB
NAMES transformer_engine
PATHS ${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake
${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib
/usr/local/lib
$ENV{HOME}/.local/lib
NO_DEFAULT_PATH
)

if(NOT TRANSFORMER_ENGINE_LIB)
message(WARNING "TransformerEngine library not found in expected paths. Trying system paths...")
find_library(TRANSFORMER_ENGINE_LIB NAMES transformer_engine)
endif()

if(TRANSFORMER_ENGINE_LIB)
message(STATUS "Found TransformerEngine library: ${TRANSFORMER_ENGINE_LIB}")
else()
message(FATAL_ERROR "TransformerEngine library not found. Please build TransformerEngine first:\n"
" cd ${CMAKE_CURRENT_SOURCE_DIR}/../..\n"
" pip install -e . --no-build-isolation\n"
"Searched paths:\n"
" ${CMAKE_CURRENT_SOURCE_DIR}/../..\n"
" ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake\n"
" ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib")
endif()

function(add_te_benchmark TARGET_NAME SOURCE_FILE)
add_executable(${TARGET_NAME} ${SOURCE_FILE} ${TESTS_CPP_DIR}/test_common.hip)
target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS})
target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK)
target_link_libraries(${TARGET_NAME} PRIVATE
benchmark::benchmark
GTest::gtest
${TRANSFORMER_ENGINE_LIB}
hiprand
)
endfunction()

add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp)
add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp)
add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp)
add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp)
269 changes: 269 additions & 0 deletions benchmarks/cpp/cast/bench_casttranspose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
/*************************************************************************
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

#include <benchmark/benchmark.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#include "amd_detail/hip_float8.h"

#include "benchmark_utils.h"

#include "transformer_engine/cast_hip.h"
#include "transformer_engine/transpose_hip.h"
#include "transformer_engine/transformer_engine_hip.h"

// #define NVTE_ROCM_EXTENDED_BENCHMARKS 1

using namespace te_bench;
using namespace transformer_engine;
using fp8_e4m3 = test::fp8e4m3;

#define GPT_OSS_COMMON_SHAPES \
->Args({2880, 2880}) \
->Args({2880, 4096}) \
->Args({5120, 2880}) \
->Args({5760, 2880}) \
->Args({16384, 2880}) \
->Args({16384, 4096}) \
->Args({16384, 5120})

// GPT-OSS MoE per-expert shapes (hidden=2880, intermediate=5760)
#define GPT_OSS_MOE \
->Args({64, 2880}) \
->Args({256, 2880}) \
->Args({320, 2880}) \
->Args({496, 2880}) \
->Args({1792, 2880}) \
->Args({64, 5760}) \
->Args({256, 5760}) \
->Args({320, 5760}) \
->Args({496, 5760}) \
->Args({1792, 5760})

// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B)
#define COMMON_SHAPES \
->Args({1024, 3584}) \
->Args({1024, 4096}) \
->Args({1024, 8192}) \
->Args({1024, 14336}) \
->Args({1024, 18944}) \
->Args({2048, 4096}) \
->Args({2048, 8192}) \
->Args({2048, 14336}) \
->Args({2048, 28672}) \
->Args({2048, 29568}) \
->Args({4096, 4096}) \
->Args({4096, 8192}) \
->Args({4096, 16384}) \
->Args({4096, 14336}) \
->Args({4096, 28672}) \
->Args({8192, 8192}) \
->Args({8192, 16384}) \
->Args({8192, 28672}) \
->Args({8192, 29568}) \
->Args({8192, 53248}) \
->Args({16384, 8192}) \
->Args({16384, 16384}) \
->Args({16384, 28672}) \
->Args({32768, 8192}) \
->Args({32768, 16384})

// Only used for specific benchmarks (older models, special cases, etc)
#define EXTENDED_SHAPES \
->Args({2048, 12288}) \
->Args({256, 65536}) \
->Args({65536, 128}) \
->Args({1600, 1600}) \
->Args({1600, 6400}) \
->Args({4800, 1600}) \
->Args({56320 , 1600}) \
->Args({6400, 1600}) \
->Args({128256, 4096}) \
->Args({24576, 128256}) \
->Args({24576, 4096}) \
->Args({24576, 5120}) \
->Args({28672, 4096}) \
->Args({4096, 12288}) \
->Args({5120, 4096}) \
->Args({10240, 8192}) \
->Args({128256, 8192}) \
->Args({57344, 10240}) \
->Args({57344, 128256}) \
->Args({57344, 8192}) \
->Args({32000, 4096}) \
->Args({32768, 32000}) \
->Args({32768, 4096}) \
->Args({32768, 5120}) \
->Args({3072, 1024}) \
->Args({24576, 1024}) \
->Args({4096, 1024})





template <typename IType>
static void BM_CastOnly(benchmark::State &state) {
const size_t rows = state.range(0);
const size_t cols = state.range(1);
std::vector<size_t> shape = {rows, cols};

DType itype = std::is_same_v<IType, float> ? DType::kFloat32 :
std::is_same_v<IType, hip_bfloat16> ? DType::kBFloat16 :
DType::kFloat16;

test::Tensor &input = TensorCache::get_or_create(
"cast_input", shape, itype, true, false, NVTE_DELAYED_TENSOR_SCALING, true);
test::Tensor &output = TensorCache::get_or_create(
"cast_output", shape, DType::kFloat8E4M3, true, false, NVTE_DELAYED_TENSOR_SCALING, false);

output.set_scale(1.0f);

hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));

hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));

// Untimed call to trigger any RTC compilation before measurement
nvte_quantize(input.data(), output.data(), stream);
warmup_gpu();

for (auto _ : state) {
HIP_CHECK(hipEventRecord(start, stream));
nvte_quantize(input.data(), output.data(), stream);
HIP_CHECK(hipEventRecord(stop, stream));
HIP_CHECK(hipEventSynchronize(stop));

float ms = 0;
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
state.SetIterationTime(ms / 1000.0);
}

HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));

const size_t bytes_read = rows * cols * sizeof(IType);
const size_t bytes_write = rows * cols * sizeof(fp8_e4m3);
set_bytes_processed(state, bytes_read + bytes_write);

HIP_CHECK(hipStreamDestroy(stream));
}

template <typename IType>
static void BM_CastTranspose(benchmark::State &state) {
const size_t rows = state.range(0);
const size_t cols = state.range(1);
std::vector<size_t> shape = {rows, cols};

DType itype = std::is_same_v<IType, float> ? DType::kFloat32 :
std::is_same_v<IType, hip_bfloat16> ? DType::kBFloat16 :
DType::kFloat16;

test::Tensor &input = TensorCache::get_or_create(
"ct_input", shape, itype, true, false, NVTE_DELAYED_TENSOR_SCALING, true);
test::Tensor &output = TensorCache::get_or_create(
"ct_output", shape, DType::kFloat8E4M3, true, true, NVTE_DELAYED_TENSOR_SCALING, false);

output.set_scale(1.0f);

hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));

hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));

// Untimed call to trigger any RTC compilation before measurement
nvte_quantize(input.data(), output.data(), stream);
warmup_gpu();

for (auto _ : state) {
HIP_CHECK(hipEventRecord(start, stream));
nvte_quantize(input.data(), output.data(), stream);
HIP_CHECK(hipEventRecord(stop, stream));
HIP_CHECK(hipEventSynchronize(stop));

float ms = 0;
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
state.SetIterationTime(ms / 1000.0);
}

HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));

const size_t bytes_read = rows * cols * sizeof(IType);
const size_t bytes_write = rows * cols * sizeof(fp8_e4m3) * 2;
set_bytes_processed(state, bytes_read + bytes_write);

HIP_CHECK(hipStreamDestroy(stream));
}

#define REGISTER_CAST_ONLY(ITYPE, INAME) \
BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \
->Name("BM_CastOnly/" INAME "_E4M3/gpt_oss") \
GPT_OSS_COMMON_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \
->Name("BM_CastOnly/" INAME "_E4M3/gpt_oss_moe") \
GPT_OSS_MOE \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \
->Name("BM_CastOnly/" INAME "_E4M3/llm") \
COMMON_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

#define REGISTER_CAST_TRANSPOSE(ITYPE, INAME) \
BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \
->Name("BM_CastTranspose/" INAME "_E4M3/gpt_oss") \
GPT_OSS_COMMON_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \
->Name("BM_CastTranspose/" INAME "_E4M3/gpt_oss_moe") \
GPT_OSS_MOE \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \
->Name("BM_CastTranspose/" INAME "_E4M3/llm") \
COMMON_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

#ifdef NVTE_ROCM_EXTENDED_BENCHMARKS
#define REGISTER_EXTENDED_CAST_ONLY(ITYPE, INAME) \
BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \
->Name("BM_CastOnlyExtended/" INAME "_E4M3/llm") \
EXTENDED_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

#define REGISTER_EXTENDED_CAST_TRANSPOSE(ITYPE, INAME) \
BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \
->Name("BM_CastTransposeExtended/" INAME "_E4M3/llm") \
EXTENDED_SHAPES \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

REGISTER_EXTENDED_CAST_ONLY(float, "FP32")
REGISTER_EXTENDED_CAST_ONLY(hip_bfloat16, "BF16")

REGISTER_EXTENDED_CAST_TRANSPOSE(float, "FP32")
REGISTER_EXTENDED_CAST_TRANSPOSE(hip_bfloat16, "BF16")
#endif // #ifdef NVTE_ROCM_EXTENDED_BENCHMARKS

REGISTER_CAST_ONLY(float, "FP32")
REGISTER_CAST_ONLY(hip_bfloat16, "BF16")

REGISTER_CAST_TRANSPOSE(float, "FP32")
REGISTER_CAST_TRANSPOSE(hip_bfloat16, "BF16")

BENCHMARK_MAIN();
Loading