diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 157b4ab45..ea1234212 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -65,7 +65,11 @@ std::vector create_transpose(const InputType* const input, const size // Compute the global encode scale factor for a given global amax float compute_global_encode_scaling_factor_FP4(const float global_amax) { +#ifdef __HIP_PLATFORM_AMD__ + const float fp8_max = Numeric_Traits::maxNorm; +#else constexpr float fp8_max = 448.0f; // 448.0f; +#endif constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 00f1eabbc..1afa7eef9 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -108,7 +108,12 @@ void compute_ref(const fp4e2m1* input, const size_t rows, const size_t cols, const size_t scale_stride) { +#ifdef __HIP_PLATFORM_AMD__ + const float fp8_max = Numeric_Traits::maxNorm; + const float factor_inv = 1.0f / (6.0f * fp8_max); +#else constexpr float factor_inv = 1.0f / (6.0f * 448.0f); +#endif const size_t blocks_per_row = cols / kFP4BlockSize1D; diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660d..2bdfbed19 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,7 +12,10 @@ from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -108,8 +113,13 @@ def check_nvfp4_gemm_versus_reference( # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn # for the reference GEMM to work correctly - sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) - sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) + if IS_HIP_EXTENSION: + fp8_dtype = get_torch_float8_e4m3_type() + sx_trimmed = sx_trimmed.view(fp8_dtype) + sw_trimmed = sw_trimmed.view(fp8_dtype) + else: + sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) + sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM ref_quantizer = NVFP4QuantizerRef( @@ -150,7 +160,14 @@ def check_nvfp4_gemm_versus_reference( # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) # Allocate cuBLAS workspace - workspace = torch.empty(4, dtype=torch.uint8, device=device) + if IS_HIP_EXTENSION: + # On ROCm, FP4 is dequantized to BF16 in workspace before GEMM, so allocate enough space. + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes + bf16_size = torch.bfloat16.itemsize + ws_bytes = M * K * bf16_size + K * N * bf16_size + get_cublas_workspace_size_bytes() + workspace = torch.empty(ws_bytes, dtype=torch.uint8, device=device) + else: + workspace = torch.empty(4, dtype=torch.uint8, device=device) transa = True if not w_columnwise else False transb = False if not x_columnwise else True diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815..11777a715 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,9 @@ import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -58,10 +63,18 @@ def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: - sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) + if IS_HIP_EXTENSION: + fp8_dtype = get_torch_float8_e4m3_type() + fp8_max = 240.0 if is_fp8_fnuz() else 448.0 + sf = sx.repeat_interleave(16, dim=1).view(fp8_dtype).to(torch.float32) + else: + sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) dqx = fp4_to_fp32(unpack_fp4(qx)) sf = sf[: dqx.shape[0], : dqx.shape[1]] - dequant = dqx * sf * (amax / (6.0 * 448)) + if IS_HIP_EXTENSION: + dequant = dqx * sf * (amax / (6.0 * fp8_max)) + else: + dequant = dqx * sf * (amax / (6.0 * 448)) return dequant diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e4647ac82..cc3c6f341 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1773,7 +1773,10 @@ def test_clamped_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) + if IS_HIP_EXTENSION: + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + else: + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -2937,6 +2940,8 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: # Check values tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking + if IS_HIP_EXTENSION: + tols["atol"] = 0.54 torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d1e9b341e..4a768377e 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -778,7 +778,6 @@ def test_gpt_full_activation_recompute( if (dtype == torch.bfloat16 and not fp8 and not use_reentrant - and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") if fp8 and recipe.nvfp4(): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2ec51746d..b057951c2 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -214,6 +214,7 @@ list(APPEND transformer_engine_cuda_sources recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu + recipe/nvfp4.cu swizzle/swizzle.cu) list(APPEND transformer_engine_cuda_arch_specific_sources diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 94fc16b03..bd01acefe 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -66,7 +66,11 @@ __global__ void __launch_bounds__(512) #else float amax = (tensor_amax != nullptr) ? *tensor_amax : 1.0f; #endif +#if defined(__HIP_DEVICE_COMPILE__) + constexpr float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); +#else constexpr float factor_inv = 1.0 / (6.0 * 448.0); +#endif float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 3bc8d9bc8..c0e82b8ff 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -197,6 +197,79 @@ struct GemmParam { int ldb = 0; // B column strides }; +// FP4 e2m1 lookup table +__device__ constexpr float kFP4E2M1Table[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f,-0.5f,-1.0f,-1.5f,-2.0f,-3.0f,-4.0f,-6.0f +}; + +// Dequantize FP4 (e2m1) packed data with FP8 e4m3 block scales to BF16. +// Only applies block scales: output = fp4_value * block_scale. +// The per-tensor amax correction is applied separately via the GEMM alpha scalar. +// +// Scale layout: 2D tensor of shape {num_rows_padded, scale_stride} where +// scale_stride = roundup(num_cols / 16, 4). Each scale covers a block of 16 +// consecutive elements along the fast (column) dimension. +__global__ void dequant_fp4_to_bf16_kernel( + const uint8_t* __restrict__ data, + const fp8e4m3* __restrict__ scale_inv, + hip_bfloat16* __restrict__ output, + int64_t total_elements, + int64_t num_cols, + int64_t scale_stride) +{ + // Process 2 elements (1 byte) per iteration for coalesced access + const int64_t pair_idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t total_pairs = total_elements / 2; + if (pair_idx >= total_pairs) return; + + const uint8_t byte = data[pair_idx]; + const uint8_t lo_nibble = byte & 0xF; + const uint8_t hi_nibble = byte >> 4; + + const int64_t elem_base = pair_idx * 2; + const int64_t row0 = elem_base / num_cols; + const int64_t col0 = elem_base % num_cols; + const int64_t row1 = (elem_base + 1) / num_cols; + const int64_t col1 = (elem_base + 1) % num_cols; + const float s0 = static_cast(scale_inv[row0 * scale_stride + col0 / 16]); + const float s1 = static_cast(scale_inv[row1 * scale_stride + col1 / 16]); + + output[elem_base] = static_cast(kFP4E2M1Table[lo_nibble] * s0); + output[elem_base + 1] = static_cast(kFP4E2M1Table[hi_nibble] * s1); +} + +// Launch helper for dequant kernel +static void launch_dequant_fp4_to_bf16( + const void* data, const void* scale_inv, + void* output, int64_t total_elements, + int64_t num_cols, int64_t scale_stride, + hipStream_t stream) +{ + constexpr int kBlockSize = 256; + const int64_t total_pairs = total_elements / 2; + const int64_t num_blocks = (total_pairs + kBlockSize - 1) / kBlockSize; + + dequant_fp4_to_bf16_kernel<<>>( + reinterpret_cast(data), + reinterpret_cast(scale_inv), + reinterpret_cast(output), + total_elements, num_cols, scale_stride); +} + +// Compute per-row alpha vector on device for NVFP4 GEMM: +// alpha_out[i] = alpha_in * amax_A * amax_B / (fp4_max^2 * fp8_max^2) for i in [0, m) +// Used with HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST, which +// expects a device vector of length m for alpha, while beta stays on the host. +__global__ void compute_fp4_alpha_vector_kernel(float alpha_in, const float* __restrict__ amax_A, + const float* __restrict__ amax_B, float factor_inv, + float* __restrict__ alpha_out, int m) { + const float alpha_val = alpha_in * (*amax_A) * (*amax_B) * factor_inv; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < m; i += blockDim.x * gridDim.x) { + alpha_out[i] = alpha_val; + } +} + GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, const int m, const int n, const int k) { @@ -245,6 +318,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = is_A_transposed ? k : m; + } else if (is_nvfp_scaling(A.scaling_mode)) { + // NVFP4: dequant path always produces TN layout for the BF16 GEMM, + // but the source data may come from either rowwise or columnwise buffers. + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; // NVFP4 gemm is always TN layout + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; } else { NVTE_ERROR("A has unsupported scaling mode"); } @@ -283,6 +364,14 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; + } else if (is_nvfp_scaling(B.scaling_mode)) { + // NVFP4: dequant path always produces TN layout for the BF16 GEMM, + // but the source data may come from either rowwise or columnwise buffers. + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; // NVFP4 gemm is always TN layout + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; } else { NVTE_ERROR("B has unsupported scaling mode"); } @@ -290,6 +379,90 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } +// Dequantize FP4 inputs to BF16 in-place within the workspace and set up +// the alpha device vector for the subsequent hipBLASLt GEMM. +// After this call, param.A/B point to BF16 buffers within workspace, +// param.Atype/Btype are kBFloat16, and *alpha_ptr_out points to a device vector. +static void dequant_fp4_gemm_inputs( + GemmParam& param, + const transformer_engine::Tensor& inputA, cublasOperation_t transa, + const transformer_engine::Tensor& inputB, cublasOperation_t transb, + int m, int n, int k, float alpha, + void* workspace, size_t& workspaceSize, + const void** alpha_ptr_out, hipStream_t stream) { + + const float fp4_max = 6.0f; + const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); + + const float* amax_A = (transa == CUBLAS_OP_T) + ? reinterpret_cast(inputA.amax.dptr) + : reinterpret_cast(inputA.columnwise_amax.dptr); + const float* amax_B = (transb == CUBLAS_OP_N) + ? reinterpret_cast(inputB.amax.dptr) + : reinterpret_cast(inputB.columnwise_amax.dptr); + + // Compute total extra bytes needed from the workspace: + // alpha vector: m * sizeof(float) + // dequant A: m * k * sizeof(bf16) (if A is FP4) + // dequant B: k * n * sizeof(bf16) (if B is FP4) + const size_t alpha_vec_bytes = static_cast(m) * sizeof(float); + const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) + ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; + const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) + ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; + const size_t fp4_total_bytes = alpha_vec_bytes + a_bf16_bytes + b_bf16_bytes; + NVTE_CHECK(workspaceSize >= fp4_total_bytes, + "NVFP4 GEMM requires at least ", fp4_total_bytes, " bytes workspace (", + fp4_total_bytes / (1024 * 1024), " MiB) for alpha vector + BF16 dequant buffers, " + "but only ", workspaceSize, " bytes (", workspaceSize / (1024 * 1024), + " MiB) available. Increase the cuBLAS workspace size."); + + // Carve regions from the end of the workspace. + // Layout: [cublas workspace ... | alpha_vec | dequant_a | dequant_b] + workspaceSize = (workspaceSize / sizeof(float)) * sizeof(float) - fp4_total_bytes; + uint8_t* ws_ptr = reinterpret_cast(workspace) + workspaceSize; + + float* device_alpha_vec = reinterpret_cast(ws_ptr); + ws_ptr += alpha_vec_bytes; + + NVTE_CHECK(amax_A != nullptr, "FP4 GEMM requires amax_A"); + NVTE_CHECK(amax_B != nullptr, "FP4 GEMM requires amax_B"); + constexpr int kBlockSize = 256; + const int num_blocks = (m + kBlockSize - 1) / kBlockSize; + compute_fp4_alpha_vector_kernel<<>>( + alpha, amax_A, amax_B, factor_inv, device_alpha_vec, m); + *alpha_ptr_out = static_cast(device_alpha_vec); + + // Stage FP4 operand: dequantize to BF16 in workspace and update GEMM param. + auto stage_fp4_operand = [&](DType& op_type, void*& op_data, + void*& op_scale_inv, + const transformer_engine::Tensor& input, + bool use_rowwise, int64_t rows, int64_t cols, + size_t bf16_bytes) { + if (!is_fp4_dtype(op_type)) + return; + + hip_bfloat16* bf16_buf = reinterpret_cast(ws_ptr); + ws_ptr += bf16_bytes; + const auto& sinv = use_rowwise ? input.scale_inv : input.columnwise_scale_inv; + const int64_t num_cols = use_rowwise ? input.data.shape.back() + : input.columnwise_data.shape.back(); + const int64_t scale_stride = (sinv.shape.size() >= 2) ? sinv.shape[1] : (num_cols / 16); + launch_dequant_fp4_to_bf16(op_data, op_scale_inv, bf16_buf, + rows * cols, num_cols, scale_stride, stream); + op_data = bf16_buf; + op_type = DType::kBFloat16; + op_scale_inv = nullptr; + }; + + // Dequantize FP4 -> BF16 (block scales only, no amax folded in) + stage_fp4_operand(param.Atype, param.A, param.A_scale_inv, + inputA, transa == CUBLAS_OP_T, m, k, a_bf16_bytes); + stage_fp4_operand(param.Btype, param.B, param.B_scale_inv, + inputB, transb == CUBLAS_OP_N, k, n, b_bf16_bytes); +} + static class HandlePool { public: @@ -521,19 +694,22 @@ public: //Make it int instead of hipblasLtMatmulMatrixScale_t for compatibility with old hipblasLt int scaling_mode; hipblasLtEpilogue_t epilogue; + bool fp4_alpha_device_vector; // FP4 uses ALPHA_DEVICE_VECTOR pointer mode Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_, hipDataType bias_type_, hipDataType aux_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_, hipblasOperation_t transa_, hipblasOperation_t transb_, - int scaling_mode_, hipblasLtEpilogue_t epilogue_): + int scaling_mode_, hipblasLtEpilogue_t epilogue_, + bool fp4_alpha_device_vector_): deviceCap(deviceCap_), a_type(a_type_), b_type(b_type_), d_type(d_type_), bias_type(bias_type_), aux_type(aux_type_), m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_), transa(transa_), transb(transb_), - scaling_mode(scaling_mode_), epilogue(epilogue_) {} + scaling_mode(scaling_mode_), epilogue(epilogue_), + fp4_alpha_device_vector(fp4_alpha_device_vector_) {} Key() {} @@ -546,7 +722,8 @@ public: && (m == val.m) && (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) - && (scaling_mode == val.scaling_mode) && (epilogue == val.epilogue) ); + && (scaling_mode == val.scaling_mode) && (epilogue == val.epilogue) + && (fp4_alpha_device_vector == val.fp4_alpha_device_vector) ); } struct Comp @@ -676,7 +853,7 @@ protected: fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" << "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type" << "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type" - << "ws_min" << "ws_max" << "algo_id" << "aidx"; + << "fp4_alpha" << "ws_min" << "ws_max" << "algo_id" << "aidx"; } void load_() @@ -747,7 +924,9 @@ protected: std::getline(is, epi, csv_sep); std::getline(is, comp, csv_sep); std::getline(is, scale, csv_sep); - is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; + int fp4_alpha = 0; + is >> fp4_alpha >> c >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; + cfg.fp4_alpha_device_vector = (fp4_alpha != 0); if (is.bad()) { @@ -882,7 +1061,9 @@ protected: << ((cfg.aux_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.aux_type)) << cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName(cfg.epilogue) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) - << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; + << (cfg.fp4_alpha_device_vector ? 1 : 0) + << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index + << csv_helper::end() << "\n"; } private: @@ -951,7 +1132,23 @@ void hipblaslt_gemm(const Tensor *inputA, } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + + // FP4 dequant path: hipBLASLt does not support FP4 natively, + // so we dequantize FP4 -> BF16 (block scales only) and run a standard BF16 GEMM. + // + // The per-tensor amax correction is computed on-device as a per-row alpha vector: + // alpha'[i] = alpha * amax_A * amax_B / (fp4_max^2 * fp8_max^2) + // Alpha is passed as a device vector of length m via + // HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST. Beta stays on host. + const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype); + const void* alpha_ptr = static_cast(&alpha); + const void* beta_ptr = static_cast(&beta); + if (use_fp4) { + dequant_fp4_gemm_inputs(param, *inputA, transa, *inputB, transb, + m, n, k, alpha, workspace, workspaceSize, + &alpha_ptr, stream); + } bool nvte_log_gemm_config = false; if (const char* env_p = std::getenv("NVTE_LOG_GEMM_CONFIG") ) { @@ -1172,10 +1369,18 @@ void hipblaslt_gemm(const Tensor *inputA, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + if (use_fp4) { + int32_t pointer_mode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( + operationDesc, HIPBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + } + GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, use_fp8 ? bias_type : (hipDataType)-1, (use_fp8 && gelu) ? aux_type : (hipDataType)-1, - m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue ); + m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue, + use_fp4); GemmAlgoCache::Algo cached_algo; if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) { @@ -1193,10 +1398,10 @@ void hipblaslt_gemm(const Tensor *inputA, if (HIPBLAS_STATUS_SUCCESS == hipblaslt_ext::matmulIsAlgoSupported( handle, operationDesc, - static_cast(&alpha), + alpha_ptr, Adesc, Bdesc, - static_cast(&beta), + beta_ptr, Ddesc, Ddesc, algo_arr[0].algo, @@ -1273,12 +1478,12 @@ void hipblaslt_gemm(const Tensor *inputA, // Warm-up call NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha_ptr, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta_ptr, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1295,12 +1500,12 @@ void hipblaslt_gemm(const Tensor *inputA, { NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha_ptr, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta_ptr, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1356,12 +1561,12 @@ void hipblaslt_gemm(const Tensor *inputA, // D = alpha * (A * B) + beta * C NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha_ptr, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta_ptr, /* beta */ C, /* C */ Cdesc, D, /* D */ diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index ef60a13be..0631f326b 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -166,6 +166,7 @@ __device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float } } } +#endif __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, float* __restrict__ output_identity_amax_ptr, @@ -181,6 +182,7 @@ __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_p } } +#ifndef __HIP_PLATFORM_AMD__ template @@ -978,15 +980,8 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran auto* in_ptr = reinterpret_cast(input.dptr); - if (pre_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); - } - if (id_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); - } - if (tr_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); - } + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); #else constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, diff --git a/transformer_engine/common/hadamard_transform/wht16.cuh b/transformer_engine/common/hadamard_transform/wht16.cuh new file mode 100644 index 000000000..490ebbb6d --- /dev/null +++ b/transformer_engine/common/hadamard_transform/wht16.cuh @@ -0,0 +1,126 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Shared 16-point Walsh-Hadamard transform primitives for AMDGPU. + +#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ +#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ + +#include "hip/hip_runtime.h" + +#ifdef __HIP_PLATFORM_AMD__ + +static constexpr int kHadamardDim = 16; +static constexpr int kWarpSize = 64; +static constexpr int kThreadsPerWHT = 4; +static constexpr int kElemsPerThread = 4; +static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 +static constexpr int kWarpsPerBlock = 4; +static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 +static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 +static constexpr float kHadamardScale = 0.25f; + +// ds_swizzle: sub-wavefront exchange without LDS. +__device__ __forceinline__ float ds_swizzle_xor1(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +__device__ __forceinline__ float ds_swizzle_xor2(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +// ----------------------------------------------------------------------- +// 16-point WHT via the Kronecker trick (no shared memory) +// ----------------------------------------------------------------------- +// +// 1. The vec operator +// vec() flattens a matrix into a column vector by stacking its +// columns one on top of the other: +// +// X = |a c| vec(X) = |a| +// |b d| |b| +// |c| +// |d| +// +// 2. The "Kronecker trick" for 1D -> 2D +// The fundamental identity that connects these concepts is: +// +// vec(B . X . A^T) = (A (x) B) . vec(X) +// +// For a 16-point Hadamard transform (H16 = H4 (x) H4), +// set A = H4 and B = H4. The formula becomes: +// +// H16 . x = vec(H4 . X . H4^T) +// +// 3. Data layout (column-major, one column per thread) +// Reshape the 16-element 1D vector x into a 4x4 matrix X +// by filling columns first: +// +// X = | x0 x4 x8 x12 | thread 0 holds col 0: v0..v3 = x0 ..x3 +// | x1 x5 x9 x13 | thread 1 holds col 1: v0..v3 = x4 ..x7 +// | x2 x6 x10 x14 | thread 2 holds col 2: v0..v3 = x8 ..x11 +// | x3 x7 x11 x15 | thread 3 holds col 3: v0..v3 = x12..x15 +// +// 4. Three-stage computation +// Stage 1 (local H4) : left-multiply H4 . X (within each thread) +// Stage 2 (xor-1 swap) : \ (across 4 threads) +// Stage 3 (xor-2 swap) : / right-multiply . H4^T together these two butterfly stages = H4^T +// +// Result: vec(H4 . X . H4^T) = H16 . x +// +// 5. Randomised Hadamard Transform (RHT) +// A diagonal sign matrix D (from sign_mask) is applied either +// before the WHT (apply_pre=true, forward) or after (inverse). +// +// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, +// extended with NV random_sign_mask (uint16_t bitmask). +// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). +// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). +__device__ __forceinline__ void wht16( + float& v0, float& v1, float& v2, float& v3, + int thread_in_group, uint16_t sign_mask, bool apply_pre) { + auto sgn = [&](int k) -> float { + return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; + }; + + if (apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } + + // Stage 1: local H4 + float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; + v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; + + // Stage 2: cross-thread XOR-1 + { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), + p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + bool up=(thread_in_group&1); + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + // Stage 3: cross-thread XOR-2 + { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), + p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + bool up=(thread_in_group>>1)&1; + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; + + if (!apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } +} + +#endif // __HIP_PLATFORM_AMD__ + +#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 682d8b53f..cfe95d92a 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -14,14 +16,27 @@ namespace transformer_engine { namespace nvfp4_recipe { +#ifndef __HIP_PLATFORM_AMD__ // constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); +#endif // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, const float *amax_B, float *alpha_out) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float fp4_max = detail::TypeExtrema::max; +#if defined(__HIP_DEVICE_COMPILE__) + constexpr float fp8_max = detail::TypeExtrema::max; +#else + constexpr float fp8_max = 240.0f; // host placeholder; only device path executes +#endif + const float fi = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); + *alpha_out = alpha_in * (*amax_A) * (*amax_B) * fi; +#else // factor is defined in the enclosing namespace *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; +#endif } } // namespace nvfp4_recipe diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 59742d1e7..3f0c0fe84 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -22,6 +22,9 @@ #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "common/util/cuda_runtime.h" +#endif namespace transformer_engine { @@ -137,12 +140,32 @@ constexpr int kThreadsPerWarp = 32; constexpr int kNFP4PerContainer = 2; // Hyperparameters for performance tuning +#ifndef __HIP_PLATFORM_AMD__ constexpr int kTileDim = 128; +#endif // constexpr int kScaleDim = 32; constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +#ifndef __HIP_PLATFORM_AMD__ constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total +#endif + +// Tile dimension and thread block size: +// gfx942: kTileDim=64 (64 KB LDS, kThreadsPerBlock=128, 4 warps) +// gfx950 / NVIDIA: kTileDim=128 (128 KB LDS, kThreadsPerBlock=256, 8 warps) +// On AMD, __gfx950__ is only defined during device compilation, so the host +// must select tile_dim at runtime via cuda::sm_arch() using the constants below. +#ifdef __HIP_PLATFORM_AMD__ +constexpr int kTileDimGfx950 = 128; +constexpr int kTileDimGfx942 = 64; +#if !defined(__gfx950__) +constexpr int kTileDim = kTileDimGfx942; +#else +constexpr int kTileDim = kTileDimGfx950; +#endif +constexpr int kThreadsPerBlock = 2 * kTileDim; +#endif // Auto-calculated constants, do not modify directly) static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); @@ -156,6 +179,14 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); +#ifdef __HIP_PLATFORM_AMD__ +// Host-side helper: computes shared memory size for a runtime tile dimension. +// Needed because the host determines tile_dim at runtime via cuda::sm_arch(). +constexpr int smem_size_for_tile(int tile_dim) { + return tile_dim * ((tile_dim / kNVecSMem) + 1) * kNVecSMem; +} +#endif + // for 2D block scaling, we need to reduce amax in warp static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; @@ -301,7 +332,45 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro result.__x = static_cast<__hip_fp4x4_storage_t>(lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8)); return result; #else - NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); + // Stochastic rounding fallback for AMD GPUs without native + // FP4 SR instructions (e.g. gfx942). + // + // FP4 E2M1 has 8 non-negative magnitudes whose 3-bit codes happen to + // be sorted: {0->0.0, 1->0.5, 2->1.0, 3->1.5, 4->2.0, 5->3.0, + // 6->4.0, 7->6.0}. + // + // For each value we: + // 1. Clamp |x| into [0, 6] (the FP4 representable range). + // 2. Find the floor index fi in the FP4 grid via branchless + // comparisons (sum of (|x| >= threshold) for each level). + // 3. Compute the fractional position within [kV[fi], kV[ci]] + // where ci = min(fi+1, 7) is the ceiling index. + // 4. Draw a uniform random value r in [0,1) from 8 bits of rbits. + // 5. Round up to ci if r < frac, otherwise keep fi. + // This gives E[round(x)] = x (unbiased). + // 6. Set the sign bit (bit 3) if the original value was negative. + { + constexpr float kV[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; + const float vals[4] = {in01.x, in01.y, in23.x, in23.y}; + __hip_fp4_storage_t q[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + const float av = fminf(fabsf(vals[i]), 6.0f); + const int fi = int(av >= 0.5f) + int(av >= 1.0f) + int(av >= 1.5f) + + int(av >= 2.0f) + int(av >= 3.0f) + int(av >= 4.0f) + int(av >= 6.0f); + const int ci = min(fi + 1, 7); + const float gap = kV[ci] - kV[fi]; + const float frac = (gap > 0.0f) ? (av - kV[fi]) / gap : 0.0f; + const float r = static_cast((rbits >> (8 * i)) & 0xFFu) * (1.0f / 256.0f); + const int ri = (r < frac) ? ci : fi; + q[i] = static_cast<__hip_fp4_storage_t>((vals[i] < 0.0f) ? (ri | 0x8) : ri); + } + __nv_fp4x4_e2m1 result; + result.__x = static_cast<__hip_fp4x4_storage_t>( + (q[0] & 0xFu) | ((q[1] & 0xFu) << 4) | + ((q[2] & 0xFu) << 8) | ((q[3] & 0xFu) << 12)); + return result; + } #endif // ARCH_HAS_STOCHASTIC_ROUNDING #endif // !__HIP_PLATFORM_AMD__ uint16_t dummy = 0; @@ -803,8 +872,16 @@ void quantize_transpose_vector_blockwise_fp4( using namespace transformer_engine::quantize_transpose_nvfp4; +#ifdef __HIP_PLATFORM_AMD__ + // Tile dimension is selected at compile time based on the target architecture. + // The host still needs the runtime value for grid/smem computation. + const int tile_dim = (cuda::sm_arch() >= 95) ? kTileDimGfx950 : kTileDimGfx942; + const size_t num_blocks_x = DIVUP(row_length, static_cast(tile_dim)); + const size_t num_blocks_y = DIVUP(num_rows, static_cast(tile_dim)); +#else const size_t num_blocks_x = DIVUP(row_length, static_cast(kTileDim)); const size_t num_blocks_y = DIVUP(num_rows, static_cast(kTileDim)); +#endif // noop tensor for cuda graph const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); @@ -830,7 +907,11 @@ void quantize_transpose_vector_blockwise_fp4( using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16; constexpr bool kPow2Scale = false; +#ifdef __HIP_PLATFORM_AMD__ + const bool full_tile = row_length % tile_dim == 0 && num_rows % tile_dim == 0; +#else const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; +#endif TRANSFORMER_ENGINE_SWITCH_CONDITION( return_identity, kReturnIdentity, @@ -850,7 +931,11 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( use_2d_quantization, kIs2DBlockScaling, +#ifdef __HIP_PLATFORM_AMD__ + size_t smem_bytes = smem_size_for_tile(tile_dim) * sizeof(InputType); +#else size_t smem_bytes = kSMemSize * sizeof(InputType); +#endif auto kernel = block_scaled_1d_cast_transpose_kernel< kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, float, InputType, OutputType, ScaleType, kSwizzledScale, @@ -861,8 +946,13 @@ void quantize_transpose_vector_blockwise_fp4( smem_bytes); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); +#ifdef __HIP_PLATFORM_AMD__ + } kernel<<>>( +#else } kernel<<>>( +#endif reinterpret_cast(input.dptr), reinterpret_cast(global_amax.dptr), reinterpret_cast(output.dptr), diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d4f5b0fa8..b22b50c70 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -20,6 +20,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -288,6 +289,24 @@ def general_gemm( beta = validate_gemm_scale(beta, accumulate) workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False) + # On ROCm, FP4 is dequantized to BF16 in the workspace before GEMM. + # Compute the required extra space and extend the workspace if needed. + if IS_HIP_EXTENSION and ( + isinstance(A, NVFP4TensorStorage) or isinstance(B, NVFP4TensorStorage) + ): + assert ub is None, "User buffers (comm overlap) are not supported with NVFP4" + import math + bf16_size = torch.bfloat16.itemsize + fp4_extra = 0 + if isinstance(A, NVFP4TensorStorage): + fp4_extra += math.prod(A.size()) * bf16_size + fp4_extra += A.size(0) * 4 # alpha vector (m floats) + if isinstance(B, NVFP4TensorStorage): + fp4_extra += math.prod(B.size()) * bf16_size + total_needed = fp4_extra + get_cublas_workspace_size_bytes() + if workspace.numel() < total_needed: + workspace = torch.empty(total_needed, dtype=torch.uint8, device=workspace.device) + if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5a6a98442..acbe4753b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -498,7 +498,6 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } -#ifndef USE_ROCM // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate @@ -802,6 +801,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( return res; } +#ifndef USE_ROCM // Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, const std::vector &input_list, @@ -964,6 +964,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } } +#endif // #ifndef USE_ROCM void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const std::vector &input_list, @@ -1020,8 +1021,16 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); } +#ifndef USE_ROCM nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); +#else + // nvte_group_amax is not available on ROCm; compute amax individually + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) continue; + nvte_compute_amax(input_list[i].data(), output_list[i].data(), stream); + } +#endif for (size_t i = 0; i < num_tensors; i++) { output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); } @@ -1086,6 +1095,7 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Perform multi-tensor quantization NVTE_SCOPED_GIL_RELEASE({ +#ifndef USE_ROCM if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data // Check that config is supported NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); @@ -1097,9 +1107,16 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, stream); } +#else + // ROCm: group hadamard kernels are not available, fall back to per-tensor quantize + // which handles both RHT and non-RHT paths via NVFP4Quantizer::quantize_impl. + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) continue; + quantizers[i]->quantize(input_list[i], output_list[i]); + } +#endif }); } -#endif // #ifndef USE_ROCM } // namespace @@ -1169,14 +1186,12 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsMXFP8Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_MXFP8; -#ifndef USE_ROCM } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), [](const py::handle &quantizer) -> bool { return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_NVFP4; quantization_method = QuantizationMethod::FUSED_NVFP4; -#endif } } @@ -1204,7 +1219,6 @@ std::vector split_quantize(const at::Tensor &tensor, bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); break; } -#ifndef USE_ROCM case AllocationMethod::BULK_NVFP4: { // Bulk allocation for NVFP4 tensors std::vector nvfp4_quantizers; @@ -1220,7 +1234,6 @@ std::vector split_quantize(const at::Tensor &tensor, } break; } -#endif default: { // Allocate output tensors individually for (size_t i = 0; i < num_splits; ++i) { @@ -1234,7 +1247,6 @@ std::vector split_quantize(const at::Tensor &tensor, // Quantize into output tensors switch (quantization_method) { -#ifndef USE_ROCM case QuantizationMethod::FUSED_NVFP4: { // Fused NVFP4 quantize kernel auto input_nvte = makeTransformerEngineTensor(input_dptr, input_shape, input_dtype); @@ -1246,7 +1258,6 @@ std::vector split_quantize(const at::Tensor &tensor, nvfp4_quantizers); break; } -#endif default: // General multi-tensor quantization multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1bd2b39c1..bb960406d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1668,25 +1668,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou need_separate_columnwise_rng ? quant_config_columnwise : quant_config; if (!eligible_for_rht_cast_fusion) { -#ifdef USE_ROCM +#ifndef USE_ROCM + at::Tensor rht_output_t; +#endif // If rht_output_t was already produced by the fused amax+transform kernel above, // skip the separate hadamard_transform call. +#ifdef USE_ROCM if (!rht_output_t.defined()) { - rht_output_t = - allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); - TensorWrapper rht_output_t_cpp; - rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); - NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - } #else - at::Tensor rht_output_t; - rht_output_t = - allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); { +#endif + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); TensorWrapper rht_output_t_cpp; rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), std::vector{cols, rows}); @@ -1695,7 +1688,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou this->rht_matrix_random_sign_mask_t, stream); }); } -#endif TensorWrapper rht_output_t_cpp; rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ebc00ea05..f5f0b47d4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -53,6 +53,7 @@ from ..triton_kernels.cast import te_quantize_triton from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -89,7 +90,7 @@ def get_cublas_workspace_size_bytes() -> None: """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" if get_device_compute_capability() == (9, 5): return 67_108_864 - return 33_554_432 + return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales @@ -1489,6 +1490,11 @@ def get_weight_workspace( reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: reset_cache = True + elif isinstance(out, NVFP4TensorStorage): + if quantizer.rowwise_usage and out._rowwise_data is None: + reset_cache = True + elif quantizer.columnwise_usage and out._columnwise_data is None: + reset_cache = True if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): reset_cache = True if reset_cache: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7347fc138..a2817b79b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -312,6 +312,11 @@ def forward( weight_quantizer = weight._quantizer elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) + # NVFP4 must produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor) + from ..tensor.nvfp4_tensor import NVFP4Quantizer + if isinstance(weight_quantizer, NVFP4Quantizer) and is_grad_enabled: + weight_quantizer.set_usage(columnwise=True) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -369,7 +374,9 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: + # NVFP4TensorStorage doesn't have _transpose (no lazy transpose like Float8Tensor), + # so guard with hasattr. + if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache and hasattr(weightmat, '_transpose'): assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( @@ -1861,5 +1868,9 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if IS_HIP_EXTENSION: - weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + # NVFP4 must always produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor), so force columnwise=True. + from ..tensor.nvfp4_tensor import NVFP4Quantizer + is_nvfp4 = isinstance(weight_quantizer, NVFP4Quantizer) + weight_quantizer.set_usage(columnwise=True if is_nvfp4 else self.keep_fp8_weight_transpose_cache) return [weight_quantizer] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 01d07d91a..27e2648e7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -68,6 +68,7 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.utils import is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import ( @@ -265,6 +266,10 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) + # NVFP4 must produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor) + if not columnwise_usage and isinstance(weight_quantizer, NVFP4Quantizer): + columnwise_usage = is_grad_enabled and inp.requires_grad weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) elif isinstance(weight, QuantizedTensor): # If weight is already quantized, no need to set quantizer states @@ -325,7 +330,9 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: + # NVFP4TensorStorage doesn't have _transpose (no lazy transpose like Float8Tensor), + # so guard with hasattr. + if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache and hasattr(weightmat, '_transpose'): assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." nvtx_range_push(f"{nvtx_label}.gemm") @@ -1712,5 +1719,8 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if IS_HIP_EXTENSION: - weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + # NVFP4 must always produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor), so force columnwise=True. + is_nvfp4 = isinstance(weight_quantizer, NVFP4Quantizer) + weight_quantizer.set_usage(columnwise=True if is_nvfp4 else self.keep_fp8_weight_transpose_cache) return [weight_quantizer] diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index ba8a2aec5..b08b10da9 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -89,7 +89,7 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: if IS_HIP_EXTENSION: - return False, "ROCm TE currently not supporting NVFP4" + return True, "" """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index 87cfa722e..ed4002f2c 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -8,6 +8,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.triton_kernels.common import ( te_dtype_to_torch_dtype, te_dtype_to_triton_dtype, @@ -222,7 +223,7 @@ def _te_norm_fwd_triton( quantizer.amax, N, ATOMIC_REDUCTION_BLOCK_SIZE, ) - elif IS_MXFP8 or IS_FP8_CURRENT_SCALING: + elif IS_MXFP8 or IS_FP8_CURRENT_SCALING or isinstance(quantizer, NVFP4Quantizer): _out = quantizer.make_empty( input_tensor.shape, dtype=te_dtype_to_torch_dtype(otype), diff --git a/transformer_engine/pytorch/triton_kernels/utils.py b/transformer_engine/pytorch/triton_kernels/utils.py index 15a733ce9..884fab5e3 100644 --- a/transformer_engine/pytorch/triton_kernels/utils.py +++ b/transformer_engine/pytorch/triton_kernels/utils.py @@ -6,6 +6,7 @@ import triton from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from .common import te_dtype_to_torch_dtype def get_ln_sm_margin(sm_margin_type): @@ -59,7 +60,7 @@ def make_ln_out(ln_out, quantizer=None, input_shape=None, out_dtype=torch.float3 if ln_out is None: # TODO(micky774): Remove corresponding FP8Quantizer check when kernels properly support MXFP8/float8_current_scaling as a fused operation - if quantizer is None or isinstance(quantizer, MXFP8Quantizer) or isinstance(quantizer, Float8CurrentScalingQuantizer): + if quantizer is None or isinstance(quantizer, (MXFP8Quantizer, Float8CurrentScalingQuantizer, NVFP4Quantizer)): return torch.empty(input_shape, dtype=out_dtype, device='cuda') return quantizer.make_empty(input_shape, dtype=out_dtype)