Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
d954c6d
Typo fix (#397)
Micky774 Dec 6, 2025
7b5cf20
ROCm UserBuffers for Comm Overlap
alextmagro Oct 30, 2025
640f7e8
Copyrights and cleanup
alextmagro Jan 27, 2026
82faeec
test guards
alextmagro Feb 24, 2026
b6a3ae4
Cleanup and RS flag race condition fix
alextmagro Mar 2, 2026
9e32d3a
Debugging midpoint
alextmagro Mar 12, 2026
84209ad
Cleanup and workspace fix
alextmagro Mar 17, 2026
c669bd2
Guard layer registration in UB
alextmagro Mar 17, 2026
8040909
Cleanup of profiling example for rocm
alextmagro Mar 17, 2026
e375923
Readd example script and update custom_map
alextmagro Mar 17, 2026
c6bd974
fix typo
alextmagro Mar 17, 2026
d76aa06
MI300 test skips due to jittery results
alextmagro Mar 18, 2026
ae979d0
Comment regarding sm_margin performance
alextmagro Mar 18, 2026
b58cbd1
Variable renamed, pybind fix, tolerance tightening
alextmagro Mar 23, 2026
e5d7446
Remove git conflict
alextmagro Mar 24, 2026
7734ce5
Address style and hip/cu specific paths
alextmagro Mar 26, 2026
c169c75
HIP guards
alextmagro Mar 27, 2026
80e0aab
initial impl
matthiasdiener Mar 27, 2026
de7863a
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Mar 27, 2026
bda7b13
test update
matthiasdiener Mar 30, 2026
7ddb539
Update extensions.h
alextmagro Mar 30, 2026
63c7a48
amax opt
matthiasdiener Mar 30, 2026
a260459
simplify
matthiasdiener Mar 30, 2026
3dd8af9
Merge pull request #367 from ROCm/userbuffer_epic
alextmagro Mar 31, 2026
ab217cb
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Mar 31, 2026
26c5fb7
simplify pt 2
matthiasdiener Mar 31, 2026
2087f24
expand test
matthiasdiener Mar 31, 2026
05cedb7
compute amax from BF16-rounded outputs
matthiasdiener Mar 31, 2026
67b93a8
TE building over TheRock (#511)
ipanfilo Apr 1, 2026
465d547
Typo fix (#397)
Micky774 Dec 6, 2025
9fb21f9
Add NVTE_UB_WITH_MPI to rocm build path
alextmagro Apr 1, 2026
2f66594
Merge pull request #513 from ROCm/ub_mpi_hotfix
alextmagro Apr 1, 2026
986d8ba
NVFP4: hadamard_transform_cast_fusion_columnwise
matthiasdiener Apr 1, 2026
b339c86
unify hadamard_transform_cast_fusion_columnwise
matthiasdiener Apr 1, 2026
f74a0ab
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 1, 2026
e9426cd
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 1, 2026
1d0a70e
Rebase onto dev
aris134 Apr 2, 2026
6e3eea5
Enable NVFP4 recipe
matthiasdiener Apr 2, 2026
9c3dc2f
NVFP4 GEMM via BF16 dequant
matthiasdiener Apr 2, 2026
e3a2502
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 2, 2026
3a63f32
add explanation to wht16
matthiasdiener Apr 2, 2026
35ef81c
comment and test
matthiasdiener Apr 2, 2026
9559131
Merge branch 'amartin/nvfp4-dequant' into mdiener/nvfp4-gemm
matthiasdiener Apr 3, 2026
e8ff6bd
enable use_fused_bulk_alloc
matthiasdiener Apr 3, 2026
e26ffc8
compute random sign mask on device
matthiasdiener Apr 3, 2026
c7cc488
CI: enable CI runs on every PR
matthiasdiener Apr 6, 2026
7c68bd8
Avoid duplicate entry when opening PR
matthiasdiener Apr 6, 2026
a19dd60
fix stream capture error in GEMM
matthiasdiener Apr 6, 2026
17d50ee
Merge branch 'dev' into mdiener/fp4_hadamard
matthiasdiener Apr 6, 2026
e32a758
merge errors
matthiasdiener Apr 6, 2026
4857721
Merge branch 'dev' into mdiener/fp4_hadamard
matthiasdiener Apr 6, 2026
b243b4c
merge
matthiasdiener Apr 6, 2026
5d39b27
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 7, 2026
141eadc
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 7, 2026
6527004
Merge branch 'dev' into mdiener/fp4_hadamard
matthiasdiener Apr 7, 2026
ca1aacf
change to __builtin_bit_cast
matthiasdiener Apr 7, 2026
c8e6c72
more fixes
matthiasdiener Apr 7, 2026
1b0fe3e
fix dequant buffer
matthiasdiener Apr 7, 2026
bc9f0a3
remove copyright header
matthiasdiener Apr 8, 2026
167c311
fix triton rmsnorm
matthiasdiener Apr 8, 2026
203ef86
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 8, 2026
4b0550d
mi300 fixes
matthiasdiener Apr 8, 2026
5da621a
software fallbacks for SR on gfx942
matthiasdiener Apr 8, 2026
9f1851d
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 9, 2026
287708d
more gfx942 fixes
matthiasdiener Apr 9, 2026
81c45c9
ensure columnwise data for dgrad GEMM
matthiasdiener Apr 9, 2026
b75e066
fix mi350
matthiasdiener Apr 9, 2026
5095971
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 9, 2026
739a20d
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 13, 2026
2225c72
replace dequant allocation with workspace
matthiasdiener Apr 13, 2026
42bf230
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 13, 2026
f269097
enable tests
matthiasdiener Apr 13, 2026
346beb1
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 16, 2026
cf2c8f6
address reviewer comments
matthiasdiener Apr 16, 2026
2772834
minor fixes
matthiasdiener Apr 16, 2026
26c5cb1
PreRhtAmax optimizations
matthiasdiener Apr 16, 2026
071aa4b
Merge branch 'mdiener/fp4_hadamard' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 17, 2026
018d24f
use ZeroAmaxKernel
matthiasdiener Apr 17, 2026
3efd532
Merge remote-tracking branch 'origin/dev' into mdiener/fp4_hadamard
matthiasdiener Apr 17, 2026
b835818
Merge branch 'mdiener/fp4_hadamard' into mdiener/nvfp4-cast_fusion
matthiasdiener Apr 17, 2026
f2caca7
Merge branch 'mdiener/nvfp4-cast_fusion' into mdiener/nvfp4-gemm
matthiasdiener Apr 17, 2026
95518ea
Merge branch 'dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 20, 2026
feff829
undo hadamard_fusion
matthiasdiener Apr 20, 2026
1ba9474
fixes
matthiasdiener Apr 20, 2026
e1ba512
cleanups, cleaner mi300 LDS workaround
matthiasdiener Apr 20, 2026
da093bd
more cleanups
matthiasdiener Apr 20, 2026
75a4738
re-fix null tensor_amax
matthiasdiener Apr 20, 2026
5315506
minor cleanups
matthiasdiener Apr 21, 2026
7254ba4
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 23, 2026
a91eaf0
address review comments
matthiasdiener Apr 23, 2026
29cacef
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 23, 2026
0f53a9d
address review comments
matthiasdiener Apr 24, 2026
2851bd9
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 24, 2026
3deb3ff
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 26, 2026
a08e8c5
use maxNorm
matthiasdiener Apr 27, 2026
fae76d3
factor out FP4 staging
matthiasdiener Apr 27, 2026
81d3cbd
Merge remote-tracking branch 'origin/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 28, 2026
0afd821
Merge remote-tracking branch 'upstream/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 28, 2026
a6f4787
address review comments
matthiasdiener Apr 28, 2026
b1575bd
Merge remote-tracking branch 'upstream/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 28, 2026
eae6e95
Merge remote-tracking branch 'upstream/dev' into mdiener/nvfp4-gemm
matthiasdiener Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ std::vector<InputType> create_transpose(const InputType* const input, const size

// Compute the global encode scale factor for a given global amax
float compute_global_encode_scaling_factor_FP4(const float global_amax) {
#ifdef __HIP_PLATFORM_AMD__
const float fp8_max = Numeric_Traits<fp8e4m3>::maxNorm;
#else
Comment thread
aris134 marked this conversation as resolved.
constexpr float fp8_max = 448.0f; // 448.0f;
#endif
constexpr float fp4_max = 6.0f; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/operator/test_dequantize_nvfp4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ void compute_ref(const fp4e2m1* input,
const size_t rows,
const size_t cols,
const size_t scale_stride) {
#ifdef __HIP_PLATFORM_AMD__
const float fp8_max = Numeric_Traits<fp8e4m3>::maxNorm;
const float factor_inv = 1.0f / (6.0f * fp8_max);
#else
constexpr float factor_inv = 1.0f / (6.0f * 448.0f);
#endif

const size_t blocks_per_row = cols / kFP4BlockSize1D;

Expand Down
23 changes: 20 additions & 3 deletions tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -10,7 +12,10 @@
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from torch.utils.cpp_extension import IS_HIP_EXTENSION

if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type

recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)

Expand Down Expand Up @@ -108,8 +113,13 @@ def check_nvfp4_gemm_versus_reference(

# Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn
# for the reference GEMM to work correctly
sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn)
sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn)
if IS_HIP_EXTENSION:
fp8_dtype = get_torch_float8_e4m3_type()
sx_trimmed = sx_trimmed.view(fp8_dtype)
sw_trimmed = sw_trimmed.view(fp8_dtype)
else:
sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn)
sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn)

# Create reference quantizer for reference GEMM
ref_quantizer = NVFP4QuantizerRef(
Expand Down Expand Up @@ -150,7 +160,14 @@ def check_nvfp4_gemm_versus_reference(

# Native TE GEMM using tex.generic_gemm (cuBLAS GEMM)
# Allocate cuBLAS workspace
workspace = torch.empty(4, dtype=torch.uint8, device=device)
if IS_HIP_EXTENSION:
# On ROCm, FP4 is dequantized to BF16 in workspace before GEMM, so allocate enough space.
Comment thread
aris134 marked this conversation as resolved.
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
bf16_size = torch.bfloat16.itemsize
ws_bytes = M * K * bf16_size + K * N * bf16_size + get_cublas_workspace_size_bytes()
workspace = torch.empty(ws_bytes, dtype=torch.uint8, device=device)
Comment thread
ipanfilo marked this conversation as resolved.
else:
workspace = torch.empty(4, dtype=torch.uint8, device=device)

transa = True if not w_columnwise else False
transb = False if not x_columnwise else True
Expand Down
17 changes: 15 additions & 2 deletions tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -11,6 +13,9 @@
import transformer_engine_torch as tex

from transformer_engine.pytorch import NVFP4Quantizer
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz

recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)

Expand Down Expand Up @@ -58,10 +63,18 @@ def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor:


def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor:
sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
if IS_HIP_EXTENSION:
fp8_dtype = get_torch_float8_e4m3_type()
fp8_max = 240.0 if is_fp8_fnuz() else 448.0
sf = sx.repeat_interleave(16, dim=1).view(fp8_dtype).to(torch.float32)
else:
sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
dqx = fp4_to_fp32(unpack_fp4(qx))
sf = sf[: dqx.shape[0], : dqx.shape[1]]
dequant = dqx * sf * (amax / (6.0 * 448))
if IS_HIP_EXTENSION:
dequant = dqx * sf * (amax / (6.0 * fp8_max))
else:
dequant = dqx * sf * (amax / (6.0 * 448))
return dequant


Expand Down
7 changes: 6 additions & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,7 +1773,10 @@ def test_clamped_swiglu(
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
if IS_HIP_EXTENSION:
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
else:
maybe_skip_quantization(quantization, dims=in_shape, device=device)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down Expand Up @@ -2937,6 +2940,8 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:

# Check values
tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking
if IS_HIP_EXTENSION:
tols["atol"] = 0.54
torch.testing.assert_close(to_cpu(y_test), y_ref, **tols)
torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols)
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,6 @@ def test_gpt_full_activation_recompute(
if (dtype == torch.bfloat16
and not fp8
and not use_reentrant
and recipe.float8_per_tensor_scaling()
):
pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.")
if fp8 and recipe.nvfp4():
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ list(APPEND transformer_engine_cuda_sources
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
swizzle/swizzle.cu)

list(APPEND transformer_engine_cuda_arch_specific_sources
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ __global__ void __launch_bounds__(512)
#else
float amax = (tensor_amax != nullptr) ? *tensor_amax : 1.0f;
#endif
#if defined(__HIP_DEVICE_COMPILE__)
constexpr float factor_inv = 1.0f / (detail::TypeExtrema<fp4e2m1>::max * detail::TypeExtrema<fp8e4m3>::max);
#else
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
#endif
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
Expand Down
Loading
Loading