From fe9a66c26a0b119a25ef0a0f0fb746945a50237b Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 22 Apr 2026 18:12:28 +0000 Subject: [PATCH 01/43] [ROCm] port tdm to npi_gfx1250 --- .../rocshmem_api/rocshmem_waitkernel.hip | 115 ------ .../common/util/rocm_cast_gated_kernels.cuh | 56 ++- .../common/util/rocm_cast_kernels.cuh | 41 ++- .../common/util/rocm_dequantize_kernels.cuh | 25 +- transformer_engine/common/util/tdm.cuh | 336 ++++++++++++++++++ 5 files changed, 452 insertions(+), 121 deletions(-) delete mode 100644 transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip create mode 100644 transformer_engine/common/util/tdm.cuh diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip deleted file mode 100644 index 9f7fe0e2e..000000000 --- a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip +++ /dev/null @@ -1,115 +0,0 @@ -/************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - * License for AMD contributions = MIT. See LICENSE for more information -*************************************************************************/ - -#include -#include - -#include "../util/logging_hip.h" -#include "rocshmem_waitkernel.hpp" - -using namespace rocshmem; - -__global__ void wait_until_on_stream_and_reset(uint64_t *wait_flag, - uint64_t wait_value, - uint64_t signal_reset) { - rocshmem_ulonglong_wait_until((unsigned long long*)wait_flag, - ROCSHMEM_CMP_EQ, - (unsigned long long)wait_value); -} - -__global__ void rocshmem_putmem_signal_kernel(void* dst_ptr, const void* src_ptr, - size_t nelement, uint64_t* sig_addr, - uint64_t sigval, int peer) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - rocshmem_putmem(dst_ptr, src_ptr, nelement, peer); - rocshmem_fence(); - rocshmem_ulonglong_p((unsigned long long*)sig_addr, - (unsigned long long)sigval, - peer); - } -} - -void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement, - uint64_t* sig_addr, uint64_t sigval, int peer, - hipStream_t cur_stream) { - hipLaunchKernelGGL(rocshmem_putmem_signal_kernel, - dim3(1), dim3(1), 0, cur_stream, - dst_ptr, src_ptr, nelement, sig_addr, - sigval, peer); -} - -void te_rocshmem_wait_on_stream(uint64_t* sig_addr, - WaitKind wait_kind, - hipStream_t cur_stream) { - uint64_t wait_value = 1; - uint64_t signal_reset = 0; - - NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && - wait_kind <= WaitKind::STREAM_WAIT, - "Invalid wait kind"); - - switch (wait_kind) { -// ### wait_until_on_stream not yet implemented for rocshmem ### -// ### KernelWait is robust but slightly slower due to launch ### - case WaitKind::ROCSHMEM_WAIT: - printf("WARNING: rocshmem wait is not implemented yet, defaulting to kernel wait.\n"); - // rocshmem__ulonglong_wait_until_on_stream(sig_addr, - // ROCSHMEM_CMP_EQ, - // wait_value, - // cur_stream); - // hipStreamWriteValue64(cur_stream, - // reinterpret_cast(sig_addr), - // signal_reset, 0); - // break; - case WaitKind::KERNEL_WAIT: - hipLaunchKernelGGL(wait_until_on_stream_and_reset, - dim3(1), dim3(1), 0, cur_stream, - sig_addr, wait_value, signal_reset); - hipStreamWriteValue64(cur_stream, - reinterpret_cast(sig_addr), - signal_reset, 0); - break; - case WaitKind::STREAM_WAIT: - hipStreamWaitValue64(cur_stream, - reinterpret_cast(sig_addr), - wait_value, hipStreamWaitValueGte); - hipStreamWriteValue64(cur_stream, - reinterpret_cast(sig_addr), - signal_reset, 0); - break; - } -} - -int te_rocshmem_init_thread(int required, int* provided) { - if (required == 0 && provided == nullptr) { - rocshmem_init(); - return 0; - } else { - return rocshmem_init_thread(required, provided); - } -} - -void te_rocshmem_finalize() { - rocshmem_finalize(); -} - -int te_rocshmem_my_pe() { - return rocshmem_my_pe(); -} - -int te_rocshmem_n_pes() { - return rocshmem_n_pes(); -} - -void* te_rocshmem_malloc(size_t size) { - return rocshmem_malloc(size); -} - -void te_rocshmem_free(void* ptr) { - rocshmem_free(ptr); -} - -void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value, - hipStream_t stream); \ No newline at end of file diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index 387445a78..cafec34e1 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -13,6 +13,7 @@ #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" +#include "tdm.cuh" #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "vectorized_pointwise.h" @@ -134,6 +135,28 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t row_base = chunk_it_offset_y; // Initiate bulk tensor copy +#if defined(__gfx1250__) + { + constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + if constexpr (IS_DGATED) { + // grad uses stride=cols, act/gate use stride=2*cols -- issue separately + tdm::copy_2d_to_shared( + &in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz); + tdm::copy_2d_to_shared_x2( + &in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz); + } else { + tdm::copy_2d_to_shared_x2( + &in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#else if constexpr (IS_DGATED) { copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); @@ -142,12 +165,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Act copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - + // Gate copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); +#endif const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; @@ -353,6 +377,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); +#if defined(__gfx1250__) + { + constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + if constexpr (USE_ROWWISE_SCALING) { + tdm::store_2d_to_global(&out_act_rowwise_sh[0], output_act_rowwise, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); + } + } + if constexpr (USE_COLWISE_SCALING) { + tdm::store_2d_to_global(&out_act_colwise_sh[0], output_act_colwise, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(&out_gate_colwise_sh[0], output_gate_colwise, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); + } + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#else if constexpr (USE_ROWWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); @@ -361,7 +412,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } - + if constexpr (USE_COLWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); @@ -371,6 +422,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } __syncthreads(); +#endif } } } // namespace gated_kernels diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 33c53e8e8..b512c3718 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -12,6 +12,7 @@ #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" +#include "tdm.cuh" #include "transformer_engine/cast.h" #include "../transpose/cast_transpose.h" #include "vectorized_pointwise.h" @@ -161,15 +162,31 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; +#if defined(__gfx1250__) + constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, - chunk_it_offset_x, chunk_it_offset_y, cols, + tdm::copy_2d_to_shared_x2( + &in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, + &act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz); + } else { + tdm::copy_2d_to_shared( + &in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); +#else + if constexpr (IS_DACT) { + copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, + chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); +#endif if constexpr (USE_ROWWISE_SCALING) { Vec in; @@ -312,6 +329,23 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); +#if defined(__gfx1250__) + constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + if constexpr (USE_ROWWISE_SCALING) { + tdm::store_2d_to_global(&out_rowwise_sh[0][0], output_rowwise, + chunk_it_offset_x, chunk_it_offset_y, + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, + cols, rows, cols, out_data_sz); + } + if constexpr (USE_COLWISE_SCALING) { + tdm::store_2d_to_global(&out_colwise_sh[0][0], output_colwise, + chunk_it_offset_x, chunk_it_offset_y, + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, + cols, rows, cols, out_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); +#else if constexpr (USE_ROWWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, @@ -324,6 +358,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } __syncthreads(); +#endif } } diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 0d020b5eb..5aae8ede1 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -14,6 +14,7 @@ #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" +#include "tdm.cuh" #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "../transpose/cast_transpose.h" @@ -85,10 +86,21 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, +#if defined(__gfx1250__) + { + constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + tdm::copy_2d_to_shared(&in_sh[0][0], input_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz); + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#else + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); +#endif const int scale_offset_Y = USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) @@ -126,11 +138,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); +#if defined(__gfx1250__) + { + constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + tdm::store_2d_to_global(&out_sh[0][0], output_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, out_data_sz); + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#else bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); +#endif } } } // namespace dequantization diff --git a/transformer_engine/common/util/tdm.cuh b/transformer_engine/common/util/tdm.cuh new file mode 100644 index 000000000..d3b5a4a44 --- /dev/null +++ b/transformer_engine/common/util/tdm.cuh @@ -0,0 +1,336 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +/*! \file tdm.cuh + * \brief TDM (Tensor Data Mover) wrappers for gfx1250. + * + * AMD's TDM is the gfx1250 equivalent of NVIDIA's TMA (Tensor Memory + * Accelerator). It provides asynchronous bulk copies between global memory + * and LDS (shared memory) using hardware descriptors. + * + * Key differences from TMA: + * - Descriptors are constructed on-device (no host-side CUtensorMap). + * - The instruction is wave-level (EXEC mask ignored). + * - Synchronization uses TENSORcnt + s_wait_tensorcnt (not mbarrier). + * - A single counter tracks both loads and stores in issue order. + */ + +#ifndef TRANSFORMER_ENGINE_TDM_CUH_ +#define TRANSFORMER_ENGINE_TDM_CUH_ + +#ifdef __HIP_PLATFORM_AMD__ + +#if defined(__gfx1250__) +#include +#endif + +namespace transformer_engine { +namespace tdm { + +#if defined(__gfx1250__) + +// --------------------------------------------------------------------------- +// Data type mapping +// --------------------------------------------------------------------------- + +//! Returns log2(sizeof(element_in_bytes)) for the TDM dataSize field. +//! 8-bit -> 0, 16-bit -> 1, 32-bit -> 2, 64-bit -> 3. +//! For sub-byte types (NVFP4), treat as packed uint8 and pass 0. +__device__ __forceinline__ constexpr uint32_t get_data_size_from_bits(size_t type_num_bits) { + // type_num_bits: 8 -> 0, 16 -> 1, 32 -> 2, 64 -> 3 + return (type_num_bits <= 8) ? 0 : (type_num_bits <= 16) ? 1 : (type_num_bits <= 32) ? 2 : 3; +} + +// --------------------------------------------------------------------------- +// Wave guard +// --------------------------------------------------------------------------- + +//! Returns true for threads in the first wavefront (wave 0) of the block. +//! TDM instructions are wave-level -- only wave 0 should issue them. +__device__ __forceinline__ bool is_tdm_wave() { + const int linear_tid = threadIdx.x + threadIdx.y * blockDim.x; + return (linear_tid < 32); +} + +// --------------------------------------------------------------------------- +// Core 2D load: global memory -> LDS +// --------------------------------------------------------------------------- + +//! Set up a 2D D# descriptor (groups 0+1) and issue a TDM load. +//! +//! @param global_base Raw device pointer to tensor base. +//! @param lds_byte_offset LDS destination byte offset (from shared ptr cast). +//! @param tensor_w Full tensor width in elements. +//! @param tensor_h Full tensor height in elements. +//! @param tile_dim_x Tile width to load (elements, inner/columns). +//! @param tile_dim_y Tile height to load (elements, outer/rows). +//! @param stride_elements Row stride in elements. +//! @param data_size log2(sizeof(element)): 0=1B, 1=2B, 2=4B, 3=8B. +//! @param tile_col Tile start column offset in elements. +//! @param tile_row Tile start row offset in elements. +__device__ __forceinline__ +void load_2d_to_lds(const void* global_base, + uint32_t lds_byte_offset, + uint32_t tensor_w, + uint32_t tensor_h, + uint32_t tile_dim_x, + uint32_t tile_dim_y, + uint32_t stride_elements, + uint32_t data_size, + uint32_t tile_col, + uint32_t tile_row) { + gfx1250_TDM_GROUP0 g0; + gfx1250_TDM_GROUP1 g1; + + g0.ldsAddr(lds_byte_offset); + + // Compute global address of the tile's top-left element. + const size_t elem_bytes = 1u << data_size; + const char* base = reinterpret_cast(global_base); + const char* tile_start = base + + (static_cast(tile_row) * stride_elements + tile_col) * elem_bytes; + g0.globalAddr(reinterpret_cast(tile_start)); + + g1.dataSize(data_size); + // tensorDim = remaining extent from tile start to tensor edge (for OOB clamping). + g1.tensorDim0(tensor_w - tile_col); + g1.tensorDim1(tensor_h - tile_row); + g1.tileDim0(tile_dim_x); + g1.tileDim1(tile_dim_y); + g1.tensorDim0Stride(stride_elements); + + __builtin_amdgcn_tensor_load_to_lds_d2(g0.m_bitfield, g1.m_bitfield, /*cachepolicy=*/0); +} + +// --------------------------------------------------------------------------- +// Core 2D store: LDS -> global memory +// --------------------------------------------------------------------------- + +//! Set up a 2D D# descriptor and issue a TDM store. +//! Parameters mirror load_2d_to_lds but direction is LDS->global. +__device__ __forceinline__ +void store_2d_from_lds(void* global_base, + uint32_t lds_byte_offset, + uint32_t tensor_w, + uint32_t tensor_h, + uint32_t tile_dim_x, + uint32_t tile_dim_y, + uint32_t stride_elements, + uint32_t data_size, + uint32_t tile_col, + uint32_t tile_row) { + gfx1250_TDM_GROUP0 g0; + gfx1250_TDM_GROUP1 g1; + + g0.ldsAddr(lds_byte_offset); + + const size_t elem_bytes = 1u << data_size; + char* base = reinterpret_cast(global_base); + char* tile_start = base + + (static_cast(tile_row) * stride_elements + tile_col) * elem_bytes; + g0.globalAddr(reinterpret_cast(tile_start)); + + g1.dataSize(data_size); + g1.tensorDim0(tensor_w - tile_col); + g1.tensorDim1(tensor_h - tile_row); + g1.tileDim0(tile_dim_x); + g1.tileDim1(tile_dim_y); + g1.tensorDim0Stride(stride_elements); + + __builtin_amdgcn_tensor_store_from_lds_d2(g0.m_bitfield, g1.m_bitfield, /*cachepolicy=*/0); +} + +// --------------------------------------------------------------------------- +// Wait helpers (argument must be compile-time immediate) +// --------------------------------------------------------------------------- + +__device__ __forceinline__ void wait_tensorcnt_0() { + __builtin_amdgcn_s_wait_tensorcnt(0); +} + +__device__ __forceinline__ void wait_tensorcnt_1() { + __builtin_amdgcn_s_wait_tensorcnt(1); +} + +__device__ __forceinline__ void wait_tensorcnt_2() { + __builtin_amdgcn_s_wait_tensorcnt(2); +} + +__device__ __forceinline__ void wait_tensorcnt_3() { + __builtin_amdgcn_s_wait_tensorcnt(3); +} + +__device__ __forceinline__ void wait_tensorcnt_4() { + __builtin_amdgcn_s_wait_tensorcnt(4); +} + +// --------------------------------------------------------------------------- +// Higher-level helpers (matching ptx.cuh copy_2d_to_shared interface) +// --------------------------------------------------------------------------- +// These handle the is_tdm_wave() guard internally. +// The caller is responsible for __syncthreads() AFTER calling these, +// matching the TMA pattern where mbarrier_wait + syncthreads follows. + +//! Load a single 2D tile from global to shared via TDM. +//! Only wave 0 issues the instruction; other waves are no-ops. +//! +//! @param lds_dst Shared memory destination pointer. +//! @param global_base Raw device pointer to tensor base. +//! @param chunk_x Tile column offset (elements). +//! @param chunk_y Tile row offset (elements). +//! @param tile_dim_x Tile width (elements). +//! @param tile_dim_y Tile height (elements). +//! @param tensor_w Full tensor width (elements). +//! @param tensor_h Full tensor height (elements). +//! @param stride Row stride (elements). +//! @param data_size log2(sizeof(element)). +__device__ __forceinline__ +void copy_2d_to_shared(void* lds_dst, + const void* global_base, + uint32_t chunk_x, + uint32_t chunk_y, + uint32_t tile_dim_x, + uint32_t tile_dim_y, + uint32_t tensor_w, + uint32_t tensor_h, + uint32_t stride, + uint32_t data_size) { + if (is_tdm_wave()) { + uint32_t lds_off = static_cast(reinterpret_cast(lds_dst)); + load_2d_to_lds(global_base, lds_off, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + chunk_x, chunk_y); + } +} + +//! Load two 2D tiles from (possibly different) tensors into shared via TDM. +__device__ __forceinline__ +void copy_2d_to_shared_x2(void* dst1, const void* src1, uint32_t cx1, uint32_t cy1, + void* dst2, const void* src2, uint32_t cx2, uint32_t cy2, + uint32_t tile_dim_x, uint32_t tile_dim_y, + uint32_t tensor_w, uint32_t tensor_h, + uint32_t stride, uint32_t data_size) { + if (is_tdm_wave()) { + uint32_t lds_off1 = static_cast(reinterpret_cast(dst1)); + load_2d_to_lds(src1, lds_off1, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + cx1, cy1); + + uint32_t lds_off2 = static_cast(reinterpret_cast(dst2)); + load_2d_to_lds(src2, lds_off2, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + cx2, cy2); + } +} + +//! Load three 2D tiles from (possibly different) tensors into shared via TDM. +__device__ __forceinline__ +void copy_2d_to_shared_x3(void* dst1, const void* src1, uint32_t cx1, uint32_t cy1, + void* dst2, const void* src2, uint32_t cx2, uint32_t cy2, + void* dst3, const void* src3, uint32_t cx3, uint32_t cy3, + uint32_t tile_dim_x, uint32_t tile_dim_y, + uint32_t tensor_w, uint32_t tensor_h, + uint32_t stride, uint32_t data_size) { + if (is_tdm_wave()) { + uint32_t lds_off1 = static_cast(reinterpret_cast(dst1)); + load_2d_to_lds(src1, lds_off1, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + cx1, cy1); + + uint32_t lds_off2 = static_cast(reinterpret_cast(dst2)); + load_2d_to_lds(src2, lds_off2, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + cx2, cy2); + + uint32_t lds_off3 = static_cast(reinterpret_cast(dst3)); + load_2d_to_lds(src3, lds_off3, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + cx3, cy3); + } +} + +//! Store a 2D tile from shared to global via TDM. +//! Only wave 0 issues the instruction. +//! Caller must ensure all threads have finished writing to LDS (via __syncthreads()) +//! BEFORE calling this. +__device__ __forceinline__ +void store_2d_to_global(const void* lds_src, + void* global_base, + uint32_t chunk_x, + uint32_t chunk_y, + uint32_t tile_dim_x, + uint32_t tile_dim_y, + uint32_t tensor_w, + uint32_t tensor_h, + uint32_t stride, + uint32_t data_size) { + if (is_tdm_wave()) { + uint32_t lds_off = static_cast(reinterpret_cast(lds_src)); + store_2d_from_lds(global_base, lds_off, + tensor_w, tensor_h, + tile_dim_x, tile_dim_y, + stride, data_size, + chunk_x, chunk_y); + } +} + +#else // !defined(__gfx1250__) + +// Stubs for non-gfx1250 AMD targets -- these should never be called. +__device__ __forceinline__ bool is_tdm_wave() { return false; } +__device__ __forceinline__ void wait_tensorcnt_0() {} +__device__ __forceinline__ void wait_tensorcnt_1() {} +__device__ __forceinline__ void wait_tensorcnt_2() {} +__device__ __forceinline__ void wait_tensorcnt_3() {} +__device__ __forceinline__ void wait_tensorcnt_4() {} + +__device__ __forceinline__ constexpr uint32_t get_data_size_from_bits(size_t type_num_bits) { + return (type_num_bits <= 8) ? 0 : (type_num_bits <= 16) ? 1 : (type_num_bits <= 32) ? 2 : 3; +} + +__device__ __forceinline__ +void copy_2d_to_shared(void*, const void*, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t) {} + +__device__ __forceinline__ +void copy_2d_to_shared_x2(void*, const void*, uint32_t, uint32_t, + void*, const void*, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t) {} + +__device__ __forceinline__ +void copy_2d_to_shared_x3(void*, const void*, uint32_t, uint32_t, + void*, const void*, uint32_t, uint32_t, + void*, const void*, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t) {} + +__device__ __forceinline__ +void store_2d_to_global(const void*, void*, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t) {} + +#endif // defined(__gfx1250__) + +} // namespace tdm +} // namespace transformer_engine + +#endif // __HIP_PLATFORM_AMD__ + +#endif // TRANSFORMER_ENGINE_TDM_CUH_ From 40f3902e9769d48d5ff760075ecc0a5f0dc1a3ea Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 23 Apr 2026 00:54:47 +0000 Subject: [PATCH 02/43] [ROCm] address first round reviewer comments --- .../rocshmem_api/rocshmem_waitkernel.hip | 115 +++ .../common/util/cast_gated_kernels.cuh | 915 +++++++++++++++++- 2 files changed, 1029 insertions(+), 1 deletion(-) create mode 100644 transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip diff --git a/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip new file mode 100644 index 000000000..9f7fe0e2e --- /dev/null +++ b/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip @@ -0,0 +1,115 @@ +/************************************************************************* + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * License for AMD contributions = MIT. See LICENSE for more information +*************************************************************************/ + +#include +#include + +#include "../util/logging_hip.h" +#include "rocshmem_waitkernel.hpp" + +using namespace rocshmem; + +__global__ void wait_until_on_stream_and_reset(uint64_t *wait_flag, + uint64_t wait_value, + uint64_t signal_reset) { + rocshmem_ulonglong_wait_until((unsigned long long*)wait_flag, + ROCSHMEM_CMP_EQ, + (unsigned long long)wait_value); +} + +__global__ void rocshmem_putmem_signal_kernel(void* dst_ptr, const void* src_ptr, + size_t nelement, uint64_t* sig_addr, + uint64_t sigval, int peer) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + rocshmem_putmem(dst_ptr, src_ptr, nelement, peer); + rocshmem_fence(); + rocshmem_ulonglong_p((unsigned long long*)sig_addr, + (unsigned long long)sigval, + peer); + } +} + +void te_rocshmem_putmem_signal(void* dst_ptr, const void* src_ptr, size_t nelement, + uint64_t* sig_addr, uint64_t sigval, int peer, + hipStream_t cur_stream) { + hipLaunchKernelGGL(rocshmem_putmem_signal_kernel, + dim3(1), dim3(1), 0, cur_stream, + dst_ptr, src_ptr, nelement, sig_addr, + sigval, peer); +} + +void te_rocshmem_wait_on_stream(uint64_t* sig_addr, + WaitKind wait_kind, + hipStream_t cur_stream) { + uint64_t wait_value = 1; + uint64_t signal_reset = 0; + + NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && + wait_kind <= WaitKind::STREAM_WAIT, + "Invalid wait kind"); + + switch (wait_kind) { +// ### wait_until_on_stream not yet implemented for rocshmem ### +// ### KernelWait is robust but slightly slower due to launch ### + case WaitKind::ROCSHMEM_WAIT: + printf("WARNING: rocshmem wait is not implemented yet, defaulting to kernel wait.\n"); + // rocshmem__ulonglong_wait_until_on_stream(sig_addr, + // ROCSHMEM_CMP_EQ, + // wait_value, + // cur_stream); + // hipStreamWriteValue64(cur_stream, + // reinterpret_cast(sig_addr), + // signal_reset, 0); + // break; + case WaitKind::KERNEL_WAIT: + hipLaunchKernelGGL(wait_until_on_stream_and_reset, + dim3(1), dim3(1), 0, cur_stream, + sig_addr, wait_value, signal_reset); + hipStreamWriteValue64(cur_stream, + reinterpret_cast(sig_addr), + signal_reset, 0); + break; + case WaitKind::STREAM_WAIT: + hipStreamWaitValue64(cur_stream, + reinterpret_cast(sig_addr), + wait_value, hipStreamWaitValueGte); + hipStreamWriteValue64(cur_stream, + reinterpret_cast(sig_addr), + signal_reset, 0); + break; + } +} + +int te_rocshmem_init_thread(int required, int* provided) { + if (required == 0 && provided == nullptr) { + rocshmem_init(); + return 0; + } else { + return rocshmem_init_thread(required, provided); + } +} + +void te_rocshmem_finalize() { + rocshmem_finalize(); +} + +int te_rocshmem_my_pe() { + return rocshmem_my_pe(); +} + +int te_rocshmem_n_pes() { + return rocshmem_n_pes(); +} + +void* te_rocshmem_malloc(size_t size) { + return rocshmem_malloc(size); +} + +void te_rocshmem_free(void* ptr) { + rocshmem_free(ptr); +} + +void te_rocshmem_wait_until(uint64_t* signal_addr, uint64_t expected_value, + hipStream_t stream); \ No newline at end of file diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index dcb3aa42d..7718f91f6 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -30,6 +30,7 @@ #include "ptx.cuh" #ifdef __HIP_PLATFORM_AMD__ #include "rocm_cast_gated_kernels.cuh" +#include "tdm.cuh" #endif namespace transformer_engine { @@ -972,6 +973,892 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } #endif //#ifdef __HIP_PLATFORM_AMD__ +// ===================================================================================== +// NV-upstream-flow TDM kernels for gfx1250 +// These mirror the NV upstream kernel structure (128x128 chunks, 512 threads) but +// replace CUtensorMap/TMA with raw-pointer/TDM. They are compiled only on AMD and +// selectable via NVTE_USE_NV_UPSTREAM_FLOW env var for performance comparison. +// ===================================================================================== +#ifdef __HIP_PLATFORM_AMD__ + +namespace nv_upstream_tdm { + +// NV upstream constants (same as NV upstream, not ROCm flow) +constexpr size_t NV_CHUNK_DIM_Y = 128; +constexpr size_t NV_CHUNK_DIM_X = 128; +constexpr size_t NV_THREADS_PER_CHUNK = 512; +constexpr size_t NV_THREADS_PER_CHUNK_X = NV_CHUNK_DIM_X; +constexpr size_t NV_THREADS_PER_CHUNK_Y = NV_THREADS_PER_CHUNK / NV_THREADS_PER_CHUNK_X; // 4 +constexpr size_t NV_BUFFER_DIM_Y = 32; +constexpr size_t NV_BUFFER_DIM_X = NV_CHUNK_DIM_X; // 128 +constexpr size_t NV_SHMEM_DIM_Y = NV_BUFFER_DIM_Y; // 32 +constexpr size_t NV_SHMEM_DIM_X = NV_BUFFER_DIM_X; // 128 +constexpr size_t NV_BUFFER_STAGES_NUM = NV_BUFFER_DIM_Y / NV_THREADS_PER_CHUNK_Y; // 8 +constexpr size_t NV_ITERATIONS = NV_CHUNK_DIM_Y / NV_BUFFER_DIM_Y; // 4 +static_assert(NV_ITERATIONS >= 1); + +__device__ inline float nv_sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + +// --------------------------------------------------------------------------- +// FP8 gated kernel — NV upstream flow ported to TDM +// --------------------------------------------------------------------------- +template +__global__ void __launch_bounds__(NV_THREADS_PER_CHUNK) + cast_fp8_gated_kernel_tdm(const IType *__restrict__ grad_ptr, + const IType *__restrict__ input_act_ptr, + const IType *__restrict__ input_gate_ptr, + OType *__restrict__ output_act_ptr, + OType *__restrict__ output_gate_ptr, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, + const size_t rows, const size_t cols, + const size_t input_act_stride, + const size_t input_gate_stride, + const size_t grad_stride, + const size_t output_stride) { +#if defined(__gfx1250__) + const size_t chunk_offset_Y = blockIdx.y * NV_CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * NV_CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / NV_THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % NV_THREADS_PER_CHUNK_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + constexpr size_t ALIGNMENT = 128; + constexpr size_t buff_elems = NV_SHMEM_DIM_Y * NV_SHMEM_DIM_X; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + uintptr_t dshmem = (base_shmem_ptr + ALIGNMENT - 1) & ~(static_cast(ALIGNMENT - 1)); + + constexpr size_t buff_size_aligned_in = + ((buff_elems * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + constexpr size_t buff_size_aligned_out = + ((buff_elems * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + constexpr size_t out_act_mem = buff_size_aligned_out; + + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + constexpr uint32_t in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + +#pragma unroll + for (int it = 0; it < NV_ITERATIONS; ++it) { + const size_t chunk_it_offset_y = chunk_offset_Y + it * NV_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + + // TDM load + if constexpr (IS_DGATED) { + tdm::copy_2d_to_shared(in_grad_sh, grad_ptr, + chunk_it_offset_x, chunk_it_offset_y, + NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, + cols, rows, grad_stride, in_data_sz); + tdm::copy_2d_to_shared_x2( + in_act_sh, input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, + in_gate_sh, input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, + NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } else { + tdm::copy_2d_to_shared_x2( + in_act_sh, input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, + in_gate_sh, input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, + NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + + // Compute +#pragma unroll + for (int stage = 0; stage < NV_BUFFER_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage * NV_THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = tid_Y + stage_offset_Y; + const size_t shmem_offset_x = tid_X; + const size_t shmem_idx = shmem_offset_y * NV_SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh[shmem_idx]); + float gate_elt = static_cast(in_gate_sh[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh[shmem_idx]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = nv_sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + out_act_sh[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, {}) * gate_elt; + out_act_sh[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + __syncthreads(); + + // TDM store + tdm::store_2d_to_global(out_act_sh, output_act_ptr, + chunk_it_offset_x, chunk_it_offset_y, + NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, + cols, rows, output_stride, out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(out_gate_sh, output_gate_ptr, + chunk_it_offset_x, chunk_it_offset_y, + NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, + cols, rows, output_stride, out_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + amax = reduce_max(amax, warp_id); + if (threadIdx.x == 0) { + atomicMaxFloat(amax_ptr, amax); + } + } + + if (threadIdx.x == 0 && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } +#endif // defined(__gfx1250__) +} + +// --------------------------------------------------------------------------- +// MXFP8 gated kernel — NV upstream flow ported to TDM +// --------------------------------------------------------------------------- +namespace nv_mxfp8 { + +constexpr size_t MX_CHUNK_DIM_Y = 64; +constexpr size_t MX_CHUNK_DIM_X = 64; +constexpr size_t MX_THREADS_PER_CHUNK_COLWISE = 128; +constexpr size_t MX_THREADS_PER_CHUNK_NON_COLWISE = MX_CHUNK_DIM_X; + +constexpr size_t MX_SCALE_DIM_Y = 32; +constexpr size_t MX_SCALE_DIM_X = 32; + +constexpr size_t MX_BUFF_DIM_Y = 32; +constexpr size_t MX_BUFF_DIM_X = MX_CHUNK_DIM_X; +constexpr size_t MX_BUFF_DIM = MX_BUFF_DIM_Y * MX_BUFF_DIM_X; +static_assert(MX_BUFF_DIM_Y == 32); + +constexpr size_t MX_PACK_SIZE = 4; +constexpr size_t MX_WAVES = MX_SCALE_DIM_X / MX_PACK_SIZE; + +constexpr size_t MX_TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 +constexpr size_t MX_THREADS_PER_BANK = MX_TOTAL_BANKS_WIDTH / MX_SCALE_DIM_X; // 4 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_mxfp8_gated_kernel_tdm(const IType *__restrict__ grad_ptr, + const IType *__restrict__ input_act_ptr, + const IType *__restrict__ input_gate_ptr, + OType *__restrict__ output_act_rowwise_ptr, + OType *__restrict__ output_gate_rowwise_ptr, + OType *__restrict__ output_act_colwise_ptr, + OType *__restrict__ output_gate_colwise_ptr, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, + const size_t input_act_stride, + const size_t input_gate_stride, + const size_t grad_stride, + const size_t output_stride) { +#if defined(__gfx1250__) + constexpr size_t STAGES = MX_CHUNK_DIM_Y / MX_BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING; + constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); + constexpr size_t COLWISE_WAVEFRONT_SIZE = (THREADS_PER_CHUNK + MX_CHUNK_DIM_X - 1) / MX_CHUNK_DIM_X; + + const size_t block_offset_Y = blockIdx.y * MX_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * MX_CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * MX_CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * MX_CHUNK_DIM_X / MX_SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * MX_CHUNK_DIM_Y / MX_SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * MX_CHUNK_DIM_X; + + constexpr size_t THREADS_X_ROWWISE = MX_CHUNK_DIM_X / MX_SCALE_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_Y_colwise = threadIdx.x / MX_CHUNK_DIM_X; + const size_t tid_X_colwise = threadIdx.x % MX_CHUNK_DIM_X; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * MX_SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const size_t gate_scale_idx_offset_rowwise = (cols + MX_SCALE_DIM_X - 1) / MX_SCALE_DIM_X; + const size_t gate_scale_idx_offset_colwise = cols; + + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / MX_THREADS_PER_BANK; + + constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; + __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][MX_CHUNK_DIM_X]; + + constexpr size_t ALIGNMENT = 128; + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + uintptr_t dshmem = (base_shmem_ptr + ALIGNMENT - 1) & ~(static_cast(ALIGNMENT - 1)); + + constexpr size_t buff_elems = MX_BUFF_DIM_Y * MX_BUFF_DIM_X; + constexpr size_t buff_size_aligned_in = + ((buff_elems * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + constexpr size_t buff_size_aligned_out = + ((buff_elems * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_mem = out_act_mem + out_gate_mem; + + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + + OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + OType *out_act_colwise_sh = out_act_rowwise_sh; + OType *out_gate_colwise_sh = out_gate_rowwise_sh; + + if constexpr (ROWWISE_SCALING && COLWISE_SCALING) { + out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + out_gate_colwise_sh = + reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + } + + IType *cached_act_sh = in_act_sh; + IType *cached_gate_sh = in_gate_sh; + + constexpr uint32_t in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * MX_BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + // TDM load + if constexpr (IS_DGATED) { + tdm::copy_2d_to_shared(in_grad_sh, grad_ptr, + global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, grad_stride, in_data_sz); + tdm::copy_2d_to_shared_x2( + in_act_sh, input_act_ptr, global_offset_X, global_offset_Y, + in_gate_sh, input_gate_ptr, global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } else { + tdm::copy_2d_to_shared_x2( + in_act_sh, input_act_ptr, global_offset_X, global_offset_Y, + in_gate_sh, input_gate_ptr, global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + + // ---- Colwise scaling pass ---- + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = tid_Y_colwise * MX_BUFF_DIM_X + tid_X_colwise; + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + float after_act_colwise[MX_BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; + float after_gate_colwise[MX_BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; + +#pragma unroll + for (int i = 0; i < MX_SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const size_t shmem_offset_colwise = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * MX_BUFF_DIM_X; + + float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); + float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); + float after_act_elt; + float after_gate_elt; + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = nv_sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + } else { + after_act_elt = ActOP(act_elt, {}) * gate_elt; + } + + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } + } + + after_act_colwise[i] = after_act_elt; + if constexpr (IS_DGATED) { + after_gate_colwise[i] = after_gate_elt; + } + + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); + if constexpr (IS_DGATED) { + cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); + } + } + + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } + } + } + + if constexpr (ONLY_COLWISE_SCALING) { + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act; + } + __syncthreads(); + if (tid_Y_colwise == 0) { +#pragma unroll + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_act >= 0); + __builtin_assume(other_thread_amax >= 0); + thread_amax_act = fmaxf(thread_amax_act, other_thread_amax); + } + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act; + } + __syncthreads(); + thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; + + if constexpr (IS_DGATED) { + __syncthreads(); + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + if (tid_Y_colwise == 0) { +#pragma unroll + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_gate >= 0); + __builtin_assume(other_thread_amax >= 0); + thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax); + } + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise]; + } + } + + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; + const bool out_of_bounds_colwise_flag = row_out_of_bounds_colwise || col_out_of_bounds_colwise; + + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise_flag)) { + scales_colwise[scale_idx] = biased_exponent_act; + } + + float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + float block_scale_inverse_gate; + + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise_flag)) { + scales_colwise[scale_idx_gate] = biased_exponent_gate; + } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + } + +#pragma unroll + for (int i = 0; i < MX_SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const size_t shmem_offset_elt = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * MX_BUFF_DIM_X; + if constexpr (IS_DGATED) { + out_act_colwise_sh[shmem_offset_elt] = + static_cast(block_scale_inverse_act * after_act_colwise[i]); + out_gate_colwise_sh[shmem_offset_elt] = + static_cast(block_scale_inverse_gate * after_gate_colwise[i]); + } else { + const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; + out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); + } + } + } + + // ---- Rowwise scaling pass ---- + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = thread_offset_Y_rowwise * MX_BUFF_DIM_X; + + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + + Vec in_cached_act[MX_WAVES]; + Vec in_cached_gate[MX_WAVES]; + + float after_act_rowwise[MX_SCALE_DIM_X]; + float after_gate_rowwise[MX_SCALE_DIM_X]; + + if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); +#pragma unroll + for (int w = 0; w < MX_WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * MX_PACK_SIZE) % MX_SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); + } + if (!out_of_bounds) { +#pragma unroll + for (int e = 0; e < MX_PACK_SIZE; ++e) { + thread_amax_act = fmaxf(thread_amax_act, + fabsf(static_cast(in_cached_act[w].data.elt[e]))); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, + fabsf(static_cast(in_cached_gate[w].data.elt[e]))); + } + } + } + } + } else { +#pragma unroll + for (int w = 0; w < MX_WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * MX_PACK_SIZE) % MX_SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in_grad; + Vec in_act; + Vec in_gate; + + in_act.load_from(&in_act_sh[shmem_offset_rowwise]); + in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); + } + +#pragma unroll + for (int e = 0; e < MX_PACK_SIZE; ++e) { + const int j = w * MX_PACK_SIZE + e; + + float act_elt = static_cast(in_act.data.elt[e]); + float gate_elt = static_cast(in_gate.data.elt[e]); + float after_act_elt; + float after_gate_elt; + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad.data.elt[e]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = nv_sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + after_act_rowwise[j] = after_act_elt; + after_gate_rowwise[j] = after_gate_elt; + } else { + after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_rowwise[j] = after_act_elt; + } + + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } + } + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } + } + } + } + } + + // Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; + const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx] = biased_exponent_act; + } + + const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + + float block_scale_inverse_gate; + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + } + + // Scale elements +#pragma unroll + for (int w = 0; w < MX_WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * MX_PACK_SIZE) % MX_SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + +#pragma unroll + for (int e = 0; e < MX_PACK_SIZE; ++e) { + const int j = w * MX_PACK_SIZE + e; + float in_act_val; + + if constexpr (IS_CACHED_ACT_OP) { + in_act_val = static_cast(in_cached_act[w].data.elt[e]); + } else { + in_act_val = after_act_rowwise[j]; + } + out_act_rowwise_sh[shmem_offset_rowwise + e] = + static_cast(block_scale_inverse_act * in_act_val); + + if constexpr (IS_DGATED) { + float in_gate_val; + if constexpr (IS_CACHED_ACT_OP) { + in_gate_val = static_cast(in_cached_gate[w].data.elt[e]); + } else { + in_gate_val = after_gate_rowwise[j]; + } + out_gate_rowwise_sh[shmem_offset_rowwise + e] = + static_cast(block_scale_inverse_gate * in_gate_val); + } + } + } + } + + __syncthreads(); + + // TDM store + if constexpr (ROWWISE_SCALING) { + tdm::store_2d_to_global(out_act_rowwise_sh, output_act_rowwise_ptr, + global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, output_stride, out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(out_gate_rowwise_sh, output_gate_rowwise_ptr, + global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, output_stride, out_data_sz); + } + } + if constexpr (COLWISE_SCALING) { + tdm::store_2d_to_global(out_act_colwise_sh, output_act_colwise_ptr, + global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, output_stride, out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(out_gate_colwise_sh, output_gate_colwise_ptr, + global_offset_X, global_offset_Y, + MX_BUFF_DIM_X, MX_BUFF_DIM_Y, + cols, rows, output_stride, out_data_sz); + } + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#endif // defined(__gfx1250__) +} + +} // namespace nv_mxfp8 +} // namespace nv_upstream_tdm + +// --------------------------------------------------------------------------- +// NV upstream TDM launcher: FP8 gated +// --------------------------------------------------------------------------- +template +void cast_fp8_gated_nv_upstream_tdm(const Tensor &grad, const Tensor &gated_input, + Tensor *output, cudaStream_t stream) { + using namespace nv_upstream_tdm; + + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, NV_CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, NV_CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(NV_THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + const IType *grad_ptr = IS_DGATED + ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); + const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; + OType *output_act_ptr = reinterpret_cast(output->data.dptr); + OType *output_gate_ptr = IS_DGATED + ? reinterpret_cast(output->data.dptr) + cols : nullptr; + + constexpr size_t ALIGNMENT = 128; + constexpr size_t buff_elems = NV_SHMEM_DIM_Y * NV_SHMEM_DIM_X; + const size_t buff_size_aligned_in = + ((buff_elems * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + const size_t buff_size_aligned_out = + ((buff_elems * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + ALIGNMENT; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_fp8_gated_kernel_tdm, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_fp8_gated_kernel_tdm + <<>>( + grad_ptr, input_act_ptr, input_gate_ptr, + output_act_ptr, output_gate_ptr, + amax_ptr, scale_inv_ptr, scale_ptr, + rows, cols, + cols * 2, cols * 2, cols, output_cols); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} + +// --------------------------------------------------------------------------- +// NV upstream TDM launcher: MXFP8 gated +// --------------------------------------------------------------------------- +template +void cast_mxfp8_gated_nv_upstream_tdm(const Tensor &grad, const Tensor &gated_input, + Tensor *output, cudaStream_t stream) { + using namespace nv_upstream_tdm; + using namespace nv_upstream_tdm::nv_mxfp8; + + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + + if (USE_ROWWISE_SCALING) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, MX_CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, MX_CHUNK_DIM_X); + + const size_t THREADS_PER_CHUNK = USE_COLWISE_SCALING + ? MX_THREADS_PER_CHUNK_COLWISE : MX_THREADS_PER_CHUNK_NON_COLWISE; + + const dim3 grid(blocks_X, blocks_Y); + const dim3 block_size(THREADS_PER_CHUNK); + + size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; + size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + const IType *grad_ptr = IS_DGATED + ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); + const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; + OType *output_act_rowwise = USE_ROWWISE_SCALING + ? reinterpret_cast(output->data.dptr) : nullptr; + OType *output_gate_rowwise = USE_ROWWISE_SCALING + ? reinterpret_cast(output->data.dptr) + cols : nullptr; + OType *output_act_colwise = USE_COLWISE_SCALING + ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; + OType *output_gate_colwise = USE_COLWISE_SCALING + ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + + constexpr size_t ALIGNMENT = 128; + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + constexpr size_t buff_elems = MX_BUFF_DIM_Y * MX_BUFF_DIM_X; + const size_t input_buff_size = (buff_elems * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + ((input_buff_size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + const size_t buff_size_aligned_out = + ((output_buff_size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + const size_t shmem_size = in_mem + out_mem + ALIGNMENT; + + auto launch_kernel = [&](auto rowwise_tag, auto colwise_tag, auto threads_tag) { + constexpr bool RW = decltype(rowwise_tag)::value; + constexpr bool CW = decltype(colwise_tag)::value; + constexpr size_t TPC = decltype(threads_tag)::value; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel_tdm, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel_tdm + <<>>( + grad_ptr, input_act_ptr, input_gate_ptr, + output_act_rowwise, output_gate_rowwise, + output_act_colwise, output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise, + cols * 2, cols * 2, cols, output_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + launch_kernel(std::true_type{}, std::false_type{}, + std::integral_constant{}); + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + launch_kernel(std::false_type{}, std::true_type{}, + std::integral_constant{}); + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + launch_kernel(std::true_type{}, std::true_type{}, + std::integral_constant{}); + } + ); // NOLINT(*) + ); // NOLINT(*) +} + +#endif // __HIP_PLATFORM_AMD__ + template void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, @@ -1316,11 +2203,21 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { #ifdef __HIP_PLATFORM_AMD__ - if constexpr (IS_DGATED) { + // Check env var: NVTE_USE_NV_UPSTREAM_FLOW=1 selects NV upstream TDM kernel path + static const bool use_nv_upstream = [] { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_nv_upstream && use_tma_kernels) { + cast_fp8_gated_nv_upstream_tdm( + grad, gated_input, output, stream); + } else { + if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { cast_gated(gated_input, output, stream); } + } #else if (use_tma_kernels) { cast_fp8_gated(grad, gated_input, output, stream); @@ -1333,12 +2230,28 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } #endif } else if (is_mxfp_scaling(output->scaling_mode)) { +#ifdef __HIP_PLATFORM_AMD__ + static const bool use_nv_upstream_mx = [] { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_nv_upstream_mx && use_tma_kernels) { + cast_mxfp8_gated_nv_upstream_tdm( + grad, gated_input, output, stream); + } else if (use_tma_kernels) { + cast_mxfp8_gated(grad, gated_input, output, stream); + } else { + NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + "by 32, got input of shape ", gated_input.data.shape); + } +#else if (use_tma_kernels) { cast_mxfp8_gated(grad, gated_input, output, stream); } else { NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", "by 32, got input of shape ", gated_input.data.shape); } +#endif } else { NVTE_ERROR("Not supported scaling mode"); } From cb421f8eb86ee3f371905e71a858d7f811645292 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 23 Apr 2026 13:11:51 +0000 Subject: [PATCH 03/43] [ROCm] address reviewer comments --- .../common/util/cast_gated_kernels.cuh | 1537 ++++++----------- .../common/util/cast_kernels.cuh | 374 +++- .../common/util/dequantize_kernels.cuh | 94 +- 3 files changed, 927 insertions(+), 1078 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 7718f91f6..80a29b3f2 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -37,7 +37,11 @@ namespace transformer_engine { namespace gated_kernels { -#ifndef __HIP_PLATFORM_AMD__ +// NV upstream flow constants — used on both NVIDIA (TMA) and AMD gfx1250 (TDM). +// On AMD, the ROCm flow constants live in rocm_cast_gated_kernels.cuh; +// these are in the nv_flow namespace to avoid collision. +namespace nv_flow { + constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; constexpr size_t THREADS_PER_CHUNK = 512; @@ -58,14 +62,24 @@ __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + cast_fp8_gated_kernel( +#ifdef __HIP_PLATFORM_AMD__ + const IType *__restrict__ grad_ptr, + const IType *__restrict__ input_act_ptr, + const IType *__restrict__ input_gate_ptr, + OType *__restrict__ output_act_ptr, + OType *__restrict__ output_gate_ptr, +#else + const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, const __grid_constant__ CUtensorMap tensor_map_input_gate, const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_gate, +#endif float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const float *const scale_ptr, const size_t rows, const size_t cols, + const size_t input_act_stride, const size_t output_stride) { +#if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; @@ -81,17 +95,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t SHMEM_ALIGNMENT = 128; +#else + constexpr size_t SHMEM_ALIGNMENT = TMA_SHMEM_ALIGNMENT; +#endif + uintptr_t dshmem = (base_shmem_ptr + SHMEM_ALIGNMENT - 1) & + ~(static_cast(SHMEM_ALIGNMENT - 1)); constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + ((buff_elems_total * sizeof(IType) + SHMEM_ALIGNMENT - 1) / SHMEM_ALIGNMENT) * SHMEM_ALIGNMENT; constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + ((buff_elems_total * sizeof(OType) + SHMEM_ALIGNMENT - 1) / SHMEM_ALIGNMENT) * SHMEM_ALIGNMENT; constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; @@ -100,23 +117,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t out_act_mem = buff_size_aligned_out; +#ifndef __HIP_PLATFORM_AMD__ constexpr size_t in_transaction_size = buff_elems * sizeof(IType); +#endif - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); +#ifndef __HIP_PLATFORM_AMD__ const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); +#endif const bool is_master_thread = (threadIdx.x == 0); +#ifndef __HIP_PLATFORM_AMD__ // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[ITERATIONS]; @@ -126,7 +147,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) int parity = 0; // Prefetch data of the first stage - if constexpr (IS_DGATED) { copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, @@ -137,15 +157,43 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch + constexpr uint32_t in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + + // Prefetch data of the first stage + if constexpr (IS_DGATED) { + tdm::copy_2d_to_shared(in_grad_sh, grad_ptr, + chunk_offset_X, chunk_offset_Y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, cols, in_data_sz); + tdm::copy_2d_to_shared_x2( + in_act_sh, input_act_ptr, chunk_offset_X, chunk_offset_Y, + in_gate_sh, input_gate_ptr, chunk_offset_X, chunk_offset_Y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } else { + tdm::copy_2d_to_shared_x2( + in_act_sh, input_act_ptr, chunk_offset_X, chunk_offset_Y, + in_gate_sh, input_gate_ptr, chunk_offset_X, chunk_offset_Y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif // __HIP_PLATFORM_AMD__ #pragma unroll for (int it = 0; it < ITERATIONS; ++it) { const size_t buff = it % BUFFERS_NUM; const size_t next_it = it + 1; + + // Prefetch next iteration's data if (next_it < ITERATIONS) { const size_t next_buff = next_it % BUFFERS_NUM; const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; +#ifndef __HIP_PLATFORM_AMD__ if constexpr (IS_DGATED) { copy_2d_to_sharedx3( &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, @@ -158,12 +206,40 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, &mbar[next_it], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch + if constexpr (IS_DGATED) { + tdm::copy_2d_to_shared(&in_grad_sh[next_buff * buff_elems], grad_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, cols, in_data_sz); + tdm::copy_2d_to_shared_x2( + &in_act_sh[next_buff * buff_elems], input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } else { + tdm::copy_2d_to_shared_x2( + &in_act_sh[next_buff * buff_elems], input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, input_act_stride, in_data_sz); + } + // TDM is async — wait for the prefetch of the NEXT buffer in the NEXT iteration. + // For now, we use a simple wait-all pattern; double-buffering optimization can be added later. +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[it], parity); +#else + // On TDM: wait for prefetch of current buffer (issued in previous iteration or before the loop) + // and for any stores from the previous iteration to complete. + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; IType *in_act_sh_curr = in_act_sh + buff * buff_elems; @@ -171,6 +247,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_act_sh_curr = out_act_sh + buff * buff_elems; OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; + // Compute — identical for TMA and TDM #pragma unroll for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; @@ -212,64 +289,83 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } + // Store computed results from shared memory to global memory +#ifndef __HIP_PLATFORM_AMD__ // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) ptx::fence_proxy_async_shared_cta(); __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; - // dGeLU ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast(out_act_sh_curr)); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast(out_gate_sh_curr)); } - // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. ptx::cp_async_bulk_wait_group_read(); } +#else // __HIP_PLATFORM_AMD__ — TDM store + __syncthreads(); + { + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + + tdm::store_2d_to_global(out_act_sh_curr, output_act_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, output_stride, out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(out_gate_sh_curr, output_gate_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, output_stride, out_data_sz); + } + // TDM stores are async — they will be drained at the top of the next iteration + // (or after the loop for the last iteration). + } +#endif // __HIP_PLATFORM_AMD__ } + +#ifndef __HIP_PLATFORM_AMD__ ptx::cp_async_bulk_wait_group_read<0>(); +#else + tdm::wait_tensorcnt_0(); +#endif __syncthreads(); if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block amax = reduce_max(amax, warp_id); - // Update the global amax if (is_master_thread) { atomicMaxFloat(amax_ptr, amax); } } - // Update scale-inverse if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { reciprocal(scale_inv_ptr, scale); } - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. +#ifndef __HIP_PLATFORM_AMD__ + // Destroy the barriers. if (is_master_thread) { #pragma unroll for (int it = 0; it < ITERATIONS; ++it) { ptx::mbarrier_invalid(&mbar[it]); } } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#endif +#endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } +} // namespace nv_flow namespace mxfp8_kernel { @@ -300,19 +396,33 @@ template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + cast_mxfp8_gated_kernel( +#ifdef __HIP_PLATFORM_AMD__ + const IType *__restrict__ grad_ptr, + const IType *__restrict__ input_act_ptr, + const IType *__restrict__ input_gate_ptr, + OType *__restrict__ output_act_rowwise_ptr, + OType *__restrict__ output_gate_rowwise_ptr, + OType *__restrict__ output_act_colwise_ptr, + OType *__restrict__ output_gate_colwise_ptr, +#else + const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, const __grid_constant__ CUtensorMap tensor_map_input_gate, const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, +#endif e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const size_t scale_stride_colwise, + const size_t input_act_stride, const size_t output_stride) { +#if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) +#ifndef __HIP_PLATFORM_AMD__ using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; +#endif constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); @@ -367,17 +477,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t MX_SHMEM_ALIGNMENT = 128; +#else + constexpr size_t MX_SHMEM_ALIGNMENT = TMA_SHMEM_ALIGNMENT; +#endif + uintptr_t dshmem = (base_shmem_ptr + MX_SHMEM_ALIGNMENT - 1) & + ~(static_cast(MX_SHMEM_ALIGNMENT - 1)); constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + ((buff_elems_total * sizeof(IType) + MX_SHMEM_ALIGNMENT - 1) / MX_SHMEM_ALIGNMENT) * MX_SHMEM_ALIGNMENT; constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + ((buff_elems_total * sizeof(OType) + MX_SHMEM_ALIGNMENT - 1) / MX_SHMEM_ALIGNMENT) * MX_SHMEM_ALIGNMENT; const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); @@ -413,6 +526,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const bool is_master_thread = (threadIdx.x == 0); +#ifndef __HIP_PLATFORM_AMD__ // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[STAGES]; @@ -421,6 +535,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) int parity = 0; + // TMA prefetch if constexpr (IS_DGATED) { copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, @@ -431,6 +546,31 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, shmem_buff_size, &mbar[0], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM + constexpr uint32_t mx_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t mx_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + + // TDM prefetch + if constexpr (IS_DGATED) { + tdm::copy_2d_to_shared(&in_grad_sh[0], grad_ptr, + block_offset_X, block_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared_x2( + &in_act_sh[0], input_act_ptr, block_offset_X, block_offset_Y, + &in_gate_sh[0], input_gate_ptr, block_offset_X, block_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, input_act_stride, mx_in_data_sz); + } else { + tdm::copy_2d_to_shared_x2( + &in_act_sh[0], input_act_ptr, block_offset_X, block_offset_Y, + &in_gate_sh[0], input_gate_ptr, block_offset_X, block_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, input_act_stride, mx_in_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif // __HIP_PLATFORM_AMD__ #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { @@ -439,8 +579,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (next_stage < STAGES) { +#ifndef __HIP_PLATFORM_AMD__ // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to ptx::cp_async_bulk_wait_group_read<1>(); const size_t next_buff = next_stage % BUFFS_NUM; @@ -460,12 +600,41 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch next stage + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DGATED) { + tdm::copy_2d_to_shared(&in_grad_sh[next_buff_offset], grad_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared_x2( + &in_act_sh[next_buff_offset], input_act_ptr, global_offset_X, global_offset_Y, + &in_gate_sh[next_buff_offset], input_gate_ptr, global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, input_act_stride, mx_in_data_sz); + } else { + tdm::copy_2d_to_shared_x2( + &in_act_sh[next_buff_offset], input_act_ptr, global_offset_X, global_offset_Y, + &in_gate_sh[next_buff_offset], input_gate_ptr, global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, input_act_stride, mx_in_data_sz); + } +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); +#else + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif if constexpr (COLWISE_SCALING) { const size_t shmem_offset_base_colwise = @@ -621,6 +790,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t shmem_offset_elt = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; if constexpr (IS_DGATED) { +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for mul_cvt_2x (PTX intrinsic not available on AMD) + out_act_colwise_sh[shmem_offset_elt] = + static_cast(block_scale_inverse_act * after_act_colwise[i]); + out_gate_colwise_sh[shmem_offset_elt] = + static_cast(block_scale_inverse_gate * after_gate_colwise[i]); +#else OType2 out_pair; ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, @@ -628,6 +804,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair); out_act_colwise_sh[shmem_offset_elt] = out_pair.x; out_gate_colwise_sh[shmem_offset_elt] = out_pair.y; +#endif } else { const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); @@ -652,8 +829,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); +#ifndef __HIP_PLATFORM_AMD__ IType2 thread_amax_2x_act = {static_cast(0.0f), static_cast(0.0f)}; IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; +#endif #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; @@ -672,6 +851,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries if (!out_of_bounds) { +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for abs_max_2x (PTX intrinsic not available on AMD) +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax_act = fmaxf(thread_amax_act, + fabsf(static_cast(in_cached_act[w].data.elt[e]))); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, + fabsf(static_cast(in_cached_gate[w].data.elt[e]))); + } + } +#else if constexpr (std::is_same_v) { #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { @@ -693,8 +884,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } } +#endif // __HIP_PLATFORM_AMD__ } } +#ifndef __HIP_PLATFORM_AMD__ if constexpr (!std::is_same_v) { thread_amax_act = static_cast( __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); @@ -703,6 +896,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); } } +#endif } else { #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -786,11 +980,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); +#ifndef __HIP_PLATFORM_AMD__ const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act, block_scale_inverse_act}; +#endif float block_scale_inverse_gate; +#ifndef __HIP_PLATFORM_AMD__ ptx::floatx2 block_scale_inverse_2x_gate; +#endif if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); @@ -799,12 +997,43 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx_gate] = biased_exponent_gate; } block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); +#ifndef __HIP_PLATFORM_AMD__ block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate}; +#endif } // 3. Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for mul_cvt_2x (PTX intrinsic not available on AMD) +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + float in_act_val; + if constexpr (IS_CACHED_ACT_OP) { + in_act_val = static_cast(in_cached_act[w].data.elt[e]); + } else { + in_act_val = after_act_rowwise[w * PACK_SIZE + e]; + } + out_act_rowwise_sh[shmem_offset_rowwise + e] = + static_cast(block_scale_inverse_act * in_act_val); + + if constexpr (IS_DGATED) { + float in_gate_val; + if constexpr (IS_CACHED_ACT_OP) { + in_gate_val = static_cast(in_cached_gate[w].data.elt[e]); + } else { + in_gate_val = after_gate_rowwise[w * PACK_SIZE + e]; + } + out_gate_rowwise_sh[shmem_offset_rowwise + e] = + static_cast(block_scale_inverse_gate * in_gate_val); + } + } +#else Vec out_act; Vec out_gate; #pragma unroll @@ -837,16 +1066,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); if constexpr (IS_DGATED) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); } +#endif // __HIP_PLATFORM_AMD__ } } + // Store computed results from shared memory to global memory +#ifndef __HIP_PLATFORM_AMD__ // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); @@ -882,11 +1111,51 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); } +#else // __HIP_PLATFORM_AMD__ — TDM store + __syncthreads(); + { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + tdm::store_2d_to_global(&out_act_rowwise_sh[buff_offset], output_act_rowwise_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, output_stride, mx_out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(&out_gate_rowwise_sh[buff_offset], output_gate_rowwise_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, output_stride, mx_out_data_sz); + } + } + if constexpr (COLWISE_SCALING) { + tdm::store_2d_to_global(&out_act_colwise_sh[buff_offset], output_act_colwise_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, output_stride, mx_out_data_sz); + if constexpr (IS_DGATED) { + tdm::store_2d_to_global(&out_gate_colwise_sh[buff_offset], output_gate_colwise_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, output_stride, mx_out_data_sz); + } + } + // TDM stores are async — they will be drained at the top of the next iteration + // (or after the loop for the last iteration). + } +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ parity ^= 1; destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#else + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif +#endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } } // namespace mxfp8_kernel @@ -908,14 +1177,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, nv_flow::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, nv_flow::CHUNK_DIM_X); float *const amax_ptr = reinterpret_cast(output->amax.dptr); float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); float *const scale_ptr = reinterpret_cast(output->scale.dptr); - const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 block_dim(nv_flow::THREADS_PER_CHUNK); const dim3 grid_dim(blocks_X, blocks_Y); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -923,6 +1192,41 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, +#ifdef __HIP_PLATFORM_AMD__ + const IType *grad_ptr = IS_DGATED + ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); + const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; + OType *output_act_ptr = reinterpret_cast(output->data.dptr); + OType *output_gate_ptr = IS_DGATED + ? reinterpret_cast(output->data.dptr) + cols : nullptr; + + constexpr size_t ALIGNMENT = 128; + constexpr size_t buff_elems_total = nv_flow::BUFFERS_NUM * nv_flow::SHMEM_DIM_Y * nv_flow::SHMEM_DIM_X; + const size_t buff_size_aligned_in = + ((buff_elems_total * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + const size_t buff_size_aligned_out = + ((buff_elems_total * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + ALIGNMENT; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + nv_flow::cast_fp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + nv_flow::cast_fp8_gated_kernel + <<>>( + grad_ptr, input_act_ptr, input_gate_ptr, + output_act_ptr, output_gate_ptr, + amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, + cols * 2, output_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else alignas(64) CUtensorMap tensor_map_grad{}; alignas(64) CUtensorMap tensor_map_input_act{}; alignas(64) CUtensorMap tensor_map_input_gate{}; @@ -930,23 +1234,23 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu alignas(64) CUtensorMap tensor_map_output_gate{}; if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, nv_flow::SHMEM_DIM_Y, + nv_flow::SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, nv_flow::SHMEM_DIM_Y, + nv_flow::SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, nv_flow::SHMEM_DIM_Y, + nv_flow::SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, nv_flow::SHMEM_DIM_Y, + nv_flow::SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, nv_flow::SHMEM_DIM_Y, + nv_flow::SHMEM_DIM_X, tensor_stride_elems, cols, typeToNumBits(output->dtype())); - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = nv_flow::BUFFERS_NUM * nv_flow::SHMEM_DIM_Y * nv_flow::SHMEM_DIM_X; const size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = @@ -960,950 +1264,66 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, + nv_flow::cast_fp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - cast_fp8_gated_kernel + nv_flow::cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) + cols, cols * 2, output_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); +#endif // __HIP_PLATFORM_AMD__ + ); // NOLINT(*) + ); // NOLINT(*) } -#endif //#ifdef __HIP_PLATFORM_AMD__ - -// ===================================================================================== -// NV-upstream-flow TDM kernels for gfx1250 -// These mirror the NV upstream kernel structure (128x128 chunks, 512 threads) but -// replace CUtensorMap/TMA with raw-pointer/TDM. They are compiled only on AMD and -// selectable via NVTE_USE_NV_UPSTREAM_FLOW env var for performance comparison. -// ===================================================================================== -#ifdef __HIP_PLATFORM_AMD__ -namespace nv_upstream_tdm { - -// NV upstream constants (same as NV upstream, not ROCm flow) -constexpr size_t NV_CHUNK_DIM_Y = 128; -constexpr size_t NV_CHUNK_DIM_X = 128; -constexpr size_t NV_THREADS_PER_CHUNK = 512; -constexpr size_t NV_THREADS_PER_CHUNK_X = NV_CHUNK_DIM_X; -constexpr size_t NV_THREADS_PER_CHUNK_Y = NV_THREADS_PER_CHUNK / NV_THREADS_PER_CHUNK_X; // 4 -constexpr size_t NV_BUFFER_DIM_Y = 32; -constexpr size_t NV_BUFFER_DIM_X = NV_CHUNK_DIM_X; // 128 -constexpr size_t NV_SHMEM_DIM_Y = NV_BUFFER_DIM_Y; // 32 -constexpr size_t NV_SHMEM_DIM_X = NV_BUFFER_DIM_X; // 128 -constexpr size_t NV_BUFFER_STAGES_NUM = NV_BUFFER_DIM_Y / NV_THREADS_PER_CHUNK_Y; // 8 -constexpr size_t NV_ITERATIONS = NV_CHUNK_DIM_Y / NV_BUFFER_DIM_Y; // 4 -static_assert(NV_ITERATIONS >= 1); - -__device__ inline float nv_sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } - -// --------------------------------------------------------------------------- -// FP8 gated kernel — NV upstream flow ported to TDM -// --------------------------------------------------------------------------- + template -__global__ void __launch_bounds__(NV_THREADS_PER_CHUNK) - cast_fp8_gated_kernel_tdm(const IType *__restrict__ grad_ptr, - const IType *__restrict__ input_act_ptr, - const IType *__restrict__ input_gate_ptr, - OType *__restrict__ output_act_ptr, - OType *__restrict__ output_gate_ptr, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, - const size_t rows, const size_t cols, - const size_t input_act_stride, - const size_t input_gate_stride, - const size_t grad_stride, - const size_t output_stride) { -#if defined(__gfx1250__) - const size_t chunk_offset_Y = blockIdx.y * NV_CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * NV_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / NV_THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % NV_THREADS_PER_CHUNK_X; + float (*DActOP)(float, const ParamOP &)> +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + checkCuDriverContext(stream); - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); - constexpr size_t ALIGNMENT = 128; - constexpr size_t buff_elems = NV_SHMEM_DIM_Y * NV_SHMEM_DIM_X; + if (USE_ROWWISE_SCALING) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - uintptr_t dshmem = (base_shmem_ptr + ALIGNMENT - 1) & ~(static_cast(ALIGNMENT - 1)); +#ifndef __HIP_PLATFORM_AMD__ + ScalingType scaling_type; + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + scaling_type = ScalingType::COLWISE; + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + scaling_type = ScalingType::BIDIMENSIONAL; + } +#endif - constexpr size_t buff_size_aligned_in = - ((buff_elems * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - constexpr size_t buff_size_aligned_out = - ((buff_elems * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - constexpr size_t out_act_mem = buff_size_aligned_out; +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; + constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; + constexpr size_t BUFFS_NUM = BUFFERS_NUM; - constexpr uint32_t in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); - constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); +#else -#pragma unroll - for (int it = 0; it < NV_ITERATIONS; ++it) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * NV_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // TDM load - if constexpr (IS_DGATED) { - tdm::copy_2d_to_shared(in_grad_sh, grad_ptr, - chunk_it_offset_x, chunk_it_offset_y, - NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, - cols, rows, grad_stride, in_data_sz); - tdm::copy_2d_to_shared_x2( - in_act_sh, input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, - in_gate_sh, input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, - NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, - cols, rows, input_act_stride, in_data_sz); - } else { - tdm::copy_2d_to_shared_x2( - in_act_sh, input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, - in_gate_sh, input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, - NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, - cols, rows, input_act_stride, in_data_sz); - } - tdm::wait_tensorcnt_0(); - __syncthreads(); - - // Compute -#pragma unroll - for (int stage = 0; stage < NV_BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * NV_THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = tid_Y + stage_offset_Y; - const size_t shmem_offset_x = tid_X; - const size_t shmem_idx = shmem_offset_y * NV_SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh[shmem_idx]); - float gate_elt = static_cast(in_gate_sh[shmem_idx]); - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = nv_sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); - } - - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; - - out_act_sh[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; - out_act_sh[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - __syncthreads(); - - // TDM store - tdm::store_2d_to_global(out_act_sh, output_act_ptr, - chunk_it_offset_x, chunk_it_offset_y, - NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, - cols, rows, output_stride, out_data_sz); - if constexpr (IS_DGATED) { - tdm::store_2d_to_global(out_gate_sh, output_gate_ptr, - chunk_it_offset_x, chunk_it_offset_y, - NV_SHMEM_DIM_X, NV_SHMEM_DIM_Y, - cols, rows, output_stride, out_data_sz); - } - tdm::wait_tensorcnt_0(); - __syncthreads(); - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - amax = reduce_max(amax, warp_id); - if (threadIdx.x == 0) { - atomicMaxFloat(amax_ptr, amax); - } - } - - if (threadIdx.x == 0 && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } -#endif // defined(__gfx1250__) -} - -// --------------------------------------------------------------------------- -// MXFP8 gated kernel — NV upstream flow ported to TDM -// --------------------------------------------------------------------------- -namespace nv_mxfp8 { - -constexpr size_t MX_CHUNK_DIM_Y = 64; -constexpr size_t MX_CHUNK_DIM_X = 64; -constexpr size_t MX_THREADS_PER_CHUNK_COLWISE = 128; -constexpr size_t MX_THREADS_PER_CHUNK_NON_COLWISE = MX_CHUNK_DIM_X; - -constexpr size_t MX_SCALE_DIM_Y = 32; -constexpr size_t MX_SCALE_DIM_X = 32; - -constexpr size_t MX_BUFF_DIM_Y = 32; -constexpr size_t MX_BUFF_DIM_X = MX_CHUNK_DIM_X; -constexpr size_t MX_BUFF_DIM = MX_BUFF_DIM_Y * MX_BUFF_DIM_X; -static_assert(MX_BUFF_DIM_Y == 32); - -constexpr size_t MX_PACK_SIZE = 4; -constexpr size_t MX_WAVES = MX_SCALE_DIM_X / MX_PACK_SIZE; - -constexpr size_t MX_TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 -constexpr size_t MX_THREADS_PER_BANK = MX_TOTAL_BANKS_WIDTH / MX_SCALE_DIM_X; // 4 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel_tdm(const IType *__restrict__ grad_ptr, - const IType *__restrict__ input_act_ptr, - const IType *__restrict__ input_gate_ptr, - OType *__restrict__ output_act_rowwise_ptr, - OType *__restrict__ output_gate_rowwise_ptr, - OType *__restrict__ output_act_colwise_ptr, - OType *__restrict__ output_gate_colwise_ptr, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, - const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, - const size_t input_act_stride, - const size_t input_gate_stride, - const size_t grad_stride, - const size_t output_stride) { -#if defined(__gfx1250__) - constexpr size_t STAGES = MX_CHUNK_DIM_Y / MX_BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING; - constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); - constexpr size_t COLWISE_WAVEFRONT_SIZE = (THREADS_PER_CHUNK + MX_CHUNK_DIM_X - 1) / MX_CHUNK_DIM_X; - - const size_t block_offset_Y = blockIdx.y * MX_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * MX_CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * MX_CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * MX_CHUNK_DIM_X / MX_SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * MX_CHUNK_DIM_Y / MX_SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * MX_CHUNK_DIM_X; - - constexpr size_t THREADS_X_ROWWISE = MX_CHUNK_DIM_X / MX_SCALE_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const size_t tid_Y_colwise = threadIdx.x / MX_CHUNK_DIM_X; - const size_t tid_X_colwise = threadIdx.x % MX_CHUNK_DIM_X; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * MX_SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const size_t gate_scale_idx_offset_rowwise = (cols + MX_SCALE_DIM_X - 1) / MX_SCALE_DIM_X; - const size_t gate_scale_idx_offset_colwise = cols; - - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / MX_THREADS_PER_BANK; - - constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; - __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][MX_CHUNK_DIM_X]; - - constexpr size_t ALIGNMENT = 128; - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - uintptr_t dshmem = (base_shmem_ptr + ALIGNMENT - 1) & ~(static_cast(ALIGNMENT - 1)); - - constexpr size_t buff_elems = MX_BUFF_DIM_Y * MX_BUFF_DIM_X; - constexpr size_t buff_size_aligned_in = - ((buff_elems * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - constexpr size_t buff_size_aligned_out = - ((buff_elems * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = in_act_mem + in_gate_mem; - - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); - const size_t out_mem = out_act_mem + out_gate_mem; - - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - - OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - OType *out_act_colwise_sh = out_act_rowwise_sh; - OType *out_gate_colwise_sh = out_gate_rowwise_sh; - - if constexpr (ROWWISE_SCALING && COLWISE_SCALING) { - out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); - out_gate_colwise_sh = - reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); - } - - IType *cached_act_sh = in_act_sh; - IType *cached_gate_sh = in_gate_sh; - - constexpr uint32_t in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); - constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t stage_offset_Y = stage * MX_BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - - // TDM load - if constexpr (IS_DGATED) { - tdm::copy_2d_to_shared(in_grad_sh, grad_ptr, - global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, grad_stride, in_data_sz); - tdm::copy_2d_to_shared_x2( - in_act_sh, input_act_ptr, global_offset_X, global_offset_Y, - in_gate_sh, input_gate_ptr, global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, input_act_stride, in_data_sz); - } else { - tdm::copy_2d_to_shared_x2( - in_act_sh, input_act_ptr, global_offset_X, global_offset_Y, - in_gate_sh, input_gate_ptr, global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, input_act_stride, in_data_sz); - } - tdm::wait_tensorcnt_0(); - __syncthreads(); - - // ---- Colwise scaling pass ---- - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = tid_Y_colwise * MX_BUFF_DIM_X + tid_X_colwise; - float thread_amax_act = 0.0f; - float thread_amax_gate = 0.0f; - float after_act_colwise[MX_BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; - float after_gate_colwise[MX_BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; - -#pragma unroll - for (int i = 0; i < MX_SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { - const size_t shmem_offset_colwise = - shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * MX_BUFF_DIM_X; - - float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); - float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); - float after_act_elt; - float after_gate_elt; - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = nv_sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); - } - after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; - } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; - } - - if constexpr (!std::is_same_v) { - after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { - after_gate_elt = static_cast(static_cast(after_gate_elt)); - } - } - - after_act_colwise[i] = after_act_elt; - if constexpr (IS_DGATED) { - after_gate_colwise[i] = after_gate_elt; - } - - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); - if constexpr (IS_DGATED) { - cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); - } - } - - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - - if (!out_of_bounds) { - thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { - thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); - } - } - } - - if constexpr (ONLY_COLWISE_SCALING) { - if (tid_Y_colwise > 0) { - subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act; - } - __syncthreads(); - if (tid_Y_colwise == 0) { -#pragma unroll - for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { - const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; - __builtin_assume(thread_amax_act >= 0); - __builtin_assume(other_thread_amax >= 0); - thread_amax_act = fmaxf(thread_amax_act, other_thread_amax); - } - subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act; - } - __syncthreads(); - thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; - - if constexpr (IS_DGATED) { - __syncthreads(); - if (tid_Y_colwise > 0) { - subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate; - } - __syncthreads(); - if (tid_Y_colwise == 0) { -#pragma unroll - for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { - const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; - __builtin_assume(thread_amax_gate >= 0); - __builtin_assume(other_thread_amax >= 0); - thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax); - } - subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate; - } - __syncthreads(); - thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise]; - } - } - - const e8m0_t biased_exponent_act = - ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; - const bool out_of_bounds_colwise_flag = row_out_of_bounds_colwise || col_out_of_bounds_colwise; - - if (tid_Y_colwise == 0 && (!out_of_bounds_colwise_flag)) { - scales_colwise[scale_idx] = biased_exponent_act; - } - - float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); - float block_scale_inverse_gate; - - if constexpr (IS_DGATED) { - const e8m0_t biased_exponent_gate = - ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); - const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; - if (tid_Y_colwise == 0 && (!out_of_bounds_colwise_flag)) { - scales_colwise[scale_idx_gate] = biased_exponent_gate; - } - block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); - } - -#pragma unroll - for (int i = 0; i < MX_SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { - const size_t shmem_offset_elt = - shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * MX_BUFF_DIM_X; - if constexpr (IS_DGATED) { - out_act_colwise_sh[shmem_offset_elt] = - static_cast(block_scale_inverse_act * after_act_colwise[i]); - out_gate_colwise_sh[shmem_offset_elt] = - static_cast(block_scale_inverse_gate * after_gate_colwise[i]); - } else { - const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; - out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); - } - } - } - - // ---- Rowwise scaling pass ---- - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = thread_offset_Y_rowwise * MX_BUFF_DIM_X; - - float thread_amax_act = 0.0f; - float thread_amax_gate = 0.0f; - - Vec in_cached_act[MX_WAVES]; - Vec in_cached_gate[MX_WAVES]; - - float after_act_rowwise[MX_SCALE_DIM_X]; - float after_gate_rowwise[MX_SCALE_DIM_X]; - - if constexpr (IS_CACHED_ACT_OP) { - __syncthreads(); -#pragma unroll - for (int w = 0; w < MX_WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * MX_PACK_SIZE) % MX_SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { - in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); - } - if (!out_of_bounds) { -#pragma unroll - for (int e = 0; e < MX_PACK_SIZE; ++e) { - thread_amax_act = fmaxf(thread_amax_act, - fabsf(static_cast(in_cached_act[w].data.elt[e]))); - if constexpr (IS_DGATED) { - thread_amax_gate = fmaxf(thread_amax_gate, - fabsf(static_cast(in_cached_gate[w].data.elt[e]))); - } - } - } - } - } else { -#pragma unroll - for (int w = 0; w < MX_WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * MX_PACK_SIZE) % MX_SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in_grad; - Vec in_act; - Vec in_gate; - - in_act.load_from(&in_act_sh[shmem_offset_rowwise]); - in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { - in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); - } - -#pragma unroll - for (int e = 0; e < MX_PACK_SIZE; ++e) { - const int j = w * MX_PACK_SIZE + e; - - float act_elt = static_cast(in_act.data.elt[e]); - float gate_elt = static_cast(in_gate.data.elt[e]); - float after_act_elt; - float after_gate_elt; - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad.data.elt[e]); - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = nv_sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); - } - after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; - after_act_rowwise[j] = after_act_elt; - after_gate_rowwise[j] = after_gate_elt; - } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; - after_act_rowwise[j] = after_act_elt; - } - - if constexpr (!std::is_same_v) { - after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { - after_gate_elt = static_cast(static_cast(after_gate_elt)); - } - } - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { - thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); - } - } - } - } - } - - // Compute E8M0 scaling factor - const e8m0_t biased_exponent_act = - ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; - const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; - if (!out_of_bounds_rowwise) { - scales_rowwise[scale_idx] = biased_exponent_act; - } - - const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); - - float block_scale_inverse_gate; - if constexpr (IS_DGATED) { - const e8m0_t biased_exponent_gate = - ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); - const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; - if (!out_of_bounds_rowwise) { - scales_rowwise[scale_idx_gate] = biased_exponent_gate; - } - block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); - } - - // Scale elements -#pragma unroll - for (int w = 0; w < MX_WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * MX_PACK_SIZE) % MX_SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - -#pragma unroll - for (int e = 0; e < MX_PACK_SIZE; ++e) { - const int j = w * MX_PACK_SIZE + e; - float in_act_val; - - if constexpr (IS_CACHED_ACT_OP) { - in_act_val = static_cast(in_cached_act[w].data.elt[e]); - } else { - in_act_val = after_act_rowwise[j]; - } - out_act_rowwise_sh[shmem_offset_rowwise + e] = - static_cast(block_scale_inverse_act * in_act_val); - - if constexpr (IS_DGATED) { - float in_gate_val; - if constexpr (IS_CACHED_ACT_OP) { - in_gate_val = static_cast(in_cached_gate[w].data.elt[e]); - } else { - in_gate_val = after_gate_rowwise[j]; - } - out_gate_rowwise_sh[shmem_offset_rowwise + e] = - static_cast(block_scale_inverse_gate * in_gate_val); - } - } - } - } - - __syncthreads(); - - // TDM store - if constexpr (ROWWISE_SCALING) { - tdm::store_2d_to_global(out_act_rowwise_sh, output_act_rowwise_ptr, - global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, output_stride, out_data_sz); - if constexpr (IS_DGATED) { - tdm::store_2d_to_global(out_gate_rowwise_sh, output_gate_rowwise_ptr, - global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, output_stride, out_data_sz); - } - } - if constexpr (COLWISE_SCALING) { - tdm::store_2d_to_global(out_act_colwise_sh, output_act_colwise_ptr, - global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, output_stride, out_data_sz); - if constexpr (IS_DGATED) { - tdm::store_2d_to_global(out_gate_colwise_sh, output_gate_colwise_ptr, - global_offset_X, global_offset_Y, - MX_BUFF_DIM_X, MX_BUFF_DIM_Y, - cols, rows, output_stride, out_data_sz); - } - } - tdm::wait_tensorcnt_0(); - __syncthreads(); - } -#endif // defined(__gfx1250__) -} - -} // namespace nv_mxfp8 -} // namespace nv_upstream_tdm - -// --------------------------------------------------------------------------- -// NV upstream TDM launcher: FP8 gated -// --------------------------------------------------------------------------- -template -void cast_fp8_gated_nv_upstream_tdm(const Tensor &grad, const Tensor &gated_input, - Tensor *output, cudaStream_t stream) { - using namespace nv_upstream_tdm; - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, NV_CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, NV_CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(NV_THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - const IType *grad_ptr = IS_DGATED - ? reinterpret_cast(grad.data.dptr) : nullptr; - const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); - const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; - OType *output_act_ptr = reinterpret_cast(output->data.dptr); - OType *output_gate_ptr = IS_DGATED - ? reinterpret_cast(output->data.dptr) + cols : nullptr; - - constexpr size_t ALIGNMENT = 128; - constexpr size_t buff_elems = NV_SHMEM_DIM_Y * NV_SHMEM_DIM_X; - const size_t buff_size_aligned_in = - ((buff_elems * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - const size_t buff_size_aligned_out = - ((buff_elems * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel_tdm, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel_tdm - <<>>( - grad_ptr, input_act_ptr, input_gate_ptr, - output_act_ptr, output_gate_ptr, - amax_ptr, scale_inv_ptr, scale_ptr, - rows, cols, - cols * 2, cols * 2, cols, output_cols); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} - -// --------------------------------------------------------------------------- -// NV upstream TDM launcher: MXFP8 gated -// --------------------------------------------------------------------------- -template -void cast_mxfp8_gated_nv_upstream_tdm(const Tensor &grad, const Tensor &gated_input, - Tensor *output, cudaStream_t stream) { - using namespace nv_upstream_tdm; - using namespace nv_upstream_tdm::nv_mxfp8; - - const bool USE_ROWWISE_SCALING = output->has_data(); - const bool USE_COLWISE_SCALING = output->has_columnwise_data(); - - if (USE_ROWWISE_SCALING) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (USE_COLWISE_SCALING) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, MX_CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, MX_CHUNK_DIM_X); - - const size_t THREADS_PER_CHUNK = USE_COLWISE_SCALING - ? MX_THREADS_PER_CHUNK_COLWISE : MX_THREADS_PER_CHUNK_NON_COLWISE; - - const dim3 grid(blocks_X, blocks_Y); - const dim3 block_size(THREADS_PER_CHUNK); - - size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; - size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - const IType *grad_ptr = IS_DGATED - ? reinterpret_cast(grad.data.dptr) : nullptr; - const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); - const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; - OType *output_act_rowwise = USE_ROWWISE_SCALING - ? reinterpret_cast(output->data.dptr) : nullptr; - OType *output_gate_rowwise = USE_ROWWISE_SCALING - ? reinterpret_cast(output->data.dptr) + cols : nullptr; - OType *output_act_colwise = USE_COLWISE_SCALING - ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - OType *output_gate_colwise = USE_COLWISE_SCALING - ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; - - constexpr size_t ALIGNMENT = 128; - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - constexpr size_t buff_elems = MX_BUFF_DIM_Y * MX_BUFF_DIM_X; - const size_t input_buff_size = (buff_elems * input_type_bit_size) / 8; - const size_t output_buff_size = (buff_elems * output_type_bit_size) / 8; - const size_t buff_size_aligned_in = - ((input_buff_size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - const size_t buff_size_aligned_out = - ((output_buff_size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; - - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); - size_t out_mem = out_act_mem + out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } - - const size_t shmem_size = in_mem + out_mem + ALIGNMENT; - - auto launch_kernel = [&](auto rowwise_tag, auto colwise_tag, auto threads_tag) { - constexpr bool RW = decltype(rowwise_tag)::value; - constexpr bool CW = decltype(colwise_tag)::value; - constexpr size_t TPC = decltype(threads_tag)::value; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel_tdm, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel_tdm - <<>>( - grad_ptr, input_act_ptr, input_gate_ptr, - output_act_rowwise, output_gate_rowwise, - output_act_colwise, output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise, - cols * 2, cols * 2, cols, output_cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - }; - - if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { - launch_kernel(std::true_type{}, std::false_type{}, - std::integral_constant{}); - } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { - launch_kernel(std::false_type{}, std::true_type{}, - std::integral_constant{}); - } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { - launch_kernel(std::true_type{}, std::true_type{}, - std::integral_constant{}); - } - ); // NOLINT(*) - ); // NOLINT(*) -} - -#endif // __HIP_PLATFORM_AMD__ - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { - checkCuDriverContext(stream); - - const bool USE_ROWWISE_SCALING = output->has_data(); - const bool USE_COLWISE_SCALING = output->has_columnwise_data(); - - if (USE_ROWWISE_SCALING) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (USE_COLWISE_SCALING) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - -#ifndef __HIP_PLATFORM_AMD__ - ScalingType scaling_type; - if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { - scaling_type = ScalingType::COLWISE; - } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { - scaling_type = ScalingType::BIDIMENSIONAL; - } -#endif - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - -#ifdef __HIP_PLATFORM_AMD__ - constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; - - constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; - constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; - constexpr size_t BUFFS_NUM = BUFFERS_NUM; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); -#else - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); @@ -2009,26 +1429,101 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + // Check env var: NVTE_USE_NV_UPSTREAM_FLOW=1 selects NV upstream (TDM) kernel path + static const bool use_nv_upstream_mx = [] { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_nv_upstream_mx) { + // NV upstream flow with TDM — uses mxfp8_kernel::cast_mxfp8_gated_kernel + constexpr size_t NV_THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; + constexpr size_t NV_THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; + + // Recompute shmem size with NV upstream constants + constexpr size_t NV_BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t NV_BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t NV_BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + constexpr size_t NV_ALIGNMENT = 128; + + const size_t nv_buff_elems_total = NV_BUFFS_NUM * NV_BUFF_DIM_Y * NV_BUFF_DIM_X; + const size_t nv_input_buff_size = (nv_buff_elems_total * input_type_bit_size) / 8; + const size_t nv_output_buff_size = (nv_buff_elems_total * output_type_bit_size) / 8; + const size_t nv_buff_size_aligned_in = + ((nv_input_buff_size + NV_ALIGNMENT - 1) / NV_ALIGNMENT) * NV_ALIGNMENT; + const size_t nv_buff_size_aligned_out = + ((nv_output_buff_size + NV_ALIGNMENT - 1) / NV_ALIGNMENT) * NV_ALIGNMENT; + const size_t nv_grad_mem = (IS_DGATED ? nv_buff_size_aligned_in : 0); + const size_t nv_in_act_mem = nv_buff_size_aligned_in; + const size_t nv_in_gate_mem = nv_buff_size_aligned_in; + const size_t nv_in_mem = nv_grad_mem + nv_in_act_mem + nv_in_gate_mem; + const size_t nv_out_act_mem = nv_buff_size_aligned_out; + const size_t nv_out_gate_mem = (IS_DGATED ? nv_buff_size_aligned_out : 0); + size_t nv_out_mem = nv_out_act_mem + nv_out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { nv_out_mem *= 2; } + const size_t nv_shmem_size = nv_in_mem + nv_out_mem + NV_ALIGNMENT; + + const size_t nv_blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); + const size_t nv_blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + const size_t NV_THREADS_PER_CHUNK = USE_COLWISE_SCALING + ? NV_THREADS_PER_CHUNK_COLWISE : NV_THREADS_PER_CHUNK_NON_COLWISE; + const dim3 nv_grid(nv_blocks_X, nv_blocks_Y); + const dim3 nv_block(NV_THREADS_PER_CHUNK); + + auto nv_launch = [&](auto rowwise_tag, auto colwise_tag, auto threads_tag) { + constexpr bool RW = decltype(rowwise_tag)::value; + constexpr bool CW = decltype(colwise_tag)::value; + constexpr size_t TPC = decltype(threads_tag)::value; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, nv_shmem_size)); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise, + cols * 2, output_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + nv_launch(std::true_type{}, std::false_type{}, + std::integral_constant{}); + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + nv_launch(std::false_type{}, std::true_type{}, + std::integral_constant{}); + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + nv_launch(std::true_type{}, std::true_type{}, + std::integral_constant{}); + } + } else { + // ROCm flow kernel (default on AMD) TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - }))); // NOLINT(*) + (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) + } #else switch (scaling_type) { case ScalingType::ROWWISE: @@ -2045,7 +1540,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, cols * 2, output_cols); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -2062,7 +1557,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, cols * 2, output_cols); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: @@ -2079,7 +1574,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, cols * 2, output_cols); NVTE_CHECK_CUDA(cudaGetLastError()); break; } @@ -2202,24 +1697,8 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; if (is_delayed_tensor_scaling(output->scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - // Check env var: NVTE_USE_NV_UPSTREAM_FLOW=1 selects NV upstream TDM kernel path - static const bool use_nv_upstream = [] { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; - }(); - if (use_nv_upstream && use_tma_kernels) { - cast_fp8_gated_nv_upstream_tdm( - grad, gated_input, output, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } - } -#else if (use_tma_kernels) { + // cast_fp8_gated handles both NVIDIA (TMA) and AMD (TDM) internally via #ifdef cast_fp8_gated(grad, gated_input, output, stream); } else { if constexpr (IS_DGATED) { @@ -2228,30 +1707,14 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_gated(gated_input, output, stream); } } -#endif } else if (is_mxfp_scaling(output->scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - static const bool use_nv_upstream_mx = [] { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; - }(); - if (use_nv_upstream_mx && use_tma_kernels) { - cast_mxfp8_gated_nv_upstream_tdm( - grad, gated_input, output, stream); - } else if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } -#else if (use_tma_kernels) { + // cast_mxfp8_gated handles both NVIDIA (TMA) and AMD (TDM) internally via #ifdef cast_mxfp8_gated(grad, gated_input, output, stream); } else { NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", "by 32, got input of shape ", gated_input.data.shape); } -#endif } else { NVTE_ERROR("Not supported scaling mode"); } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index b7c4cf837..63ef28cec 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -31,11 +31,11 @@ #include "transformer_engine/transformer_engine.h" #ifdef __HIP_PLATFORM_AMD__ #include "rocm_cast_kernels.cuh" +#include "tdm.cuh" #endif namespace transformer_engine { -#ifndef __HIP_PLATFORM_AMD__ namespace mxfp8_kernel { constexpr size_t SCALE_DIM_Y = 32; @@ -53,22 +53,36 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 template + bool COLWISE_SCALING, +#ifdef __HIP_PLATFORM_AMD__ + size_t SCALE_DIM_Y_TMPL, size_t SCALE_DIM_X_TMPL, bool IS_ALIGNED, +#endif + size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + cast_mxfp8_2D_kernel( +#ifdef __HIP_PLATFORM_AMD__ + const IType *__restrict__ input_ptr, + const IType *__restrict__ act_input_ptr, + OType *__restrict__ output_rowwise_ptr, + OType *__restrict__ output_colwise_ptr, +#else + const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output_rowwise, const __grid_constant__ CUtensorMap tensor_map_output_colwise, +#endif e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; +#ifndef __HIP_PLATFORM_AMD__ using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; +#endif if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { @@ -120,12 +134,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t MX_SHMEM_ALIGNMENT = 128; +#else + constexpr size_t MX_SHMEM_ALIGNMENT = TMA_SHMEM_ALIGNMENT; +#endif + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), MX_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), MX_SHMEM_ALIGNMENT); constexpr size_t elt_input_mem = buff_size_aligned_in; constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); @@ -135,10 +155,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + // Manually align dynamic SHMEM per TMA/TDM requirements using padding + uintptr_t dshmem = (base_shmem_ptr + MX_SHMEM_ALIGNMENT - 1) & + ~(static_cast(MX_SHMEM_ALIGNMENT - 1)); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_sh = reinterpret_cast(dshmem); @@ -147,7 +166,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer +#ifndef __HIP_PLATFORM_AMD__ constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; +#endif const bool is_master_thread = (threadIdx.x == 0); @@ -162,6 +183,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_amax = 0.0f; +#ifndef __HIP_PLATFORM_AMD__ // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[STAGES]; @@ -178,6 +200,25 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, &mbar[0], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch + constexpr uint32_t mx_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t mx_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + + // Prefetch first stage + if constexpr (IS_DACT) { + tdm::copy_2d_to_shared_x2( + &in_sh[0], input_ptr, block_offset_X, block_offset_Y, + &act_in_sh[0], act_input_ptr, block_offset_X, block_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_in_data_sz); + } else { + tdm::copy_2d_to_shared(&in_sh[0], input_ptr, + block_offset_X, block_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, cols, mx_in_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif // __HIP_PLATFORM_AMD__ #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { @@ -186,6 +227,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (next_stage < STAGES) { +#ifndef __HIP_PLATFORM_AMD__ // Wait for TMA transfer to have finished reading shared memory. // I.e. the buffer is ready to be written to ptx::cp_async_bulk_wait_group_read<1>(); @@ -204,12 +246,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch next stage + { + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + tdm::copy_2d_to_shared_x2( + &in_sh[next_buff_offset], input_ptr, global_offset_X, global_offset_Y, + &act_in_sh[next_buff_offset], act_input_ptr, global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_in_data_sz); + } else { + tdm::copy_2d_to_shared(&in_sh[next_buff_offset], input_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, + cols, rows, cols, mx_in_data_sz); + } + } +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); +#else + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { @@ -220,6 +287,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for __hmax/__habs (not available on AMD) +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_colwise_IType[i]))); + } +#else IType thread_amax_f16 = static_cast(0.0f); #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { @@ -228,6 +304,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } thread_amax = static_cast(thread_amax_f16); +#endif } else { #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { @@ -278,7 +355,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#ifndef __HIP_PLATFORM_AMD__ const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; +#endif // 3. Scale elements #pragma unroll @@ -303,11 +382,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float in_compute_rowwise[SCALE_DIM_X]; Vec in_cached[WAVES]; +#ifndef __HIP_PLATFORM_AMD__ // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY Vec in_IType[WAVES]; +#endif // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for abs_max_2x (PTX intrinsic not available on AMD) +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + Vec in_vec; + in_vec.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + in_cached[w].data.elt[e] = in_vec.data.elt[e]; + thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_vec.data.elt[e]))); + } + } +#else IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -323,10 +420,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); +#endif } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); +#ifndef __HIP_PLATFORM_AMD__ IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#endif #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; @@ -339,9 +439,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Load cached elements in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries if (!out_of_bounds) { +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for abs_max_2x (PTX intrinsic not available on AMD) +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_cached[w].data.elt[e]))); + } +#else if constexpr (std::is_same_v) { #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { @@ -355,12 +460,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); } } +#endif } } +#ifndef __HIP_PLATFORM_AMD__ if constexpr (!std::is_same_v) { thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } +#endif } else { #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -422,11 +530,30 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); +#ifndef __HIP_PLATFORM_AMD__ const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; +#endif // 3. Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { +#ifdef __HIP_PLATFORM_AMD__ + // Scalar fallback for mul_cvt_2x (PTX intrinsic not available on AMD) + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + float value; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + value = static_cast(in_cached[w].data.elt[e]); + } else if constexpr (IS_CACHED_ACT_OP) { + value = static_cast(in_cached[w].data.elt[e]); + } else { + const int j = w * PACK_SIZE + e; + value = in_compute_rowwise[j]; + } + out.data.elt[e] = static_cast(value * block_scale_inverse); + } +#else Vec out; #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { @@ -444,6 +571,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } +#endif const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; @@ -455,6 +583,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __builtin_assume(thread_amax >= 0); block_amax = fmaxf(block_amax, thread_amax); +#ifndef __HIP_PLATFORM_AMD__ // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); @@ -480,9 +609,32 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); } +#else // __HIP_PLATFORM_AMD__ — TDM store + __syncthreads(); + { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + tdm::store_2d_to_global(&out_rowwise_sh[buff_offset], output_rowwise_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_out_data_sz); + } + if constexpr (COLWISE_SCALING) { + tdm::store_2d_to_global(&out_colwise_sh[buff_offset], output_colwise_ptr, + global_offset_X, global_offset_Y, + BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_out_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ parity ^= 1; +#endif if constexpr (IS_DBIAS) { float thread_partial_dbias = 0.0f; @@ -539,8 +691,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) atomicMaxFloat(amax_ptr, block_amax); } +#ifndef __HIP_PLATFORM_AMD__ destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#else + tdm::wait_tensorcnt_0(); +#endif +#endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } } // namespace mxfp8_kernel @@ -563,13 +719,20 @@ static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); template __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + cast_fp8_2D_kernel( +#ifdef __HIP_PLATFORM_AMD__ + const IType *__restrict__ input_ptr, + const IType *__restrict__ act_input_ptr, + OType *__restrict__ output_ptr, +#else + const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output, +#endif float *const dbias_workspace, float *const amax_ptr, float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; @@ -591,17 +754,32 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned +#ifdef __HIP_PLATFORM_AMD__ + alignas(128) __shared__ + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + alignas(128) __shared__ + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + alignas(128) __shared__ + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; +#else __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; +#endif +#ifndef __HIP_PLATFORM_AMD__ constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; +#endif const bool is_master_thread = (threadIdx.x == 0); + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; + +#ifndef __HIP_PLATFORM_AMD__ // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; @@ -610,9 +788,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) int parity = 0; - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - #pragma unroll for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; @@ -628,6 +803,25 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) is_master_thread); } } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch + constexpr uint32_t fp8_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t fp8_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + + // Prefetch first buffer + if constexpr (IS_DACT) { + tdm::copy_2d_to_shared_x2( + &in_sh[0][0][0], input_ptr, chunk_offset_X, chunk_offset_Y, + &act_in_sh[0][0][0], act_input_ptr, chunk_offset_X, chunk_offset_Y, + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, cols, rows, cols, fp8_in_data_sz); + } else { + tdm::copy_2d_to_shared(&in_sh[0][0][0], input_ptr, + chunk_offset_X, chunk_offset_Y, + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, + cols, rows, cols, fp8_in_data_sz); + } + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif // __HIP_PLATFORM_AMD__ #pragma unroll for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { @@ -635,6 +829,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; if (next_iter < FP8_ITERATIONS) { +#ifndef __HIP_PLATFORM_AMD__ const size_t next_buff = next_iter % FP8_BUFFERS_NUM; const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; @@ -647,10 +842,33 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch next iteration + { + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + tdm::copy_2d_to_shared_x2( + &in_sh[next_buff][0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, + &act_in_sh[next_buff][0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, cols, rows, cols, fp8_in_data_sz); + } else { + tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], input_ptr, + chunk_it_offset_x, chunk_it_offset_y, + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, + cols, rows, cols, fp8_in_data_sz); + } + } +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[iter], parity); +#else + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif #pragma unroll for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { @@ -688,6 +906,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); } +#ifndef __HIP_PLATFORM_AMD__ // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); @@ -707,11 +926,27 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) // Wait for TMA transfer to have finished reading shared memory. ptx::cp_async_bulk_wait_group_read(); } +#else // __HIP_PLATFORM_AMD__ — TDM store + __syncthreads(); + { + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + tdm::store_2d_to_global(&out_sh[buff][0][0], output_ptr, + chunk_it_offset_x, chunk_it_offset_y, + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, + cols, rows, cols, fp8_out_data_sz); + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#endif // __HIP_PLATFORM_AMD__ } + +#ifndef __HIP_PLATFORM_AMD__ ptx::cp_async_bulk_wait_group_read<0>(); __syncthreads(); parity ^= 1; +#endif if constexpr (IS_DBIAS) { const size_t dbias_offset_X = my_column; @@ -736,10 +971,16 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) reciprocal(scale_inv_ptr, scale); } +#ifndef __HIP_PLATFORM_AMD__ destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#else + tdm::wait_tensorcnt_0(); +#endif +#endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } +#ifndef __HIP_PLATFORM_AMD__ +// 1D FP8 kernel uses 1D TMA — no TDM equivalent, NV-only constexpr size_t CHUNKS_PER_BLOCK = 128; constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; @@ -907,6 +1148,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, } #ifndef __HIP_PLATFORM_AMD__ +// 1D FP8 kernel uses 1D TMA — no TDM equivalent, NV-only template static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { const size_t N = product(input.data.shape); @@ -938,11 +1180,14 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } +#endif // #ifndef __HIP_PLATFORM_AMD__ (cast_fp8_1D) template void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { +#ifndef __HIP_PLATFORM_AMD__ checkCuDriverContext(stream); +#endif const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); @@ -981,6 +1226,14 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->data.dtype, OType, +#ifdef __HIP_PLATFORM_AMD__ + cast_fp8_2D_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); +#else alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_output{}; @@ -1000,6 +1253,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); +#endif NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { @@ -1007,7 +1261,6 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T }); // NOLINT(*) ); // NOLINT(*) } -#endif // #ifndef __HIP_PLATFORM_AMD__ template @@ -1107,18 +1360,75 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ))); // NOLINT(*) + { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + if (env && std::string(env) == "1") { + // NV upstream kernel with TDM + constexpr bool NV_ROWWISE = (SCALE_DIM_X > 1); + constexpr bool NV_COLWISE = (SCALE_DIM_Y > 1); + constexpr size_t NV_CAST_DBIAS_ONLY_Y = (IS_DBIAS && (!IS_DACT) && (!IS_ACT)) ? 128 : 64; + constexpr size_t NV_CAST_DBIAS_ONLY_X = NV_CAST_DBIAS_ONLY_Y; + constexpr size_t NV_CAST_DBIAS_ONLY_T = NV_CAST_DBIAS_ONLY_Y; + + constexpr size_t NV_THREADS_X = NV_CAST_DBIAS_ONLY_X / mxfp8_kernel::SCALE_DIM_X; + constexpr size_t NV_THREADS_Y = NV_CAST_DBIAS_ONLY_T / NV_THREADS_X; + constexpr size_t NV_BUFF_DIM_Y = NV_THREADS_Y; + constexpr size_t NV_BUFF_DIM_X = NV_CAST_DBIAS_ONLY_X; + + constexpr size_t NV_SHMEM_ALIGNMENT = 128; + constexpr size_t nv_buff_elems = NV_BUFF_DIM_Y * NV_BUFF_DIM_X; + constexpr size_t nv_buff_elems_total = mxfp8_kernel::BUFFS_NUM * nv_buff_elems; + constexpr size_t nv_input_type_bit_size = TypeInfo::size; + constexpr size_t nv_output_type_bit_size = TypeInfo::size; + constexpr size_t nv_input_buff_size = (nv_buff_elems_total * nv_input_type_bit_size) / 8; + constexpr size_t nv_output_buff_size = (nv_buff_elems_total * nv_output_type_bit_size) / 8; + constexpr size_t nv_buff_size_aligned_in = + DIVUP_TO_MULTIPLE(nv_input_buff_size, NV_SHMEM_ALIGNMENT); + constexpr size_t nv_buff_size_aligned_out = + DIVUP_TO_MULTIPLE(nv_output_buff_size, NV_SHMEM_ALIGNMENT); + + constexpr size_t nv_elt_input_mem = nv_buff_size_aligned_in; + constexpr size_t nv_act_input_mem = (IS_DACT ? nv_buff_size_aligned_in : 0); + constexpr size_t nv_in_mem = nv_elt_input_mem + nv_act_input_mem; + + const size_t nv_out_rowwise_mem = (use_rowwise_scaling ? nv_buff_size_aligned_out : 0); + const size_t nv_out_colwise_mem = (use_colwise_scaling ? nv_buff_size_aligned_out : 0); + const size_t nv_out_mem = nv_out_rowwise_mem + nv_out_colwise_mem; + + const size_t nv_dshmem_size = nv_in_mem + nv_out_mem + NV_SHMEM_ALIGNMENT; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, nv_dshmem_size)); + + mxfp8_kernel::cast_mxfp8_2D_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + } else { + // Default ROCm flow + cast_mxfp8_2D_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + } + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index aaeb169b1..c046d7fea 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -33,20 +33,26 @@ #include "transformer_engine/transpose.h" #ifdef __HIP_PLATFORM_AMD__ #include "rocm_dequantize_kernels.cuh" +#include "tdm.cuh" #endif namespace transformer_engine { namespace dequantization { -#ifndef __HIP_PLATFORM_AMD__ template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + dequantize_mxfp8_kernel( +#ifdef __HIP_PLATFORM_AMD__ + const IType *__restrict__ input_ptr, + OType *__restrict__ output_ptr, +#else + const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, +#endif const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, const size_t scales_stride) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 @@ -75,15 +81,23 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; // const int thread_offset_X_colwise = tid_colwise_X; - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned +#ifdef __HIP_PLATFORM_AMD__ + alignas(128) __shared__ IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; +#else __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; +#endif constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; +#ifndef __HIP_PLATFORM_AMD__ constexpr int transaction_size = shmem_buff_size; +#endif const bool is_master_thread = (threadIdx.x == 0); +#ifndef __HIP_PLATFORM_AMD__ // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[ITERATIONS]; @@ -118,12 +132,25 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Other threads just arrive ptx::mbarrier_arrive(&mbar[iteration_zero]); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch + constexpr uint32_t deq_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); + constexpr uint32_t deq_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + + // Prefetch first iteration + tdm::copy_2d_to_shared(&in_sh[0][0][0], input_ptr, + chunk_offset_X, chunk_offset_Y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, cols, deq_in_data_sz); + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif // __HIP_PLATFORM_AMD__ #pragma unroll for (int iter = 0; iter < ITERATIONS; ++iter) { const int buff = iter % BUFFERS_NUM; const int next_iter = iter + 1; if (next_iter < ITERATIONS) { +#ifndef __HIP_PLATFORM_AMD__ if (is_master_thread) { const int next_buff = next_iter % BUFFERS_NUM; const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; @@ -140,12 +167,28 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Other threads just arrive ptx::mbarrier_arrive(&mbar[next_iter]); } +#else // __HIP_PLATFORM_AMD__ — TDM prefetch next iteration + { + const int next_buff = next_iter % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], input_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, cols, deq_in_data_sz); + } +#endif // __HIP_PLATFORM_AMD__ } +#ifndef __HIP_PLATFORM_AMD__ ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[iter], parity); +#else + tdm::wait_tensorcnt_0(); + __syncthreads(); +#endif const int scale_offset_Y = USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) @@ -181,6 +224,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } +#ifndef __HIP_PLATFORM_AMD__ // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); @@ -200,7 +244,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for TMA transfer to have finished reading shared memory. ptx::cp_async_bulk_wait_group_read<1>(); } +#else // __HIP_PLATFORM_AMD__ — TDM store + __syncthreads(); + { + const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + tdm::store_2d_to_global(&out_sh[buff][0][0], output_ptr, + chunk_it_offset_x, chunk_it_offset_y, + SHMEM_DIM_X, SHMEM_DIM_Y, + cols, rows, cols, deq_out_data_sz); + tdm::wait_tensorcnt_0(); + __syncthreads(); + } +#endif // __HIP_PLATFORM_AMD__ } + +#ifndef __HIP_PLATFORM_AMD__ ptx::cp_async_bulk_wait_group_read<0>(); __syncthreads(); @@ -215,9 +274,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mbarrier_invalid(&mbar[iter]); } } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#else + tdm::wait_tensorcnt_0(); +#endif +#endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } -#endif // #ifndef __HIP_PLATFORM_AMD__ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); @@ -312,9 +373,24 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s #ifdef __HIP_PLATFORM_AMD__ TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(OType))), IS_ALIGNED, - dequantize_mxfp8_kernel - <<>>(reinterpret_cast(input_data.dptr), reinterpret_cast(output->data.dptr), scales_ptr, - rows, cols, scales_stride);); // NOLINT(*) + { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + if (env && std::string(env) == "1") { + // NV upstream kernel with TDM + dequantization::dequantize_mxfp8_kernel + <<>>( + reinterpret_cast(input_data.dptr), + reinterpret_cast(output->data.dptr), + scales_ptr, rows, cols, scales_stride); + } else { + // Default ROCm flow + dequantize_mxfp8_kernel + <<>>( + reinterpret_cast(input_data.dptr), + reinterpret_cast(output->data.dptr), + scales_ptr, rows, cols, scales_stride); + } + }); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_output{}; From 0e2d24d30b5fa1e93dde3f0d91fc1fd169d715ac Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 23 Apr 2026 14:50:01 +0000 Subject: [PATCH 04/43] [ROCm] address more reviewer comments --- transformer_engine/common/common.h | 3 + .../common/util/cast_gated_kernels.cuh | 155 +++++++----------- .../common/util/cast_kernels.cuh | 109 ++++-------- .../common/util/dequantize_kernels.cuh | 26 +-- transformer_engine/common/util/ptx.cuh | 66 ++++++-- .../common/util/rocm_cast_gated_kernels.cuh | 2 +- .../common/util/rocm_cast_kernels.cuh | 8 +- .../common/util/rocm_dequantize_kernels.cuh | 4 +- 8 files changed, 163 insertions(+), 210 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0015e9155..0fc26a05e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -731,6 +731,9 @@ constexpr size_t scale_tensor_alignment_X_rowwise = 1; constexpr size_t scale_tensor_alignment_Y_rowwise = 1; constexpr size_t scale_tensor_alignment_X_colwise = 1; constexpr size_t scale_tensor_alignment_Y_colwise = 1; + +// Alignment requirements for the Tensor Data Mover (TDM) on gfx1250 +constexpr size_t TDM_SHMEM_ALIGNMENT = 128; // shared memory address alignment #endif inline bool is_aligned_ptr(const void *ptr, size_t alignment) { diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 80a29b3f2..bb669f8f2 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -77,8 +77,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_gate, #endif float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols, - const size_t input_act_stride, const size_t output_stride) { + const float *const scale_ptr, const size_t rows, const size_t cols +#ifdef __HIP_PLATFORM_AMD__ + , const size_t input_act_stride, const size_t output_stride +#endif + ) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; @@ -95,8 +98,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA/TDM requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! #ifdef __HIP_PLATFORM_AMD__ - constexpr size_t SHMEM_ALIGNMENT = 128; + constexpr size_t SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; #else constexpr size_t SHMEM_ALIGNMENT = TMA_SHMEM_ALIGNMENT; #endif @@ -106,9 +111,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - ((buff_elems_total * sizeof(IType) + SHMEM_ALIGNMENT - 1) / SHMEM_ALIGNMENT) * SHMEM_ALIGNMENT; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - ((buff_elems_total * sizeof(OType) + SHMEM_ALIGNMENT - 1) / SHMEM_ALIGNMENT) * SHMEM_ALIGNMENT; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), SHMEM_ALIGNMENT); constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; @@ -121,6 +126,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_transaction_size = buff_elems * sizeof(IType); #endif + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); @@ -179,8 +185,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, input_act_stride, in_data_sz); } - tdm::wait_tensorcnt_0(); - __syncthreads(); #endif // __HIP_PLATFORM_AMD__ #pragma unroll @@ -224,8 +228,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, input_act_stride, in_data_sz); } - // TDM is async — wait for the prefetch of the NEXT buffer in the NEXT iteration. - // For now, we use a simple wait-all pattern; double-buffering optimization can be added later. #endif // __HIP_PLATFORM_AMD__ } @@ -235,9 +237,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[it], parity); #else - // On TDM: wait for prefetch of current buffer (issued in previous iteration or before the loop) - // and for any stores from the previous iteration to complete. - tdm::wait_tensorcnt_0(); + // Wait for current buffer's loads (and any prior stores) to complete, + // but keep the just-issued prefetch for the next buffer alive. + if (next_it < ITERATIONS) { + // Prefetch in flight: IS_DGATED issued 3 ops (1+2), non-dgated issued 2 ops + if constexpr (IS_DGATED) { + tdm::wait_tensorcnt_3(); + } else { + tdm::wait_tensorcnt_2(); + } + } else { + // Last iteration — drain all outstanding TDM ops + tdm::wait_tensorcnt_0(); + } __syncthreads(); #endif @@ -294,23 +306,29 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) ptx::fence_proxy_async_shared_cta(); __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; + // dGeLU ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast(out_act_sh_curr)); if constexpr (IS_DGATED) { + // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast(out_gate_sh_curr)); } + // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. ptx::cp_async_bulk_wait_group_read(); } #else // __HIP_PLATFORM_AMD__ — TDM store @@ -344,18 +362,23 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block amax = reduce_max(amax, warp_id); + // Update the global amax if (is_master_thread) { atomicMaxFloat(amax_ptr, amax); } } + // Update scale-inverse if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { reciprocal(scale_inv_ptr, scale); } #ifndef __HIP_PLATFORM_AMD__ - // Destroy the barriers. + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. if (is_master_thread) { #pragma unroll for (int it = 0; it < ITERATIONS; ++it) { @@ -416,13 +439,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, - const size_t input_act_stride, const size_t output_stride) { + const size_t scale_stride_colwise +#ifdef __HIP_PLATFORM_AMD__ + , const size_t input_act_stride, const size_t output_stride +#endif + ) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) -#ifndef __HIP_PLATFORM_AMD__ using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; -#endif constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); @@ -477,8 +501,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) extern __shared__ char dynamic_shmem[]; uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA/TDM requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! #ifdef __HIP_PLATFORM_AMD__ - constexpr size_t MX_SHMEM_ALIGNMENT = 128; + constexpr size_t MX_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; #else constexpr size_t MX_SHMEM_ALIGNMENT = TMA_SHMEM_ALIGNMENT; #endif @@ -488,9 +514,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - ((buff_elems_total * sizeof(IType) + MX_SHMEM_ALIGNMENT - 1) / MX_SHMEM_ALIGNMENT) * MX_SHMEM_ALIGNMENT; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), MX_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - ((buff_elems_total * sizeof(OType) + MX_SHMEM_ALIGNMENT - 1) / MX_SHMEM_ALIGNMENT) * MX_SHMEM_ALIGNMENT; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), MX_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); @@ -581,6 +607,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (next_stage < STAGES) { #ifndef __HIP_PLATFORM_AMD__ // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to ptx::cp_async_bulk_wait_group_read<1>(); const size_t next_buff = next_stage % BUFFS_NUM; @@ -790,13 +817,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t shmem_offset_elt = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; if constexpr (IS_DGATED) { -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for mul_cvt_2x (PTX intrinsic not available on AMD) - out_act_colwise_sh[shmem_offset_elt] = - static_cast(block_scale_inverse_act * after_act_colwise[i]); - out_gate_colwise_sh[shmem_offset_elt] = - static_cast(block_scale_inverse_gate * after_gate_colwise[i]); -#else OType2 out_pair; ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, @@ -804,7 +824,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair); out_act_colwise_sh[shmem_offset_elt] = out_pair.x; out_gate_colwise_sh[shmem_offset_elt] = out_pair.y; -#endif } else { const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); @@ -829,10 +848,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); -#ifndef __HIP_PLATFORM_AMD__ IType2 thread_amax_2x_act = {static_cast(0.0f), static_cast(0.0f)}; IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; -#endif #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; @@ -851,18 +868,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries if (!out_of_bounds) { -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for abs_max_2x (PTX intrinsic not available on AMD) -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax_act = fmaxf(thread_amax_act, - fabsf(static_cast(in_cached_act[w].data.elt[e]))); - if constexpr (IS_DGATED) { - thread_amax_gate = fmaxf(thread_amax_gate, - fabsf(static_cast(in_cached_gate[w].data.elt[e]))); - } - } -#else if constexpr (std::is_same_v) { #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { @@ -884,19 +889,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } } -#endif // __HIP_PLATFORM_AMD__ } } -#ifndef __HIP_PLATFORM_AMD__ if constexpr (!std::is_same_v) { - thread_amax_act = static_cast( - __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); + thread_amax_act = fmaxf(fabsf(static_cast(thread_amax_2x_act.x)), + fabsf(static_cast(thread_amax_2x_act.y))); if constexpr (IS_DGATED) { - thread_amax_gate = static_cast( - __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); + thread_amax_gate = fmaxf(fabsf(static_cast(thread_amax_2x_gate.x)), + fabsf(static_cast(thread_amax_2x_gate.y))); } } -#endif } else { #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -980,15 +982,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); -#ifndef __HIP_PLATFORM_AMD__ const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act, block_scale_inverse_act}; -#endif float block_scale_inverse_gate; -#ifndef __HIP_PLATFORM_AMD__ ptx::floatx2 block_scale_inverse_2x_gate; -#endif if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); @@ -997,9 +995,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx_gate] = biased_exponent_gate; } block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); -#ifndef __HIP_PLATFORM_AMD__ block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate}; -#endif } // 3. Scale elements @@ -1009,31 +1005,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for mul_cvt_2x (PTX intrinsic not available on AMD) -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - float in_act_val; - if constexpr (IS_CACHED_ACT_OP) { - in_act_val = static_cast(in_cached_act[w].data.elt[e]); - } else { - in_act_val = after_act_rowwise[w * PACK_SIZE + e]; - } - out_act_rowwise_sh[shmem_offset_rowwise + e] = - static_cast(block_scale_inverse_act * in_act_val); - - if constexpr (IS_DGATED) { - float in_gate_val; - if constexpr (IS_CACHED_ACT_OP) { - in_gate_val = static_cast(in_cached_gate[w].data.elt[e]); - } else { - in_gate_val = after_gate_rowwise[w * PACK_SIZE + e]; - } - out_gate_rowwise_sh[shmem_offset_rowwise + e] = - static_cast(block_scale_inverse_gate * in_gate_val); - } - } -#else Vec out_act; Vec out_gate; #pragma unroll @@ -1070,7 +1041,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); } -#endif // __HIP_PLATFORM_AMD__ } } @@ -1201,19 +1171,18 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu OType *output_gate_ptr = IS_DGATED ? reinterpret_cast(output->data.dptr) + cols : nullptr; - constexpr size_t ALIGNMENT = 128; constexpr size_t buff_elems_total = nv_flow::BUFFERS_NUM * nv_flow::SHMEM_DIM_Y * nv_flow::SHMEM_DIM_X; const size_t buff_size_aligned_in = - ((buff_elems_total * sizeof(IType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TDM_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = - ((buff_elems_total * sizeof(OType) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TDM_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + ALIGNMENT; + (out_act_mem + out_gate_mem) + TDM_SHMEM_ALIGNMENT; NVTE_CHECK_CUDA(cudaFuncSetAttribute( nv_flow::cast_fp8_gated_kernel, @@ -1271,7 +1240,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols, cols * 2, output_cols); + cols); NVTE_CHECK_CUDA(cudaGetLastError()); #endif // __HIP_PLATFORM_AMD__ ); // NOLINT(*) @@ -1443,15 +1412,13 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out constexpr size_t NV_BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; constexpr size_t NV_BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; constexpr size_t NV_BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; - constexpr size_t NV_ALIGNMENT = 128; - const size_t nv_buff_elems_total = NV_BUFFS_NUM * NV_BUFF_DIM_Y * NV_BUFF_DIM_X; const size_t nv_input_buff_size = (nv_buff_elems_total * input_type_bit_size) / 8; const size_t nv_output_buff_size = (nv_buff_elems_total * output_type_bit_size) / 8; const size_t nv_buff_size_aligned_in = - ((nv_input_buff_size + NV_ALIGNMENT - 1) / NV_ALIGNMENT) * NV_ALIGNMENT; + DIVUP_TO_MULTIPLE(nv_input_buff_size, TDM_SHMEM_ALIGNMENT); const size_t nv_buff_size_aligned_out = - ((nv_output_buff_size + NV_ALIGNMENT - 1) / NV_ALIGNMENT) * NV_ALIGNMENT; + DIVUP_TO_MULTIPLE(nv_output_buff_size, TDM_SHMEM_ALIGNMENT); const size_t nv_grad_mem = (IS_DGATED ? nv_buff_size_aligned_in : 0); const size_t nv_in_act_mem = nv_buff_size_aligned_in; const size_t nv_in_gate_mem = nv_buff_size_aligned_in; @@ -1460,7 +1427,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t nv_out_gate_mem = (IS_DGATED ? nv_buff_size_aligned_out : 0); size_t nv_out_mem = nv_out_act_mem + nv_out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { nv_out_mem *= 2; } - const size_t nv_shmem_size = nv_in_mem + nv_out_mem + NV_ALIGNMENT; + const size_t nv_shmem_size = nv_in_mem + nv_out_mem + TDM_SHMEM_ALIGNMENT; const size_t nv_blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); const size_t nv_blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); @@ -1540,7 +1507,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, cols * 2, output_cols); + scale_stride_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1557,7 +1524,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, cols * 2, output_cols); + scale_stride_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: @@ -1574,7 +1541,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, cols * 2, output_cols); + scale_stride_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); break; } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 63ef28cec..92debf9c6 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -79,10 +79,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; -#ifndef __HIP_PLATFORM_AMD__ using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; -#endif if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { @@ -135,7 +133,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int bank_group = thread_lane / THREADS_PER_BANK; #ifdef __HIP_PLATFORM_AMD__ - constexpr size_t MX_SHMEM_ALIGNMENT = 128; + constexpr size_t MX_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; #else constexpr size_t MX_SHMEM_ALIGNMENT = TMA_SHMEM_ALIGNMENT; #endif @@ -216,8 +214,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_in_data_sz); } - tdm::wait_tensorcnt_0(); - __syncthreads(); #endif // __HIP_PLATFORM_AMD__ #pragma unroll @@ -274,7 +270,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); #else - tdm::wait_tensorcnt_0(); + // Wait for current buffer's loads (and any prior stores) to complete, + // but keep the just-issued prefetch for the next buffer alive. + if (next_stage < STAGES) { + // Prefetch in flight: IS_DACT issued 2 ops, non-DACT issued 1 op + if constexpr (IS_DACT) { + tdm::wait_tensorcnt_2(); + } else { + tdm::wait_tensorcnt_1(); + } + } else { + // Last iteration — drain all outstanding TDM ops + tdm::wait_tensorcnt_0(); + } __syncthreads(); #endif @@ -287,24 +295,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for __hmax/__habs (not available on AMD) #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; in_colwise_IType[i] = in_sh[shmem_offset_colwise]; thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_colwise_IType[i]))); } -#else - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); -#endif } else { #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { @@ -382,29 +378,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float in_compute_rowwise[SCALE_DIM_X]; Vec in_cached[WAVES]; -#ifndef __HIP_PLATFORM_AMD__ // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY Vec in_IType[WAVES]; -#endif // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for abs_max_2x (PTX intrinsic not available on AMD) -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - Vec in_vec; - in_vec.load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - in_cached[w].data.elt[e] = in_vec.data.elt[e]; - thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_vec.data.elt[e]))); - } - } -#else IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -419,14 +397,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); -#endif + fmaxf(fabsf(static_cast(thread_amax_2x.x)), fabsf(static_cast(thread_amax_2x.y))); } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); -#ifndef __HIP_PLATFORM_AMD__ IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#endif #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; @@ -440,13 +415,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Load cached elements in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); if (!out_of_bounds) { -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for abs_max_2x (PTX intrinsic not available on AMD) -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_cached[w].data.elt[e]))); - } -#else if constexpr (std::is_same_v) { #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { @@ -460,15 +428,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); } } -#endif } } -#ifndef __HIP_PLATFORM_AMD__ if constexpr (!std::is_same_v) { thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + fmaxf(fabsf(static_cast(thread_amax_2x.x)), fabsf(static_cast(thread_amax_2x.y))); } -#endif } else { #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -530,30 +495,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); -#ifndef __HIP_PLATFORM_AMD__ const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -#endif // 3. Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { -#ifdef __HIP_PLATFORM_AMD__ - // Scalar fallback for mul_cvt_2x (PTX intrinsic not available on AMD) - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - float value; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - value = static_cast(in_cached[w].data.elt[e]); - } else if constexpr (IS_CACHED_ACT_OP) { - value = static_cast(in_cached[w].data.elt[e]); - } else { - const int j = w * PACK_SIZE + e; - value = in_compute_rowwise[j]; - } - out.data.elt[e] = static_cast(value * block_scale_inverse); - } -#else Vec out; #pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { @@ -571,7 +517,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } -#endif const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; @@ -755,11 +700,11 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned #ifdef __HIP_PLATFORM_AMD__ - alignas(128) __shared__ + alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - alignas(128) __shared__ + alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - alignas(128) __shared__ + alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; #else __shared__ alignas(TMA_SHMEM_ALIGNMENT) @@ -819,8 +764,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, cols, rows, cols, fp8_in_data_sz); } - tdm::wait_tensorcnt_0(); - __syncthreads(); #endif // __HIP_PLATFORM_AMD__ #pragma unroll @@ -866,7 +809,17 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[iter], parity); #else - tdm::wait_tensorcnt_0(); + // Wait for current buffer's loads (and any prior stores) to complete, + // but keep the just-issued prefetch for the next buffer alive. + if (next_iter < FP8_ITERATIONS) { + if constexpr (IS_DACT) { + tdm::wait_tensorcnt_2(); + } else { + tdm::wait_tensorcnt_1(); + } + } else { + tdm::wait_tensorcnt_0(); + } __syncthreads(); #endif @@ -1375,7 +1328,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, constexpr size_t NV_BUFF_DIM_Y = NV_THREADS_Y; constexpr size_t NV_BUFF_DIM_X = NV_CAST_DBIAS_ONLY_X; - constexpr size_t NV_SHMEM_ALIGNMENT = 128; + constexpr size_t NV_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; constexpr size_t nv_buff_elems = NV_BUFF_DIM_Y * NV_BUFF_DIM_X; constexpr size_t nv_buff_elems_total = mxfp8_kernel::BUFFS_NUM * nv_buff_elems; constexpr size_t nv_input_type_bit_size = TypeInfo::size; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index c046d7fea..2a01d1c1e 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -83,8 +83,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned #ifdef __HIP_PLATFORM_AMD__ - alignas(128) __shared__ IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; - alignas(128) __shared__ OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; #else __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; @@ -141,8 +141,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) chunk_offset_X, chunk_offset_Y, SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, deq_in_data_sz); - tdm::wait_tensorcnt_0(); - __syncthreads(); #endif // __HIP_PLATFORM_AMD__ #pragma unroll @@ -186,7 +184,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[iter], parity); #else - tdm::wait_tensorcnt_0(); + // Wait for current buffer's loads (and any prior stores) to complete, + // but keep the just-issued prefetch for the next buffer alive. + if (next_iter < ITERATIONS) { + tdm::wait_tensorcnt_1(); // 1 prefetch load in flight + } else { + tdm::wait_tensorcnt_0(); // Last iteration — drain all + } __syncthreads(); #endif @@ -339,17 +343,13 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s const size_t unpadded_scales_X_colwise = cols; const size_t scales_Y_rowwise = - DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * - scale_tensor_alignment_Y_rowwise; + DIVUP_TO_MULTIPLE(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise); const size_t scales_X_rowwise = - DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * - scale_tensor_alignment_X_rowwise; + DIVUP_TO_MULTIPLE(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise); const size_t scales_Y_colwise = - DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * - scale_tensor_alignment_Y_colwise; + DIVUP_TO_MULTIPLE(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise); const size_t scales_X_colwise = - DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * - scale_tensor_alignment_X_colwise; + DIVUP_TO_MULTIPLE(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise); const e8m0_t *const scales_ptr = use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 7c38a337b..84cc75007 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -156,6 +156,24 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { #endif } +template +struct alignas(2 * sizeof(T)) FPx2 { + T x; + T y; +}; + +using floatx2 = FPx2; +using bf16x2 = FPx2; +using fp16x2 = FPx2; +using fp8e4m3x2 = FPx2; +using fp8e5m2x2 = FPx2; + +static_assert(sizeof(floatx2) == 8); +static_assert(sizeof(bf16x2) == 4); +static_assert(sizeof(fp16x2) == 4); +static_assert(sizeof(fp8e4m3x2) == 2); +static_assert(sizeof(fp8e5m2x2) == 2); + #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -221,24 +239,6 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } -template -struct alignas(2 * sizeof(T)) FPx2 { - T x; - T y; -}; - -using floatx2 = FPx2; -using bf16x2 = FPx2; -using fp16x2 = FPx2; -using fp8e4m3x2 = FPx2; -using fp8e5m2x2 = FPx2; - -static_assert(sizeof(floatx2) == 8); -static_assert(sizeof(bf16x2) == 4); -static_assert(sizeof(fp16x2) == 4); -static_assert(sizeof(fp8e4m3x2) == 2); -static_assert(sizeof(fp8e5m2x2) == 2); - // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -376,6 +376,36 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +#ifdef __HIP_PLATFORM_AMD__ +// AMD scalar fallbacks for PTX SIMD intrinsics. +// These provide the same interface as the CUDA PTX versions but use +// scalar float operations, allowing kernel code to avoid #ifdef guards. + +// abs_max_2x: dst = max(|dst|, max(|p1|, |p2|)) per element +template +__device__ __forceinline__ void abs_max_2x(FPx2 &dst, const FPx2 &p1, const FPx2 &p2) { + float ax = fmaxf(fabsf(static_cast(p1.x)), fabsf(static_cast(p2.x))); + float ay = fmaxf(fabsf(static_cast(p1.y)), fabsf(static_cast(p2.y))); + dst.x = static_cast(ax); + dst.y = static_cast(ay); +} + +// mul_cvt_2x: out = (OType)(in * scale) per element +// float input version +template +__device__ __forceinline__ void mul_cvt_2x(OType2T &out, const floatx2 &in, const floatx2 &scale) { + out.x = static_cast(in.x * scale.x); + out.y = static_cast(in.y * scale.y); +} + +// bf16/fp16 input version +template +__device__ __forceinline__ void mul_cvt_2x(OType2T &out, const IType2T &in, const floatx2 &scale) { + out.x = static_cast(static_cast(in.x) * scale.x); + out.y = static_cast(static_cast(in.y) * scale.y); +} +#endif // __HIP_PLATFORM_AMD__ + } // namespace ptx namespace { diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index cafec34e1..e0df26c1d 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -22,7 +22,7 @@ namespace transformer_engine { namespace gated_kernels { -constexpr size_t ALIGNMENT_SIZE = 128; +constexpr size_t ALIGNMENT_SIZE = TDM_SHMEM_ALIGNMENT; // TODO: Identify optimal chunk/thread size for MI350+ constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index b512c3718..808b5e59d 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -128,10 +128,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(128) __shared__ IType in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ IType act_in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ OType out_rowwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(128) __shared__ OType out_colwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType act_in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_rowwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_colwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; float block_amax = 0; diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 5aae8ede1..5b541062e 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -79,8 +79,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - alignas(128) __shared__ OType out_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; for (int iter = 0; iter < ITERATIONS; iter++) { const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; From 0a13bbc054dc4e047dee9c92a6725f20a8b733c8 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 23 Apr 2026 15:09:19 +0000 Subject: [PATCH 05/43] [ROCm] Address TDM review comments: remove extra params, add explanatory comments - Remove input_act_stride/output_stride as kernel params in gated kernels; compute them inside the kernel from cols and IS_DGATED template param - Add comments explaining why TDM does not need in_transaction_size (uses s_wait_tensorcnt counting ops, not mbarrier counting bytes) Co-Authored-By: Claude Opus 4 --- .../common/util/cast_gated_kernels.cuh | 28 ++++++++++++------- .../common/util/cast_kernels.cuh | 2 ++ .../common/util/dequantize_kernels.cuh | 2 ++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index bb669f8f2..b6cd3cff5 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -78,12 +78,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif float *const amax_ptr, float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, const size_t cols -#ifdef __HIP_PLATFORM_AMD__ - , const size_t input_act_stride, const size_t output_stride -#endif ) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) +#ifdef __HIP_PLATFORM_AMD__ + // TDM needs explicit strides. For gated inputs, act and gate are interleaved → stride = 2*cols. + // For outputs, IS_DGATED interleaves dact/dgate → stride = 2*cols; otherwise stride = cols. + const size_t input_act_stride = cols * 2; + const size_t output_stride = IS_DGATED ? cols * 2 : cols; +#endif + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; @@ -123,6 +127,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t out_act_mem = buff_size_aligned_out; #ifndef __HIP_PLATFORM_AMD__ + // TMA mbarriers require the expected byte count to know when the async copy is done. + // TDM does not need this — it uses s_wait_tensorcnt which counts outstanding ops, not bytes. constexpr size_t in_transaction_size = buff_elems * sizeof(IType); #endif @@ -440,14 +446,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise -#ifdef __HIP_PLATFORM_AMD__ - , const size_t input_act_stride, const size_t output_stride -#endif ) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; +#ifdef __HIP_PLATFORM_AMD__ + // TDM needs explicit strides. For gated inputs, act and gate are interleaved → stride = 2*cols. + // For outputs, IS_DGATED interleaves dact/dgate → stride = 2*cols; otherwise stride = cols. + const size_t input_act_stride = cols * 2; + const size_t output_stride = IS_DGATED ? cols * 2 : cols; +#endif + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); @@ -1192,8 +1202,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu <<>>( grad_ptr, input_act_ptr, input_gate_ptr, output_act_ptr, output_gate_ptr, - amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, - cols * 2, output_cols); + amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); NVTE_CHECK_CUDA(cudaGetLastError()); #else alignas(64) CUtensorMap tensor_map_grad{}; @@ -1453,8 +1462,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise, - cols * 2, output_cols); + rows, cols, scale_stride_rowwise, scale_stride_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 92debf9c6..334333023 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -165,6 +165,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer #ifndef __HIP_PLATFORM_AMD__ + // TMA mbarriers require the expected byte count to know when the async copy is done. + // TDM does not need this — it uses s_wait_tensorcnt which counts outstanding ops, not bytes. constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; #endif diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 2a01d1c1e..721f05de2 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -92,6 +92,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; #ifndef __HIP_PLATFORM_AMD__ + // TMA mbarriers require the expected byte count to know when the async copy is done. + // TDM does not need this — it uses s_wait_tensorcnt which counts outstanding ops, not bytes. constexpr int transaction_size = shmem_buff_size; #endif From bfb7199c02f71f62cb46034a9e5889d04c4c7de0 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 23 Apr 2026 16:11:25 +0000 Subject: [PATCH 06/43] [ROCm] Address remaining review comments and enable TDM flow in CI gtests - Fix `) {` placement to minimize diff in gated kernel signatures - Fix MXFP8 gated kernel: remove unnecessary pre-loop wait, make in-loop wait conditional to preserve double-buffering prefetch - Add comments explaining TDM does not need mbarrier destroy - Add NVTE_USE_NV_UPSTREAM_FLOW=1 ctest run in ci/core.sh to exercise TDM kernel paths for MXFP8 quantize, gated, and dequantize Co-Authored-By: Claude Opus 4 --- ci/core.sh | 5 ++++ .../common/util/cast_gated_kernels.cuh | 24 +++++++++++++------ .../common/util/cast_kernels.cuh | 2 ++ .../common/util/dequantize_kernels.cuh | 1 + 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ci/core.sh b/ci/core.sh index cf08c2185..ea3515fee 100755 --- a/ci/core.sh +++ b/ci/core.sh @@ -33,6 +33,11 @@ if [ $? -eq 0 ]; then echo ===== Run non GEMM tests ===== ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "GEMMTestSuite" test $? -eq 0 || test_run_error "non-GEMM" + + echo ===== Run non GEMM tests with NV upstream TDM flow ===== + NVTE_USE_NV_UPSTREAM_FLOW=1 ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure \ + -R "FusedCastMXFP8TestSuite|CastMXFP8_GatedActTestSuite|DequantizeMXFP8TestSuite" + test $? -eq 0 || test_run_error "non-GEMM NV upstream flow" fi check_test_filter "gemm" diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index b6cd3cff5..a48ef7508 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -77,8 +77,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_gate, #endif float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols - ) { + const float *const scale_ptr, const size_t rows, const size_t cols) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) #ifdef __HIP_PLATFORM_AMD__ @@ -385,6 +384,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Destroy the barriers. This invalidates the memory region of the barrier. // If further computations were to take place in the kernel, this allows the // memory location of the shared memory barrier to be reused. + // TDM does not use mbarriers — it uses s_wait_tensorcnt, so no barrier destroy is needed. if (is_master_thread) { #pragma unroll for (int it = 0; it < ITERATIONS; ++it) { @@ -445,8 +445,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise - ) { + const size_t scale_stride_colwise) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -604,8 +603,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) BUFF_DIM_X, BUFF_DIM_Y, cols, rows, input_act_stride, mx_in_data_sz); } - tdm::wait_tensorcnt_0(); - __syncthreads(); #endif // __HIP_PLATFORM_AMD__ #pragma unroll @@ -669,7 +666,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); #else - tdm::wait_tensorcnt_0(); + // Wait for current buffer's loads (and any prior stores) to complete, + // but keep the just-issued prefetch for the next buffer alive. + if (next_stage < STAGES) { + // Prefetch in flight: IS_DGATED issued 3 ops (1+2), non-dgated issued 2 ops + if constexpr (IS_DGATED) { + tdm::wait_tensorcnt_3(); + } else { + tdm::wait_tensorcnt_2(); + } + } else { + // Last stage — drain all outstanding TDM ops + tdm::wait_tensorcnt_0(); + } __syncthreads(); #endif @@ -1130,6 +1139,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #ifndef __HIP_PLATFORM_AMD__ parity ^= 1; + // TDM does not use mbarriers — it uses s_wait_tensorcnt, so no barrier destroy is needed. destroy_barriers(mbar, is_master_thread); #else tdm::wait_tensorcnt_0(); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 334333023..6cacfd2ca 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -639,6 +639,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #ifndef __HIP_PLATFORM_AMD__ + // TDM does not use mbarriers — it uses s_wait_tensorcnt, so no barrier destroy is needed. destroy_barriers(mbar, is_master_thread); #else tdm::wait_tensorcnt_0(); @@ -927,6 +928,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) } #ifndef __HIP_PLATFORM_AMD__ + // TDM does not use mbarriers — it uses s_wait_tensorcnt, so no barrier destroy is needed. destroy_barriers(mbar, is_master_thread); #else tdm::wait_tensorcnt_0(); diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 721f05de2..5cbf08d4f 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -274,6 +274,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Destroy barrier. This invalidates the memory region of the barrier. If // further computations were to take place in the kernel, this allows the // memory location of the shared memory barrier to be reused. + // TDM does not use mbarriers — it uses s_wait_tensorcnt, so no barrier destroy is needed. if (is_master_thread) { #pragma unroll for (int iter = 0; iter < ITERATIONS; ++iter) { From ba4bbb72adccba172439f391f4a2bee22091db26 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Fri, 24 Apr 2026 12:16:29 -0500 Subject: [PATCH 07/43] tdm: clamp tensorDim to avoid uint32_t underflow on OOB prefetch tiles When a double-buffered prefetch tile origin falls past the tensor boundary (non-tile-aligned rows/cols), tensor_h - tile_row and tensor_w - tile_col would underflow as uint32_t to ~4 billion, causing the TDM hardware to attempt a DMA of billions of rows and trigger a GPU page fault. Clamp the remaining extent to 0 when tile_row >= tensor_h or tile_col >= tensor_w. Unlike NV TMA (which encodes full tensor shape in a host-side CUtensorMap and clamps automatically), TDM computes the remaining extent per-call, so the caller must guard against out-of-bounds origins. Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/util/tdm.cuh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/util/tdm.cuh b/transformer_engine/common/util/tdm.cuh index d3b5a4a44..b618e9808 100644 --- a/transformer_engine/common/util/tdm.cuh +++ b/transformer_engine/common/util/tdm.cuh @@ -95,9 +95,10 @@ void load_2d_to_lds(const void* global_base, g0.globalAddr(reinterpret_cast(tile_start)); g1.dataSize(data_size); - // tensorDim = remaining extent from tile start to tensor edge (for OOB clamping). - g1.tensorDim0(tensor_w - tile_col); - g1.tensorDim1(tensor_h - tile_row); + // Clamp remaining extent to avoid uint32_t underflow when a prefetch tile origin + // falls past the tensor boundary (e.g. last block in a non-tile-aligned dimension). + g1.tensorDim0(tile_col < tensor_w ? tensor_w - tile_col : 0u); + g1.tensorDim1(tile_row < tensor_h ? tensor_h - tile_row : 0u); g1.tileDim0(tile_dim_x); g1.tileDim1(tile_dim_y); g1.tensorDim0Stride(stride_elements); @@ -134,8 +135,8 @@ void store_2d_from_lds(void* global_base, g0.globalAddr(reinterpret_cast(tile_start)); g1.dataSize(data_size); - g1.tensorDim0(tensor_w - tile_col); - g1.tensorDim1(tensor_h - tile_row); + g1.tensorDim0(tile_col < tensor_w ? tensor_w - tile_col : 0u); + g1.tensorDim1(tile_row < tensor_h ? tensor_h - tile_row : 0u); g1.tileDim0(tile_dim_x); g1.tileDim1(tile_dim_y); g1.tensorDim0Stride(stride_elements); From ab77fbf201dba84701bd8bbcc35aebb3b5a74fdf Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Fri, 24 Apr 2026 22:41:44 -0500 Subject: [PATCH 08/43] tdm: add HIPTensorMap descriptor struct; revert TDM from rocm_* kernels Introduce HIPTensorMap/HIPTensorMapOut structs in tdm.cuh as the AMD analog of CUtensorMap. Callers in cast_kernels.cuh and cast_gated_kernels.cuh now construct one descriptor per tensor at kernel entry and pass it to TDM helper calls instead of repeating 6+ raw scalars at every call site. Revert TDM usage in rocm_cast_kernels.cuh, rocm_cast_gated_kernels.cuh, and rocm_dequantize_kernels.cuh back to the original HIP vectorized copy_2d_to_shared / bulk_tensor_2d_shared_to_global path. The rocm_* kernels are the legacy non-TDM path; TDM is used only in the NV-upstream ported kernels (cast_kernels.cuh / cast_gated_kernels.cuh) behind NVTE_USE_NV_UPSTREAM_FLOW. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 172 ++++++++++-------- .../common/util/cast_kernels.cuh | 104 ++++++----- .../common/util/rocm_cast_gated_kernels.cuh | 51 ------ .../common/util/rocm_cast_kernels.cuh | 34 ---- .../common/util/rocm_dequantize_kernels.cuh | 22 --- transformer_engine/common/util/tdm.cuh | 130 +++++++++++++ 6 files changed, 288 insertions(+), 225 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index a48ef7508..304eff7cd 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -172,23 +172,40 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr uint32_t in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + const tdm::HIPTensorMap tmap_grad{grad_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + SHMEM_DIM_X, SHMEM_DIM_Y, in_data_sz}; + const tdm::HIPTensorMap tmap_act{input_act_ptr, + static_cast(cols), + static_cast(rows), + static_cast(input_act_stride), + SHMEM_DIM_X, SHMEM_DIM_Y, in_data_sz}; + const tdm::HIPTensorMap tmap_gate{input_gate_ptr, + static_cast(cols), + static_cast(rows), + static_cast(input_act_stride), + SHMEM_DIM_X, SHMEM_DIM_Y, in_data_sz}; + const tdm::HIPTensorMapOut tmap_out_act{output_act_ptr, + static_cast(cols), + static_cast(rows), + static_cast(output_stride), + SHMEM_DIM_X, SHMEM_DIM_Y, out_data_sz}; + const tdm::HIPTensorMapOut tmap_out_gate{output_gate_ptr, + static_cast(cols), + static_cast(rows), + static_cast(output_stride), + SHMEM_DIM_X, SHMEM_DIM_Y, out_data_sz}; + // Prefetch data of the first stage if constexpr (IS_DGATED) { - tdm::copy_2d_to_shared(in_grad_sh, grad_ptr, - chunk_offset_X, chunk_offset_Y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, cols, in_data_sz); - tdm::copy_2d_to_shared_x2( - in_act_sh, input_act_ptr, chunk_offset_X, chunk_offset_Y, - in_gate_sh, input_gate_ptr, chunk_offset_X, chunk_offset_Y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, input_act_stride, in_data_sz); + tdm::copy_2d_to_shared(in_grad_sh, tmap_grad, chunk_offset_X, chunk_offset_Y); + tdm::copy_2d_to_shared_x2(in_act_sh, tmap_act, chunk_offset_X, chunk_offset_Y, + in_gate_sh, tmap_gate, chunk_offset_X, chunk_offset_Y); } else { - tdm::copy_2d_to_shared_x2( - in_act_sh, input_act_ptr, chunk_offset_X, chunk_offset_Y, - in_gate_sh, input_gate_ptr, chunk_offset_X, chunk_offset_Y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, input_act_stride, in_data_sz); + tdm::copy_2d_to_shared_x2(in_act_sh, tmap_act, chunk_offset_X, chunk_offset_Y, + in_gate_sh, tmap_gate, chunk_offset_X, chunk_offset_Y); } #endif // __HIP_PLATFORM_AMD__ @@ -217,21 +234,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #else // __HIP_PLATFORM_AMD__ — TDM prefetch if constexpr (IS_DGATED) { - tdm::copy_2d_to_shared(&in_grad_sh[next_buff * buff_elems], grad_ptr, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, cols, in_data_sz); + tdm::copy_2d_to_shared(&in_grad_sh[next_buff * buff_elems], tmap_grad, + chunk_it_offset_x, chunk_it_offset_y); tdm::copy_2d_to_shared_x2( - &in_act_sh[next_buff * buff_elems], input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, input_act_stride, in_data_sz); + &in_act_sh[next_buff * buff_elems], tmap_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], tmap_gate, chunk_it_offset_x, chunk_it_offset_y); } else { tdm::copy_2d_to_shared_x2( - &in_act_sh[next_buff * buff_elems], input_act_ptr, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], input_gate_ptr, chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, input_act_stride, in_data_sz); + &in_act_sh[next_buff * buff_elems], tmap_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], tmap_gate, chunk_it_offset_x, chunk_it_offset_y); } #endif // __HIP_PLATFORM_AMD__ } @@ -342,15 +353,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; - tdm::store_2d_to_global(out_act_sh_curr, output_act_ptr, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, output_stride, out_data_sz); + tdm::store_2d_to_global(out_act_sh_curr, tmap_out_act, + chunk_it_offset_x, chunk_it_offset_y); if constexpr (IS_DGATED) { - tdm::store_2d_to_global(out_gate_sh_curr, output_gate_ptr, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, - cols, rows, output_stride, out_data_sz); + tdm::store_2d_to_global(out_gate_sh_curr, tmap_out_gate, + chunk_it_offset_x, chunk_it_offset_y); } // TDM stores are async — they will be drained at the top of the next iteration // (or after the loop for the last iteration). @@ -585,23 +592,50 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr uint32_t mx_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); constexpr uint32_t mx_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + const tdm::HIPTensorMap mx_tmap_grad{grad_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + BUFF_DIM_X, BUFF_DIM_Y, mx_in_data_sz}; + const tdm::HIPTensorMap mx_tmap_act{input_act_ptr, + static_cast(cols), + static_cast(rows), + static_cast(input_act_stride), + BUFF_DIM_X, BUFF_DIM_Y, mx_in_data_sz}; + const tdm::HIPTensorMap mx_tmap_gate{input_gate_ptr, + static_cast(cols), + static_cast(rows), + static_cast(input_act_stride), + BUFF_DIM_X, BUFF_DIM_Y, mx_in_data_sz}; + const tdm::HIPTensorMapOut mx_tmap_out_act_rw{output_act_rowwise_ptr, + static_cast(cols), + static_cast(rows), + static_cast(output_stride), + BUFF_DIM_X, BUFF_DIM_Y, mx_out_data_sz}; + const tdm::HIPTensorMapOut mx_tmap_out_gate_rw{output_gate_rowwise_ptr, + static_cast(cols), + static_cast(rows), + static_cast(output_stride), + BUFF_DIM_X, BUFF_DIM_Y, mx_out_data_sz}; + const tdm::HIPTensorMapOut mx_tmap_out_act_cw{output_act_colwise_ptr, + static_cast(cols), + static_cast(rows), + static_cast(output_stride), + BUFF_DIM_X, BUFF_DIM_Y, mx_out_data_sz}; + const tdm::HIPTensorMapOut mx_tmap_out_gate_cw{output_gate_colwise_ptr, + static_cast(cols), + static_cast(rows), + static_cast(output_stride), + BUFF_DIM_X, BUFF_DIM_Y, mx_out_data_sz}; + // TDM prefetch if constexpr (IS_DGATED) { - tdm::copy_2d_to_shared(&in_grad_sh[0], grad_ptr, - block_offset_X, block_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, cols, mx_in_data_sz); - tdm::copy_2d_to_shared_x2( - &in_act_sh[0], input_act_ptr, block_offset_X, block_offset_Y, - &in_gate_sh[0], input_gate_ptr, block_offset_X, block_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, input_act_stride, mx_in_data_sz); + tdm::copy_2d_to_shared(&in_grad_sh[0], mx_tmap_grad, block_offset_X, block_offset_Y); + tdm::copy_2d_to_shared_x2(&in_act_sh[0], mx_tmap_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], mx_tmap_gate, block_offset_X, block_offset_Y); } else { - tdm::copy_2d_to_shared_x2( - &in_act_sh[0], input_act_ptr, block_offset_X, block_offset_Y, - &in_gate_sh[0], input_gate_ptr, block_offset_X, block_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, input_act_stride, mx_in_data_sz); + tdm::copy_2d_to_shared_x2(&in_act_sh[0], mx_tmap_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], mx_tmap_gate, block_offset_X, block_offset_Y); } #endif // __HIP_PLATFORM_AMD__ @@ -641,21 +675,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DGATED) { - tdm::copy_2d_to_shared(&in_grad_sh[next_buff_offset], grad_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared(&in_grad_sh[next_buff_offset], mx_tmap_grad, + global_offset_X, global_offset_Y); tdm::copy_2d_to_shared_x2( - &in_act_sh[next_buff_offset], input_act_ptr, global_offset_X, global_offset_Y, - &in_gate_sh[next_buff_offset], input_gate_ptr, global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, input_act_stride, mx_in_data_sz); + &in_act_sh[next_buff_offset], mx_tmap_act, global_offset_X, global_offset_Y, + &in_gate_sh[next_buff_offset], mx_tmap_gate, global_offset_X, global_offset_Y); } else { tdm::copy_2d_to_shared_x2( - &in_act_sh[next_buff_offset], input_act_ptr, global_offset_X, global_offset_Y, - &in_gate_sh[next_buff_offset], input_gate_ptr, global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, input_act_stride, mx_in_data_sz); + &in_act_sh[next_buff_offset], mx_tmap_act, global_offset_X, global_offset_Y, + &in_gate_sh[next_buff_offset], mx_tmap_gate, global_offset_X, global_offset_Y); } #endif // __HIP_PLATFORM_AMD__ } @@ -1108,27 +1136,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { - tdm::store_2d_to_global(&out_act_rowwise_sh[buff_offset], output_act_rowwise_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, output_stride, mx_out_data_sz); + tdm::store_2d_to_global(&out_act_rowwise_sh[buff_offset], mx_tmap_out_act_rw, + global_offset_X, global_offset_Y); if constexpr (IS_DGATED) { - tdm::store_2d_to_global(&out_gate_rowwise_sh[buff_offset], output_gate_rowwise_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, output_stride, mx_out_data_sz); + tdm::store_2d_to_global(&out_gate_rowwise_sh[buff_offset], mx_tmap_out_gate_rw, + global_offset_X, global_offset_Y); } } if constexpr (COLWISE_SCALING) { - tdm::store_2d_to_global(&out_act_colwise_sh[buff_offset], output_act_colwise_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, output_stride, mx_out_data_sz); + tdm::store_2d_to_global(&out_act_colwise_sh[buff_offset], mx_tmap_out_act_cw, + global_offset_X, global_offset_Y); if constexpr (IS_DGATED) { - tdm::store_2d_to_global(&out_gate_colwise_sh[buff_offset], output_gate_colwise_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, output_stride, mx_out_data_sz); + tdm::store_2d_to_global(&out_gate_colwise_sh[buff_offset], mx_tmap_out_gate_cw, + global_offset_X, global_offset_Y); } } // TDM stores are async — they will be drained at the top of the next iteration diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 6cacfd2ca..de3baf851 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -204,17 +204,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr uint32_t mx_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); constexpr uint32_t mx_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + const tdm::HIPTensorMap tmap_in{input_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + BUFF_DIM_X, BUFF_DIM_Y, mx_in_data_sz}; + const tdm::HIPTensorMap tmap_act_in{act_input_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + BUFF_DIM_X, BUFF_DIM_Y, mx_in_data_sz}; + const tdm::HIPTensorMapOut tmap_rowwise{output_rowwise_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + BUFF_DIM_X, BUFF_DIM_Y, mx_out_data_sz}; + const tdm::HIPTensorMapOut tmap_colwise{output_colwise_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + BUFF_DIM_X, BUFF_DIM_Y, mx_out_data_sz}; + // Prefetch first stage if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2( - &in_sh[0], input_ptr, block_offset_X, block_offset_Y, - &act_in_sh[0], act_input_ptr, block_offset_X, block_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared_x2(&in_sh[0], tmap_in, block_offset_X, block_offset_Y, + &act_in_sh[0], tmap_act_in, block_offset_X, block_offset_Y); } else { - tdm::copy_2d_to_shared(&in_sh[0], input_ptr, - block_offset_X, block_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared(&in_sh[0], tmap_in, block_offset_X, block_offset_Y); } #endif // __HIP_PLATFORM_AMD__ @@ -252,15 +268,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2( - &in_sh[next_buff_offset], input_ptr, global_offset_X, global_offset_Y, - &act_in_sh[next_buff_offset], act_input_ptr, global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared_x2(&in_sh[next_buff_offset], tmap_in, + global_offset_X, global_offset_Y, + &act_in_sh[next_buff_offset], tmap_act_in, + global_offset_X, global_offset_Y); } else { - tdm::copy_2d_to_shared(&in_sh[next_buff_offset], input_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, - cols, rows, cols, mx_in_data_sz); + tdm::copy_2d_to_shared(&in_sh[next_buff_offset], tmap_in, + global_offset_X, global_offset_Y); } } #endif // __HIP_PLATFORM_AMD__ @@ -564,14 +578,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { - tdm::store_2d_to_global(&out_rowwise_sh[buff_offset], output_rowwise_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_out_data_sz); + tdm::store_2d_to_global(&out_rowwise_sh[buff_offset], tmap_rowwise, + global_offset_X, global_offset_Y); } if constexpr (COLWISE_SCALING) { - tdm::store_2d_to_global(&out_colwise_sh[buff_offset], output_colwise_ptr, - global_offset_X, global_offset_Y, - BUFF_DIM_X, BUFF_DIM_Y, cols, rows, cols, mx_out_data_sz); + tdm::store_2d_to_global(&out_colwise_sh[buff_offset], tmap_colwise, + global_offset_X, global_offset_Y); } tdm::wait_tensorcnt_0(); __syncthreads(); @@ -755,17 +767,29 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) constexpr uint32_t fp8_in_data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); constexpr uint32_t fp8_out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); + const tdm::HIPTensorMap fp8_tmap_in{input_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, fp8_in_data_sz}; + const tdm::HIPTensorMap fp8_tmap_act_in{act_input_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, fp8_in_data_sz}; + const tdm::HIPTensorMapOut fp8_tmap_out{output_ptr, + static_cast(cols), + static_cast(rows), + static_cast(cols), + FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, fp8_out_data_sz}; + // Prefetch first buffer if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2( - &in_sh[0][0][0], input_ptr, chunk_offset_X, chunk_offset_Y, - &act_in_sh[0][0][0], act_input_ptr, chunk_offset_X, chunk_offset_Y, - FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, cols, rows, cols, fp8_in_data_sz); + tdm::copy_2d_to_shared_x2(&in_sh[0][0][0], fp8_tmap_in, chunk_offset_X, chunk_offset_Y, + &act_in_sh[0][0][0], fp8_tmap_act_in, + chunk_offset_X, chunk_offset_Y); } else { - tdm::copy_2d_to_shared(&in_sh[0][0][0], input_ptr, - chunk_offset_X, chunk_offset_Y, - FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, - cols, rows, cols, fp8_in_data_sz); + tdm::copy_2d_to_shared(&in_sh[0][0][0], fp8_tmap_in, chunk_offset_X, chunk_offset_Y); } #endif // __HIP_PLATFORM_AMD__ @@ -794,15 +818,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2( - &in_sh[next_buff][0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, - &act_in_sh[next_buff][0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, - FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, cols, rows, cols, fp8_in_data_sz); + tdm::copy_2d_to_shared_x2(&in_sh[next_buff][0][0], fp8_tmap_in, + chunk_it_offset_x, chunk_it_offset_y, + &act_in_sh[next_buff][0][0], fp8_tmap_act_in, + chunk_it_offset_x, chunk_it_offset_y); } else { - tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], input_ptr, - chunk_it_offset_x, chunk_it_offset_y, - FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, - cols, rows, cols, fp8_in_data_sz); + tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], fp8_tmap_in, + chunk_it_offset_x, chunk_it_offset_y); } } #endif // __HIP_PLATFORM_AMD__ @@ -887,10 +909,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) { const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; - tdm::store_2d_to_global(&out_sh[buff][0][0], output_ptr, - chunk_it_offset_x, chunk_it_offset_y, - FP8_SHMEM_DIM_X, FP8_SHMEM_DIM_Y, - cols, rows, cols, fp8_out_data_sz); + tdm::store_2d_to_global(&out_sh[buff][0][0], fp8_tmap_out, + chunk_it_offset_x, chunk_it_offset_y); tdm::wait_tensorcnt_0(); __syncthreads(); } diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index e0df26c1d..c1da033e3 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -135,28 +135,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t row_base = chunk_it_offset_y; // Initiate bulk tensor copy -#if defined(__gfx1250__) - { - constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); - if constexpr (IS_DGATED) { - // grad uses stride=cols, act/gate use stride=2*cols -- issue separately - tdm::copy_2d_to_shared( - &in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz); - tdm::copy_2d_to_shared_x2( - &in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz); - } else { - tdm::copy_2d_to_shared_x2( - &in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, 2*cols, data_sz); - } - tdm::wait_tensorcnt_0(); - __syncthreads(); - } -#else if constexpr (IS_DGATED) { copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); @@ -171,7 +149,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); -#endif const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; @@ -377,33 +354,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); -#if defined(__gfx1250__) - { - constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); - if constexpr (USE_ROWWISE_SCALING) { - tdm::store_2d_to_global(&out_act_rowwise_sh[0], output_act_rowwise, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); - if constexpr (IS_DGATED) { - tdm::store_2d_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); - } - } - if constexpr (USE_COLWISE_SCALING) { - tdm::store_2d_to_global(&out_act_colwise_sh[0], output_act_colwise, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); - if constexpr (IS_DGATED) { - tdm::store_2d_to_global(&out_gate_colwise_sh[0], output_gate_colwise, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, output_cols, out_data_sz); - } - } - tdm::wait_tensorcnt_0(); - __syncthreads(); - } -#else if constexpr (USE_ROWWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); @@ -422,7 +372,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } __syncthreads(); -#endif } } } // namespace gated_kernels diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 808b5e59d..0b52997a8 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -162,21 +162,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; -#if defined(__gfx1250__) - constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); - if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2( - &in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, - &act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, - MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz); - } else { - tdm::copy_2d_to_shared( - &in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, - MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, cols, rows, cols, data_sz); - } - tdm::wait_tensorcnt_0(); - __syncthreads(); -#else if constexpr (IS_DACT) { copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, @@ -186,7 +171,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); -#endif if constexpr (USE_ROWWISE_SCALING) { Vec in; @@ -329,23 +313,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); -#if defined(__gfx1250__) - constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); - if constexpr (USE_ROWWISE_SCALING) { - tdm::store_2d_to_global(&out_rowwise_sh[0][0], output_rowwise, - chunk_it_offset_x, chunk_it_offset_y, - MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, - cols, rows, cols, out_data_sz); - } - if constexpr (USE_COLWISE_SCALING) { - tdm::store_2d_to_global(&out_colwise_sh[0][0], output_colwise, - chunk_it_offset_x, chunk_it_offset_y, - MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, - cols, rows, cols, out_data_sz); - } - tdm::wait_tensorcnt_0(); - __syncthreads(); -#else if constexpr (USE_ROWWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, @@ -358,7 +325,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } __syncthreads(); -#endif } } diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 5b541062e..00c81ac08 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -86,21 +86,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; -#if defined(__gfx1250__) - { - constexpr uint32_t data_sz = tdm::get_data_size_from_bits(sizeof(IType) * 8); - tdm::copy_2d_to_shared(&in_sh[0][0], input_ptr, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, data_sz); - tdm::wait_tensorcnt_0(); - __syncthreads(); - } -#else copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); -#endif const int scale_offset_Y = USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) @@ -138,22 +127,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); -#if defined(__gfx1250__) - { - constexpr uint32_t out_data_sz = tdm::get_data_size_from_bits(sizeof(OType) * 8); - tdm::store_2d_to_global(&out_sh[0][0], output_ptr, - chunk_it_offset_x, chunk_it_offset_y, - SHMEM_DIM_X, SHMEM_DIM_Y, cols, rows, cols, out_data_sz); - tdm::wait_tensorcnt_0(); - __syncthreads(); - } -#else bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); -#endif } } } // namespace dequantization diff --git a/transformer_engine/common/util/tdm.cuh b/transformer_engine/common/util/tdm.cuh index b618e9808..61c5255e4 100644 --- a/transformer_engine/common/util/tdm.cuh +++ b/transformer_engine/common/util/tdm.cuh @@ -55,6 +55,35 @@ __device__ __forceinline__ bool is_tdm_wave() { return (linear_tid < 32); } +// --------------------------------------------------------------------------- +// HIPTensorMap: device-side tensor descriptor (AMD analog of CUtensorMap) +// --------------------------------------------------------------------------- +// On NV, CUtensorMap is built on the host and encodes both full tensor shape +// and tile shape; hardware auto-clamps at boundaries. TDM has no host-side +// descriptor — the device supplies the tile origin, remaining extent, and tile +// dims at each instruction. HIPTensorMap centralises that metadata so callers +// only pass a single descriptor + tile coordinates instead of 6+ scalars. + +struct HIPTensorMap { + const void* base_ptr; // pointer to tensor base (global memory) + uint32_t tensor_w; // full tensor width in elements + uint32_t tensor_h; // full tensor height in elements + uint32_t stride; // row stride in elements (may differ from tensor_w) + uint32_t tile_dim_x; // tile width to transfer per call + uint32_t tile_dim_y; // tile height to transfer per call + uint32_t data_size; // log2(sizeof(element)): 0=1B,1=2B,2=4B,3=8B +}; + +struct HIPTensorMapOut { + void* base_ptr; + uint32_t tensor_w; + uint32_t tensor_h; + uint32_t stride; + uint32_t tile_dim_x; + uint32_t tile_dim_y; + uint32_t data_size; +}; + // --------------------------------------------------------------------------- // Core 2D load: global memory -> LDS // --------------------------------------------------------------------------- @@ -200,6 +229,9 @@ void copy_2d_to_shared(void* lds_dst, uint32_t stride, uint32_t data_size) { if (is_tdm_wave()) { + if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) + printf("[TDM] copy_2d_to_shared: chunk=(%u,%u) tile=(%u,%u) tensor=(%u,%u)\n", + chunk_x, chunk_y, tile_dim_x, tile_dim_y, tensor_w, tensor_h); uint32_t lds_off = static_cast(reinterpret_cast(lds_dst)); load_2d_to_lds(global_base, lds_off, tensor_w, tensor_h, @@ -290,6 +322,80 @@ void store_2d_to_global(const void* lds_src, } } +// --------------------------------------------------------------------------- +// HIPTensorMap-based overloads (single descriptor + tile coords) +// --------------------------------------------------------------------------- + +__device__ __forceinline__ +void copy_2d_to_shared(void* lds_dst, + const HIPTensorMap& tmap, + uint32_t chunk_x, + uint32_t chunk_y) { + copy_2d_to_shared(lds_dst, tmap.base_ptr, chunk_x, chunk_y, + tmap.tile_dim_x, tmap.tile_dim_y, + tmap.tensor_w, tmap.tensor_h, + tmap.stride, tmap.data_size); +} + +__device__ __forceinline__ +void copy_2d_to_shared_x2(void* dst1, const HIPTensorMap& tmap1, uint32_t cx1, uint32_t cy1, + void* dst2, const HIPTensorMap& tmap2, uint32_t cx2, uint32_t cy2) { + if (is_tdm_wave()) { + uint32_t lds_off1 = static_cast(reinterpret_cast(dst1)); + load_2d_to_lds(tmap1.base_ptr, lds_off1, + tmap1.tensor_w, tmap1.tensor_h, + tmap1.tile_dim_x, tmap1.tile_dim_y, + tmap1.stride, tmap1.data_size, + cx1, cy1); + + uint32_t lds_off2 = static_cast(reinterpret_cast(dst2)); + load_2d_to_lds(tmap2.base_ptr, lds_off2, + tmap2.tensor_w, tmap2.tensor_h, + tmap2.tile_dim_x, tmap2.tile_dim_y, + tmap2.stride, tmap2.data_size, + cx2, cy2); + } +} + +__device__ __forceinline__ +void copy_2d_to_shared_x3(void* dst1, const HIPTensorMap& tmap1, uint32_t cx1, uint32_t cy1, + void* dst2, const HIPTensorMap& tmap2, uint32_t cx2, uint32_t cy2, + void* dst3, const HIPTensorMap& tmap3, uint32_t cx3, uint32_t cy3) { + if (is_tdm_wave()) { + uint32_t lds_off1 = static_cast(reinterpret_cast(dst1)); + load_2d_to_lds(tmap1.base_ptr, lds_off1, + tmap1.tensor_w, tmap1.tensor_h, + tmap1.tile_dim_x, tmap1.tile_dim_y, + tmap1.stride, tmap1.data_size, + cx1, cy1); + + uint32_t lds_off2 = static_cast(reinterpret_cast(dst2)); + load_2d_to_lds(tmap2.base_ptr, lds_off2, + tmap2.tensor_w, tmap2.tensor_h, + tmap2.tile_dim_x, tmap2.tile_dim_y, + tmap2.stride, tmap2.data_size, + cx2, cy2); + + uint32_t lds_off3 = static_cast(reinterpret_cast(dst3)); + load_2d_to_lds(tmap3.base_ptr, lds_off3, + tmap3.tensor_w, tmap3.tensor_h, + tmap3.tile_dim_x, tmap3.tile_dim_y, + tmap3.stride, tmap3.data_size, + cx3, cy3); + } +} + +__device__ __forceinline__ +void store_2d_to_global(const void* lds_src, + const HIPTensorMapOut& tmap, + uint32_t chunk_x, + uint32_t chunk_y) { + store_2d_to_global(lds_src, tmap.base_ptr, chunk_x, chunk_y, + tmap.tile_dim_x, tmap.tile_dim_y, + tmap.tensor_w, tmap.tensor_h, + tmap.stride, tmap.data_size); +} + #else // !defined(__gfx1250__) // Stubs for non-gfx1250 AMD targets -- these should never be called. @@ -304,17 +410,33 @@ __device__ __forceinline__ constexpr uint32_t get_data_size_from_bits(size_t typ return (type_num_bits <= 8) ? 0 : (type_num_bits <= 16) ? 1 : (type_num_bits <= 32) ? 2 : 3; } +struct HIPTensorMap { + const void* base_ptr; + uint32_t tensor_w, tensor_h, stride, tile_dim_x, tile_dim_y, data_size; +}; +struct HIPTensorMapOut { + void* base_ptr; + uint32_t tensor_w, tensor_h, stride, tile_dim_x, tile_dim_y, data_size; +}; + __device__ __forceinline__ void copy_2d_to_shared(void*, const void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t) {} +__device__ __forceinline__ +void copy_2d_to_shared(void*, const HIPTensorMap&, uint32_t, uint32_t) {} + __device__ __forceinline__ void copy_2d_to_shared_x2(void*, const void*, uint32_t, uint32_t, void*, const void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t) {} +__device__ __forceinline__ +void copy_2d_to_shared_x2(void*, const HIPTensorMap&, uint32_t, uint32_t, + void*, const HIPTensorMap&, uint32_t, uint32_t) {} + __device__ __forceinline__ void copy_2d_to_shared_x3(void*, const void*, uint32_t, uint32_t, void*, const void*, uint32_t, uint32_t, @@ -322,11 +444,19 @@ void copy_2d_to_shared_x3(void*, const void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t) {} +__device__ __forceinline__ +void copy_2d_to_shared_x3(void*, const HIPTensorMap&, uint32_t, uint32_t, + void*, const HIPTensorMap&, uint32_t, uint32_t, + void*, const HIPTensorMap&, uint32_t, uint32_t) {} + __device__ __forceinline__ void store_2d_to_global(const void*, void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t) {} +__device__ __forceinline__ +void store_2d_to_global(const void*, const HIPTensorMapOut&, uint32_t, uint32_t) {} + #endif // defined(__gfx1250__) } // namespace tdm From a0a60fed4bf190727057bd88b7de9e112626ac02 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Fri, 24 Apr 2026 22:49:20 -0500 Subject: [PATCH 09/43] tdm: fully revert rocm_*.cuh to branch-point state Remove the tdm.cuh include, TDM_SHMEM_ALIGNMENT usage, and any whitespace changes introduced in the previous commit, so rocm_cast_kernels.cuh, rocm_cast_gated_kernels.cuh, and rocm_dequantize_kernels.cuh are byte-for-byte identical to 5e8d61edd377f (Ilya's branch point). Co-Authored-By: Claude Sonnet 4 --- .../common/util/rocm_cast_gated_kernels.cuh | 7 +++---- .../common/util/rocm_cast_kernels.cuh | 15 +++++++-------- .../common/util/rocm_dequantize_kernels.cuh | 7 +++---- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index c1da033e3..387445a78 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -13,7 +13,6 @@ #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" -#include "tdm.cuh" #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "vectorized_pointwise.h" @@ -22,7 +21,7 @@ namespace transformer_engine { namespace gated_kernels { -constexpr size_t ALIGNMENT_SIZE = TDM_SHMEM_ALIGNMENT; +constexpr size_t ALIGNMENT_SIZE = 128; // TODO: Identify optimal chunk/thread size for MI350+ constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; @@ -143,7 +142,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Act copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - + // Gate copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); @@ -362,7 +361,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } - + if constexpr (USE_COLWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 0b52997a8..33c53e8e8 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -12,7 +12,6 @@ #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" -#include "tdm.cuh" #include "transformer_engine/cast.h" #include "../transpose/cast_transpose.h" #include "vectorized_pointwise.h" @@ -128,10 +127,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType act_in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_rowwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_colwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(128) __shared__ IType in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(128) __shared__ IType act_in_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(128) __shared__ OType out_rowwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + alignas(128) __shared__ OType out_colwise_sh[MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; float block_amax = 0; @@ -163,11 +162,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, - chunk_it_offset_x, chunk_it_offset_y, cols, + copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, + chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 00c81ac08..0d020b5eb 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -14,7 +14,6 @@ #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" -#include "tdm.cuh" #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "../transpose/cast_transpose.h" @@ -79,14 +78,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ OType out_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; for (int iter = 0; iter < ITERATIONS; iter++) { const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); From 506d78c8896653b543c45fa009021392fae18931 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sat, 25 Apr 2026 00:18:05 -0500 Subject: [PATCH 10/43] tdm: extract ROCm flow into separate rocm_* launchers; TDM stays in main kernel functions On AMD, each mxfp8 quantize/dequantize/gated function previously dispatched between TDM and ROCm kernels via an inline env-var check. This refactor separates the two flows cleanly: - cast_gated_kernels.cuh / rocm_cast_gated_kernels.cuh: rocm_cast_mxfp8_gated() hosts the ROCm HIP gated kernel dispatch. cast_mxfp8_gated() is now TDM-only on AMD. quantize_gated() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var. - cast_kernels.cuh / rocm_cast_kernels.cuh: rocm_mxfp8_quantize() hosts the ROCm HIP cast kernel dispatch. mxfp8_quantize() is now TDM-only on AMD. fp8_quantize_rocm() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var. - dequantize_kernels.cuh / rocm_dequantize_kernels.cuh: rocm_mxfp8_dequantize() hosts the ROCm HIP dequantize dispatch. mxfp8_dequantize() is now TDM-only on AMD. dequantize_helper() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var. NV upstream path (no AMD) is unchanged throughout. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 94 +++++-------- .../common/util/cast_kernels.cuh | 123 ++++++++---------- .../common/util/dequantize_kernels.cuh | 36 ++--- .../common/util/rocm_cast_gated_kernels.cuh | 83 ++++++++++++ .../common/util/rocm_cast_kernels.cuh | 87 ++++++++++++- .../common/util/rocm_dequantize_kernels.cuh | 53 ++++++++ 6 files changed, 326 insertions(+), 150 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 304eff7cd..e3e9c6f19 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1321,12 +1321,12 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out #ifdef __HIP_PLATFORM_AMD__ constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; - constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; - constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; - constexpr size_t BUFFS_NUM = BUFFERS_NUM; + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); #else constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; @@ -1343,8 +1343,10 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out : THREADS_PER_CHUNK_NON_COLWISE; #endif +#ifndef __HIP_PLATFORM_AMD__ const dim3 grid(blocks_X, blocks_Y); const dim3 block_size(THREADS_PER_CHUNK); +#endif size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -1412,38 +1414,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } #endif // #ifdef __HIP_PLATFORM_AMD__ - const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; - const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - - const size_t out_act_mem = buff_size_aligned_out; #ifdef __HIP_PLATFORM_AMD__ - const size_t out_gate_mem = buff_size_aligned_out; -#else - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); -#endif - size_t out_mem = out_act_mem + out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } - - const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - -#ifdef __HIP_PLATFORM_AMD__ - // Check env var: NVTE_USE_NV_UPSTREAM_FLOW=1 selects NV upstream (TDM) kernel path - static const bool use_nv_upstream_mx = [] { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; - }(); - if (use_nv_upstream_mx) { - // NV upstream flow with TDM — uses mxfp8_kernel::cast_mxfp8_gated_kernel + { + // TDM flow — uses mxfp8_kernel::cast_mxfp8_gated_kernel constexpr size_t NV_THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; constexpr size_t NV_THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; @@ -1506,30 +1479,22 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out nv_launch(std::true_type{}, std::true_type{}, std::integral_constant{}); } - } else { - // ROCm flow kernel (default on AMD) - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - }))); // NOLINT(*) } #else + const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_mem = grad_mem + buff_size_aligned_in + buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; switch (scaling_type) { case ScalingType::ROWWISE: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -1714,8 +1679,19 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } } else if (is_mxfp_scaling(output->scaling_mode)) { if (use_tma_kernels) { - // cast_mxfp8_gated handles both NVIDIA (TMA) and AMD (TDM) internally via #ifdef +#ifdef __HIP_PLATFORM_AMD__ + static const bool use_nv_upstream_flow = [] { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_nv_upstream_flow) { + cast_mxfp8_gated(grad, gated_input, output, stream); + } else { + rocm_cast_mxfp8_gated(grad, gated_input, output, stream); + } +#else cast_mxfp8_gated(grad, gated_input, output, stream); +#endif } else { NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", "by 32, got input of shape ", gated_input.data.shape); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index de3baf851..3e29ea609 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1266,17 +1266,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); -#ifdef __HIP_PLATFORM_AMD__ - constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; - constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; - constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; -#else constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; +#ifndef __HIP_PLATFORM_AMD__ constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; @@ -1338,72 +1334,57 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(IType))), IS_ALIGNED, { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); - if (env && std::string(env) == "1") { - // NV upstream kernel with TDM - constexpr bool NV_ROWWISE = (SCALE_DIM_X > 1); - constexpr bool NV_COLWISE = (SCALE_DIM_Y > 1); - constexpr size_t NV_CAST_DBIAS_ONLY_Y = (IS_DBIAS && (!IS_DACT) && (!IS_ACT)) ? 128 : 64; - constexpr size_t NV_CAST_DBIAS_ONLY_X = NV_CAST_DBIAS_ONLY_Y; - constexpr size_t NV_CAST_DBIAS_ONLY_T = NV_CAST_DBIAS_ONLY_Y; - - constexpr size_t NV_THREADS_X = NV_CAST_DBIAS_ONLY_X / mxfp8_kernel::SCALE_DIM_X; - constexpr size_t NV_THREADS_Y = NV_CAST_DBIAS_ONLY_T / NV_THREADS_X; - constexpr size_t NV_BUFF_DIM_Y = NV_THREADS_Y; - constexpr size_t NV_BUFF_DIM_X = NV_CAST_DBIAS_ONLY_X; - - constexpr size_t NV_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; - constexpr size_t nv_buff_elems = NV_BUFF_DIM_Y * NV_BUFF_DIM_X; - constexpr size_t nv_buff_elems_total = mxfp8_kernel::BUFFS_NUM * nv_buff_elems; - constexpr size_t nv_input_type_bit_size = TypeInfo::size; - constexpr size_t nv_output_type_bit_size = TypeInfo::size; - constexpr size_t nv_input_buff_size = (nv_buff_elems_total * nv_input_type_bit_size) / 8; - constexpr size_t nv_output_buff_size = (nv_buff_elems_total * nv_output_type_bit_size) / 8; - constexpr size_t nv_buff_size_aligned_in = - DIVUP_TO_MULTIPLE(nv_input_buff_size, NV_SHMEM_ALIGNMENT); - constexpr size_t nv_buff_size_aligned_out = - DIVUP_TO_MULTIPLE(nv_output_buff_size, NV_SHMEM_ALIGNMENT); - - constexpr size_t nv_elt_input_mem = nv_buff_size_aligned_in; - constexpr size_t nv_act_input_mem = (IS_DACT ? nv_buff_size_aligned_in : 0); - constexpr size_t nv_in_mem = nv_elt_input_mem + nv_act_input_mem; - - const size_t nv_out_rowwise_mem = (use_rowwise_scaling ? nv_buff_size_aligned_out : 0); - const size_t nv_out_colwise_mem = (use_colwise_scaling ? nv_buff_size_aligned_out : 0); - const size_t nv_out_mem = nv_out_rowwise_mem + nv_out_colwise_mem; - - const size_t nv_dshmem_size = nv_in_mem + nv_out_mem + NV_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, nv_dshmem_size)); - - mxfp8_kernel::cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - } else { - // Default ROCm flow - cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - } + // TDM flow — uses mxfp8_kernel::cast_mxfp8_2D_kernel + constexpr bool NV_ROWWISE = (SCALE_DIM_X > 1); + constexpr bool NV_COLWISE = (SCALE_DIM_Y > 1); + constexpr size_t NV_CAST_DBIAS_ONLY_Y = (IS_DBIAS && (!IS_DACT) && (!IS_ACT)) ? 128 : 64; + constexpr size_t NV_CAST_DBIAS_ONLY_X = NV_CAST_DBIAS_ONLY_Y; + constexpr size_t NV_CAST_DBIAS_ONLY_T = NV_CAST_DBIAS_ONLY_Y; + + constexpr size_t NV_THREADS_X = NV_CAST_DBIAS_ONLY_X / mxfp8_kernel::SCALE_DIM_X; + constexpr size_t NV_THREADS_Y = NV_CAST_DBIAS_ONLY_T / NV_THREADS_X; + constexpr size_t NV_BUFF_DIM_Y = NV_THREADS_Y; + constexpr size_t NV_BUFF_DIM_X = NV_CAST_DBIAS_ONLY_X; + + constexpr size_t NV_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; + constexpr size_t nv_buff_elems = NV_BUFF_DIM_Y * NV_BUFF_DIM_X; + constexpr size_t nv_buff_elems_total = mxfp8_kernel::BUFFS_NUM * nv_buff_elems; + constexpr size_t nv_input_type_bit_size = TypeInfo::size; + constexpr size_t nv_output_type_bit_size = TypeInfo::size; + constexpr size_t nv_input_buff_size = (nv_buff_elems_total * nv_input_type_bit_size) / 8; + constexpr size_t nv_output_buff_size = (nv_buff_elems_total * nv_output_type_bit_size) / 8; + constexpr size_t nv_buff_size_aligned_in = + DIVUP_TO_MULTIPLE(nv_input_buff_size, NV_SHMEM_ALIGNMENT); + constexpr size_t nv_buff_size_aligned_out = + DIVUP_TO_MULTIPLE(nv_output_buff_size, NV_SHMEM_ALIGNMENT); + + constexpr size_t nv_elt_input_mem = nv_buff_size_aligned_in; + constexpr size_t nv_act_input_mem = (IS_DACT ? nv_buff_size_aligned_in : 0); + constexpr size_t nv_in_mem = nv_elt_input_mem + nv_act_input_mem; + + const size_t nv_out_rowwise_mem = (use_rowwise_scaling ? nv_buff_size_aligned_out : 0); + const size_t nv_out_colwise_mem = (use_colwise_scaling ? nv_buff_size_aligned_out : 0); + const size_t nv_out_mem = nv_out_rowwise_mem + nv_out_colwise_mem; + + const size_t nv_dshmem_size = nv_in_mem + nv_out_mem + NV_SHMEM_ALIGNMENT; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, nv_dshmem_size)); + + mxfp8_kernel::cast_mxfp8_2D_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); }))); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 5cbf08d4f..f543f2936 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -377,22 +377,12 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(OType))), IS_ALIGNED, { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); - if (env && std::string(env) == "1") { - // NV upstream kernel with TDM - dequantization::dequantize_mxfp8_kernel - <<>>( - reinterpret_cast(input_data.dptr), - reinterpret_cast(output->data.dptr), - scales_ptr, rows, cols, scales_stride); - } else { - // Default ROCm flow - dequantize_mxfp8_kernel - <<>>( - reinterpret_cast(input_data.dptr), - reinterpret_cast(output->data.dptr), - scales_ptr, rows, cols, scales_stride); - } + // TDM flow — uses dequantization::dequantize_mxfp8_kernel + dequantization::dequantize_mxfp8_kernel + <<>>( + reinterpret_cast(input_data.dptr), + reinterpret_cast(output->data.dptr), + scales_ptr, rows, cols, scales_stride); }); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; @@ -427,14 +417,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) dequantization::fp8_dequantize(input, output, stream); } else if (is_mxfp_scaling(input.scaling_mode)) { #ifdef __HIP_PLATFORM_AMD__ - if (1) { + { + static const bool use_nv_upstream_flow = [] { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_nv_upstream_flow) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + rocm_mxfp8_dequantize(input, output, stream); + } + } #else if (is_supported_by_CC_100()) { -#endif dequantization::mxfp8_dequantize(input, output, stream); } else { NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } +#endif } else { // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index 387445a78..910fe9356 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -374,4 +374,87 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } } // namespace gated_kernels + +template +void rocm_cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + using namespace gated_kernels; + + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const dim3 block_size(THREADS_PER_CHUNK); + + size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; + size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + const IType *grad_ptr = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); + const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; + OType *output_act_rowwise_ptr = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; + OType *output_gate_rowwise_ptr = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; + OType *output_act_colwise_ptr = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; + OType *output_gate_colwise_ptr = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + const size_t buff_elems_total = BUFFERS_NUM * BUFFER_DIM_Y * BUFFER_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, ALIGNMENT_SIZE); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, ALIGNMENT_SIZE); + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_mem = grad_mem + buff_size_aligned_in + buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + const size_t shmem_size = in_mem + out_mem + ALIGNMENT_SIZE; + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + grad_ptr, input_act_ptr, input_gate_ptr, + output_act_rowwise_ptr, output_gate_rowwise_ptr, + output_act_colwise_ptr, output_gate_colwise_ptr, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, + scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + } // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 33c53e8e8..3ede0c74c 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -459,6 +459,12 @@ void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const si reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); } +// Forward declaration +template +void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); + template void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tensor *noop, @@ -547,8 +553,17 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso break; } case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); + static const bool use_nv_upstream_flow = [] { + const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_nv_upstream_flow) { + mxfp8_quantize(input, act_input, noop, output, + dbias, workspace, stream); + } else { + rocm_mxfp8_quantize(input, act_input, noop, output, + dbias, workspace, stream); + } break; } default: @@ -557,4 +572,72 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso } +template +void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = MXFP8_THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + { + cast_mxfp8_2D_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) closes: {} SWITCH_CONDITION inner_MX outer_MX + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) closes FP8ONLY + ); // NOLINT(*) closes NON_FP8ONLY +} + } // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 0d020b5eb..7e320dc82 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -134,4 +134,57 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } } // namespace dequantization + +static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace dequantization; + + bool use_rowwise_scaling = input.has_data(); + bool use_colwise_scaling = input.has_columnwise_data(); + + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + + const size_t scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t scales_X_colwise = cols; + + const e8m0_t *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + : reinterpret_cast(input.columnwise_scale_inv.dptr); + + const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + + const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; + + const dim3 block(THREADS_PER_CHUNK); + const dim3 grid(chunks_X, chunks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(OType))), IS_ALIGNED, + { + dequantize_mxfp8_kernel + <<>>( + reinterpret_cast(input_data.dptr), + reinterpret_cast(output->data.dptr), + scales_ptr, rows, cols, scales_stride); + NVTE_CHECK_CUDA(cudaGetLastError()); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + } // namespace transformer_engine From acc7e4f0ef817367cbee60b8576d49129aed7f5c Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 09:41:18 -0500 Subject: [PATCH 11/43] tdm: address review comments for cast_gated_kernels.cuh - Rename namespace nv_flow -> tma_flow (more accurate: both TMA on NV and TDM on AMD use this path) - Rename env-var NVTE_USE_NV_UPSTREAM_FLOW -> NVTE_USE_TDM_FLOW with inverted default: 0 = ROCm flow (default), 1 = TDM flow - Apply same env-var dispatch to fp8 gated path (was missing) - Remove dead AMD-specific guards around ScalingType, BUFF_DIM, blocks, THREADS_PER_CHUNK, grid, block_size in cast_mxfp8_gated - Remove AMD-specific {} wrapper and duplicate shmem computation block; TMA_SHMEM_ALIGNMENT == TDM_SHMEM_ALIGNMENT == 128 so NV upstream formula works on both platforms Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 154 +++++------------- 1 file changed, 44 insertions(+), 110 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index e3e9c6f19..dae53c36b 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -37,10 +37,10 @@ namespace transformer_engine { namespace gated_kernels { -// NV upstream flow constants — used on both NVIDIA (TMA) and AMD gfx1250 (TDM). +// TMA/TDM flow constants — used on both NVIDIA (TMA) and AMD gfx1250 (TDM). // On AMD, the ROCm flow constants live in rocm_cast_gated_kernels.cuh; -// these are in the nv_flow namespace to avoid collision. -namespace nv_flow { +// these are in the tma_flow namespace to avoid collision. +namespace tma_flow { constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -401,7 +401,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif #endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } -} // namespace nv_flow +} // namespace tma_flow namespace mxfp8_kernel { @@ -1187,14 +1187,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, nv_flow::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, nv_flow::CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, tma_flow::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, tma_flow::CHUNK_DIM_X); float *const amax_ptr = reinterpret_cast(output->amax.dptr); float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); float *const scale_ptr = reinterpret_cast(output->scale.dptr); - const dim3 block_dim(nv_flow::THREADS_PER_CHUNK); + const dim3 block_dim(tma_flow::THREADS_PER_CHUNK); const dim3 grid_dim(blocks_X, blocks_Y); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -1211,7 +1211,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu OType *output_gate_ptr = IS_DGATED ? reinterpret_cast(output->data.dptr) + cols : nullptr; - constexpr size_t buff_elems_total = nv_flow::BUFFERS_NUM * nv_flow::SHMEM_DIM_Y * nv_flow::SHMEM_DIM_X; + constexpr size_t buff_elems_total = tma_flow::BUFFERS_NUM * tma_flow::SHMEM_DIM_Y * tma_flow::SHMEM_DIM_X; const size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TDM_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = @@ -1225,10 +1225,10 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu (out_act_mem + out_gate_mem) + TDM_SHMEM_ALIGNMENT; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - nv_flow::cast_fp8_gated_kernel, + tma_flow::cast_fp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - nv_flow::cast_fp8_gated_kernel + tma_flow::cast_fp8_gated_kernel <<>>( grad_ptr, input_act_ptr, input_gate_ptr, output_act_ptr, output_gate_ptr, @@ -1242,23 +1242,23 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu alignas(64) CUtensorMap tensor_map_output_gate{}; if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, nv_flow::SHMEM_DIM_Y, - nv_flow::SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, tma_flow::SHMEM_DIM_Y, + tma_flow::SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, nv_flow::SHMEM_DIM_Y, - nv_flow::SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, nv_flow::SHMEM_DIM_Y, - nv_flow::SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, nv_flow::SHMEM_DIM_Y, - nv_flow::SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, nv_flow::SHMEM_DIM_Y, - nv_flow::SHMEM_DIM_X, tensor_stride_elems, cols, + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, tma_flow::SHMEM_DIM_Y, + tma_flow::SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, tma_flow::SHMEM_DIM_Y, + tma_flow::SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, tma_flow::SHMEM_DIM_Y, + tma_flow::SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, tma_flow::SHMEM_DIM_Y, + tma_flow::SHMEM_DIM_X, tensor_stride_elems, cols, typeToNumBits(output->dtype())); - const size_t buff_elems_total = nv_flow::BUFFERS_NUM * nv_flow::SHMEM_DIM_Y * nv_flow::SHMEM_DIM_X; + const size_t buff_elems_total = tma_flow::BUFFERS_NUM * tma_flow::SHMEM_DIM_Y * tma_flow::SHMEM_DIM_X; const size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = @@ -1272,10 +1272,10 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - nv_flow::cast_fp8_gated_kernel, + tma_flow::cast_fp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - nv_flow::cast_fp8_gated_kernel + tma_flow::cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, @@ -1303,7 +1303,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } -#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { scaling_type = ScalingType::ROWWISE; @@ -1312,23 +1311,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { scaling_type = ScalingType::BIDIMENSIONAL; } -#endif const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; -#ifdef __HIP_PLATFORM_AMD__ - constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; - - const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); -#else - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; @@ -1341,12 +1328,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; -#endif -#ifndef __HIP_PLATFORM_AMD__ const dim3 grid(blocks_X, blocks_Y); const dim3 block_size(THREADS_PER_CHUNK); -#endif size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -1414,73 +1398,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } #endif // #ifdef __HIP_PLATFORM_AMD__ -#ifdef __HIP_PLATFORM_AMD__ - { - // TDM flow — uses mxfp8_kernel::cast_mxfp8_gated_kernel - constexpr size_t NV_THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; - constexpr size_t NV_THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; - - // Recompute shmem size with NV upstream constants - constexpr size_t NV_BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t NV_BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t NV_BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; - const size_t nv_buff_elems_total = NV_BUFFS_NUM * NV_BUFF_DIM_Y * NV_BUFF_DIM_X; - const size_t nv_input_buff_size = (nv_buff_elems_total * input_type_bit_size) / 8; - const size_t nv_output_buff_size = (nv_buff_elems_total * output_type_bit_size) / 8; - const size_t nv_buff_size_aligned_in = - DIVUP_TO_MULTIPLE(nv_input_buff_size, TDM_SHMEM_ALIGNMENT); - const size_t nv_buff_size_aligned_out = - DIVUP_TO_MULTIPLE(nv_output_buff_size, TDM_SHMEM_ALIGNMENT); - const size_t nv_grad_mem = (IS_DGATED ? nv_buff_size_aligned_in : 0); - const size_t nv_in_act_mem = nv_buff_size_aligned_in; - const size_t nv_in_gate_mem = nv_buff_size_aligned_in; - const size_t nv_in_mem = nv_grad_mem + nv_in_act_mem + nv_in_gate_mem; - const size_t nv_out_act_mem = nv_buff_size_aligned_out; - const size_t nv_out_gate_mem = (IS_DGATED ? nv_buff_size_aligned_out : 0); - size_t nv_out_mem = nv_out_act_mem + nv_out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { nv_out_mem *= 2; } - const size_t nv_shmem_size = nv_in_mem + nv_out_mem + TDM_SHMEM_ALIGNMENT; - - const size_t nv_blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); - const size_t nv_blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); - const size_t NV_THREADS_PER_CHUNK = USE_COLWISE_SCALING - ? NV_THREADS_PER_CHUNK_COLWISE : NV_THREADS_PER_CHUNK_NON_COLWISE; - const dim3 nv_grid(nv_blocks_X, nv_blocks_Y); - const dim3 nv_block(NV_THREADS_PER_CHUNK); - - auto nv_launch = [&](auto rowwise_tag, auto colwise_tag, auto threads_tag) { - constexpr bool RW = decltype(rowwise_tag)::value; - constexpr bool CW = decltype(colwise_tag)::value; - constexpr size_t TPC = decltype(threads_tag)::value; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, nv_shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - }; - - if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { - nv_launch(std::true_type{}, std::false_type{}, - std::integral_constant{}); - } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { - nv_launch(std::false_type{}, std::true_type{}, - std::integral_constant{}); - } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { - nv_launch(std::true_type{}, std::true_type{}, - std::integral_constant{}); - } - } -#else const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; @@ -1668,8 +1585,24 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { if (use_tma_kernels) { - // cast_fp8_gated handles both NVIDIA (TMA) and AMD (TDM) internally via #ifdef +#ifdef __HIP_PLATFORM_AMD__ + // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. + static const bool use_tdm_flow_fp8 = [] { + const char *env = std::getenv("NVTE_USE_TDM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_tdm_flow_fp8) { + cast_fp8_gated(grad, gated_input, output, stream); + } else { + if constexpr (IS_DGATED) { + cast_dgated(grad, gated_input, output, stream); + } else { + cast_gated(gated_input, output, stream); + } + } +#else cast_fp8_gated(grad, gated_input, output, stream); +#endif } else { if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); @@ -1680,11 +1613,12 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } else if (is_mxfp_scaling(output->scaling_mode)) { if (use_tma_kernels) { #ifdef __HIP_PLATFORM_AMD__ - static const bool use_nv_upstream_flow = [] { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. + static const bool use_tdm_flow = [] { + const char *env = std::getenv("NVTE_USE_TDM_FLOW"); return env != nullptr && env[0] == '1' && env[1] == '\0'; }(); - if (use_nv_upstream_flow) { + if (use_tdm_flow) { cast_mxfp8_gated(grad, gated_input, output, stream); } else { rocm_cast_mxfp8_gated(grad, gated_input, output, stream); From 3c851017c95b7cac37747f40aa628b53d79e7865 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 09:47:53 -0500 Subject: [PATCH 12/43] tdm: revert swizzled_* lines to NV upstream position Move swizzled_group_idx/swizzled_idx/shmem_offset_rowwise back to just before out_act.store_to(), matching NV upstream cast_gated_kernels.cuh lines 831-834, to minimize diff. Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/util/cast_gated_kernels.cuh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index dae53c36b..980c9d6c8 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1048,10 +1048,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 3. Scale elements #pragma unroll for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - Vec out_act; Vec out_gate; #pragma unroll @@ -1084,6 +1080,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); if constexpr (IS_DGATED) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); From 0007c883cc56e91bdc3e87e0657f777b926dcead Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 10:00:38 -0500 Subject: [PATCH 13/43] tdm: fix cast_mxfp8_gated to match NV upstream structure Two differences found vs NV upstream (line 968): 1. out_gate_mem: AMD TDM kernel always needs a gate shmem buffer regardless of IS_DGATED (kernel signature always includes gate output pointers), so restore AMD-specific: out_gate_mem = buff_size_aligned_out (always) vs NV: out_gate_mem = IS_DGATED ? buff_size_aligned_out : 0 2. in_mem: split into in_act_mem + in_gate_mem intermediate vars to match NV upstream style exactly. 3. AMD TDM dispatch: restore TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH dispatch (was accidentally dropped when removing the {} wrapper), guarded under #ifdef __HIP_PLATFORM_AMD__. NV uses switch(scaling_type). Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 980c9d6c8..c514346ea 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1405,12 +1405,41 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_mem = grad_mem + buff_size_aligned_in + buff_size_aligned_in; + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + const size_t out_act_mem = buff_size_aligned_out; +#ifdef __HIP_PLATFORM_AMD__ + const size_t out_gate_mem = buff_size_aligned_out; +#else const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); +#endif size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; +#ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) +#else switch (scaling_type) { case ScalingType::ROWWISE: NVTE_CHECK_CUDA(cudaFuncSetAttribute( From bacd2269480a9456def3adc21d21872c3a89cdad Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 10:16:29 -0500 Subject: [PATCH 14/43] tdm: use switch(scaling_type) for AMD TDM mxfp8 gated dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH with the same switch(scaling_type) structure as NV upstream. The TDM kernel shares the same ROWWISE_SCALING/COLWISE_SCALING/THREADS_PER_CHUNK template params as the NV kernel — SCALE_DIM_Y/X/IS_ALIGNED were ROCm-flow params that don't apply here. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index c514346ea..6ae90ad6c 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1419,26 +1419,56 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - }))); // NOLINT(*) + switch (scaling_type) { + case ScalingType::ROWWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::COLWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::BIDIMENSIONAL: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } #else switch (scaling_type) { case ScalingType::ROWWISE: From 4ba588356eedc1c25e127e7129ad1da7cd1b1f7c Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 10:18:17 -0500 Subject: [PATCH 15/43] tdm: hoist shared next-stage offset vars above #ifdef in cast_mxfp8_gated_kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit next_buff, next_stage_offset_Y, global_offset_Y, global_offset_X, next_buff_offset are identical in both the TMA and TDM branches — declare them once above the #ifndef __HIP_PLATFORM_AMD__ guard. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 6ae90ad6c..8e725b873 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -646,16 +646,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (next_stage < STAGES) { -#ifndef __HIP_PLATFORM_AMD__ - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_stage % BUFFS_NUM; const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; +#ifndef __HIP_PLATFORM_AMD__ + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + if constexpr (IS_DGATED) { copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, @@ -669,11 +669,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) is_master_thread); } #else // __HIP_PLATFORM_AMD__ — TDM prefetch next stage - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DGATED) { tdm::copy_2d_to_shared(&in_grad_sh[next_buff_offset], mx_tmap_grad, global_offset_X, global_offset_Y); From 7dbf21866243771373a49e89d8dc3829afa3b0f9 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 10:19:34 -0500 Subject: [PATCH 16/43] tdm: hoist shared shmem computation above #ifdef in cast_fp8_gated The shmem size calculation is identical for TDM and TMA paths (TDM_SHMEM_ALIGNMENT == TMA_SHMEM_ALIGNMENT == 128), so declare it once above the #ifdef __HIP_PLATFORM_AMD__ guard. Only the pointer setup and kernel launch remain platform-specific. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 8e725b873..c6260756f 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1196,27 +1196,27 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, -#ifdef __HIP_PLATFORM_AMD__ - const IType *grad_ptr = IS_DGATED - ? reinterpret_cast(grad.data.dptr) : nullptr; - const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); - const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; - OType *output_act_ptr = reinterpret_cast(output->data.dptr); - OType *output_gate_ptr = IS_DGATED - ? reinterpret_cast(output->data.dptr) + cols : nullptr; - constexpr size_t buff_elems_total = tma_flow::BUFFERS_NUM * tma_flow::SHMEM_DIM_Y * tma_flow::SHMEM_DIM_X; const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TDM_SHMEM_ALIGNMENT); + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TDM_SHMEM_ALIGNMENT); + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TDM_SHMEM_ALIGNMENT; + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; + +#ifdef __HIP_PLATFORM_AMD__ + const IType *grad_ptr = IS_DGATED + ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *input_act_ptr = reinterpret_cast(gated_input.data.dptr); + const IType *input_gate_ptr = reinterpret_cast(gated_input.data.dptr) + cols; + OType *output_act_ptr = reinterpret_cast(output->data.dptr); + OType *output_gate_ptr = IS_DGATED + ? reinterpret_cast(output->data.dptr) + cols : nullptr; NVTE_CHECK_CUDA(cudaFuncSetAttribute( tma_flow::cast_fp8_gated_kernel, @@ -1252,19 +1252,6 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu tma_flow::SHMEM_DIM_X, tensor_stride_elems, cols, typeToNumBits(output->dtype())); - const size_t buff_elems_total = tma_flow::BUFFERS_NUM * tma_flow::SHMEM_DIM_Y * tma_flow::SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - NVTE_CHECK_CUDA(cudaFuncSetAttribute( tma_flow::cast_fp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); From 293d970460a49549927b4da28aa07675549b5ddb Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 10:29:49 -0500 Subject: [PATCH 17/43] tdm: collapse duplicate switch(scaling_type) blocks in cast_mxfp8_gated MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both AMD and NV call mxfp8_kernel::cast_mxfp8_gated_kernel — the TDM kernel at line 435 is also inside namespace mxfp8_kernel. The two switch blocks were identical except for the namespace qualifier, so remove the #ifdef and keep one unified switch block. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 53 ------------------- 1 file changed, 53 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index c6260756f..24ead840e 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1400,58 +1400,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; -#ifdef __HIP_PLATFORM_AMD__ - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } -#else switch (scaling_type) { case ScalingType::ROWWISE: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -1505,7 +1453,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK_CUDA(cudaGetLastError()); break; } -#endif ); // NOLINT(*) ); // NOLINT(*) } From 53b5d226e88ae6ead529ddf7487d12938f217631 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 10:51:05 -0500 Subject: [PATCH 18/43] tdm: remove tma_flow namespace; prefix ROCm-specific constants with ROCM_ - Remove namespace tma_flow from cast_gated_kernels.cuh; constants and cast_fp8_gated_kernel now live directly in gated_kernels namespace, consistent with mxfp8_kernel::cast_mxfp8_gated_kernel organization - In rocm_cast_gated_kernels.cuh, prefix all constants that conflict with the now-unnamespaced tma_flow constants with ROCM_: ROCM_CHUNK_DIM_Y/X, ROCM_THREADS_PER_CHUNK, ROCM_THREADS_PER_CHUNK_X/Y, ROCM_BUFFERS_NUM, ROCM_BUFFER_DIM_Y/X, ROCM_SHMEM_DIM_Y/X, ROCM_BUFFER_STAGES_NUM, ROCM_ITERATIONS - Remove duplicate sigmoidf definition from rocm_cast_gated_kernels.cuh (already defined in cast_gated_kernels.cuh which includes it) Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 42 +++---- .../common/util/rocm_cast_gated_kernels.cuh | 110 +++++++++--------- 2 files changed, 72 insertions(+), 80 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 24ead840e..cbdc64a18 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -37,10 +37,6 @@ namespace transformer_engine { namespace gated_kernels { -// TMA/TDM flow constants — used on both NVIDIA (TMA) and AMD gfx1250 (TDM). -// On AMD, the ROCm flow constants live in rocm_cast_gated_kernels.cuh; -// these are in the tma_flow namespace to avoid collision. -namespace tma_flow { constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -401,8 +397,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif #endif // #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) } -} // namespace tma_flow - namespace mxfp8_kernel { constexpr size_t CHUNK_DIM_Y = 64; @@ -1181,14 +1175,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, tma_flow::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, tma_flow::CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); float *const amax_ptr = reinterpret_cast(output->amax.dptr); float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); float *const scale_ptr = reinterpret_cast(output->scale.dptr); - const dim3 block_dim(tma_flow::THREADS_PER_CHUNK); + const dim3 block_dim(THREADS_PER_CHUNK); const dim3 grid_dim(blocks_X, blocks_Y); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -1196,7 +1190,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - constexpr size_t buff_elems_total = tma_flow::BUFFERS_NUM * tma_flow::SHMEM_DIM_Y * tma_flow::SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = @@ -1219,10 +1213,10 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ? reinterpret_cast(output->data.dptr) + cols : nullptr; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - tma_flow::cast_fp8_gated_kernel, + cast_fp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - tma_flow::cast_fp8_gated_kernel + cast_fp8_gated_kernel <<>>( grad_ptr, input_act_ptr, input_gate_ptr, output_act_ptr, output_gate_ptr, @@ -1236,27 +1230,27 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu alignas(64) CUtensorMap tensor_map_output_gate{}; if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, tma_flow::SHMEM_DIM_Y, - tma_flow::SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); } const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, tma_flow::SHMEM_DIM_Y, - tma_flow::SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, tma_flow::SHMEM_DIM_Y, - tma_flow::SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, tma_flow::SHMEM_DIM_Y, - tma_flow::SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, tma_flow::SHMEM_DIM_Y, - tma_flow::SHMEM_DIM_X, tensor_stride_elems, cols, + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, typeToNumBits(output->dtype())); NVTE_CHECK_CUDA(cudaFuncSetAttribute( - tma_flow::cast_fp8_gated_kernel, + cast_fp8_gated_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - tma_flow::cast_fp8_gated_kernel + cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index 910fe9356..bc75d6e97 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -23,27 +23,25 @@ namespace gated_kernels { constexpr size_t ALIGNMENT_SIZE = 128; // TODO: Identify optimal chunk/thread size for MI350+ -constexpr size_t CHUNK_DIM_Y = 64; -constexpr size_t CHUNK_DIM_X = 64; -constexpr size_t THREADS_PER_CHUNK = 256; -constexpr size_t THREADS_PER_CHUNK_X = 64; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 256 / 64 -constexpr size_t BUFFERS_NUM = 1; // No async load for HIP -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); - -__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } +constexpr size_t ROCM_CHUNK_DIM_Y = 64; +constexpr size_t ROCM_CHUNK_DIM_X = 64; +constexpr size_t ROCM_THREADS_PER_CHUNK = 256; +constexpr size_t ROCM_THREADS_PER_CHUNK_X = 64; +constexpr size_t ROCM_THREADS_PER_CHUNK_Y = ROCM_THREADS_PER_CHUNK / ROCM_THREADS_PER_CHUNK_X; // 4 = 256 / 64 +constexpr size_t ROCM_BUFFERS_NUM = 1; // No async load for HIP +constexpr size_t ROCM_ROCM_BUFFER_DIM_Y = 32; +constexpr size_t ROCM_BUFFER_DIM_X = ROCM_CHUNK_DIM_X; // 64 +constexpr size_t ROCM_ROCM_SHMEM_DIM_Y = ROCM_ROCM_BUFFER_DIM_Y; // 32 +constexpr size_t ROCM_ROCM_SHMEM_DIM_X = ROCM_BUFFER_DIM_X; // 64 + +constexpr size_t ROCM_BUFFER_STAGES_NUM = ROCM_ROCM_BUFFER_DIM_Y / ROCM_THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ROCM_ITERATIONS = ROCM_CHUNK_DIM_Y / ROCM_ROCM_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(ROCM_ITERATIONS >= 1); template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) +__global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) cast_mxfp8_gated_kernel(const IType *grad_ptr, const IType *input_act, const IType *input_gate, @@ -58,22 +56,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = ROCM_CHUNK_DIM_Y; // 64 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = ROCM_CHUNK_DIM_X / SCALE_DIM_X; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = ROCM_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = ROCM_CHUNK_DIM_X; // 64 const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const int chunk_offset_Y = blockIdx.y * ROCM_CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * ROCM_CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + const int tid_Y = threadIdx.x / ROCM_THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % ROCM_THREADS_PER_CHUNK_X; constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType); @@ -88,8 +86,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_elems = ROCM_ROCM_SHMEM_DIM_Y * ROCM_ROCM_SHMEM_DIM_X; + const size_t buff_elems_total = ROCM_BUFFERS_NUM * buff_elems; const size_t buff_size_aligned_in = DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; const size_t buff_size_aligned_out = @@ -124,44 +122,44 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); } - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + __shared__ float stage_amax_sh[ROCM_THREADS_PER_CHUNK_Y][ROCM_CHUNK_DIM_X]; __syncthreads(); - for (int it = 0; it < ITERATIONS; it++) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + for (int it = 0; it < ROCM_ITERATIONS; it++) { + const int chunk_it_offset_y = chunk_offset_Y + it * ROCM_ROCM_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; // Initiate bulk tensor copy if constexpr (IS_DGATED) { copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, - cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + cols, ROCM_ROCM_SHMEM_DIM_Y, ROCM_ROCM_SHMEM_DIM_X, rows, cols); } // Act copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); - + 2*cols, ROCM_ROCM_SHMEM_DIM_Y, ROCM_ROCM_SHMEM_DIM_X, rows, cols); + // Gate copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + 2*cols, ROCM_ROCM_SHMEM_DIM_Y, ROCM_ROCM_SHMEM_DIM_X, rows, cols); __syncthreads(); const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; + const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * ROCM_BUFFER_DIM_Y; - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; + float after_dact_reg[ROCM_BUFFER_STAGES_NUM]; + float after_dgate_reg[ROCM_BUFFER_STAGES_NUM]; float thread_Y_mx_block_amax = 0.0f; float thread_Y_mx_block_amax_gate = 0.0f; - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + for (int stage = 0; stage < ROCM_BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * ROCM_THREADS_PER_CHUNK_Y; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + const int shmem_idx = shmem_offset_y * ROCM_SHMEM_DIM_X + shmem_offset_x; const size_t row = row_base + shmem_offset_y; const bool row_out_of_bounds = (row >= rows); @@ -264,7 +262,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if (tid_Y == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + for (int y = 1; y < ROCM_THREADS_PER_CHUNK_Y; ++y) { thread_Y_mx_block_amax_gate = fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); } @@ -294,11 +292,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + for (int stage = 0; stage < ROCM_BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * ROCM_THREADS_PER_CHUNK_Y; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + const int shmem_idx = shmem_offset_y * ROCM_ROCM_SHMEM_DIM_X + shmem_offset_x; out_gate_colwise_sh[shmem_idx] = static_cast(scale_reciprocal * after_dgate_reg[stage]); @@ -311,7 +309,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if (tid_Y == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + for (int y = 1; y < ROCM_THREADS_PER_CHUNK_Y; ++y) { thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); } stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax @@ -340,11 +338,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + for (int stage = 0; stage < ROCM_BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * ROCM_THREADS_PER_CHUNK_Y; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + const int shmem_idx = shmem_offset_y * ROCM_ROCM_SHMEM_DIM_X + shmem_offset_x; out_act_colwise_sh[shmem_idx] = static_cast(scale_reciprocal * after_dact_reg[stage]); @@ -355,19 +353,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (USE_ROWWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + chunk_it_offset_y, output_cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + chunk_it_offset_y, output_cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + chunk_it_offset_y, output_cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, - chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); + chunk_it_offset_y, output_cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); } } __syncthreads(); @@ -388,10 +386,10 @@ void rocm_cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); const dim3 grid(blocks_X, blocks_Y); - const dim3 block_size(THREADS_PER_CHUNK); + const dim3 block_size(ROCM_THREADS_PER_CHUNK); size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -417,7 +415,7 @@ void rocm_cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor constexpr size_t input_type_bit_size = TypeInfo::size; constexpr size_t output_type_bit_size = TypeInfo::size; - const size_t buff_elems_total = BUFFERS_NUM * BUFFER_DIM_Y * BUFFER_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * ROCM_BUFFER_DIM_Y * BUFFER_DIM_X; const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; const size_t buff_size_aligned_in = From a0a9ab6a8ecc1a3d37a1f4dec18cb46d80588c29 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 11:13:17 -0500 Subject: [PATCH 19/43] util: apply ROCM_ prefix to ROCm-specific constants in cast/dequantize kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same cleanup as cast_gated_kernels.cuh: prefix ROCm-flow constants with ROCM_ in rocm_cast_kernels.cuh and rocm_dequantize_kernels.cuh to disambiguate from TDM-flow constants sharing the same namespace. - rocm_cast_kernels.cuh: rename ELEMS_PER_THREAD, THREADS_PER_CHUNK_X_ROWWISE, THREADS_PER_CHUNK_Y_ROWWISE, THREADS_PER_CHUNK_X_COLWISE, TILE_DIM → ROCM_* - rocm_dequantize_kernels.cuh: rename all constants in dequantization namespace (CHUNK_DIM_Y/X, THREADS_PER_CHUNK, BUFFERS_NUM, ELEMS_PER_THREAD, BUFFER_DIM_Y/X, SHMEM_DIM_Y/X, THREADS_PER_CHUNK_X_*, ITERATIONS) → ROCM_* - dequantize_kernels.cuh: add TDM-flow constants directly into dequantization namespace (CHUNK_DIM_Y/X, THREADS_PER_CHUNK, BUFFERS_NUM, ELEMS_PER_THREAD, BUFFER_DIM_Y, SHMEM_DIM_Y/X, THREADS_PER_CHUNK_X_*, ITERATIONS) so the TDM/NV kernel is self-contained and no longer depends on the ROCm include Co-Authored-By: Claude Sonnet 4 --- .../common/util/dequantize_kernels.cuh | 14 +++ .../common/util/rocm_cast_kernels.cuh | 72 +++++++-------- .../common/util/rocm_dequantize_kernels.cuh | 91 ++++++++++--------- 3 files changed, 96 insertions(+), 81 deletions(-) diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index f543f2936..3d03bbaee 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -40,6 +40,20 @@ namespace transformer_engine { namespace dequantization { +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t BUFFERS_NUM = 2; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t BUFFER_DIM_Y = 16; +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16 +constexpr size_t SHMEM_DIM_X = CHUNK_DIM_X; // 128 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 + template __global__ void __launch_bounds__(THREADS_PER_CHUNK) dequantize_mxfp8_kernel( diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 3ede0c74c..4b0927133 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -33,19 +33,19 @@ constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; -constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t ROCM_ELEMS_PER_THREAD = 16; constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t ROCM_THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ROCM_ELEMS_PER_THREAD; // 4 = 64 / 16 +constexpr size_t ROCM_THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / ROCM_THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 +constexpr size_t ROCM_THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 + MXFP8_BUFFER_DIM_Y / ROCM_THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 template partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; + Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; if constexpr (IS_DBIAS) { if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { @@ -172,16 +172,16 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + Vec in; + Vec act_in; + Vec out_c; const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; #pragma unroll for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; stage++) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int stage_offset_Y = stage * ROCM_THREADS_PER_CHUNK_Y_ROWWISE; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X_rowwise; @@ -194,10 +194,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; + float in_compute[ROCM_ELEMS_PER_THREAD]; #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { + for (int j = 0; j < ROCM_ELEMS_PER_THREAD; j++) { const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); @@ -246,7 +246,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { + for (int j = 0; j < ROCM_ELEMS_PER_THREAD; j++) { out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); } out_c.store_to(&out_rowwise_sh[shmem_offset_y][shmem_offset_x]); @@ -330,9 +330,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if constexpr (IS_DBIAS) { if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; - constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; - constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + constexpr size_t Y = ROCM_THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = ROCM_THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ROCM_ELEMS_PER_THREAD]; if (tid_rowwise_Y > 0) { #pragma unroll @@ -346,18 +346,18 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if (tid_rowwise_Y == 0) { #pragma unroll for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; c++) { - Vec other_row_dbias; + Vec other_row_dbias; const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + const int right_bound = dbias_rowwise_offset_X + ROCM_ELEMS_PER_THREAD - 1; #pragma unroll for (int i = 0; i < Y; i++) { other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { + for (int j = 0; j < ROCM_ELEMS_PER_THREAD; j++) { partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; } } @@ -410,13 +410,13 @@ template void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, hipStream_t stream); -constexpr size_t TILE_DIM = 32; +constexpr size_t ROCM_TILE_DIM = 32; template __global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { - __shared__ float tile[TILE_DIM][TILE_DIM]; + __shared__ float tile[ROCM_TILE_DIM][ROCM_TILE_DIM]; - int tile_start_col = blockIdx.x * TILE_DIM; - int tile_start_row = blockIdx.y * TILE_DIM; + int tile_start_col = blockIdx.x * ROCM_TILE_DIM; + int tile_start_row = blockIdx.y * ROCM_TILE_DIM; int thread_col_in_tile = threadIdx.x; int thread_row_in_tile = threadIdx.y; @@ -430,7 +430,7 @@ __global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_o } __syncthreads(); - for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { + for (int stride = ROCM_TILE_DIM / 2; stride > 0; stride /= 2) { if (thread_row_in_tile < stride) { tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; } @@ -445,8 +445,8 @@ __global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_o template void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, hipStream_t stream, Tensor* partial_sum_workspace) { - dim3 block_dim_partial(TILE_DIM, TILE_DIM); - dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); + dim3 block_dim_partial(ROCM_TILE_DIM, ROCM_TILE_DIM); + dim3 grid_dim_partial(DIVUP(cols, ROCM_TILE_DIM), DIVUP(rows, ROCM_TILE_DIM)); const size_t partial_rows = grid_dim_partial.y; float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); @@ -480,7 +480,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); if (workspace->data.dptr == nullptr) { if constexpr (IS_DACT) { - const size_t partial_rows = DIVUP(rows, TILE_DIM); + const size_t partial_rows = DIVUP(rows, ROCM_TILE_DIM); size_t total_elements = (rows * cols) + (partial_rows * cols); workspace->data.shape = {total_elements}; workspace->data.dtype = DType::kFloat32; @@ -505,7 +505,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso // The values to reduce are the result of the dAct function. NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); - const size_t partial_rows = DIVUP(rows, TILE_DIM); + const size_t partial_rows = DIVUP(rows, ROCM_TILE_DIM); const size_t full_size_bytes = rows * cols * sizeof(float); workspace_buffer = *workspace; workspace_buffer.data.shape = {rows, cols}; diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 7e320dc82..5bb5f4671 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -25,24 +25,25 @@ namespace transformer_engine { namespace dequantization { -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 128; -constexpr size_t BUFFERS_NUM = 2; - -constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(ITERATIONS >= 1); +constexpr size_t ROCM_CHUNK_DIM_Y = 128; +constexpr size_t ROCM_CHUNK_DIM_X = 128; +constexpr size_t ROCM_THREADS_PER_CHUNK = 128; +constexpr size_t ROCM_BUFFERS_NUM = 2; + +constexpr size_t ROCM_ELEMS_PER_THREAD = 16; +constexpr size_t ROCM_BUFFER_DIM_Y = 16; +constexpr size_t ROCM_BUFFER_DIM_X = ROCM_CHUNK_DIM_X; // 128 +constexpr size_t ROCM_SHMEM_DIM_Y = ROCM_BUFFER_DIM_Y; // 16 +constexpr size_t ROCM_SHMEM_DIM_X = ROCM_BUFFER_DIM_X; // 128 + +constexpr size_t ROCM_THREADS_PER_CHUNK_X_ROWWISE = + ROCM_CHUNK_DIM_X / ROCM_ELEMS_PER_THREAD; // 8 = 128 / 16 +constexpr size_t ROCM_THREADS_PER_CHUNK_X_COLWISE = ROCM_CHUNK_DIM_X; // 128 +constexpr size_t ROCM_ITERATIONS = ROCM_CHUNK_DIM_Y / ROCM_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(ROCM_ITERATIONS >= 1); template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) +__global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) dequantize_mxfp8_kernel(const IType *input_ptr, OType *output_ptr, const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, @@ -50,49 +51,49 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = ROCM_CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = ROCM_CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = ROCM_CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = ROCM_CHUNK_DIM_X; // 128 constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + DIVUP(SCALE_DIM_X, ROCM_ELEMS_PER_THREAD); // 2 = 32 / 16 constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(IType); - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const int chunk_offset_Y = blockIdx.y * ROCM_CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * ROCM_CHUNK_DIM_X; const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; - const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; - const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; - const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + const int tid_rowwise_Y = threadIdx.x / ROCM_THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % ROCM_THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / ROCM_THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % ROCM_THREADS_PER_CHUNK_X_COLWISE; const int thread_offset_Y = tid_rowwise_Y; - const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + const int thread_offset_X_rowwise = tid_rowwise_X * ROCM_ELEMS_PER_THREAD; // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - alignas(128) __shared__ IType in_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; - alignas(128) __shared__ OType out_sh[SHMEM_DIM_Y][SHMEM_DIM_X]; + alignas(128) __shared__ IType in_sh[ROCM_SHMEM_DIM_Y][ROCM_SHMEM_DIM_X]; + alignas(128) __shared__ OType out_sh[ROCM_SHMEM_DIM_Y][ROCM_SHMEM_DIM_X]; - for (int iter = 0; iter < ITERATIONS; iter++) { - const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + for (int iter = 0; iter < ROCM_ITERATIONS; iter++) { + const int chunk_it_offset_y = chunk_offset_Y + iter * ROCM_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, rows, cols); + chunk_it_offset_y, cols, ROCM_SHMEM_DIM_Y, + ROCM_SHMEM_DIM_X, rows, cols); __syncthreads(); const int scale_offset_Y = - USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) - : (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y); + USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * ROCM_BUFFER_DIM_Y + tid_rowwise_Y) + : (scales_colwise_chunk_offset_Y + (iter * ROCM_BUFFER_DIM_Y) / SCALE_DIM_Y); const int scale_offset_X = USE_ROWWISE_SCALING @@ -104,21 +105,21 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec out; + Vec in; + Vec out; const int shmem_offset_y = thread_offset_Y; const int shmem_offset_x = thread_offset_X_rowwise; in.load_from(&in_sh[shmem_offset_y][shmem_offset_x]); #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; j++) { + for (int j = 0; j < ROCM_ELEMS_PER_THREAD; j++) { out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); } out.store_to(&out_sh[shmem_offset_y][shmem_offset_x]); } else { #pragma unroll - for (int i = 0; i < BUFFER_DIM_Y; i++) { + for (int i = 0; i < ROCM_BUFFER_DIM_Y; i++) { const float elt = static_cast(in_sh[i][tid_colwise_X]); out_sh[i][tid_colwise_X] = static_cast(block_scale * elt); } @@ -127,8 +128,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, rows, cols); + chunk_it_offset_y, cols, ROCM_SHMEM_DIM_Y, + ROCM_SHMEM_DIM_X, rows, cols); __syncthreads(); } @@ -146,8 +147,8 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); const size_t scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); const size_t scales_X_colwise = cols; @@ -160,7 +161,7 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; - const dim3 block(THREADS_PER_CHUNK); + const dim3 block(ROCM_THREADS_PER_CHUNK); const dim3 grid(chunks_X, chunks_Y); TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( From fdb4b1a21ad15dce7c05c1dc2a185952abb8ec30 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 11:43:14 -0500 Subject: [PATCH 20/43] util: address PR review comments on cast_kernels.cuh and dequantize_kernels.cuh - Remove AMD-only template params (SCALE_DIM_Y_TMPL, SCALE_DIM_X_TMPL, IS_ALIGNED) from cast_mxfp8_2D_kernel signature; pass raw pointers from launcher instead - Revert f16 amax computation to __hmax/__habs with thread_amax_f16 (NV upstream pattern) - Hoist next-stage offset variables above #ifndef to avoid duplicate declarations - Remove #ifndef guard around ptx::floatx2 block_scale_inverse_2x (works on HIP) - Fix __shared__ alignas(...) order for AMD shared memory declarations - Replace AMD TDM launcher with cast_gated_kernels.cuh raw-pointer pattern - Remove unnecessary TRANSFORMER_ENGINE_SWITCH_CONDITION IS_ALIGNED wrapper from AMD TDM dequantize launcher to match NV upstream structure Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_kernels.cuh | 125 +++++------------- .../common/util/dequantize_kernels.cuh | 20 +-- 2 files changed, 39 insertions(+), 106 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 3e29ea609..38affbdff 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -53,11 +53,7 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 template + bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_2D_kernel( #ifdef __HIP_PLATFORM_AMD__ @@ -241,16 +237,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (next_stage < STAGES) { -#ifndef __HIP_PLATFORM_AMD__ - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_stage % BUFFS_NUM; const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; +#ifndef __HIP_PLATFORM_AMD__ + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, @@ -261,21 +257,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } #else // __HIP_PLATFORM_AMD__ — TDM prefetch next stage - { - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2(&in_sh[next_buff_offset], tmap_in, - global_offset_X, global_offset_Y, - &act_in_sh[next_buff_offset], tmap_act_in, - global_offset_X, global_offset_Y); - } else { - tdm::copy_2d_to_shared(&in_sh[next_buff_offset], tmap_in, - global_offset_X, global_offset_Y); - } + if constexpr (IS_DACT) { + tdm::copy_2d_to_shared_x2(&in_sh[next_buff_offset], tmap_in, + global_offset_X, global_offset_Y, + &act_in_sh[next_buff_offset], tmap_act_in, + global_offset_X, global_offset_Y); + } else { + tdm::copy_2d_to_shared(&in_sh[next_buff_offset], tmap_in, + global_offset_X, global_offset_Y); } #endif // __HIP_PLATFORM_AMD__ } @@ -311,12 +300,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax = fmaxf(thread_amax, fabsf(static_cast(in_colwise_IType[i]))); + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } + thread_amax = static_cast(thread_amax_f16); } else { #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { @@ -367,9 +358,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); -#ifndef __HIP_PLATFORM_AMD__ const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -#endif // 3. Scale elements #pragma unroll @@ -413,7 +402,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } thread_amax = - fmaxf(fabsf(static_cast(thread_amax_2x.x)), fabsf(static_cast(thread_amax_2x.y))); + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } else if constexpr (IS_CACHED_ACT_OP) { // ensures that all writes to cache made in the section above are visible to all threads __syncthreads(); @@ -448,7 +437,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } if constexpr (!std::is_same_v) { thread_amax = - fmaxf(fabsf(static_cast(thread_amax_2x.x)), fabsf(static_cast(thread_amax_2x.y))); + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } } else { #pragma unroll @@ -715,11 +704,11 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned #ifdef __HIP_PLATFORM_AMD__ - alignas(TDM_SHMEM_ALIGNMENT) __shared__ + __shared__ alignas(TDM_SHMEM_ALIGNMENT) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ + __shared__ alignas(TDM_SHMEM_ALIGNMENT) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ + __shared__ alignas(TDM_SHMEM_ALIGNMENT) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; #else __shared__ alignas(TMA_SHMEM_ALIGNMENT) @@ -1327,66 +1316,16 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - { - // TDM flow — uses mxfp8_kernel::cast_mxfp8_2D_kernel - constexpr bool NV_ROWWISE = (SCALE_DIM_X > 1); - constexpr bool NV_COLWISE = (SCALE_DIM_Y > 1); - constexpr size_t NV_CAST_DBIAS_ONLY_Y = (IS_DBIAS && (!IS_DACT) && (!IS_ACT)) ? 128 : 64; - constexpr size_t NV_CAST_DBIAS_ONLY_X = NV_CAST_DBIAS_ONLY_Y; - constexpr size_t NV_CAST_DBIAS_ONLY_T = NV_CAST_DBIAS_ONLY_Y; - - constexpr size_t NV_THREADS_X = NV_CAST_DBIAS_ONLY_X / mxfp8_kernel::SCALE_DIM_X; - constexpr size_t NV_THREADS_Y = NV_CAST_DBIAS_ONLY_T / NV_THREADS_X; - constexpr size_t NV_BUFF_DIM_Y = NV_THREADS_Y; - constexpr size_t NV_BUFF_DIM_X = NV_CAST_DBIAS_ONLY_X; - - constexpr size_t NV_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; - constexpr size_t nv_buff_elems = NV_BUFF_DIM_Y * NV_BUFF_DIM_X; - constexpr size_t nv_buff_elems_total = mxfp8_kernel::BUFFS_NUM * nv_buff_elems; - constexpr size_t nv_input_type_bit_size = TypeInfo::size; - constexpr size_t nv_output_type_bit_size = TypeInfo::size; - constexpr size_t nv_input_buff_size = (nv_buff_elems_total * nv_input_type_bit_size) / 8; - constexpr size_t nv_output_buff_size = (nv_buff_elems_total * nv_output_type_bit_size) / 8; - constexpr size_t nv_buff_size_aligned_in = - DIVUP_TO_MULTIPLE(nv_input_buff_size, NV_SHMEM_ALIGNMENT); - constexpr size_t nv_buff_size_aligned_out = - DIVUP_TO_MULTIPLE(nv_output_buff_size, NV_SHMEM_ALIGNMENT); - - constexpr size_t nv_elt_input_mem = nv_buff_size_aligned_in; - constexpr size_t nv_act_input_mem = (IS_DACT ? nv_buff_size_aligned_in : 0); - constexpr size_t nv_in_mem = nv_elt_input_mem + nv_act_input_mem; - - const size_t nv_out_rowwise_mem = (use_rowwise_scaling ? nv_buff_size_aligned_out : 0); - const size_t nv_out_colwise_mem = (use_colwise_scaling ? nv_buff_size_aligned_out : 0); - const size_t nv_out_mem = nv_out_rowwise_mem + nv_out_colwise_mem; - - const size_t nv_dshmem_size = nv_in_mem + nv_out_mem + NV_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, nv_dshmem_size)); - - mxfp8_kernel::cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - }))); // NOLINT(*) + const IType *tensor_map_input = reinterpret_cast(input.data.dptr); + const IType *tensor_map_act_input = + IS_DACT ? reinterpret_cast(act_input->data.dptr) : nullptr; + OType *tensor_map_output_rowwise = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + OType *tensor_map_output_colwise = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 3d03bbaee..be13ce738 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -388,16 +388,12 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(OType))), IS_ALIGNED, - { - // TDM flow — uses dequantization::dequantize_mxfp8_kernel - dequantization::dequantize_mxfp8_kernel - <<>>( - reinterpret_cast(input_data.dptr), - reinterpret_cast(output->data.dptr), - scales_ptr, rows, cols, scales_stride); - }); // NOLINT(*) + // TDM flow — uses dequantization::dequantize_mxfp8_kernel + dequantization::dequantize_mxfp8_kernel + <<>>( + reinterpret_cast(input_data.dptr), + reinterpret_cast(output->data.dptr), + scales_ptr, rows, cols, scales_stride); #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_output{}; @@ -409,14 +405,12 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s dequantize_mxfp8_kernel <<>>(tensor_map_input, tensor_map_output, scales_ptr, - rows, cols, scales_stride);); // NOLINT(*) + rows, cols, scales_stride); #endif // #ifdef __HIP_PLATFORM_AMD__ ); // NOLINT(*) ); // NOLINT(*) ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ ); // NOLINT(*) -#endif NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace dequantization From a732d3598e21d8ae6f650447a3aa8bc49033e308 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 12:08:23 -0500 Subject: [PATCH 21/43] util: address 4 more PR review comments on cast/dequantize kernels - Remove #ifdef guard from shmem declarations in cast_mxfp8_2D_kernel: TDM_SHMEM_ALIGNMENT == TMA_SHMEM_ALIGNMENT == 128, use TMA_SHMEM_ALIGNMENT throughout to minimize diff with NV upstream - Update fp8_quantize AMD path to check NVTE_USE_TDM_FLOW: if set, call fp8_quantize_arch_ge_100 (TDM kernel); otherwise fall back to fp8_quantize_rocm - Remove #ifdef and alignas swap in dequantize_mxfp8_kernel shmem declarations: use __shared__ alignas(TMA_SHMEM_ALIGNMENT) for both platforms - Replace NVTE_USE_NV_UPSTREAM_FLOW with NVTE_USE_TDM_FLOW in dequantize_helper AMD path to match cast_gated_kernels.cuh pattern Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_kernels.cuh | 25 ++++++++++--------- .../common/util/dequantize_kernels.cuh | 24 +++++++----------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 38affbdff..5aa3fcec0 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -703,21 +703,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned -#ifdef __HIP_PLATFORM_AMD__ - __shared__ alignas(TDM_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TDM_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TDM_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; -#else __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; -#endif #ifndef __HIP_PLATFORM_AMD__ constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; @@ -1636,9 +1627,19 @@ void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *no dbias, workspace, stream); } #else - // AMD - fp8_quantize_rocm(input, act_input, noop, output, - dbias, workspace, stream); + // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. + static const bool use_tdm_flow = [] { + const char *env = std::getenv("NVTE_USE_TDM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_tdm_flow) { + fp8_quantize_arch_ge_100(input, act_input, noop, + output, dbias, workspace, + stream); + } else { + fp8_quantize_rocm(input, act_input, noop, output, + dbias, workspace, stream); + } #endif //#ifndef __HIP_PLATFORM_AMD__ } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index be13ce738..d55dde265 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -96,13 +96,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned -#ifdef __HIP_PLATFORM_AMD__ - alignas(TDM_SHMEM_ALIGNMENT) __shared__ IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; - alignas(TDM_SHMEM_ALIGNMENT) __shared__ OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; -#else __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; -#endif constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; #ifndef __HIP_PLATFORM_AMD__ @@ -425,16 +420,15 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) dequantization::fp8_dequantize(input, output, stream); } else if (is_mxfp_scaling(input.scaling_mode)) { #ifdef __HIP_PLATFORM_AMD__ - { - static const bool use_nv_upstream_flow = [] { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; - }(); - if (use_nv_upstream_flow) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - rocm_mxfp8_dequantize(input, output, stream); - } + // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. + static const bool use_tdm_flow = [] { + const char *env = std::getenv("NVTE_USE_TDM_FLOW"); + return env != nullptr && env[0] == '1' && env[1] == '\0'; + }(); + if (use_tdm_flow) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + rocm_mxfp8_dequantize(input, output, stream); } #else if (is_supported_by_CC_100()) { From 1a004dfc29c2aaa8cf0d27b3900cf0ad4c780ee4 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 16:08:56 -0500 Subject: [PATCH 22/43] util: hoist shared next-iter offset vars above #ifndef in cast_mxfp8_2D_kernel next_buff, chunk_it_offset_y, chunk_it_offset_x were duplicated in both the TMA and TDM prefetch branches. Hoist above #ifndef to declare once, matching the pattern from cast_gated_kernels.cuh. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_kernels.cuh | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 5aa3fcec0..19004841b 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -779,10 +779,10 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; if (next_iter < FP8_ITERATIONS) { -#ifndef __HIP_PLATFORM_AMD__ const size_t next_buff = next_iter % FP8_BUFFERS_NUM; const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_x = chunk_offset_X; +#ifndef __HIP_PLATFORM_AMD__ if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, @@ -793,19 +793,14 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); } #else // __HIP_PLATFORM_AMD__ — TDM prefetch next iteration - { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - tdm::copy_2d_to_shared_x2(&in_sh[next_buff][0][0], fp8_tmap_in, - chunk_it_offset_x, chunk_it_offset_y, - &act_in_sh[next_buff][0][0], fp8_tmap_act_in, - chunk_it_offset_x, chunk_it_offset_y); - } else { - tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], fp8_tmap_in, - chunk_it_offset_x, chunk_it_offset_y); - } + if constexpr (IS_DACT) { + tdm::copy_2d_to_shared_x2(&in_sh[next_buff][0][0], fp8_tmap_in, + chunk_it_offset_x, chunk_it_offset_y, + &act_in_sh[next_buff][0][0], fp8_tmap_act_in, + chunk_it_offset_x, chunk_it_offset_y); + } else { + tdm::copy_2d_to_shared(&in_sh[next_buff][0][0], fp8_tmap_in, + chunk_it_offset_x, chunk_it_offset_y); } #endif // __HIP_PLATFORM_AMD__ } From 573f8d7ae92b9fec45fc64ac06cb2b9324298f0e Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 16:13:05 -0500 Subject: [PATCH 23/43] Revert " Remove padding from scales for hipBLASlt calls (#442)" This reverts commit 8eafaa408601b2fd8f5a84173f8c3ff60b4d2ae9. --- tests/cpp/operator/test_cublaslt_gemm.cu | 15 ++----- tests/cpp/test_common.cu | 20 ---------- tests/cpp/test_common.h | 7 ---- .../blockwise_quantizer_reference.py | 5 --- tests/pytorch/test_sanity.py | 2 +- transformer_engine/common/common.h | 10 +---- .../common/transformer_engine.cpp | 7 +--- .../common/util/rocm_cast_kernels.cuh | 8 +--- transformer_engine/jax/csrc/extensions/misc.h | 6 --- .../jax/quantize/scaling_modes.py | 9 +---- transformer_engine/pytorch/csrc/quantizer.cpp | 24 +---------- transformer_engine/pytorch/module/base.py | 4 +- .../pytorch/tensor/float8_blockwise_tensor.py | 18 ++------- .../pytorch/tensor/mxfp8_tensor.py | 40 ++++++------------- 14 files changed, 31 insertions(+), 144 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 4b0aabc4c..7c900570f 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,7 +29,6 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { - {32, 128, 16}, {768, 3072, 4096}, }; @@ -346,11 +345,8 @@ void performTest(const TestParams& params) { if (!has_fp8) { GTEST_SKIP() << "MXFP8 scaling mode requires Float8 types"; } - if (params.m % 16 || params.n % 16) { - GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; - } - if (params.k % 128) { - GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; + if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { + GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; } } @@ -571,11 +567,8 @@ void performDqTest(const TestParams ¶ms) { GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; - if (params.m % 16 || params.n % 16) { - GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; - } - if (params.k % 128) { - GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; + if (params.m % 32 != 0 || params.n % 32 != 0 || params.k % 32 != 0) { + GTEST_SKIP() << "MXFP8 requires M, N, K to be multiples of 32"; } cudaDeviceProp prop; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 3ddd9047d..7a89148fd 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -147,11 +147,7 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; -#ifdef __HIP_PLATFORM_AMD__ - auto block_alignment = std::vector{1ul, 1ul}; -#else auto block_alignment = std::vector{128ul, 4ul}; -#endif { auto alignment = block_alignment[0]; auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; @@ -185,20 +181,12 @@ std::pair get_scales(const NVTEShape& shape, { auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = DIVUP(last_dim, static_cast(128)); -#else auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; -#endif ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = DIVUP(first_dim, static_cast(128)); -#else auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; -#endif ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -219,20 +207,12 @@ std::pair get_scales(const NVTEShape& shape, { auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = first_dim; -#else auto scale_dim_1 = DIVUP(first_dim, 4) * 4; -#endif ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); -#ifdef __HIP_PLATFORM_AMD__ - auto scale_dim_1 = last_dim; -#else auto scale_dim_1 = DIVUP(last_dim, 4) * 4; -#endif ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 9e3c0f2a4..b824f8d4d 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -330,17 +330,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -#ifdef __HIP_PLATFORM_AMD__ -constexpr size_t scale_tensor_alignment_X_rowwise = 1; -constexpr size_t scale_tensor_alignment_Y_rowwise = 1; -constexpr size_t scale_tensor_alignment_X_colwise = 1; -constexpr size_t scale_tensor_alignment_Y_colwise = 1; -#else constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; -#endif inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 9ffbf9452..1ce7d3e42 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,7 +8,6 @@ from typing import Optional, Protocol, Tuple from references.quantize_scale_calc import scale_from_amax_tensor -from torch.utils.cpp_extension import IS_HIP_EXTENSION @dataclasses.dataclass() class QuantizeResult: @@ -39,8 +36,6 @@ def munge_scale_shapes_for_backend( def _pad_inner_to_align(s: torch.Tensor, transpose: bool) -> torch.Tensor: if transpose: s = s.transpose(-1, -2).contiguous() - if IS_HIP_EXTENSION: # HIP does not use scale padding - return s M, K = s.shape if K % 4 == 0: return s diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 0c6b65329..a7d762c3d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -41,7 +41,7 @@ Float8Quantizer, Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0fc26a05e..3c1a1ebed 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -715,23 +715,17 @@ template <> struct is_fp4 : std::true_type {}; #endif -#ifndef __HIP_PLATFORM_AMD__ // [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +#ifndef __HIP_PLATFORM_AMD__ // Alignment requirements for the Tensor Memory Accelerator (TMA) constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment #else -// HIP does not use scale padding -constexpr size_t scale_tensor_alignment_X_rowwise = 1; -constexpr size_t scale_tensor_alignment_Y_rowwise = 1; -constexpr size_t scale_tensor_alignment_X_colwise = 1; -constexpr size_t scale_tensor_alignment_Y_colwise = 1; - // Alignment requirements for the Tensor Data Mover (TDM) on gfx1250 constexpr size_t TDM_SHMEM_ALIGNMENT = 128; // shared memory address alignment #endif diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 78af6061f..68d1f0ec5 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -98,13 +98,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { } else { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { -#ifndef __HIP_PLATFORM_AMD__ // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; -#else - // HIP does not use scale padding - auto block_alignment = std::vector{1ul, 1ul}; -#endif size_t expected_x, expected_y, alignment; const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; const size_t block_size_colwise = 32; diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 4b0927133..8eef22424 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -232,8 +232,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); // Only single thread writes the computed scaling factor - const bool col_out_of_bounds = dbias_rowwise_offset_X >= cols; - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0 && !(row_out_of_bounds || col_out_of_bounds)) { + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { const int global_scales_offset_Y = iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; const int global_scales_offset_X = @@ -297,10 +296,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - const bool row_out_of_bounds = row_base >= rows; - if (!(row_out_of_bounds || col_out_of_bounds)) { - scales_colwise[scale_idx] = biased_exponent; - } + scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c71bb1306..af7f54feb 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -85,11 +83,7 @@ constexpr struct BlockSize { constexpr struct Alignment { size_t x; size_t y; -#ifndef __HIP_PLATFORM_AMD__ } MXFP8_ALIGNMENT{128, 4}; -#else -} MXFP8_ALIGNMENT{1, 1}; -#endif std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index deb2320eb..e81a614f0 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -25,7 +23,6 @@ from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported -from ..util import is_hip_extension __all__ = [ @@ -369,11 +366,7 @@ def __init__(self, block_dims: Tuple[int]): block_dims: Dimensions of the scaling blocks """ self._block_dims = block_dims - if is_hip_extension(): - self._block_alignment = (1, 1) - else: - self._block_alignment = (128, 4) - + self._block_alignment = (128, 4) def get_scale_dtype(self) -> jnp.dtype: """Get the data type for scale tensors in block scaling. diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 9935f3c90..37c13362c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -842,21 +842,13 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); -#ifdef __HIP_PLATFORM_AMD__ - return !columnwise - ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} - : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim}; -#else + std::vector scale_shape; bool rowwise_usage = !columnwise; + if (rowwise_usage) { // rowwise scaling factor shape size_t sinv0 = roundup(numel / last_dim, 128); @@ -1143,7 +1124,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s scale_shape = {sinv0, sinv1}; } return scale_shape; -#endif } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 745c61cad..3fb50af99 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -605,7 +605,7 @@ def fill_userbuffers_buffer_for_all_gather( comm.copy_into_buffer(local_data, local_chunk=True) # Gather scaling-inverses - if math.prod(local_shape[:-1]) % 128 != 0 and not IS_HIP_EXTENSION: + if math.prod(local_shape[:-1]) % 128 != 0: raise ValueError( "Userbuffers requires MXFP8 tensor dims that are divisible by 128, " f"but got MXFP8 tensor with shape={tuple(local_shape)}" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index e51da3223..0e41fc9c5 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -19,8 +17,6 @@ from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple -from torch.utils.cpp_extension import IS_HIP_EXTENSION - aten = torch.ops.aten @@ -141,17 +137,11 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, if self.block_scaling_dim == 2: if columnwise: outer = math.ceil(K / self.block_len) - if IS_HIP_EXTENSION: - inner = math.ceil(M / self.block_len) - else: - inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) + inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4) return (outer, inner) # rowwise outer = math.ceil(M / self.block_len) - if IS_HIP_EXTENSION: - inner = math.ceil(K / self.block_len) - else: - inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) + inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4) return (outer, inner) # 1D 1x128 quantization block scaling # CuBLAS requries 1x128 scaling factor to be padded and transposed @@ -159,7 +149,7 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, if columnwise: columnwise_compact = self.all_gather_usage outer = math.ceil(M / self.block_len) - inner = round_up_to_nearest_multiple(K, 4) if not IS_HIP_EXTENSION or not columnwise_compact else K + inner = round_up_to_nearest_multiple(K, 4) if not columnwise_compact else K # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS # for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner] # so no need to swap inner outer here @@ -167,7 +157,7 @@ def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, # rowwise rowwise_compact = self.all_gather_usage outer = math.ceil(K / self.block_len) - inner = round_up_to_nearest_multiple(M, 4) if not IS_HIP_EXTENSION or not rowwise_compact else M + inner = round_up_to_nearest_multiple(M, 4) if not rowwise_compact else M # GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need # for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here return (outer, inner) if not rowwise_compact else (inner, outer) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 485787bc5..16b1568cb 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -113,20 +113,12 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) # ROCm TE does not implement fuse padding zeros so use zero tensor here - if IS_HIP_EXTENSION: - scale_inv = torch.zeros( - math.prod(shape[:-1]), - math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), - dtype=torch.uint8, - device=device, - ) - else: - scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) + scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -134,20 +126,12 @@ def make_empty( if self.columnwise_usage: columnwise_data = torch.empty_like(data) # ROCm TE does not implement fuse padding zeros so use zero tensor here - if IS_HIP_EXTENSION: - columnwise_scale_inv = torch.zeros( - math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), - shape[-1], - dtype=torch.uint8, - device=device, - ) - else: - columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), - round_up_to_nearest_multiple(shape[-1], 128), - dtype=torch.uint8, - device=device, - ) + columnwise_scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) # Construct FP8 tensor return MXFP8Tensor( From fec2de59b788eae28a8e1416258ab04265489f7f Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 16:55:09 -0500 Subject: [PATCH 24/43] fix(rocm): correct double-prefixed constants in rocm_cast_gated_kernels.cuh ROCM_ rename was applied twice to BUFFER_DIM_Y, SHMEM_DIM_Y, SHMEM_DIM_X, creating ROCM_ROCM_* definitions while usages only had single ROCM_ prefix. Also BUFFERS_NUM and BUFFER_DIM_X at the shmem size calculation were never renamed to their ROCM_ equivalents. Co-Authored-By: Claude Sonnet 4 --- .../common/util/rocm_cast_gated_kernels.cuh | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index bc75d6e97..d552059d5 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -29,13 +29,13 @@ constexpr size_t ROCM_THREADS_PER_CHUNK = 256; constexpr size_t ROCM_THREADS_PER_CHUNK_X = 64; constexpr size_t ROCM_THREADS_PER_CHUNK_Y = ROCM_THREADS_PER_CHUNK / ROCM_THREADS_PER_CHUNK_X; // 4 = 256 / 64 constexpr size_t ROCM_BUFFERS_NUM = 1; // No async load for HIP -constexpr size_t ROCM_ROCM_BUFFER_DIM_Y = 32; +constexpr size_t ROCM_BUFFER_DIM_Y = 32; constexpr size_t ROCM_BUFFER_DIM_X = ROCM_CHUNK_DIM_X; // 64 -constexpr size_t ROCM_ROCM_SHMEM_DIM_Y = ROCM_ROCM_BUFFER_DIM_Y; // 32 -constexpr size_t ROCM_ROCM_SHMEM_DIM_X = ROCM_BUFFER_DIM_X; // 64 +constexpr size_t ROCM_SHMEM_DIM_Y = ROCM_BUFFER_DIM_Y; // 32 +constexpr size_t ROCM_SHMEM_DIM_X = ROCM_BUFFER_DIM_X; // 64 -constexpr size_t ROCM_BUFFER_STAGES_NUM = ROCM_ROCM_BUFFER_DIM_Y / ROCM_THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ROCM_ITERATIONS = ROCM_CHUNK_DIM_Y / ROCM_ROCM_BUFFER_DIM_Y; // 2 = 64 / 32 +constexpr size_t ROCM_BUFFER_STAGES_NUM = ROCM_BUFFER_DIM_Y / ROCM_THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ROCM_ITERATIONS = ROCM_CHUNK_DIM_Y / ROCM_BUFFER_DIM_Y; // 2 = 64 / 32 static_assert(ROCM_ITERATIONS >= 1); template (ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); - const size_t buff_elems = ROCM_ROCM_SHMEM_DIM_Y * ROCM_ROCM_SHMEM_DIM_X; + const size_t buff_elems = ROCM_SHMEM_DIM_Y * ROCM_SHMEM_DIM_X; const size_t buff_elems_total = ROCM_BUFFERS_NUM * buff_elems; const size_t buff_size_aligned_in = DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; @@ -127,23 +127,23 @@ __global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) __syncthreads(); for (int it = 0; it < ROCM_ITERATIONS; it++) { - const int chunk_it_offset_y = chunk_offset_Y + it * ROCM_ROCM_BUFFER_DIM_Y; + const int chunk_it_offset_y = chunk_offset_Y + it * ROCM_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; // Initiate bulk tensor copy if constexpr (IS_DGATED) { copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, - cols, ROCM_ROCM_SHMEM_DIM_Y, ROCM_ROCM_SHMEM_DIM_X, rows, cols); + cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); } // Act copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, ROCM_ROCM_SHMEM_DIM_Y, ROCM_ROCM_SHMEM_DIM_X, rows, cols); + 2*cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); // Gate copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, - 2*cols, ROCM_ROCM_SHMEM_DIM_Y, ROCM_ROCM_SHMEM_DIM_X, rows, cols); + 2*cols, ROCM_SHMEM_DIM_Y, ROCM_SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -296,7 +296,7 @@ __global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) const int stage_offset_Y = stage * ROCM_THREADS_PER_CHUNK_Y; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * ROCM_ROCM_SHMEM_DIM_X + shmem_offset_x; + const int shmem_idx = shmem_offset_y * ROCM_SHMEM_DIM_X + shmem_offset_x; out_gate_colwise_sh[shmem_idx] = static_cast(scale_reciprocal * after_dgate_reg[stage]); @@ -342,7 +342,7 @@ __global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) const int stage_offset_Y = stage * ROCM_THREADS_PER_CHUNK_Y; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * ROCM_ROCM_SHMEM_DIM_X + shmem_offset_x; + const int shmem_idx = shmem_offset_y * ROCM_SHMEM_DIM_X + shmem_offset_x; out_act_colwise_sh[shmem_idx] = static_cast(scale_reciprocal * after_dact_reg[stage]); @@ -415,7 +415,7 @@ void rocm_cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor constexpr size_t input_type_bit_size = TypeInfo::size; constexpr size_t output_type_bit_size = TypeInfo::size; - const size_t buff_elems_total = BUFFERS_NUM * ROCM_BUFFER_DIM_Y * BUFFER_DIM_X; + const size_t buff_elems_total = ROCM_BUFFERS_NUM * ROCM_BUFFER_DIM_Y * ROCM_BUFFER_DIM_X; const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; const size_t buff_size_aligned_in = From 004d59f828fdb34d48945871d703d8035999026d Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 16:56:27 -0500 Subject: [PATCH 25/43] fix(rocm): add TMA_SHMEM_ALIGNMENT alias and sigmoidf for AMD compilation - common.h: add TMA_SHMEM_ALIGNMENT as alias for TDM_SHMEM_ALIGNMENT in AMD block so cast_gated_kernels.cuh launcher code compiles without ifdefs - rocm_cast_gated_kernels.cuh: define sigmoidf device inline since HIP runtime does not provide it (mirrors the CUDA definition in cast_gated_kernels.cuh) Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/common.h | 1 + transformer_engine/common/util/rocm_cast_gated_kernels.cuh | 3 +++ 2 files changed, 4 insertions(+) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 3c1a1ebed..1b9b68ea6 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -728,6 +728,7 @@ constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment #else // Alignment requirements for the Tensor Data Mover (TDM) on gfx1250 constexpr size_t TDM_SHMEM_ALIGNMENT = 128; // shared memory address alignment +constexpr size_t TMA_SHMEM_ALIGNMENT = TDM_SHMEM_ALIGNMENT; // alias so shared launchers compile #endif inline bool is_aligned_ptr(const void *ptr, size_t alignment) { diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index d552059d5..ea1c32df6 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -19,6 +19,9 @@ #include "../utils.cuh" namespace transformer_engine { + +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + namespace gated_kernels { constexpr size_t ALIGNMENT_SIZE = 128; From 7c86c98a93046e17bc7cf10a410ad5e0d5b925af Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 16:59:44 -0500 Subject: [PATCH 26/43] =?UTF-8?q?fix(rocm):=20fix=20fp8=5Fquantize=20AMD?= =?UTF-8?q?=20flow=20=E2=80=94=20remove=20unavailable=20fp8=5Fquantize=5Fa?= =?UTF-8?q?rch=5Fge=5F100=20call?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fp8_quantize_arch_ge_100 is guarded by #ifndef __HIP_PLATFORM_AMD__ (NV TMA only). AMD branch should delegate entirely to fp8_quantize_rocm, which internally dispatches to mxfp8_quantize (TDM path) or rocm_mxfp8_quantize based on NVTE_USE_TDM_FLOW. Also rename NVTE_USE_NV_UPSTREAM_FLOW to NVTE_USE_TDM_FLOW in rocm_cast_kernels.cuh to match the unified env var. Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/util/cast_kernels.cuh | 16 +++------------- .../common/util/rocm_cast_kernels.cuh | 6 +++--- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 19004841b..aafc6847a 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1622,19 +1622,9 @@ void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *no dbias, workspace, stream); } #else - // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. - static const bool use_tdm_flow = [] { - const char *env = std::getenv("NVTE_USE_TDM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; - }(); - if (use_tdm_flow) { - fp8_quantize_arch_ge_100(input, act_input, noop, - output, dbias, workspace, - stream); - } else { - fp8_quantize_rocm(input, act_input, noop, output, - dbias, workspace, stream); - } + // AMD: fp8_quantize_rocm internally checks NVTE_USE_TDM_FLOW to select TDM vs ROCm path. + fp8_quantize_rocm(input, act_input, noop, output, + dbias, workspace, stream); #endif //#ifndef __HIP_PLATFORM_AMD__ } diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 8eef22424..b5cf592bb 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -549,11 +549,11 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso break; } case NVTE_MXFP8_1D_SCALING: { - static const bool use_nv_upstream_flow = [] { - const char *env = std::getenv("NVTE_USE_NV_UPSTREAM_FLOW"); + static const bool use_tdm_flow = [] { + const char *env = std::getenv("NVTE_USE_TDM_FLOW"); return env != nullptr && env[0] == '1' && env[1] == '\0'; }(); - if (use_nv_upstream_flow) { + if (use_tdm_flow) { mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } else { From 0b40533d026572f6549a628e11d840a0a9d2df0d Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 21:42:41 -0500 Subject: [PATCH 27/43] fix(rocm): route NVTE_MXFP8_1D_SCALING through fp8_quantize_rocm on AMD The AMD section of mxfp8_quantize only sets up raw pointers and never launches a kernel, so calling it directly from quantize_helper left the scale buffer zero-initialized. fp8_quantize_rocm already has the correct TDM/plain-ROCm dispatch logic; route AMD through it instead. Fixes 1110 FusedCastMXFP8TestSuite failures on gfx950 (NVTE_USE_TDM_FLOW=0). --- transformer_engine/common/util/cast_kernels.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index aafc6847a..4b18819ac 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1677,9 +1677,15 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o break; } case NVTE_MXFP8_1D_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8_quantize_rocm( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); +#else mxfp8_quantize( *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); +#endif break; } #ifndef __HIP_PLATFORM_AMD__ From c89b5ff2f20341db2ac6e3925345bc04880d0c57 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 21:52:26 -0500 Subject: [PATCH 28/43] fix(rocm): use padded scales_stride in rocm_mxfp8_dequantize The rowwise scale tensor is allocated with stride padded to scale_tensor_alignment_X_rowwise (4), but rocm_mxfp8_dequantize was computing scales_stride = DIVUP(cols, 32) (unpadded). From row 1 onward the kernel reads the wrong scale, producing inf/garbage output. Fix: use DIVUP_TO_MULTIPLE(..., scale_tensor_alignment_X_rowwise), matching the allocation in the test harness and the NV dequantize path. Fixes 6 DequantizeMXFP8TestSuite failures (65x96, block_size=(1,32)) on gfx950. --- transformer_engine/common/util/rocm_dequantize_kernels.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 5bb5f4671..17596a67b 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -150,7 +150,8 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); - const size_t scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t scales_X_rowwise = + DIVUP_TO_MULTIPLE(DIVUP(cols, scale_dim_X_rowwise), scale_tensor_alignment_X_rowwise); const size_t scales_X_colwise = cols; const e8m0_t *const scales_ptr = From 9f55a8baf3eb548740db8bc69cf32fef7d9321fc Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Sun, 26 Apr 2026 22:52:15 -0500 Subject: [PATCH 29/43] fix(rocm): guard TDM flow dispatch behind __gfx1250__ on AMD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The NVTE_USE_TDM_FLOW=1 branches in rocm_cast_kernels.cuh, cast_gated_kernels.cuh, and dequantize_kernels.cuh called TDM/TMA kernel paths (mxfp8_quantize, cast_mxfp8_gated, mxfp8_dequantize) that are no-ops on non-gfx1250 AMD — their device code is wrapped in #if defined(__gfx1250__) so nothing executes, leaving scales at zero. Wrap the TDM flow selection in #if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__), falling back to the plain ROCm kernels (rocm_mxfp8_quantize, rocm_cast_mxfp8_gated, rocm_mxfp8_dequantize) on all other AMD architectures. Fixes all 2748 tests passing with NVTE_USE_TDM_FLOW=1 on gfx950. --- .../common/util/cast_gated_kernels.cuh | 16 ++++++++++++---- .../common/util/dequantize_kernels.cuh | 6 ++++-- .../common/util/rocm_cast_kernels.cuh | 5 +++++ 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index cbdc64a18..33e4c73c8 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1566,8 +1566,8 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { if (use_tma_kernels) { -#ifdef __HIP_PLATFORM_AMD__ - // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__) + // On gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. static const bool use_tdm_flow_fp8 = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); return env != nullptr && env[0] == '1' && env[1] == '\0'; @@ -1581,6 +1581,12 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_gated(gated_input, output, stream); } } +#elif defined(__HIP_PLATFORM_AMD__) + if constexpr (IS_DGATED) { + cast_dgated(grad, gated_input, output, stream); + } else { + cast_gated(gated_input, output, stream); + } #else cast_fp8_gated(grad, gated_input, output, stream); #endif @@ -1593,8 +1599,8 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } } else if (is_mxfp_scaling(output->scaling_mode)) { if (use_tma_kernels) { -#ifdef __HIP_PLATFORM_AMD__ - // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__) + // On gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. static const bool use_tdm_flow = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); return env != nullptr && env[0] == '1' && env[1] == '\0'; @@ -1604,6 +1610,8 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } else { rocm_cast_mxfp8_gated(grad, gated_input, output, stream); } +#elif defined(__HIP_PLATFORM_AMD__) + rocm_cast_mxfp8_gated(grad, gated_input, output, stream); #else cast_mxfp8_gated(grad, gated_input, output, stream); #endif diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index d55dde265..b741ae06f 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -419,8 +419,8 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) if (is_tensor_scaling(input.scaling_mode)) { dequantization::fp8_dequantize(input, output, stream); } else if (is_mxfp_scaling(input.scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - // On AMD gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__) + // On gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. static const bool use_tdm_flow = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); return env != nullptr && env[0] == '1' && env[1] == '\0'; @@ -430,6 +430,8 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) } else { rocm_mxfp8_dequantize(input, output, stream); } +#elif defined(__HIP_PLATFORM_AMD__) + rocm_mxfp8_dequantize(input, output, stream); #else if (is_supported_by_CC_100()) { dequantization::mxfp8_dequantize(input, output, stream); diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index b5cf592bb..4940bb9ec 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -549,6 +549,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso break; } case NVTE_MXFP8_1D_SCALING: { +#if defined(__gfx1250__) static const bool use_tdm_flow = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); return env != nullptr && env[0] == '1' && env[1] == '\0'; @@ -560,6 +561,10 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } +#else + rocm_mxfp8_quantize(input, act_input, noop, output, + dbias, workspace, stream); +#endif break; } default: From 833872567d3853d0191c196a60bc34869077e6ea Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 27 Apr 2026 12:36:08 -0500 Subject: [PATCH 30/43] fix(rocm): wire up cast_mxfp8_2D_kernel launch on gfx1250 TDM path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The AMD section of mxfp8_quantize set up raw pointers but never launched the kernel — all TDM quantize calls were silent no-ops. Add the kernel launch switch (ROWWISE/COLWISE/BIDIMENSIONAL) mirroring the NV path, using raw pointers and TDM shared-memory sizing. Also fix host-side TDM dispatch guards: replace device-only __gfx1250__ with CMake-injected NVTE_ARCH_HAS_TDM (visible to host compilation) plus a runtime cuda::sm_arch_name() check, matching the ARCH_HAS_STOCHASTIC_ROUNDING pattern from PR #472. This ensures gfx942/950-only builds compile cleanly and multi-arch builds running on non-gfx1250 hardware fall back to the ROCm path even when NVTE_USE_TDM_FLOW=1. Add debug fprintf/printf traces across all dispatch and kernel entry points to confirm which code path executes at runtime. Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/CMakeLists.txt | 7 ++ .../common/util/cast_gated_kernels.cuh | 27 ++++++-- .../common/util/cast_kernels.cuh | 67 +++++++++++++++++-- .../common/util/dequantize_kernels.cuh | 14 +++- .../common/util/rocm_cast_kernels.cuh | 15 ++++- .../common/util/rocm_dequantize_kernels.cuh | 6 ++ 6 files changed, 119 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 58f5365e5..43c512c6a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -86,6 +86,13 @@ else() SET(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}) endif() + # Set NVTE_ARCH_HAS_TDM when building for gfx1250, which has the Tensor Data Mover (TDM). + # This define is visible to both host and device compilation, unlike __gfx1250__ which is + # device-only. Used to gate TDM dispatch in host functions. + if("gfx1250" IN_LIST CMAKE_HIP_ARCHITECTURES) + add_definitions(-DNVTE_ARCH_HAS_TDM) + endif() + set(CMAKE_CXX_STANDARD 17) project(transformer_engine LANGUAGES HIP CXX) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 33e4c73c8..3262d2118 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -24,6 +24,7 @@ #include #include "../common.h" +#include "cuda_runtime.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" #include "math.h" @@ -75,6 +76,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float *const amax_ptr, float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, const size_t cols) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + printf("[DBG cast_fp8_gated_kernel] TDM kernel executing rows=%zu cols=%zu\n", + (size_t)rows, (size_t)cols); + } #ifdef __HIP_PLATFORM_AMD__ // TDM needs explicit strides. For gated inputs, act and gate are interleaved → stride = 2*cols. @@ -448,6 +453,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + printf("[DBG cast_mxfp8_gated_kernel] TDM kernel executing rows=%zu cols=%zu\n", + (size_t)rows, (size_t)cols); + } using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -1566,15 +1575,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { if (use_tma_kernels) { -#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__) - // On gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. +#if defined(__HIP_PLATFORM_AMD__) && defined(NVTE_ARCH_HAS_TDM) static const bool use_tdm_flow_fp8 = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; + return env != nullptr && env[0] == '1' && env[1] == '\0' && + cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow_fp8) { + fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 TDM -> cast_fp8_gated\n"); cast_fp8_gated(grad, gated_input, output, stream); } else { + fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 ROCm -> cast_gated/cast_dgated\n"); if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { @@ -1582,6 +1593,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } } #elif defined(__HIP_PLATFORM_AMD__) + fprintf(stderr, "[DBG gated delayed_scaling] non-gfx1250 AMD -> cast_gated/cast_dgated\n"); if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { @@ -1599,18 +1611,21 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } } else if (is_mxfp_scaling(output->scaling_mode)) { if (use_tma_kernels) { -#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__) - // On gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. +#if defined(__HIP_PLATFORM_AMD__) && defined(NVTE_ARCH_HAS_TDM) static const bool use_tdm_flow = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; + return env != nullptr && env[0] == '1' && env[1] == '\0' && + cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { + fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 TDM -> cast_mxfp8_gated\n"); cast_mxfp8_gated(grad, gated_input, output, stream); } else { + fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 ROCm -> rocm_cast_mxfp8_gated\n"); rocm_cast_mxfp8_gated(grad, gated_input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) + fprintf(stderr, "[DBG gated mxfp_scaling] non-gfx1250 AMD -> rocm_cast_mxfp8_gated\n"); rocm_cast_mxfp8_gated(grad, gated_input, output, stream); #else cast_mxfp8_gated(grad, gated_input, output, stream); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 4b18819ac..a195b2465 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -72,6 +72,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + printf("[DBG cast_mxfp8_2D_kernel] TDM kernel executing rows=%zu cols=%zu\n", + (size_t)rows, (size_t)cols); + } constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -1247,12 +1251,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; -#ifndef __HIP_PLATFORM_AMD__ constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); @@ -1270,7 +1272,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; -#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { scaling_type = ScalingType::ROWWISE; @@ -1279,7 +1280,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, } else if (use_rowwise_scaling && use_colwise_scaling) { scaling_type = ScalingType::BIDIMENSIONAL; } -#endif if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); @@ -1310,8 +1310,63 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, OType *tensor_map_output_colwise = use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TDM_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TDM_SHMEM_ALIGNMENT); + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t dshmem_size = in_mem + out_rowwise_mem + out_colwise_mem + TDM_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::COLWISE: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + case ScalingType::BIDIMENSIONAL: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index b741ae06f..f5d77d44d 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -24,6 +24,7 @@ #include #include "../common.h" +#include "cuda_runtime.h" #include "../transpose/cast_transpose.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" @@ -67,6 +68,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, const size_t scales_stride) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + printf("[DBG mxfp8_dequantize TDM kernel] executing rows=%zu cols=%zu\n", + (size_t)rows, (size_t)cols); + } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 @@ -419,18 +424,21 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) if (is_tensor_scaling(input.scaling_mode)) { dequantization::fp8_dequantize(input, output, stream); } else if (is_mxfp_scaling(input.scaling_mode)) { -#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx1250__) - // On gfx1250: NVTE_USE_TDM_FLOW=1 selects TDM kernel; default (0) uses ROCm flow. +#if defined(__HIP_PLATFORM_AMD__) && defined(NVTE_ARCH_HAS_TDM) static const bool use_tdm_flow = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; + return env != nullptr && env[0] == '1' && env[1] == '\0' && + cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { + fprintf(stderr, "[DBG dequantize_helper] gfx1250 TDM -> mxfp8_dequantize\n"); dequantization::mxfp8_dequantize(input, output, stream); } else { + fprintf(stderr, "[DBG dequantize_helper] gfx1250 ROCm -> rocm_mxfp8_dequantize\n"); rocm_mxfp8_dequantize(input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) + fprintf(stderr, "[DBG dequantize_helper] non-gfx1250 AMD -> rocm_mxfp8_dequantize\n"); rocm_mxfp8_dequantize(input, output, stream); #else if (is_supported_by_CC_100()) { diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 4940bb9ec..aa164e916 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -9,6 +9,7 @@ #include #include "../common.h" +#include "cuda_runtime.h" #include "math.h" #include "ptx.cuh" #include "rocm_vectorized_2d.cuh" @@ -63,6 +64,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if (noop != nullptr && noop[0] == 1.0f) return; } + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + printf("[DBG cast_mxfp8_2D_kernel ROCm] plain ROCm kernel executing rows=%zu cols=%zu\n", + (size_t)rows, (size_t)cols); + } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; @@ -549,19 +554,23 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso break; } case NVTE_MXFP8_1D_SCALING: { -#if defined(__gfx1250__) +#ifdef NVTE_ARCH_HAS_TDM static const bool use_tdm_flow = [] { const char *env = std::getenv("NVTE_USE_TDM_FLOW"); - return env != nullptr && env[0] == '1' && env[1] == '\0'; + return env != nullptr && env[0] == '1' && env[1] == '\0' && + cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { + fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 TDM branch -> mxfp8_quantize\n"); mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } else { + fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 ROCm branch -> rocm_mxfp8_quantize\n"); rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } #else + fprintf(stderr, "[DBG fp8_quantize_rocm] non-gfx1250 AMD -> rocm_mxfp8_quantize\n"); rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); #endif @@ -582,6 +591,8 @@ void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Ten const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + fprintf(stderr, "[DBG rocm_mxfp8_quantize] rows=%zu cols=%zu — launching cast_mxfp8_2D_kernel\n", + rows, cols); const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 17596a67b..29351b502 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -48,6 +48,10 @@ __global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) OType *output_ptr, const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, const size_t scales_stride) { + if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + printf("[DBG dequantize_mxfp8_kernel ROCm] plain ROCm kernel executing rows=%zu cols=%zu\n", + (size_t)rows, (size_t)cols); + } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; @@ -147,6 +151,8 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + fprintf(stderr, "[DBG rocm_mxfp8_dequantize] rows=%zu cols=%zu — launching dequantize_mxfp8_kernel\n", + rows, cols); const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); From 14329d5d5c6d74bf26bd69f5054872347e5db65c Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 27 Apr 2026 12:42:33 -0500 Subject: [PATCH 31/43] refactor(rocm): consolidate mxfp8_quantize kernel launch for TDM and TMA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The switch-case for cast_mxfp8_2D_kernel is identical on AMD (TDM) and NV (TMA) — only the first four args differ (raw pointers vs CUtensorMap). Move the shared dshmem sizing and switch-case after the #ifdef block so there is a single launch path. The #ifdef now only covers platform-specific setup: raw pointer casts on AMD, create_2D_tensor_map descriptors on NV. TMA_SHMEM_ALIGNMENT is aliased to TDM_SHMEM_ALIGNMENT (both 128) so the shmem calculation is correct on both platforms without a separate formula. Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_kernels.cuh | 85 +++---------------- 1 file changed, 10 insertions(+), 75 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index a195b2465..0a02083e7 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1302,6 +1302,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ + // AMD (TDM): pass raw pointers directly to the kernel const IType *tensor_map_input = reinterpret_cast(input.data.dptr); const IType *tensor_map_act_input = IS_DACT ? reinterpret_cast(act_input->data.dptr) : nullptr; @@ -1309,66 +1310,8 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; OType *tensor_map_output_colwise = use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TDM_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TDM_SHMEM_ALIGNMENT); - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t dshmem_size = in_mem + out_rowwise_mem + out_colwise_mem + TDM_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } -#else // #ifdef __HIP_PLATFORM_AMD__ - +#else + // NV (TMA): build descriptor objects and register tensor layouts alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_output_rowwise{}; @@ -1394,25 +1337,21 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); } +#endif // __HIP_PLATFORM_AMD__ + // Shared launch: TMA_SHMEM_ALIGNMENT == TDM_SHMEM_ALIGNMENT == 128 constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); constexpr size_t elt_input_mem = buff_size_aligned_in; constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); constexpr size_t in_mem = elt_input_mem + act_input_mem; - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + const size_t dshmem_size = in_mem + out_rowwise_mem + out_colwise_mem + TMA_SHMEM_ALIGNMENT; switch (scaling_type) { case ScalingType::ROWWISE: @@ -1420,7 +1359,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, cast_mxfp8_2D_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - cast_mxfp8_2D_kernel <<>>( @@ -1435,7 +1373,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, cast_mxfp8_2D_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - cast_mxfp8_2D_kernel <<>>( @@ -1450,9 +1387,8 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, cast_mxfp8_2D_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel + cast_mxfp8_2D_kernel <<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, @@ -1461,7 +1397,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, NVTE_CHECK_CUDA(cudaGetLastError()); break; } -#endif // #ifdef __HIP_PLATFORM_AMD__ if constexpr (IS_DBIAS) { reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); From d38c6bdfd3767c9c494ea0329c710957eb923f50 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 27 Apr 2026 13:49:20 -0500 Subject: [PATCH 32/43] fix(amd): guard cudaFuncSetAttribute and add hip_bfloat16 overloads for TDM path - Move `using namespace mxfp8_kernel` outside `#ifndef __HIP_PLATFORM_AMD__` so tiling constants (CHUNK_DIM_Y/X, SCALE_DIM_X, BUFFS_NUM) are in scope on AMD - Guard all three `cudaFuncSetAttribute` calls with `#ifndef __HIP_PLATFORM_AMD__` since HIP cannot take the address of a templated kernel function the same way; dynamic shmem size is still correctly passed via <<>> - Add `__device__ __forceinline__` overloads of `__habs` and `__hmax` for `hip_bfloat16` (TE's bf16 alias) because ROCm only defines them for `__hip_bfloat16`, a distinct type on this ROCm version Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/util/cast_kernels.cuh | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 0a02083e7..07797be32 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -32,6 +32,14 @@ #ifdef __HIP_PLATFORM_AMD__ #include "rocm_cast_kernels.cuh" #include "tdm.cuh" +// ROCm defines __habs/__hmax only for __hip_bfloat16, not hip_bfloat16 (TE's bf16 alias). +// Provide the missing overloads so the TDM kernel compiles for bfloat16 inputs. +__device__ __forceinline__ hip_bfloat16 __habs(hip_bfloat16 x) { + return static_cast(fabsf(static_cast(x))); +} +__device__ __forceinline__ hip_bfloat16 __hmax(hip_bfloat16 x, hip_bfloat16 y) { + return static_cast(x) >= static_cast(y) ? x : y; +} #endif namespace transformer_engine { @@ -1223,8 +1231,8 @@ template , cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); +#endif cast_mxfp8_2D_kernel <<>>( @@ -1369,10 +1379,12 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: +#ifndef __HIP_PLATFORM_AMD__ NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); +#endif cast_mxfp8_2D_kernel <<>>( @@ -1383,10 +1395,12 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: +#ifndef __HIP_PLATFORM_AMD__ NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); +#endif cast_mxfp8_2D_kernel <<>>( From 14a1dab4c5dc44e2bb146efd5a28809e6537a044 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 14:25:34 -0500 Subject: [PATCH 33/43] chore: remove debug print statements from MXFP8 cast/dequantize kernels Co-Authored-By: Claude Sonnet 4 --- .../common/util/cast_gated_kernels.cuh | 15 --------------- transformer_engine/common/util/cast_kernels.cuh | 4 ---- .../common/util/dequantize_kernels.cuh | 7 ------- .../common/util/rocm_cast_kernels.cuh | 9 --------- .../common/util/rocm_dequantize_kernels.cuh | 6 ------ 5 files changed, 41 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 3262d2118..83993b631 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -76,11 +76,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float *const amax_ptr, float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, const size_t cols) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - printf("[DBG cast_fp8_gated_kernel] TDM kernel executing rows=%zu cols=%zu\n", - (size_t)rows, (size_t)cols); - } - #ifdef __HIP_PLATFORM_AMD__ // TDM needs explicit strides. For gated inputs, act and gate are interleaved → stride = 2*cols. // For outputs, IS_DGATED interleaves dact/dgate → stride = 2*cols; otherwise stride = cols. @@ -453,10 +448,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - printf("[DBG cast_mxfp8_gated_kernel] TDM kernel executing rows=%zu cols=%zu\n", - (size_t)rows, (size_t)cols); - } using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -1582,10 +1573,8 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow_fp8) { - fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 TDM -> cast_fp8_gated\n"); cast_fp8_gated(grad, gated_input, output, stream); } else { - fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 ROCm -> cast_gated/cast_dgated\n"); if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { @@ -1593,7 +1582,6 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } } #elif defined(__HIP_PLATFORM_AMD__) - fprintf(stderr, "[DBG gated delayed_scaling] non-gfx1250 AMD -> cast_gated/cast_dgated\n"); if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { @@ -1618,14 +1606,11 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { - fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 TDM -> cast_mxfp8_gated\n"); cast_mxfp8_gated(grad, gated_input, output, stream); } else { - fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 ROCm -> rocm_cast_mxfp8_gated\n"); rocm_cast_mxfp8_gated(grad, gated_input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) - fprintf(stderr, "[DBG gated mxfp_scaling] non-gfx1250 AMD -> rocm_cast_mxfp8_gated\n"); rocm_cast_mxfp8_gated(grad, gated_input, output, stream); #else cast_mxfp8_gated(grad, gated_input, output, stream); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 07797be32..85d70a64a 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -80,10 +80,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - printf("[DBG cast_mxfp8_2D_kernel] TDM kernel executing rows=%zu cols=%zu\n", - (size_t)rows, (size_t)cols); - } constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index f5d77d44d..09cc32dfc 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -68,10 +68,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, const size_t scales_stride) { #if defined(__gfx1250__) || ((defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - printf("[DBG mxfp8_dequantize TDM kernel] executing rows=%zu cols=%zu\n", - (size_t)rows, (size_t)cols); - } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 @@ -431,14 +427,11 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { - fprintf(stderr, "[DBG dequantize_helper] gfx1250 TDM -> mxfp8_dequantize\n"); dequantization::mxfp8_dequantize(input, output, stream); } else { - fprintf(stderr, "[DBG dequantize_helper] gfx1250 ROCm -> rocm_mxfp8_dequantize\n"); rocm_mxfp8_dequantize(input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) - fprintf(stderr, "[DBG dequantize_helper] non-gfx1250 AMD -> rocm_mxfp8_dequantize\n"); rocm_mxfp8_dequantize(input, output, stream); #else if (is_supported_by_CC_100()) { diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index aa164e916..cf182c298 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -64,10 +64,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if (noop != nullptr && noop[0] == 1.0f) return; } - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - printf("[DBG cast_mxfp8_2D_kernel ROCm] plain ROCm kernel executing rows=%zu cols=%zu\n", - (size_t)rows, (size_t)cols); - } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; @@ -561,16 +557,13 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { - fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 TDM branch -> mxfp8_quantize\n"); mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } else { - fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 ROCm branch -> rocm_mxfp8_quantize\n"); rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } #else - fprintf(stderr, "[DBG fp8_quantize_rocm] non-gfx1250 AMD -> rocm_mxfp8_quantize\n"); rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); #endif @@ -591,8 +584,6 @@ void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Ten const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - fprintf(stderr, "[DBG rocm_mxfp8_quantize] rows=%zu cols=%zu — launching cast_mxfp8_2D_kernel\n", - rows, cols); const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 29351b502..17596a67b 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -48,10 +48,6 @@ __global__ void __launch_bounds__(ROCM_THREADS_PER_CHUNK) OType *output_ptr, const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, const size_t scales_stride) { - if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - printf("[DBG dequantize_mxfp8_kernel ROCm] plain ROCm kernel executing rows=%zu cols=%zu\n", - (size_t)rows, (size_t)cols); - } constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; @@ -151,8 +147,6 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - fprintf(stderr, "[DBG rocm_mxfp8_dequantize] rows=%zu cols=%zu — launching dequantize_mxfp8_kernel\n", - rows, cols); const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); From 198495ad300320aa41907aedce4342539a87066f Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 14:34:24 -0500 Subject: [PATCH 34/43] chore: restore launcher debug prints, remove only in-kernel printf statements Co-Authored-By: Claude Sonnet 4 --- transformer_engine/common/util/cast_gated_kernels.cuh | 6 ++++++ transformer_engine/common/util/dequantize_kernels.cuh | 3 +++ transformer_engine/common/util/rocm_cast_kernels.cuh | 5 +++++ transformer_engine/common/util/rocm_dequantize_kernels.cuh | 2 ++ 4 files changed, 16 insertions(+) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 83993b631..1afdafdf6 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1573,8 +1573,10 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow_fp8) { + fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 TDM -> cast_fp8_gated\n"); cast_fp8_gated(grad, gated_input, output, stream); } else { + fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 ROCm -> cast_gated/cast_dgated\n"); if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { @@ -1582,6 +1584,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu } } #elif defined(__HIP_PLATFORM_AMD__) + fprintf(stderr, "[DBG gated delayed_scaling] non-gfx1250 AMD -> cast_gated/cast_dgated\n"); if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { @@ -1606,11 +1609,14 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { + fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 TDM -> cast_mxfp8_gated\n"); cast_mxfp8_gated(grad, gated_input, output, stream); } else { + fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 ROCm -> rocm_cast_mxfp8_gated\n"); rocm_cast_mxfp8_gated(grad, gated_input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) + fprintf(stderr, "[DBG gated mxfp_scaling] non-gfx1250 AMD -> rocm_cast_mxfp8_gated\n"); rocm_cast_mxfp8_gated(grad, gated_input, output, stream); #else cast_mxfp8_gated(grad, gated_input, output, stream); diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 09cc32dfc..df1199db4 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -427,11 +427,14 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { + fprintf(stderr, "[DBG dequantize_helper] gfx1250 TDM -> mxfp8_dequantize\n"); dequantization::mxfp8_dequantize(input, output, stream); } else { + fprintf(stderr, "[DBG dequantize_helper] gfx1250 ROCm -> rocm_mxfp8_dequantize\n"); rocm_mxfp8_dequantize(input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) + fprintf(stderr, "[DBG dequantize_helper] non-gfx1250 AMD -> rocm_mxfp8_dequantize\n"); rocm_mxfp8_dequantize(input, output, stream); #else if (is_supported_by_CC_100()) { diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index cf182c298..9d48ffb15 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -557,13 +557,16 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { + fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 TDM branch -> mxfp8_quantize\n"); mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } else { + fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 ROCm branch -> rocm_mxfp8_quantize\n"); rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } #else + fprintf(stderr, "[DBG fp8_quantize_rocm] non-gfx1250 AMD -> rocm_mxfp8_quantize\n"); rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); #endif @@ -584,6 +587,8 @@ void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Ten const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + fprintf(stderr, "[DBG rocm_mxfp8_quantize] rows=%zu cols=%zu — launching cast_mxfp8_2D_kernel\n", + rows, cols); const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 17596a67b..abcb2a0d6 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -147,6 +147,8 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + fprintf(stderr, "[DBG rocm_mxfp8_dequantize] rows=%zu cols=%zu — launching dequantize_mxfp8_kernel\n", + rows, cols); const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); From 362ae532923508178cf5701efb68dc373cba3ee5 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 16:22:16 -0500 Subject: [PATCH 35/43] test: add 16384x16384 matrix size to CastMXFP8_GatedAct benchmark run Co-Authored-By: Claude Sonnet 4 --- tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 52180786d..6bddf9c49 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -446,6 +446,7 @@ std::vector> matrix_sizes = { {768, 1024}, {8192, 128}, {577, 1632}, + {16384, 16384}, }; std::vector> block_sizes = { From 04564928695cbc7ec3a0cd6df1a190e3d669c1ea Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 16:48:21 -0500 Subject: [PATCH 36/43] feat: migrate benchmarks/cpp/cast from dev branch Co-Authored-By: Claude Sonnet 4 --- benchmarks/cpp/CMakeLists.txt | 88 ++++++ benchmarks/cpp/cast/bench_casttranspose.cpp | 269 ++++++++++++++++++ .../cpp/cast/bench_dequantize_mxfp8.cpp | 130 +++++++++ benchmarks/cpp/cast/bench_gated_mxfp8.cpp | 214 ++++++++++++++ .../cpp/cast/bench_quantize_mxfp8_fused.cpp | 182 ++++++++++++ benchmarks/cpp/utils/benchmark_utils.h | 223 +++++++++++++++ 6 files changed, 1106 insertions(+) create mode 100644 benchmarks/cpp/CMakeLists.txt create mode 100644 benchmarks/cpp/cast/bench_casttranspose.cpp create mode 100644 benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp create mode 100644 benchmarks/cpp/cast/bench_gated_mxfp8.cpp create mode 100644 benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp create mode 100644 benchmarks/cpp/utils/benchmark_utils.h diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt new file mode 100644 index 000000000..6071f9083 --- /dev/null +++ b/benchmarks/cpp/CMakeLists.txt @@ -0,0 +1,88 @@ +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CXX_COMPILER) + set(CMAKE_CXX_COMPILER hipcc) +endif() + +include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/rocm_utils.cmake") + +project(transformer_engine_benchmarks LANGUAGES CXX HIP) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(HIP REQUIRED) + +include(FetchContent) +FetchContent_Declare( + benchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.8.3 +) +set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) +set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE) +FetchContent_MakeAvailable(benchmark) + +set(TESTS_CPP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../tests/cpp) + +include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/hipify/hipify.cmake") +TE_Hipify(${TESTS_CPP_DIR}) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common/include + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}/../../transformer_engine + ${CMAKE_CURRENT_SOURCE_DIR}/utils + ${TESTS_CPP_DIR} +) + +set(COMMON_COMPILE_OPTIONS + -Wall + -Wextra + -O3 + -DNDEBUG + -DUSE_ROCM +) + +find_library(TRANSFORMER_ENGINE_LIB + NAMES transformer_engine + PATHS ${CMAKE_CURRENT_SOURCE_DIR}/../.. + ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake + ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib + /usr/local/lib + $ENV{HOME}/.local/lib + NO_DEFAULT_PATH +) + +if(NOT TRANSFORMER_ENGINE_LIB) + message(WARNING "TransformerEngine library not found in expected paths. Trying system paths...") + find_library(TRANSFORMER_ENGINE_LIB NAMES transformer_engine) +endif() + +if(TRANSFORMER_ENGINE_LIB) + message(STATUS "Found TransformerEngine library: ${TRANSFORMER_ENGINE_LIB}") +else() + message(FATAL_ERROR "TransformerEngine library not found. Please build TransformerEngine first:\n" + " cd ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" + " pip install -e . --no-build-isolation\n" + "Searched paths:\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../..\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/cmake\n" + " ${CMAKE_CURRENT_SOURCE_DIR}/../../build/lib") +endif() + +function(add_te_benchmark TARGET_NAME SOURCE_FILE) + add_executable(${TARGET_NAME} ${SOURCE_FILE} ${TESTS_CPP_DIR}/test_common.hip) + target_compile_options(${TARGET_NAME} PRIVATE ${COMMON_COMPILE_OPTIONS}) + target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK) + target_link_libraries(${TARGET_NAME} PRIVATE + benchmark::benchmark + ${TRANSFORMER_ENGINE_LIB} + hiprand + ) +endfunction() + +add_te_benchmark(bench_quantize_mxfp8_fused cast/bench_quantize_mxfp8_fused.cpp) +add_te_benchmark(bench_dequantize_mxfp8 cast/bench_dequantize_mxfp8.cpp) +add_te_benchmark(bench_gated_mxfp8 cast/bench_gated_mxfp8.cpp) +add_te_benchmark(bench_casttranspose cast/bench_casttranspose.cpp) diff --git a/benchmarks/cpp/cast/bench_casttranspose.cpp b/benchmarks/cpp/cast/bench_casttranspose.cpp new file mode 100644 index 000000000..3a5f1fdd1 --- /dev/null +++ b/benchmarks/cpp/cast/bench_casttranspose.cpp @@ -0,0 +1,269 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "transformer_engine/cast_hip.h" +#include "transformer_engine/transpose_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +// #define NVTE_ROCM_EXTENDED_BENCHMARKS 1 + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +#define GPT_OSS_COMMON_SHAPES \ + ->Args({2880, 2880}) \ + ->Args({2880, 4096}) \ + ->Args({5120, 2880}) \ + ->Args({5760, 2880}) \ + ->Args({16384, 2880}) \ + ->Args({16384, 4096}) \ + ->Args({16384, 5120}) + +// GPT-OSS MoE per-expert shapes (hidden=2880, intermediate=5760) +#define GPT_OSS_MOE \ + ->Args({64, 2880}) \ + ->Args({256, 2880}) \ + ->Args({320, 2880}) \ + ->Args({496, 2880}) \ + ->Args({1792, 2880}) \ + ->Args({64, 5760}) \ + ->Args({256, 5760}) \ + ->Args({320, 5760}) \ + ->Args({496, 5760}) \ + ->Args({1792, 5760}) + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384}) \ + ->Args({16384, 28672}) \ + ->Args({32768, 8192}) \ + ->Args({32768, 16384}) + +// Only used for specific benchmarks (older models, special cases, etc) +#define EXTENDED_SHAPES \ + ->Args({2048, 12288}) \ + ->Args({256, 65536}) \ + ->Args({65536, 128}) \ + ->Args({1600, 1600}) \ + ->Args({1600, 6400}) \ + ->Args({4800, 1600}) \ + ->Args({56320 , 1600}) \ + ->Args({6400, 1600}) \ + ->Args({128256, 4096}) \ + ->Args({24576, 128256}) \ + ->Args({24576, 4096}) \ + ->Args({24576, 5120}) \ + ->Args({28672, 4096}) \ + ->Args({4096, 12288}) \ + ->Args({5120, 4096}) \ + ->Args({10240, 8192}) \ + ->Args({128256, 8192}) \ + ->Args({57344, 10240}) \ + ->Args({57344, 128256}) \ + ->Args({57344, 8192}) \ + ->Args({32000, 4096}) \ + ->Args({32768, 32000}) \ + ->Args({32768, 4096}) \ + ->Args({32768, 5120}) \ + ->Args({3072, 1024}) \ + ->Args({24576, 1024}) \ + ->Args({4096, 1024}) + + + + + +template +static void BM_CastOnly(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + std::vector shape = {rows, cols}; + + DType itype = std::is_same_v ? DType::kFloat32 : + std::is_same_v ? DType::kBFloat16 : + DType::kFloat16; + + test::Tensor &input = TensorCache::get_or_create( + "cast_input", shape, itype, true, false, NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output = TensorCache::get_or_create( + "cast_output", shape, DType::kFloat8E4M3, true, false, NVTE_DELAYED_TENSOR_SCALING, false); + + output.set_scale(1.0f); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + // Untimed call to trigger any RTC compilation before measurement + nvte_quantize(input.data(), output.data(), stream); + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + nvte_quantize(input.data(), output.data(), stream); + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_read = rows * cols * sizeof(IType); + const size_t bytes_write = rows * cols * sizeof(fp8_e4m3); + set_bytes_processed(state, bytes_read + bytes_write); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +template +static void BM_CastTranspose(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + std::vector shape = {rows, cols}; + + DType itype = std::is_same_v ? DType::kFloat32 : + std::is_same_v ? DType::kBFloat16 : + DType::kFloat16; + + test::Tensor &input = TensorCache::get_or_create( + "ct_input", shape, itype, true, false, NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output = TensorCache::get_or_create( + "ct_output", shape, DType::kFloat8E4M3, true, true, NVTE_DELAYED_TENSOR_SCALING, false); + + output.set_scale(1.0f); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + // Untimed call to trigger any RTC compilation before measurement + nvte_quantize(input.data(), output.data(), stream); + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + nvte_quantize(input.data(), output.data(), stream); + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_read = rows * cols * sizeof(IType); + const size_t bytes_write = rows * cols * sizeof(fp8_e4m3) * 2; + set_bytes_processed(state, bytes_read + bytes_write); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_CAST_ONLY(ITYPE, INAME) \ + BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \ + ->Name("BM_CastOnly/" INAME "_E4M3/gpt_oss") \ + GPT_OSS_COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \ + ->Name("BM_CastOnly/" INAME "_E4M3/gpt_oss_moe") \ + GPT_OSS_MOE \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \ + ->Name("BM_CastOnly/" INAME "_E4M3/llm") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_CAST_TRANSPOSE(ITYPE, INAME) \ + BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \ + ->Name("BM_CastTranspose/" INAME "_E4M3/gpt_oss") \ + GPT_OSS_COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \ + ->Name("BM_CastTranspose/" INAME "_E4M3/gpt_oss_moe") \ + GPT_OSS_MOE \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \ + ->Name("BM_CastTranspose/" INAME "_E4M3/llm") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#ifdef NVTE_ROCM_EXTENDED_BENCHMARKS +#define REGISTER_EXTENDED_CAST_ONLY(ITYPE, INAME) \ + BENCHMARK_TEMPLATE(BM_CastOnly, ITYPE) \ + ->Name("BM_CastOnlyExtended/" INAME "_E4M3/llm") \ + EXTENDED_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_EXTENDED_CAST_TRANSPOSE(ITYPE, INAME) \ + BENCHMARK_TEMPLATE(BM_CastTranspose, ITYPE) \ + ->Name("BM_CastTransposeExtended/" INAME "_E4M3/llm") \ + EXTENDED_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_EXTENDED_CAST_ONLY(float, "FP32") +REGISTER_EXTENDED_CAST_ONLY(hip_bfloat16, "BF16") + +REGISTER_EXTENDED_CAST_TRANSPOSE(float, "FP32") +REGISTER_EXTENDED_CAST_TRANSPOSE(hip_bfloat16, "BF16") +#endif // #ifdef NVTE_ROCM_EXTENDED_BENCHMARKS + +REGISTER_CAST_ONLY(float, "FP32") +REGISTER_CAST_ONLY(hip_bfloat16, "BF16") + +REGISTER_CAST_TRANSPOSE(float, "FP32") +REGISTER_CAST_TRANSPOSE(hip_bfloat16, "BF16") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp new file mode 100644 index 000000000..bd6b4c652 --- /dev/null +++ b/benchmarks/cpp/cast/bench_dequantize_mxfp8.cpp @@ -0,0 +1,130 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "transformer_engine/cast_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384})\ + ->Args({32768, 8192}) + +template +static void BM_DequantizeMXFP8(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? cols : 0; + + std::vector shape = {rows, cols}; + DType itype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + DType otype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + DeviceBuffer temp_fp32(rows * cols); + fill_random_uniform_gpu(temp_fp32.get(), rows * cols, -2.0f, 1.0f, stream); + + void *input_data_ptr = USE_ROWWISE ? input_tensor.rowwise_dptr() : input_tensor.columnwise_dptr(); + size_t threads = 256; + size_t blocks = (rows * cols + threads - 1) / threads; + cast_fp32_kernel<<>>(temp_fp32.get(), static_cast(input_data_ptr), rows * cols); + + HIP_CHECK(hipStreamSynchronize(stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_read_data = rows * cols * sizeof(IType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + const size_t bytes_read_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + const size_t bytes_write = rows * cols * sizeof(OType); + const size_t total_bytes = bytes_read_data + bytes_read_scales + bytes_write; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_DEQUANTIZE_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_DequantizeMXFP8, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_DequantizeMXFP8/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, __half, "E4M3", "FP16") +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, hip_bfloat16, "E4M3", "BF16") +REGISTER_DEQUANTIZE_ALL_CONFIGS(fp8_e4m3, float, "E4M3", "FP32") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_gated_mxfp8.cpp b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp new file mode 100644 index 000000000..c715c5920 --- /dev/null +++ b/benchmarks/cpp/cast/bench_gated_mxfp8.cpp @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "transformer_engine/cast_hip.h" +#include "transformer_engine/activation_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +// SwiGLU shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({1024, 28672}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({4096, 53248}) \ + ->Args({8192, 14336}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 28672}) \ + ->Args({16384, 53248}) \ + ->Args({32768, 28672}) \ + ->Args({32768, 53248}) + +template +static void BM_GatedMXFP8_Forward(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t input_cols = cols * 2; + const size_t output_cols = cols; + + const size_t scale_cols_row = USE_ROWWISE ? (output_cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? output_cols : 0; + + std::vector input_shape = {rows, input_cols}; + std::vector output_shape = {rows, output_cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &input_tensor = TensorCache::get_or_create("input", input_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", output_shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_write_data = rows * output_cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + const size_t bytes_read = rows * cols * sizeof(IType) * 2; + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +template +static void BM_GatedMXFP8_Backward(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t input_cols = cols * 2; + const size_t output_cols = cols * 2; + + const size_t scale_cols_row = USE_ROWWISE ? (output_cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? output_cols : 0; + + std::vector grad_shape = {rows, cols}; + std::vector input_shape = {rows, input_cols}; + std::vector output_shape = {rows, output_cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &grad_tensor = TensorCache::get_or_create("grad", grad_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &input_tensor = TensorCache::get_or_create("input", input_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", output_shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + nvte_dswiglu(grad_tensor.data(), input_tensor.data(), output_tensor.data(), stream); + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + const size_t bytes_write_data = rows * output_cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + const size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + const size_t bytes_read = rows * cols * sizeof(IType) * 3; + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_GATED_ALL_CONFIGS(ITYPE, OTYPE, INAME, ONAME) \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Forward, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GatedMXFP8_Forward/" INAME "_" ONAME "/both") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 1, 32) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/rowwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 32, 1) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/colwise") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_GatedMXFP8_Backward, ITYPE, OTYPE, 32, 32) \ + ->Name("BM_GatedMXFP8_Backward/" INAME "_" ONAME "/both") \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +REGISTER_GATED_ALL_CONFIGS(__half, fp8_e4m3, "FP16", "E4M3") +REGISTER_GATED_ALL_CONFIGS(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") +REGISTER_GATED_ALL_CONFIGS(float, fp8_e4m3, "FP32", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp new file mode 100644 index 000000000..6f5540714 --- /dev/null +++ b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp @@ -0,0 +1,182 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include "amd_detail/hip_float8.h" + +#include "benchmark_utils.h" + +#include "transformer_engine/cast_hip.h" +#include "transformer_engine/activation_hip.h" +#include "transformer_engine/transformer_engine_hip.h" + +using namespace te_bench; +using namespace transformer_engine; +using fp8_e4m3 = test::fp8e4m3; + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +// Tensor shapes from LLaMA (8B, 70B, 405B) and Qwen (7B, 72B) +#define COMMON_SHAPES \ + ->Args({1024, 3584}) \ + ->Args({1024, 4096}) \ + ->Args({1024, 8192}) \ + ->Args({1024, 14336}) \ + ->Args({1024, 18944}) \ + ->Args({2048, 4096}) \ + ->Args({2048, 8192}) \ + ->Args({2048, 14336}) \ + ->Args({2048, 28672}) \ + ->Args({2048, 29568}) \ + ->Args({4096, 4096}) \ + ->Args({4096, 8192}) \ + ->Args({4096, 16384}) \ + ->Args({4096, 14336}) \ + ->Args({4096, 28672}) \ + ->Args({8192, 8192}) \ + ->Args({8192, 16384}) \ + ->Args({8192, 28672}) \ + ->Args({8192, 29568}) \ + ->Args({8192, 53248}) \ + ->Args({16384, 8192}) \ + ->Args({16384, 16384})\ + ->Args({16384, 28672})\ + ->Args({32768, 8192}) \ + ->Args({32768, 16384}) + +template +static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { + const size_t rows = state.range(0); + const size_t cols = state.range(1); + + constexpr bool USE_ROWWISE = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE = SCALE_DIM_Y > 1; + + const size_t scale_cols_row = USE_ROWWISE ? (cols + 31) / 32 : 0; + const size_t scale_rows_col = USE_COLWISE ? (rows + 31) / 32 : 0; + const size_t scale_cols_col = USE_COLWISE ? cols : 0; + + std::vector shape = {rows, cols}; + + DType itype = std::is_same_v ? DType::kFloat16 : + (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); + DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; + + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, + NVTE_MXFP8_1D_SCALING, false); + + test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; + + if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { + std::vector bias_shape = {cols}; + dbias_tensor_ptr = &TensorCache::get_or_create("dbias", bias_shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + workspace_tensor_ptr = &TensorCache::get_or_create("workspace", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, false); + } + + if constexpr (PROC_METHOD == CAST_DBIAS_DACT || PROC_METHOD == CAST_DACT) { + grad_tensor_ptr = &TensorCache::get_or_create("grad", shape, itype, true, false, + NVTE_DELAYED_TENSOR_SCALING, true); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + + if constexpr (PROC_METHOD == CAST_ONLY) { + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); + } else if constexpr (PROC_METHOD == CAST_DBIAS) { + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor_ptr->data(), workspace_tensor_ptr->data(), stream); + } else if constexpr (PROC_METHOD == CAST_DBIAS_DACT) { + nvte_quantize_dbias_dgelu(grad_tensor_ptr->data(), input_tensor.data(), output_tensor.data(), dbias_tensor_ptr->data(), workspace_tensor_ptr->data(), stream); + } else if constexpr (PROC_METHOD == CAST_DACT) { + nvte_dgelu(grad_tensor_ptr->data(), input_tensor.data(), output_tensor.data(), stream); + } else if constexpr (PROC_METHOD == CAST_ACT) { + nvte_gelu(input_tensor.data(), output_tensor.data(), stream); + } + + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t bytes_write_data = rows * cols * sizeof(OType) * + ((USE_ROWWISE ?: 0) + (USE_COLWISE ?: 0)); + // Scales are single byte, E8M0 type + size_t bytes_write_scales = (USE_ROWWISE ? rows * scale_cols_row : 0) + + (USE_COLWISE ? scale_rows_col * scale_cols_col : 0); + + size_t bytes_read = rows * cols * sizeof(IType); + if constexpr (PROC_METHOD == CAST_DBIAS_DACT || PROC_METHOD == CAST_DACT) { + bytes_read += rows * cols * sizeof(IType); + } + if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { + bytes_write_data += cols * sizeof(IType); + } + + const size_t total_bytes = bytes_read + bytes_write_data + bytes_write_scales; + + set_bytes_processed(state, total_bytes); + + HIP_CHECK(hipStreamDestroy(stream)); +} + +#define REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, METHOD, METHOD_NAME) \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 1, 32, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/rowwise/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 32, 1, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/colwise/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_QuantizeMXFP8_Fused, ITYPE, OTYPE, 32, 32, METHOD) \ + ->Name("BM_QuantizeMXFP8_" METHOD_NAME "/both/" INAME "_" ONAME) \ + COMMON_SHAPES \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define REGISTER_ALL_METHODS(ITYPE, OTYPE, INAME, ONAME) \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_ONLY, "CastOnly") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DBIAS, "CastDBias") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DBIAS_DACT, "CastDBiasDACT") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_DACT, "CastDACT") \ + REGISTER_QUANTIZE_FUSED(ITYPE, OTYPE, INAME, ONAME, CAST_ACT, "CastACT") + +REGISTER_ALL_METHODS(__half, fp8_e4m3, "FP16", "E4M3") +REGISTER_ALL_METHODS(hip_bfloat16, fp8_e4m3, "BF16", "E4M3") +REGISTER_ALL_METHODS(float, fp8_e4m3, "FP32", "E4M3") + +BENCHMARK_MAIN(); diff --git a/benchmarks/cpp/utils/benchmark_utils.h b/benchmarks/cpp/utils/benchmark_utils.h new file mode 100644 index 000000000..bd2906b30 --- /dev/null +++ b/benchmarks/cpp/utils/benchmark_utils.h @@ -0,0 +1,223 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "test_common_hip.h" + +namespace te_bench { + +#define HIP_CHECK(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error at %s:%d: %s\n", __FILE__, __LINE__, \ + hipGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +template +class DeviceBuffer { + public: + DeviceBuffer(size_t count) : count_(count) { + HIP_CHECK(hipMalloc(&ptr_, count * sizeof(T))); + } + + ~DeviceBuffer() { + if (ptr_) { + hipError_t err = hipFree(ptr_); + (void)err; + } + } + + DeviceBuffer(const DeviceBuffer &) = delete; + DeviceBuffer &operator=(const DeviceBuffer &) = delete; + + DeviceBuffer(DeviceBuffer &&other) noexcept : ptr_(other.ptr_), count_(other.count_) { + other.ptr_ = nullptr; + other.count_ = 0; + } + + T *get() { return ptr_; } + const T *get() const { return ptr_; } + size_t count() const { return count_; } + size_t bytes() const { return count_ * sizeof(T); } + + void upload(const std::vector &host_data) { + if (host_data.size() != count_) { + throw std::runtime_error("Size mismatch in upload"); + } + HIP_CHECK(hipMemcpy(ptr_, host_data.data(), bytes(), hipMemcpyHostToDevice)); + } + + void download(std::vector &host_data) const { + host_data.resize(count_); + HIP_CHECK(hipMemcpy(host_data.data(), ptr_, bytes(), hipMemcpyDeviceToHost)); + } + + private: + T *ptr_ = nullptr; + size_t count_ = 0; +}; + +template +std::vector generate_random_data(size_t count, T min_val = -1.0, T max_val = 1.0) { + std::vector data(count); + std::mt19937 gen(42); + + if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist(min_val, max_val); + for (auto &val : data) { + val = dist(gen); + } + } else { + std::uniform_int_distribution dist(static_cast(min_val), static_cast(max_val)); + for (auto &val : data) { + val = static_cast(dist(gen)); + } + } + + return data; +} + +__global__ void scale_shift_kernel(float *data, size_t count, float scale, float offset) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < count) { + data[idx] = data[idx] * scale + offset; + } +} + +inline void fill_random_uniform_gpu(float *dptr, size_t count, float min_val = -2.0f, float max_val = 1.0f, hipStream_t stream = 0) { + hiprandGenerator_t gen; + hiprandCreateGenerator(&gen, HIPRAND_RNG_PSEUDO_DEFAULT); + hiprandSetPseudoRandomGeneratorSeed(gen, 42); + if (stream != 0) { + hiprandSetStream(gen, stream); + } + hiprandGenerateUniform(gen, dptr, count); + float scale = max_val - min_val; + float offset = min_val; + + size_t threads = 256; + size_t blocks = (count + threads - 1) / threads; + scale_shift_kernel<<>>(dptr, count, scale, offset); + + hiprandDestroyGenerator(gen); +} + +template +__global__ void cast_fp32_kernel(const float *in, T *out, size_t count) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < count) { + out[idx] = static_cast(in[idx]); + } +} + +template +inline void fill_random_uniform_gpu_typed(T *dptr, size_t count, float min_val = -2.0f, float max_val = 1.0f, hipStream_t stream = 0) { + if constexpr (std::is_same_v) { + fill_random_uniform_gpu(dptr, count, min_val, max_val, stream); + } else { + DeviceBuffer temp_fp32(count); + fill_random_uniform_gpu(temp_fp32.get(), count, min_val, max_val, stream); + + size_t threads = 256; + size_t blocks = (count + threads - 1) / threads; + cast_fp32_kernel<<>>(temp_fp32.get(), dptr, count); + } +} + +inline void warmup_gpu(int iterations = 10) { + DeviceBuffer dummy(1024); + for (int i = 0; i < iterations; ++i) { + HIP_CHECK(hipMemset(dummy.get(), 0, dummy.bytes())); + } + HIP_CHECK(hipDeviceSynchronize()); +} + +inline double calculate_bandwidth_gbps(size_t bytes, double time_ns) { + return (bytes / 1e9) / (time_ns / 1e9); +} + +inline void set_items_processed(benchmark::State &state, size_t items_per_iter) { + state.SetItemsProcessed(state.iterations() * items_per_iter); +} + +inline void set_bytes_processed(benchmark::State &state, size_t bytes_per_iter) { + state.SetBytesProcessed(state.iterations() * bytes_per_iter); +} + +class TensorCache { + public: + struct CacheKey { + std::string name; + size_t rows; + size_t cols; + transformer_engine::DType dtype; + bool rowwise; + bool colwise; + NVTEScalingMode scaling_mode; + + bool operator<(const CacheKey &other) const { + return std::tie(name, rows, cols, dtype, rowwise, colwise, scaling_mode) < + std::tie(other.name, other.rows, other.cols, other.dtype, other.rowwise, other.colwise, other.scaling_mode); + } + }; + + static test::Tensor &get_or_create(const std::string &name, + const std::vector &shape, + transformer_engine::DType dtype, + bool rowwise = true, + bool colwise = false, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING, + bool initialize_random = false) { + CacheKey key{name, shape[0], shape[1], dtype, rowwise, colwise, scaling_mode}; + + static auto* cache = new std::map>(); + + auto it = cache->find(key); + if (it == cache->end()) { + auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); + + if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && + dtype != transformer_engine::DType::kFloat8E5M2) { + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + size_t count = shape[0] * shape[1]; + void *data_ptr = tensor_ptr->rowwise_dptr(); + + if (dtype == transformer_engine::DType::kFloat32) { + fill_random_uniform_gpu(static_cast(data_ptr), count, -2.0f, 1.0f, stream); + } else if (dtype == transformer_engine::DType::kFloat16) { + fill_random_uniform_gpu_typed<__half>(static_cast<__half*>(data_ptr), count, -2.0f, 1.0f, stream); + } else if (dtype == transformer_engine::DType::kBFloat16) { + fill_random_uniform_gpu_typed(static_cast(data_ptr), count, -2.0f, 1.0f, stream); + } + + HIP_CHECK(hipStreamSynchronize(stream)); + HIP_CHECK(hipStreamDestroy(stream)); + } + + (*cache)[key] = std::move(tensor_ptr); + it = cache->find(key); + } + + return *(it->second); + } +}; +} // namespace te_bench From 573f6ea681ff11c337c4a2395332f9450880f1a7 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 16:49:18 -0500 Subject: [PATCH 37/43] build: add rocm_utils.cmake needed by benchmarks/cpp CMakeLists Co-Authored-By: Claude Sonnet 4 --- build_tools/rocm_utils.cmake | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 build_tools/rocm_utils.cmake diff --git a/build_tools/rocm_utils.cmake b/build_tools/rocm_utils.cmake new file mode 100644 index 000000000..dca794c0c --- /dev/null +++ b/build_tools/rocm_utils.cmake @@ -0,0 +1,20 @@ +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +#Determine ROCM_PATH +if(NOT "$ENV{ROCM_PATH}" STREQUAL "") + set(ROCM_PATH "$ENV{ROCM_PATH}") +elseif(EXISTS "/opt/rocm/core") + set(ROCM_PATH "/opt/rocm/core") +else() + set(ROCM_PATH "/opt/rocm") +endif() + +#Configure target GPU architectures +if(NOT DEFINED ENV{NVTE_ROCM_ARCH}) + SET(CMAKE_HIP_ARCHITECTURES gfx942 gfx950) +else() + # Accept comma separated list for NVTE_ROCM_ARCH + string(REPLACE "," ";" HIP_ARCH_LIST "$ENV{NVTE_ROCM_ARCH}") + SET(CMAKE_HIP_ARCHITECTURES ${HIP_ARCH_LIST}) +endif() From 9f340a14a0563cdaab9469c2a5d29edead8682f0 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 22:19:46 +0000 Subject: [PATCH 38/43] fix: suppress clang warnings in Google Benchmark for gfx1250 toolchain Newer ROCm clang (required for gfx1250) classifies __COUNTER__ as a C2y extension and emits -Wc2y-extensions. Combined with benchmark's own -pedantic-errors -Werror flags this causes a build failure. Also suppress -Wunused-const-variable for benchmark.h compiled as a standalone TU. Also fetch googletest v1.14.0 via FetchContent since test_common.hip (included via benchmark_utils.h) depends on gtest/gtest.h. --- benchmarks/cpp/CMakeLists.txt | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index 6071f9083..f53e37ab3 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -14,6 +14,15 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(HIP REQUIRED) include(FetchContent) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) +set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + FetchContent_Declare( benchmark GIT_REPOSITORY https://github.com/google/benchmark.git @@ -22,6 +31,17 @@ FetchContent_Declare( set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable gtest in benchmark" FORCE) FetchContent_MakeAvailable(benchmark) +# Suppress clang warnings from benchmark headers that fire under -pedantic-errors +# on newer ROCm toolchains (gfx1250): __COUNTER__ classified as C2y extension, +# and kDefaultMinTimeStr triggers -Wunused-const-variable when benchmark.h is +# compiled as a standalone TU by hipcc. +set_target_properties(benchmark benchmark_main PROPERTIES LINKER_LANGUAGE CXX) +foreach(_bench_target benchmark benchmark_main) + target_compile_options(${_bench_target} PRIVATE + -Wno-c2y-extensions + -Wno-unused-const-variable + ) +endforeach() set(TESTS_CPP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../tests/cpp) @@ -77,6 +97,7 @@ function(add_te_benchmark TARGET_NAME SOURCE_FILE) target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK) target_link_libraries(${TARGET_NAME} PRIVATE benchmark::benchmark + GTest::gtest ${TRANSFORMER_ENGINE_LIB} hiprand ) From 09ed78cbab7c2c2e97c61e28a33601289356e0ca Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 17:40:21 -0500 Subject: [PATCH 39/43] test: remove 16384x16384 from gated swiglu test (causes CPU ref hang) --- tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 6bddf9c49..52180786d 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -446,7 +446,6 @@ std::vector> matrix_sizes = { {768, 1024}, {8192, 128}, {577, 1632}, - {16384, 16384}, }; std::vector> block_sizes = { From b02fe7639850ba3485cbaf25c9386122bfd32087 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 21:53:27 -0500 Subject: [PATCH 40/43] fix(rocm): remove TDM debug prints and fix NVTE_ROCM_BENCHMARK guards for benchmark build --- benchmarks/cpp/CMakeLists.txt | 10 ------- .../cpp/cast/bench_quantize_mxfp8_fused.cpp | 29 ++++++++++++++----- benchmarks/cpp/utils/benchmark_utils.h | 5 ++-- tests/cpp/test_common.cu | 27 +++++++++++------ .../common/util/cast_gated_kernels.cuh | 18 ++++-------- .../common/util/dequantize_kernels.cuh | 9 ++---- .../common/util/rocm_cast_kernels.cuh | 14 +++------ .../common/util/rocm_dequantize_kernels.cuh | 4 +-- 8 files changed, 56 insertions(+), 60 deletions(-) diff --git a/benchmarks/cpp/CMakeLists.txt b/benchmarks/cpp/CMakeLists.txt index f53e37ab3..dfb9ca3ec 100644 --- a/benchmarks/cpp/CMakeLists.txt +++ b/benchmarks/cpp/CMakeLists.txt @@ -14,15 +14,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(HIP REQUIRED) include(FetchContent) -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.14.0 -) -set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) -set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) -FetchContent_MakeAvailable(googletest) - FetchContent_Declare( benchmark GIT_REPOSITORY https://github.com/google/benchmark.git @@ -97,7 +88,6 @@ function(add_te_benchmark TARGET_NAME SOURCE_FILE) target_compile_definitions(${TARGET_NAME} PRIVATE NVTE_ROCM_BENCHMARK) target_link_libraries(${TARGET_NAME} PRIVATE benchmark::benchmark - GTest::gtest ${TRANSFORMER_ENGINE_LIB} hiprand ) diff --git a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp index 6f5540714..7bdb2ac10 100644 --- a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp +++ b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp @@ -6,6 +6,23 @@ #include #include +#include +#include +#include +static void segv_handler(int sig) { + void *buf[64]; + int n = backtrace(buf, 64); + fprintf(stderr, "[BACKTRACE] signal %d:\n", sig); + backtrace_symbols_fd(buf, n, STDERR_FILENO); + _exit(1); +} +__attribute__((constructor)) static void install_handler() { + struct sigaction sa{}; + sa.sa_handler = segv_handler; + sigemptyset(&sa.sa_mask); + sa.sa_flags = SA_RESETHAND; + sigaction(SIGSEGV, &sa, nullptr); +} #include #include #include "amd_detail/hip_float8.h" @@ -75,12 +92,11 @@ static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; - test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, +test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, NVTE_DELAYED_TENSOR_SCALING, true); - test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, +test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, NVTE_MXFP8_1D_SCALING, false); - - test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; +test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { std::vector bias_shape = {cols}; @@ -102,9 +118,8 @@ static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { HIP_CHECK(hipEventCreate(&start)); HIP_CHECK(hipEventCreate(&stop)); - warmup_gpu(); - - for (auto _ : state) { +warmup_gpu(); +for (auto _ : state) { HIP_CHECK(hipEventRecord(start, stream)); if constexpr (PROC_METHOD == CAST_ONLY) { diff --git a/benchmarks/cpp/utils/benchmark_utils.h b/benchmarks/cpp/utils/benchmark_utils.h index bd2906b30..7473280f0 100644 --- a/benchmarks/cpp/utils/benchmark_utils.h +++ b/benchmarks/cpp/utils/benchmark_utils.h @@ -191,9 +191,8 @@ class TensorCache { auto it = cache->find(key); if (it == cache->end()) { - auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); - - if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && +auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); +if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && dtype != transformer_engine::DType::kFloat8E5M2) { hipStream_t stream; HIP_CHECK(hipStreamCreate(&stream)); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 7a89148fd..20b023367 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -17,7 +17,9 @@ #include #include +#ifndef NVTE_ROCM_BENCHMARK #include +#endif #include #include @@ -28,9 +30,13 @@ namespace test { size_t create_seed_from_tensor_name(const std::string& tensor_name) { +#ifndef NVTE_ROCM_BENCHMARK auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + "/" + tensor_name; return std::hash{}(full_name); +#else + return std::hash{}(tensor_name); +#endif } std::vector all_fp_types = {DType::kFloat32, @@ -229,13 +235,13 @@ Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, const NVTEScalingMode &scaling_mode) { - name_ = name; +name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); rowwise_ = rowwise; columnwise_ = columnwise; - size_t total_size = bytes(shape, type); - void *dptr_rowwise = nullptr; +size_t total_size = bytes(shape, type); +void *dptr_rowwise = nullptr; void *dptr_columnwise = nullptr; cpu_data_rowwise_ = nullptr; cpu_data_columnwise_ = nullptr; @@ -251,7 +257,7 @@ Tensor::Tensor(const std::string& name, std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); - NVTEShape columnwise_shape = {}; +NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { @@ -271,12 +277,11 @@ Tensor::Tensor(const std::string& name, columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); } - tensor_ = TensorWrapper(scaling_mode); - - if (total_size != 0) { +tensor_ = TensorWrapper(scaling_mode); +if (total_size != 0) { if (rowwise) { - (void)cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) - (void)cudaMemset(dptr_rowwise, 0, total_size); +(void)cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) +(void)cudaMemset(dptr_rowwise, 0, total_size); cpu_data_rowwise_ = std::make_unique(total_size); std::fill_n(cpu_data_rowwise_.get(), total_size, 0); } @@ -528,6 +533,8 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } +#ifndef NVTE_ROCM_BENCHMARK + void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, double atol, double rtol, bool if_on_gpus, @@ -770,6 +777,8 @@ void adjust_ref_for_e8m0_scale_error(const std::string &name, } #endif // #ifdef __HIP_PLATFORM_AMD__ +#endif // NVTE_ROCM_BENCHMARK + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 1afdafdf6..3786bb76b 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1573,19 +1573,16 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow_fp8) { - fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 TDM -> cast_fp8_gated\n"); - cast_fp8_gated(grad, gated_input, output, stream); +cast_fp8_gated(grad, gated_input, output, stream); } else { - fprintf(stderr, "[DBG gated delayed_scaling] gfx1250 ROCm -> cast_gated/cast_dgated\n"); - if constexpr (IS_DGATED) { +if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { cast_gated(gated_input, output, stream); } } #elif defined(__HIP_PLATFORM_AMD__) - fprintf(stderr, "[DBG gated delayed_scaling] non-gfx1250 AMD -> cast_gated/cast_dgated\n"); - if constexpr (IS_DGATED) { +if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { cast_gated(gated_input, output, stream); @@ -1609,15 +1606,12 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { - fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 TDM -> cast_mxfp8_gated\n"); - cast_mxfp8_gated(grad, gated_input, output, stream); +cast_mxfp8_gated(grad, gated_input, output, stream); } else { - fprintf(stderr, "[DBG gated mxfp_scaling] gfx1250 ROCm -> rocm_cast_mxfp8_gated\n"); - rocm_cast_mxfp8_gated(grad, gated_input, output, stream); +rocm_cast_mxfp8_gated(grad, gated_input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) - fprintf(stderr, "[DBG gated mxfp_scaling] non-gfx1250 AMD -> rocm_cast_mxfp8_gated\n"); - rocm_cast_mxfp8_gated(grad, gated_input, output, stream); +rocm_cast_mxfp8_gated(grad, gated_input, output, stream); #else cast_mxfp8_gated(grad, gated_input, output, stream); #endif diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index df1199db4..fc89bfef6 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -427,15 +427,12 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { - fprintf(stderr, "[DBG dequantize_helper] gfx1250 TDM -> mxfp8_dequantize\n"); - dequantization::mxfp8_dequantize(input, output, stream); +dequantization::mxfp8_dequantize(input, output, stream); } else { - fprintf(stderr, "[DBG dequantize_helper] gfx1250 ROCm -> rocm_mxfp8_dequantize\n"); - rocm_mxfp8_dequantize(input, output, stream); +rocm_mxfp8_dequantize(input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) - fprintf(stderr, "[DBG dequantize_helper] non-gfx1250 AMD -> rocm_mxfp8_dequantize\n"); - rocm_mxfp8_dequantize(input, output, stream); +rocm_mxfp8_dequantize(input, output, stream); #else if (is_supported_by_CC_100()) { dequantization::mxfp8_dequantize(input, output, stream); diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 9d48ffb15..808569121 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -557,17 +557,14 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { - fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 TDM branch -> mxfp8_quantize\n"); - mxfp8_quantize(input, act_input, noop, output, +mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } else { - fprintf(stderr, "[DBG fp8_quantize_rocm] gfx1250 ROCm branch -> rocm_mxfp8_quantize\n"); - rocm_mxfp8_quantize(input, act_input, noop, output, +rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); } #else - fprintf(stderr, "[DBG fp8_quantize_rocm] non-gfx1250 AMD -> rocm_mxfp8_quantize\n"); - rocm_mxfp8_quantize(input, act_input, noop, output, +rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); #endif break; @@ -587,10 +584,7 @@ void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Ten const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - fprintf(stderr, "[DBG rocm_mxfp8_quantize] rows=%zu cols=%zu — launching cast_mxfp8_2D_kernel\n", - rows, cols); - - const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); +const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); const dim3 grid(blocks_X, blocks_Y); const size_t block_size = MXFP8_THREADS_PER_CHUNK; diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index abcb2a0d6..71dc115ed 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -147,9 +147,7 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - fprintf(stderr, "[DBG rocm_mxfp8_dequantize] rows=%zu cols=%zu — launching dequantize_mxfp8_kernel\n", - rows, cols); - const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); +const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); const size_t scales_X_rowwise = From 186d793c244df7412e6a880c745b633b8cb1388a Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 22:01:16 -0500 Subject: [PATCH 41/43] fix(rocm): restore indentation lost during debug print removal --- .../common/util/cast_gated_kernels.cuh | 12 ++++++------ .../common/util/dequantize_kernels.cuh | 6 +++--- transformer_engine/common/util/rocm_cast_kernels.cuh | 2 +- .../common/util/rocm_dequantize_kernels.cuh | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 3786bb76b..83993b631 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1573,16 +1573,16 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow_fp8) { -cast_fp8_gated(grad, gated_input, output, stream); + cast_fp8_gated(grad, gated_input, output, stream); } else { -if constexpr (IS_DGATED) { + if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { cast_gated(gated_input, output, stream); } } #elif defined(__HIP_PLATFORM_AMD__) -if constexpr (IS_DGATED) { + if constexpr (IS_DGATED) { cast_dgated(grad, gated_input, output, stream); } else { cast_gated(gated_input, output, stream); @@ -1606,12 +1606,12 @@ if constexpr (IS_DGATED) { cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { -cast_mxfp8_gated(grad, gated_input, output, stream); + cast_mxfp8_gated(grad, gated_input, output, stream); } else { -rocm_cast_mxfp8_gated(grad, gated_input, output, stream); + rocm_cast_mxfp8_gated(grad, gated_input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) -rocm_cast_mxfp8_gated(grad, gated_input, output, stream); + rocm_cast_mxfp8_gated(grad, gated_input, output, stream); #else cast_mxfp8_gated(grad, gated_input, output, stream); #endif diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index fc89bfef6..09cc32dfc 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -427,12 +427,12 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { -dequantization::mxfp8_dequantize(input, output, stream); + dequantization::mxfp8_dequantize(input, output, stream); } else { -rocm_mxfp8_dequantize(input, output, stream); + rocm_mxfp8_dequantize(input, output, stream); } #elif defined(__HIP_PLATFORM_AMD__) -rocm_mxfp8_dequantize(input, output, stream); + rocm_mxfp8_dequantize(input, output, stream); #else if (is_supported_by_CC_100()) { dequantization::mxfp8_dequantize(input, output, stream); diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 808569121..05180d5b8 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -584,7 +584,7 @@ void rocm_mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Ten const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); -const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t blocks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); const dim3 grid(blocks_X, blocks_Y); const size_t block_size = MXFP8_THREADS_PER_CHUNK; diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index 71dc115ed..17596a67b 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -147,7 +147,7 @@ static void rocm_mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStrea const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); -const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); + const size_t chunks_Y = DIVUP(rows, ROCM_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, ROCM_CHUNK_DIM_X); const size_t scales_X_rowwise = From 98627456ae7d1769aaac44fc9a29c71d191bbf00 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 22:03:46 -0500 Subject: [PATCH 42/43] fix(rocm): remove segfault debug handler and restore indentation in benchmark/test files --- .../cpp/cast/bench_quantize_mxfp8_fused.cpp | 27 ++++--------------- benchmarks/cpp/utils/benchmark_utils.h | 4 +-- tests/cpp/test_common.cu | 16 +++++------ 3 files changed, 15 insertions(+), 32 deletions(-) diff --git a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp index 7bdb2ac10..926ec32d3 100644 --- a/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp +++ b/benchmarks/cpp/cast/bench_quantize_mxfp8_fused.cpp @@ -6,23 +6,6 @@ #include #include -#include -#include -#include -static void segv_handler(int sig) { - void *buf[64]; - int n = backtrace(buf, 64); - fprintf(stderr, "[BACKTRACE] signal %d:\n", sig); - backtrace_symbols_fd(buf, n, STDERR_FILENO); - _exit(1); -} -__attribute__((constructor)) static void install_handler() { - struct sigaction sa{}; - sa.sa_handler = segv_handler; - sigemptyset(&sa.sa_mask); - sa.sa_flags = SA_RESETHAND; - sigaction(SIGSEGV, &sa, nullptr); -} #include #include #include "amd_detail/hip_float8.h" @@ -92,11 +75,11 @@ static void BM_QuantizeMXFP8_Fused(benchmark::State &state) { (std::is_same_v ? DType::kBFloat16 : DType::kFloat32); DType otype = std::is_same_v ? DType::kFloat8E4M3 : DType::kFloat8E5M2; -test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, + test::Tensor &input_tensor = TensorCache::get_or_create("input", shape, itype, true, false, NVTE_DELAYED_TENSOR_SCALING, true); -test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, + test::Tensor &output_tensor = TensorCache::get_or_create("output", shape, otype, USE_ROWWISE, USE_COLWISE, NVTE_MXFP8_1D_SCALING, false); -test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; + test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspace_tensor_ptr = nullptr; if constexpr (PROC_METHOD == CAST_DBIAS || PROC_METHOD == CAST_DBIAS_DACT) { std::vector bias_shape = {cols}; @@ -118,8 +101,8 @@ test::Tensor *grad_tensor_ptr = nullptr, *dbias_tensor_ptr = nullptr, *workspac HIP_CHECK(hipEventCreate(&start)); HIP_CHECK(hipEventCreate(&stop)); -warmup_gpu(); -for (auto _ : state) { + warmup_gpu(); + for (auto _ : state) { HIP_CHECK(hipEventRecord(start, stream)); if constexpr (PROC_METHOD == CAST_ONLY) { diff --git a/benchmarks/cpp/utils/benchmark_utils.h b/benchmarks/cpp/utils/benchmark_utils.h index 7473280f0..483d978ca 100644 --- a/benchmarks/cpp/utils/benchmark_utils.h +++ b/benchmarks/cpp/utils/benchmark_utils.h @@ -191,8 +191,8 @@ class TensorCache { auto it = cache->find(key); if (it == cache->end()) { -auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); -if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && + auto tensor_ptr = std::make_unique(name, shape, dtype, rowwise, colwise, scaling_mode); + if (initialize_random && dtype != transformer_engine::DType::kFloat8E4M3 && dtype != transformer_engine::DType::kFloat8E5M2) { hipStream_t stream; HIP_CHECK(hipStreamCreate(&stream)); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 20b023367..bee129e85 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -235,13 +235,13 @@ Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, const NVTEScalingMode &scaling_mode) { -name_ = name; + name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); rowwise_ = rowwise; columnwise_ = columnwise; -size_t total_size = bytes(shape, type); -void *dptr_rowwise = nullptr; + size_t total_size = bytes(shape, type); + void *dptr_rowwise = nullptr; void *dptr_columnwise = nullptr; cpu_data_rowwise_ = nullptr; cpu_data_columnwise_ = nullptr; @@ -257,7 +257,7 @@ void *dptr_rowwise = nullptr; std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); -NVTEShape columnwise_shape = {}; + NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { @@ -277,11 +277,11 @@ NVTEShape columnwise_shape = {}; columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); } -tensor_ = TensorWrapper(scaling_mode); -if (total_size != 0) { + tensor_ = TensorWrapper(scaling_mode); + if (total_size != 0) { if (rowwise) { -(void)cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) -(void)cudaMemset(dptr_rowwise, 0, total_size); + (void)cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) + (void)cudaMemset(dptr_rowwise, 0, total_size); cpu_data_rowwise_ = std::make_unique(total_size); std::fill_n(cpu_data_rowwise_.get(), total_size, 0); } From 1afc5b2cef2a061576fb0e1e8cb78d6b68a2f4e5 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 29 Apr 2026 22:07:25 -0500 Subject: [PATCH 43/43] fix(rocm): restore indentation in fp8_quantize_rocm TDM branch --- transformer_engine/common/util/rocm_cast_kernels.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index 05180d5b8..08a76d46a 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -557,14 +557,14 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso cuda::sm_arch_name().find("gfx1250") != std::string::npos; }(); if (use_tdm_flow) { -mxfp8_quantize(input, act_input, noop, output, - dbias, workspace, stream); + mxfp8_quantize(input, act_input, noop, output, + dbias, workspace, stream); } else { -rocm_mxfp8_quantize(input, act_input, noop, output, - dbias, workspace, stream); + rocm_mxfp8_quantize(input, act_input, noop, output, + dbias, workspace, stream); } #else -rocm_mxfp8_quantize(input, act_input, noop, output, + rocm_mxfp8_quantize(input, act_input, noop, output, dbias, workspace, stream); #endif break;