Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
d954c6d
Typo fix (#397)
Micky774 Dec 6, 2025
7b5cf20
ROCm UserBuffers for Comm Overlap
alextmagro Oct 30, 2025
640f7e8
Copyrights and cleanup
alextmagro Jan 27, 2026
82faeec
test guards
alextmagro Feb 24, 2026
b6a3ae4
Cleanup and RS flag race condition fix
alextmagro Mar 2, 2026
9e32d3a
Debugging midpoint
alextmagro Mar 12, 2026
84209ad
Cleanup and workspace fix
alextmagro Mar 17, 2026
c669bd2
Guard layer registration in UB
alextmagro Mar 17, 2026
8040909
Cleanup of profiling example for rocm
alextmagro Mar 17, 2026
e375923
Readd example script and update custom_map
alextmagro Mar 17, 2026
c6bd974
fix typo
alextmagro Mar 17, 2026
d76aa06
MI300 test skips due to jittery results
alextmagro Mar 18, 2026
ae979d0
Comment regarding sm_margin performance
alextmagro Mar 18, 2026
b58cbd1
Variable renamed, pybind fix, tolerance tightening
alextmagro Mar 23, 2026
e5d7446
Remove git conflict
alextmagro Mar 24, 2026
7734ce5
Address style and hip/cu specific paths
alextmagro Mar 26, 2026
c169c75
HIP guards
alextmagro Mar 27, 2026
80e0aab
initial impl
matthiasdiener Mar 27, 2026
de7863a
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Mar 27, 2026
bda7b13
test update
matthiasdiener Mar 30, 2026
7ddb539
Update extensions.h
alextmagro Mar 30, 2026
63c7a48
amax opt
matthiasdiener Mar 30, 2026
a260459
simplify
matthiasdiener Mar 30, 2026
3dd8af9
Merge pull request #367 from ROCm/userbuffer_epic
alextmagro Mar 31, 2026
ab217cb
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Mar 31, 2026
26c5fb7
simplify pt 2
matthiasdiener Mar 31, 2026
2087f24
expand test
matthiasdiener Mar 31, 2026
05cedb7
compute amax from BF16-rounded outputs
matthiasdiener Mar 31, 2026
67b93a8
TE building over TheRock (#511)
ipanfilo Apr 1, 2026
465d547
Typo fix (#397)
Micky774 Dec 6, 2025
9fb21f9
Add NVTE_UB_WITH_MPI to rocm build path
alextmagro Apr 1, 2026
2f66594
Merge pull request #513 from ROCm/ub_mpi_hotfix
alextmagro Apr 1, 2026
986d8ba
NVFP4: hadamard_transform_cast_fusion_columnwise
matthiasdiener Apr 1, 2026
b339c86
unify hadamard_transform_cast_fusion_columnwise
matthiasdiener Apr 1, 2026
f74a0ab
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 1, 2026
e9426cd
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 1, 2026
e3a2502
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 2, 2026
3a63f32
add explanation to wht16
matthiasdiener Apr 2, 2026
17d50ee
Merge branch 'dev' into mdiener/fp4_hadamard
matthiasdiener Apr 6, 2026
e32a758
merge errors
matthiasdiener Apr 6, 2026
4857721
Merge branch 'dev' into mdiener/fp4_hadamard
matthiasdiener Apr 6, 2026
b243b4c
merge
matthiasdiener Apr 6, 2026
6527004
Merge branch 'dev' into mdiener/fp4_hadamard
matthiasdiener Apr 7, 2026
ca1aacf
change to __builtin_bit_cast
matthiasdiener Apr 7, 2026
bc9f0a3
remove copyright header
matthiasdiener Apr 8, 2026
9f1851d
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 9, 2026
739a20d
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 13, 2026
f269097
enable tests
matthiasdiener Apr 13, 2026
346beb1
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 16, 2026
cf2c8f6
address reviewer comments
matthiasdiener Apr 16, 2026
2772834
minor fixes
matthiasdiener Apr 16, 2026
26c5cb1
PreRhtAmax optimizations
matthiasdiener Apr 16, 2026
071aa4b
Merge branch 'mdiener/fp4_hadamard' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 17, 2026
018d24f
use ZeroAmaxKernel
matthiasdiener Apr 17, 2026
3efd532
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 17, 2026
b835818
Merge branch 'mdiener/fp4_hadamard' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 17, 2026
c723ccc
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
activation/swiglu.cu
cast/cast.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
Expand All @@ -244,7 +245,6 @@ if(USE_CUDA)
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
transpose/quantize_transpose_square_blockwise.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
transpose/quantize_transpose_square_blockwise.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ __device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float
}
}
}
#endif

__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
Expand All @@ -181,6 +182,7 @@ __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_p
}
}

#ifndef __HIP_PLATFORM_AMD__
template <typename IType, int kHadamardDimension, int CHUNK_DIM_Y, int CHUNK_DIM_X, int BUFF_DIM_Y,
int BUFF_DIM_X, int THREADS_PER_CHUNK, int THREADS_PER_Y, bool kReturnPreRhtAmax,
bool kReturnIdentityAmax, bool kReturnTransposedAmax>
Expand Down Expand Up @@ -978,15 +980,8 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran

auto* in_ptr = reinterpret_cast<const __hip_bfloat16*>(input.dptr);

if (pre_amax_ptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream));
}
if (id_amax_ptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream));
}
if (tr_amax_ptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream));
}
ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
#else
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand All @@ -9,19 +11,24 @@
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cutlass/arch/barrier.h>
#endif
#include <transformer_engine/hadamard_transform.h>

#include <cuda/barrier>
#ifndef __HIP_PLATFORM_AMD__
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#endif

#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#ifndef __HIP_PLATFORM_AMD__
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
Expand All @@ -30,9 +37,11 @@
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/print_error.hpp"
#endif

// clang-format off

#ifndef __HIP_PLATFORM_AMD__
namespace transformer_engine {
namespace detail {
namespace {
Expand Down Expand Up @@ -726,6 +735,210 @@ rht_gemm_ttt_wrapper(int m, int n,

// clang-format on

} // namespace transformer_engine
#else

#include "wht16.cuh"

namespace transformer_engine {

namespace {

__device__ __forceinline__ float to_f32(__hip_bfloat16 v) { return static_cast<float>(v); }

__device__ __forceinline__ float group_max_4(float v) {
v = fmaxf(v, ds_swizzle_xor1(v));
v = fmaxf(v, ds_swizzle_xor2(v));
return v;
}

__device__ __forceinline__ float compute_global_encode_scale_fp4(const float global_amax) {
#if !defined(__HIP_DEVICE_COMPILE__)
const float fp8_max = detail::TypeExtrema<fp8e4m3>::max;
#else
constexpr float fp8_max = detail::TypeExtrema<fp8e4m3>::max;
#endif
constexpr float fp4_max = detail::TypeExtrema<fp4e2m1>::max;
float global_encode_scale = fp8_max * fp4_max / global_amax;
global_encode_scale = fminf(global_encode_scale, detail::TypeExtrema<float>::max);
return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale;
}

template <typename ScaleType>
__device__ __forceinline__ ScaleType compute_decode_scale_fp4(const float amax,
const float global_encode_scale) {
float decode_scale = amax / detail::TypeExtrema<fp4e2m1>::max;
decode_scale *= global_encode_scale;
decode_scale = fminf(decode_scale, detail::TypeExtrema<float>::max);
return static_cast<ScaleType>(decode_scale);
}

template <typename ScaleType>
__device__ __forceinline__ float compute_encode_scale_fp4(ScaleType decode_scale,
const float global_decode_scale) {
return fminf(1.0f / (static_cast<float>(decode_scale) * global_decode_scale),
detail::TypeExtrema<float>::max);
}

__device__ __forceinline__ uint32_t get_rbits(
transformer_engine::curanddx::detail::philox4x32_native_state<10>& rng, uint4& random_uint4,
int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
random_uint4 = rng.generate4();
}
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
return rbits_arr[rnd_idx++];
}

template <bool kUseStochasticRounding>
__device__ __forceinline__ fp4e2m1x4 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const uint32_t rbits) {
if constexpr (kUseStochasticRounding) {
#if ARCH_HAS_STOCHASTIC_ROUNDING
union {
uint32_t ui32;
__hip_fp4x2_storage_t fp4x2[4];
} packed{0};
__amd_floatx2_storage_t packed01{in01.x, in01.y};
__amd_floatx2_storage_t packed23{in23.x, in23.y};
packed.ui32 =
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed01, rbits, 1.0f, 1);
const __hip_fp4x2_storage_t lo = packed.fp4x2[1];
packed.ui32 =
__builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed23, rbits, 1.0f, 1);
const __hip_fp4x2_storage_t hi = packed.fp4x2[1];

fp4e2m1x4 result;
result.__x = static_cast<__hip_fp4x4_storage_t>(
lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8));
return result;
#else
NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later.");
return fp4e2m1x4{};
#endif
} else {
const __hip_fp4_storage_t q0 =
__hip_cvt_float_to_fp4(in01.x, __HIP_E2M1, hipRoundNearest);
const __hip_fp4_storage_t q1 =
__hip_cvt_float_to_fp4(in01.y, __HIP_E2M1, hipRoundNearest);
const __hip_fp4_storage_t q2 =
__hip_cvt_float_to_fp4(in23.x, __HIP_E2M1, hipRoundNearest);
const __hip_fp4_storage_t q3 =
__hip_cvt_float_to_fp4(in23.y, __HIP_E2M1, hipRoundNearest);

fp4e2m1x4 result;
result.__x = static_cast<__hip_fp4x4_storage_t>((q0 & 0xFu) | ((q1 & 0xFu) << 4) |
((q2 & 0xFu) << 8) | ((q3 & 0xFu) << 12));
return result;
}
}

__device__ __forceinline__ uint16_t fp4x4_to_bits(fp4e2m1x4 v) {
uint16_t bits;
__builtin_memcpy(&bits, &v, sizeof(bits));
return bits;
}

template <bool kUseStochasticRounding>
__global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusionKernel(
const __hip_bfloat16* __restrict__ input, uint8_t* __restrict__ output_t,
fp8e4m3* __restrict__ scale_inv_t, const float global_amax,
const uint16_t random_sign_mask_t, const uint64_t num_rows, const uint64_t row_length,
const size_t scale_stride, const size_t* rng_state) {
const int tid = threadIdx.x;
const int warp_id = tid / kWarpSize;
const int lane_id = tid % kWarpSize;
const int row_in_warp = lane_id / kThreadsPerWHT;
const int thread_in_grp = lane_id % kThreadsPerWHT;

const uint64_t output_row = static_cast<uint64_t>(blockIdx.x) * kHadamardDim + row_in_warp;
const uint64_t block_row_base =
static_cast<uint64_t>(blockIdx.y) * kRowsPerBlock + warp_id * kHadamardDim;

if (block_row_base + kHadamardDim > num_rows) {
return;
}

const uint64_t input_row_base = block_row_base + thread_in_grp * kElemsPerThread;
const uint64_t input_col = output_row;

float c0 = to_f32(input[(input_row_base + 0) * row_length + input_col]);
float c1 = to_f32(input[(input_row_base + 1) * row_length + input_col]);
float c2 = to_f32(input[(input_row_base + 2) * row_length + input_col]);
float c3 = to_f32(input[(input_row_base + 3) * row_length + input_col]);

wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, /*apply_pre=*/true);

// Truncate to BF16 precision to match the reference BF16 matmul path.
// Without this, FP32 WHT results at FP4 quantization boundaries round
// differently than the BF16-precision reference, causing off-by-one errors.
c0 = to_f32(static_cast<__hip_bfloat16>(c0));
c1 = to_f32(static_cast<__hip_bfloat16>(c1));
c2 = to_f32(static_cast<__hip_bfloat16>(c2));
c3 = to_f32(static_cast<__hip_bfloat16>(c3));

const float local_block_amax =
fmaxf(fmaxf(fabsf(c0), fabsf(c1)), fmaxf(fabsf(c2), fabsf(c3)));
const float block_amax = group_max_4(local_block_amax);

const float global_encode_scale = compute_global_encode_scale_fp4(global_amax);
const float global_decode_scale = 1.0f / global_encode_scale;
const fp8e4m3 scale_inv = compute_decode_scale_fp4<fp8e4m3>(block_amax, global_encode_scale);
const float encode_scale = compute_encode_scale_fp4(scale_inv, global_decode_scale);

if (thread_in_grp == 0) {
const uint64_t scale_col = block_row_base / kHadamardDim;
scale_inv_t[output_row * scale_stride + scale_col] = scale_inv;
}

transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
uint4 random_uint4{0, 0, 0, 0};
int rnd_idx = 0;
if constexpr (kUseStochasticRounding) {
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
const size_t rng_sequence = static_cast<size_t>(threadIdx.x) +
static_cast<size_t>(blockIdx.x) * blockDim.x +
static_cast<size_t>(blockIdx.y) * gridDim.x * blockDim.x;
rng.init(rng_seed, rng_sequence, rng_offset);
random_uint4 = rng.generate4();
}

const float2 scaled01{c0 * encode_scale, c1 * encode_scale};
const float2 scaled23{c2 * encode_scale, c3 * encode_scale};
const uint32_t rbits = kUseStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
const uint16_t packed = fp4x4_to_bits(cvt_fp32_to_fp4_4x<kUseStochasticRounding>(
scaled01, scaled23, rbits));

const uint64_t output_col_base = input_row_base;
const uint64_t output_byte_offset = output_row * (num_rows / 2) + output_col_base / 2;
*reinterpret_cast<uint16_t*>(&output_t[output_byte_offset]) = packed;
}

uint16_t random_sign_mask_from_rht_matrix(const SimpleTensor& hadamard_matrix, cudaStream_t stream) {
std::array<uint16_t, kHadamardDim * kHadamardDim> host_matrix{};

NVTE_CHECK_CUDA(cudaMemcpyAsync(host_matrix.data(), hadamard_matrix.dptr,
host_matrix.size() * sizeof(uint16_t),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));

uint16_t random_sign_mask = 0;
for (size_t row = 0; row < kHadamardDim; ++row) {
// The first column of diag(sign) @ H16 is sign[row] * 0.25.
random_sign_mask |= static_cast<uint16_t>(((host_matrix[row * kHadamardDim] >> 15) & 1) << row);
}
return random_sign_mask;
}

} // namespace

} // namespace transformer_engine
#endif

namespace transformer_engine {

void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_,
const Tensor &hadamard_matrix_,
QuantizationConfig quant_config,
Expand Down Expand Up @@ -757,13 +970,15 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}

#ifndef __HIP_PLATFORM_AMD__
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TC = cutlass::float_e2m1_t;
using TSFC = cutlass::float_ue4m3_t;

checkCuDriverContext(stream);
#endif

// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
Expand All @@ -788,12 +1003,15 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
m *= input.shape[i];
}

#ifndef __HIP_PLATFORM_AMD__
auto sm_count = transformer_engine::cuda::sm_count();
#endif

NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension.");

NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension");

#ifndef __HIP_PLATFORM_AMD__
int k_tile_size = 1024;

if (m == 8192 && n == 5120) {
Expand Down Expand Up @@ -825,8 +1043,6 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
}

TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
quant_config.use_fast_math, kUseFastMath,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding, kUseFastMath>(
/*m=*/m,
Expand All @@ -840,6 +1056,26 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size);););
#else
const uint16_t random_sign_mask_t =
random_sign_mask_from_rht_matrix(hadamard_matrix, stream);

const dim3 block(kThreadsPerBlock);
const dim3 grid(DIVUP(n, static_cast<size_t>(kHadamardDim)),
DIVUP(m, static_cast<size_t>(kRowsPerBlock)));

TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
HadamardTransformCastFusionKernel<kUseStochasticRounding><<<grid, block, 0, stream>>>(
reinterpret_cast<const __hip_bfloat16*>(input.dptr),
reinterpret_cast<uint8_t*>(output_t.dptr),
reinterpret_cast<fp8e4m3*>(scale_inv_t.dptr),
*reinterpret_cast<const float*>(global_amax.dptr), random_sign_mask_t,
static_cast<uint64_t>(m), static_cast<uint64_t>(n), scale_inv_t.shape[1],
rng_state););

NVTE_CHECK_CUDA(cudaGetLastError());
#endif
}

} // namespace transformer_engine
Expand Down
Loading