diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5fefb8c64..7c569ce00 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -222,6 +222,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/swiglu.cu cast/cast.cu hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) @@ -244,7 +245,6 @@ if(USE_CUDA) gemm/cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform.cu transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu transpose/quantize_transpose_square_blockwise.cu diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index ef60a13be..0631f326b 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -166,6 +166,7 @@ __device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float } } } +#endif __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, float* __restrict__ output_identity_amax_ptr, @@ -181,6 +182,7 @@ __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_p } } +#ifndef __HIP_PLATFORM_AMD__ template @@ -978,15 +980,8 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran auto* in_ptr = reinterpret_cast(input.dptr); - if (pre_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); - } - if (id_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); - } - if (tr_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); - } + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); #else constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 0696deaaa..f8c963390 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -9,19 +11,24 @@ #include #include #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include +#ifndef __HIP_PLATFORM_AMD__ #include #include #include +#endif #include "common/common.h" #include "common/util/cuda_runtime.h" #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/builders/sm100_common.inl" @@ -30,9 +37,11 @@ #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" +#endif // clang-format off +#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace detail { namespace { @@ -726,6 +735,210 @@ rht_gemm_ttt_wrapper(int m, int n, // clang-format on +} // namespace transformer_engine +#else + +#include "wht16.cuh" + +namespace transformer_engine { + +namespace { + +__device__ __forceinline__ float to_f32(__hip_bfloat16 v) { return static_cast(v); } + +__device__ __forceinline__ float group_max_4(float v) { + v = fmaxf(v, ds_swizzle_xor1(v)); + v = fmaxf(v, ds_swizzle_xor2(v)); + return v; +} + +__device__ __forceinline__ float compute_global_encode_scale_fp4(const float global_amax) { +#if !defined(__HIP_DEVICE_COMPILE__) + const float fp8_max = detail::TypeExtrema::max; +#else + constexpr float fp8_max = detail::TypeExtrema::max; +#endif + constexpr float fp4_max = detail::TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + global_encode_scale = fminf(global_encode_scale, detail::TypeExtrema::max); + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +__device__ __forceinline__ ScaleType compute_decode_scale_fp4(const float amax, + const float global_encode_scale) { + float decode_scale = amax / detail::TypeExtrema::max; + decode_scale *= global_encode_scale; + decode_scale = fminf(decode_scale, detail::TypeExtrema::max); + return static_cast(decode_scale); +} + +template +__device__ __forceinline__ float compute_encode_scale_fp4(ScaleType decode_scale, + const float global_decode_scale) { + return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), + detail::TypeExtrema::max); +} + +__device__ __forceinline__ uint32_t get_rbits( + transformer_engine::curanddx::detail::philox4x32_native_state<10>& rng, uint4& random_uint4, + int& rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + random_uint4 = rng.generate4(); + } + const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); + return rbits_arr[rnd_idx++]; +} + +template +__device__ __forceinline__ fp4e2m1x4 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const uint32_t rbits) { + if constexpr (kUseStochasticRounding) { +#if ARCH_HAS_STOCHASTIC_ROUNDING + union { + uint32_t ui32; + __hip_fp4x2_storage_t fp4x2[4]; + } packed{0}; + __amd_floatx2_storage_t packed01{in01.x, in01.y}; + __amd_floatx2_storage_t packed23{in23.x, in23.y}; + packed.ui32 = + __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed01, rbits, 1.0f, 1); + const __hip_fp4x2_storage_t lo = packed.fp4x2[1]; + packed.ui32 = + __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed23, rbits, 1.0f, 1); + const __hip_fp4x2_storage_t hi = packed.fp4x2[1]; + + fp4e2m1x4 result; + result.__x = static_cast<__hip_fp4x4_storage_t>( + lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8)); + return result; +#else + NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); + return fp4e2m1x4{}; +#endif + } else { + const __hip_fp4_storage_t q0 = + __hip_cvt_float_to_fp4(in01.x, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q1 = + __hip_cvt_float_to_fp4(in01.y, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q2 = + __hip_cvt_float_to_fp4(in23.x, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q3 = + __hip_cvt_float_to_fp4(in23.y, __HIP_E2M1, hipRoundNearest); + + fp4e2m1x4 result; + result.__x = static_cast<__hip_fp4x4_storage_t>((q0 & 0xFu) | ((q1 & 0xFu) << 4) | + ((q2 & 0xFu) << 8) | ((q3 & 0xFu) << 12)); + return result; + } +} + +__device__ __forceinline__ uint16_t fp4x4_to_bits(fp4e2m1x4 v) { + uint16_t bits; + __builtin_memcpy(&bits, &v, sizeof(bits)); + return bits; +} + +template +__global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusionKernel( + const __hip_bfloat16* __restrict__ input, uint8_t* __restrict__ output_t, + fp8e4m3* __restrict__ scale_inv_t, const float global_amax, + const uint16_t random_sign_mask_t, const uint64_t num_rows, const uint64_t row_length, + const size_t scale_stride, const size_t* rng_state) { + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane_id = tid % kWarpSize; + const int row_in_warp = lane_id / kThreadsPerWHT; + const int thread_in_grp = lane_id % kThreadsPerWHT; + + const uint64_t output_row = static_cast(blockIdx.x) * kHadamardDim + row_in_warp; + const uint64_t block_row_base = + static_cast(blockIdx.y) * kRowsPerBlock + warp_id * kHadamardDim; + + if (block_row_base + kHadamardDim > num_rows) { + return; + } + + const uint64_t input_row_base = block_row_base + thread_in_grp * kElemsPerThread; + const uint64_t input_col = output_row; + + float c0 = to_f32(input[(input_row_base + 0) * row_length + input_col]); + float c1 = to_f32(input[(input_row_base + 1) * row_length + input_col]); + float c2 = to_f32(input[(input_row_base + 2) * row_length + input_col]); + float c3 = to_f32(input[(input_row_base + 3) * row_length + input_col]); + + wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, /*apply_pre=*/true); + + // Truncate to BF16 precision to match the reference BF16 matmul path. + // Without this, FP32 WHT results at FP4 quantization boundaries round + // differently than the BF16-precision reference, causing off-by-one errors. + c0 = to_f32(static_cast<__hip_bfloat16>(c0)); + c1 = to_f32(static_cast<__hip_bfloat16>(c1)); + c2 = to_f32(static_cast<__hip_bfloat16>(c2)); + c3 = to_f32(static_cast<__hip_bfloat16>(c3)); + + const float local_block_amax = + fmaxf(fmaxf(fabsf(c0), fabsf(c1)), fmaxf(fabsf(c2), fabsf(c3))); + const float block_amax = group_max_4(local_block_amax); + + const float global_encode_scale = compute_global_encode_scale_fp4(global_amax); + const float global_decode_scale = 1.0f / global_encode_scale; + const fp8e4m3 scale_inv = compute_decode_scale_fp4(block_amax, global_encode_scale); + const float encode_scale = compute_encode_scale_fp4(scale_inv, global_decode_scale); + + if (thread_in_grp == 0) { + const uint64_t scale_col = block_row_base / kHadamardDim; + scale_inv_t[output_row * scale_stride + scale_col] = scale_inv; + } + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + uint4 random_uint4{0, 0, 0, 0}; + int rnd_idx = 0; + if constexpr (kUseStochasticRounding) { + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + const size_t rng_sequence = static_cast(threadIdx.x) + + static_cast(blockIdx.x) * blockDim.x + + static_cast(blockIdx.y) * gridDim.x * blockDim.x; + rng.init(rng_seed, rng_sequence, rng_offset); + random_uint4 = rng.generate4(); + } + + const float2 scaled01{c0 * encode_scale, c1 * encode_scale}; + const float2 scaled23{c2 * encode_scale, c3 * encode_scale}; + const uint32_t rbits = kUseStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + const uint16_t packed = fp4x4_to_bits(cvt_fp32_to_fp4_4x( + scaled01, scaled23, rbits)); + + const uint64_t output_col_base = input_row_base; + const uint64_t output_byte_offset = output_row * (num_rows / 2) + output_col_base / 2; + *reinterpret_cast(&output_t[output_byte_offset]) = packed; +} + +uint16_t random_sign_mask_from_rht_matrix(const SimpleTensor& hadamard_matrix, cudaStream_t stream) { + std::array host_matrix{}; + + NVTE_CHECK_CUDA(cudaMemcpyAsync(host_matrix.data(), hadamard_matrix.dptr, + host_matrix.size() * sizeof(uint16_t), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + uint16_t random_sign_mask = 0; + for (size_t row = 0; row < kHadamardDim; ++row) { + // The first column of diag(sign) @ H16 is sign[row] * 0.25. + random_sign_mask |= static_cast(((host_matrix[row * kHadamardDim] >> 15) & 1) << row); + } + return random_sign_mask; +} + +} // namespace + +} // namespace transformer_engine +#endif + +namespace transformer_engine { + void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, const Tensor &hadamard_matrix_, QuantizationConfig quant_config, @@ -757,6 +970,7 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out rng_state = reinterpret_cast(rng_state_tensor.data.dptr); } +#ifndef __HIP_PLATFORM_AMD__ // Template arguments using TA = cute::bfloat16_t; using TB = cute::bfloat16_t; @@ -764,6 +978,7 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out using TSFC = cutlass::float_ue4m3_t; checkCuDriverContext(stream); +#endif // Check Hadamard matrix constexpr int kHadamardDimension = 16; @@ -788,12 +1003,15 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out m *= input.shape[i]; } +#ifndef __HIP_PLATFORM_AMD__ auto sm_count = transformer_engine::cuda::sm_count(); +#endif NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); +#ifndef __HIP_PLATFORM_AMD__ int k_tile_size = 1024; if (m == 8192 && n == 5120) { @@ -825,8 +1043,6 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out } TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, kUseStochasticRounding, - TRANSFORMER_ENGINE_SWITCH_CONDITION( quant_config.use_fast_math, kUseFastMath, detail::rht_gemm_ttt_wrapper( /*m=*/m, @@ -840,6 +1056,26 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out /*sm_count=*/sm_count, /*stream=*/stream, /*k_tile_size=*/k_tile_size););); +#else + const uint16_t random_sign_mask_t = + random_sign_mask_from_rht_matrix(hadamard_matrix, stream); + + const dim3 block(kThreadsPerBlock); + const dim3 grid(DIVUP(n, static_cast(kHadamardDim)), + DIVUP(m, static_cast(kRowsPerBlock))); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + HadamardTransformCastFusionKernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv_t.dptr), + *reinterpret_cast(global_amax.dptr), random_sign_mask_t, + static_cast(m), static_cast(n), scale_inv_t.shape[1], + rng_state);); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#endif } } // namespace transformer_engine diff --git a/transformer_engine/common/hadamard_transform/wht16.cuh b/transformer_engine/common/hadamard_transform/wht16.cuh new file mode 100644 index 000000000..b9a1a51b7 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/wht16.cuh @@ -0,0 +1,82 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Shared 16-point Walsh-Hadamard transform primitives for AMDGPU. + +#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ +#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ + +#ifdef __HIP_PLATFORM_AMD__ + +static constexpr int kHadamardDim = 16; +static constexpr int kWarpSize = 64; +static constexpr int kThreadsPerWHT = 4; +static constexpr int kElemsPerThread = 4; +static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 +static constexpr int kWarpsPerBlock = 4; +static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 +static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 +static constexpr float kHadamardScale = 0.25f; + +// ds_swizzle: sub-wavefront exchange without LDS. +__device__ __forceinline__ float ds_swizzle_xor1(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +__device__ __forceinline__ float ds_swizzle_xor2(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +// 16-point WHT: in-register, no shared memory. +// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, +// extended with NV random_sign_mask (uint16_t bitmask). +// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). +// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). +__device__ __forceinline__ void wht16( + float& v0, float& v1, float& v2, float& v3, + int thread_in_group, uint16_t sign_mask, bool apply_pre) { + auto sgn = [&](int k) -> float { + return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; + }; + + if (apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } + + // Stage 1: local H4 + float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; + v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; + + // Stage 2: cross-thread XOR-1 + { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), + p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + bool up=(thread_in_group&1); + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + // Stage 3: cross-thread XOR-2 + { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), + p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + bool up=(thread_in_group>>1)&1; + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; + + if (!apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } +} + +#endif // __HIP_PLATFORM_AMD__ + +#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_