From d954c6d6946f1b4e9b92b7c5caed70f147805a5a Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Fri, 5 Dec 2025 23:40:54 -0500 Subject: [PATCH 01/69] Typo fix (#397) From 7b5cf2048beca91312344f1a6438c58d7d40b0af Mon Sep 17 00:00:00 2001 From: alextmagro Date: Thu, 30 Oct 2025 15:57:29 -0500 Subject: [PATCH 02/69] ROCm UserBuffers for Comm Overlap --- build_tools/hipify/custom_map.json | 8 +- build_tools/pytorch.py | 23 +- ci/pytorch.sh | 1 + .../te_layer_with_overlap_profile.py | 504 +++++++++++++ .../pytorch/comm_gemm_overlap/ub_config.json | 15 + setup.py | 11 +- .../cpp/operator/test_normalization_mxfp8.cu | 6 +- .../distributed/run_layer_with_overlap.py | 7 + .../distributed/test_comm_gemm_overlap.py | 7 +- transformer_engine/common/CMakeLists.txt | 16 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 63 +- .../rocm_comm_gemm_overlap.cpp | 664 ++++++++++++++++++ .../userbuffers/userbuffers-host.cpp | 13 +- .../userbuffers/userbuffers.cu | 105 ++- .../userbuffers/userbuffers.h | 27 +- .../transformer_engine/comm_gemm_overlap.h | 70 +- .../include/transformer_engine/multi_stream.h | 2 +- .../common/util/cuda_runtime.cpp | 8 +- transformer_engine/common/util/cuda_runtime.h | 2 - .../common/util/pybind_helper.h | 18 +- transformer_engine/pytorch/csrc/common.h | 2 - transformer_engine/pytorch/csrc/extensions.h | 10 - .../csrc/extensions/comm_gemm_overlap.cpp | 5 +- .../pytorch/csrc/extensions/gemm.cpp | 22 +- .../pytorch/csrc/extensions/pybind.cpp | 6 - transformer_engine/pytorch/module/base.py | 10 +- .../pytorch/ops/fused/__init__.py | 19 +- transformer_engine/pytorch/transformer.py | 6 +- 28 files changed, 1530 insertions(+), 120 deletions(-) create mode 100644 examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py create mode 100644 examples/pytorch/comm_gemm_overlap/ub_config.json create mode 100644 transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 872d38efa..72bf8a383 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -15,7 +15,13 @@ "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", "#include " : "", - "#include " : "" + "#include " : "", + "cudaLaunchKernel": "hipLaunchKernel", + "CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t", + "cudaLaunchConfig_t": "hipLaunchConfig_t", + "cudaLaunchAttribute": "hipLaunchAttribute", + "cudaLaunchAttributeCooperative": "hipLaunchAttributeCooperative", + "CUdeviceptr": "hipDeviceptr_t" } } diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index d8eb9a81e..db7b61d02 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -81,14 +81,6 @@ def setup_pytorch_extension( if version < (12, 0): raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_path / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") - library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): @@ -102,12 +94,22 @@ def setup_pytorch_extension( cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))): - cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM") mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) include_dirs.append(mpi_home / "include") library_dirs.append(mpi_home / "lib") - libraries.append("mpi_cxx") + libraries.append("mpi") + cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"]) + extra_link_args = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) + include_dirs.append(mpi_path / "include") + library_dirs.append(mpi_path / "lib") + libraries.append("mpi") + cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"]) # Construct PyTorch CUDA extension sources = [str(path) for path in sources] @@ -121,4 +123,5 @@ def setup_pytorch_extension( extra_compile_args={"cxx": cxx_flags}, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], + extra_link_args=[str(arg) for arg in extra_link_args], ) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 1d21e2450..a7259ebf6 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -95,6 +95,7 @@ run_test_config_mgpu(){ run_default_fa 1 test_gemm_sm_count.py run_default_fa 3 test_sanity_import.py run_default_fa 3 distributed/test_cast_master_weights_to_fp8.py + run_default_fa 3 distributed/test_comm_gemm_overlap.py run_default_fa 2 distributed/test_fusible_ops.py run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py new file mode 100644 index 000000000..ba5afd2b6 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -0,0 +1,504 @@ +#!/usr/bin/python3 + +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import fcntl +import struct +import argparse +import warnings + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +import torch.profiler + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=10, help="Number of dummy 'training' iterations." + ) + parser.add_argument("-b", "--batch-size", type=int, default=8, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=16384, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument( + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--no-comm-overlap", + action="store_true", + default=False, + help="Disable the comm+GEMM overlap.", + ) + parser.add_argument( + "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out from every rank instead of just the root rank of relevant process groups.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Enable PyTorch profiler.", + ) + parser.add_argument( + "--profile-dir", + type=str, + default="./logs/profiler_traces", + help="Directory to save PyTorch profiler traces.", + ) + parser.add_argument( + "--ub_config", + type=str, + default="./ub_config.json", + help="Userbuffer configuration file.", + ) + + args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True + return args + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + # args.append(4 * hidden_size) + args.append(int(3.5 * hidden_size)) + + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + +def create_ub_cfgs(config_file: str, tp_size: int = 8): + import json + with open(config_file, 'r') as f: + data = json.load(f) + cfgs = {} + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + layers_all_gather_overlap = [ + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "proj_wgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + "fc2_wgrad", + ] + + for name, method in data.items(): + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() + + cfg = { + "method": method, + "is_reduce_scatter": name in layers_reduce_scatter_overlap, + "num_sm": 1 if method == "ring_exchange" else 16, + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": False, + "num_splits": tp_size if method == "ring_exchange" else 4, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + } + + cfgs[name] = cfg + + return cfgs + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + NUM_NODES = WORLD_SIZE // LOCAL_SIZE + + # Initialize torch.distributed global process group and get DP/TP groups + torch.cuda.set_device(LOCAL_RANK) + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init or NUM_NODES > 1: + if NUM_NODES > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Figure out process groups for tensor- and data-parallelism (if any) + if NUM_NODES > 1: + # Create a list of world ranks on this node + hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] + dist.all_gather_object(hostnames, hostname) + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] + + if opts.num_replicas > 1: + # Split node ranks into multiple replicas + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i + break + assert self_replica_idx >= 0 + + else: + # The entire node is the tensor-parallel group + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx + + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + + else: + if opts.num_replicas > 1: + # Mixed data- and tensor-parallelism on a single node + # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions + all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + else: + dp_group = None + tp_group = nccl_world + + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + dist_print( + f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", + group=tp_group, + ) + if dp_group is not None: + dp_rank = dist.get_rank(dp_group) + dist_print( + f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", + group=dp_group, + ) + else: + dp_rank = 0 + + # Intialize userbuffers + hidden_size = opts.num_heads * opts.head_dim + batched_size = opts.seq_length * opts.batch_size + if not opts.no_comm_overlap: + te.module.base.initialize_ub( + [batched_size, hidden_size], + tp_size, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=create_ub_cfgs(opts.ub_config, tp_size) + ) + # Initialize the fused LayerNorm + Multi-layer Perceptron module + torch.manual_seed(opts.seed + dp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) + if dp_group is not None: + model = DistributedDataParallel(model, dim=1, process_group=dp_group) + + # Initialize optimizer with model parameters + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + if opts.profile: + log_dir = os.path.join(opts.profile_dir, f"rank_{WORLD_RANK}") + os.makedirs(log_dir, exist_ok=True) + dist_print(f"Profiler traces will be saved to: {log_dir}", group=nccl_world) + + schedule = torch.profiler.schedule(wait=1, warmup=2, active=5, repeat=1) + + on_trace_ready = torch.profiler.tensorboard_trace_handler( + log_dir, worker_name=f"rank_{WORLD_RANK}" + ) + + profiler_activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + import time + + start_time = time.time() + with torch.profiler.profile( + schedule=schedule, + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + on_trace_ready=on_trace_ready, + profile_memory=True, + activities=profiler_activities, + ) as prof: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + prof.step() + torch.cuda.synchronize() + end_time = time.time() + total_wall_clock_time = end_time - start_time + print(f"Total Wall Clock Time: {total_wall_clock_time:.4f} seconds") + # total_flops = sum([item.flops for item in prof.key_averages()]) + # print(f"Total FLOPs: {total_flops}") + else: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + + dist_print("Finished training!") + te.module.base.destroy_ub() + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) \ No newline at end of file diff --git a/examples/pytorch/comm_gemm_overlap/ub_config.json b/examples/pytorch/comm_gemm_overlap/ub_config.json new file mode 100644 index 000000000..a26c7f9f1 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/ub_config.json @@ -0,0 +1,15 @@ +{ + "qkv_fprop": "ring_exchange", + "fc1_fprop": "ring_exchange", + "fc2_dgrad": "ring_exchange", + "proj_wgrad": "ring_exchange", + "fc2_wgrad": "ring_exchange", + + + "proj_fprop": "ring_exchange", + "fc2_fprop": "ring_exchange", + + "qkv_dgrad": "ring_exchange", + "fc1_dgrad": "ring_exchange" + +} \ No newline at end of file diff --git a/setup.py b/setup.py index d201641f3..04a5befc8 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,12 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" cmake_flags = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if rocm_build(): cmake_flags.append("-DUSE_ROCM=ON") if os.getenv("NVTE_AOTRITON_PATH"): @@ -101,11 +107,6 @@ def setup_common_extension() -> CMakeExtension: else: cmake_flags.append("-DUSE_ROCM=OFF") cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" - cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))): assert ( diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index f38a9e695..b3b94ea01 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -131,10 +131,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, DType wtype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input("input", std::vector{ N, H }, itype); + Tensor input("input2", std::vector{ N, H }, itype); Tensor z("z", std::vector{ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); - Tensor gamma("gamma", std::vector{ H }, wtype); - Tensor beta("beta", std::vector{ H }, wtype); + Tensor gamma("gamma2", std::vector{ H }, wtype); + Tensor beta("beta2", std::vector{ H }, wtype); Tensor mu("mu", std::vector{ N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); Tensor workspace; diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 53c7a5e7c..f732e9a44 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -30,6 +32,11 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +import transformer_engine.pytorch.cpp_extensions as tex +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + class multi_module_model(torch.nn.Module): def __init__(self, module, num_layers, *args, **kwargs): diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 4523cb13e..d1c03de70 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,6 +12,9 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.misc import is_hip_extension + + if torch.cuda.device_count() < 2: pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") @@ -179,7 +184,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ if connections == 8: - if torch.cuda.get_device_properties(0).major != 9: + if is_hip_extension() or torch.cuda.get_device_properties(0).major != 9: pytest.skip( "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" " 9.0 (HOPPER ARCH)." diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8d5537368..69b2af81c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -185,6 +185,7 @@ list(APPEND transformer_engine_cpp_sources list(APPEND transformer_engine_cuda_sources common.cu + comm_gemm_overlap/userbuffers/userbuffers.cu multi_tensor/adam.cu multi_tensor/compute_scale.cu multi_tensor/l2norm.cu @@ -242,8 +243,7 @@ if(USE_CUDA) fused_attn/utils.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu - recipe/nvfp4.cu - comm_gemm_overlap/userbuffers/userbuffers.cu) + recipe/nvfp4.cu) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu @@ -356,11 +356,15 @@ target_include_directories(transformer_engine PRIVATE ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +# Changed option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) if (NVTE_UB_WITH_MPI) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + # OpenMPI C++ headers are deprecated -- flag unused w/ MPICH + add_definitions(-DOMPI_SKIP_MPICXX) + + target_include_directories(transformer_engine PRIVATE "$ENV{MPI_HOME}/include") + target_link_directories(transformer_engine PRIVATE "$ENV{MPI_HOME}/lib") + target_link_libraries(transformer_engine PUBLIC mpi) target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) endif() @@ -543,7 +547,7 @@ endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") else() - set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3 -fopenmp") set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17") # Ask hcc to generate device code during compilation so we can use # host linker to link. diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 2a3c64e8d..00d8dc322 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,12 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#ifdef __HIP_PLATFORM_AMD__ +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#endif + using namespace std::placeholders; namespace transformer_engine { @@ -83,7 +89,7 @@ void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_stream _gemm_priority = gemm_priority; _comm_priority = comm_priority; } - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + for (int i = 0; i < std::max(num_max_streams, num_splits); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); _stream_compute.push_back(std::move(stream)); @@ -353,7 +359,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); + (cudaEvent_t)_comm_launch_event); } else { reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); @@ -471,7 +477,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); -} // split_overlap_rs +} // atomic_gemm_overlap_rs /* ** Split FPROP GEMM + ReduceScatter @@ -619,6 +625,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, cudaStream_t stream_main) { + int comm_bytes = _ubuf.bytes(); int comm_bytes_per_rank = comm_bytes / _tp_size; @@ -717,10 +724,23 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); } + for (int i = 0; i < 7; i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_send.push_back(std::move(stream)); + } + for (int i = 0; i < 7; i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_recv.push_back(std::move(stream)); + } NVTE_CHECK_CUDA( cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); + for (int i = 0; i < 7; i++) { + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&l_stop_recv[i], 0)); + } } CommOverlapP2PBase::~CommOverlapP2PBase() { @@ -730,6 +750,43 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) { cudaStreamDestroy(_stream_send[i]); } + for (int i = 0; i < 7; i++) { + cudaStreamDestroy(l_stream_recv[i]); + cudaStreamDestroy(l_stream_send[i]); + cudaEventDestroy(l_stop_recv[i]); + } +} + +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); } void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, diff --git a/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp new file mode 100644 index 000000000..ea05ea95f --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp @@ -0,0 +1,664 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" + +namespace transformer_engine { +#if 0 +// Recursive doubling AG code for future reference +void CommOverlapP2PBase::rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].bytes(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + int steps = 31 - __builtin_clz(_tp_size); + + // Chunk dims + std::vector input_b_chunk_shape = + (transb ? std::vector{k, n_chunk} : std::vector{n_chunk, k}); + std::vector output_chunk_shape = {n_chunk, m}; + size_t input_b_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + + // GEMM + auto input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * _tp_id, input_b_chunk_shape); + auto output_chunk = + get_tensor_chunk(D, output_chunk_size * _tp_id, output_chunk_shape); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * _tp_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (_tp_id % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[_tp_id % _stream_compute.size()]); + + std::vector owned_chunks; + owned_chunks.reserve(_tp_size); + owned_chunks.push_back(_tp_id); + size_t offset = 1; + + for (int step = 0; step < steps; step++) { + int send_rank = (_tp_id + offset) % _tp_size; + int recv_rank = (_tp_id - offset + _tp_size) % _tp_size; + + for (int i = 0; i < owned_chunks.size(); i++) { + size_t send_offset = owned_chunks[i] * comm_bytes; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, + comm_bytes, _ub_comm, send_rank, _stream_send[i % _stream_send.size()]); + } + + std::vector new_chunks; + for (size_t i = 0; i < owned_chunks.size(); i++) { + size_t new_chunk_id = (recv_rank + i * offset) % _tp_size; + if (new_chunk_id >= _tp_size || + std::find(owned_chunks.begin(), owned_chunks.end(), new_chunk_id) != owned_chunks.end()) continue; + size_t recv_offset = new_chunk_id * comm_bytes; + size_t stream_id = new_chunks.size() % _stream_compute.size(); + + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, + comm_bytes, _ub_comm, recv_rank, _stream_recv); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[stream_id], _stop_recv, 0)); + + auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * new_chunk_id, input_b_chunk_shape); + output_chunk = get_tensor_chunk(D, output_chunk_size * new_chunk_id, output_chunk_shape); + aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * new_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[stream_id]); + + new_chunks.push_back(new_chunk_id); + } + owned_chunks.insert(owned_chunks.end(), new_chunks.begin(), new_chunks.end()); + offset <<= 1; + } + + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // rocm_split_overlap_ag_rd +#endif // #if 0 + +// TODO: Generalize for TP other than 2,4,8 using Walecki construction +constexpr int tp_next_8[7][8] = { + {1, 5, 4, 6, 3, 2, 7, 0}, + {2, 6, 1, 0, 5, 7, 4, 3}, + {3, 7, 0, 5, 6, 4, 1, 2}, + {4, 3, 6, 2, 7, 0, 5, 1}, + {5, 2, 7, 4, 1, 3, 0, 6}, + {6, 0, 5, 7, 2, 1, 3, 4}, + {7, 4, 3, 1, 0, 6, 2, 5}, +}; + +constexpr int tp_prev_8[7][8] = { + {7, 0, 5, 4, 2, 1, 3, 6}, + {3, 2, 0, 7, 6, 4, 1, 5}, + {2, 6, 7, 0, 5, 3, 4, 1}, + {5, 7, 3, 1, 0, 6, 2, 4}, + {6, 4, 1, 5, 3, 0, 7, 2}, + {1, 5, 4, 6, 7, 2, 0, 3}, + {4, 3, 6, 2, 1, 7, 5, 0}, +}; + +// No full Hamiltonian decomposition for TP=4 TP=6 (Tillson’s Theorem) +// Further optimization for these cases may be multiring w/ RD for example +constexpr int tp_next_4[2][4] = { + {1, 2, 3, 0}, + {3, 0, 1, 2}, +}; + +constexpr int tp_prev_4[2][4] = { + {3, 0, 1, 2}, + {1, 2, 3, 0} +}; + +template +constexpr bool multiring_hamiltonian_check(const int (&next)[NUM_RINGS][TP_SIZE]) { + for (int r = 0; r < NUM_RINGS; ++r) { + bool visited[TP_SIZE] = {}; + + int curr = 0; + for (int step = 0; step < TP_SIZE; ++step) { + if (visited[curr]) return false; + visited[curr] = true; + curr = next[r][curr]; + } + + if (curr != 0) return false; + + for (int i = 0; i < TP_SIZE; ++i) { + if (!visited[i]) return false; + } + } + return true; +} + +template +constexpr bool rings_are_unique( + const int next[NUM_RINGS][TP_SIZE]) +{ + for (int src = 0; src < TP_SIZE; ++src) { + bool seen[TP_SIZE] = {}; + + for (int r = 0; r < NUM_RINGS; ++r) { + int dst = next[r][src]; + + // No self-send + if (dst == src) + return false; + + if (seen[dst]) + return false; + + seen[dst] = true; + } + } + return true; +} + +template +constexpr bool prev_is_inverse_of_next( + const int next[NUM_RINGS][TP_SIZE], + const int prev[NUM_RINGS][TP_SIZE]) +{ + for (int r = 0; r < NUM_RINGS; ++r) { + for (int i = 0; i < TP_SIZE; ++i) { + int n = next[r][i]; + int p = prev[r][i]; + + if (n < 0 || n >= TP_SIZE) return false; + if (p < 0 || p >= TP_SIZE) return false; + + if (prev[r][n] != i) return false; + if (next[r][p] != i) return false; + } + } + return true; +} + +static_assert(multiring_hamiltonian_check<2,4>(tp_next_4), "Non-Hamiltonian ring present!"); +static_assert(multiring_hamiltonian_check<7,8>(tp_next_8), "Non-Hamiltonian ring present!"); + +static_assert(rings_are_unique<2,4>(tp_next_4), "Rings overlap"); +static_assert(rings_are_unique<7,8>(tp_next_8), "Rings overlap"); + +static_assert(prev_is_inverse_of_next<2,4>(tp_next_4, tp_prev_4), "tp_prev_4 is not inverse of tp_next_4"); +static_assert(prev_is_inverse_of_next<7,8>(tp_next_8, tp_prev_8), "tp_prev_8 is not inverse of tp_next_8"); + +// TODO: Introduce HIPGraphs for dependency management. +void CommOverlapP2PBase::rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].bytes(); + const bool do_gelu = pre_gelu_out.numel() > 0; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + const int max_rings = (_tp_size == 4) ? 2 : + (_tp_size == 6) ? 4 : + _tp_size - 1; + const int num_rings = std::min({ + transformer_engine::getenv("GPU_MAX_HW_QUEUES", 4), + _tp_size - 1, + max_rings + }); + + const int *next, *prev; + switch (_tp_size) { + case 8: + next = reinterpret_cast(tp_next_8); + prev = reinterpret_cast(tp_prev_8); + break; + case 4: + next = reinterpret_cast(tp_next_4); + prev = reinterpret_cast(tp_prev_4); + break; + case 2: + return this->split_overlap_ag(A, transa, B, transb, D, bias, pre_gelu_out, workspace, grad, + accumulate, use_split_accumulator, B_copy, stream_main); + default: + NVTE_ERROR("ROCm supports TP sizes of 2, 4, 8 only."); + } + + const int alignment = 256; + const int base_slice_bytes = (comm_bytes / num_rings) & ~(alignment - 1); + const int total_base_bytes = base_slice_bytes * num_rings; + const int remainder_bytes = comm_bytes - total_base_bytes; + + const size_t base_n_slice = n_chunk / num_rings; + const size_t remainder_n = n_chunk - (base_n_slice * num_rings); + + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size()); + } + + auto get_slice_info = [&](int ring) -> std::pair { + size_t offset = ring * base_slice_bytes; + int size = base_slice_bytes; + if (ring == num_rings - 1) + size += remainder_bytes; + return {offset, size}; + }; + + auto get_slice_n = [&](int ring) -> size_t { + return base_n_slice + (ring == num_rings - 1 ? remainder_n : 0); + }; + + auto get_chunk_id = [&](int ring, int step) { + int owner = _tp_id; + for (int s = 0; s < step; ++s) + owner = prev[ring * _tp_size + owner]; + return owner; + }; + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + + for (int r = 0; r < num_rings; ++r) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); + } + + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + const int total_slices = _tp_size * num_rings; + std::vector slice_events(total_slices); + + for (int i = 0; i < total_slices; i++) { + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&slice_events[i], cudaEventDisableTiming)); + } + + auto get_event = [&](int chunk, int ring) { + return slice_events[chunk * num_rings + ring]; + }; + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); + } + + auto get_slice_offset = [&](int chunk, int ring) { + auto [ring_offset, _] = get_slice_info(ring); + return chunk * comm_bytes + ring_offset; + }; + + auto launch_slice_gemm = [&](int ring_id, int step) { + int chunk_id = get_chunk_id(ring_id, step); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[ring_id], + get_event(chunk_id, ring_id), 0)); + size_t n_slice = get_slice_n(ring_id); + + size_t input_b_slice_elems = n_slice * k; + size_t output_slice_elems = n_slice * m; + + size_t b_elem_offset = chunk_id * n_chunk * k; + size_t d_elem_offset = chunk_id * n_chunk * m; + + for (int r = 0; r < ring_id; r++) { + size_t prev_n = get_slice_n(r); + b_elem_offset += prev_n * k; + d_elem_offset += prev_n * m; + } + + std::vector input_b_slice_shape = + (transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); + std::vector output_slice_shape = {n_slice, m}; + + auto input_b_slice = get_buffer_chunk_like(B, b_elem_offset, input_b_slice_shape); + auto output_slice = get_tensor_chunk(D, d_elem_offset, output_slice_shape); + + auto aux_slice = (do_gelu) + ? get_tensor_chunk(pre_gelu_out, d_elem_offset, {n_slice, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + + auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, + {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + aux_slice.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[ring_id]); + }; + + for (int step = 0; step < _tp_size; step++) { + for (int r = 0; r < num_rings; r++) { + if (step < _tp_size - 1) { + int curr_chunk_id = get_chunk_id(r, step); + int next_recv_chunk_id = get_chunk_id(r, step + 1); + + int next_rank = next[r * _tp_size + _tp_id]; + int prev_rank = prev[r * _tp_size + _tp_id]; + + size_t send_off = get_slice_offset(curr_chunk_id, r); + size_t recv_off = get_slice_offset(next_recv_chunk_id, r); + + auto [_, slice_bytes] = get_slice_info(r); + + if (step > 0) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(curr_chunk_id, r), 0)); + } + + { + int peerlocal = next_rank % _ub_comm->nvsize; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, r); + void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_off; + void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + send_off; + + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, slice_bytes, cudaMemcpyDeviceToDevice, l_stream_send[r])); + uint32_t signal_val = step + 1; + hipStreamWriteValue32(l_stream_send[r], flagptr, signal_val, 0); + } + + { + int peerlocal = prev_rank % _ub_comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(prev_rank, _ub_comm, _ub_reg, r); + + uint32_t signal_val = step + 1; + hipStreamWaitValue32(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFF); + } + + NVTE_CHECK_CUDA(cudaEventRecord(get_event(next_recv_chunk_id, r), l_stream_recv[r])); + } + } + + for (int r = 0; r < num_rings; r++) { + launch_slice_gemm(r, step); + } + } + + if (B_copy.numel() > 0) { + for (int r = 0; r < num_rings; r++) { + int last_chunk = get_chunk_id(r, _tp_size - 1); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[0], get_event(last_chunk, r), 0)); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, l_stream_send[0])); + } + + _ub_comm->sms = ori_sms; + + for (auto& s : _stream_compute) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, s)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + + for (auto& ev : slice_events) { + NVTE_CHECK_CUDA(cudaEventDestroy(ev)); + } +} // CommOverlapP2PBase::rocm_split_overlap_ag + +void CommOverlapP2PBase::rocm_split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // GEMM dimensions + const size_t m = transa ? A.size(0) : A.size(1); + const size_t k = transa ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].bytes(); + + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + const int max_rings = (_tp_size == 4) ? 2 : + (_tp_size == 6) ? 4 : + _tp_size - 1; + + const int num_rings = std::min({ + transformer_engine::getenv("GPU_MAX_HW_QUEUES", 4), + _tp_size - 1, + max_rings + }); + + const int *next, *prev; + switch (_tp_size) { + case 8: + next = reinterpret_cast(tp_next_8); + prev = reinterpret_cast(tp_prev_8); + break; + case 4: + next = reinterpret_cast(tp_next_4); + prev = reinterpret_cast(tp_prev_4); + break; + case 2: + return this->split_overlap_rs(A, transa, B, transb, D, bias, pre_gelu_out, workspace, grad, + accumulate, use_split_accumulator, rs_output, stream_main); + default: + NVTE_ERROR("ROCm supports TP sizes of 2, 4, 8 only."); + } + + const int alignment = 256; + const int base_slice_bytes = (comm_bytes / num_rings) & ~(alignment - 1); + const int total_base_bytes = base_slice_bytes * num_rings; + const int remainder_bytes = comm_bytes - total_base_bytes; + + const size_t base_n_slice = n_chunk / num_rings; + const size_t remainder_n = n_chunk - base_n_slice * num_rings; + + auto get_slice_info = [&](int ring) -> std::pair { + size_t offset = ring * base_slice_bytes; + int size = base_slice_bytes; + if (ring == num_rings - 1) + size += remainder_bytes; + return {offset, size}; + }; + + auto get_slice_n = [&](int ring) -> size_t { + return base_n_slice + (ring == num_rings - 1 ? remainder_n : 0); + }; + + auto get_chunk_id = [&](int ring, int step) { + int owner = _tp_id; + for (int s = 0; s < step; ++s) + owner = prev[ring * _tp_size + owner]; + return owner; + }; + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[r], _start_compute, 0)); + } + + for (auto &s : _stream_compute) + NVTE_CHECK_CUDA(cudaStreamWaitEvent(s, _start_compute, 0)); + + const int total_slices = _tp_size * num_rings; + std::vector slice_events(total_slices); + + for (auto &e : slice_events) + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&e, cudaEventDisableTiming)); + + auto get_event = [&](int chunk, int ring) { + return slice_events[chunk * num_rings + ring]; + }; + + for (int r = 0; r < num_rings; ++r) + NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); + + auto get_slice_offset = [&](int chunk, int ring) { + auto [ring_offset, _] = get_slice_info(ring); + return chunk * comm_bytes + ring_offset; + }; + + auto launch_slice_gemm = [&](int chunk_id, int ring_id, int step) { + size_t n_slice = get_slice_n(ring_id); + + size_t b_elem_offset = chunk_id * n_chunk * k; + size_t d_elem_offset = chunk_id * n_chunk * m; + + for (int r = 0; r < ring_id; ++r) { + b_elem_offset += get_slice_n(r) * k; + d_elem_offset += get_slice_n(r) * m; + } + + auto input_b_slice = get_tensor_chunk(B, b_elem_offset, transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); + auto output_slice = get_tensor_chunk(D, d_elem_offset, {n_slice, m}); // D acts as the accumulation buffer + auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[ring_id]); + NVTE_CHECK_CUDA(cudaEventRecord(get_event(chunk_id, ring_id), _stream_compute[ring_id])); + }; + + for (int step = 0; step < _tp_size; ++step) { + for (int r = 0; r < num_rings; ++r) { + int curr_chunk = get_chunk_id(r, step); + launch_slice_gemm(curr_chunk, r, step); + } + + if (step > 0) { + int prev_step = step - 1; + + for (int r = 0; r < num_rings; ++r) { + int chunk_to_send = get_chunk_id(r, prev_step); + + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(chunk_to_send, r), 0)); + + size_t send_off = get_slice_offset(chunk_to_send, r); + auto [_, slice_bytes] = get_slice_info(r); + + int next_rank = next[r * _tp_size + _tp_id]; + int prev_rank = prev[r * _tp_size + _tp_id]; + + { + int peerlocal = next_rank % _ub_comm->nvsize; + void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_off; + void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + send_off; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, r); + + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, slice_bytes, cudaMemcpyDeviceToDevice, l_stream_send[r])); + uint32_t signal_val = prev_step + 1; // Use step count as signal + hipStreamWriteValue32(l_stream_send[r], flagptr, signal_val, 0); + } + + { + int peerlocal = prev_rank % _ub_comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(prev_rank, _ub_comm, _ub_reg, r); + uint32_t signal_val = prev_step + 1; + hipStreamWaitValue32(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFF); + } + } + } + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + + _ub_comm->sms = ori_sms; + + // Cleanup events + for (auto &e : slice_events) NVTE_CHECK_CUDA(cudaEventDestroy(e)); +} // rocm_split_overlap_rs + +} // namespace transformer_engine diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 9c597be30..c118d3281 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -357,12 +357,12 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, NVTE_CHECK_CUDA(cudaDeviceSynchronize()); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA( - cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * sizeof(int))); + cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(reinterpret_cast(&(*comm)->recv_id), - NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); - NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); + NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); + NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); NVTE_CHECK_CUDA( - cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); (*comm)->sms = 16; (*comm)->threads = 1024; @@ -375,8 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, cudaMalloc(reinterpret_cast(&(*comm)->flags_baseptr), 2 * GPU_PAGE_SIZE)); NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE)); (*comm)->flags = reinterpret_cast( +#ifdef __HIP_PLATFORM_AMD__ + (reinterpret_cast((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); +#else ((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); - +#endif using namespace std; sched_param param; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 3d8848d95..fe4514b15 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -5,6 +5,15 @@ ************************************************************************/ #include + +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include "amd_detail/hip_float8.h" +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#else #include #include @@ -13,6 +22,7 @@ #else #define half_dtype half #endif +#endif #include #include @@ -24,6 +34,7 @@ #define MAX_THREADS 1024 +#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ @@ -34,6 +45,18 @@ } \ if (blockIdx.x == 0) __syncthreads(); \ } +#else +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + __threadfence(); \ + } \ + if (blockIdx.x == 0) __syncthreads(); \ + } +#endif #define ATOMIC_PRODUCER(chunk) \ if (counters) { \ @@ -62,7 +85,7 @@ printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__) // Report and error on timeout -#define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) +#define CHECK_TIMEOUT(t, timeout) (((uint64_t)clock64() - (t)) > timeout) template __global__ void __launch_bounds__(MAX_THREADS) @@ -132,7 +155,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -477,7 +500,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -708,7 +731,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); __shared__ int lastSM; @@ -1025,7 +1048,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[0] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } __syncthreads(); @@ -1116,7 +1143,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } __syncthreads(); @@ -1329,7 +1360,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); __shared__ int lastSM; @@ -1357,6 +1388,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } } // fp16 inplace allgather kernel (Volta,Hopper) +#ifndef __HIP_PLATFORM_AMD__ #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[2]; \ @@ -1367,6 +1399,15 @@ __global__ void __launch_bounds__(MAX_THREADS) attribute_ub[0].id = cudaLaunchAttributeCooperative; \ cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#else +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[1]; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + attribute_ub[0].value.cooperative = 1; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 1; +#endif #if (CUDART_VERSION >= 12030) #define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ @@ -1378,6 +1419,11 @@ __global__ void __launch_bounds__(MAX_THREADS) #define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 #endif +#ifdef __HIP_PLATFORM_AMD__ +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg; \ + NVTE_ERROR("SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT is not supported for AMD GPUs") +#else #define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ @@ -1389,6 +1435,7 @@ __global__ void __launch_bounds__(MAX_THREADS) attribute_ub[0].id = cudaLaunchAttributeCooperative; \ cfg.attrs = attribute_ub; \ cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; +#endif #define callranks_ag(x) \ if (ar_nvsize == x) { \ @@ -2049,7 +2096,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence_system(); + __threadfence(); atomicAdd_system(flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2111,7 +2158,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence_system(); + __threadfence(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2169,7 +2216,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence_system(); + __threadfence(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2196,7 +2243,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[0] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } } @@ -2236,7 +2287,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat } __syncthreads(); if (!threadIdx.x) { - __threadfence_system(); + __threadfence(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } @@ -2267,7 +2318,11 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } @@ -2284,6 +2339,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Return TRUE if two ranks share the same NV domain #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) +#ifndef __HIP_PLATFORM_AMD__ // Moved to header for visibility // Index corresponds to the type of flag: // 0 - Send index counter // 1 - CE start index counter @@ -2303,12 +2359,13 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ sizeof(int))) +#endif // #ifndef __HIP_PLATFORM_AMD__ void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream) { + const int peer, cudaStream_t stream, int ring_id) { int peerlocal = peer % comm->nvsize; - void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, ring_id); // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); @@ -2317,7 +2374,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (comm->push == 0) { - kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), + kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer * NVTE_MAX_RINGS + ring_id]), reinterpret_cast(flagptr)); NVTE_CHECK_CUDA(cudaGetLastError()); } else { @@ -2330,7 +2387,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); - int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); + int *arg1 = &comm->send_id[peer * NVTE_MAX_RINGS + ring_id], *arg2 = reinterpret_cast(flagptr); int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); int arg5 = signalonly ? 0 : bytes / 16; void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), @@ -2500,9 +2557,9 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream) { + const int peer, cudaStream_t stream, int ring_id) { int peerlocal = peer % comm->nvsize; - void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); + void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, ring_id); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); assert(INTRANODE(peer)); @@ -2514,12 +2571,12 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), + &(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); NVTE_CHECK_CUDA(cudaGetLastError()); if (!signalonly) { - kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id])); NVTE_CHECK_CUDA(cudaGetLastError()); } if (comm->use_ce) { @@ -2528,7 +2585,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds } else { kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], reinterpret_cast(flagptr), + &comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id], reinterpret_cast(flagptr), signalonly || comm->sms, comm->ub_timeout, reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) @@ -2576,7 +2633,11 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { // COMM kernel need to explicitely flash gmem. // GEMM kernel already executed, and can not see gmem // change without COMM kernel explicitely make change +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } // consumer @@ -2586,7 +2647,11 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { } ((unsigned int *)atomic_ptr)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } @@ -2598,7 +2663,11 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { } ((unsigned int *)atomic_ptr)[i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index c8d7c8731..dfcc13450 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -35,6 +35,7 @@ using ExtBarrierOp = std::function; #define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_CPU 2 #define NVTE_MAX_NVLINK 32 +#define NVTE_MAX_RINGS 7 #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -63,6 +64,28 @@ using ExtBarrierOp = std::function; #define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3) #define NVTE_MAX_SHARP 16 +#ifdef __HIP_PLATFORM_AMD__ // Moved to header for visibility +// Index corresponds to the type of flag: +// 0 - Send index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsth, index) \ + ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + +// Index corresponds to the type of flag: +// 0 - Receive index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ + ((reinterpret_cast((comm)->mem_ptr[0])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) +#endif // #ifdef __HIP_PLATFORM_AMD__ + typedef struct ub_request { int optype; int blocksize; @@ -268,10 +291,10 @@ output is strided: row starts separated by stride elements*/ void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream = 0); + const int peer, cudaStream_t stream = 0, int ring_id = 0); void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream = 0); + const int peer, cudaStream_t stream = 0, int ring_id = 0); void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, const size_t bytes, communicator *comm, const int send_peer, const int recv_peer, cudaStream_t stream = 0); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 6307eab14..75b9c6f5a 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -15,7 +17,7 @@ #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" -#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#define NVTE_COMM_OVERLAP_MAX_STREAMS 7 namespace transformer_engine { @@ -37,7 +39,7 @@ enum class CommOverlapAlgo { ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, ATOMIC_GEMM_RS_P2P = 7, - EXTERNAL_BULK_OVERLAP_AG = 8, + EXTERNAL_BULK_OVERLAP_AG = 8 }; class CommOverlapCore { @@ -107,8 +109,13 @@ class CommOverlapCore { bool is_p2p_overlap() { return _is_p2p; } + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + virtual bool is_aggregate() { + NVTE_ERROR("Operation is not implemented."); + } + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, @@ -133,6 +140,14 @@ class CommOverlapCore { NVTE_ERROR("Operation is not implemented."); } + virtual void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, @@ -153,6 +168,14 @@ class CommOverlapCore { cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } + + virtual void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { @@ -225,6 +248,22 @@ class CommOverlapBase : public CommOverlapCore { void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, cudaStream_t stream_main) override; + + void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -238,9 +277,13 @@ class CommOverlapP2PBase : public CommOverlapCore { int _num_ubuf_chunks; int _self_chunk_id; std::vector _ubufs; - std::vector _stream_send; + std::vector _stream_send, l_stream_send, l_stream_recv; cudaStream_t _stream_recv; - cudaEvent_t _stop_send, _stop_recv; + cudaEvent_t _stop_send, _stop_recv, l_stop_recv[7]; + + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); private: void initialize(const std::vector &buffer_shape, DType buffer_dtype, @@ -311,6 +354,25 @@ class CommOverlapP2PBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + /* + ** ROCm Multiring ReduceScatter + GEMM + */ + void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + /* + ** ROCm Multiring AllGather + GEMM + */ + void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + + bool is_aggregate() { return _aggregate; } // needed for rocm pathing /* ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. diff --git a/transformer_engine/common/include/transformer_engine/multi_stream.h b/transformer_engine/common/include/transformer_engine/multi_stream.h index e30e7c0e5..d570c87e6 100644 --- a/transformer_engine/common/include/transformer_engine/multi_stream.h +++ b/transformer_engine/common/include/transformer_engine/multi_stream.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 748c0f1db..515131b1d 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -27,7 +27,7 @@ namespace { #include "string_path_cuda_include.h" } // namespace -#endif // __HIP_PLATFORM_AMD__ +#endif // #ifndef __HIP_PLATFORM_AMD__ int num_devices() { auto query_num_devices = []() -> int { @@ -103,7 +103,6 @@ int sm_count(int device_id) { return cache[device_id]; } -#ifndef __HIP_PLATFORM_AMD__ void stream_priority_range(int *low_priority, int *high_priority, int device_id) { static std::vector> cache(num_devices()); static std::vector flags(num_devices()); @@ -124,6 +123,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id) *high_priority = cache[device_id].second; } +#ifdef __HIP_PLATFORM_AMD__ +bool supports_multicast(int _) { + return false; +} +#else bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile-time and run-time because the diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 70250e079..1cccb492f 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -50,7 +50,6 @@ const std::string &sm_arch_name(int device_id = -1); */ int sm_count(int device_id = -1); -#ifndef __HIP_PLATFORM_AMD__ /* \brief Minimum and maximum stream priorities supported on device * * \param[in] device_id CUDA device (default is current device) @@ -68,7 +67,6 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id * \return CUDA multicast support flag */ bool supports_multicast(int device_id = -1); -#endif /* \brief Path to CUDA/ROCm Toolkit headers * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 0cd24b668..cb2f17f9e 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -10,10 +10,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include -//TODO: rocm does not support comm gemm overlap yet -#ifndef USE_ROCM #include -#endif #include #include @@ -35,9 +32,6 @@ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); #endif -// Define comm overlap handles if not using ROCm -#ifndef USE_ROCM - #define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ @@ -56,7 +50,9 @@ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ .value("EXTERNAL_BULK_OVERLAP_AG", \ - transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG) \ + .value("SPLIT_PIPELINED_AG_RD_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ @@ -91,14 +87,6 @@ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); -#else -#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ - pybind11::class_(m, "CommOverlapType", \ - pybind11::module_local()); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()); -#endif #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 27f24a961..992ffb15b 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -30,9 +30,7 @@ #include #include #include -#ifndef USE_ROCM #include -#endif #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cbdc63dc2..78f0134d5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -17,14 +17,6 @@ class CommOverlapHelper; class CommOverlap; class CommOverlapP2P; -#ifdef USE_ROCM -namespace transformer_engine { -//dummy CommOverlapCore, CommOverlapType in rocm -class CommOverlapCore{}; -class CommOverlapType{}; -} -#endif - namespace transformer_engine::pytorch { /*************************************************************************************************** @@ -517,7 +509,6 @@ void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at: } // namespace transformer_engine::pytorch -#ifndef USE_ROCM /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -589,6 +580,5 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm std::pair get_communication_stream(); }; // CommOverlapP2P -#endif // !USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 0844bbdc0..e33734aa5 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -5,7 +5,6 @@ * * See LICENSE for license information. ************************************************************************/ -#ifndef USE_ROCM #include "../extensions.h" #include "transformer_engine/transformer_engine.h" @@ -314,10 +313,12 @@ std::pair CommOverlapP2P::get_communication_stream() { at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; } +#ifndef USE_ROCM void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { auto main_stream = at::cuda::getCurrentCUDAStream(); allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), at::cuda::CUDAStream(recv_stream), main_stream); } -#endif // !USE_ROCM +#endif + \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index eb9c2bf72..6709a8b86 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -259,7 +259,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans #endif if (comm_overlap) { -#ifndef USE_ROCM // Prepare extra output tensor TensorWrapper extra_output_tensor; if (extra_output.has_value()) { @@ -268,7 +267,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor = makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); } - // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ @@ -285,6 +283,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); +#ifdef __HIP_PLATFORM_AMD__ + } else if (!comm_overlap->is_aggregate()) { + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->rocm_split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + }); +#endif // #ifdef __HIP_PLATFORM_AMD } else { NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, @@ -302,17 +309,22 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor, main_stream); }); } else { +#ifdef __HIP_PLATFORM_AMD__ + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->rocm_split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, extra_output_tensor, + main_stream); +#else NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); +#endif }); } } -#else - NVTE_ERROR("ROCm TE does not support comm_overlap\n"); -#endif //!USE_ROCM } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index db70dfbf1..40cc16c23 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -490,7 +490,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); -#ifndef USE_ROCM py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) .def(py::init>(), @@ -530,9 +529,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -#else - m.def("CommOverlapHelper", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlap", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlapP2P", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); -#endif //USE_ROCM } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e74fd9d17..29a523604 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -16,6 +16,7 @@ from contextlib import contextmanager import logging from types import MethodType +from itertools import chain import torch import torch.nn.functional as F @@ -327,7 +328,7 @@ def initialize_ub( # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} - external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} + external_gemm_to_overlap = {} if IS_HIP_EXTENSION else {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -348,7 +349,7 @@ def get_default_config(name): "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": not method == "ring_exchange", + "set_sm_margin": not method == "ring_exchange" and not IS_HIP_EXTENSION, "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, @@ -428,6 +429,7 @@ def add_ub( if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) else dtype ) + if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape @@ -480,9 +482,7 @@ def add_ub( new_method = user_ub_cfg[name]["method"] methods[new_method].append(name) - for name in ( - methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] - ): + for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) if user_ub_cfg is not None and name in user_ub_cfg: fp8_buf = (name in layers_all_gather_overlap) or ( diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 9827a916f..a2a01b728 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -34,13 +34,12 @@ ForwardLinearScaleAdd, fuse_forward_linear_scale_add, ) -from torch.utils.cpp_extension import IS_HIP_EXTENSION -if not IS_HIP_EXTENSION: - from .userbuffers_backward_linear import ( - UserbuffersBackwardLinear, - fuse_userbuffers_backward_linear, - ) - from .userbuffers_forward_linear import ( - UserbuffersForwardLinear, - fuse_userbuffers_forward_linear, - ) + +from .userbuffers_backward_linear import ( + UserbuffersBackwardLinear, + fuse_userbuffers_backward_linear, +) +from .userbuffers_forward_linear import ( + UserbuffersForwardLinear, + fuse_userbuffers_forward_linear, +) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 96c28c48b..7e926b337 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -36,6 +36,8 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from torch.utils.cpp_extension import IS_HIP_EXTENSION + warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -311,8 +313,8 @@ def __init__( ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, - ub_bulk_dgrad: bool = True, - ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = not IS_HIP_EXTENSION, + ub_bulk_wgrad: bool = not IS_HIP_EXTENSION, bias: bool = True, activation: str = "gelu", activation_params: Optional[dict] = None, From 640f7e809334ba36139719b5335bc65127f65154 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 27 Jan 2026 15:52:53 +0000 Subject: [PATCH 03/69] Copyrights and cleanup --- .../comm_gemm_overlap/te_layer_with_overlap_profile.py | 2 +- tests/cpp/operator/test_normalization_mxfp8.cu | 6 +++--- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 ++ .../comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 2 ++ .../common/comm_gemm_overlap/userbuffers/userbuffers.cu | 2 ++ .../common/comm_gemm_overlap/userbuffers/userbuffers.h | 2 ++ transformer_engine/common/util/pybind_helper.h | 4 +--- .../pytorch/csrc/extensions/comm_gemm_overlap.cpp | 1 - transformer_engine/pytorch/transformer.py | 2 ++ 9 files changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py index ba5afd2b6..02b9b9696 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -1,7 +1,7 @@ #!/usr/bin/python3 # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index b3b94ea01..f38a9e695 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -131,10 +131,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, DType wtype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input("input2", std::vector{ N, H }, itype); + Tensor input("input", std::vector{ N, H }, itype); Tensor z("z", std::vector{ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); - Tensor gamma("gamma2", std::vector{ H }, wtype); - Tensor beta("beta2", std::vector{ H }, wtype); + Tensor gamma("gamma", std::vector{ H }, wtype); + Tensor beta("beta", std::vector{ H }, wtype); Tensor mu("mu", std::vector{ N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); Tensor workspace; diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 00d8dc322..1dc340023 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index c118d3281..fe7f839b8 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index fe4514b15..55b4f5229 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index dfcc13450..7f087dadc 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index cb2f17f9e..456d2ea50 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -50,9 +50,7 @@ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ .value("EXTERNAL_BULK_OVERLAP_AG", \ - transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG) \ - .value("SPLIT_PIPELINED_AG_RD_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index e33734aa5..0613f084b 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -321,4 +321,3 @@ void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( at::cuda::CUDAStream(recv_stream), main_stream); } #endif - \ No newline at end of file diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7e926b337..5c899465d 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From 82faeec99fc58d9bee278f76cb8ec666f11a7629 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 13:28:31 -0600 Subject: [PATCH 04/69] test guards --- tests/pytorch/distributed/test_comm_gemm_overlap.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index d1c03de70..62010c84e 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -12,7 +12,7 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.jax.cpp_extensions.misc import is_hip_extension +from torch.utils.cpp_extension import IS_HIP_EXTENSION if torch.cuda.device_count() < 2: @@ -71,6 +71,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization if bulk: test_cmd.append("--bulk-overlap") else: + if IS_HIP_EXTENSION and not p2p: + pytest.skip("HIP only supports A2A operations.") if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: @@ -184,11 +186,13 @@ def test_bulk_overlaps(comm_type, quantization, connections): Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ if connections == 8: - if is_hip_extension() or torch.cuda.get_device_properties(0).major != 9: + if torch.cuda.get_device_properties(0).major != 9: pytest.skip( "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" " 9.0 (HOPPER ARCH)." ) + if IS_HIP_EXTENSION: + pytest.skip("HIP Does not support bulk overlaps with 8 connections.") os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" From b6a3ae424b5e56a52134f108f8200bfc2babe45d Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Mon, 2 Mar 2026 16:30:31 -0600 Subject: [PATCH 05/69] Cleanup and RS flag race condition fix --- build_tools/pytorch.py | 20 +- ci/pytorch.sh | 3 + .../distributed/run_layer_with_overlap.py | 16 +- .../distributed/test_comm_gemm_overlap.py | 10 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 10 +- .../rocm_comm_gemm_overlap.cpp | 230 ++++++------------ .../userbuffers/userbuffers-host.cpp | 9 +- .../userbuffers/userbuffers.cu | 14 +- .../userbuffers/userbuffers.h | 9 +- .../transformer_engine/comm_gemm_overlap.h | 3 + .../pytorch/csrc/extensions/gemm.cpp | 24 +- transformer_engine/pytorch/module/base.py | 51 +++- 12 files changed, 181 insertions(+), 218 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index db7b61d02..ee0b0bc2a 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -81,6 +81,14 @@ def setup_pytorch_extension( if version < (12, 0): raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_path / "include") + cxx_flags.append("-DNVTE_UB_WITH_MPI") + library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): @@ -100,17 +108,6 @@ def setup_pytorch_extension( libraries.append("mpi") cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"]) - extra_link_args = [] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) - include_dirs.append(mpi_path / "include") - library_dirs.append(mpi_path / "lib") - libraries.append("mpi") - cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"]) - # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] @@ -123,5 +120,4 @@ def setup_pytorch_extension( extra_compile_args={"cxx": cxx_flags}, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], - extra_link_args=[str(arg) for arg in extra_link_args], ) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index a7259ebf6..38ab019da 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -94,7 +94,10 @@ run_test_config_mgpu(){ #run in parallel on CI and it affects timing run_default_fa 1 test_gemm_sm_count.py run_default_fa 3 test_sanity_import.py +<<<<<<< HEAD run_default_fa 3 distributed/test_cast_master_weights_to_fp8.py +======= +>>>>>>> 16b62493 (Cleanup and RS flag race condition fix) run_default_fa 3 distributed/test_comm_gemm_overlap.py run_default_fa 2 distributed/test_fusible_ops.py run_default_fa 2 distributed/test_numerics.py diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index f732e9a44..bd431db58 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -28,14 +28,17 @@ MXFP8BlockScaling, ) +from torch.utils.cpp_extension import IS_HIP_EXTENSION + warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) -import transformer_engine.pytorch.cpp_extensions as tex -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -if not tex.device_supports_multicast(): - os.environ["UB_SKIPMC"] = "1" +if IS_HIP_EXTENSION: + import transformer_engine.pytorch.cpp_extensions as tex + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" class multi_module_model(torch.nn.Module): @@ -118,6 +121,7 @@ def _get_layer_args(config, tp_group, tp_size, num_layers, reference=False): kwargs["input_layernorm"] = True else: kwargs["ub_tp_comm_overlap"] = not reference + # Disable forward pass overlaps on HIP to isolate backward RS overlap kwargs["hidden_dropout"] = 0.0 kwargs["set_parallel_mode"] = True kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference @@ -559,8 +563,8 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 + rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 5e-2 + atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else 1e-2 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 62010c84e..9f79fe5ad 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -72,7 +72,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization test_cmd.append("--bulk-overlap") else: if IS_HIP_EXTENSION and not p2p: - pytest.skip("HIP only supports A2A operations.") + pytest.skip("HIP only supports P2P operations.") if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: @@ -99,6 +99,9 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization def _run_layer_with_overlap( layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 ): + # Skip BULK overlap tests on HIP (column parallel or None with overlap_rs_dgrad=False) + if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None): + pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -162,6 +165,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p): _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) +@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Bulk overlap is not yet supported on ROCm.") @pytest.mark.parametrize( "comm_type, quantization, connections", [ @@ -191,8 +195,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" " 9.0 (HOPPER ARCH)." ) - if IS_HIP_EXTENSION: - pytest.skip("HIP Does not support bulk overlaps with 8 connections.") os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -227,7 +229,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): ids=[ f" {te.Linear.__name__} - ROW-PARALLEL ", f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", - f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.Linear.__name__} - COL-PARALLEL - DGRAD+RS ", f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 1dc340023..1a0410898 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -206,6 +206,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk"); NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width, "Attempted to get out-of-bounds tensor chunk"); +#ifndef __HIP_PLATFORM_AMD__ if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { // MXFP8 scale-inverses are padded to a 2D matrix with dims that // are divisible by 128. UB doesn't handle this padding yet. @@ -214,6 +215,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0, "Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128"); } +#endif #undef NVTE_DIM_CHECK // Construct tensor chunk @@ -726,12 +728,12 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); } - for (int i = 0; i < 7; i++) { + for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); l_stream_send.push_back(std::move(stream)); } - for (int i = 0; i < 7; i++) { + for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); l_stream_recv.push_back(std::move(stream)); @@ -740,7 +742,7 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); - for (int i = 0; i < 7; i++) { + for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&l_stop_recv[i], 0)); } } @@ -752,7 +754,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) { cudaStreamDestroy(_stream_send[i]); } - for (int i = 0; i < 7; i++) { + for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { cudaStreamDestroy(l_stream_recv[i]); cudaStreamDestroy(l_stream_send[i]); cudaEventDestroy(l_stop_recv[i]); diff --git a/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp index ea05ea95f..da430bec3 100644 --- a/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp @@ -305,6 +305,9 @@ void CommOverlapP2PBase::rocm_split_overlap_ag(const TensorWrapper &A, bool tran NVTE_CHECK(B_copy.element_size() == _ubuf.element_size()); } + const uint64_t ag_signal_base = _ag_signal_base + _tp_size; + uint64_t signal_val; + auto get_slice_info = [&](int ring) -> std::pair { size_t offset = ring * base_slice_bytes; int size = base_slice_bytes; @@ -403,7 +406,6 @@ void CommOverlapP2PBase::rocm_split_overlap_ag(const TensorWrapper &A, bool tran int prev_rank = prev[r * _tp_size + _tp_id]; size_t send_off = get_slice_offset(curr_chunk_id, r); - size_t recv_off = get_slice_offset(next_recv_chunk_id, r); auto [_, slice_bytes] = get_slice_info(r); @@ -416,18 +418,16 @@ void CommOverlapP2PBase::rocm_split_overlap_ag(const TensorWrapper &A, bool tran void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, r); void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_off; void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + send_off; - + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, slice_bytes, cudaMemcpyDeviceToDevice, l_stream_send[r])); - uint32_t signal_val = step + 1; - hipStreamWriteValue32(l_stream_send[r], flagptr, signal_val, 0); + signal_val = ag_signal_base + step + 1; + hipStreamWriteValue64(l_stream_send[r], flagptr, signal_val, 0); } { - int peerlocal = prev_rank % _ub_comm->nvsize; void *flagptr = GET_RECV_PTR_BY_INDEX(prev_rank, _ub_comm, _ub_reg, r); - - uint32_t signal_val = step + 1; - hipStreamWaitValue32(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFF); + signal_val = ag_signal_base + step + 1; + hipStreamWaitValue64(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFFFFFFFFFF); } NVTE_CHECK_CUDA(cudaEventRecord(get_event(next_recv_chunk_id, r), l_stream_recv[r])); @@ -439,6 +439,8 @@ void CommOverlapP2PBase::rocm_split_overlap_ag(const TensorWrapper &A, bool tran } } + _ag_signal_base = signal_val; + if (B_copy.numel() > 0) { for (int r = 0; r < num_rings; r++) { int last_chunk = get_chunk_id(r, _tp_size - 1); @@ -479,172 +481,93 @@ void CommOverlapP2PBase::rocm_split_overlap_rs(const TensorWrapper &A, bool tran _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - // GEMM dimensions const size_t m = transa ? A.size(0) : A.size(1); const size_t k = transa ? A.size(1) : A.size(0); const size_t n_chunk = _ubufs[0].size(0); const int comm_bytes = _ubufs[0].bytes(); - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - - const int max_rings = (_tp_size == 4) ? 2 : - (_tp_size == 6) ? 4 : - _tp_size - 1; - - const int num_rings = std::min({ - transformer_engine::getenv("GPU_MAX_HW_QUEUES", 4), - _tp_size - 1, - max_rings - }); - - const int *next, *prev; - switch (_tp_size) { - case 8: - next = reinterpret_cast(tp_next_8); - prev = reinterpret_cast(tp_prev_8); - break; - case 4: - next = reinterpret_cast(tp_next_4); - prev = reinterpret_cast(tp_prev_4); - break; - case 2: - return this->split_overlap_rs(A, transa, B, transb, D, bias, pre_gelu_out, workspace, grad, - accumulate, use_split_accumulator, rs_output, stream_main); - default: - NVTE_ERROR("ROCm supports TP sizes of 2, 4, 8 only."); - } - - const int alignment = 256; - const int base_slice_bytes = (comm_bytes / num_rings) & ~(alignment - 1); - const int total_base_bytes = base_slice_bytes * num_rings; - const int remainder_bytes = comm_bytes - total_base_bytes; - - const size_t base_n_slice = n_chunk / num_rings; - const size_t remainder_n = n_chunk - base_n_slice * num_rings; - - auto get_slice_info = [&](int ring) -> std::pair { - size_t offset = ring * base_slice_bytes; - int size = base_slice_bytes; - if (ring == num_rings - 1) - size += remainder_bytes; - return {offset, size}; - }; - - auto get_slice_n = [&](int ring) -> size_t { - return base_n_slice + (ring == num_rings - 1 ? remainder_n : 0); - }; + const size_t input_chunk_size = n_chunk * k; + const size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto get_chunk_id = [&](int ring, int step) { - int owner = _tp_id; - for (int s = 0; s < step; ++s) - owner = prev[ring * _tp_size + owner]; - return owner; - }; + const uint64_t rs_signal_base = _rs_signal_base + _tp_size; + int64_t signal_val; + // Catch up all streams to main NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - for (int r = 0; r < num_rings; r++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[r], _start_compute, 0)); - } - - for (auto &s : _stream_compute) - NVTE_CHECK_CUDA(cudaStreamWaitEvent(s, _start_compute, 0)); - - const int total_slices = _tp_size * num_rings; - std::vector slice_events(total_slices); - - for (auto &e : slice_events) - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&e, cudaEventDisableTiming)); - - auto get_event = [&](int chunk, int ring) { - return slice_events[chunk * num_rings + ring]; - }; - - for (int r = 0; r < num_rings; ++r) - NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); - - auto get_slice_offset = [&](int chunk, int ring) { - auto [ring_offset, _] = get_slice_info(ring); - return chunk * comm_bytes + ring_offset; - }; - - auto launch_slice_gemm = [&](int chunk_id, int ring_id, int step) { - size_t n_slice = get_slice_n(ring_id); - - size_t b_elem_offset = chunk_id * n_chunk * k; - size_t d_elem_offset = chunk_id * n_chunk * m; + for (size_t i = 0; i < l_stream_send.size(); i++) + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[i], _start_compute, 0)); + for (size_t i = 0; i < l_stream_recv.size(); i++) + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[i], _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); - for (int r = 0; r < ring_id; ++r) { - b_elem_offset += get_slice_n(r) * k; - d_elem_offset += get_slice_n(r) * m; - } + for (int i = 0; i < _tp_size; i++) { + int stream_id = i % _stream_compute.size(); + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - auto input_b_slice = get_tensor_chunk(B, b_elem_offset, transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); - auto output_slice = get_tensor_chunk(D, d_elem_offset, {n_slice, m}); // D acts as the accumulation buffer - auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, {workspace_size_chunk}); + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); + auto workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); - nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), - accumulate, use_split_accumulator, _math_sms, - _stream_compute[ring_id]); - NVTE_CHECK_CUDA(cudaEventRecord(get_event(chunk_id, ring_id), _stream_compute[ring_id])); - }; - - for (int step = 0; step < _tp_size; ++step) { - for (int r = 0; r < num_rings; ++r) { - int curr_chunk = get_chunk_id(r, step); - launch_slice_gemm(curr_chunk, r, step); - } - - if (step > 0) { - int prev_step = step - 1; - - for (int r = 0; r < num_rings; ++r) { - int chunk_to_send = get_chunk_id(r, prev_step); - - NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(chunk_to_send, r), 0)); - - size_t send_off = get_slice_offset(chunk_to_send, r); - auto [_, slice_bytes] = get_slice_info(r); - - int next_rank = next[r * _tp_size + _tp_id]; - int prev_rank = prev[r * _tp_size + _tp_id]; - - { - int peerlocal = next_rank % _ub_comm->nvsize; - void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_off; - void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + send_off; - void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, r); - - NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, slice_bytes, cudaMemcpyDeviceToDevice, l_stream_send[r])); - uint32_t signal_val = prev_step + 1; // Use step count as signal - hipStreamWriteValue32(l_stream_send[r], flagptr, signal_val, 0); - } + accumulate, use_split_accumulator, _math_sms, _stream_compute[stream_id]); + + if (i > 0) { + // Each step uses its own send/recv stream — fully parallel since each + // send goes to a unique destination rank (the chunk owner) + int comm_stream_id = i - 1; + int prev_stream_id = (i - 1) % _stream_compute.size(); + + const int send_offset = comm_bytes * (i - 1); + const int recv_offset = comm_bytes * (i - 1 + _tp_size); + const int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + const int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + signal_val = rs_signal_base + i; + + // Wait for GEMM of previous chunk before sending + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[comm_stream_id], _start_comm, 0)); + + // Send partial to chunk owner + { + int peerlocal = send_rank % _ub_comm->nvsize; + void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_offset; + void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + recv_offset; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, comm_stream_id); + + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, comm_bytes, + cudaMemcpyDeviceToDevice, l_stream_send[comm_stream_id])); + hipStreamWriteValue64(l_stream_send[comm_stream_id], flagptr, signal_val, 0); + } - { - int peerlocal = prev_rank % _ub_comm->nvsize; - void *flagptr = GET_RECV_PTR_BY_INDEX(prev_rank, _ub_comm, _ub_reg, r); - uint32_t signal_val = prev_step + 1; - hipStreamWaitValue32(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFF); - } + // Wait for incoming partial from chunk contributor + { + void *flagptr = GET_RECV_PTR_BY_INDEX(recv_rank, _ub_comm, _ub_reg, comm_stream_id); + hipStreamWaitValue64(l_stream_recv[comm_stream_id], flagptr, signal_val, + hipStreamWaitValueGte, 0xFFFFFFFFFFFFFFFF); } } } + + _rs_signal_base = signal_val; + + // Sync all streams back to main for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - for (int r = 0; r < num_rings; r++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + for (int i = 0; i < _tp_size - 1; i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); } - // Reduce GEMM output chunks + // Reduce: received partials live at _ubufs[_tp_size-1] through _ubufs[2*_tp_size-2] + // plus local partial at _ubufs[_tp_size-1], matching single ring layout exactly char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( @@ -656,9 +579,6 @@ void CommOverlapP2PBase::rocm_split_overlap_rs(const TensorWrapper &A, bool tran } _ub_comm->sms = ori_sms; - - // Cleanup events - for (auto &e : slice_events) NVTE_CHECK_CUDA(cudaEventDestroy(e)); -} // rocm_split_overlap_rs +} } // namespace transformer_engine diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index fe7f839b8..fcccd396a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -359,12 +359,12 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, NVTE_CHECK_CUDA(cudaDeviceSynchronize()); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA( - cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); + cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * NVTE_ROCM_MAX_RINGS * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(reinterpret_cast(&(*comm)->recv_id), - NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); - NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); + NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_ROCM_MAX_RINGS * sizeof(int))); + NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * NVTE_ROCM_MAX_RINGS * sizeof(int))); NVTE_CHECK_CUDA( - cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); + cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_ROCM_MAX_RINGS * sizeof(int))); (*comm)->sms = 16; (*comm)->threads = 1024; @@ -725,5 +725,4 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; - printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 55b4f5229..1eb958ef7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -36,7 +36,7 @@ #define MAX_THREADS 1024 -#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) +#ifndef __HIP_PLATFORM_AMD__ #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ @@ -87,7 +87,7 @@ printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__) // Report and error on timeout -#define CHECK_TIMEOUT(t, timeout) (((uint64_t)clock64() - (t)) > timeout) +#define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) template __global__ void __launch_bounds__(MAX_THREADS) @@ -2376,7 +2376,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (comm->push == 0) { - kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer * NVTE_MAX_RINGS + ring_id]), + kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer * NVTE_ROCM_MAX_RINGS + ring_id]), reinterpret_cast(flagptr)); NVTE_CHECK_CUDA(cudaGetLastError()); } else { @@ -2389,7 +2389,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); - int *arg1 = &comm->send_id[peer * NVTE_MAX_RINGS + ring_id], *arg2 = reinterpret_cast(flagptr); + int *arg1 = &comm->send_id[peer * NVTE_ROCM_MAX_RINGS + ring_id], *arg2 = reinterpret_cast(flagptr); int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); int arg5 = signalonly ? 0 : bytes / 16; void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), @@ -2573,12 +2573,12 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id]), reinterpret_cast(flagptr), + &(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_ROCM_MAX_RINGS + ring_id]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); NVTE_CHECK_CUDA(cudaGetLastError()); if (!signalonly) { - kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id])); + kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_ROCM_MAX_RINGS + ring_id])); NVTE_CHECK_CUDA(cudaGetLastError()); } if (comm->use_ce) { @@ -2587,7 +2587,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds } else { kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id], reinterpret_cast(flagptr), + &comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_ROCM_MAX_RINGS + ring_id], reinterpret_cast(flagptr), signalonly || comm->sms, comm->ub_timeout, reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 7f087dadc..d897a0f37 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -37,7 +37,10 @@ using ExtBarrierOp = std::function; #define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_CPU 2 #define NVTE_MAX_NVLINK 32 -#define NVTE_MAX_RINGS 7 + +#define NVTE_ROCM_MAX_TP_SIZE 8 +// Maximum # of rings possible for ring_exchange +#define NVTE_ROCM_MAX_RINGS (NVTE_ROCM_MAX_TP_SIZE - 1) #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -75,7 +78,7 @@ using ExtBarrierOp = std::function; ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ - sizeof(int))) + sizeof(uint64_t))) // Index corresponds to the type of flag: // 0 - Receive index counter @@ -85,7 +88,7 @@ using ExtBarrierOp = std::function; ((reinterpret_cast((comm)->mem_ptr[0])) + \ ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ - sizeof(int))) + sizeof(uint64_t))) #endif // #ifdef __HIP_PLATFORM_AMD__ typedef struct ub_request { diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 75b9c6f5a..3cda89ca1 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -281,6 +281,9 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv, l_stop_recv[7]; + uint64_t _ag_signal_base = 0; + uint64_t _rs_signal_base = 0; + private: void initialize(const std::vector &buffer_shape, DType buffer_dtype, CommOverlapType comm_type, bool aggregate); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 6709a8b86..88883a721 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -310,19 +310,23 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans }); } else { #ifdef __HIP_PLATFORM_AMD__ - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->rocm_split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); -#else + if (comm_overlap->is_p2p_overlap()) { + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->rocm_split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, extra_output_tensor, + main_stream); + }); + } else +#endif + { NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); -#endif + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, extra_output_tensor, + main_stream); }); + } } } } else { diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 29a523604..b83547e8d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -313,17 +313,36 @@ def initialize_ub( layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] # Default overlap methods for layers - methods = { - "ring_exchange": [ - "qkv_fprop", - "fc1_fprop", - "proj_dgrad", - "fc2_dgrad", - ], - "pipeline": ["proj_fprop", "fc2_fprop"], - "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], - "external": ["proj_wgrad", "fc2_wgrad"], - } + if IS_HIP_EXTENSION: + methods = { + "ring_exchange": [ + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + "proj_wgrad", + "fc2_wgrad", + "proj_fprop", + "fc2_fprop", + "qkv_wgrad", + "fc1_wgrad" + ], + "pipeline": [], + # TODO: Investigate issues with qkv_dgrad and fc1_dgrad overlap on ROCm + "bulk": [], + } + else: + methods = { + "ring_exchange": [ + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + ], + "pipeline": ["proj_fprop", "fc2_fprop"], + "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + "external": ["proj_wgrad", "fc2_wgrad"], + } # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} @@ -478,10 +497,18 @@ def add_ub( layers_reduce_scatter_overlap.remove(wgrad_name) layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) - methods["bulk"].remove(name) + if name in methods["bulk"]: + methods["bulk"].remove(name) new_method = user_ub_cfg[name]["method"] methods[new_method].append(name) + if IS_HIP_EXTENSION and user_ub_cfg is not None: + for name, cfg in user_ub_cfg.items(): + assert cfg.get("method") != "bulk", ( + f"Bulk overlap method for '{name}' is not supported on HIP/ROCm. " + "Please use 'ring_exchange' method instead." + ) + for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) if user_ub_cfg is not None and name in user_ub_cfg: From 9e32d3a25a5441e3164706c78116bb7f7eefb485 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Thu, 12 Mar 2026 18:55:10 -0500 Subject: [PATCH 06/69] Debugging midpoint --- build_tools/hipify/custom_map.json | 1 + .../distributed/run_layer_with_overlap.py | 10 +++-- .../distributed/test_comm_gemm_overlap.py | 7 ++-- transformer_engine/common/CMakeLists.txt | 11 ++--- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 32 --------------- .../userbuffers/userbuffers-host.cpp | 2 +- .../userbuffers/userbuffers.cu | 40 +++++++++---------- .../transformer_engine/comm_gemm_overlap.h | 17 +++----- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/pybind.cpp | 4 +- transformer_engine/pytorch/module/base.py | 8 ++++ 11 files changed, 55 insertions(+), 81 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 72bf8a383..a38306eb1 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -1,6 +1,7 @@ { "custom_map" : { "" : "", + "util/cuda_runtime.h" : "util/hip_runtime.h", "" : "\"common/amd_detail/hip_float8.h\"", "" : "", "cuda_runtime.h\"" : "hip_runtime.h\"", diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index bd431db58..0c6df7f2e 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -248,7 +248,7 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--debug", action="store_true", - default=False, + default=True, help="Print out additional debug information.", ) parser.add_argument( @@ -388,7 +388,7 @@ def _train(opts): " layers. Use --num-heads or --head-dim for other cases." ) - def dist_print(msg, src=None, end="\n", debug=False, error=False): + def dist_print(msg, src=None, end="\n", debug=True, error=False): if debug and not opts.debug: return stream = sys.stderr if error else sys.stdout @@ -439,6 +439,8 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ] if opts.first_last_layers_bf16 and opts.fp8: quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE) + + dist_print(f"Before initialize_ub, opts: {opts}") te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], @@ -563,8 +565,8 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 5e-2 - atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else 1e-2 + rtol = 0.125 if opts.fp8 else 0.025 + atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 9f79fe5ad..ffa40f0c4 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -100,8 +100,8 @@ def _run_layer_with_overlap( layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 ): # Skip BULK overlap tests on HIP (column parallel or None with overlap_rs_dgrad=False) - if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None): - pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") + #if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None): + # pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -131,6 +131,7 @@ def _run_layer_with_overlap( os.environ["NVTE_TORCH_COMPILE"] = "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + print("test_cmd: ", test_cmd) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) os.unsetenv("PYTORCH_JIT") @@ -165,7 +166,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p): _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) -@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Bulk overlap is not yet supported on ROCm.") +#@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Bulk overlap is not yet supported on ROCm.") @pytest.mark.parametrize( "comm_type, quantization, connections", [ diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 69b2af81c..05ffeccfc 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -173,7 +173,10 @@ set(transformer_engine_cuda_arch_specific_sources) # Source files in both cuda and rocm list(APPEND transformer_engine_cpp_sources - transformer_engine.cpp + transformer_engine.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/comm_gemm_overlap.cpp gemm/config.cpp normalization/common.cpp normalization/layernorm/ln_api.cpp @@ -231,10 +234,7 @@ if(USE_CUDA) list(APPEND transformer_engine_cpp_sources cudnn_utils.cpp fused_attn/fused_attn.cpp - util/cuda_nvml.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/comm_gemm_overlap.cpp) + util/cuda_nvml.cpp) list(APPEND transformer_engine_cuda_sources transpose/quantize_transpose_vector_blockwise.cu fused_attn/fused_attn_f16_max512_seqlen.cu @@ -252,6 +252,7 @@ if(USE_CUDA) else() #ROCm specific source codes list(APPEND transformer_engine_cpp_sources + comm_gemm_overlap/rocm_comm_gemm_overlap.cpp fused_attn_rocm/fused_attn.cpp gemm/rocm_gemm.cu gemm/ck_grouped_gemm.cpp diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 1a0410898..a4e5c12ca 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -793,38 +793,6 @@ void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapp cudaMemcpyDeviceToDevice, stream)); } -void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, - bool local_chunk, bool rowwise) { - // Check element size - const size_t element_size = source.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); - - // Input data - const size_t source_size = source.numel(); - const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); - - // Userbuffers data - void *dst_ptr; - if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); - dst_ptr = _ubufs[_tp_id].dptr(); - } else { - NVTE_CHECK(_ubuf.numel() == source_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); - dst_ptr = _ubuf.dptr(); - } - - // Copy data - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, - cudaMemcpyDeviceToDevice, stream)); -} - TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index fcccd396a..eb6b86e07 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -378,7 +378,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE)); (*comm)->flags = reinterpret_cast( #ifdef __HIP_PLATFORM_AMD__ - (reinterpret_cast((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); + (reinterpret_cast((*comm)->flags_baseptr) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); #else ((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); #endif diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 1eb958ef7..2d5d0cbd9 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -7,17 +7,15 @@ ************************************************************************/ #include +#include +#include #ifdef __HIP_PLATFORM_AMD__ -#include #include -#include "amd_detail/hip_float8.h" #define half_dtype hip_bfloat16 #define __nv_fp8_e5m2 te_hip_fp8_e5m2 #define __nv_fp8_e4m3 te_hip_fp8_e4m3 #else -#include -#include #if __CUDA_ARCH__ >= 800 #define half_dtype nv_bfloat16 @@ -54,7 +52,7 @@ while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ } \ ((unsigned int *)counters)[chunk] = 1; \ - __threadfence(); \ + __threadfence_system(); \ } \ if (blockIdx.x == 0) __syncthreads(); \ } @@ -157,7 +155,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -237,7 +235,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[myrank][lineoffset + line] = sum; } __syncthreads(); - if (threadIdx.x == 0) __threadfence(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -502,7 +500,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -733,7 +731,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); __shared__ int lastSM; @@ -1053,7 +1051,7 @@ __global__ void __launch_bounds__(MAX_THREADS) #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } } @@ -1148,7 +1146,7 @@ __global__ void __launch_bounds__(MAX_THREADS) #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } } @@ -1362,7 +1360,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence(); + if (threadIdx.x == 0) __threadfence_system(); __syncthreads(); __shared__ int lastSM; @@ -2098,7 +2096,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence(); + __threadfence_system(); atomicAdd_system(flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2160,7 +2158,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence(); + __threadfence_system(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2218,7 +2216,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence(); + __threadfence_system(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2248,7 +2246,7 @@ __global__ void __launch_bounds__(MAX_THREADS) #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } } @@ -2289,7 +2287,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat } __syncthreads(); if (!threadIdx.x) { - __threadfence(); + __threadfence_system(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } @@ -2323,7 +2321,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } } @@ -2638,7 +2636,7 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } @@ -2652,7 +2650,7 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } } @@ -2668,7 +2666,7 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i #ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); #else - __threadfence(); + __threadfence_system(); #endif } } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 3cda89ca1..3b280c98f 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -17,7 +17,6 @@ #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" -#define NVTE_COMM_OVERLAP_MAX_STREAMS 7 namespace transformer_engine { @@ -39,7 +38,7 @@ enum class CommOverlapAlgo { ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, ATOMIC_GEMM_RS_P2P = 7, - EXTERNAL_BULK_OVERLAP_AG = 8 + EXTERNAL_BULK_OVERLAP_AG = 8, }; class CommOverlapCore { @@ -195,7 +194,7 @@ class CommOverlapBase : public CommOverlapCore { CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); @@ -262,7 +261,7 @@ class CommOverlapBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override { - NVTE_ERROR("Operation not supported."); + NVTE_ERROR("Operation not supported."); } }; // CommOverlapBase @@ -288,17 +287,13 @@ class CommOverlapP2PBase : public CommOverlapCore { void initialize(const std::vector &buffer_shape, DType buffer_dtype, CommOverlapType comm_type, bool aggregate); - private: - void initialize(const std::vector &buffer_shape, DType buffer_dtype, - CommOverlapType comm_type, bool aggregate); - public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + CommOverlapType comm_type, int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); @@ -359,7 +354,7 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t stream_main) override; /* ** ROCm Multiring ReduceScatter + GEMM - */ + */ void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, @@ -368,7 +363,7 @@ class CommOverlapP2PBase : public CommOverlapCore { /* ** ROCm Multiring AllGather + GEMM - */ + */ void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 78f0134d5..451d34adc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -544,7 +544,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); @@ -565,7 +565,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 40cc16c23..8017ae17a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -502,7 +502,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { int, int, int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), - py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_ROCM_MAX_RINGS, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) @@ -520,7 +520,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), - py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, + py::arg("num_max_streams") = NVTE_ROCM_MAX_RINGS, py::arg("comm_cga_size") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b83547e8d..beb12995e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,3 +1,4 @@ +import time, random # This file was modified for portability to AMDGPU # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. @@ -324,6 +325,8 @@ def initialize_ub( "fc2_wgrad", "proj_fprop", "fc2_fprop", + #"qkv_dgrad", + #"fc1_dgrad", "qkv_wgrad", "fc1_wgrad" ], @@ -482,9 +485,12 @@ def add_ub( comm_priority=comm_priority, rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) + time.sleep(random.uniform(0,1)) _ub_communicators[(name, quantization_mode)] = ub_obj + WORLD_RANK = int(os.getenv("RANK", "0")) for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): + print(f"[rank{WORLD_RANK}] at the beginning of for-loop, user_ub_cfg: {user_ub_cfg}") if user_ub_cfg is not None: for name in dgrad_reduce_scatter_overlap: if ( @@ -509,6 +515,7 @@ def add_ub( "Please use 'ring_exchange' method instead." ) + print(f"[rank{WORLD_RANK}] before add_ub for-loop, methods: {methods}") for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) if user_ub_cfg is not None and name in user_ub_cfg: @@ -517,6 +524,7 @@ def add_ub( ) ub_cfg.update(user_ub_cfg[name]) ub_cfg["fp8_buf"] = fp8_buf + print(f"[rank{WORLD_RANK}] before add_ub, ub_cfg: {ub_cfg}") add_ub(name, quantization_mode, **ub_cfg) From 84209ad0affbd4d84d2a298f1bd19bb7dce759e7 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Mon, 16 Mar 2026 22:02:29 -0500 Subject: [PATCH 07/69] Cleanup and workspace fix --- .../distributed/run_layer_with_overlap.py | 10 ++++----- .../distributed/test_comm_gemm_overlap.py | 7 +++---- .../userbuffers/userbuffers.cu | 2 +- .../transformer_engine/comm_gemm_overlap.h | 2 ++ transformer_engine/pytorch/module/base.py | 21 ++++++++----------- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 0c6df7f2e..bd431db58 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -248,7 +248,7 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--debug", action="store_true", - default=True, + default=False, help="Print out additional debug information.", ) parser.add_argument( @@ -388,7 +388,7 @@ def _train(opts): " layers. Use --num-heads or --head-dim for other cases." ) - def dist_print(msg, src=None, end="\n", debug=True, error=False): + def dist_print(msg, src=None, end="\n", debug=False, error=False): if debug and not opts.debug: return stream = sys.stderr if error else sys.stdout @@ -439,8 +439,6 @@ def dist_print(msg, src=None, end="\n", debug=True, error=False): ] if opts.first_last_layers_bf16 and opts.fp8: quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE) - - dist_print(f"Before initialize_ub, opts: {opts}") te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], @@ -565,8 +563,8 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 + rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 5e-2 + atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else 1e-2 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index ffa40f0c4..9f79fe5ad 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -100,8 +100,8 @@ def _run_layer_with_overlap( layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 ): # Skip BULK overlap tests on HIP (column parallel or None with overlap_rs_dgrad=False) - #if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None): - # pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") + if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None): + pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -131,7 +131,6 @@ def _run_layer_with_overlap( os.environ["NVTE_TORCH_COMPILE"] = "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" - print("test_cmd: ", test_cmd) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) os.unsetenv("PYTORCH_JIT") @@ -166,7 +165,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p): _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) -#@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Bulk overlap is not yet supported on ROCm.") +@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Bulk overlap is not yet supported on ROCm.") @pytest.mark.parametrize( "comm_type, quantization, connections", [ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 2d5d0cbd9..83e0859de 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -52,7 +52,7 @@ while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ } \ ((unsigned int *)counters)[chunk] = 1; \ - __threadfence_system(); \ + __threadfence_system(); \ } \ if (blockIdx.x == 0) __syncthreads(); \ } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 3b280c98f..96d104772 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -17,6 +17,8 @@ #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" +#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 + namespace transformer_engine { diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index beb12995e..61e812ac7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,4 +1,3 @@ -import time, random # This file was modified for portability to AMDGPU # Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. @@ -290,15 +289,18 @@ def initialize_ub( flush=True, ) - # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls + # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls. + # The workspace must have enough copies for max(num_max_streams, tp_size) compute streams, + # since CommOverlapCore creates that many streams and divides the workspace among them. + num_workspace_copies = max(_NUM_MAX_UB_STREAMS, tp_size) global _cublas_workspace if _cublas_workspace is None: - _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) - elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS: + _cublas_workspace = get_workspace().repeat((num_workspace_copies if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS)) + elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * (num_workspace_copies if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS): # This ensures we don't do `.repeat()` on an already expanded workspace _cublas_workspace = torch.empty( get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" - ).repeat(_NUM_MAX_UB_STREAMS) + ).repeat((num_workspace_copies if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS)) # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe layers_all_gather_overlap = [ @@ -325,8 +327,8 @@ def initialize_ub( "fc2_wgrad", "proj_fprop", "fc2_fprop", - #"qkv_dgrad", - #"fc1_dgrad", + "qkv_dgrad", + "fc1_dgrad", "qkv_wgrad", "fc1_wgrad" ], @@ -485,12 +487,9 @@ def add_ub( comm_priority=comm_priority, rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) - time.sleep(random.uniform(0,1)) _ub_communicators[(name, quantization_mode)] = ub_obj - WORLD_RANK = int(os.getenv("RANK", "0")) for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): - print(f"[rank{WORLD_RANK}] at the beginning of for-loop, user_ub_cfg: {user_ub_cfg}") if user_ub_cfg is not None: for name in dgrad_reduce_scatter_overlap: if ( @@ -515,7 +514,6 @@ def add_ub( "Please use 'ring_exchange' method instead." ) - print(f"[rank{WORLD_RANK}] before add_ub for-loop, methods: {methods}") for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) if user_ub_cfg is not None and name in user_ub_cfg: @@ -524,7 +522,6 @@ def add_ub( ) ub_cfg.update(user_ub_cfg[name]) ub_cfg["fp8_buf"] = fp8_buf - print(f"[rank{WORLD_RANK}] before add_ub, ub_cfg: {ub_cfg}") add_ub(name, quantization_mode, **ub_cfg) From c669bd2f9225a5411def91e9d15f974a199da437 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Mon, 16 Mar 2026 22:34:03 -0500 Subject: [PATCH 08/69] Guard layer registration in UB --- transformer_engine/pytorch/module/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 61e812ac7..64a094944 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -499,13 +499,17 @@ def add_ub( ): wgrad_name = name.replace("dgrad", "wgrad") assert wgrad_name not in user_ub_cfg - layers_reduce_scatter_overlap.remove(wgrad_name) - layers_all_gather_overlap.remove(name) - layers_reduce_scatter_overlap.append(name) + if wgrad_name in layers_reduce_scatter_overlap: + layers_reduce_scatter_overlap.remove(wgrad_name) + if name in layers_all_gather_overlap: + layers_all_gather_overlap.remove(name) + if name not in layers_reduce_scatter_overlap: + layers_reduce_scatter_overlap.append(name) if name in methods["bulk"]: methods["bulk"].remove(name) new_method = user_ub_cfg[name]["method"] - methods[new_method].append(name) + if name not in methods[new_method]: + methods[new_method].append(name) if IS_HIP_EXTENSION and user_ub_cfg is not None: for name, cfg in user_ub_cfg.items(): From 804090986f702a18db98da3398814313e49234b3 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 17 Mar 2026 00:55:59 -0500 Subject: [PATCH 09/69] Cleanup of profiling example for rocm --- .../te_layer_with_overlap.py | 6 + .../te_layer_with_overlap_profile.py | 504 ------------------ .../pytorch/comm_gemm_overlap/ub_config.json | 15 - 3 files changed, 6 insertions(+), 519 deletions(-) delete mode 100644 examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py delete mode 100644 examples/pytorch/comm_gemm_overlap/ub_config.json diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index 8b3fe542a..3a9b94fa7 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -20,6 +22,10 @@ import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.common.recipe import Format, DelayedScaling +from torch.utils.cpp_extension import IS_HIP_EXTENSION + +assert (not IS_HIP_EXTENSION), "Please use rocm_te_layer_overlap_profile.py with HIP." + warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py deleted file mode 100644 index 02b9b9696..000000000 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py +++ /dev/null @@ -1,504 +0,0 @@ -#!/usr/bin/python3 - -# This file was modified for portability to AMDGPU -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import os -import sys -import socket -import fcntl -import struct -import argparse -import warnings - -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel - -import torch.profiler - -import transformer_engine.pytorch as te -import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.common.recipe import Format, DelayedScaling - -warnings.filterwarnings("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=FutureWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -if not tex.device_supports_multicast(): - os.environ["UB_SKIPMC"] = "1" - - -def _te_layer_argtype(name): - te_layers = [ - te.Linear, - te.LayerNormLinear, - te.LayerNormMLP, - te.MultiheadAttention, - te.TransformerLayer, - ] - layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) - if name.lower() not in layer_map.keys(): - raise argparse.ArgumentTypeError( - f"Invalid TE layer name! Please choose from: {layer_map.keys()}" - ) - return layer_map[name.lower()] - - -def _parse_args(argv=None, namespace=None): - parser = argparse.ArgumentParser( - description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." - ) - parser.add_argument( - "-i", "--num-iters", type=int, default=10, help="Number of dummy 'training' iterations." - ) - parser.add_argument("-b", "--batch-size", type=int, default=8, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=16384, help="Input sequence length.") - parser.add_argument( - "-n", "--num-heads", type=int, default=64, help="Number of attention heads." - ) - parser.add_argument( - "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." - ) - parser.add_argument( - "--layer-type", - type=_te_layer_argtype, - default=te.TransformerLayer, - help="Transformer Engine layer to train with comm+GEMM overlap.", - ) - parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") - parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." - ) - parser.add_argument( - "--no-comm-overlap", - action="store_true", - default=False, - help="Disable the comm+GEMM overlap.", - ) - parser.add_argument( - "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." - ) - parser.add_argument( - "--tcp-init", - action="store_true", - default=False, - help="Initialize torch.distributed with TcpStore.", - ) - parser.add_argument( - "--bind-to-device", - action="store_true", - default=False, - help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", - ) - parser.add_argument( - "--bootstrap-backend", - type=str.lower, - default="nccl", - choices=["gloo", "mpi", "nccl"], - help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", - ) - parser.add_argument( - "-v", - "--verbose", - action="store_true", - default=False, - help="Print out from every rank instead of just the root rank of relevant process groups.", - ) - parser.add_argument( - "--debug", - action="store_true", - default=False, - help="Print out additional debug information.", - ) - parser.add_argument( - "--profile", - action="store_true", - default=False, - help="Enable PyTorch profiler.", - ) - parser.add_argument( - "--profile-dir", - type=str, - default="./logs/profiler_traces", - help="Directory to save PyTorch profiler traces.", - ) - parser.add_argument( - "--ub_config", - type=str, - default="./ub_config.json", - help="Userbuffer configuration file.", - ) - - args = parser.parse_args(argv, namespace) - if args.bootstrap_backend == "nccl": - args.bind_to_device = True - return args - - -def _get_layer_args(config, tp_group, tp_size, reference=False): - hidden_size = config.num_heads * config.head_dim - input_shape = [config.seq_length, config.batch_size, hidden_size] - args = [hidden_size] - kwargs = { - "params_dtype": torch.float32, - "device": "cuda", - "tp_group": tp_group, - "tp_size": tp_size, - "sequence_parallel": True, - } - kwargs["ub_overlap_ag"] = not config.no_comm_overlap - - if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = "row" - kwargs["ub_overlap_rs"] = not config.no_comm_overlap - kwargs["ub_name"] = "proj" - else: - input_shape[0] = config.seq_length // tp_size - if config.layer_type is te.LayerNormLinear: - args.append(3 * hidden_size) - kwargs["parallel_mode"] = "column" - kwargs["ub_name"] = "qkv" - else: - kwargs["set_parallel_mode"] = True - kwargs["ub_overlap_rs"] = not config.no_comm_overlap - if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: - # args.append(4 * hidden_size) - args.append(int(3.5 * hidden_size)) - - kwargs["seq_length"] = config.seq_length - if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: - args.append(config.num_heads) - kwargs["attention_dropout"] = 0.0 - kwargs["fuse_qkv_params"] = True - if config.layer_type is te.MultiheadAttention: - kwargs["input_layernorm"] = True - else: - kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap - kwargs["hidden_dropout"] = 0.0 - - return args, kwargs, input_shape - -def create_ub_cfgs(config_file: str, tp_size: int = 8): - import json - with open(config_file, 'r') as f: - data = json.load(f) - cfgs = {} - _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None - layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] - layers_all_gather_overlap = [ - "qkv_fprop", - "qkv_dgrad", - "proj_dgrad", - "proj_wgrad", - "fc1_fprop", - "fc1_dgrad", - "fc2_dgrad", - "fc2_wgrad", - ] - - for name, method in data.items(): - if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: - _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() - - cfg = { - "method": method, - "is_reduce_scatter": name in layers_reduce_scatter_overlap, - "num_sm": 1 if method == "ring_exchange" else 16, - "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": False, - "num_splits": tp_size if method == "ring_exchange" else 4, - "aggregate": False, - "atomic_gemm": False, - "use_ce": True, - "fp8_buf": name in layers_all_gather_overlap, - "comm_priority": _MAX_STREAM_PRIORITY, - "gemm_priority": _MIN_STREAM_PRIORITY, - } - - cfgs[name] = cfg - - return cfgs - -def _train(opts): - if "OMPI_COMM_WORLD_SIZE" in os.environ: - # Execution with `mpirun -np N` - WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) - WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) - LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) - opts.tcp_init = True - opts.bind_to_device = True - opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: - WORLD_RANK = int(os.getenv("RANK", "0")) - WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - NUM_NODES = WORLD_SIZE // LOCAL_SIZE - - # Initialize torch.distributed global process group and get DP/TP groups - torch.cuda.set_device(LOCAL_RANK) - dist_init_kwargs = { - "backend": "nccl", - "rank": WORLD_RANK, - "world_size": WORLD_SIZE, - } - if opts.tcp_init or NUM_NODES > 1: - if NUM_NODES > 1: - assert ( - "MASTER_ADDR" in os.environ - ), "Multi-node run requires MASTER_ADDR to be set in the environment." - MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) - MASTER_PORT = os.getenv("MASTER_PORT", "1234") - dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" - if opts.bind_to_device or opts.bootstrap_backend == "nccl": - dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") - assert dist.is_nccl_available() - dist.init_process_group(**dist_init_kwargs) - nccl_world = dist.new_group(backend="nccl") - - def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): - if debug and not opts.debug: - return - group_rank = dist.get_rank(group) - stream = sys.stderr if error else sys.stdout - if group_rank == src: - stream.write(f"[rank{WORLD_RANK}] {msg}{end}") - dist.barrier(group) - - dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") - - # Figure out process groups for tensor- and data-parallelism (if any) - if NUM_NODES > 1: - # Create a list of world ranks on this node - hostname = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), - ) - - if ifname is not None: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - hostname = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - - hostnames = [None for _ in range(WORLD_SIZE)] - dist.all_gather_object(hostnames, hostname) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - assert len(unique_hosts) == NUM_NODES - - ranks_per_node_list = [[] for _ in range(NUM_NODES)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0 - self_node_ranks = ranks_per_node_list[self_node_idx] - - if opts.num_replicas > 1: - # Split node ranks into multiple replicas - assert len(self_node_ranks) % opts.num_replicas == 0 - tp_size = len(self_node_ranks) // opts.num_replicas - ranks_per_replica_list = [] - for node_ranks in ranks_per_node_list: - for i in range(opts.num_replicas): - start = i * tp_size - end = start + tp_size - ranks_per_replica_list.append(node_ranks[start:end]) - - self_replica_idx = -1 - for i, replica_ranks in enumerate(ranks_per_replica_list): - if WORLD_RANK in replica_ranks: - self_replica_idx = i - break - assert self_replica_idx >= 0 - - else: - # The entire node is the tensor-parallel group - ranks_per_replica_list = ranks_per_node_list - self_replica_idx = self_node_idx - - tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") - ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) - dp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" - ) - - else: - if opts.num_replicas > 1: - # Mixed data- and tensor-parallelism on a single node - # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions - all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") - ranks_per_replica_tensor = all_ranks.reshape( - (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) - ) - tp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.tolist(), backend="nccl" - ) - dp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" - ) - else: - dp_group = None - tp_group = nccl_world - - tp_rank = dist.get_rank(tp_group) - tp_size = dist.get_world_size(tp_group) - dist_print( - f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", - group=tp_group, - ) - if dp_group is not None: - dp_rank = dist.get_rank(dp_group) - dist_print( - f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", - group=dp_group, - ) - else: - dp_rank = 0 - - # Intialize userbuffers - hidden_size = opts.num_heads * opts.head_dim - batched_size = opts.seq_length * opts.batch_size - if not opts.no_comm_overlap: - te.module.base.initialize_ub( - [batched_size, hidden_size], - tp_size, - use_fp8=opts.fp8, - dtype=torch.bfloat16, - bootstrap_backend=opts.bootstrap_backend, - ub_cfgs=create_ub_cfgs(opts.ub_config, tp_size) - ) - # Initialize the fused LayerNorm + Multi-layer Perceptron module - torch.manual_seed(opts.seed + dp_rank) - torch.cuda.manual_seed(opts.seed + tp_rank) - layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) - model = opts.layer_type(*layer_args, **layer_kwargs) - if dp_group is not None: - model = DistributedDataParallel(model, dim=1, process_group=dp_group) - - # Initialize optimizer with model parameters - optim = torch.optim.Adam(model.parameters(), lr=0.0001) - - # Fp8 recipe setup - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - - if opts.profile: - log_dir = os.path.join(opts.profile_dir, f"rank_{WORLD_RANK}") - os.makedirs(log_dir, exist_ok=True) - dist_print(f"Profiler traces will be saved to: {log_dir}", group=nccl_world) - - schedule = torch.profiler.schedule(wait=1, warmup=2, active=5, repeat=1) - - on_trace_ready = torch.profiler.tensorboard_trace_handler( - log_dir, worker_name=f"rank_{WORLD_RANK}" - ) - - profiler_activities = [ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ] - import time - - start_time = time.time() - with torch.profiler.profile( - schedule=schedule, - # record_shapes=True, - # with_stack=True, - # with_flops=True, - # with_modules=True, - on_trace_ready=on_trace_ready, - profile_memory=True, - activities=profiler_activities, - ) as prof: - dist_print("Starting training iterations...") - for i in range(opts.num_iters): - dist_print(f" Iter {i+1}", group=tp_group, debug=True) - - dist_print(" |-- Generate random input batch", group=tp_group, debug=True) - x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) - - dist_print(" |-- Forward pass", group=tp_group, debug=True) - with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - if isinstance(y, tuple): - out, *_ = y - else: - out = y - dist_print(" |-- Compute loss", group=tp_group, debug=True) - loss = out.sum() - - dist_print(" |-- Backward pass", group=tp_group, debug=True) - loss.backward() - - dist_print(" |-- Optimizer step", group=tp_group, debug=True) - optim.step() - - prof.step() - torch.cuda.synchronize() - end_time = time.time() - total_wall_clock_time = end_time - start_time - print(f"Total Wall Clock Time: {total_wall_clock_time:.4f} seconds") - # total_flops = sum([item.flops for item in prof.key_averages()]) - # print(f"Total FLOPs: {total_flops}") - else: - dist_print("Starting training iterations...") - for i in range(opts.num_iters): - dist_print(f" Iter {i+1}", group=tp_group, debug=True) - - dist_print(" |-- Generate random input batch", group=tp_group, debug=True) - x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) - - dist_print(" |-- Forward pass", group=tp_group, debug=True) - with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - if isinstance(y, tuple): - out, *_ = y - else: - out = y - dist_print(" |-- Compute loss", group=tp_group, debug=True) - loss = out.sum() - - dist_print(" |-- Backward pass", group=tp_group, debug=True) - loss.backward() - - dist_print(" |-- Optimizer step", group=tp_group, debug=True) - optim.step() - - - dist_print("Finished training!") - te.module.base.destroy_ub() - - dist_print("Destroying all process groups...", debug=True) - dist.destroy_process_group() - if opts.debug and WORLD_RANK == 0: - print("Exiting...\n", end="", flush=True) - - return 0 - - -if __name__ == "__main__": - sys.exit(_train(_parse_args())) \ No newline at end of file diff --git a/examples/pytorch/comm_gemm_overlap/ub_config.json b/examples/pytorch/comm_gemm_overlap/ub_config.json deleted file mode 100644 index a26c7f9f1..000000000 --- a/examples/pytorch/comm_gemm_overlap/ub_config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "qkv_fprop": "ring_exchange", - "fc1_fprop": "ring_exchange", - "fc2_dgrad": "ring_exchange", - "proj_wgrad": "ring_exchange", - "fc2_wgrad": "ring_exchange", - - - "proj_fprop": "ring_exchange", - "fc2_fprop": "ring_exchange", - - "qkv_dgrad": "ring_exchange", - "fc1_dgrad": "ring_exchange" - -} \ No newline at end of file From e37592309411a999734429eb4fde46366c26a57a Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 17 Mar 2026 10:50:49 -0500 Subject: [PATCH 10/69] Readd example script and update custom_map --- build_tools/hipify/custom_map.json | 3 +- .../rocm_te_layer_with_overlap.py | 453 ++++++++++++++++++ 2 files changed, 454 insertions(+), 2 deletions(-) create mode 100644 examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index a38306eb1..13d9d5d7b 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -1,7 +1,6 @@ { "custom_map" : { "" : "", - "util/cuda_runtime.h" : "util/hip_runtime.h", "" : "\"common/amd_detail/hip_float8.h\"", "" : "", "cuda_runtime.h\"" : "hip_runtime.h\"", @@ -17,7 +16,7 @@ "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", "#include " : "", "#include " : "", - "cudaLaunchKernel": "hipLaunchKernel", + "cudaLaunchKernel": "hipLaunchKernelExC", "CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t", "cudaLaunchConfig_t": "hipLaunchConfig_t", "cudaLaunchAttribute": "hipLaunchAttribute", diff --git a/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py new file mode 100644 index 000000000..b8c6d5cbd --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py @@ -0,0 +1,453 @@ +#!/usr/bin/python3 + +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +import os +import sys +import socket +import fcntl +import struct +import argparse +import warnings + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +import torch.profiler + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=10, help="Number of dummy 'training' iterations." + ) + parser.add_argument("-b", "--batch-size", type=int, default=8, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=16384, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument( + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." + ) + parser.add_argument( + "--no-comm-overlap", + action="store_true", + default=False, + help="Disable the comm+GEMM overlap.", + ) + parser.add_argument( + "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out from every rank instead of just the root rank of relevant process groups.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Enable PyTorch profiler.", + ) + parser.add_argument( + "--profile-dir", + type=str, + default="./logs/profiler_traces", + help="Directory to save PyTorch profiler traces.", + ) + args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True + return args + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + # args.append(4 * hidden_size) + args.append(int(3.5 * hidden_size)) + + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + NUM_NODES = WORLD_SIZE // LOCAL_SIZE + + # Initialize torch.distributed global process group and get DP/TP groups + torch.cuda.set_device(LOCAL_RANK) + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init or NUM_NODES > 1: + if NUM_NODES > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Figure out process groups for tensor- and data-parallelism (if any) + if NUM_NODES > 1: + # Create a list of world ranks on this node + hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] + dist.all_gather_object(hostnames, hostname) + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] + + if opts.num_replicas > 1: + # Split node ranks into multiple replicas + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i + break + assert self_replica_idx >= 0 + + else: + # The entire node is the tensor-parallel group + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx + + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + + else: + if opts.num_replicas > 1: + # Mixed data- and tensor-parallelism on a single node + # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions + all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + else: + dp_group = None + tp_group = nccl_world + + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + dist_print( + f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", + group=tp_group, + ) + if dp_group is not None: + dp_rank = dist.get_rank(dp_group) + dist_print( + f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", + group=dp_group, + ) + else: + dp_rank = 0 + + # Intialize userbuffers + hidden_size = opts.num_heads * opts.head_dim + batched_size = opts.seq_length * opts.batch_size + if not opts.no_comm_overlap: + te.module.base.initialize_ub( + [batched_size, hidden_size], + tp_size, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ) + # Initialize the fused LayerNorm + Multi-layer Perceptron module + torch.manual_seed(opts.seed + dp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) + if dp_group is not None: + model = DistributedDataParallel(model, dim=1, process_group=dp_group) + + # Initialize optimizer with model parameters + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + if opts.profile: + log_dir = os.path.join(opts.profile_dir, f"rank_{WORLD_RANK}") + os.makedirs(log_dir, exist_ok=True) + dist_print(f"Profiler traces will be saved to: {log_dir}", group=nccl_world) + + schedule = torch.profiler.schedule(wait=1, warmup=2, active=5, repeat=1) + + on_trace_ready = torch.profiler.tensorboard_trace_handler( + log_dir, worker_name=f"rank_{WORLD_RANK}" + ) + + profiler_activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + import time + + start_time = time.time() + with torch.profiler.profile( + schedule=schedule, + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + on_trace_ready=on_trace_ready, + profile_memory=True, + activities=profiler_activities, + ) as prof: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + prof.step() + torch.cuda.synchronize() + end_time = time.time() + total_wall_clock_time = end_time - start_time + print(f"Total Wall Clock Time: {total_wall_clock_time:.4f} seconds") + # total_flops = sum([item.flops for item in prof.key_averages()]) + # print(f"Total FLOPs: {total_flops}") + else: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + + dist_print("Finished training!") + te.module.base.destroy_ub() + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) \ No newline at end of file From c6bd974515abc966e4439b5b4b49e10076f61875 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 17 Mar 2026 10:51:56 -0500 Subject: [PATCH 11/69] fix typo --- build_tools/hipify/custom_map.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 13d9d5d7b..74c8e0dd4 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -16,7 +16,7 @@ "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", "#include " : "", "#include " : "", - "cudaLaunchKernel": "hipLaunchKernelExC", + "cudaLaunchKernelExC": "hipLaunchKernelExC", "CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t", "cudaLaunchConfig_t": "hipLaunchConfig_t", "cudaLaunchAttribute": "hipLaunchAttribute", From d76aa06f0c512add7f9d543f6b786116845f1ffc Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Wed, 18 Mar 2026 00:52:05 -0500 Subject: [PATCH 12/69] MI300 test skips due to jittery results --- tests/pytorch/distributed/test_comm_gemm_overlap.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 9f79fe5ad..23f51f56f 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -13,6 +13,7 @@ import transformer_engine.pytorch.cpp_extensions as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION +from transformer_engine.pytorch.utils import get_device_compute_capability if torch.cuda.device_count() < 2: @@ -102,6 +103,12 @@ def _run_layer_with_overlap( # Skip BULK overlap tests on HIP (column parallel or None with overlap_rs_dgrad=False) if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None): pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.") + # On gfx942, non-determinism across the 8 XCDs causes small jitter that compounds + # This should not affect training convergence, but creates larger numerical differences. + if (IS_HIP_EXTENSION + and get_device_compute_capability() < (9, 5) + and layer_type == te.TransformerLayer.__name__): + pytest.skip("TransformerLayer overlap can exceed numerical tolerance on pre-MI350 due to jitter.") test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), From ae979d05c9b50a16b45edd25e73d14d473cda9f3 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Wed, 18 Mar 2026 00:54:48 -0500 Subject: [PATCH 13/69] Comment regarding sm_margin performance --- transformer_engine/pytorch/module/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 64a094944..67d2b87c6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -373,7 +373,7 @@ def get_default_config(name): "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": not method == "ring_exchange" and not IS_HIP_EXTENSION, + "set_sm_margin": not method == "ring_exchange" and not IS_HIP_EXTENSION, # Default set to False for HIP for performance "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, From b58cbd19e7b5753f9edb27993ac1ae23dbc698a6 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Mon, 23 Mar 2026 13:43:48 -0500 Subject: [PATCH 14/69] Variable renamed, pybind fix, tolerance tightening --- build_tools/hipify/custom_map.json | 4 +- .../distributed/run_layer_with_overlap.py | 6 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 6 - .../userbuffers/userbuffers.cu | 4 +- .../userbuffers/userbuffers.h | 6 +- .../common/util/pybind_helper.h | 107 +++++++++--------- transformer_engine/pytorch/module/base.py | 11 +- .../pytorch/ops/fused/__init__.py | 2 - 8 files changed, 73 insertions(+), 73 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 74c8e0dd4..1244d2e09 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -8,7 +8,9 @@ "CUfunc_cache" : "hipFuncCache_t", "" : "", "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", - "__nv_bfloat162":"__hip_bfloat162", + "__nv_bfloat162" : "__hip_bfloat162", + "__nv_fp8_e5m2" : "te_hip_fp8_e5m2", + "__nv_fp8_e4m3" : "te_hip_fp8_e4m3", "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", "__nv_fp4_e2m1" : "__hip_fp4_e2m1", "__nv_fp4x2_e2m1" : "__hip_fp4x2_e2m1", diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index bd431db58..067c6ba63 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -330,7 +330,11 @@ def _compare_tensors(name, test, ref, rtol, atol): ) if abs_err <= atol: numerics_info += f" abs. error = {abs_err} (tol = {atol})" + rel_diffs = diff / torch.clamp(torch.abs(ref.flatten()), min=1e-5) + failed_mask = (diff > atol) & (rel_diffs > rtol) + num_actually_failing = failed_mask.sum().item() + numerics_info += f"\nElements violating both atol and rtol: {num_actually_failing} out of" return numerics_failed, numerics_info @@ -563,7 +567,7 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 5e-2 + rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 3e-2 atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else 1e-2 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index a4e5c12ca..241f26af6 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -23,12 +23,6 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 -#ifdef __HIP_PLATFORM_AMD__ -#define half_dtype hip_bfloat16 -#define __nv_fp8_e5m2 te_hip_fp8_e5m2 -#define __nv_fp8_e4m3 te_hip_fp8_e4m3 -#endif - using namespace std::placeholders; namespace transformer_engine { diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 83e0859de..b569378d9 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -13,8 +13,6 @@ #ifdef __HIP_PLATFORM_AMD__ #include #define half_dtype hip_bfloat16 -#define __nv_fp8_e5m2 te_hip_fp8_e5m2 -#define __nv_fp8_e4m3 te_hip_fp8_e4m3 #else #if __CUDA_ARCH__ >= 800 @@ -2339,7 +2337,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Return TRUE if two ranks share the same NV domain #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) -#ifndef __HIP_PLATFORM_AMD__ // Moved to header for visibility +#ifndef __HIP_PLATFORM_AMD__ // Moved to userbuffers.h for visibility // Index corresponds to the type of flag: // 0 - Send index counter // 1 - CE start index counter diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index d897a0f37..ddcd6689a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -38,9 +38,13 @@ using ExtBarrierOp = std::function; #define NVTE_LAUNCH_CPU 2 #define NVTE_MAX_NVLINK 32 +#ifdef __HIP_PLATFORM_AMD__ #define NVTE_ROCM_MAX_TP_SIZE 8 // Maximum # of rings possible for ring_exchange #define NVTE_ROCM_MAX_RINGS (NVTE_ROCM_MAX_TP_SIZE - 1) +#else +#define NVTE_ROCM_MAX_RINGS 1 +#endif #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -69,7 +73,7 @@ using ExtBarrierOp = std::function; #define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3) #define NVTE_MAX_SHARP 16 -#ifdef __HIP_PLATFORM_AMD__ // Moved to header for visibility +#ifdef __HIP_PLATFORM_AMD__ // Moved from userbuffers.cu for visibility // Index corresponds to the type of flag: // 0 - Send index counter // 1 - CE start index counter diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 456d2ea50..68a8d963d 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -32,60 +32,6 @@ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); #endif -#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ - pybind11::enum_(m, "CommOverlapType", \ - pybind11::module_local()) \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo", \ - pybind11::module_local()) \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ - .value("EXTERNAL_BULK_OVERLAP_AG", \ - transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ - py::call_guard()) \ - .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ - py::call_guard()) \ - .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ - py::call_guard()) \ - .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ - py::call_guard()); \ - py::class_, \ - transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ - pybind11::module_local()) \ - .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ - py::call_guard()); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def( \ - "get_stream_priority_range", \ - [](int device_id = -1) { \ - int low_pri, high_pri; \ - transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ - return std::make_pair(low_pri, high_pri); \ - }, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); - #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ .value("kByte", transformer_engine::DType::kByte) \ @@ -147,11 +93,62 @@ .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ + .value("EXTERNAL_BULK_OVERLAP_AG", \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def( \ + "get_stream_priority_range", \ + [](int device_id = -1) { \ + int low_pri, high_pri; \ + transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ + return std::make_pair(low_pri, high_pri); \ + }, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); \ NVTE_DECLARE_FUSED_ATTENTION_HANDLES(m) \ pybind11::enum_( \ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ - NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) #endif diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 67d2b87c6..a25dc0009 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -292,15 +292,18 @@ def initialize_ub( # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls. # The workspace must have enough copies for max(num_max_streams, tp_size) compute streams, # since CommOverlapCore creates that many streams and divides the workspace among them. - num_workspace_copies = max(_NUM_MAX_UB_STREAMS, tp_size) + if IS_HIP_EXTENSION: + num_workspace_copies = max(_NUM_MAX_UB_STREAMS, tp_size) + else: + num_workspace_copies = _NUM_MAX_UB_STREAMS global _cublas_workspace if _cublas_workspace is None: - _cublas_workspace = get_workspace().repeat((num_workspace_copies if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS)) - elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * (num_workspace_copies if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS): + _cublas_workspace = get_workspace().repeat(num_workspace_copies) + elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * (num_workspace_copies): # This ensures we don't do `.repeat()` on an already expanded workspace _cublas_workspace = torch.empty( get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" - ).repeat((num_workspace_copies if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS)) + ).repeat(num_workspace_copies) # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe layers_all_gather_overlap = [ diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index a2a01b728..e1a51197d 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From e5d744603bb4576bf332ab286211cfa31427b79f Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Mar 2026 12:54:55 -0500 Subject: [PATCH 15/69] Remove git conflict --- ci/pytorch.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 38ab019da..a7259ebf6 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -94,10 +94,7 @@ run_test_config_mgpu(){ #run in parallel on CI and it affects timing run_default_fa 1 test_gemm_sm_count.py run_default_fa 3 test_sanity_import.py -<<<<<<< HEAD run_default_fa 3 distributed/test_cast_master_weights_to_fp8.py -======= ->>>>>>> 16b62493 (Cleanup and RS flag race condition fix) run_default_fa 3 distributed/test_comm_gemm_overlap.py run_default_fa 2 distributed/test_fusible_ops.py run_default_fa 2 distributed/test_numerics.py From 7734ce5cce6acd34a78519e3de964640a6a7daec Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Thu, 26 Mar 2026 17:01:17 -0500 Subject: [PATCH 16/69] Address style and hip/cu specific paths --- build_tools/hipify/custom_map.json | 1 - .../rocm_te_layer_with_overlap.py | 3 +-- .../distributed/run_layer_with_overlap.py | 4 ++-- transformer_engine/common/CMakeLists.txt | 2 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 2 +- .../userbuffers/userbuffers.cu | 20 +++++++++++++++++++ .../transformer_engine/comm_gemm_overlap.h | 10 +++++++--- transformer_engine/pytorch/csrc/extensions.h | 4 ++-- .../pytorch/csrc/extensions/pybind.cpp | 4 ++-- 9 files changed, 36 insertions(+), 14 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 1244d2e09..6525731f5 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -8,7 +8,6 @@ "CUfunc_cache" : "hipFuncCache_t", "" : "", "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", - "__nv_bfloat162" : "__hip_bfloat162", "__nv_fp8_e5m2" : "te_hip_fp8_e5m2", "__nv_fp8_e4m3" : "te_hip_fp8_e4m3", "cuda::getCurrentCUDAStream" : "hip::getCurrentHIPStreamMasqueradingAsCUDA", diff --git a/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py index b8c6d5cbd..7cfdbf09e 100644 --- a/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/rocm_te_layer_with_overlap.py @@ -448,6 +448,5 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) return 0 - if __name__ == "__main__": - sys.exit(_train(_parse_args())) \ No newline at end of file + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 067c6ba63..8da4a4c34 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -567,8 +567,8 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 3e-2 - atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else 1e-2 + rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else .03 + atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else .01 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 05ffeccfc..65b6c3997 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -173,7 +173,7 @@ set(transformer_engine_cuda_arch_specific_sources) # Source files in both cuda and rocm list(APPEND transformer_engine_cpp_sources - transformer_engine.cpp + transformer_engine.cpp comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/comm_gemm_overlap.cpp diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 241f26af6..4be1ffb65 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -357,7 +357,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); + (cudaEvent_t)_comm_launch_event); } else { reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index b569378d9..781760411 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -153,7 +153,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); +#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); +#else + if (threadIdx.x == 0) __threadfence(); +#endif __syncthreads(); if (threadIdx.x < RANKS) { @@ -233,7 +237,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[myrank][lineoffset + line] = sum; } __syncthreads(); +#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); +#else + if (threadIdx.x == 0) __threadfence(); +#endif __syncthreads(); if (threadIdx.x < RANKS) { @@ -498,7 +506,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); +#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); +#else + if (threadIdx.x == 0) __threadfence(); +#endif __syncthreads(); if (threadIdx.x < RANKS) { @@ -729,7 +741,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); +#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); +#else + if (threadIdx.x == 0) __threadfence(); +#endif __syncthreads(); __shared__ int lastSM; @@ -1358,7 +1374,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); +#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); +#else + if (threadIdx.x == 0) __threadfence(); +#endif __syncthreads(); __shared__ int lastSM; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 96d104772..1b7300c2c 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -17,7 +17,11 @@ #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" +#ifdef __HIP_PLATFORM_AMD__ +#define NVTE_COMM_OVERLAP_MAX_STREAMS NVTE_ROCM_MAX_RINGS +#else #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#endif namespace transformer_engine { @@ -196,7 +200,7 @@ class CommOverlapBase : public CommOverlapCore { CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, - int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 2, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); @@ -280,7 +284,7 @@ class CommOverlapP2PBase : public CommOverlapCore { std::vector _ubufs; std::vector _stream_send, l_stream_send, l_stream_recv; cudaStream_t _stream_recv; - cudaEvent_t _stop_send, _stop_recv, l_stop_recv[7]; + cudaEvent_t _stop_send, _stop_recv, l_stop_recv[NVTE_ROCM_MAX_RINGS]; uint64_t _ag_signal_base = 0; uint64_t _rs_signal_base = 0; @@ -295,7 +299,7 @@ class CommOverlapP2PBase : public CommOverlapCore { CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - CommOverlapType comm_type, int num_max_streams = NVTE_ROCM_MAX_RINGS, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 451d34adc..78f0134d5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -544,7 +544,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, - int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 2, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); @@ -565,7 +565,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_ROCM_MAX_RINGS, int comm_cga_size = 2, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8017ae17a..40cc16c23 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -502,7 +502,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { int, int, int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), - py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_ROCM_MAX_RINGS, + py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) @@ -520,7 +520,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), - py::arg("num_max_streams") = NVTE_ROCM_MAX_RINGS, py::arg("comm_cga_size") = 1, + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) From c169c7529655e86a0d37a8ba412153646e756023 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Fri, 27 Mar 2026 09:27:34 -0500 Subject: [PATCH 17/69] HIP guards --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 28 +++++++++++-------- .../userbuffers/userbuffers.cu | 16 ----------- .../transformer_engine/comm_gemm_overlap.h | 7 +++-- 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 4be1ffb65..ae24291ca 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -722,23 +722,26 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); } + +#ifdef __HIP_PLATFORM_AMD__ for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { - cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); - l_stream_send.push_back(std::move(stream)); - } - for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { - cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); - l_stream_recv.push_back(std::move(stream)); + { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_send.push_back(std::move(stream)); + } + { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_recv.push_back(std::move(stream)); + } } +#endif + NVTE_CHECK_CUDA( cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); - for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { - NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&l_stop_recv[i], 0)); - } } CommOverlapP2PBase::~CommOverlapP2PBase() { @@ -748,11 +751,12 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) { cudaStreamDestroy(_stream_send[i]); } +#ifdef __HIP_PLATFORM_AMD__ for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) { cudaStreamDestroy(l_stream_recv[i]); cudaStreamDestroy(l_stream_send[i]); - cudaEventDestroy(l_stop_recv[i]); } +#endif } void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 781760411..ee75f9d1a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -153,11 +153,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); -#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); -#else - if (threadIdx.x == 0) __threadfence(); -#endif __syncthreads(); if (threadIdx.x < RANKS) { @@ -506,11 +502,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); -#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); -#else - if (threadIdx.x == 0) __threadfence(); -#endif __syncthreads(); if (threadIdx.x < RANKS) { @@ -741,11 +733,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); -#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); -#else - if (threadIdx.x == 0) __threadfence(); -#endif __syncthreads(); __shared__ int lastSM; @@ -1374,11 +1362,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); -#ifdef __HIP_PLATFORM_AMD__ if (threadIdx.x == 0) __threadfence_system(); -#else - if (threadIdx.x == 0) __threadfence(); -#endif __syncthreads(); __shared__ int lastSM; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 1b7300c2c..7c71468de 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -282,9 +282,12 @@ class CommOverlapP2PBase : public CommOverlapCore { int _num_ubuf_chunks; int _self_chunk_id; std::vector _ubufs; - std::vector _stream_send, l_stream_send, l_stream_recv; + std::vector _stream_send; +#ifdef __HIP_PLATFORM_AMD__ + std::vector l_stream_send, l_stream_recv; +#endif cudaStream_t _stream_recv; - cudaEvent_t _stop_send, _stop_recv, l_stop_recv[NVTE_ROCM_MAX_RINGS]; + cudaEvent_t _stop_send, _stop_recv; uint64_t _ag_signal_base = 0; uint64_t _rs_signal_base = 0; From 80e0aab6cab6c1cefbc1b0e5abfc5fd0aa6dcbbd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Mar 2026 13:58:07 -0500 Subject: [PATCH 18/69] initial impl --- build_tools/hipify/custom_map.json | 3 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 66 ++++ transformer_engine/common/CMakeLists.txt | 2 +- .../hadamard_transform/hadamard_transform.cu | 318 ++++++++++++++++++ .../transformer_engine/hadamard_transform.h | 4 - transformer_engine/pytorch/csrc/common.h | 2 - transformer_engine/pytorch/csrc/pybind.h | 4 - transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- 8 files changed, 393 insertions(+), 14 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 872d38efa..92b3f0a44 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -15,7 +15,8 @@ "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", "#include " : "", - "#include " : "" + "#include " : "", + "#include " : "" } } diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f5..b3b5ad8df 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -246,3 +246,69 @@ def test_nvfp4_quantization_noncontiguous_inputs( use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, ) + + +def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: + """Pure-Python reference WHT: tiled 16-point butterfly, normalised by 0.25.""" + import numpy as np + x_np = x.float().cpu().numpy().copy() + rows, cols = x_np.shape + d = np.array([((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], dtype=np.float32) + for c in range(0, cols, 16): + tile = x_np[:, c:c+16] * d + h = 1 + while h < 16: + for i in range(0, 16, h * 2): + for j in range(i, i + h): + a, b = tile[:, j].copy(), tile[:, j + h].copy() + tile[:, j], tile[:, j + h] = a + b, a - b + h *= 2 + x_np[:, c:c+16] = tile * 0.25 + return torch.from_numpy(x_np) + + +@pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) +def test_hadamard_transform_amax(rows, cols): + """ + Tests nvte_hadamard_transform_amax via NVFP4Quantizer (with_rht=True). + Exercises the WHT kernel without requiring a full NVFP4 recipe. + Checks: + - amax_rowwise == max|x| (pre-RHT amax of raw input) + - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) + """ + torch.manual_seed(42) + x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() + + quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + out = quantizer(x) + + # amax_rowwise: pre-RHT, should equal max|x| + expected_rowwise_amax = x.float().abs().max() + torch.testing.assert_close( + out._amax_rowwise.float().squeeze(), + expected_rowwise_amax, + rtol=1e-3, atol=1e-3, + msg=f"pre-RHT amax mismatch rows={rows} cols={cols}", + ) + + # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| + sign_mask_t = quantizer.rht_matrix_random_sign_mask_t + x_t = x.t().contiguous() # (cols, rows) + wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t) + expected_colwise_amax = wht_x_t.float().abs().max() + + torch.testing.assert_close( + out._amax_columnwise.float().squeeze().item(), + float(expected_colwise_amax), + rtol=2e-2, atol=2e-2, + msg=f"post-RHT amax mismatch rows={rows} cols={cols}", + ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8d5537368..847fbcb8e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -223,6 +223,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu activation/relu.cu activation/swiglu.cu + hadamard_transform/hadamard_transform.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) if(USE_CUDA) @@ -247,7 +248,6 @@ if(USE_CUDA) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu) else() #ROCm specific source codes diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index b901f9023..e2f449673 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -23,6 +23,8 @@ namespace { constexpr int kThreadsPerWarp = 32; constexpr float k16x16HadamardScale = 0.25f; +#ifndef __HIP_PLATFORM_AMD__ + template __device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, @@ -658,12 +660,261 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 } +#endif // __HIP_PLATFORM_AMD__ } // namespace +#ifdef __HIP_PLATFORM_AMD__ + +namespace { + +static constexpr int kHadamardDim = 16; +static constexpr int kWarpSize = 64; +static constexpr int kThreadsPerWHT = 4; +static constexpr int kElemsPerThread = 4; +static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 +static constexpr int kWarpsPerBlock = 4; +static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 +static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 +static constexpr float kHadamardScale = 0.25f; + +// ds_swizzle: sub-wavefront exchange without LDS. +// Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu. +__device__ __forceinline__ float ds_swizzle_xor1(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +__device__ __forceinline__ float ds_swizzle_xor2(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +// BF16 helpers +__device__ __forceinline__ float to_f32 (__hip_bfloat16 v) { return static_cast(v); } +__device__ __forceinline__ __hip_bfloat16 to_bf16(float v) { return static_cast<__hip_bfloat16>(v); } + +// Bit-cast __hip_bfloat16->uint16_t without address-of-temporary. +__device__ __forceinline__ uint16_t bf16_to_bits(__hip_bfloat16 v) { + uint16_t bits; __builtin_memcpy(&bits, &v, sizeof(uint16_t)); return bits; +} + +// Unpack/pack 4 BF16 values as uint64_t (vectorised global load/store). +// Same trick as cast_transpose_mxfp4_kernel_shuffled.cu::bf16x4_to_float4. +__device__ __forceinline__ void unpack_bf16x4(uint64_t p, + float& v0, float& v1, float& v2, float& v3) { + v0 = __uint_as_float(((uint32_t)( p & 0xFFFF)) << 16); + v1 = __uint_as_float(((uint32_t)((p >> 16) & 0xFFFF)) << 16); + v2 = __uint_as_float(((uint32_t)((p >> 32) & 0xFFFF)) << 16); + v3 = __uint_as_float(((uint32_t)((p >> 48) & 0xFFFF)) << 16); +} + +__device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, float v3) { + return (uint64_t)bf16_to_bits(to_bf16(v0)) + | ((uint64_t)bf16_to_bits(to_bf16(v1)) << 16) + | ((uint64_t)bf16_to_bits(to_bf16(v2)) << 32) + | ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48); +} + +// 16-point WHT: in-register, no shared memory. +// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, +// extended with NV random_sign_mask (uint16_t bitmask). +// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). +// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). +__device__ __forceinline__ void wht16( + float& v0, float& v1, float& v2, float& v3, + int thread_in_group, uint16_t sign_mask, bool apply_pre) { + auto sgn = [&](int k) -> float { + return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; + }; + if (apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } + + // Stage 1: local H4 + float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; + v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; + + // Stage 2: cross-thread XOR-1 + { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), + p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + bool up=(thread_in_group&1); + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + // Stage 3: cross-thread XOR-2 + { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), + p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + bool up=(thread_in_group>>1)&1; + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; + if (!apply_pre) { v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); } +} + +// Grid: blockIdx.x = col tile [0, row_length/16) +// blockIdx.y = row batch [0, ceil(num_rows/64)) +// Block: 256 threads = 4 wavefronts of 64 lanes. +// lane/4 = row_in_warp (0..15), lane%4 = thread_in_grp (0..3) +template +__global__ __launch_bounds__(kThreadsPerBlock, 4) +void HadamardTransformKernel( + const __hip_bfloat16* __restrict__ input, + __hip_bfloat16* __restrict__ output, + __hip_bfloat16* __restrict__ output_t, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length, + float* __restrict__ amax, float* __restrict__ amax_t, + bool inverse_hadamard) { + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane_id = tid % kWarpSize; + const int row_in_warp = lane_id / kThreadsPerWHT; + const int thread_in_grp = lane_id % kThreadsPerWHT; + const uint64_t col_tile_base = (uint64_t)blockIdx.x * kHadamardDim; + const uint64_t row_batch = (uint64_t)blockIdx.y * kRowsPerBlock; + const uint64_t global_row = row_batch + (uint64_t)warp_id*kRowsPerWarp + row_in_warp; + const uint64_t col_base = col_tile_base + (uint64_t)thread_in_grp * kElemsPerThread; + + const bool apply_pre = !inverse_hadamard; + const bool in_bounds = (global_row < num_rows) && (col_base + kElemsPerThread - 1 < row_length); + + // Smem for transposed path: 64Ă—(16+1) BF16; +1 avoids LDS bank conflict. + __shared__ __hip_bfloat16 smem[kRowsPerBlock][kHadamardDim + 1]; + float v0=0.f, v1=0.f, v2=0.f, v3=0.f; + if (in_bounds) { + unpack_bf16x4(*reinterpret_cast( + &input[global_row * row_length + col_base]), v0, v1, v2, v3); + } + + // Identity path: WHT along row dimension + if constexpr (kComputeIdentity || kUpdateAmax) { + float r0=v0, r1=v1, r2=v2, r3=v3; + if (global_row < num_rows) { + wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); + if constexpr (kUpdateAmax) { + float lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); + for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); + if (lane_id == 0) atomicMaxFloat(amax, lam); + } + if constexpr (kComputeIdentity) + if (output && in_bounds) + *reinterpret_cast(&output[global_row*row_length+col_base]) = + pack_bf16x4(r0,r1,r2,r3); + } + } + + // Transposed path: WHT along column dimension via smem transpose + if constexpr (kComputeTransposed || kUpdateAmaxT) { + const int local_row = warp_id * kRowsPerWarp + row_in_warp; + const int col_offset = thread_in_grp * kElemsPerThread; + smem[local_row][col_offset+0] = to_bf16(global_row < num_rows ? v0 : 0.f); + smem[local_row][col_offset+1] = to_bf16(global_row < num_rows ? v1 : 0.f); + smem[local_row][col_offset+2] = to_bf16(global_row < num_rows ? v2 : 0.f); + smem[local_row][col_offset+3] = to_bf16(global_row < num_rows ? v3 : 0.f); + __syncthreads(); + + // Re-read: row_in_warp -> column index, thread_in_grp -> 4 rows + const int t_col = row_in_warp; + const int smem_rbase = warp_id*kRowsPerWarp + thread_in_grp*kElemsPerThread; + + float c0=to_f32(smem[smem_rbase+0][t_col]), c1=to_f32(smem[smem_rbase+1][t_col]); + float c2=to_f32(smem[smem_rbase+2][t_col]), c3=to_f32(smem[smem_rbase+3][t_col]); + + wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); + + if constexpr (kUpdateAmaxT) { + float lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); + for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); + if (lane_id == 0) atomicMaxFloat(amax_t, lam); + } + + if constexpr (kComputeTransposed) { + if (output_t) { + const uint64_t global_col = col_tile_base + t_col; + const uint64_t out_row_base = row_batch + (uint64_t)warp_id*kRowsPerWarp + + (uint64_t)thread_in_grp*kElemsPerThread; + if (global_col < row_length && out_row_base+kElemsPerThread-1 < num_rows) + *reinterpret_cast( + &output_t[global_col*num_rows+out_row_base]) = + pack_bf16x4(c0,c1,c2,c3); + } + } + } +} + +// Pre-RHT amax: max|input| before any transform. +__global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, + float* __restrict__ amax_out, uint64_t num_elems) { + float lam = 0.f; + for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; + i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) + lam = fmaxf(lam, fabsf(to_f32(input[i]))); + for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); + if (threadIdx.x % kWarpSize == 0) atomicMaxFloat(amax_out, lam); +} + +static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { + return dim3((uint32_t)(row_length / kHadamardDim), + (uint32_t)DIVUP(num_rows, (uint64_t)kRowsPerBlock)); +} + +} // namespace + +#endif // __HIP_PLATFORM_AMD__ + void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform); +#ifdef __HIP_PLATFORM_AMD__ + NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); + NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); + + const SimpleTensor& input = input_.data; + SimpleTensor identity_out; + SimpleTensor& transposed_out = output_.data; + + const bool want_identity = (identity_out.dptr != nullptr); + const bool want_transposed = (transposed_out.dptr != nullptr); + + if (!want_identity && !want_transposed) + return; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + + for (size_t i = 0; i < ndim - 1; ++i) + num_rows *= input.shape[i]; + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + auto* id_ptr = reinterpret_cast<__hip_bfloat16*>(identity_out.dptr); + auto* tr_ptr = reinterpret_cast<__hip_bfloat16*>(transposed_out.dptr); + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); + +#define LAUNCH_T(IDENT, TRANS) \ + HadamardTransformKernel \ + <<>>(in_ptr,id_ptr,tr_ptr, \ + random_sign_mask,random_sign_mask_t, \ + (uint64_t)num_rows,(uint64_t)row_length,nullptr,nullptr,false) + + if (want_identity && want_transposed) + LAUNCH_T(true, true); + else if (want_identity) + LAUNCH_T(true, false); + else + LAUNCH_T(false, true); + NVTE_CHECK_CUDA(cudaGetLastError()); +#undef LAUNCH_T +#else // CUDA // Check tensors // NOTE (frsun): This is non-intuitive, we are writing the result of // transposed RHT to the output of rowwise. @@ -736,6 +987,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s num_rows, row_length, nullptr, nullptr, false););); NVTE_CHECK_CUDA(cudaGetLastError()); +#endif // __HIP_PLATFORM_AMD__ } // Kernel that will apply the 16x16 hadamard transform the input and input.T, and then @@ -743,6 +995,71 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); +#ifdef __HIP_PLATFORM_AMD__ + NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); + NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); + + const SimpleTensor& input = input_.data; + SimpleTensor& pre_rht_tensor = output_.amax; + SimpleTensor identity_tensor; + SimpleTensor& transpose_tensor = output_.columnwise_amax; + + const bool want_pre_rht = (pre_rht_tensor.dptr != nullptr); + const bool want_identity = (identity_tensor.dptr != nullptr); + const bool want_trans = (transpose_tensor.dptr != nullptr); + + if (!want_pre_rht && !want_identity && !want_trans) + return; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + + for (size_t i = 0; i < ndim - 1; ++i) + num_rows *= input.shape[i]; + + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + auto* pre_amax_ptr = reinterpret_cast(pre_rht_tensor.dptr); + auto* id_amax_ptr = reinterpret_cast(identity_tensor.dptr); + auto* tr_amax_ptr = reinterpret_cast(transpose_tensor.dptr); + + if (pre_amax_ptr) + NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); + if (id_amax_ptr) + NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); + if (tr_amax_ptr) + NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); + + if (want_pre_rht) { + const uint64_t num_elems = (uint64_t)num_rows * row_length; + dim3 g(DIVUP(num_elems, (uint64_t)kThreadsPerBlock)); + PreRhtAmaxKernel<<>>(in_ptr,pre_amax_ptr,num_elems); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + if (want_identity || want_trans) { + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); +#define LAUNCH_A(IDENT,TRANS,UA,UAT) \ + HadamardTransformKernel \ + <<>>(in_ptr,nullptr,nullptr, \ + random_sign_mask,random_sign_mask_t, \ + (uint64_t)num_rows,(uint64_t)row_length, \ + id_amax_ptr,tr_amax_ptr,false) + + if (want_identity && want_trans) + LAUNCH_A(true, true, true, true); + else if (want_identity) + LAUNCH_A(true, false, true, false); + else + LAUNCH_A(false,true, false, true); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#undef LAUNCH_A + } +#else // __HIP_PLATFORM_AMD__ #if CUDA_VERSION >= 12080 // Check input tensor @@ -853,6 +1170,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // CUDA_VERSION >= 12080 +#endif // __HIP_PLATFORM_AMD__ } } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 6785040df..90a722e66 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -13,8 +13,6 @@ #ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ #define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ -#ifndef __HIP_PLATFORM_AMD__ - #include "transformer_engine.h" #ifdef __cplusplus @@ -69,6 +67,4 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE } // extern "C" #endif -#endif - #endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 27f24a961..6c19bae13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -293,7 +293,6 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; -#ifndef USE_ROCM class NVFP4Quantizer : public Quantizer { public: // fp4 dtype @@ -347,7 +346,6 @@ class NVFP4Quantizer : public Quantizer { void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; -#endif // #ifndef USE_ROCM std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index d5fd4a4fe..b924e8a77 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -110,12 +110,8 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), -#ifdef USE_ROCM -}; -#else std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer)}; -#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 50c6bc810..7ca0ab3bf 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1163,7 +1163,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s #endif } -#ifndef USE_ROCM NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); @@ -1516,6 +1515,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; +#ifdef USE_ROCM + eligible_for_rht_cast_fusion = false; +#endif + // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { @@ -1663,6 +1666,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); }); } else { +#ifndef USE_ROCM // RHT cast fusion kernel. NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, "RHT matrix is not set"); @@ -1671,6 +1675,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_hadamard_transform_cast_fusion_columnwise( input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); }); +#endif } } } else { @@ -1740,6 +1745,5 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } return scale_shape; } -#endif } // namespace transformer_engine::pytorch From bda7b13b9e1ed4af99457c83bfc13ac51cede85e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 30 Mar 2026 12:12:10 -0500 Subject: [PATCH 19/69] test update --- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 28 +++++++++++-------- .../hadamard_transform/hadamard_transform.cu | 2 ++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index b3b5ad8df..5fd55e059 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -249,22 +251,26 @@ def test_nvfp4_quantization_noncontiguous_inputs( def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: - """Pure-Python reference WHT: tiled 16-point butterfly, normalised by 0.25.""" - import numpy as np - x_np = x.float().cpu().numpy().copy() - rows, cols = x_np.shape - d = np.array([((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], dtype=np.float32) + """Reference 16-point WHT tiled along last dim, normalised by 0.25.""" + x = x.float() + _rows, cols = x.shape + d = torch.tensor( + [((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], + dtype=torch.float32, device=x.device, + ) + out = x.clone() for c in range(0, cols, 16): - tile = x_np[:, c:c+16] * d + tile = out[:, c:c+16] * d # apply sign h = 1 while h < 16: for i in range(0, 16, h * 2): - for j in range(i, i + h): - a, b = tile[:, j].copy(), tile[:, j + h].copy() - tile[:, j], tile[:, j + h] = a + b, a - b + a = tile[:, i:i+h].clone() + b = tile[:, i+h:i+2*h].clone() + tile[:, i:i+h] = a + b + tile[:, i+h:i+2*h] = a - b h *= 2 - x_np[:, c:c+16] = tile * 0.25 - return torch.from_numpy(x_np) + out[:, c:c+16] = tile * 0.25 + return out @pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index e2f449673..64aa243e1 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. From 7ddb53938199ec81cab9342d76de8275175425ef Mon Sep 17 00:00:00 2001 From: alextmagro Date: Mon, 30 Mar 2026 14:03:14 -0500 Subject: [PATCH 20/69] Update extensions.h Remove TODO regarding userbuffers --- transformer_engine/pytorch/csrc/extensions.h | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 78f0134d5..67c4c9ac9 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -374,7 +374,6 @@ size_t get_cublasLt_version(); size_t get_cudnn_version(); #endif -//TODO: support user buffer for ROCm void placeholder(); /*************************************************************************************************** From 63c7a48730f6ba44aae9c831ad5be4a1f8a82f11 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 30 Mar 2026 15:42:04 -0500 Subject: [PATCH 21/69] amax opt --- .../hadamard_transform/hadamard_transform.cu | 100 ++++++++++++++---- 1 file changed, 82 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 64aa243e1..3c8a7e53d 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -732,6 +732,7 @@ __device__ __forceinline__ void wht16( auto sgn = [&](int k) -> float { return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; }; + if (apply_pre) { v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); } @@ -755,7 +756,10 @@ __device__ __forceinline__ void wht16( v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; - if (!apply_pre) { v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); } + + if (!apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } } // Grid: blockIdx.x = col tile [0, row_length/16) @@ -786,8 +790,12 @@ void HadamardTransformKernel( const bool apply_pre = !inverse_hadamard; const bool in_bounds = (global_row < num_rows) && (col_base + kElemsPerThread - 1 < row_length); - // Smem for transposed path: 64Ă—(16+1) BF16; +1 avoids LDS bank conflict. + // Smem for transposed path: 64*(16+1) BF16; +1 avoids LDS bank conflict. __shared__ __hip_bfloat16 smem[kRowsPerBlock][kHadamardDim + 1]; + + __shared__ float block_amax[kWarpsPerBlock]; + __shared__ float block_amax_t[kWarpsPerBlock]; + float v0=0.f, v1=0.f, v2=0.f, v3=0.f; if (in_bounds) { unpack_bf16x4(*reinterpret_cast( @@ -797,24 +805,30 @@ void HadamardTransformKernel( // Identity path: WHT along row dimension if constexpr (kComputeIdentity || kUpdateAmax) { float r0=v0, r1=v1, r2=v2, r3=v3; + float lam = 0.f; if (global_row < num_rows) { wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); if constexpr (kUpdateAmax) { - float lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); - for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); - if (lane_id == 0) atomicMaxFloat(amax, lam); + lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); } if constexpr (kComputeIdentity) if (output && in_bounds) *reinterpret_cast(&output[global_row*row_length+col_base]) = pack_bf16x4(r0,r1,r2,r3); } + if constexpr (kUpdateAmax) { + if (lane_id == 0) + block_amax[warp_id] = lam; + } } // Transposed path: WHT along column dimension via smem transpose if constexpr (kComputeTransposed || kUpdateAmaxT) { const int local_row = warp_id * kRowsPerWarp + row_in_warp; const int col_offset = thread_in_grp * kElemsPerThread; + float lam = 0.f; smem[local_row][col_offset+0] = to_bf16(global_row < num_rows ? v0 : 0.f); smem[local_row][col_offset+1] = to_bf16(global_row < num_rows ? v1 : 0.f); smem[local_row][col_offset+2] = to_bf16(global_row < num_rows ? v2 : 0.f); @@ -831,9 +845,10 @@ void HadamardTransformKernel( wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); if constexpr (kUpdateAmaxT) { - float lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); - for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); - if (lane_id == 0) atomicMaxFloat(amax_t, lam); + lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); + + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); } if constexpr (kComputeTransposed) { @@ -847,18 +862,67 @@ void HadamardTransformKernel( pack_bf16x4(c0,c1,c2,c3); } } + + if constexpr (kUpdateAmaxT) { + if (lane_id == 0) + block_amax_t[warp_id] = lam; + } + } + + if constexpr (kUpdateAmax || kUpdateAmaxT) { + __syncthreads(); + + if (warp_id == 0) { + if constexpr (kUpdateAmax) { + float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; + + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam = fmaxf(block_lam, __shfl_xor(block_lam, off)); + + if (lane_id == 0) + atomicMaxFloat(amax, block_lam); + } + + if constexpr (kUpdateAmaxT) { + float block_lam_t = (lane_id < kWarpsPerBlock) ? block_amax_t[lane_id] : 0.f; + + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam_t = fmaxf(block_lam_t, __shfl_xor(block_lam_t, off)); + + if (lane_id == 0) + atomicMaxFloat(amax_t, block_lam_t); + } + } } } // Pre-RHT amax: max|input| before any transform. __global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, float* __restrict__ amax_out, uint64_t num_elems) { + __shared__ float block_amax[kWarpsPerBlock]; float lam = 0.f; for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) - lam = fmaxf(lam, fabsf(to_f32(input[i]))); - for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); - if (threadIdx.x % kWarpSize == 0) atomicMaxFloat(amax_out, lam); + lam = fmaxf(lam, fabsf(to_f32(input[i]))); + + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); + + const int warp_id = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + if (lane_id == 0) + block_amax[warp_id] = lam; + + __syncthreads(); + + if (warp_id == 0) { + float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam=fmaxf(block_lam,__shfl_xor(block_lam,off)); + + if (lane_id == 0) + atomicMaxFloat(amax_out, block_lam); + } } static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { @@ -878,7 +942,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); const SimpleTensor& input = input_.data; - SimpleTensor identity_out; + SimpleTensor identity_out; // Unused SimpleTensor& transposed_out = output_.data; const bool want_identity = (identity_out.dptr != nullptr); @@ -904,9 +968,9 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s #define LAUNCH_T(IDENT, TRANS) \ HadamardTransformKernel \ - <<>>(in_ptr,id_ptr,tr_ptr, \ - random_sign_mask,random_sign_mask_t, \ - (uint64_t)num_rows,(uint64_t)row_length,nullptr,nullptr,false) + <<>>(in_ptr, id_ptr, tr_ptr, \ + random_sign_mask, random_sign_mask_t, \ + (uint64_t)num_rows, (uint64_t)row_length, nullptr, nullptr, false) if (want_identity && want_transposed) LAUNCH_T(true, true); @@ -992,8 +1056,8 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s #endif // __HIP_PLATFORM_AMD__ } -// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then -// get the absolute max value of the result. +// Kernel that applies the 16x16 hadamard transform the input and input.T, and then +// gets the absolute max value of the result. void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); @@ -1003,7 +1067,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran const SimpleTensor& input = input_.data; SimpleTensor& pre_rht_tensor = output_.amax; - SimpleTensor identity_tensor; + SimpleTensor identity_tensor; // Unused SimpleTensor& transpose_tensor = output_.columnwise_amax; const bool want_pre_rht = (pre_rht_tensor.dptr != nullptr); From a2604591481f0417a14bb42ae7d1ee867a6fd076 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 30 Mar 2026 17:07:54 -0500 Subject: [PATCH 22/69] simplify --- .../hadamard_transform/hadamard_transform.cu | 133 +++++++----------- 1 file changed, 53 insertions(+), 80 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 3c8a7e53d..d574d9ea2 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -1056,77 +1056,12 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s #endif // __HIP_PLATFORM_AMD__ } -// Kernel that applies the 16x16 hadamard transform the input and input.T, and then -// gets the absolute max value of the result. +// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then +// get the absolute max value of the result. void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); -#ifdef __HIP_PLATFORM_AMD__ - NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); - NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); - - const SimpleTensor& input = input_.data; - SimpleTensor& pre_rht_tensor = output_.amax; - SimpleTensor identity_tensor; // Unused - SimpleTensor& transpose_tensor = output_.columnwise_amax; - - const bool want_pre_rht = (pre_rht_tensor.dptr != nullptr); - const bool want_identity = (identity_tensor.dptr != nullptr); - const bool want_trans = (transpose_tensor.dptr != nullptr); - - if (!want_pre_rht && !want_identity && !want_trans) - return; - - const size_t ndim = input.shape.size(); - const size_t row_length = input.shape[ndim - 1]; - size_t num_rows = 1; - - for (size_t i = 0; i < ndim - 1; ++i) - num_rows *= input.shape[i]; - - NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); - NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); - - auto* in_ptr = reinterpret_cast(input.dptr); - auto* pre_amax_ptr = reinterpret_cast(pre_rht_tensor.dptr); - auto* id_amax_ptr = reinterpret_cast(identity_tensor.dptr); - auto* tr_amax_ptr = reinterpret_cast(transpose_tensor.dptr); - - if (pre_amax_ptr) - NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); - if (id_amax_ptr) - NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); - if (tr_amax_ptr) - NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); - - if (want_pre_rht) { - const uint64_t num_elems = (uint64_t)num_rows * row_length; - dim3 g(DIVUP(num_elems, (uint64_t)kThreadsPerBlock)); - PreRhtAmaxKernel<<>>(in_ptr,pre_amax_ptr,num_elems); - NVTE_CHECK_CUDA(cudaGetLastError()); - } - - if (want_identity || want_trans) { - dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); -#define LAUNCH_A(IDENT,TRANS,UA,UAT) \ - HadamardTransformKernel \ - <<>>(in_ptr,nullptr,nullptr, \ - random_sign_mask,random_sign_mask_t, \ - (uint64_t)num_rows,(uint64_t)row_length, \ - id_amax_ptr,tr_amax_ptr,false) - - if (want_identity && want_trans) - LAUNCH_A(true, true, true, true); - else if (want_identity) - LAUNCH_A(true, false, true, false); - else - LAUNCH_A(false,true, false, true); - - NVTE_CHECK_CUDA(cudaGetLastError()); -#undef LAUNCH_A - } -#else // __HIP_PLATFORM_AMD__ -#if CUDA_VERSION >= 12080 +#if CUDA_VERSION >= 12080 || defined(__HIP_PLATFORM_AMD__) // Check input tensor NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, @@ -1151,16 +1086,6 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran return; } - // Zero out amaxes if needed - ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), - reinterpret_cast(output_identity_amax.dptr), - reinterpret_cast(output_transpose_amax.dptr)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - checkCuDriverContext(stream); - - using IType = bf16; - const size_t ndim = input.shape.size(); const size_t row_length = input.shape[ndim - 1]; size_t num_rows = 1; @@ -1168,6 +1093,41 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran num_rows *= input.shape[i]; } +#ifdef __HIP_PLATFORM_AMD__ + auto* pre_amax_ptr = reinterpret_cast(output_pre_rht_amax.dptr); + auto* id_amax_ptr = reinterpret_cast(output_identity_amax.dptr); + auto* tr_amax_ptr = reinterpret_cast(output_transpose_amax.dptr); + + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + + if (pre_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); + } + if (id_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); + } + if (tr_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); + } + + if (return_pre_rht_amax) { + const uint64_t num_elems = static_cast(num_rows) * row_length; + dim3 grid(DIVUP(num_elems, static_cast(kThreadsPerBlock))); + PreRhtAmaxKernel<<>>(in_ptr, pre_amax_ptr, num_elems); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +#else + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, "row_length must be divisible by hadamard_dimension."); @@ -1200,6 +1160,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); +#endif TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transposed_amax, kReturnTransposedAmax, @@ -1207,6 +1168,17 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran TRANSFORMER_ENGINE_SWITCH_CONDITION( return_identity_amax, kReturnIdentityAmax, +#ifdef __HIP_PLATFORM_AMD__ + if (kReturnIdentityAmax || kReturnTransposedAmax) { + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); + HadamardTransformKernel + <<>>(in_ptr, nullptr, nullptr, random_sign_mask, + random_sign_mask_t, static_cast(num_rows), + static_cast(row_length), id_amax_ptr, + tr_amax_ptr, false); + } +#else TRANSFORMER_ENGINE_SWITCH_CONDITION( return_pre_rht_amax, kReturnPreRhtAmax, @@ -1230,13 +1202,14 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran reinterpret_cast(output_identity_amax.dptr), reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, random_sign_mask_t, num_rows, row_length);))); +#endif + )); NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // CUDA_VERSION >= 12080 -#endif // __HIP_PLATFORM_AMD__ +#endif // CUDA_VERSION >= 12080 || __HIP_PLATFORM_AMD__ } } // namespace transformer_engine From 26c5fb7061cfed05714a44eb683bd41799d8cfbc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 31 Mar 2026 10:46:00 -0500 Subject: [PATCH 23/69] simplify pt 2 --- .../hadamard_transform/hadamard_transform.cu | 88 ++++++++----------- 1 file changed, 35 insertions(+), 53 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index d574d9ea2..407d4ed2e 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -937,50 +937,7 @@ static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform); -#ifdef __HIP_PLATFORM_AMD__ - NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); - NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); - - const SimpleTensor& input = input_.data; - SimpleTensor identity_out; // Unused - SimpleTensor& transposed_out = output_.data; - - const bool want_identity = (identity_out.dptr != nullptr); - const bool want_transposed = (transposed_out.dptr != nullptr); - - if (!want_identity && !want_transposed) - return; - - const size_t ndim = input.shape.size(); - const size_t row_length = input.shape[ndim - 1]; - size_t num_rows = 1; - - for (size_t i = 0; i < ndim - 1; ++i) - num_rows *= input.shape[i]; - - NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); - NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); - auto* in_ptr = reinterpret_cast(input.dptr); - auto* id_ptr = reinterpret_cast<__hip_bfloat16*>(identity_out.dptr); - auto* tr_ptr = reinterpret_cast<__hip_bfloat16*>(transposed_out.dptr); - dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); - -#define LAUNCH_T(IDENT, TRANS) \ - HadamardTransformKernel \ - <<>>(in_ptr, id_ptr, tr_ptr, \ - random_sign_mask, random_sign_mask_t, \ - (uint64_t)num_rows, (uint64_t)row_length, nullptr, nullptr, false) - - if (want_identity && want_transposed) - LAUNCH_T(true, true); - else if (want_identity) - LAUNCH_T(true, false); - else - LAUNCH_T(false, true); - NVTE_CHECK_CUDA(cudaGetLastError()); -#undef LAUNCH_T -#else // CUDA // Check tensors // NOTE (frsun): This is non-intuitive, we are writing the result of // transposed RHT to the output of rowwise. @@ -1004,7 +961,9 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s return; } +#ifndef __HIP_PLATFORM_AMD__ checkCuDriverContext(stream); +#endif const size_t ndim = input.shape.size(); const size_t row_length = input.shape[ndim - 1]; @@ -1013,14 +972,23 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s num_rows *= input.shape[i]; } +#ifdef __HIP_PLATFORM_AMD__ + using IType = __hip_bfloat16; + constexpr int kHadamardDimension = kHadamardDim; +#else using IType = bf16; constexpr int kHadamardDimension = 16; +#endif NVTE_CHECK(row_length % kHadamardDimension == 0, "row_length must be divisible by hadamard_dimension."); NVTE_CHECK(num_rows % kHadamardDimension == 0, "num_rows must be divisible by hadamard_dimension"); +#ifdef __HIP_PLATFORM_AMD__ + dim3 block(kThreadsPerBlock); + dim3 grid = transform_grid(num_rows, row_length); +#else constexpr uint64_t kThreadBlockX = 4; // Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth. constexpr uint64_t kThreadBlockY = 4; @@ -1034,6 +1002,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX), DIVUP(num_rows / kHadamardDimension, kThreadBlockY)); +#endif TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transposed, kReturnTransposed, @@ -1041,6 +1010,15 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s TRANSFORMER_ENGINE_SWITCH_CONDITION( return_identity, kReturnIdentity, +#ifdef __HIP_PLATFORM_AMD__ + HadamardTransformKernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), random_sign_mask, + random_sign_mask_t, static_cast(num_rows), + static_cast(row_length), nullptr, nullptr, false););); +#else auto kernel = HadamardTransformKernel; @@ -1051,9 +1029,9 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s reinterpret_cast(input.dptr), reinterpret_cast(output.dptr), reinterpret_cast(output_t.dptr), random_sign_mask, random_sign_mask_t, num_rows, row_length, nullptr, nullptr, false););); +#endif NVTE_CHECK_CUDA(cudaGetLastError()); -#endif // __HIP_PLATFORM_AMD__ } // Kernel that will apply the 16x16 hadamard transform the input and input.T, and then @@ -1086,6 +1064,18 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran return; } +#ifndef __HIP_PLATFORM_AMD__ + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; +#endif + const size_t ndim = input.shape.size(); const size_t row_length = input.shape[ndim - 1]; size_t num_rows = 1; @@ -1120,14 +1110,6 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran NVTE_CHECK_CUDA(cudaGetLastError()); } #else - // Zero out amaxes if needed - ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); - NVTE_CHECK_CUDA(cudaGetLastError()); - - checkCuDriverContext(stream); - - using IType = bf16; - constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, "row_length must be divisible by hadamard_dimension."); @@ -1178,6 +1160,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran static_cast(row_length), id_amax_ptr, tr_amax_ptr, false); } + )); #else TRANSFORMER_ENGINE_SWITCH_CONDITION( return_pre_rht_amax, kReturnPreRhtAmax, @@ -1203,7 +1186,6 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, random_sign_mask_t, num_rows, row_length);))); #endif - )); NVTE_CHECK_CUDA(cudaGetLastError()); #else From 2087f24d9d9fb76906fb4fc890c5df0a6665b0e6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 31 Mar 2026 15:15:30 -0500 Subject: [PATCH 24/69] expand test --- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 74 +++++++++++++++++-- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 5fd55e059..dfeecd7c2 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -273,14 +273,54 @@ def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: return out +def _ref_quantize_wht16_tiled( + x: torch.Tensor, sign_mask: int, global_amax: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # Mirror the non-fused TE RHT path by quantizing the BF16-rounded WHT(x.T) + # with the same global amax used by the TE columnwise output. + + x_t_rht = _ref_wht16_tiled(x.t().contiguous(), sign_mask=sign_mask).to(dtype=x.dtype) + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=False, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=False, + with_random_sign_mask=False, + ) + + x_t_rht_padded = ref_quantizer._pad_tensor( + x_t_rht, + row_divisor=ref_quantizer.quant_tile_shape[0], + col_divisor=ref_quantizer.quant_tile_shape[1], + ) + + qx_t_ref, sx_t_ref = ref_quantizer._quantize_blockwise_reference( + x_t_rht_padded, + global_amax, + ref_quantizer.quant_tile_shape[1], + ref_quantizer.quant_tile_shape[0], + pow_2_scales=ref_quantizer.pow_2_scales, + eps=ref_quantizer.eps, + ) + + qx_t_ref = ref_quantizer._rm_pad_tensor(qx_t_ref, (x_t_rht.shape[0], x_t_rht.shape[1] // 2)) + + return qx_t_ref, sx_t_ref + + @pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) def test_hadamard_transform_amax(rows, cols): """ - Tests nvte_hadamard_transform_amax via NVFP4Quantizer (with_rht=True). - Exercises the WHT kernel without requiring a full NVFP4 recipe. + Tests hadamard_transform_amax() and hadamard_transform() via NVFP4Quantizer + (with_rht=True), without requiring a full NVFP4 recipe. Checks: - amax_rowwise == max|x| (pre-RHT amax of raw input) - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) + - packed columnwise output == quantized WHT(x.T) derived from hadamard_transform() + in the non-fused TE path """ torch.manual_seed(42) x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() @@ -302,8 +342,7 @@ def test_hadamard_transform_amax(rows, cols): torch.testing.assert_close( out._amax_rowwise.float().squeeze(), expected_rowwise_amax, - rtol=1e-3, atol=1e-3, - msg=f"pre-RHT amax mismatch rows={rows} cols={cols}", + rtol=0, atol=0, ) # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| @@ -315,6 +354,29 @@ def test_hadamard_transform_amax(rows, cols): torch.testing.assert_close( out._amax_columnwise.float().squeeze().item(), float(expected_colwise_amax), - rtol=2e-2, atol=2e-2, - msg=f"post-RHT amax mismatch rows={rows} cols={cols}", + rtol=0, atol=0, + ) + + assert out._columnwise_data is not None + assert out._columnwise_scale_inv is not None + + qx_t_ref, sx_t_ref = _ref_quantize_wht16_tiled(x, sign_mask_t, out._amax_columnwise) + + qx_t = unpack_fp4(out._columnwise_data.view(torch.uint8)) + qx_t_ref = unpack_fp4(qx_t_ref.view(torch.uint8)) + torch.testing.assert_close( + qx_t, + qx_t_ref, + atol=0.0, + rtol=0.0, + ) + + sx_t = out._columnwise_scale_inv + sx_t_ref = sx_t_ref.view(dtype=torch.uint8) + sx_t_valid = sx_t[: sx_t_ref.shape[0], : sx_t_ref.shape[1]] + torch.testing.assert_close( + sx_t_valid, + sx_t_ref, + atol=0.0, + rtol=0.0, ) From 05cedb7371b542247c31cca60fd85854819a7e79 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 31 Mar 2026 18:51:33 -0500 Subject: [PATCH 25/69] compute amax from BF16-rounded outputs --- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 27 ++++++------------- .../hadamard_transform/hadamard_transform.cu | 16 +++++++++-- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index dfeecd7c2..de9495edf 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -276,8 +276,8 @@ def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: def _ref_quantize_wht16_tiled( x: torch.Tensor, sign_mask: int, global_amax: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - # Mirror the non-fused TE RHT path by quantizing the BF16-rounded WHT(x.T) - # with the same global amax used by the TE columnwise output. + # Mirror the TE columnwise RHT path by BF16-rounding WHT(x.T) + # before applying NVFP4 reference quantization with the TE global amax. x_t_rht = _ref_wht16_tiled(x.t().contiguous(), sign_mask=sign_mask).to(dtype=x.dtype) ref_quantizer = NVFP4QuantizerRef( @@ -314,13 +314,12 @@ def _ref_quantize_wht16_tiled( @pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) def test_hadamard_transform_amax(rows, cols): """ - Tests hadamard_transform_amax() and hadamard_transform() via NVFP4Quantizer - (with_rht=True), without requiring a full NVFP4 recipe. + Tests hadamard_transform_amax() via NVFP4Quantizer (with_rht=True), + without requiring a full NVFP4 recipe. Checks: - amax_rowwise == max|x| (pre-RHT amax of raw input) - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) - - packed columnwise output == quantized WHT(x.T) derived from hadamard_transform() - in the non-fused TE path + - packed columnwise output == quantized BF16-rounded WHT(x.T) """ torch.manual_seed(42) x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() @@ -348,7 +347,7 @@ def test_hadamard_transform_amax(rows, cols): # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| sign_mask_t = quantizer.rht_matrix_random_sign_mask_t x_t = x.t().contiguous() # (cols, rows) - wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t) + wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t).to(torch.bfloat16).float() expected_colwise_amax = wht_x_t.float().abs().max() torch.testing.assert_close( @@ -364,19 +363,9 @@ def test_hadamard_transform_amax(rows, cols): qx_t = unpack_fp4(out._columnwise_data.view(torch.uint8)) qx_t_ref = unpack_fp4(qx_t_ref.view(torch.uint8)) - torch.testing.assert_close( - qx_t, - qx_t_ref, - atol=0.0, - rtol=0.0, - ) + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) sx_t = out._columnwise_scale_inv sx_t_ref = sx_t_ref.view(dtype=torch.uint8) sx_t_valid = sx_t[: sx_t_ref.shape[0], : sx_t_ref.shape[1]] - torch.testing.assert_close( - sx_t_valid, - sx_t_ref, - atol=0.0, - rtol=0.0, - ) + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 407d4ed2e..5df865751 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -809,7 +809,13 @@ void HadamardTransformKernel( if (global_row < num_rows) { wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); if constexpr (kUpdateAmax) { - lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); + // Match the stored/output precision when reporting amax. + const float r0_bf16 = to_f32(to_bf16(r0)); + const float r1_bf16 = to_f32(to_bf16(r1)); + const float r2_bf16 = to_f32(to_bf16(r2)); + const float r3_bf16 = to_f32(to_bf16(r3)); + lam = fmaxf(fmaxf(fabsf(r0_bf16), fabsf(r1_bf16)), + fmaxf(fabsf(r2_bf16), fabsf(r3_bf16))); for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); } @@ -845,7 +851,13 @@ void HadamardTransformKernel( wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); if constexpr (kUpdateAmaxT) { - lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); + // Match the stored/output precision when reporting amax. + const float c0_bf16 = to_f32(to_bf16(c0)); + const float c1_bf16 = to_f32(to_bf16(c1)); + const float c2_bf16 = to_f32(to_bf16(c2)); + const float c3_bf16 = to_f32(to_bf16(c3)); + lam = fmaxf(fmaxf(fabsf(c0_bf16), fabsf(c1_bf16)), + fmaxf(fabsf(c2_bf16), fabsf(c3_bf16))); for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); From 67b93a800eb1f6f11ca140b072db0b35ded9e55a Mon Sep 17 00:00:00 2001 From: ipanfilo <145064111+ipanfilo@users.noreply.github.com> Date: Tue, 31 Mar 2026 21:52:26 -0400 Subject: [PATCH 26/69] TE building over TheRock (#511) * Update Dockerfile to use ROCm TheRock * Update wheels building script to work with ROCm TheRock and the latest Manylinux image * Support default ROCm location /opt/rocm/core * Fix UB code build on TheRock * Support comma separated list of target GPU architectures * Guess ROCm build from HIP_PLATFORM --- .github/scripts/aiter_prebuild_upload.sh | 9 ++- .github/workflows/rocm-ci.yml | 7 +-- build_tools/rocm_utils.cmake | 20 +++++++ build_tools/utils.py | 56 +++++++++---------- .../wheel_utils/Dockerfile.rocm.manylinux.x86 | 17 +++--- build_tools/wheel_utils/build_wheels.sh | 45 ++++++++------- ci/_utils.sh | 10 +++- tests/cpp/CMakeLists.txt | 6 +- transformer_engine/common/CMakeLists.txt | 19 ++----- transformer_engine/common/__init__.py | 5 +- .../common/ck_fused_attn/aiter_prebuilt.cmake | 6 +- .../common/rocshmem_api/CMakeLists.txt | 4 +- .../common/util/cuda_runtime.cpp | 1 + 13 files changed, 115 insertions(+), 90 deletions(-) create mode 100644 build_tools/rocm_utils.cmake diff --git a/.github/scripts/aiter_prebuild_upload.sh b/.github/scripts/aiter_prebuild_upload.sh index fc13d12ee..ebbb29a1b 100755 --- a/.github/scripts/aiter_prebuild_upload.sh +++ b/.github/scripts/aiter_prebuild_upload.sh @@ -12,7 +12,14 @@ set -euo pipefail # Derive ROCm version and aiter commit -> cache key ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" -ROCM_PATH="${ROCM_PATH:-/opt/rocm}" +if [ -n "${ROCM_PATH:-}" ]; then + true # Use provided ROCM_PATH +elif [ -d "/opt/rocm/core" ]; then + ROCM_PATH="/opt/rocm/core" +else + ROCM_PATH="/opt/rocm" +fi +export ROCM_PATH ROCM_VER=`head -n1 "${ROCM_PATH}/.info/version" | cut -d. -f1` AITER_DIR="${ROOT_DIR}/3rdparty/aiter" diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 158dab761..a52d42b17 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -70,7 +70,7 @@ jobs: docker ps -a echo ">>> ROCm Installation:" - ls -d /opt/rocm* || echo "No /opt/rocm found" + (ls -d /opt/rocm/core-* || ls -d /opt/rocm-* || echo "No default ROCm path found") 2>/dev/null || true echo ">>> GPU info:" ls -l /dev/dri ls -l /dev/kfd @@ -247,11 +247,6 @@ jobs: set -x -o pipefail ulimit -c 0 # Disable core dumps - # debug output - ls -d /opt/rocm* - python --version - pip list | egrep "transformer_e|torch|jax|numpy|ml_dtypes|typing_ext" - HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 & torch_pid=$!; echo Pytorch test pid $! diff --git a/build_tools/rocm_utils.cmake b/build_tools/rocm_utils.cmake new file mode 100644 index 000000000..dca794c0c --- /dev/null +++ b/build_tools/rocm_utils.cmake @@ -0,0 +1,20 @@ +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +#Determine ROCM_PATH +if(NOT "$ENV{ROCM_PATH}" STREQUAL "") + set(ROCM_PATH "$ENV{ROCM_PATH}") +elseif(EXISTS "/opt/rocm/core") + set(ROCM_PATH "/opt/rocm/core") +else() + set(ROCM_PATH "/opt/rocm") +endif() + +#Configure target GPU architectures +if(NOT DEFINED ENV{NVTE_ROCM_ARCH}) + SET(CMAKE_HIP_ARCHITECTURES gfx942 gfx950) +else() + # Accept comma separated list for NVTE_ROCM_ARCH + string(REPLACE "," ";" HIP_ARCH_LIST "$ENV{NVTE_ROCM_ARCH}") + SET(CMAKE_HIP_ARCHITECTURES ${HIP_ARCH_LIST}) +endif() diff --git a/build_tools/utils.py b/build_tools/utils.py index 2cb7a3768..29722034b 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -184,42 +184,30 @@ def rocm_build() -> bool: Determines which build platform to use: - If `NVTE_USE_ROCM` is set: - - Non-zero value: Use ROCm, if hipcc is detected. + - Any value except "0": Use ROCm, if hipcc is detected. - Zero value: Use CUDA, if nvcc is detected. - - If `NVTE_USE_ROCM` is not set: - - Attempt to auto-detect: Check for ROCm first, then CUDA. + - If `NVTE_USE_ROCM` is not set, guess if from `HIP_PLATFORM` if it is set. + - Otherwise auto-detect trying ROCm first. Returns: bool: `True` for ROCm, `False` for CUDA. Raises: - ValueError: If NVTE_USE_ROCM is set to invalid value. FileNotFoundError: If required tools (hipcc or nvcc) are not found. """ - nvte_use_rocm = os.getenv("NVTE_USE_ROCM") - if nvte_use_rocm: + nvte_use_rocm = os.getenv("NVTE_USE_ROCM", "") + if not nvte_use_rocm: + match os.getenv("HIP_PLATFORM", ""): + case "amd": nvte_use_rocm = "1" + case "nvidia": nvte_use_rocm = "0" + if nvte_use_rocm != "0": try: - nvte_use_rocm = bool(int(nvte_use_rocm)) - except ValueError: - raise ValueError( - f"Invalid value for NVTE_USE_ROCM: '{nvte_use_rocm}'.") - - if nvte_use_rocm: - _, hipcc_bin = rocm_path() - if hipcc_bin.is_file(): - return True - else: - raise FileNotFoundError(f"Could not find hipcc at {hipcc_bin}") - else: - nvcc_path() - return False - - # Try to detect ROCm - _, hipcc_bin = rocm_path() - if hipcc_bin.is_file(): - return True - - # Try to detect CUDA + rocm_path() + return True + except FileNotFoundError: + if nvte_use_rocm: + raise FileNotFoundError("Could not find ROCm installation.") + # Try to detect CUDA if NVTE_USE_ROCM is set to "0" or ROCm is not found try: nvcc_path() return False @@ -230,8 +218,10 @@ def rocm_build() -> bool: @functools.lru_cache(maxsize=None) def rocm_path() -> Tuple[str, str]: - """ROCm root path and HIPCC binary path as a tuple""" - """If ROCm installation is not specified, use default /opt/rocm path""" + """ + ROCm root path and HIPCC binary path as a tuple + If ROCm installation is not specified, use default ROCm path + """ hipcc_bin = None if os.getenv("ROCM_PATH"): rocm_home = Path(os.getenv("ROCM_PATH")) @@ -239,11 +229,15 @@ def rocm_path() -> Tuple[str, str]: if hipcc_bin is None: hipcc_bin = shutil.which("hipcc") if hipcc_bin is not None: - hipcc_bin = Path(hipcc_bin) + hipcc_bin = Path(hipcc_bin).resolve() rocm_home = hipcc_bin.parent.parent if hipcc_bin is None: - rocm_home = Path("/opt/rocm/") + rocm_home = Path("/opt/rocm/core") + if not rocm_home.is_dir(): + rocm_home = Path("/opt/rocm/") hipcc_bin = rocm_home / "bin" / "hipcc" + if not hipcc_bin.is_file(): + raise FileNotFoundError(f"Could not find hipcc at {hipcc_bin}") return rocm_home, hipcc_bin diff --git a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 index 6b908f9bc..cb6782ebe 100644 --- a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 +++ b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 @@ -7,7 +7,7 @@ # Build args: # BASE_IMAGE - Base manylinux image to use. Default: quay.io/pypa/manylinux_2_28_x86_64 -# ROCM_REPO_URL - ROCm repository URL. Default: https://repo.radeon.com/rocm/rhel8/latest/main/ +# ROCM_REPO_URL - ROCm TheRock repository URL. Default: https://repo.amd.com/rocm/packages/rhel8/x86_64/ # GPU_TARGETS - Semicolon separated list of target GPU architectures. Default: "gfx942;gfx950" # TARGET_BRANCH - Target branch for TransformerEngine. Default: none (use git default) # GPU_TARGETS and TARGET_BRANCH can be overriden when start a container with NVTE_ROCM_ARCH and TARGET_BRANCH environment variables. @@ -16,22 +16,25 @@ ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64 FROM $BASE_IMAGE -ARG ROCM_REPO_URL=https://repo.radeon.com/rocm/rhel8/latest/main/ +ARG ROCM_REPO_URL=https://repo.amd.com/rocm/packages/rhel8/x86_64/ # Set up ROCm repo RUN echo -e "[rocm]\nname=ROCm\nbaseurl=${ROCM_REPO_URL}\nenabled=1\ngpgcheck=0" > /etc/yum.repos.d/rocm.repo # Setup packages -RUN dnf install -y --disablerepo=epel rocm-dev hipblaslt hipblaslt-devel hipcub hipcub-devel hiprand-devel rocrand-devel -RUN dnf group install -y "Development Tools" && dnf install -y git cmake llvm-toolset gcc-toolset-12 +# ROCm TheRock packages are built for single GPU family. However, nothing from HW libraries is statically linked to TE, +# so they can be used to build for multiple GPU targets. Using gfx950 here. +RUN dnf install -y --disablerepo=epel amdrocm-core-devel-gfx950 -#Uncomment the next line for ROCm 6.4 cmake workaround: remove newer incomnpatible cmake preinstalled on base image -#RUN rm /usr/local/bin/cmake || true +# xz-devel installs lzma needed by AOTriton +RUN dnf install -y gcc-toolset-13 xz-devel RUN dnf clean all RUN rm -rf /var/cache/dnf/* -ENV CMAKE_PREFIX_PATH=/opt/rocm/lib/cmake +ENV HIP_PLATFORM=amd +ENV ROCM_PATH=/opt/rocm/core + ENV NVTE_RELEASE_BUILD=1 ARG GPU_TARGETS="gfx942;gfx950" diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index ec5a1d0c4..c1ecb409a 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -22,17 +22,21 @@ mkdir -p /wheelhouse/logs git config --global --add safe.directory /TransformerEngine cd /TransformerEngine -#If there is default Python installation, use it -PYTHON=`which python || true` -if [ -z "$PYTHON" ]; then - PYBINDIR=/opt/python/cp310-cp310/bin/ - #hipify expects python in PATH, also ninja may be installed to python bindir - PATH="$PYBINDIR:$PATH" -else - PYBINDIR="" #python bindir is already in PATHs -fi - -ROCM_BUILD=$(${PYBINDIR}python -c "import build_tools.utils as u; print('true' if u.rocm_build() else 'false')") +#hipify and aiter expect python in PATH, also ninja may be installed to python bindir +#set it first because system python may be too old +PATH="/opt/python/cp310-cp310/bin/:$PATH" + +case "$HIP_PLATFORM" in + amd) + ROCM_BUILD=true + ;; + nvidia) + ROCM_BUILD=false + ;; + *) + ROCM_BUILD=$(python -c "import build_tools.utils as u; print('true' if u.rocm_build() else 'false')") + ;; +esac if [ "$LOCAL_TREE_BUILD" != "1" ]; then if $ROCM_BUILD ; then @@ -45,22 +49,23 @@ else fi if $ROCM_BUILD ; then - #dataclasses, psutil are needed for AITER - ${PYBINDIR}pip install pybind11[global] ninja setuptools dataclasses psutil + pip install pybind11[global] ninja setuptools wheel + #modules needed to build AITER + pip install dataclasses psutil numpy pandas + export PATH=$PATH:$ROCM_PATH/bin else - PYBINDIR=/opt/python/cp310-cp310/bin/ /opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel fi if $BUILD_METAPACKAGE ; then cd /TransformerEngine - NVTE_BUILD_METAPACKAGE=1 ${PYBINDIR}python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt + NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt mv dist/* /wheelhouse/ fi if $BUILD_COMMON -a $ROCM_BUILD ; then # Create the wheel. - ${PYBINDIR}python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt + python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) @@ -93,8 +98,8 @@ fi if $BUILD_PYTORCH -a $ROCM_BUILD ; then cd /TransformerEngine/transformer_engine/pytorch #Only need torch for creating sdist, install CPU version to avoid installing CUDA/ROCm dependencies - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/cpu - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + pip install torch --index-url https://download.pytorch.org/whl/cpu + python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt mv dist/* /wheelhouse/ elif $BUILD_PYTORCH ; then cd /TransformerEngine/transformer_engine/pytorch @@ -105,8 +110,8 @@ fi if $BUILD_JAX -a $ROCM_BUILD ; then cd /TransformerEngine/transformer_engine/jax - ${PYBINDIR}pip install jax - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + pip install jax + python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt mv dist/* /wheelhouse/ elif $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax diff --git a/ci/_utils.sh b/ci/_utils.sh index 966efbcad..b4aae9cc7 100644 --- a/ci/_utils.sh +++ b/ci/_utils.sh @@ -225,7 +225,15 @@ check_test_filter() { start_message() { echo "Started with TEST_LEVEL=$TEST_LEVEL sGPU='$TEST_SGPU' mGPU='$TEST_MGPU' at `date`" - echo "ROCm: `ls -d /opt/rocm-*`" + if [ -n "$ROCM_PATH" ]; then + _rocm_path="$ROCM_PATH" + elif [ -d "/opt/rocm/core" ]; then + _rocm_path="/opt/rocm/core" + else + _rocm_path="/opt/rocm" + fi + _rocm_path=`$REALPATH "$_rocm_path"` + test -d "$_rocm_path" && echo "ROCm: $_rocm_path" || echo "ROCm path not found" python --version } diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index c69fce7b1..599f74599 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -33,11 +33,7 @@ if(USE_ROCM) # Disable Asserts In Code (Can't use asserts on HIP stack.) add_definitions(-DNDEBUG) add_definitions(-DUSE_ROCM) - if(NOT DEFINED ENV{NVTE_ROCM_ARCH}) - SET(CMAKE_HIP_ARCHITECTURES gfx942 gfx950) - else() - SET(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}) - endif() + include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/rocm_utils.cmake") else() if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 65b6c3997..5fd869cf2 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -133,23 +133,17 @@ else() message(FATAL_ERROR "This project requires the Ninja build system. Install it using 'pip install ninja'.") endif() + include("${CMAKE_CURRENT_SOURCE_DIR}/../../build_tools/rocm_utils.cmake") + # Disable Asserts In Code (Can't use asserts on HIP stack.) add_definitions(-DNDEBUG) add_definitions(-DUSE_ROCM) - if(NOT DEFINED ENV{NVTE_ROCM_ARCH}) - SET(CMAKE_HIP_ARCHITECTURES gfx942 gfx950) - else() - SET(CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}) - endif() - # build error will be dup-ed parallel-jobs times # set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -parallel-jobs=4") if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -g") endif() - - list(APPEND CMAKE_MODULE_PATH "/opt/rocm") endif() set(message_line "-------------------------------------------------------------") @@ -323,6 +317,10 @@ else() message(STATUS "nvte hipified sources: ${te_hip_sources}") add_library(transformer_engine SHARED ${te_hip_sources}) + + # Workaround for TheRock installation that moved some headers from system-wide location + # to rocm_sysdeps but missing it in the default include path. + target_include_directories(transformer_engine SYSTEM PRIVATE "${ROCM_PATH}/lib/rocm_sysdeps/include") endif() target_include_directories(transformer_engine PUBLIC @@ -562,11 +560,6 @@ endif() install(TARGETS transformer_engine DESTINATION .) if (USE_ROCM) set_target_properties(transformer_engine PROPERTIES INSTALL_RPATH "$ORIGIN/lib;$ORIGIN/transformer_engine/lib") - if("$ENV{ROCM_PATH}" STREQUAL "") - set(ROCM_PATH "/opt/rocm") - else() - set(ROCM_PATH "$ENV{ROCM_PATH}") - endif() file(READ "${ROCM_PATH}/.info/version" ROCM_VER) string(STRIP "${ROCM_VER}" ROCM_VER) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}") diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 95719e188..fbc67910b 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -421,7 +421,10 @@ def _load_core_library(): if te_rocm_build: try: # Get installed ROCm version - with open(os.getenv("ROCM_PATH", "/opt/rocm") + "/.info/version", "r") as f: + for rocm_path in (os.getenv("ROCM_PATH"), "/opt/rocm/core", "/opt/rocm"): + if rocm_path and os.path.exists(os.path.join(rocm_path, ".info/version")): + break + with open(os.path.join(rocm_path, ".info/version"), "r") as f: rocm_version= f.read().strip().split('.')[:2] # Get ROCm version from the build info file diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index e52e2c948..9b97114e5 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -10,11 +10,9 @@ if(POLICY CMP0135) cmake_policy(SET CMP0135 NEW) endif() +include("${CMAKE_CURRENT_SOURCE_DIR}/../../../build_tools/rocm_utils.cmake") + # Extract ROCm version -set(ROCM_PATH "$ENV{ROCM_PATH}") -if("${ROCM_PATH}" STREQUAL "") - set(ROCM_PATH "/opt/rocm") -endif() file(READ "${ROCM_PATH}/.info/version" ROCM_VER_CONTENT) string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}") diff --git a/transformer_engine/common/rocshmem_api/CMakeLists.txt b/transformer_engine/common/rocshmem_api/CMakeLists.txt index 124f30c09..dfa9bcbfa 100644 --- a/transformer_engine/common/rocshmem_api/CMakeLists.txt +++ b/transformer_engine/common/rocshmem_api/CMakeLists.txt @@ -19,10 +19,12 @@ set_target_properties(rocshmemapi PROPERTIES POSITION_INDEPENDENT_CODE ON ) +include("${CMAKE_CURRENT_SOURCE_DIR}/../../../build_tools/rocm_utils.cmake") + if(DEFINED ENV{ROCSHMEM_HOME}) set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}" CACHE STRING "Location of ROCSHMEM installation") else() - set(ROCSHMEM_HOME "/opt/rocm" CACHE STRING "Location of ROCSHMEM installation (default)") + set(ROCSHMEM_HOME "${ROCM_PATH}" CACHE STRING "Location of ROCSHMEM installation (default)") endif() set(ROCSHMEM_LIBRARY_PATH "${ROCSHMEM_HOME}/lib/librocshmem.a") if (EXISTS ${ROCSHMEM_LIBRARY_PATH}) diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 515131b1d..3505516ba 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -178,6 +178,7 @@ const std::string &include_directory(bool required) { #ifdef __HIP_PLATFORM_AMD__ std::vector> search_paths = {{"ROCM_PATH", ""}, {"HIP_PATH", ""}, + {"", "/opt/rocm/core"}, {"", "/opt/rocm"}}; #else std::vector> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""}, From 465d5473aaf76cf8474481a120fb4b1fa439b5ef Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Fri, 5 Dec 2025 23:40:54 -0500 Subject: [PATCH 27/69] Typo fix (#397) From 9fb21f9ee233e2044dd56da9be93477ccdf086d2 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 31 Mar 2026 21:36:43 -0500 Subject: [PATCH 28/69] Add NVTE_UB_WITH_MPI to rocm build path --- transformer_engine/common/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5fd869cf2..42e5c0449 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -353,6 +353,7 @@ target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_ target_include_directories(transformer_engine PRIVATE ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_INCLUDE_DIR}) +endif() # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI # Changed @@ -367,6 +368,7 @@ if (NVTE_UB_WITH_MPI) target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) endif() +if(USE_CUDA) option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF) if (NVTE_ENABLE_NVSHMEM) add_subdirectory(nvshmem_api) From 986d8ba8732bfdce4a4881efa8f978998bf54891 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Apr 2026 10:51:51 -0500 Subject: [PATCH 29/69] NVFP4: hadamard_transform_cast_fusion_columnwise --- transformer_engine/common/CMakeLists.txt | 4 +- .../hadamard_transform/hadamard_transform.cu | 69 +---- .../hadamard_transform_cast_fusion.cu | 281 ++++++++++++++++++ .../common/hadamard_transform/wht16.cuh | 82 +++++ transformer_engine/pytorch/csrc/quantizer.cpp | 6 - 5 files changed, 367 insertions(+), 75 deletions(-) create mode 100644 transformer_engine/common/hadamard_transform/wht16.cuh diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 059aefbca..511de348c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -228,6 +228,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/relu.cu activation/swiglu.cu hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) if(USE_CUDA) @@ -247,8 +248,7 @@ if(USE_CUDA) recipe/nvfp4.cu) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu - transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform_cast_fusion.cu) + transpose/quantize_transpose_square_blockwise.cu) else() #ROCm specific source codes list(APPEND transformer_engine_cpp_sources diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 5df865751..8808f9289 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -667,33 +667,9 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri #ifdef __HIP_PLATFORM_AMD__ -namespace { - -static constexpr int kHadamardDim = 16; -static constexpr int kWarpSize = 64; -static constexpr int kThreadsPerWHT = 4; -static constexpr int kElemsPerThread = 4; -static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 -static constexpr int kWarpsPerBlock = 4; -static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 -static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 -static constexpr float kHadamardScale = 0.25f; - -// ds_swizzle: sub-wavefront exchange without LDS. -// Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu. -__device__ __forceinline__ float ds_swizzle_xor1(float v) { - float r; - asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" - "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); - return r; -} +#include "wht16.cuh" -__device__ __forceinline__ float ds_swizzle_xor2(float v) { - float r; - asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" - "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); - return r; -} +namespace { // BF16 helpers __device__ __forceinline__ float to_f32 (__hip_bfloat16 v) { return static_cast(v); } @@ -721,47 +697,6 @@ __device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, fl | ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48); } -// 16-point WHT: in-register, no shared memory. -// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, -// extended with NV random_sign_mask (uint16_t bitmask). -// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). -// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). -__device__ __forceinline__ void wht16( - float& v0, float& v1, float& v2, float& v3, - int thread_in_group, uint16_t sign_mask, bool apply_pre) { - auto sgn = [&](int k) -> float { - return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; - }; - - if (apply_pre) { - v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); - } - - // Stage 1: local H4 - float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; - v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; - - // Stage 2: cross-thread XOR-1 - { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), - p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); - bool up=(thread_in_group&1); - v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); - v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } - - // Stage 3: cross-thread XOR-2 - { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), - p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); - bool up=(thread_in_group>>1)&1; - v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); - v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } - - v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; - - if (!apply_pre) { - v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); - } -} - // Grid: blockIdx.x = col tile [0, row_length/16) // blockIdx.y = row batch [0, ceil(num_rows/64)) // Block: 256 threads = 4 wavefronts of 64 lanes. diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 653de6206..bc3c044f9 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -9,19 +11,24 @@ #include #include #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif #include #include +#ifndef __HIP_PLATFORM_AMD__ #include #include #include +#endif #include "common/common.h" #include "common/util/cuda_runtime.h" #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/builders/sm100_common.inl" @@ -31,9 +38,11 @@ #include "cutlass/util/command_line.h" #include "cutlass/util/helper_cuda.hpp" #include "cutlass/util/print_error.hpp" +#endif // clang-format off +#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace detail { namespace { @@ -817,6 +826,278 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out } } // namespace transformer_engine +#else + +#include "wht16.cuh" + +namespace transformer_engine { + +namespace { + +__device__ __forceinline__ float to_f32(__hip_bfloat16 v) { return static_cast(v); } + +__device__ __forceinline__ float group_max_4(float v) { + v = fmaxf(v, ds_swizzle_xor1(v)); + v = fmaxf(v, ds_swizzle_xor2(v)); + return v; +} + +__device__ __forceinline__ float compute_global_encode_scale_fp4(const float global_amax) { +#if !defined(__HIP_DEVICE_COMPILE__) + const float fp8_max = detail::TypeExtrema::max; +#else + constexpr float fp8_max = detail::TypeExtrema::max; +#endif + constexpr float fp4_max = detail::TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + global_encode_scale = fminf(global_encode_scale, detail::TypeExtrema::max); + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +__device__ __forceinline__ ScaleType compute_decode_scale_fp4(const float amax, + const float global_encode_scale) { + float decode_scale = amax / detail::TypeExtrema::max; + decode_scale *= global_encode_scale; + decode_scale = fminf(decode_scale, detail::TypeExtrema::max); + return static_cast(decode_scale); +} + +template +__device__ __forceinline__ float compute_encode_scale_fp4(ScaleType decode_scale, + const float global_decode_scale) { + return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), + detail::TypeExtrema::max); +} + +__device__ __forceinline__ uint32_t get_rbits( + transformer_engine::curanddx::detail::philox4x32_native_state<10>& rng, uint4& random_uint4, + int& rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + random_uint4 = rng.generate4(); + } + const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); + return rbits_arr[rnd_idx++]; +} + +template +__device__ __forceinline__ fp4e2m1x4 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const uint32_t rbits) { + if constexpr (kUseStochasticRounding) { +#if ARCH_HAS_STOCHASTIC_ROUNDING + union { + uint32_t ui32; + __hip_fp4x2_storage_t fp4x2[4]; + } packed{0}; + __amd_floatx2_storage_t packed01{in01.x, in01.y}; + __amd_floatx2_storage_t packed23{in23.x, in23.y}; + packed.ui32 = + __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed01, rbits, 1.0f, 1); + const __hip_fp4x2_storage_t lo = packed.fp4x2[1]; + packed.ui32 = + __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed23, rbits, 1.0f, 1); + const __hip_fp4x2_storage_t hi = packed.fp4x2[1]; + + fp4e2m1x4 result; + result.__x = static_cast<__hip_fp4x4_storage_t>( + lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8)); + return result; +#else + NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); + return fp4e2m1x4{}; +#endif + } else { + const __hip_fp4_storage_t q0 = + __hip_cvt_float_to_fp4(in01.x, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q1 = + __hip_cvt_float_to_fp4(in01.y, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q2 = + __hip_cvt_float_to_fp4(in23.x, __HIP_E2M1, hipRoundNearest); + const __hip_fp4_storage_t q3 = + __hip_cvt_float_to_fp4(in23.y, __HIP_E2M1, hipRoundNearest); + + fp4e2m1x4 result; + result.__x = static_cast<__hip_fp4x4_storage_t>((q0 & 0xFu) | ((q1 & 0xFu) << 4) | + ((q2 & 0xFu) << 8) | ((q3 & 0xFu) << 12)); + return result; + } +} + +__device__ __forceinline__ uint16_t fp4x4_to_bits(fp4e2m1x4 v) { + uint16_t bits; + __builtin_memcpy(&bits, &v, sizeof(bits)); + return bits; +} + +template +__global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusionKernel( + const __hip_bfloat16* __restrict__ input, uint8_t* __restrict__ output_t, + fp8e4m3* __restrict__ scale_inv_t, const float global_amax, + const uint16_t random_sign_mask_t, const uint64_t num_rows, const uint64_t row_length, + const size_t scale_stride, const size_t* rng_state) { + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane_id = tid % kWarpSize; + const int row_in_warp = lane_id / kThreadsPerWHT; + const int thread_in_grp = lane_id % kThreadsPerWHT; + + const uint64_t output_row = static_cast(blockIdx.x) * kHadamardDim + row_in_warp; + const uint64_t block_row_base = + static_cast(blockIdx.y) * kRowsPerBlock + warp_id * kHadamardDim; + + if (block_row_base + kHadamardDim > num_rows) { + return; + } + + const uint64_t input_row_base = block_row_base + thread_in_grp * kElemsPerThread; + const uint64_t input_col = output_row; + + float c0 = to_f32(input[(input_row_base + 0) * row_length + input_col]); + float c1 = to_f32(input[(input_row_base + 1) * row_length + input_col]); + float c2 = to_f32(input[(input_row_base + 2) * row_length + input_col]); + float c3 = to_f32(input[(input_row_base + 3) * row_length + input_col]); + + wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, /*apply_pre=*/true); + + // Truncate to BF16 precision to match the reference BF16 matmul path. + // Without this, FP32 WHT results at FP4 quantization boundaries round + // differently than the BF16-precision reference, causing off-by-one errors. + c0 = to_f32(static_cast<__hip_bfloat16>(c0)); + c1 = to_f32(static_cast<__hip_bfloat16>(c1)); + c2 = to_f32(static_cast<__hip_bfloat16>(c2)); + c3 = to_f32(static_cast<__hip_bfloat16>(c3)); + + const float local_block_amax = + fmaxf(fmaxf(fabsf(c0), fabsf(c1)), fmaxf(fabsf(c2), fabsf(c3))); + const float block_amax = group_max_4(local_block_amax); + + const float global_encode_scale = compute_global_encode_scale_fp4(global_amax); + const float global_decode_scale = 1.0f / global_encode_scale; + const fp8e4m3 scale_inv = compute_decode_scale_fp4(block_amax, global_encode_scale); + const float encode_scale = compute_encode_scale_fp4(scale_inv, global_decode_scale); + + if (thread_in_grp == 0) { + const uint64_t scale_col = block_row_base / kHadamardDim; + scale_inv_t[output_row * scale_stride + scale_col] = scale_inv; + } + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + uint4 random_uint4{0, 0, 0, 0}; + int rnd_idx = 0; + if constexpr (kUseStochasticRounding) { + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + const size_t rng_sequence = static_cast(threadIdx.x) + + static_cast(blockIdx.x) * blockDim.x + + static_cast(blockIdx.y) * gridDim.x * blockDim.x; + rng.init(rng_seed, rng_sequence, rng_offset); + random_uint4 = rng.generate4(); + } + + const float2 scaled01{c0 * encode_scale, c1 * encode_scale}; + const float2 scaled23{c2 * encode_scale, c3 * encode_scale}; + const uint32_t rbits = kUseStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + const uint16_t packed = fp4x4_to_bits(cvt_fp32_to_fp4_4x( + scaled01, scaled23, rbits)); + + const uint64_t output_col_base = input_row_base; + const uint64_t output_byte_offset = output_row * (num_rows / 2) + output_col_base / 2; + *reinterpret_cast(&output_t[output_byte_offset]) = packed; +} + +uint16_t random_sign_mask_from_rht_matrix(const SimpleTensor& hadamard_matrix, cudaStream_t stream) { + std::array host_matrix{}; + + NVTE_CHECK_CUDA(cudaMemcpyAsync(host_matrix.data(), hadamard_matrix.dptr, + host_matrix.size() * sizeof(uint16_t), + cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + uint16_t random_sign_mask = 0; + for (size_t row = 0; row < kHadamardDim; ++row) { + // The first column of diag(sign) @ H16 is sign[row] * 0.25. + random_sign_mask |= static_cast(((host_matrix[row * kHadamardDim] >> 15) & 1) << row); + } + return random_sign_mask; +} + +} // namespace + +void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, + QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); + + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + NVTE_CHECK(output_.scaling_mode == NVTE_NVFP4_1D_SCALING, + "Output tensor must use NVFP4 scaling, but scaling mode is ", + to_string(output_.scaling_mode), "."); + NVTE_CHECK(output_.data.dptr != nullptr, "Output rowwise data must be allocated."); + NVTE_CHECK(output_.scale_inv.dptr != nullptr, "Output rowwise scale_inv must be allocated."); + NVTE_CHECK(output_.amax.dptr != nullptr, "Output rowwise amax must be allocated."); + + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const auto expected_hadamard_shape = std::vector{kHadamardDim, kHadamardDim}; + NVTE_CHECK(hadamard_matrix_.shape() == expected_hadamard_shape, + "Hadamard matrix must have shape=", + expected_hadamard_shape, + ", but got shape=", hadamard_matrix_.shape(), "."); + + const SimpleTensor& input = input_.data; + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + NVTE_CHECK(n % kHadamardDim == 0, "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(m % kHadamardDim == 0, "num_rows must be divisible by hadamard_dimension."); + + const size_t* rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor& rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + const uint16_t random_sign_mask_t = + random_sign_mask_from_rht_matrix(hadamard_matrix_.data, stream); + + const dim3 block(kThreadsPerBlock); + const dim3 grid(DIVUP(n, static_cast(kHadamardDim)), + DIVUP(m, static_cast(kRowsPerBlock))); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config.stochastic_rounding, kUseStochasticRounding, + HadamardTransformCastFusionKernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output_.data.dptr), + reinterpret_cast(output_.scale_inv.dptr), + *reinterpret_cast(output_.amax.dptr), random_sign_mask_t, + static_cast(m), static_cast(n), output_.scale_inv.shape[1], + rng_state);); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine +#endif void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, const NVTETensor hadamard_matrix, diff --git a/transformer_engine/common/hadamard_transform/wht16.cuh b/transformer_engine/common/hadamard_transform/wht16.cuh new file mode 100644 index 000000000..b9a1a51b7 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/wht16.cuh @@ -0,0 +1,82 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +// Shared 16-point Walsh-Hadamard transform primitives for AMDGPU. + +#ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ +#define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ + +#ifdef __HIP_PLATFORM_AMD__ + +static constexpr int kHadamardDim = 16; +static constexpr int kWarpSize = 64; +static constexpr int kThreadsPerWHT = 4; +static constexpr int kElemsPerThread = 4; +static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 +static constexpr int kWarpsPerBlock = 4; +static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 +static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 +static constexpr float kHadamardScale = 0.25f; + +// ds_swizzle: sub-wavefront exchange without LDS. +__device__ __forceinline__ float ds_swizzle_xor1(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +__device__ __forceinline__ float ds_swizzle_xor2(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +// 16-point WHT: in-register, no shared memory. +// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, +// extended with NV random_sign_mask (uint16_t bitmask). +// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). +// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). +__device__ __forceinline__ void wht16( + float& v0, float& v1, float& v2, float& v3, + int thread_in_group, uint16_t sign_mask, bool apply_pre) { + auto sgn = [&](int k) -> float { + return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; + }; + + if (apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } + + // Stage 1: local H4 + float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; + v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; + + // Stage 2: cross-thread XOR-1 + { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), + p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + bool up=(thread_in_group&1); + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + // Stage 3: cross-thread XOR-2 + { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), + p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + bool up=(thread_in_group>>1)&1; + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; + + if (!apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } +} + +#endif // __HIP_PLATFORM_AMD__ + +#endif // TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7ca0ab3bf..91321469b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1515,10 +1515,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; -#ifdef USE_ROCM - eligible_for_rht_cast_fusion = false; -#endif - // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { @@ -1666,7 +1662,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); }); } else { -#ifndef USE_ROCM // RHT cast fusion kernel. NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, "RHT matrix is not set"); @@ -1675,7 +1670,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_hadamard_transform_cast_fusion_columnwise( input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); }); -#endif } } } else { From b339c8670bbdacbcf0243247c96a1f3bf4ec25e2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Apr 2026 15:39:59 -0500 Subject: [PATCH 30/69] unify hadamard_transform_cast_fusion_columnwise --- .../hadamard_transform_cast_fusion.cu | 241 +++++++----------- 1 file changed, 99 insertions(+), 142 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index bc3c044f9..e6374a1e7 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -713,118 +713,6 @@ rht_gemm_ttt_wrapper(int m, int n, // clang-format on -void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, - const Tensor &hadamard_matrix_, - QuantizationConfig quant_config, - cudaStream_t stream) { - NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); - - // Check input and output tensors - NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "Input tensor must be BF16 tensor, but scaling mode is ", - to_string(input_.scaling_mode), "."); - NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, - "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); - NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); - const SimpleTensor &input = input_.data; - SimpleTensor &global_amax = output_.amax; - SimpleTensor &output_t = output_.data; - SimpleTensor &scale_inv_t = output_.scale_inv; - - // Stochastic rounding config - const bool use_stochastic_rounding = quant_config.stochastic_rounding; - const size_t *rng_state = nullptr; - if (quant_config.rng_state != nullptr) { - Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); - NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, - "RNG state should contain 2 64-bit values."); - NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, - "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); - rng_state = reinterpret_cast(rng_state_tensor.data.dptr); - } - - // Template arguments - using TA = cute::bfloat16_t; - using TB = cute::bfloat16_t; - using TC = cutlass::float_e2m1_t; - using TSFC = cutlass::float_ue4m3_t; - - checkCuDriverContext(stream); - - // Check Hadamard matrix - constexpr int kHadamardDimension = 16; - NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "Hadamard matrix must be BF16 tensor, but scaling mode is ", - to_string(hadamard_matrix_.scaling_mode), "."); - NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, - "Hadamard matrix must be BF16 tensor, but dtype is ", - to_string(hadamard_matrix_.dtype()), "."); - const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; - NVTE_CHECK( - (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), - "Hadamard matrix must have shape=", - std::vector{kHadamardDimension, kHadamardDimension}, - ", but got shape=", hadamard_matrix_.shape(), "."); - const size_t hadamard_dimension = hadamard_matrix.shape[0]; - - const size_t ndim = input.shape.size(); - const size_t n = input.shape[ndim - 1]; - size_t m = 1; - for (size_t i = 0; i < ndim - 1; ++i) { - m *= input.shape[i]; - } - - auto sm_count = transformer_engine::cuda::sm_count(); - - NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); - - NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); - - int k_tile_size = 1024; - - if (m == 8192 && n == 5120) { - k_tile_size = 512; - } else if (m == 8192 && n == 10240) { - k_tile_size = 1024; - } else if (m == 8192 && n == 2560) { - k_tile_size = 1280; - } else if (m == 8192 && n == 11328) { - k_tile_size = 1024; - } else if (m == 8192 && n == 512) { - k_tile_size = 256; - } else if (m == 8192 && n == 3584) { - k_tile_size = 512; - } else if (m == 11328 && n == 8192) { - k_tile_size = 1024; - } else if (m == 5120 && n == 8192) { - k_tile_size = 512; - } else if (m == 10240 && n == 8192) { - k_tile_size = 1024; - } else if (m == 2560 && n == 8192) { - k_tile_size = 1280; - } else if (m == 512 && n == 8192) { - k_tile_size = 256; - } else if (m == 3584 && n == 8192) { - k_tile_size = 512; - } else if (m < 1024 || n < 1024) { - k_tile_size = 512; - } - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, kUseStochasticRounding, - detail::rht_gemm_ttt_wrapper( - /*m=*/m, - /*n=*/n, - /*A=*/reinterpret_cast(input.dptr), - /*B=*/reinterpret_cast(hadamard_matrix.dptr), - /*C=*/reinterpret_cast(output_t.dptr), - /*SFC=*/reinterpret_cast(scale_inv_t.dptr), - /*global_amax=*/reinterpret_cast(global_amax.dptr), - /*rng_state=*/rng_state, - /*sm_count=*/sm_count, - /*stream=*/stream, - /*k_tile_size=*/k_tile_size);); -} - } // namespace transformer_engine #else @@ -1024,38 +912,67 @@ uint16_t random_sign_mask_from_rht_matrix(const SimpleTensor& hadamard_matrix, c } // namespace +} // namespace transformer_engine +#endif + +namespace transformer_engine { + void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, const Tensor &hadamard_matrix_, QuantizationConfig quant_config, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); + // Check input and output tensors NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Input tensor must be BF16 tensor, but scaling mode is ", to_string(input_.scaling_mode), "."); NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); - NVTE_CHECK(output_.scaling_mode == NVTE_NVFP4_1D_SCALING, - "Output tensor must use NVFP4 scaling, but scaling mode is ", - to_string(output_.scaling_mode), "."); - NVTE_CHECK(output_.data.dptr != nullptr, "Output rowwise data must be allocated."); - NVTE_CHECK(output_.scale_inv.dptr != nullptr, "Output rowwise scale_inv must be allocated."); - NVTE_CHECK(output_.amax.dptr != nullptr, "Output rowwise amax must be allocated."); + const SimpleTensor &input = input_.data; + SimpleTensor &global_amax = output_.amax; + SimpleTensor &output_t = output_.data; + SimpleTensor &scale_inv_t = output_.scale_inv; + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } +#ifndef __HIP_PLATFORM_AMD__ + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); +#endif + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Hadamard matrix must be BF16 tensor, but scaling mode is ", to_string(hadamard_matrix_.scaling_mode), "."); NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, "Hadamard matrix must be BF16 tensor, but dtype is ", to_string(hadamard_matrix_.dtype()), "."); - const auto expected_hadamard_shape = std::vector{kHadamardDim, kHadamardDim}; - NVTE_CHECK(hadamard_matrix_.shape() == expected_hadamard_shape, - "Hadamard matrix must have shape=", - expected_hadamard_shape, - ", but got shape=", hadamard_matrix_.shape(), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; - const SimpleTensor& input = input_.data; const size_t ndim = input.shape.size(); const size_t n = input.shape[ndim - 1]; size_t m = 1; @@ -1063,41 +980,81 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out m *= input.shape[i]; } - NVTE_CHECK(n % kHadamardDim == 0, "row_length must be divisible by hadamard_dimension."); - NVTE_CHECK(m % kHadamardDim == 0, "num_rows must be divisible by hadamard_dimension."); +#ifndef __HIP_PLATFORM_AMD__ + auto sm_count = transformer_engine::cuda::sm_count(); +#endif - const size_t* rng_state = nullptr; - if (quant_config.rng_state != nullptr) { - Tensor& rng_state_tensor = *convertNVTETensor(quant_config.rng_state); - NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, - "RNG state should contain 2 64-bit values."); - NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, - "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); - rng_state = reinterpret_cast(rng_state_tensor.data.dptr); - } + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + +#ifndef __HIP_PLATFORM_AMD__ + int k_tile_size = 1024; + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size);); +#else const uint16_t random_sign_mask_t = - random_sign_mask_from_rht_matrix(hadamard_matrix_.data, stream); + random_sign_mask_from_rht_matrix(hadamard_matrix, stream); const dim3 block(kThreadsPerBlock); const dim3 grid(DIVUP(n, static_cast(kHadamardDim)), DIVUP(m, static_cast(kRowsPerBlock))); TRANSFORMER_ENGINE_SWITCH_CONDITION( - quant_config.stochastic_rounding, kUseStochasticRounding, + use_stochastic_rounding, kUseStochasticRounding, HadamardTransformCastFusionKernel<<>>( reinterpret_cast(input.dptr), - reinterpret_cast(output_.data.dptr), - reinterpret_cast(output_.scale_inv.dptr), - *reinterpret_cast(output_.amax.dptr), random_sign_mask_t, - static_cast(m), static_cast(n), output_.scale_inv.shape[1], + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv_t.dptr), + *reinterpret_cast(global_amax.dptr), random_sign_mask_t, + static_cast(m), static_cast(n), scale_inv_t.shape[1], rng_state);); NVTE_CHECK_CUDA(cudaGetLastError()); +#endif } } // namespace transformer_engine -#endif void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, const NVTETensor hadamard_matrix, From 1d0a70eb69ff84bdfd4f32f308e6b4873a1639bd Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 2 Apr 2026 15:12:12 +0000 Subject: [PATCH 31/69] Rebase onto dev --- tests/cpp/operator/CMakeLists.txt | 1 + .../cpp/operator/test_cast_nvfp4_transpose.cu | 7 - tests/cpp/operator/test_dequantize_nvfp4.cu | 410 ++++++++++++++++++ tests/cpp/test_common.h | 22 +- .../common/cast/dispatch/dequantize.cuh | 4 - ...quantize_transpose_vector_blockwise_fp4.cu | 15 +- 6 files changed, 446 insertions(+), 13 deletions(-) create mode 100644 tests/cpp/operator/test_dequantize_nvfp4.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index dfd8fba29..8a19e84f5 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -14,6 +14,7 @@ list(APPEND test_cuda_sources test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu + test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 4e42fad92..50f2d36fe 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -32,13 +32,6 @@ enum ActivationType { SReLU }; -#ifdef __HIP_PLATFORM_AMD__ -static constexpr float E2M1_LUT[16] = { - 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, - -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, -}; -#endif - double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { #ifdef __HIP_PLATFORM_AMD__ uint8_t raw = *reinterpret_cast(&fp4_pair); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu new file mode 100644 index 000000000..1da70923a --- /dev/null +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -0,0 +1,410 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +constexpr size_t kFP4BlockSize1D = 16; +constexpr size_t kFP4BlockSize2DY = 16; +constexpr size_t kFP4BlockSize2DX = 16; + +size_t divide_round_up(size_t x, size_t y) { + return (x + y - 1) / y; +} + +// Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns +// and rejects non-finite values (i.e., NaN) before storing. +// Values are written using memcpy to preserve exact +// bit patterns rather than relying on numeric conversion. +void generate_1d_scales(fp8e4m3* scale_buffer, + const size_t mathematical_rows, + const size_t mathematical_blocks_per_row, + const size_t physical_row_stride, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + const size_t total_elems = mathematical_rows * physical_row_stride; + std::memset(scale_buffer, 0, total_elems * sizeof(fp8e4m3)); + + for (size_t row = 0; row < mathematical_rows; ++row) { + for (size_t block = 0; block < mathematical_blocks_per_row; ++block) { + const size_t idx = row * physical_row_stride + block; + + while (true) { + const uint8_t bits = static_cast(dis(gen)); + + fp8e4m3 candidate; + std::memcpy(&candidate, &bits, sizeof(bits)); + + const float decoded = static_cast(candidate); + if (std::isfinite(decoded)) { + scale_buffer[idx] = candidate; + break; + } + } + } + } +} + +// Generate compact 2D scales over 16x16 tiles, then replicate them row-wise +// into the physical scale layout expected by the existing 1D dequant kernel. +// +// replicated[row][block_x] = compact_2d[row / 16][block_x] +void generate_2d_scales_with_replication(fp8e4m3* scale_buffer, + const size_t rows, + const size_t cols, + const size_t mathematical_scale_rows, + const size_t mathematical_scale_blocks_per_row, + const size_t physical_row_stride, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + const size_t total_elems = mathematical_scale_rows * physical_row_stride; + std::memset(scale_buffer, 0, total_elems * sizeof(fp8e4m3)); + + const size_t blocks_y = divide_round_up(rows, kFP4BlockSize2DY); + const size_t blocks_x = divide_round_up(cols, kFP4BlockSize2DX); + + std::vector compact_2d(blocks_y * blocks_x); + + for (size_t by = 0; by < blocks_y; ++by) { + for (size_t bx = 0; bx < blocks_x; ++bx) { + while (true) { + const uint8_t bits = static_cast(dis(gen)); + + fp8e4m3 candidate; + std::memcpy(&candidate, &bits, sizeof(bits)); + + const float decoded = static_cast(candidate); + if (std::isfinite(decoded)) { + compact_2d[by * blocks_x + bx] = candidate; + break; + } + } + } + } + + for (size_t row = 0; row < mathematical_scale_rows; ++row) { + const size_t by = row / kFP4BlockSize2DY; + for (size_t bx = 0; bx < mathematical_scale_blocks_per_row; ++bx) { + const size_t dst_idx = row * physical_row_stride + bx; + scale_buffer[dst_idx] = compact_2d[by * blocks_x + bx]; + } + } +} + +// Write one mathematical FP4 E2M1 value, represented as a raw nibble [0, 15], +// into packed storage. Two mathematical FP4 values are packed per byte: +// even mathematical index -> low nibble, odd mathematical index -> high nibble. +void set_fp4_nibble(fp4e2m1* data, const size_t mathematical_idx, const uint8_t nibble) { + ASSERT_TRUE(nibble < 16); + auto* raw = reinterpret_cast(data); + const size_t byte_idx = mathematical_idx / 2; + const uint8_t val = nibble; + + if ((mathematical_idx % 2) == 0) { + // set low nibble + raw[byte_idx] = static_cast((raw[byte_idx] & 0xF0) | val); + } else { + // set high nibble + raw[byte_idx] = static_cast((raw[byte_idx] & 0x0F) | (val << 4)); + } +} + +// Populate FP4 (E2M1) tensor using packed 4-bit encoding, and simultaneously +// populate its mathematical transpose in packed storage. +// +// data has mathematical shape [rows, cols] +// data_t has mathematical shape [cols, rows] +void generate_data_and_transpose(fp4e2m1* data, + fp4e2m1* data_t, + const size_t rows, + const size_t cols, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + const size_t packed_bytes = (rows * cols * BitsNumber::num_bits) / 8; + + std::memset(data, 0, packed_bytes); + std::memset(data_t, 0, packed_bytes); + + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const uint8_t nibble = static_cast(dis(gen)) & 0xF; + + const size_t idx = i * cols + j; + set_fp4_nibble(data, idx, nibble); + + const size_t idx_t = j * rows + i; + set_fp4_nibble(data_t, idx_t, nibble); + } + } +} + +// Decode a single FP4 (E2M1) value from packed storage. +float get_fp4_value(const fp4e2m1* data, const size_t mathematical_idx) { + const auto* raw = reinterpret_cast(data); + const size_t byte_idx = mathematical_idx / 2; + const uint8_t packed = raw[byte_idx]; + const uint8_t nibble = (mathematical_idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF); + return E2M1_LUT[nibble]; +} + +// Reference implementation: dequantize packed FP4 (E2M1) input using per-block FP8_E4M3 scales. +// Each block of 1x16 elements shares one scale; values are decoded to float and scaled. +template +void compute_ref(const fp4e2m1* input, + OutputType* output, + const fp8e4m3* scales, + const float amax, + const size_t rows, + const size_t cols, + const size_t scale_stride) { + constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + + const size_t blocks_per_row = cols / kFP4BlockSize1D; + + for (size_t i = 0; i < rows; ++i) { + for (size_t b = 0; b < blocks_per_row; ++b) { + const float scale = + static_cast(scales[i * scale_stride + b]) * amax * factor_inv; + + for (size_t k = 0; k < kFP4BlockSize1D; ++k) { + const size_t col = b * kFP4BlockSize1D + k; + const size_t idx = i * cols + col; + const float x = get_fp4_value(input, idx); + output[idx] = static_cast(x * scale); + } + } + } +} + +template +void run_single_case(const std::string& case_name, + const fp4e2m1* host_input, + const fp8e4m3* host_scales, + const size_t rows, + const size_t cols, + const size_t blocks_y, + const size_t blocks_x, + const size_t scale_stride, + const float amax, + DType otype) { + const DType itype = DType::kFloat4E2M1; + + Tensor input(case_name + "_input", std::vector{rows, cols}, itype, + true, false, NVTE_NVFP4_1D_SCALING); + Tensor output(case_name + "_output", std::vector{rows, cols}, otype, true, false); + + std::unique_ptr ref_output = + std::make_unique(rows * cols); + + const size_t data_bytes = (rows * cols * BitsNumber::num_bits) / 8; + const size_t scale_bytes = blocks_y * blocks_x * sizeof(fp8e4m3); + + auto err = cudaMemcpy(input.rowwise_dptr(), + host_input, + data_bytes, + cudaMemcpyHostToDevice); + ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); + + err = cudaMemcpy(input.rowwise_scale_inv_dptr(), + host_scales, + scale_bytes, + cudaMemcpyHostToDevice); + ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); + + input.set_tensor_amax(amax); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << case_name << ": " << cudaGetErrorString(err); + + output.to_cpu(); + + compute_ref(host_input, + ref_output.get(), + host_scales, + amax, + rows, + cols, + scale_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults(case_name, output, ref_output.get(), true, atol, rtol); +} + +// End-to-end test: generate random FP4 input and FP8 scales, then exercise +// 1) row-wise 1D dequant +// 2) col-wise 1D dequant (by running the same dequant kernel on transposed data) +// 3) 2D dequant semantics using row-wise replicated scales +template +void performTest(const size_t rows, const size_t cols, DType otype) { + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + std::unique_ptr host_input = + std::make_unique(rows * cols); + + std::unique_ptr host_input_t = + std::make_unique(rows * cols); + + std::unique_ptr host_scales_rowwise_1d = + std::make_unique(blocks_Y * blocks_X); + + std::unique_ptr host_scales_colwise_1d = + std::make_unique(blocks_Y_t * blocks_X_t); + + std::unique_ptr host_scales_2d_replicated = + std::make_unique(blocks_Y * blocks_X); + + static std::mt19937 gen(42); + std::uniform_int_distribution fp4_dis(0, 15); + std::uniform_int_distribution fp8_dis(0, 255); + + generate_data_and_transpose(host_input.get(), + host_input_t.get(), + rows, + cols, + gen, + fp4_dis); + + // Row-wise 1D scales on [rows, cols] + generate_1d_scales(host_scales_rowwise_1d.get(), + unpadded_blocks_Y, + unpadded_blocks_X, + scales_stride, + gen, + fp8_dis); + + // Col-wise 1D scales on [cols, rows] + generate_1d_scales(host_scales_colwise_1d.get(), + unpadded_blocks_Y_t, + unpadded_blocks_X_t, + scales_stride_t, + gen, + fp8_dis); + + // 2D scales replicated row-wise + generate_2d_scales_with_replication(host_scales_2d_replicated.get(), + rows, + cols, + unpadded_blocks_Y, + unpadded_blocks_X, + scales_stride, + gen, + fp8_dis); + + const float amax = 1.0f; + + run_single_case("rowwise_1d_dequant", + host_input.get(), + host_scales_rowwise_1d.get(), + rows, + cols, + blocks_Y, + blocks_X, + scales_stride, + amax, + otype); + + run_single_case("colwise_1d_dequant", + host_input_t.get(), + host_scales_colwise_1d.get(), + cols, + rows, + blocks_Y_t, + blocks_X_t, + scales_stride_t, + amax, + otype); + + run_single_case("replicated_2d_dequant", + host_input.get(), + host_scales_2d_replicated.get(), + rows, + cols, + blocks_Y, + blocks_X, + scales_stride, + amax, + otype); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, +}; + +} // namespace + +class DequantizeNVFP4TestSuite + : public ::testing::TestWithParam< + std::tuple, transformer_engine::DType>> {}; + +TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { + const auto tensor_size = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + const size_t rows = tensor_size.first; + const size_t cols = tensor_size.second; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + output_type, OutputType, + performTest(rows, cols, output_type);); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + test::typeName(std::get<1>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index e2bfdfd57..9dfc2e1b5 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -301,6 +301,21 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_tensor_amax(float amax) { + if (!amax_cpu_data_) { + amax_cpu_data_ = std::make_shared(amax); + } else { + *amax_cpu_data_ = amax; + } + + float *amax_gpu = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(), + sizeof(float), cudaMemcpyHostToDevice)); + + tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); @@ -356,6 +371,11 @@ constexpr size_t scale_tensor_alignment_Y_colwise = 4; constexpr size_t scale_tensor_alignment_X_colwise = 128; #endif +static constexpr float E2M1_LUT[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, +}; + inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; } @@ -519,7 +539,7 @@ template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, #ifdef USE_ROCM - std::vector& mismatch_indices, + std::vector& mismatch_indices, #endif //#ifdef USE_ROCM size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 1bdd7e218..6b70c7582 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -18,9 +18,7 @@ #include "../../common.h" #include "../fp8/dequantize_fp8.cuh" #include "../mxfp8/dequantize_mxfp8.cuh" -#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/dequantize_nvfp4.cuh" -#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -49,12 +47,10 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t #endif //#ifndef __HIP_PLATFORM_AMD__ break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { nvfp4::dequantize(input, output, stream); break; } -#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 59742d1e7..bdcfd8d0b 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -157,12 +157,25 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); // for 2D block scaling, we need to reduce amax in warp +#ifdef __HIP_PLATFORM_AMD__ +static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x0101010101010101ULL, 0x0202020202020202ULL, + 0x0404040404040404ULL, 0x0808080808080808ULL, + 0x1010101010101010ULL, 0x2020202020202020ULL, + 0x4040404040404040ULL, 0x8080808080808080ULL}; +#else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; +#endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { +__device__ __forceinline__ float groupMax(float val, +#ifdef __HIP_PLATFORM_AMD__ + uint64_t groupMask) { +#else + unsigned int groupMask) { +#endif for (int offset = group_size / 2; offset > 0; offset /= 2) { #ifdef __HIP_PLATFORM_AMD__ (void)groupMask; // unused on AMD, __shfl_down does not take a mask From 6e3eea50e5be7f4a0c1b81899a84b6356cd98e22 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 2 Apr 2026 11:27:27 -0500 Subject: [PATCH 32/69] Enable NVFP4 recipe --- transformer_engine/common/CMakeLists.txt | 6 +++--- transformer_engine/pytorch/quantization.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 27c26f7a8..9a23efc25 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -214,7 +214,8 @@ list(APPEND transformer_engine_cuda_sources fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + recipe/fp8_block_scaling.cu + recipe/nvfp4.cu) list(APPEND transformer_engine_cuda_arch_specific_sources cast/cast.cu @@ -238,8 +239,7 @@ if(USE_CUDA) fused_attn/fused_attn_fp8.cu fused_attn/utils.cu swizzle/swizzle.cu - swizzle/swizzle_block_scaling.cu - recipe/nvfp4.cu) + swizzle/swizzle_block_scaling.cu) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 37766f5ce..8eded3622 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,7 +87,7 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: if IS_HIP_EXTENSION: - return False, "ROCm TE currently not supporting NVFP4" + return True, "" """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" From 9c3dc2f4b75ed1c90ca725a3e3116f97a2b4a6a7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 2 Apr 2026 11:43:44 -0500 Subject: [PATCH 33/69] NVFP4 GEMM via BF16 dequant --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 11 +- transformer_engine/common/gemm/rocm_gemm.cu | 127 ++++++++++++++++++- 2 files changed, 136 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 4c5ff59fc..03a2080f0 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,8 @@ from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils +from torch.utils.cpp_extension import IS_HIP_EXTENSION + recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -143,7 +147,12 @@ def check_nvfp4_gemm_versus_reference( # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) # Allocate cuBLAS workspace - workspace = torch.empty(4, dtype=torch.uint8, device=device) + if IS_HIP_EXTENSION: + # On ROCm, FP4 is dequantized to BF16 in workspace before GEMM, so allocate enough space. + ws_bytes = M * K * 2 + K * N * 2 + 32 * 1024 * 1024 + workspace = torch.empty(ws_bytes, dtype=torch.uint8, device=device) + else: + workspace = torch.empty(4, dtype=torch.uint8, device=device) transa = True if not w_columnwise else False transb = False if not x_columnwise else True diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 3bc8d9bc8..3fcf12bdd 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -197,6 +197,54 @@ struct GemmParam { int ldb = 0; // B column strides }; +// FP4 e2m1 lookup table +__device__ constexpr float kFP4E2M1Table[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f,-0.5f,-1.0f,-1.5f,-2.0f,-3.0f,-4.0f,-6.0f +}; + +// Dequantize FP4 (e2m1) packed data with FP8 e4m3 block scales to BF16. +// NOTE: The per-tensor amax factor is NOT applied here — it is folded into the GEMM alpha instead, +// matching how cuBLASLt handles NVFP4 on CUDA. This avoids BF16 precision loss from the extra scaling. +__global__ void dequant_fp4_to_bf16_kernel( + const uint8_t* __restrict__ data, + const fp8e4m3* __restrict__ scale_inv, + hip_bfloat16* __restrict__ output, + int64_t total_elements) +{ + // Process 2 elements (1 byte) per iteration for coalesced access + const int64_t pair_idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t total_pairs = total_elements / 2; + if (pair_idx >= total_pairs) return; + + const uint8_t byte = data[pair_idx]; + const uint8_t lo_nibble = byte & 0xF; + const uint8_t hi_nibble = byte >> 4; + + const int64_t elem_base = pair_idx * 2; + const float s0 = static_cast(scale_inv[elem_base / 16]); + const float s1 = static_cast(scale_inv[(elem_base + 1) / 16]); + + output[elem_base] = static_cast(kFP4E2M1Table[lo_nibble] * s0); + output[elem_base + 1] = static_cast(kFP4E2M1Table[hi_nibble] * s1); +} + +// Launch helper for dequant kernel +static void launch_dequant_fp4_to_bf16( + const void* data, const void* scale_inv, + void* output, int64_t total_elements, hipStream_t stream) +{ + constexpr int kBlockSize = 256; + const int64_t total_pairs = total_elements / 2; + const int64_t num_blocks = (total_pairs + kBlockSize - 1) / kBlockSize; + + dequant_fp4_to_bf16_kernel<<>>( + reinterpret_cast(data), + reinterpret_cast(scale_inv), + reinterpret_cast(output), + total_elements); +} + GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, const int m, const int n, const int k) { @@ -245,6 +293,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = is_A_transposed ? k : m; + } else if (is_nvfp_scaling(A.scaling_mode)) { + // NVFP4 + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; // NVFP4 gemm is always TN layout + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; } else { NVTE_ERROR("A has unsupported scaling mode"); } @@ -283,6 +338,13 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; + } else if (is_nvfp_scaling(B.scaling_mode)) { + // NVFP4 + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; // NVFP4 gemm is always TN layout + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; } else { NVTE_ERROR("B has unsupported scaling mode"); } @@ -951,7 +1013,70 @@ void hipblaslt_gemm(const Tensor *inputA, } NVTE_CHECK(k > 0); - const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + GemmParam param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, m, n, k); + + // FP4 dequant path: hipBLASLt does not support FP4 natively, + // so we dequantize FP4 -> BF16 (block scales only) and run a standard BF16 GEMM. + // The per-tensor amax factor is folded into the GEMM alpha, matching the cuBLASLt approach. + const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype); + if (use_fp4) { + const float* amax_A = (transa == CUBLAS_OP_T) + ? reinterpret_cast(inputA->amax.dptr) + : reinterpret_cast(inputA->columnwise_amax.dptr); + const float* amax_B = (transb == CUBLAS_OP_N) + ? reinterpret_cast(inputB->amax.dptr) + : reinterpret_cast(inputB->columnwise_amax.dptr); + + // Fold per-tensor amax into alpha: alpha *= amax_A * amax_B / (fp4_max * fp8_max)^2 + // This matches the CUDA path (nvfp4.cu: compute_nvfp4_per_tensor_scale_kernel). + const float fp4_max = 6.0f; + const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); + float h_amax_A = 1.0f, h_amax_B = 1.0f; + if (amax_A != nullptr) { + NVTE_CHECK_CUDA(hipMemcpyAsync(&h_amax_A, amax_A, sizeof(float), hipMemcpyDeviceToHost, stream)); + } + if (amax_B != nullptr) { + NVTE_CHECK_CUDA(hipMemcpyAsync(&h_amax_B, amax_B, sizeof(float), hipMemcpyDeviceToHost, stream)); + } + NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); + alpha *= h_amax_A * h_amax_B * factor_inv; + + // Compute workspace needed for dequantized BF16 buffers + const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; + const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; + const size_t dequant_ws = (a_bf16_bytes + b_bf16_bytes + 255) & ~size_t(255); + NVTE_CHECK(workspaceSize >= dequant_ws, + "NVFP4 GEMM requires at least ", dequant_ws, " bytes workspace for FP4->BF16 dequant, " + "but only ", workspaceSize, " bytes available."); + + uint8_t* ws_ptr = reinterpret_cast(workspace); + size_t ws_offset = 0; + + if (is_fp4_dtype(param.Atype)) { + hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr + ws_offset); + const int64_t total_a = static_cast(m) * k; + launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, stream); + param.A = a_bf16; + param.Atype = DType::kBFloat16; + param.A_scale_inv = nullptr; + ws_offset += a_bf16_bytes; + } + + if (is_fp4_dtype(param.Btype)) { + hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr + ws_offset); + const int64_t total_b = static_cast(k) * n; + launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, stream); + param.B = b_bf16; + param.Btype = DType::kBFloat16; + param.B_scale_inv = nullptr; + ws_offset += b_bf16_bytes; + } + + // Advance workspace past dequant buffers + workspace = ws_ptr + ((ws_offset + 255) & ~size_t(255)); + workspaceSize -= dequant_ws; + } bool nvte_log_gemm_config = false; if (const char* env_p = std::getenv("NVTE_LOG_GEMM_CONFIG") ) { From 3a63f32672aa9f7500bd76a29db1b11c4cf16536 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 2 Apr 2026 13:40:59 -0500 Subject: [PATCH 34/69] add explanation to wht16 Co-authored-by: Ye Wang --- .../hadamard_transform/hadamard_transform.cu | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 5df865751..57704cc1e 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -721,7 +721,49 @@ __device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, fl | ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48); } -// 16-point WHT: in-register, no shared memory. +// ----------------------------------------------------------------------- +// 16-point WHT via the Kronecker trick (no shared memory) +// ----------------------------------------------------------------------- +// +// 1. The vec operator +// vec() flattens a matrix into a column vector by stacking its +// columns one on top of the other: +// +// X = |a c| vec(X) = |a| +// |b d| |b| +// |c| +// |d| +// +// 2. The "Kronecker trick" for 1D -> 2D +// The fundamental identity that connects these concepts is: +// +// vec(B . X . A^T) = (A (x) B) . vec(X) +// +// For a 16-point Hadamard transform (H16 = H4 (x) H4), +// set A = H4 and B = H4. The formula becomes: +// +// H16 . x = vec(H4 . X . H4^T) +// +// 3. Data layout (column-major, one column per thread) +// Reshape the 16-element 1D vector x into a 4x4 matrix X +// by filling columns first: +// +// X = | x0 x4 x8 x12 | thread 0 holds col 0: v0..v3 = x0 ..x3 +// | x1 x5 x9 x13 | thread 1 holds col 1: v0..v3 = x4 ..x7 +// | x2 x6 x10 x14 | thread 2 holds col 2: v0..v3 = x8 ..x11 +// | x3 x7 x11 x15 | thread 3 holds col 3: v0..v3 = x12..x15 +// +// 4. Three-stage computation +// Stage 1 (local H4) : left-multiply H4 . X (within each thread) +// Stage 2 (xor-1 swap) : \ (across 4 threads) +// Stage 3 (xor-2 swap) : / right-multiply . H4^T together these two butterfly stages = H4^T +// +// Result: vec(H4 . X . H4^T) = H16 . x +// +// 5. Randomised Hadamard Transform (RHT) +// A diagonal sign matrix D (from sign_mask) is applied either +// before the WHT (apply_pre=true, forward) or after (inverse). +// // Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, // extended with NV random_sign_mask (uint16_t bitmask). // thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). From 35ef81c3db91718d6b11722234a59479736cb33e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 2 Apr 2026 14:01:56 -0500 Subject: [PATCH 35/69] comment and test --- ci/pytorch.sh | 2 + .../common/hadamard_transform/wht16.cuh | 44 ++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index a7259ebf6..74a120505 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -84,6 +84,8 @@ run_test_config(){ NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_numerics.py NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_fusible_ops.py NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 triton_kernels/test_cast.py + + run_default_fa 1 nvfp4/ } run_test_config_mgpu(){ diff --git a/transformer_engine/common/hadamard_transform/wht16.cuh b/transformer_engine/common/hadamard_transform/wht16.cuh index b9a1a51b7..751ecba68 100644 --- a/transformer_engine/common/hadamard_transform/wht16.cuh +++ b/transformer_engine/common/hadamard_transform/wht16.cuh @@ -36,7 +36,49 @@ __device__ __forceinline__ float ds_swizzle_xor2(float v) { return r; } -// 16-point WHT: in-register, no shared memory. +// ----------------------------------------------------------------------- +// 16-point WHT via the Kronecker trick (no shared memory) +// ----------------------------------------------------------------------- +// +// 1. The vec operator +// vec() flattens a matrix into a column vector by stacking its +// columns one on top of the other: +// +// X = |a c| vec(X) = |a| +// |b d| |b| +// |c| +// |d| +// +// 2. The "Kronecker trick" for 1D -> 2D +// The fundamental identity that connects these concepts is: +// +// vec(B . X . A^T) = (A (x) B) . vec(X) +// +// For a 16-point Hadamard transform (H16 = H4 (x) H4), +// set A = H4 and B = H4. The formula becomes: +// +// H16 . x = vec(H4 . X . H4^T) +// +// 3. Data layout (column-major, one column per thread) +// Reshape the 16-element 1D vector x into a 4x4 matrix X +// by filling columns first: +// +// X = | x0 x4 x8 x12 | thread 0 holds col 0: v0..v3 = x0 ..x3 +// | x1 x5 x9 x13 | thread 1 holds col 1: v0..v3 = x4 ..x7 +// | x2 x6 x10 x14 | thread 2 holds col 2: v0..v3 = x8 ..x11 +// | x3 x7 x11 x15 | thread 3 holds col 3: v0..v3 = x12..x15 +// +// 4. Three-stage computation +// Stage 1 (local H4) : left-multiply H4 . X (within each thread) +// Stage 2 (xor-1 swap) : \ (across 4 threads) +// Stage 3 (xor-2 swap) : / right-multiply . H4^T together these two butterfly stages = H4^T +// +// Result: vec(H4 . X . H4^T) = H16 . x +// +// 5. Randomised Hadamard Transform (RHT) +// A diagonal sign matrix D (from sign_mask) is applied either +// before the WHT (apply_pre=true, forward) or after (inverse). +// // Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, // extended with NV random_sign_mask (uint16_t bitmask). // thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). From e8ff6bd1ca2e46aecca54aadb709ba9c28285672 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 3 Apr 2026 15:22:18 -0500 Subject: [PATCH 36/69] enable use_fused_bulk_alloc --- transformer_engine/pytorch/csrc/extensions/cast.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 16e800c48..79dec9af8 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -493,7 +493,6 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } -#ifndef USE_ROCM // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate @@ -694,7 +693,6 @@ std::tuple, std::vector> bulk_allocate_nv return retval; } -#endif // #ifndef USE_ROCM } // namespace @@ -793,7 +791,6 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); -#ifndef USE_ROCM } else if (is_nvfp4) { // NVFP4: construct output tensors with bulk allocations std::vector nvfp4_quantizers; @@ -802,7 +799,6 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); -#endif } else { NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); } From e26ffc83e879d3e4d91c613e93cbd6a4270c8b43 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 3 Apr 2026 17:43:57 -0500 Subject: [PATCH 37/69] compute random sign mask on device --- .../hadamard_transform_cast_fusion.cu | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index e6374a1e7..6a766f864 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -821,9 +821,24 @@ __device__ __forceinline__ uint16_t fp4x4_to_bits(fp4e2m1x4 v) { template __global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusionKernel( const __hip_bfloat16* __restrict__ input, uint8_t* __restrict__ output_t, - fp8e4m3* __restrict__ scale_inv_t, const float global_amax, - const uint16_t random_sign_mask_t, const uint64_t num_rows, const uint64_t row_length, - const size_t scale_stride, const size_t* rng_state) { + fp8e4m3* __restrict__ scale_inv_t, const float* __restrict__ global_amax_ptr, + const __hip_bfloat16* __restrict__ hadamard_matrix, const uint64_t num_rows, + const uint64_t row_length, const size_t scale_stride, const size_t* rng_state) { + + // Thread 0 loads global_amax and computes random sign mask + __shared__ uint16_t s_random_sign_mask; + __shared__ float s_global_amax; + if (threadIdx.x == 0) { + s_global_amax = *global_amax_ptr; + uint16_t mask = 0; + for (int row = 0; row < kHadamardDim; ++row) { + mask |= static_cast((to_f32(hadamard_matrix[row * kHadamardDim]) < 0.0f ? 1u : 0u) << row); + } + s_random_sign_mask = mask; + } + __syncthreads(); + const float global_amax = s_global_amax; + const int tid = threadIdx.x; const int warp_id = tid / kWarpSize; const int lane_id = tid % kWarpSize; @@ -846,7 +861,7 @@ __global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusi float c2 = to_f32(input[(input_row_base + 2) * row_length + input_col]); float c3 = to_f32(input[(input_row_base + 3) * row_length + input_col]); - wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, /*apply_pre=*/true); + wht16(c0, c1, c2, c3, thread_in_grp, s_random_sign_mask, /*apply_pre=*/true); // Truncate to BF16 precision to match the reference BF16 matmul path. // Without this, FP32 WHT results at FP4 quantization boundaries round @@ -894,22 +909,6 @@ __global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusi *reinterpret_cast(&output_t[output_byte_offset]) = packed; } -uint16_t random_sign_mask_from_rht_matrix(const SimpleTensor& hadamard_matrix, cudaStream_t stream) { - std::array host_matrix{}; - - NVTE_CHECK_CUDA(cudaMemcpyAsync(host_matrix.data(), hadamard_matrix.dptr, - host_matrix.size() * sizeof(uint16_t), - cudaMemcpyDeviceToHost, stream)); - NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); - - uint16_t random_sign_mask = 0; - for (size_t row = 0; row < kHadamardDim; ++row) { - // The first column of diag(sign) @ H16 is sign[row] * 0.25. - random_sign_mask |= static_cast(((host_matrix[row * kHadamardDim] >> 15) & 1) << row); - } - return random_sign_mask; -} - } // namespace } // namespace transformer_engine @@ -1033,9 +1032,6 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out /*stream=*/stream, /*k_tile_size=*/k_tile_size);); #else - const uint16_t random_sign_mask_t = - random_sign_mask_from_rht_matrix(hadamard_matrix, stream); - const dim3 block(kThreadsPerBlock); const dim3 grid(DIVUP(n, static_cast(kHadamardDim)), DIVUP(m, static_cast(kRowsPerBlock))); @@ -1046,7 +1042,8 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out reinterpret_cast(input.dptr), reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv_t.dptr), - *reinterpret_cast(global_amax.dptr), random_sign_mask_t, + reinterpret_cast(global_amax.dptr), + reinterpret_cast(hadamard_matrix.dptr), static_cast(m), static_cast(n), scale_inv_t.shape[1], rng_state);); From c7cc488c24ee6cfb61aa229621684e783bb6d3fd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 6 Apr 2026 11:22:48 -0500 Subject: [PATCH 38/69] CI: enable CI runs on every PR --- .github/workflows/rocm-ci-dispatch.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/rocm-ci-dispatch.yml b/.github/workflows/rocm-ci-dispatch.yml index 192f8bb0c..34ee8eb48 100644 --- a/.github/workflows/rocm-ci-dispatch.yml +++ b/.github/workflows/rocm-ci-dispatch.yml @@ -6,9 +6,6 @@ name: PR Automatic CI on: pull_request: - branches: - - 'dev' - - 'release_v2.*_rocm' types: [ opened, labeled, synchronize, reopened ] permissions: From 7c68bd860e5d806e46e408eea415816f7aef8202 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 6 Apr 2026 11:35:31 -0500 Subject: [PATCH 39/69] Avoid duplicate entry when opening PR We need a label to run CI. --- .github/workflows/rocm-ci-dispatch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-ci-dispatch.yml b/.github/workflows/rocm-ci-dispatch.yml index 34ee8eb48..f6524e77b 100644 --- a/.github/workflows/rocm-ci-dispatch.yml +++ b/.github/workflows/rocm-ci-dispatch.yml @@ -6,7 +6,7 @@ name: PR Automatic CI on: pull_request: - types: [ opened, labeled, synchronize, reopened ] + types: [ labeled, synchronize, reopened ] permissions: contents: read From a19dd6081103de0b97d461abd586731374d659a7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 6 Apr 2026 14:29:09 -0500 Subject: [PATCH 40/69] fix stream capture error in GEMM --- transformer_engine/common/gemm/rocm_gemm.cu | 88 ++++++++++++++------- 1 file changed, 60 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 3fcf12bdd..9811582d3 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -204,8 +204,8 @@ __device__ constexpr float kFP4E2M1Table[16] = { }; // Dequantize FP4 (e2m1) packed data with FP8 e4m3 block scales to BF16. -// NOTE: The per-tensor amax factor is NOT applied here — it is folded into the GEMM alpha instead, -// matching how cuBLASLt handles NVFP4 on CUDA. This avoids BF16 precision loss from the extra scaling. +// Only applies block scales: output = fp4_value * block_scale. +// The per-tensor amax correction is applied separately via the GEMM alpha scalar. __global__ void dequant_fp4_to_bf16_kernel( const uint8_t* __restrict__ data, const fp8e4m3* __restrict__ scale_inv, @@ -245,6 +245,19 @@ static void launch_dequant_fp4_to_bf16( total_elements); } +// Compute per-row alpha vector on device for NVFP4 GEMM: +// alpha_out[i] = alpha_in * amax_A * amax_B / (fp4_max^2 * fp8_max^2) for i in [0, m) +// Used with HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST, which +// expects a device vector of length m for alpha, while beta stays on the host. +__global__ void compute_fp4_alpha_vector_kernel(float alpha_in, const float* __restrict__ amax_A, + const float* __restrict__ amax_B, float factor_inv, + float* __restrict__ alpha_out, int m) { + const float alpha_val = alpha_in * (*amax_A) * (*amax_B) * factor_inv; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < m; i += blockDim.x * gridDim.x) { + alpha_out[i] = alpha_val; + } +} + GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, const transformer_engine::Tensor &B, const cublasOperation_t transB, const int m, const int n, const int k) { @@ -1017,9 +1030,19 @@ void hipblaslt_gemm(const Tensor *inputA, // FP4 dequant path: hipBLASLt does not support FP4 natively, // so we dequantize FP4 -> BF16 (block scales only) and run a standard BF16 GEMM. - // The per-tensor amax factor is folded into the GEMM alpha, matching the cuBLASLt approach. + // + // The per-tensor amax correction is computed on-device as a per-row alpha vector: + // alpha'[i] = alpha * amax_A * amax_B / (fp4_max^2 * fp8_max^2) + // Alpha is passed as a device vector of length m via + // HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST. Beta stays on host. const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype); + const void* alpha_ptr = static_cast(&alpha); + const void* beta_ptr = static_cast(&beta); if (use_fp4) { + const float fp4_max = 6.0f; + const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); + const float* amax_A = (transa == CUBLAS_OP_T) ? reinterpret_cast(inputA->amax.dptr) : reinterpret_cast(inputA->columnwise_amax.dptr); @@ -1027,22 +1050,24 @@ void hipblaslt_gemm(const Tensor *inputA, ? reinterpret_cast(inputB->amax.dptr) : reinterpret_cast(inputB->columnwise_amax.dptr); - // Fold per-tensor amax into alpha: alpha *= amax_A * amax_B / (fp4_max * fp8_max)^2 - // This matches the CUDA path (nvfp4.cu: compute_nvfp4_per_tensor_scale_kernel). - const float fp4_max = 6.0f; - const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; - const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); - float h_amax_A = 1.0f, h_amax_B = 1.0f; - if (amax_A != nullptr) { - NVTE_CHECK_CUDA(hipMemcpyAsync(&h_amax_A, amax_A, sizeof(float), hipMemcpyDeviceToHost, stream)); - } - if (amax_B != nullptr) { - NVTE_CHECK_CUDA(hipMemcpyAsync(&h_amax_B, amax_B, sizeof(float), hipMemcpyDeviceToHost, stream)); - } - NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); - alpha *= h_amax_A * h_amax_B * factor_inv; - - // Compute workspace needed for dequantized BF16 buffers + // Reserve m floats from end of workspace for the device alpha vector. + const size_t alpha_vec_bytes = static_cast(m) * sizeof(float); + NVTE_CHECK(workspaceSize >= alpha_vec_bytes, + "NVFP4 GEMM requires at least ", alpha_vec_bytes, " bytes workspace for alpha vector."); + workspaceSize = (workspaceSize / sizeof(float)) * sizeof(float) - alpha_vec_bytes; + float* device_alpha_vec = reinterpret_cast( + reinterpret_cast(workspace) + workspaceSize); + + NVTE_CHECK(amax_A != nullptr, "FP4 GEMM requires amax_A"); + NVTE_CHECK(amax_B != nullptr, "FP4 GEMM requires amax_B"); + constexpr int kBlockSize = 256; + const int num_blocks = (m + kBlockSize - 1) / kBlockSize; + compute_fp4_alpha_vector_kernel<<>>( + alpha, amax_A, amax_B, factor_inv, device_alpha_vec, m); + alpha_ptr = static_cast(device_alpha_vec); + // beta_ptr stays as host &beta + + // Dequantize FP4 -> BF16 (block scales only, no amax folded in) const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; const size_t dequant_ws = (a_bf16_bytes + b_bf16_bytes + 255) & ~size_t(255); @@ -1073,7 +1098,7 @@ void hipblaslt_gemm(const Tensor *inputA, ws_offset += b_bf16_bytes; } - // Advance workspace past dequant buffers + // Advance workspace past dequant buffers (device alpha vector is safe at the end) workspace = ws_ptr + ((ws_offset + 255) & ~size_t(255)); workspaceSize -= dequant_ws; } @@ -1297,6 +1322,13 @@ void hipblaslt_gemm(const Tensor *inputA, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + if (use_fp4) { + int32_t pointer_mode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( + operationDesc, HIPBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, sizeof(pointer_mode))); + } + GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, use_fp8 ? bias_type : (hipDataType)-1, (use_fp8 && gelu) ? aux_type : (hipDataType)-1, @@ -1318,10 +1350,10 @@ void hipblaslt_gemm(const Tensor *inputA, if (HIPBLAS_STATUS_SUCCESS == hipblaslt_ext::matmulIsAlgoSupported( handle, operationDesc, - static_cast(&alpha), + alpha_ptr, Adesc, Bdesc, - static_cast(&beta), + beta_ptr, Ddesc, Ddesc, algo_arr[0].algo, @@ -1398,12 +1430,12 @@ void hipblaslt_gemm(const Tensor *inputA, // Warm-up call NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha_ptr, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta_ptr, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1420,12 +1452,12 @@ void hipblaslt_gemm(const Tensor *inputA, { NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha_ptr, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta_ptr, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1481,12 +1513,12 @@ void hipblaslt_gemm(const Tensor *inputA, // D = alpha * (A * B) + beta * C NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha_ptr, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta_ptr, /* beta */ C, /* C */ Cdesc, D, /* D */ From e32a75828353b206a466373386fe18f287793439 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 6 Apr 2026 15:01:33 -0500 Subject: [PATCH 41/69] merge errors --- transformer_engine/common/CMakeLists.txt | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3f3071eb2..a362045de 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -220,13 +220,10 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu activation/relu.cu activation/swiglu.cu -<<<<<<< mdiener/fp4_hadamard - hadamard_transform/hadamard_transform.cu -======= cast/cast.cu + hadamard_transform/hadamard_transform.cu multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu ->>>>>>> dev transpose/quantize_transpose_vector_blockwise_fp4.cu) if(USE_CUDA) @@ -247,16 +244,11 @@ if(USE_CUDA) gemm/cutlass_grouped_gemm.cu hadamard_transform/group_hadamard_transform.cu transpose/quantize_transpose_square_blockwise.cu -<<<<<<< mdiener/fp4_hadamard - hadamard_transform/hadamard_transform_cast_fusion.cu) -======= - hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) ->>>>>>> dev else() #ROCm specific source codes list(APPEND transformer_engine_cpp_sources From b243b4c6804b3bf05cc85c21d9ed6ed085db62d2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 6 Apr 2026 15:15:18 -0500 Subject: [PATCH 42/69] merge --- .../common/hadamard_transform/hadamard_transform.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 33196847b..992e61046 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -18,7 +18,9 @@ #include "common/common.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "hadamard_transform_utils.cuh" +#endif namespace transformer_engine { namespace { From ca1aacf7027337b34d7ac9e641b69d039e3a7404 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 7 Apr 2026 10:48:24 -0500 Subject: [PATCH 43/69] change to __builtin_bit_cast --- .../common/hadamard_transform/hadamard_transform.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 992e61046..8472fea18 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -529,7 +529,7 @@ __device__ __forceinline__ __hip_bfloat16 to_bf16(float v) { return sta // Bit-cast __hip_bfloat16->uint16_t without address-of-temporary. __device__ __forceinline__ uint16_t bf16_to_bits(__hip_bfloat16 v) { - uint16_t bits; __builtin_memcpy(&bits, &v, sizeof(uint16_t)); return bits; + return __builtin_bit_cast(uint16_t, v); } // Unpack/pack 4 BF16 values as uint64_t (vectorised global load/store). From c8e6c72f6cf19651bca4aa996914b1c94a901816 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 7 Apr 2026 12:37:33 -0500 Subject: [PATCH 44/69] more fixes --- tests/pytorch/test_fusible_ops.py | 5 ++- tests/pytorch/test_numerics.py | 1 - transformer_engine/common/CMakeLists.txt | 5 +-- transformer_engine/common/gemm/rocm_gemm.cu | 42 +++++++++++++++---- .../pytorch/csrc/extensions/cast.cpp | 26 +++++++++--- 5 files changed, 61 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index e4647ac82..a41877da3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2936,7 +2936,10 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return out # Check values - tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking + if IS_HIP_EXTENSION: + tols = {"rtol": 0.25, "atol": 0.54} + else: + tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d1e9b341e..4a768377e 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -778,7 +778,6 @@ def test_gpt_full_activation_recompute( if (dtype == torch.bfloat16 and not fp8 and not use_reentrant - and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") if fp8 and recipe.nvfp4(): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8d1b26896..09d2df48c 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -247,12 +247,9 @@ if(USE_CUDA) transpose/quantize_transpose_square_blockwise.cu hadamard_transform/group_hadamard_transform.cu transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu) + transpose/quantize_transpose_square_blockwise.cu) else() #ROCm specific source codes list(APPEND transformer_engine_cpp_sources diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 9811582d3..4b182395f 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -206,11 +206,17 @@ __device__ constexpr float kFP4E2M1Table[16] = { // Dequantize FP4 (e2m1) packed data with FP8 e4m3 block scales to BF16. // Only applies block scales: output = fp4_value * block_scale. // The per-tensor amax correction is applied separately via the GEMM alpha scalar. +// +// Scale layout: 2D tensor of shape {num_rows_padded, scale_stride} where +// scale_stride = roundup(num_cols / 16, 4). Each scale covers a block of 16 +// consecutive elements along the fast (column) dimension. __global__ void dequant_fp4_to_bf16_kernel( const uint8_t* __restrict__ data, const fp8e4m3* __restrict__ scale_inv, hip_bfloat16* __restrict__ output, - int64_t total_elements) + int64_t total_elements, + int64_t num_cols, + int64_t scale_stride) { // Process 2 elements (1 byte) per iteration for coalesced access const int64_t pair_idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; @@ -222,8 +228,12 @@ __global__ void dequant_fp4_to_bf16_kernel( const uint8_t hi_nibble = byte >> 4; const int64_t elem_base = pair_idx * 2; - const float s0 = static_cast(scale_inv[elem_base / 16]); - const float s1 = static_cast(scale_inv[(elem_base + 1) / 16]); + const int64_t row0 = elem_base / num_cols; + const int64_t col0 = elem_base % num_cols; + const int64_t row1 = (elem_base + 1) / num_cols; + const int64_t col1 = (elem_base + 1) % num_cols; + const float s0 = static_cast(scale_inv[row0 * scale_stride + col0 / 16]); + const float s1 = static_cast(scale_inv[row1 * scale_stride + col1 / 16]); output[elem_base] = static_cast(kFP4E2M1Table[lo_nibble] * s0); output[elem_base + 1] = static_cast(kFP4E2M1Table[hi_nibble] * s1); @@ -232,7 +242,9 @@ __global__ void dequant_fp4_to_bf16_kernel( // Launch helper for dequant kernel static void launch_dequant_fp4_to_bf16( const void* data, const void* scale_inv, - void* output, int64_t total_elements, hipStream_t stream) + void* output, int64_t total_elements, + int64_t num_cols, int64_t scale_stride, + hipStream_t stream) { constexpr int kBlockSize = 256; const int64_t total_pairs = total_elements / 2; @@ -242,7 +254,7 @@ static void launch_dequant_fp4_to_bf16( reinterpret_cast(data), reinterpret_cast(scale_inv), reinterpret_cast(output), - total_elements); + total_elements, num_cols, scale_stride); } // Compute per-row alpha vector on device for NVFP4 GEMM: @@ -1081,7 +1093,15 @@ void hipblaslt_gemm(const Tensor *inputA, if (is_fp4_dtype(param.Atype)) { hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr + ws_offset); const int64_t total_a = static_cast(m) * k; - launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, stream); + // Determine scale stride from scale tensor shape + const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA->scale_inv + : inputA->columnwise_scale_inv; + const int64_t a_num_cols = (transa == CUBLAS_OP_T) + ? inputA->data.shape.back() + : inputA->columnwise_data.shape.back(); + const int64_t a_scale_stride = (a_sinv.shape.size() >= 2) ? a_sinv.shape[1] : (a_num_cols / 16); + launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, + a_num_cols, a_scale_stride, stream); param.A = a_bf16; param.Atype = DType::kBFloat16; param.A_scale_inv = nullptr; @@ -1091,7 +1111,15 @@ void hipblaslt_gemm(const Tensor *inputA, if (is_fp4_dtype(param.Btype)) { hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr + ws_offset); const int64_t total_b = static_cast(k) * n; - launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, stream); + // Determine scale stride from scale tensor shape + const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB->scale_inv + : inputB->columnwise_scale_inv; + const int64_t b_num_cols = (transb == CUBLAS_OP_N) + ? inputB->data.shape.back() + : inputB->columnwise_data.shape.back(); + const int64_t b_scale_stride = (b_sinv.shape.size() >= 2) ? b_sinv.shape[1] : (b_num_cols / 16); + launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, + b_num_cols, b_scale_stride, stream); param.B = b_bf16; param.Btype = DType::kBFloat16; param.B_scale_inv = nullptr; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 6b647abe1..2b35e0700 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -801,6 +801,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( return res; } +#ifndef USE_ROCM // Implements split-quantize NVFP4 with Row/Column-wise Hadamard Transform (RHT) void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, const std::vector &input_list, @@ -963,6 +964,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } } +#endif // #ifndef USE_ROCM void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const std::vector &input_list, @@ -1019,8 +1021,16 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); } +#ifndef USE_ROCM nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); +#else + // nvte_group_amax is not available on ROCm; compute amax individually + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) continue; + nvte_compute_amax(input_list[i].data(), output_list[i].data(), stream); + } +#endif for (size_t i = 0; i < num_tensors; i++) { output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); } @@ -1085,20 +1095,28 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Perform multi-tensor quantization NVTE_SCOPED_GIL_RELEASE({ +#ifndef USE_ROCM if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data // Check that config is supported NVTE_CHECK(input.dtype() == DType::kBFloat16, "RHT is only supported for bfloat16 input"); // Fuse the rowwise and colwise into one when the kernel is ready split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, quantizers, stream); - } else { // NVFP4 quantize - // Fuse the rowwise and colwise into one when the kernel is ready + } else { + // NVFP4 quantize without RHT split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, stream); } +#else + // ROCm: group hadamard kernels are not available, fall back to per-tensor quantize + // which handles both RHT and non-RHT paths via NVFP4Quantizer::quantize_impl. + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) continue; + quantizers[i]->quantize(input_list[i], output_list[i]); + } +#endif }); } -// #endif // #ifndef USE_ROCM } // namespace @@ -1168,14 +1186,12 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsMXFP8Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_MXFP8; -#ifndef USE_ROCM } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), [](const py::handle &quantizer) -> bool { return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_NVFP4; quantization_method = QuantizationMethod::FUSED_NVFP4; -#endif } } From 1b0fe3eb3d8df9833600f230173e46ac26ce0d72 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 7 Apr 2026 16:52:47 -0500 Subject: [PATCH 45/69] fix dequant buffer --- transformer_engine/common/gemm/rocm_gemm.cu | 31 +++++++++------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 4b182395f..1f3bce94b 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1048,6 +1048,8 @@ void hipblaslt_gemm(const Tensor *inputA, // Alpha is passed as a device vector of length m via // HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST. Beta stays on host. const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype); + void* fp4_dequant_a_buf = nullptr; + void* fp4_dequant_b_buf = nullptr; const void* alpha_ptr = static_cast(&alpha); const void* beta_ptr = static_cast(&beta); if (use_fp4) { @@ -1080,18 +1082,10 @@ void hipblaslt_gemm(const Tensor *inputA, // beta_ptr stays as host &beta // Dequantize FP4 -> BF16 (block scales only, no amax folded in) - const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; - const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; - const size_t dequant_ws = (a_bf16_bytes + b_bf16_bytes + 255) & ~size_t(255); - NVTE_CHECK(workspaceSize >= dequant_ws, - "NVFP4 GEMM requires at least ", dequant_ws, " bytes workspace for FP4->BF16 dequant, " - "but only ", workspaceSize, " bytes available."); - - uint8_t* ws_ptr = reinterpret_cast(workspace); - size_t ws_offset = 0; - if (is_fp4_dtype(param.Atype)) { - hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr + ws_offset); + const size_t a_bf16_bytes = static_cast(m) * k * sizeof(hip_bfloat16); + NVTE_CHECK_CUDA(hipMallocAsync(&fp4_dequant_a_buf, a_bf16_bytes, stream)); + hip_bfloat16* a_bf16 = reinterpret_cast(fp4_dequant_a_buf); const int64_t total_a = static_cast(m) * k; // Determine scale stride from scale tensor shape const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA->scale_inv @@ -1105,11 +1099,12 @@ void hipblaslt_gemm(const Tensor *inputA, param.A = a_bf16; param.Atype = DType::kBFloat16; param.A_scale_inv = nullptr; - ws_offset += a_bf16_bytes; } if (is_fp4_dtype(param.Btype)) { - hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr + ws_offset); + const size_t b_bf16_bytes = static_cast(k) * n * sizeof(hip_bfloat16); + NVTE_CHECK_CUDA(hipMallocAsync(&fp4_dequant_b_buf, b_bf16_bytes, stream)); + hip_bfloat16* b_bf16 = reinterpret_cast(fp4_dequant_b_buf); const int64_t total_b = static_cast(k) * n; // Determine scale stride from scale tensor shape const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB->scale_inv @@ -1123,12 +1118,7 @@ void hipblaslt_gemm(const Tensor *inputA, param.B = b_bf16; param.Btype = DType::kBFloat16; param.B_scale_inv = nullptr; - ws_offset += b_bf16_bytes; } - - // Advance workspace past dequant buffers (device alpha vector is safe at the end) - workspace = ws_ptr + ((ws_offset + 255) & ~size_t(255)); - workspaceSize -= dequant_ws; } bool nvte_log_gemm_config = false; @@ -1564,6 +1554,11 @@ void hipblaslt_gemm(const Tensor *inputA, update_tensor_scale_inv(outputD, stream); } + if (fp4_dequant_a_buf) + NVTE_CHECK_CUDA(hipFreeAsync(fp4_dequant_a_buf, stream)); + if (fp4_dequant_b_buf) + NVTE_CHECK_CUDA(hipFreeAsync(fp4_dequant_b_buf, stream)); + NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc)); From bc9f0a392b1af6f11e5e8414f582a1e8228599f2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 8 Apr 2026 10:04:07 -0500 Subject: [PATCH 46/69] remove copyright header --- .../common/include/transformer_engine/hadamard_transform.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 0201b3c7d..13103cc38 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. From 167c3112660739557c3f0799eca7b5e0c066a422 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 8 Apr 2026 11:04:05 -0500 Subject: [PATCH 47/69] fix triton rmsnorm --- transformer_engine/pytorch/triton_kernels/norms_common.py | 3 ++- transformer_engine/pytorch/triton_kernels/utils.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms_common.py b/transformer_engine/pytorch/triton_kernels/norms_common.py index 87cfa722e..ed4002f2c 100644 --- a/transformer_engine/pytorch/triton_kernels/norms_common.py +++ b/transformer_engine/pytorch/triton_kernels/norms_common.py @@ -8,6 +8,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.triton_kernels.common import ( te_dtype_to_torch_dtype, te_dtype_to_triton_dtype, @@ -222,7 +223,7 @@ def _te_norm_fwd_triton( quantizer.amax, N, ATOMIC_REDUCTION_BLOCK_SIZE, ) - elif IS_MXFP8 or IS_FP8_CURRENT_SCALING: + elif IS_MXFP8 or IS_FP8_CURRENT_SCALING or isinstance(quantizer, NVFP4Quantizer): _out = quantizer.make_empty( input_tensor.shape, dtype=te_dtype_to_torch_dtype(otype), diff --git a/transformer_engine/pytorch/triton_kernels/utils.py b/transformer_engine/pytorch/triton_kernels/utils.py index 15a733ce9..884fab5e3 100644 --- a/transformer_engine/pytorch/triton_kernels/utils.py +++ b/transformer_engine/pytorch/triton_kernels/utils.py @@ -6,6 +6,7 @@ import triton from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from .common import te_dtype_to_torch_dtype def get_ln_sm_margin(sm_margin_type): @@ -59,7 +60,7 @@ def make_ln_out(ln_out, quantizer=None, input_shape=None, out_dtype=torch.float3 if ln_out is None: # TODO(micky774): Remove corresponding FP8Quantizer check when kernels properly support MXFP8/float8_current_scaling as a fused operation - if quantizer is None or isinstance(quantizer, MXFP8Quantizer) or isinstance(quantizer, Float8CurrentScalingQuantizer): + if quantizer is None or isinstance(quantizer, (MXFP8Quantizer, Float8CurrentScalingQuantizer, NVFP4Quantizer)): return torch.empty(input_shape, dtype=out_dtype, device='cuda') return quantizer.make_empty(input_shape, dtype=out_dtype) From 4b0550d15a1b053707e8cd9dbb2ebc1b873d85a2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 8 Apr 2026 15:31:37 -0500 Subject: [PATCH 48/69] mi300 fixes --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 8 ++++++-- tests/cpp/operator/test_dequantize_nvfp4.cu | 7 ++++++- tests/pytorch/test_fusible_ops.py | 2 +- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 7 ++++++- transformer_engine/common/recipe/nvfp4.cu | 9 +++++++++ 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 50f2d36fe..a24a9c6f4 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -65,8 +65,12 @@ std::vector create_transpose(const InputType* const input, const size // Compute the global encode scale factor for a given global amax float compute_global_encode_scaling_factor_FP4(const float global_amax) { - constexpr float fp8_max = 448.0f; // 448.0f; - constexpr float fp4_max = 6.0f; // 6.0f; +#ifdef __HIP_PLATFORM_AMD__ + const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; +#else + constexpr float fp8_max = 448.0f; +#endif + constexpr float fp4_max = 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 1da70923a..ce6826b7c 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -180,7 +180,12 @@ void compute_ref(const fp4e2m1* input, const size_t rows, const size_t cols, const size_t scale_stride) { - constexpr float factor_inv = 1.0f / (6.0f * 448.0f); +#ifdef __HIP_PLATFORM_AMD__ + const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; +#else + constexpr float fp8_max = 448.0f; +#endif + const float factor_inv = 1.0f / (6.0f * fp8_max); const size_t blocks_per_row = cols / kFP4BlockSize1D; diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a41877da3..3b4e4f2d3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1773,7 +1773,7 @@ def test_clamped_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index ccdc4c93e..b75947e91 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -58,7 +58,12 @@ __global__ void __launch_bounds__(512) value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; float amax = *tensor_amax; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__) + // On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) + const float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); +#else + constexpr float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); +#endif float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 682d8b53f..e1d30f3af 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -14,14 +14,23 @@ namespace transformer_engine { namespace nvfp4_recipe { +#ifndef __HIP_PLATFORM_AMD__ // constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); +#endif // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, const float *amax_B, float *alpha_out) { +#ifdef __HIP_PLATFORM_AMD__ + const float fp4_max = detail::TypeExtrema::max; + const float fp8_max = detail::TypeExtrema::max; + const float fi = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); + *alpha_out = alpha_in * (*amax_A) * (*amax_B) * fi; +#else // factor is defined in the enclosing namespace *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; +#endif } } // namespace nvfp4_recipe From 5da621a658db466a911e3440c56780bdc3d9d9bd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 8 Apr 2026 16:56:17 -0500 Subject: [PATCH 49/69] software fallbacks for SR on gfx942 --- .../hadamard_transform_cast_fusion.cu | 41 ++++++++++++++++++- ...quantize_transpose_vector_blockwise_fp4.cu | 40 +++++++++++++++++- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 999c03245..c579c64e8 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -814,8 +814,45 @@ __device__ __forceinline__ fp4e2m1x4 cvt_fp32_to_fp4_4x(const float2 in01, const lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8)); return result; #else - NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); - return fp4e2m1x4{}; + // Software stochastic rounding fallback for AMD GPUs without native + // FP4 SR instructions (e.g. gfx942). + // + // FP4 E2M1 has 8 non-negative magnitudes whose 3-bit codes happen to + // be sorted: {0->0.0, 1->0.5, 2->1.0, 3->1.5, 4->2.0, 5->3.0, + // 6->4.0, 7->6.0}. + // + // For each value we: + // 1. Clamp |x| into [0, 6] (the FP4 representable range). + // 2. Find the floor index fi in the FP4 grid via branchless + // comparisons (sum of (|x| >= threshold) for each level). + // 3. Compute the fractional position within [kV[fi], kV[ci]] + // where ci = min(fi+1, 7) is the ceiling index. + // 4. Draw a uniform random value r in [0,1) from 8 bits of rbits. + // 5. Round up to ci if r < frac, otherwise keep fi. + // This gives E[round(x)] = x (unbiased). + // 6. Set the sign bit (bit 3) if the original value was negative. + { + constexpr float kV[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; + const float vals[4] = {in01.x, in01.y, in23.x, in23.y}; + __hip_fp4_storage_t q[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + const float av = fminf(fabsf(vals[i]), 6.0f); + const int fi = int(av >= 0.5f) + int(av >= 1.0f) + int(av >= 1.5f) + + int(av >= 2.0f) + int(av >= 3.0f) + int(av >= 4.0f) + int(av >= 6.0f); + const int ci = min(fi + 1, 7); + const float gap = kV[ci] - kV[fi]; + const float frac = (gap > 0.0f) ? (av - kV[fi]) / gap : 0.0f; + const float r = static_cast((rbits >> (8 * i)) & 0xFFu) * (1.0f / 256.0f); + const int ri = (r < frac) ? ci : fi; + q[i] = static_cast<__hip_fp4_storage_t>((vals[i] < 0.0f) ? (ri | 0x8) : ri); + } + fp4e2m1x4 result; + result.__x = static_cast<__hip_fp4x4_storage_t>( + (q[0] & 0xFu) | ((q[1] & 0xFu) << 4) | + ((q[2] & 0xFu) << 8) | ((q[3] & 0xFu) << 12)); + return result; + } #endif } else { const __hip_fp4_storage_t q0 = diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index bdcfd8d0b..306ae0e30 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -314,7 +314,45 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro result.__x = static_cast<__hip_fp4x4_storage_t>(lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8)); return result; #else - NVTE_DEVICE_ERROR("FP4 stochastic rounding on AMDGPU requires gfx950 or later."); + // Stochastic rounding fallback for AMD GPUs without native + // FP4 SR instructions (e.g. gfx942). + // + // FP4 E2M1 has 8 non-negative magnitudes whose 3-bit codes happen to + // be sorted: {0->0.0, 1->0.5, 2->1.0, 3->1.5, 4->2.0, 5->3.0, + // 6->4.0, 7->6.0}. + // + // For each value we: + // 1. Clamp |x| into [0, 6] (the FP4 representable range). + // 2. Find the floor index fi in the FP4 grid via branchless + // comparisons (sum of (|x| >= threshold) for each level). + // 3. Compute the fractional position within [kV[fi], kV[ci]] + // where ci = min(fi+1, 7) is the ceiling index. + // 4. Draw a uniform random value r in [0,1) from 8 bits of rbits. + // 5. Round up to ci if r < frac, otherwise keep fi. + // This gives E[round(x)] = x (unbiased). + // 6. Set the sign bit (bit 3) if the original value was negative. + { + constexpr float kV[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; + const float vals[4] = {in01.x, in01.y, in23.x, in23.y}; + __hip_fp4_storage_t q[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + const float av = fminf(fabsf(vals[i]), 6.0f); + const int fi = int(av >= 0.5f) + int(av >= 1.0f) + int(av >= 1.5f) + + int(av >= 2.0f) + int(av >= 3.0f) + int(av >= 4.0f) + int(av >= 6.0f); + const int ci = min(fi + 1, 7); + const float gap = kV[ci] - kV[fi]; + const float frac = (gap > 0.0f) ? (av - kV[fi]) / gap : 0.0f; + const float r = static_cast((rbits >> (8 * i)) & 0xFFu) * (1.0f / 256.0f); + const int ri = (r < frac) ? ci : fi; + q[i] = static_cast<__hip_fp4_storage_t>((vals[i] < 0.0f) ? (ri | 0x8) : ri); + } + __nv_fp4x4_e2m1 result; + result.__x = static_cast<__hip_fp4x4_storage_t>( + (q[0] & 0xFu) | ((q[1] & 0xFu) << 4) | + ((q[2] & 0xFu) << 8) | ((q[3] & 0xFu) << 12)); + return result; + } #endif // ARCH_HAS_STOCHASTIC_ROUNDING #endif // !__HIP_PLATFORM_AMD__ uint16_t dummy = 0; From 287708d427da94bc1e4bb75524ca38300b159959 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Apr 2026 12:17:39 -0500 Subject: [PATCH 50/69] more gfx942 fixes --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 8 +++++--- tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 7 +++++-- ...quantize_transpose_vector_blockwise_fp4.cu | 9 ++++++++- .../custom_recipes/quantization_nvfp4.py | 20 ++++++++++++++----- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 44a87c83f..d40e44423 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -12,6 +12,7 @@ from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -110,10 +111,11 @@ def check_nvfp4_gemm_versus_reference( sx_trimmed = sx_native[:M, :expected_sx_cols] sw_trimmed = sw_native[:N, :expected_sw_cols] - # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn + # Native scales are stored as uint8 but need to be interpreted as float8_e4m3 # for the reference GEMM to work correctly - sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) - sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) + fp8_dtype = get_torch_float8_e4m3_type() + sx_trimmed = sx_trimmed.view(fp8_dtype) + sw_trimmed = sw_trimmed.view(fp8_dtype) # Create reference quantizer for reference GEMM ref_quantizer = NVFP4QuantizerRef( diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815..98c0474a5 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -11,6 +11,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -58,10 +59,12 @@ def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: - sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) + fp8_dtype = get_torch_float8_e4m3_type() + fp8_max = 240.0 if is_fp8_fnuz() else 448.0 + sf = sx.repeat_interleave(16, dim=1).view(fp8_dtype).to(torch.float32) dqx = fp4_to_fp32(unpack_fp4(qx)) sf = sf[: dqx.shape[0], : dqx.shape[1]] - dequant = dqx * sf * (amax / (6.0 * 448)) + dequant = dqx * sf * (amax / (6.0 * fp8_max)) return dequant diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 306ae0e30..8b814b53a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -137,12 +137,19 @@ constexpr int kThreadsPerWarp = 32; constexpr int kNFP4PerContainer = 2; // Hyperparameters for performance tuning +// gfx942 has 64 KB LDS per workgroup. With kTileDim=128 and float32 input, +// shared memory exceeds 64 KB. Use kTileDim=64 and kThreadsPerBlock=128 on gfx942. +#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx950__) +constexpr int kTileDim = 64; +constexpr int kThreadsPerBlock = 128; // Thread block size, 4 warps in total +#else constexpr int kTileDim = 128; +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total +#endif // constexpr int kScaleDim = 32; constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches -constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total // Auto-calculated constants, do not modify directly) static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d00d0c8b9..11535b564 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -12,6 +12,15 @@ from transformer_engine.pytorch.custom_recipes import quantization from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz + + +def _fp8_e4m3_max(): + return 240.0 if is_fp8_fnuz() else 448.0 + + +def _fp8_e4m3_dtype(): + return get_torch_float8_e4m3_type() def nvfp4_ref_rht_2d_quantizer_factory(role): @@ -137,9 +146,9 @@ def cast_to_e4m3(decode_scale, global_amax): TODO(etsykunov): Make less unintuitive. """ decode_scale = decode_scale * global_amax - FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(_fp8_e4m3_max(), device=decode_scale.device, dtype=torch.float32) decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - return decode_scale.to(torch.float8_e4m3fn) + return decode_scale.to(_fp8_e4m3_dtype()) def high_precision_gemm_ref( @@ -470,7 +479,7 @@ def _quantize_blockwise_reference( ) # (128, 8, 1) x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) - FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(_fp8_e4m3_max(), device=x.device, dtype=torch.float32) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) if pow_2_scales: @@ -503,7 +512,7 @@ def _quantize_blockwise_reference( ), ) decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - decode_scale = decode_scale.to(torch.float8_e4m3fn) + decode_scale = decode_scale.to(_fp8_e4m3_dtype()) encode_scale = torch.min( torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), @@ -823,7 +832,8 @@ def qgemm( sx = sx.to(torch.float32) sw = sw.to(torch.float32) - factor = 6.0 * 6.0 * 448.0 * 448.0 + _e4m3_max = _fp8_e4m3_max() + factor = 6.0 * 6.0 * _e4m3_max * _e4m3_max if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col From 81c45c96f3b86d3cc321746f046bbd8e76946ee3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Apr 2026 15:02:02 -0500 Subject: [PATCH 51/69] ensure columnwise data for dgrad GEMM --- transformer_engine/pytorch/module/base.py | 6 ++++++ .../pytorch/module/layernorm_linear.py | 13 +++++++++++-- transformer_engine/pytorch/module/linear.py | 12 ++++++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2d8563729..3aac641fa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -52,6 +52,7 @@ from ..triton_kernels.cast import te_quantize_triton from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -1486,6 +1487,11 @@ def get_weight_workspace( reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: reset_cache = True + elif isinstance(out, NVFP4TensorStorage): + if quantizer.rowwise_usage and out._rowwise_data is None: + reset_cache = True + elif quantizer.columnwise_usage and out._columnwise_data is None: + reset_cache = True if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): reset_cache = True if reset_cache: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7347fc138..2a4f74d17 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -312,6 +312,11 @@ def forward( weight_quantizer = weight._quantizer elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) + # NVFP4 must produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor) + from ..tensor.nvfp4_tensor import NVFP4Quantizer + if isinstance(weight_quantizer, NVFP4Quantizer) and is_grad_enabled: + weight_quantizer.set_usage(columnwise=True) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -369,7 +374,7 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: + if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache and hasattr(weightmat, '_transpose'): assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( @@ -1861,5 +1866,9 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if IS_HIP_EXTENSION: - weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + # NVFP4 must always produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor), so force columnwise=True. + from ..tensor.nvfp4_tensor import NVFP4Quantizer + is_nvfp4 = isinstance(weight_quantizer, NVFP4Quantizer) + weight_quantizer.set_usage(columnwise=True if is_nvfp4 else self.keep_fp8_weight_transpose_cache) return [weight_quantizer] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 01d07d91a..dd17c7023 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -68,6 +68,7 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.utils import is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import ( @@ -265,6 +266,10 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) + # NVFP4 must produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor) + if not columnwise_usage and isinstance(weight_quantizer, NVFP4Quantizer): + columnwise_usage = is_grad_enabled and inp.requires_grad weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) elif isinstance(weight, QuantizedTensor): # If weight is already quantized, no need to set quantizer states @@ -325,7 +330,7 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ - if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache: + if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache and hasattr(weightmat, '_transpose'): assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." nvtx_range_push(f"{nvtx_label}.gemm") @@ -1712,5 +1717,8 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True if IS_HIP_EXTENSION: - weight_quantizer.set_usage(columnwise = self.keep_fp8_weight_transpose_cache) + # NVFP4 must always produce columnwise data at quantization time + # (no lazy transpose like Float8Tensor), so force columnwise=True. + is_nvfp4 = isinstance(weight_quantizer, NVFP4Quantizer) + weight_quantizer.set_usage(columnwise=True if is_nvfp4 else self.keep_fp8_weight_transpose_cache) return [weight_quantizer] From b75e066ed2f517eb198c43c1a9c7b41f62a834c0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Apr 2026 18:00:24 -0500 Subject: [PATCH 52/69] fix mi350 --- .../transpose/quantize_transpose_vector_blockwise_fp4.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 8b814b53a..674599009 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -138,8 +138,10 @@ constexpr int kNFP4PerContainer = 2; // Hyperparameters for performance tuning // gfx942 has 64 KB LDS per workgroup. With kTileDim=128 and float32 input, -// shared memory exceeds 64 KB. Use kTileDim=64 and kThreadsPerBlock=128 on gfx942. -#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx950__) +// shared memory exceeds 64 KB. Use kTileDim=64 and kThreadsPerBlock=128 on AMD. +// TODO: For optimal gfx950 performance (128 KB LDS), implement runtime dispatch +// with two kernel instantiations (kTileDim=64 and kTileDim=128). +#if defined(__HIP_PLATFORM_AMD__) constexpr int kTileDim = 64; constexpr int kThreadsPerBlock = 128; // Thread block size, 4 warps in total #else From 2225c72a51e8e9eb764a71a3cb38a357b84fecde Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 13 Apr 2026 11:59:27 -0500 Subject: [PATCH 53/69] replace dequant allocation with workspace --- tests/pytorch/test_cpu_offloading.py | 9 +++- transformer_engine/common/gemm/rocm_gemm.cu | 45 +++++++++++-------- .../pytorch/cpp_extensions/gemm.py | 6 +-- transformer_engine/pytorch/module/base.py | 6 +-- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 7d1b8c716..f9feee514 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -204,7 +204,14 @@ def memory_leak_check(): # Only cublas workspaces and some global tensors are allowed to be allocated. # All other allocations should be released. # This is a simple check to catch memory leaks. - if Utils.get_cuda_memory_mb() > 1000: + if IS_HIP_EXTENSION: + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes + # workspaces are larger for AMDGPU + addl_space = get_cublas_workspace_size_bytes() / (1024**2) * 4 + else: + addl_space = 0 + + if Utils.get_cuda_memory_mb() > 1000 + addl_space: memory_num = Utils.get_cuda_memory_mb() import gc diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 1f3bce94b..2e9f197f4 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1048,8 +1048,6 @@ void hipblaslt_gemm(const Tensor *inputA, // Alpha is passed as a device vector of length m via // HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST. Beta stays on host. const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype); - void* fp4_dequant_a_buf = nullptr; - void* fp4_dequant_b_buf = nullptr; const void* alpha_ptr = static_cast(&alpha); const void* beta_ptr = static_cast(&beta); if (use_fp4) { @@ -1064,13 +1062,29 @@ void hipblaslt_gemm(const Tensor *inputA, ? reinterpret_cast(inputB->amax.dptr) : reinterpret_cast(inputB->columnwise_amax.dptr); - // Reserve m floats from end of workspace for the device alpha vector. + // Compute total extra bytes needed from the workspace: + // alpha vector: m * sizeof(float) + // dequant A: m * k * sizeof(bf16) (if A is FP4) + // dequant B: k * n * sizeof(bf16) (if B is FP4) const size_t alpha_vec_bytes = static_cast(m) * sizeof(float); - NVTE_CHECK(workspaceSize >= alpha_vec_bytes, - "NVFP4 GEMM requires at least ", alpha_vec_bytes, " bytes workspace for alpha vector."); - workspaceSize = (workspaceSize / sizeof(float)) * sizeof(float) - alpha_vec_bytes; - float* device_alpha_vec = reinterpret_cast( - reinterpret_cast(workspace) + workspaceSize); + const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) + ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; + const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) + ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; + const size_t fp4_total_bytes = alpha_vec_bytes + a_bf16_bytes + b_bf16_bytes; + NVTE_CHECK(workspaceSize >= fp4_total_bytes, + "NVFP4 GEMM requires at least ", fp4_total_bytes, " bytes workspace (", + fp4_total_bytes / (1024 * 1024), " MiB) for alpha vector + BF16 dequant buffers, " + "but only ", workspaceSize, " bytes (", workspaceSize / (1024 * 1024), + " MiB) available. Increase the cuBLAS workspace size."); + + // Carve regions from the end of the workspace. + // Layout: [cublas workspace ... | alpha_vec | dequant_a | dequant_b] + workspaceSize = (workspaceSize / sizeof(float)) * sizeof(float) - fp4_total_bytes; + uint8_t* ws_ptr = reinterpret_cast(workspace) + workspaceSize; + + float* device_alpha_vec = reinterpret_cast(ws_ptr); + ws_ptr += alpha_vec_bytes; NVTE_CHECK(amax_A != nullptr, "FP4 GEMM requires amax_A"); NVTE_CHECK(amax_B != nullptr, "FP4 GEMM requires amax_B"); @@ -1083,9 +1097,8 @@ void hipblaslt_gemm(const Tensor *inputA, // Dequantize FP4 -> BF16 (block scales only, no amax folded in) if (is_fp4_dtype(param.Atype)) { - const size_t a_bf16_bytes = static_cast(m) * k * sizeof(hip_bfloat16); - NVTE_CHECK_CUDA(hipMallocAsync(&fp4_dequant_a_buf, a_bf16_bytes, stream)); - hip_bfloat16* a_bf16 = reinterpret_cast(fp4_dequant_a_buf); + hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr); + ws_ptr += a_bf16_bytes; const int64_t total_a = static_cast(m) * k; // Determine scale stride from scale tensor shape const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA->scale_inv @@ -1102,9 +1115,8 @@ void hipblaslt_gemm(const Tensor *inputA, } if (is_fp4_dtype(param.Btype)) { - const size_t b_bf16_bytes = static_cast(k) * n * sizeof(hip_bfloat16); - NVTE_CHECK_CUDA(hipMallocAsync(&fp4_dequant_b_buf, b_bf16_bytes, stream)); - hip_bfloat16* b_bf16 = reinterpret_cast(fp4_dequant_b_buf); + hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr); + ws_ptr += b_bf16_bytes; const int64_t total_b = static_cast(k) * n; // Determine scale stride from scale tensor shape const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB->scale_inv @@ -1554,11 +1566,6 @@ void hipblaslt_gemm(const Tensor *inputA, update_tensor_scale_inv(outputD, stream); } - if (fp4_dequant_a_buf) - NVTE_CHECK_CUDA(hipFreeAsync(fp4_dequant_a_buf, stream)); - if (fp4_dequant_b_buf) - NVTE_CHECK_CUDA(hipFreeAsync(fp4_dequant_b_buf, stream)); - NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc)); diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 35fae5ac1..c61d1b887 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -36,10 +36,8 @@ def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture.""" if IS_HIP_EXTENSION: - """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if get_device_compute_capability() == (9, 5): - return 67_108_864 - return 33_554_432 + """Return 512 MiB (FP4 dequant buffers).""" + return 512 * 1024 * 1024 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3aac641fa..7954c3e72 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -86,10 +86,8 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if IS_HIP_EXTENSION: - """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if get_device_compute_capability() == (9, 5): - return 67_108_864 - return 33_554_432 + """Return 512 MiB (FP4 dequant buffers).""" + return 512 * 1024 * 1024 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales From f2690977ee5193375c1b3c781fcf12297de3ffc3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 13 Apr 2026 15:30:57 -0500 Subject: [PATCH 54/69] enable tests --- ci/pytorch.sh | 2 ++ .../custom_recipes/quantization_nvfp4.py | 26 +++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index d69009f1a..b9488da11 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -84,6 +84,8 @@ run_test_config(){ NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_numerics.py NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_fusible_ops.py NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 triton_kernels/test_cast.py + + run_default_fa 1 nvfp4/ } run_test_config_mgpu(){ diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d00d0c8b9..9be40418b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -12,6 +12,10 @@ from transformer_engine.pytorch.custom_recipes import quantization from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer +from torch.utils.cpp_extension import IS_HIP_EXTENSION + +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz def nvfp4_ref_rht_2d_quantizer_factory(role): @@ -137,9 +141,15 @@ def cast_to_e4m3(decode_scale, global_amax): TODO(etsykunov): Make less unintuitive. """ decode_scale = decode_scale * global_amax - FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) + if IS_HIP_EXTENSION: + FLOAT8_E4M3_MAX = torch.tensor(240.0 if is_fp8_fnuz() else 448.0, device=decode_scale.device, dtype=torch.float32) + else: + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - return decode_scale.to(torch.float8_e4m3fn) + if IS_HIP_EXTENSION: + return decode_scale.to(get_torch_float8_e4m3_type()) + else: + return decode_scale.to(torch.float8_e4m3fn) def high_precision_gemm_ref( @@ -470,7 +480,7 @@ def _quantize_blockwise_reference( ) # (128, 8, 1) x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) - FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(240.0 if is_fp8_fnuz() else 448.0, device=x.device, dtype=torch.float32) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) if pow_2_scales: @@ -503,7 +513,10 @@ def _quantize_blockwise_reference( ), ) decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - decode_scale = decode_scale.to(torch.float8_e4m3fn) + if IS_HIP_EXTENSION: + decode_scale = decode_scale.to(get_torch_float8_e4m3_type()) + else: + decode_scale = decode_scale.to(torch.float8_e4m3fn) encode_scale = torch.min( torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), @@ -823,7 +836,10 @@ def qgemm( sx = sx.to(torch.float32) sw = sw.to(torch.float32) - factor = 6.0 * 6.0 * 448.0 * 448.0 + if IS_HIP_EXTENSION: + factor = 6.0 * 6.0 * 240.0 * 240.0 if is_fp8_fnuz() else 6.0 * 6.0 * 448.0 * 448.0 + else: + factor = 6.0 * 6.0 * 448.0 * 448.0 if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col From cf2c8f632803df838ed435990358becd7b665b12 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 16 Apr 2026 14:04:50 -0500 Subject: [PATCH 55/69] address reviewer comments --- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 205 ++++++++---------- .../hadamard_transform/hadamard_transform.cu | 106 +++++---- transformer_engine/pytorch/csrc/pybind.h | 2 - 3 files changed, 144 insertions(+), 169 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index de9495edf..ef8f30773 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -18,6 +18,7 @@ from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling +from torch.utils.cpp_extension import IS_HIP_EXTENSION import pytest import torch @@ -249,123 +250,103 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_random_sign_mask=with_random_sign_mask, ) +if IS_HIP_EXTENSION: + def _ref_rht(x: torch.Tensor) -> torch.Tensor: + """Apply reference RHT using NVFP4QuantizerRef._apply_rht.""" + ref = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=False, + quant_tile_shape=(1, 16), + with_rht=True, + with_random_sign_mask=True, + ) + return ref._apply_rht(x) + + + def _ref_quantize_rht( + x: torch.Tensor, global_amax: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize BF16-rounded RHT(x.T) with the given global amax.""" + x_t_rht = _ref_rht(x.t().contiguous()).to(dtype=x.dtype) + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=False, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=False, + with_random_sign_mask=False, + ) -def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: - """Reference 16-point WHT tiled along last dim, normalised by 0.25.""" - x = x.float() - _rows, cols = x.shape - d = torch.tensor( - [((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], - dtype=torch.float32, device=x.device, - ) - out = x.clone() - for c in range(0, cols, 16): - tile = out[:, c:c+16] * d # apply sign - h = 1 - while h < 16: - for i in range(0, 16, h * 2): - a = tile[:, i:i+h].clone() - b = tile[:, i+h:i+2*h].clone() - tile[:, i:i+h] = a + b - tile[:, i+h:i+2*h] = a - b - h *= 2 - out[:, c:c+16] = tile * 0.25 - return out - - -def _ref_quantize_wht16_tiled( - x: torch.Tensor, sign_mask: int, global_amax: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - # Mirror the TE columnwise RHT path by BF16-rounding WHT(x.T) - # before applying NVFP4 reference quantization with the TE global amax. - - x_t_rht = _ref_wht16_tiled(x.t().contiguous(), sign_mask=sign_mask).to(dtype=x.dtype) - ref_quantizer = NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - rowwise=True, - columnwise=False, - pow_2_scales=False, - eps=0.0, - quant_tile_shape=(1, 16), - with_rht=False, - with_random_sign_mask=False, - ) - - x_t_rht_padded = ref_quantizer._pad_tensor( - x_t_rht, - row_divisor=ref_quantizer.quant_tile_shape[0], - col_divisor=ref_quantizer.quant_tile_shape[1], - ) - - qx_t_ref, sx_t_ref = ref_quantizer._quantize_blockwise_reference( - x_t_rht_padded, - global_amax, - ref_quantizer.quant_tile_shape[1], - ref_quantizer.quant_tile_shape[0], - pow_2_scales=ref_quantizer.pow_2_scales, - eps=ref_quantizer.eps, - ) - - qx_t_ref = ref_quantizer._rm_pad_tensor(qx_t_ref, (x_t_rht.shape[0], x_t_rht.shape[1] // 2)) - - return qx_t_ref, sx_t_ref - - -@pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) -def test_hadamard_transform_amax(rows, cols): - """ - Tests hadamard_transform_amax() via NVFP4Quantizer (with_rht=True), - without requiring a full NVFP4 recipe. - Checks: - - amax_rowwise == max|x| (pre-RHT amax of raw input) - - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) - - packed columnwise output == quantized BF16-rounded WHT(x.T) - """ - torch.manual_seed(42) - x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() + qx_t_ref, sx_t_ref = ref_quantizer._quantize_blockwise_reference( + x_t_rht, + global_amax, + ref_quantizer.quant_tile_shape[1], + ref_quantizer.quant_tile_shape[0], + pow_2_scales=ref_quantizer.pow_2_scales, + eps=ref_quantizer.eps, + ) - quantizer = NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, - rowwise=True, - columnwise=True, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=True, - with_post_rht_amax=True, - with_random_sign_mask=True, - ) - out = quantizer(x) - - # amax_rowwise: pre-RHT, should equal max|x| - expected_rowwise_amax = x.float().abs().max() - torch.testing.assert_close( - out._amax_rowwise.float().squeeze(), - expected_rowwise_amax, - rtol=0, atol=0, - ) + return qx_t_ref, sx_t_ref + + + @pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) + def test_hadamard_transform_amax(rows, cols): + """ + Tests hadamard_transform_amax() via NVFP4Quantizer (with_rht=True), + without requiring a full NVFP4 recipe. + + Checks: + - amax_rowwise == max|x| (pre-RHT amax of raw input) + - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) + - packed columnwise output == quantized BF16-rounded WHT(x.T) + """ + torch.manual_seed(42) + x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() + + quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + out = quantizer(x) + + # amax_rowwise: pre-RHT, should equal max|x| + expected_rowwise_amax = x.float().abs().max() + torch.testing.assert_close( + out._amax_rowwise.float().squeeze(), + expected_rowwise_amax, + rtol=0, atol=0, + ) - # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| - sign_mask_t = quantizer.rht_matrix_random_sign_mask_t - x_t = x.t().contiguous() # (cols, rows) - wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t).to(torch.bfloat16).float() - expected_colwise_amax = wht_x_t.float().abs().max() + # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| + x_t = x.t().contiguous() # (cols, rows) + wht_x_t = _ref_rht(x_t).to(torch.bfloat16).float() + expected_colwise_amax = wht_x_t.float().abs().max() - torch.testing.assert_close( - out._amax_columnwise.float().squeeze().item(), - float(expected_colwise_amax), - rtol=0, atol=0, - ) + torch.testing.assert_close( + out._amax_columnwise.float().squeeze().item(), + float(expected_colwise_amax), + rtol=0, atol=0, + ) - assert out._columnwise_data is not None - assert out._columnwise_scale_inv is not None + assert out._columnwise_data is not None + assert out._columnwise_scale_inv is not None - qx_t_ref, sx_t_ref = _ref_quantize_wht16_tiled(x, sign_mask_t, out._amax_columnwise) + qx_t_ref, sx_t_ref = _ref_quantize_rht(x, out._amax_columnwise) - qx_t = unpack_fp4(out._columnwise_data.view(torch.uint8)) - qx_t_ref = unpack_fp4(qx_t_ref.view(torch.uint8)) - torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + qx_t = unpack_fp4(out._columnwise_data.view(torch.uint8)) + qx_t_ref = unpack_fp4(qx_t_ref.view(torch.uint8)) + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) - sx_t = out._columnwise_scale_inv - sx_t_ref = sx_t_ref.view(dtype=torch.uint8) - sx_t_valid = sx_t[: sx_t_ref.shape[0], : sx_t_ref.shape[1]] - torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + sx_t = out._columnwise_scale_inv + sx_t_ref = sx_t_ref.view(dtype=torch.uint8) + sx_t_valid = sx_t[: sx_t_ref.shape[0], : sx_t_ref.shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 8472fea18..d5d403553 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -25,9 +25,8 @@ namespace transformer_engine { namespace { -constexpr int kThreadsPerWarp = 32; - #ifndef __HIP_PLATFORM_AMD__ +constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], @@ -491,21 +490,37 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri } #endif // __HIP_PLATFORM_AMD__ -} // namespace #ifdef __HIP_PLATFORM_AMD__ -namespace { - -static constexpr int kHadamardDim = 16; -static constexpr int kWarpSize = 64; -static constexpr int kThreadsPerWHT = 4; -static constexpr int kElemsPerThread = 4; +// Tiling / layout constants +// +// A 16-point WHT operates on tiles of kHadamardDim (16) elements. +// Each tile is processed by kThreadsPerWHT (4) threads, each holding +// kElemsPerThread (4) values, so one wavefront of kWarpSize (64) lanes +// handles kRowsPerWarp (16) independent tiles (= rows) simultaneously. +// kWarpsPerBlock wavefronts are combined into a thread-block that covers +// kRowsPerBlock (64) consecutive rows. +static constexpr int kHadamardDim = 16; // WHT dimension (H16) +static constexpr int kWarpSize = 64; // Wavefront width +static constexpr int kThreadsPerWHT = 4; // threads per 16-pt WHT +static constexpr int kElemsPerThread = 4; // elements each thread owns static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 static constexpr int kWarpsPerBlock = 4; static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 -static constexpr float kHadamardScale = 0.25f; +static constexpr float kHadamardScale = 0.25f; // 1/sqrt(16) + +// Reduce per-warp amax values in warp 0 and atomically update a global amax. +__device__ __forceinline__ void reduce_block_amax( + const float* __restrict__ warp_amax, int lane_id, + float* __restrict__ global_amax) { + float val = (lane_id < kWarpsPerBlock) ? warp_amax[lane_id] : 0.f; + for (int off = kWarpSize / 2; off >= 1; off >>= 1) + val = fmaxf(val, __shfl_xor(val, off)); + if (lane_id == 0) + atomicMaxFloat(global_amax, val); +} // ds_swizzle: sub-wavefront exchange without LDS. // Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu. @@ -675,19 +690,20 @@ void HadamardTransformKernel( // Identity path: WHT along row dimension if constexpr (kComputeIdentity || kUpdateAmax) { float r0=v0, r1=v1, r2=v2, r3=v3; - float lam = 0.f; + float local_amax = 0.f; if (global_row < num_rows) { wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); if constexpr (kUpdateAmax) { - // Match the stored/output precision when reporting amax. + // Down-cast to BF16 and back so the amax matches the + // stored/output precision (matches upstream NV behaviour). const float r0_bf16 = to_f32(to_bf16(r0)); const float r1_bf16 = to_f32(to_bf16(r1)); const float r2_bf16 = to_f32(to_bf16(r2)); const float r3_bf16 = to_f32(to_bf16(r3)); - lam = fmaxf(fmaxf(fabsf(r0_bf16), fabsf(r1_bf16)), - fmaxf(fabsf(r2_bf16), fabsf(r3_bf16))); + local_amax = fmaxf(fmaxf(fabsf(r0_bf16), fabsf(r1_bf16)), + fmaxf(fabsf(r2_bf16), fabsf(r3_bf16))); for (int off=kWarpSize/2; off>=1; off>>=1) - lam=fmaxf(lam,__shfl_xor(lam,off)); + local_amax=fmaxf(local_amax,__shfl_xor(local_amax,off)); } if constexpr (kComputeIdentity) if (output && in_bounds) @@ -696,7 +712,7 @@ void HadamardTransformKernel( } if constexpr (kUpdateAmax) { if (lane_id == 0) - block_amax[warp_id] = lam; + block_amax[warp_id] = local_amax; } } @@ -704,7 +720,7 @@ void HadamardTransformKernel( if constexpr (kComputeTransposed || kUpdateAmaxT) { const int local_row = warp_id * kRowsPerWarp + row_in_warp; const int col_offset = thread_in_grp * kElemsPerThread; - float lam = 0.f; + float local_amax = 0.f; smem[local_row][col_offset+0] = to_bf16(global_row < num_rows ? v0 : 0.f); smem[local_row][col_offset+1] = to_bf16(global_row < num_rows ? v1 : 0.f); smem[local_row][col_offset+2] = to_bf16(global_row < num_rows ? v2 : 0.f); @@ -721,16 +737,17 @@ void HadamardTransformKernel( wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); if constexpr (kUpdateAmaxT) { - // Match the stored/output precision when reporting amax. + // Down-cast to BF16 and back so the amax matches the + // stored/output precision (matches upstream NV behaviour). const float c0_bf16 = to_f32(to_bf16(c0)); const float c1_bf16 = to_f32(to_bf16(c1)); const float c2_bf16 = to_f32(to_bf16(c2)); const float c3_bf16 = to_f32(to_bf16(c3)); - lam = fmaxf(fmaxf(fabsf(c0_bf16), fabsf(c1_bf16)), - fmaxf(fabsf(c2_bf16), fabsf(c3_bf16))); + local_amax = fmaxf(fmaxf(fabsf(c0_bf16), fabsf(c1_bf16)), + fmaxf(fabsf(c2_bf16), fabsf(c3_bf16))); for (int off=kWarpSize/2; off>=1; off>>=1) - lam=fmaxf(lam,__shfl_xor(lam,off)); + local_amax=fmaxf(local_amax,__shfl_xor(local_amax,off)); } if constexpr (kComputeTransposed) { @@ -747,7 +764,7 @@ void HadamardTransformKernel( if constexpr (kUpdateAmaxT) { if (lane_id == 0) - block_amax_t[warp_id] = lam; + block_amax_t[warp_id] = local_amax; } } @@ -755,25 +772,10 @@ void HadamardTransformKernel( __syncthreads(); if (warp_id == 0) { - if constexpr (kUpdateAmax) { - float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; - - for (int off=kWarpSize/2; off>=1; off>>=1) - block_lam = fmaxf(block_lam, __shfl_xor(block_lam, off)); - - if (lane_id == 0) - atomicMaxFloat(amax, block_lam); - } - - if constexpr (kUpdateAmaxT) { - float block_lam_t = (lane_id < kWarpsPerBlock) ? block_amax_t[lane_id] : 0.f; - - for (int off=kWarpSize/2; off>=1; off>>=1) - block_lam_t = fmaxf(block_lam_t, __shfl_xor(block_lam_t, off)); - - if (lane_id == 0) - atomicMaxFloat(amax_t, block_lam_t); - } + if constexpr (kUpdateAmax) + reduce_block_amax(block_amax, lane_id, amax); + if constexpr (kUpdateAmaxT) + reduce_block_amax(block_amax_t, lane_id, amax_t); } } } @@ -782,29 +784,23 @@ void HadamardTransformKernel( __global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, float* __restrict__ amax_out, uint64_t num_elems) { __shared__ float block_amax[kWarpsPerBlock]; - float lam = 0.f; + float local_amax = 0.f; for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) - lam = fmaxf(lam, fabsf(to_f32(input[i]))); + local_amax = fmaxf(local_amax, fabsf(to_f32(input[i]))); for (int off=kWarpSize/2; off>=1; off>>=1) - lam=fmaxf(lam,__shfl_xor(lam,off)); + local_amax=fmaxf(local_amax,__shfl_xor(local_amax,off)); const int warp_id = threadIdx.x / kWarpSize; const int lane_id = threadIdx.x % kWarpSize; if (lane_id == 0) - block_amax[warp_id] = lam; + block_amax[warp_id] = local_amax; __syncthreads(); - if (warp_id == 0) { - float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; - for (int off=kWarpSize/2; off>=1; off>>=1) - block_lam=fmaxf(block_lam,__shfl_xor(block_lam,off)); - - if (lane_id == 0) - atomicMaxFloat(amax_out, block_lam); - } + if (warp_id == 0) + reduce_block_amax(block_amax, lane_id, amax_out); } static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { @@ -812,10 +808,10 @@ static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { (uint32_t)DIVUP(num_rows, (uint64_t)kRowsPerBlock)); } -} // namespace - #endif // __HIP_PLATFORM_AMD__ +} // namespace + void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index b924e8a77..25ffef058 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. From 27728348a47a6795766191a31fb22be988026733 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 16 Apr 2026 14:11:11 -0500 Subject: [PATCH 56/69] minor fixes --- .../common/hadamard_transform/hadamard_transform.cu | 1 + .../pytorch/custom_recipes/quantization_nvfp4.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index d5d403553..451ddf969 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -27,6 +27,7 @@ namespace { #ifndef __HIP_PLATFORM_AMD__ constexpr int kThreadsPerWarp = 32; + template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 9be40418b..73a10308d 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -480,7 +482,10 @@ def _quantize_blockwise_reference( ) # (128, 8, 1) x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) - FLOAT8_E4M3_MAX = torch.tensor(240.0 if is_fp8_fnuz() else 448.0, device=x.device, dtype=torch.float32) + if IS_HIP_EXTENSION: + FLOAT8_E4M3_MAX = torch.tensor(240.0 if is_fp8_fnuz() else 448.0, device=x. device, dtype=torch.float32) + else: + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) if pow_2_scales: From 26c5cb14efc7983c326a1c3c811b5705ae524afa Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 16 Apr 2026 15:48:11 -0500 Subject: [PATCH 57/69] PreRhtAmax optimizations --- .../hadamard_transform/hadamard_transform.cu | 90 ++++++++++--------- transformer_engine/pytorch/csrc/quantizer.cpp | 87 ++++++++++++++---- 2 files changed, 118 insertions(+), 59 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 451ddf969..ef60a13be 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -653,7 +653,8 @@ __device__ __forceinline__ void wht16( // Block: 256 threads = 4 wavefronts of 64 lanes. // lane/4 = row_in_warp (0..15), lane%4 = thread_in_grp (0..3) template + bool kUpdateAmax, bool kUpdateAmaxT, + bool kUpdatePreRhtAmax = false> __global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformKernel( const __hip_bfloat16* __restrict__ input, @@ -662,6 +663,7 @@ void HadamardTransformKernel( uint16_t random_sign_mask, uint16_t random_sign_mask_t, uint64_t num_rows, uint64_t row_length, float* __restrict__ amax, float* __restrict__ amax_t, + float* __restrict__ pre_rht_amax, bool inverse_hadamard) { const int tid = threadIdx.x; const int warp_id = tid / kWarpSize; @@ -681,6 +683,7 @@ void HadamardTransformKernel( __shared__ float block_amax[kWarpsPerBlock]; __shared__ float block_amax_t[kWarpsPerBlock]; + __shared__ float block_pre_rht_amax[kWarpsPerBlock]; float v0=0.f, v1=0.f, v2=0.f, v3=0.f; if (in_bounds) { @@ -688,6 +691,19 @@ void HadamardTransformKernel( &input[global_row * row_length + col_base]), v0, v1, v2, v3); } + // Pre-RHT amax: max|input| before any transform + if constexpr (kUpdatePreRhtAmax) { + float local_pre = 0.f; + if (in_bounds) { + local_pre = fmaxf(fmaxf(fabsf(v0), fabsf(v1)), + fmaxf(fabsf(v2), fabsf(v3))); + } + for (int off = kWarpSize / 2; off >= 1; off >>= 1) + local_pre = fmaxf(local_pre, __shfl_xor(local_pre, off)); + if (lane_id == 0) + block_pre_rht_amax[warp_id] = local_pre; + } + // Identity path: WHT along row dimension if constexpr (kComputeIdentity || kUpdateAmax) { float r0=v0, r1=v1, r2=v2, r3=v3; @@ -769,7 +785,7 @@ void HadamardTransformKernel( } } - if constexpr (kUpdateAmax || kUpdateAmaxT) { + if constexpr (kUpdateAmax || kUpdateAmaxT || kUpdatePreRhtAmax) { __syncthreads(); if (warp_id == 0) { @@ -777,33 +793,12 @@ void HadamardTransformKernel( reduce_block_amax(block_amax, lane_id, amax); if constexpr (kUpdateAmaxT) reduce_block_amax(block_amax_t, lane_id, amax_t); + if constexpr (kUpdatePreRhtAmax) + reduce_block_amax(block_pre_rht_amax, lane_id, pre_rht_amax); } } } -// Pre-RHT amax: max|input| before any transform. -__global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, - float* __restrict__ amax_out, uint64_t num_elems) { - __shared__ float block_amax[kWarpsPerBlock]; - float local_amax = 0.f; - for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; - i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) - local_amax = fmaxf(local_amax, fabsf(to_f32(input[i]))); - - for (int off=kWarpSize/2; off>=1; off>>=1) - local_amax=fmaxf(local_amax,__shfl_xor(local_amax,off)); - - const int warp_id = threadIdx.x / kWarpSize; - const int lane_id = threadIdx.x % kWarpSize; - if (lane_id == 0) - block_amax[warp_id] = local_amax; - - __syncthreads(); - - if (warp_id == 0) - reduce_block_amax(block_amax, lane_id, amax_out); -} - static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { return dim3((uint32_t)(row_length / kHadamardDim), (uint32_t)DIVUP(num_rows, (uint64_t)kRowsPerBlock)); @@ -896,7 +891,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s reinterpret_cast(output.dptr), reinterpret_cast(output_t.dptr), random_sign_mask, random_sign_mask_t, static_cast(num_rows), - static_cast(row_length), nullptr, nullptr, false););); + static_cast(row_length), nullptr, nullptr, nullptr, false););); #else auto kernel = HadamardTransformKernel(output_pre_rht_amax.dptr); auto* id_amax_ptr = reinterpret_cast(output_identity_amax.dptr); auto* tr_amax_ptr = reinterpret_cast(output_transpose_amax.dptr); + // Only use output_.data as transposed output buffer if it's actually BF16. + // When called from the fused path, output_.data is a BF16 buffer for RHT output. + // When called from the non-fused path, output_.data is an FP4 buffer (wrong type/size). + auto* out_t_ptr = (output_t.dtype == transformer_engine::DType::kBFloat16) + ? reinterpret_cast<__hip_bfloat16*>(output_t.dptr) + : static_cast<__hip_bfloat16*>(nullptr); NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); @@ -981,13 +987,6 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran if (tr_amax_ptr) { NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); } - - if (return_pre_rht_amax) { - const uint64_t num_elems = static_cast(num_rows) * row_length; - dim3 grid(DIVUP(num_elems, static_cast(kThreadsPerBlock))); - PreRhtAmaxKernel<<>>(in_ptr, pre_amax_ptr, num_elems); - NVTE_CHECK_CUDA(cudaGetLastError()); - } #else constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, @@ -1030,14 +1029,25 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran return_identity_amax, kReturnIdentityAmax, #ifdef __HIP_PLATFORM_AMD__ - if (kReturnIdentityAmax || kReturnTransposedAmax) { + { + // Compute transposed path if we need transposed amax or have an output buffer. + const bool compute_transposed = kReturnTransposedAmax || (out_t_ptr != nullptr); dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); - HadamardTransformKernel - <<>>(in_ptr, nullptr, nullptr, random_sign_mask, - random_sign_mask_t, static_cast(num_rows), - static_cast(row_length), id_amax_ptr, - tr_amax_ptr, false); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_pre_rht_amax, kReturnPreRhtAmax, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + compute_transposed, kComputeTransposed, + if (kReturnIdentityAmax || kComputeTransposed || kReturnPreRhtAmax) { + HadamardTransformKernel + <<>>( + in_ptr, nullptr, out_t_ptr, random_sign_mask, + random_sign_mask_t, static_cast(num_rows), + static_cast(row_length), id_amax_ptr, + tr_amax_ptr, pre_amax_ptr, false); + } + )); } )); #else diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f8b575971..be6474c1d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1507,18 +1507,54 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou #endif // Compute amax. +#ifdef USE_ROCM + // Allocate rht_output_t early so that the amax kernel can also write the + // transposed RHT output in the same kernel launch (fused amax + transform). + at::Tensor rht_output_t; +#endif if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); } + +#ifdef USE_ROCM + // Pre-allocate if we'll need the transposed RHT output later + if (this->columnwise_usage && !eligible_for_rht_cast_fusion) { + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + } +#endif + if (this->with_post_rht_amax) { // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for RHT(input.t) - NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform_amax(input.data(), out.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); +#ifdef USE_ROCM + if (rht_output_t.defined()) { + // Fused path: compute amax AND write transposed RHT output in one kernel. + // Create a wrapper with amax fields from out + data pointing to rht_output_t. + TensorWrapper amax_and_transform(out.scaling_mode()); + auto out_amax = out.get_amax(); + auto out_col_amax = out.get_columnwise_amax(); + amax_and_transform.set_amax(out_amax.data_ptr, static_cast(out_amax.dtype), + out_amax.shape); + amax_and_transform.set_columnwise_amax(out_col_amax.data_ptr, + static_cast(out_col_amax.dtype), + out_col_amax.shape); + amax_and_transform.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_amax(input.data(), amax_and_transform.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } else +#endif + { + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_amax(input.data(), out.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } } else { // raise error since it's not supported yet NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); @@ -1631,26 +1667,39 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou need_separate_columnwise_rng ? quant_config_columnwise : quant_config; if (!eligible_for_rht_cast_fusion) { - // Invoking fallback RHT kernel. - - // If using RHT, then amax will be computed in the RHT step - // If not using RHT, then amax will be computed based on input x - at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout - // This wrapper is going to be passed as input to the quantization kernel. - TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs +#ifdef USE_ROCM + // If rht_output_t was already produced by the fused amax+transform kernel above, + // skip the separate hadamard_transform call. + if (!rht_output_t.defined()) { + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + TensorWrapper rht_output_t_cpp; + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } +#else + at::Tensor rht_output_t; rht_output_t = allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); - // NOTE (frsun): This is non-intuitive, we are writing the - // result of transposed RHT to the output of rowwise. + { + TensorWrapper rht_output_t_cpp; + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } +#endif + + TensorWrapper rht_output_t_cpp; rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), std::vector{cols, rows}); - NVTE_SCOPED_GIL_RELEASE({ - // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. - nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - // Quantize kernel will treat everything as rowwise input/output, which is // intended. NVTE_SCOPED_GIL_RELEASE({ From 018d24fd74315b27998b53595457207e32fb4852 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Apr 2026 10:12:28 -0500 Subject: [PATCH 58/69] use ZeroAmaxKernel --- .../common/hadamard_transform/hadamard_transform.cu | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index ef60a13be..0631f326b 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -166,6 +166,7 @@ __device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float } } } +#endif __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, float* __restrict__ output_identity_amax_ptr, @@ -181,6 +182,7 @@ __launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_p } } +#ifndef __HIP_PLATFORM_AMD__ template @@ -978,15 +980,8 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran auto* in_ptr = reinterpret_cast(input.dptr); - if (pre_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); - } - if (id_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); - } - if (tr_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); - } + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); #else constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, From feff829378980c6fa98613f3ed9374e5a7df46e0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 12:50:16 -0500 Subject: [PATCH 59/69] undo hadamard_fusion --- .../hadamard_transform_cast_fusion.cu | 275 +----------------- 1 file changed, 2 insertions(+), 273 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index cc970aab5..0696deaaa 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -1,6 +1,4 @@ /************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -11,24 +9,19 @@ #include #include #include -#ifndef __HIP_PLATFORM_AMD__ #include -#endif #include #include -#ifndef __HIP_PLATFORM_AMD__ #include #include #include -#endif #include "common/common.h" #include "common/util/cuda_runtime.h" #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#ifndef __HIP_PLATFORM_AMD__ #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/builders/sm100_common.inl" @@ -37,11 +30,9 @@ #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/print_error.hpp" -#endif // clang-format off -#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace detail { namespace { @@ -735,246 +726,6 @@ rht_gemm_ttt_wrapper(int m, int n, // clang-format on -} // namespace transformer_engine -#else - -#include "wht16.cuh" - -namespace transformer_engine { - -namespace { - -__device__ __forceinline__ float to_f32(__hip_bfloat16 v) { return static_cast(v); } - -__device__ __forceinline__ float group_max_4(float v) { - v = fmaxf(v, ds_swizzle_xor1(v)); - v = fmaxf(v, ds_swizzle_xor2(v)); - return v; -} - -__device__ __forceinline__ float compute_global_encode_scale_fp4(const float global_amax) { -#if !defined(__HIP_DEVICE_COMPILE__) - const float fp8_max = detail::TypeExtrema::max; -#else - constexpr float fp8_max = detail::TypeExtrema::max; -#endif - constexpr float fp4_max = detail::TypeExtrema::max; - float global_encode_scale = fp8_max * fp4_max / global_amax; - global_encode_scale = fminf(global_encode_scale, detail::TypeExtrema::max); - return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; -} - -template -__device__ __forceinline__ ScaleType compute_decode_scale_fp4(const float amax, - const float global_encode_scale) { - float decode_scale = amax / detail::TypeExtrema::max; - decode_scale *= global_encode_scale; - decode_scale = fminf(decode_scale, detail::TypeExtrema::max); - return static_cast(decode_scale); -} - -template -__device__ __forceinline__ float compute_encode_scale_fp4(ScaleType decode_scale, - const float global_decode_scale) { - return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), - detail::TypeExtrema::max); -} - -__device__ __forceinline__ uint32_t get_rbits( - transformer_engine::curanddx::detail::philox4x32_native_state<10>& rng, uint4& random_uint4, - int& rnd_idx) { - if (rnd_idx == 4) { - rnd_idx = 0; - random_uint4 = rng.generate4(); - } - const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); - return rbits_arr[rnd_idx++]; -} - -template -__device__ __forceinline__ fp4e2m1x4 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, - const uint32_t rbits) { - if constexpr (kUseStochasticRounding) { -#if ARCH_HAS_STOCHASTIC_ROUNDING - union { - uint32_t ui32; - __hip_fp4x2_storage_t fp4x2[4]; - } packed{0}; - __amd_floatx2_storage_t packed01{in01.x, in01.y}; - __amd_floatx2_storage_t packed23{in23.x, in23.y}; - packed.ui32 = - __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed01, rbits, 1.0f, 1); - const __hip_fp4x2_storage_t lo = packed.fp4x2[1]; - packed.ui32 = - __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(packed.ui32, packed23, rbits, 1.0f, 1); - const __hip_fp4x2_storage_t hi = packed.fp4x2[1]; - - fp4e2m1x4 result; - result.__x = static_cast<__hip_fp4x4_storage_t>( - lo | (static_cast<__hip_fp4x4_storage_t>(hi) << 8)); - return result; -#else - // Software stochastic rounding fallback for AMD GPUs without native - // FP4 SR instructions (e.g. gfx942). - // - // FP4 E2M1 has 8 non-negative magnitudes whose 3-bit codes happen to - // be sorted: {0->0.0, 1->0.5, 2->1.0, 3->1.5, 4->2.0, 5->3.0, - // 6->4.0, 7->6.0}. - // - // For each value we: - // 1. Clamp |x| into [0, 6] (the FP4 representable range). - // 2. Find the floor index fi in the FP4 grid via branchless - // comparisons (sum of (|x| >= threshold) for each level). - // 3. Compute the fractional position within [kV[fi], kV[ci]] - // where ci = min(fi+1, 7) is the ceiling index. - // 4. Draw a uniform random value r in [0,1) from 8 bits of rbits. - // 5. Round up to ci if r < frac, otherwise keep fi. - // This gives E[round(x)] = x (unbiased). - // 6. Set the sign bit (bit 3) if the original value was negative. - { - constexpr float kV[8] = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f}; - const float vals[4] = {in01.x, in01.y, in23.x, in23.y}; - __hip_fp4_storage_t q[4]; -#pragma unroll - for (int i = 0; i < 4; ++i) { - const float av = fminf(fabsf(vals[i]), 6.0f); - const int fi = int(av >= 0.5f) + int(av >= 1.0f) + int(av >= 1.5f) + - int(av >= 2.0f) + int(av >= 3.0f) + int(av >= 4.0f) + int(av >= 6.0f); - const int ci = min(fi + 1, 7); - const float gap = kV[ci] - kV[fi]; - const float frac = (gap > 0.0f) ? (av - kV[fi]) / gap : 0.0f; - const float r = static_cast((rbits >> (8 * i)) & 0xFFu) * (1.0f / 256.0f); - const int ri = (r < frac) ? ci : fi; - q[i] = static_cast<__hip_fp4_storage_t>((vals[i] < 0.0f) ? (ri | 0x8) : ri); - } - fp4e2m1x4 result; - result.__x = static_cast<__hip_fp4x4_storage_t>( - (q[0] & 0xFu) | ((q[1] & 0xFu) << 4) | - ((q[2] & 0xFu) << 8) | ((q[3] & 0xFu) << 12)); - return result; - } -#endif - } else { - const __hip_fp4_storage_t q0 = - __hip_cvt_float_to_fp4(in01.x, __HIP_E2M1, hipRoundNearest); - const __hip_fp4_storage_t q1 = - __hip_cvt_float_to_fp4(in01.y, __HIP_E2M1, hipRoundNearest); - const __hip_fp4_storage_t q2 = - __hip_cvt_float_to_fp4(in23.x, __HIP_E2M1, hipRoundNearest); - const __hip_fp4_storage_t q3 = - __hip_cvt_float_to_fp4(in23.y, __HIP_E2M1, hipRoundNearest); - - fp4e2m1x4 result; - result.__x = static_cast<__hip_fp4x4_storage_t>((q0 & 0xFu) | ((q1 & 0xFu) << 4) | - ((q2 & 0xFu) << 8) | ((q3 & 0xFu) << 12)); - return result; - } -} - -__device__ __forceinline__ uint16_t fp4x4_to_bits(fp4e2m1x4 v) { - uint16_t bits; - __builtin_memcpy(&bits, &v, sizeof(bits)); - return bits; -} - -template -__global__ __launch_bounds__(kThreadsPerBlock, 4) void HadamardTransformCastFusionKernel( - const __hip_bfloat16* __restrict__ input, uint8_t* __restrict__ output_t, - fp8e4m3* __restrict__ scale_inv_t, const float* __restrict__ global_amax_ptr, - const __hip_bfloat16* __restrict__ hadamard_matrix, const uint64_t num_rows, - const uint64_t row_length, const size_t scale_stride, const size_t* rng_state) { - - // Thread 0 loads global_amax and computes random sign mask - __shared__ uint16_t s_random_sign_mask; - __shared__ float s_global_amax; - if (threadIdx.x == 0) { - s_global_amax = *global_amax_ptr; - uint16_t mask = 0; - for (int row = 0; row < kHadamardDim; ++row) { - mask |= static_cast((to_f32(hadamard_matrix[row * kHadamardDim]) < 0.0f ? 1u : 0u) << row); - } - s_random_sign_mask = mask; - } - __syncthreads(); - const float global_amax = s_global_amax; - - const int tid = threadIdx.x; - const int warp_id = tid / kWarpSize; - const int lane_id = tid % kWarpSize; - const int row_in_warp = lane_id / kThreadsPerWHT; - const int thread_in_grp = lane_id % kThreadsPerWHT; - - const uint64_t output_row = static_cast(blockIdx.x) * kHadamardDim + row_in_warp; - const uint64_t block_row_base = - static_cast(blockIdx.y) * kRowsPerBlock + warp_id * kHadamardDim; - - if (block_row_base + kHadamardDim > num_rows) { - return; - } - - const uint64_t input_row_base = block_row_base + thread_in_grp * kElemsPerThread; - const uint64_t input_col = output_row; - - float c0 = to_f32(input[(input_row_base + 0) * row_length + input_col]); - float c1 = to_f32(input[(input_row_base + 1) * row_length + input_col]); - float c2 = to_f32(input[(input_row_base + 2) * row_length + input_col]); - float c3 = to_f32(input[(input_row_base + 3) * row_length + input_col]); - - wht16(c0, c1, c2, c3, thread_in_grp, s_random_sign_mask, /*apply_pre=*/true); - - // Truncate to BF16 precision to match the reference BF16 matmul path. - // Without this, FP32 WHT results at FP4 quantization boundaries round - // differently than the BF16-precision reference, causing off-by-one errors. - c0 = to_f32(static_cast<__hip_bfloat16>(c0)); - c1 = to_f32(static_cast<__hip_bfloat16>(c1)); - c2 = to_f32(static_cast<__hip_bfloat16>(c2)); - c3 = to_f32(static_cast<__hip_bfloat16>(c3)); - - const float local_block_amax = - fmaxf(fmaxf(fabsf(c0), fabsf(c1)), fmaxf(fabsf(c2), fabsf(c3))); - const float block_amax = group_max_4(local_block_amax); - - const float global_encode_scale = compute_global_encode_scale_fp4(global_amax); - const float global_decode_scale = 1.0f / global_encode_scale; - const fp8e4m3 scale_inv = compute_decode_scale_fp4(block_amax, global_encode_scale); - const float encode_scale = compute_encode_scale_fp4(scale_inv, global_decode_scale); - - if (thread_in_grp == 0) { - const uint64_t scale_col = block_row_base / kHadamardDim; - scale_inv_t[output_row * scale_stride + scale_col] = scale_inv; - } - - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - uint4 random_uint4{0, 0, 0, 0}; - int rnd_idx = 0; - if constexpr (kUseStochasticRounding) { - const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; - const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - const size_t rng_sequence = static_cast(threadIdx.x) + - static_cast(blockIdx.x) * blockDim.x + - static_cast(blockIdx.y) * gridDim.x * blockDim.x; - rng.init(rng_seed, rng_sequence, rng_offset); - random_uint4 = rng.generate4(); - } - - const float2 scaled01{c0 * encode_scale, c1 * encode_scale}; - const float2 scaled23{c2 * encode_scale, c3 * encode_scale}; - const uint32_t rbits = kUseStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; - const uint16_t packed = fp4x4_to_bits(cvt_fp32_to_fp4_4x( - scaled01, scaled23, rbits)); - - const uint64_t output_col_base = input_row_base; - const uint64_t output_byte_offset = output_row * (num_rows / 2) + output_col_base / 2; - *reinterpret_cast(&output_t[output_byte_offset]) = packed; -} - -} // namespace - -} // namespace transformer_engine -#endif - -namespace transformer_engine { - void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, const Tensor &hadamard_matrix_, QuantizationConfig quant_config, @@ -1006,7 +757,6 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out rng_state = reinterpret_cast(rng_state_tensor.data.dptr); } -#ifndef __HIP_PLATFORM_AMD__ // Template arguments using TA = cute::bfloat16_t; using TB = cute::bfloat16_t; @@ -1014,7 +764,6 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out using TSFC = cutlass::float_ue4m3_t; checkCuDriverContext(stream); -#endif // Check Hadamard matrix constexpr int kHadamardDimension = 16; @@ -1039,15 +788,12 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out m *= input.shape[i]; } -#ifndef __HIP_PLATFORM_AMD__ auto sm_count = transformer_engine::cuda::sm_count(); -#endif NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); -#ifndef __HIP_PLATFORM_AMD__ int k_tile_size = 1024; if (m == 8192 && n == 5120) { @@ -1079,6 +825,8 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out } TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + TRANSFORMER_ENGINE_SWITCH_CONDITION( quant_config.use_fast_math, kUseFastMath, detail::rht_gemm_ttt_wrapper( /*m=*/m, @@ -1092,25 +840,6 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out /*sm_count=*/sm_count, /*stream=*/stream, /*k_tile_size=*/k_tile_size););); -#else - const dim3 block(kThreadsPerBlock); - const dim3 grid(DIVUP(n, static_cast(kHadamardDim)), - DIVUP(m, static_cast(kRowsPerBlock))); - const size_t scale_stride = m / kHadamardDim; - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, kUseStochasticRounding, - HadamardTransformCastFusionKernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv_t.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(hadamard_matrix.dptr), - static_cast(m), - static_cast(n), - scale_stride, - rng_state);); -#endif } } // namespace transformer_engine From 1ba94746728f61dcfff49738be016e2bf0849734 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 13:35:42 -0500 Subject: [PATCH 60/69] fixes --- tests/cpp/test_common.h | 15 ------------ transformer_engine/pytorch/csrc/quantizer.cpp | 23 ++++++------------- 2 files changed, 7 insertions(+), 31 deletions(-) diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 3c5642b86..796e66999 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -301,21 +301,6 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } - void set_tensor_amax(float amax) { - if (!amax_cpu_data_) { - amax_cpu_data_ = std::make_shared(amax); - } else { - *amax_cpu_data_ = amax; - } - - float *amax_gpu = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float))); - NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(), - sizeof(float), cudaMemcpyHostToDevice)); - - tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape); - } - void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){ tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index be6474c1d..88f6c1f79 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1667,10 +1667,15 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou need_separate_columnwise_rng ? quant_config_columnwise : quant_config; if (!eligible_for_rht_cast_fusion) { -#ifdef USE_ROCM +#ifndef USE_ROCM + at::Tensor rht_output_t; +#endif // If rht_output_t was already produced by the fused amax+transform kernel above, // skip the separate hadamard_transform call. - if (!rht_output_t.defined()) { +#ifdef USE_ROCM + if (!rht_output_t.defined()) +#endif + { rht_output_t = allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); TensorWrapper rht_output_t_cpp; @@ -1681,20 +1686,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou this->rht_matrix_random_sign_mask_t, stream); }); } -#else - at::Tensor rht_output_t; - rht_output_t = - allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); - { - TensorWrapper rht_output_t_cpp; - rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); - NVTE_SCOPED_GIL_RELEASE({ - nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, - this->rht_matrix_random_sign_mask_t, stream); - }); - } -#endif TensorWrapper rht_output_t_cpp; rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), From e1ba512688a6b43e1cca8a19fc3a4b5ec8c3b275 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 14:46:12 -0500 Subject: [PATCH 61/69] cleanups, cleaner mi300 LDS workaround --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 1 - transformer_engine/common/recipe/nvfp4.cu | 2 +- ...quantize_transpose_vector_blockwise_fp4.cu | 130 ++++++++++++++---- 3 files changed, 103 insertions(+), 30 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index c39b57742..8901f2351 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -61,7 +61,6 @@ __global__ void __launch_bounds__(512) fp8e4m3 scale = scales[my_scale_index]; // NVFP4 may reach this path with scale present but no separate amax buffer. // Use 1.0f as the neutral fallback when tensor_amax is not provided on HIP. -// #ifndef __HIP_PLATFORM_AMD__ float amax = *tensor_amax; #if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__) // On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index e1d30f3af..af4847e4b 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -23,7 +23,7 @@ constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, const float *amax_B, float *alpha_out) { #ifdef __HIP_PLATFORM_AMD__ - const float fp4_max = detail::TypeExtrema::max; + constexpr float fp4_max = detail::TypeExtrema::max; const float fp8_max = detail::TypeExtrema::max; const float fi = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); *alpha_out = alpha_in * (*amax_A) * (*amax_B) * fi; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 674599009..173e9b18c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -22,6 +22,9 @@ #include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "common/util/cuda_runtime.h" +#endif namespace transformer_engine { @@ -137,25 +140,27 @@ constexpr int kThreadsPerWarp = 32; constexpr int kNFP4PerContainer = 2; // Hyperparameters for performance tuning -// gfx942 has 64 KB LDS per workgroup. With kTileDim=128 and float32 input, -// shared memory exceeds 64 KB. Use kTileDim=64 and kThreadsPerBlock=128 on AMD. -// TODO: For optimal gfx950 performance (128 KB LDS), implement runtime dispatch -// with two kernel instantiations (kTileDim=64 and kTileDim=128). -#if defined(__HIP_PLATFORM_AMD__) -constexpr int kTileDim = 64; -constexpr int kThreadsPerBlock = 128; // Thread block size, 4 warps in total -#else constexpr int kTileDim = 128; -constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total -#endif // constexpr int kScaleDim = 32; constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total // Auto-calculated constants, do not modify directly) static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); + +#ifdef __HIP_PLATFORM_AMD__ +// On AMD, kTileDim_ is a template parameter of the kernel for runtime dispatch: +// gfx942: kTileDim_=64 (64 KB LDS, kThreadsPerBlock=128, 4 warps) +// gfx950: kTileDim_=128 (128 KB LDS, kThreadsPerBlock=256, 8 warps) +constexpr int smem_size_for_tile(int tile_dim) { + return tile_dim * ((tile_dim / kNVecSMem) + 1) * kNVecSMem; +} +#else +constexpr int kTileDim = 128; +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total constexpr int kSMemRow = kTileDim; constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; @@ -164,27 +169,15 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 // constexpr int kNumThreadsReduce = kScaleDim / kNVecOut; static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); +#endif //__HIP_PLATFORM_AMD__ // for 2D block scaling, we need to reduce amax in warp -#ifdef __HIP_PLATFORM_AMD__ -static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { - 0x0101010101010101ULL, 0x0202020202020202ULL, - 0x0404040404040404ULL, 0x0808080808080808ULL, - 0x1010101010101010ULL, 0x2020202020202020ULL, - 0x4040404040404040ULL, 0x8080808080808080ULL}; -#else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; -#endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, -#ifdef __HIP_PLATFORM_AMD__ - uint64_t groupMask) { -#else - unsigned int groupMask) { -#endif +__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { for (int offset = group_size / 2; offset > 0; offset /= 2) { #ifdef __HIP_PLATFORM_AMD__ (void)groupMask; // unused on AMD, __shfl_down does not take a mask @@ -421,16 +414,38 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, } } -template -__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( +template < +#ifdef __HIP_PLATFORM_AMD__ + int kTileDim_, +#endif + bool kReturnIdentity, bool kReturnTranspose, bool kIsE8Scaling, + bool kAligned, typename CType, typename IType, typename OType, typename ScaleType, + bool kSwizzledScale, bool kApplyStochasticRounding, bool kIs2DBlockScaling> +__global__ void __launch_bounds__( +#ifdef __HIP_PLATFORM_AMD__ + kTileDim_ * 2 +#else + kThreadsPerBlock +#endif +) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state, const float* noop_ptr) { + // Tile-dependent constants +#ifdef __HIP_PLATFORM_AMD__ + // Redirect kTileDim to the template parameter for runtime gfx942/gfx950 dispatch +#define kTileDim kTileDim_ + constexpr int kThreadsPerBlock = kTileDim * 2; + constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; + constexpr int kNumThreadsLoad = kTileDim / kNVecIn; + constexpr int kNumThreadsStore = kTileDim / kNVecOut; + static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); + static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); +#endif + constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer; using SMemVec = Vec; using OVec = Vec; @@ -804,6 +819,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } } +#ifdef __HIP_PLATFORM_AMD__ +#undef kTileDim +#endif } } // namespace @@ -863,8 +881,17 @@ void quantize_transpose_vector_blockwise_fp4( using namespace transformer_engine::quantize_transpose_nvfp4; +#ifdef __HIP_PLATFORM_AMD__ + // Runtime tile dimension selection: + // gfx942 (64 KB LDS): tile_dim=64, threads=128 + // gfx950 (128 KB LDS): tile_dim=128, threads=256 + const int tile_dim = (cuda::sm_arch() >= 95) ? 128 : 64; + const size_t num_blocks_x = DIVUP(row_length, static_cast(tile_dim)); + const size_t num_blocks_y = DIVUP(num_rows, static_cast(tile_dim)); +#else const size_t num_blocks_x = DIVUP(row_length, static_cast(kTileDim)); const size_t num_blocks_y = DIVUP(num_rows, static_cast(kTileDim)); +#endif // noop tensor for cuda graph const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); @@ -879,6 +906,36 @@ void quantize_transpose_vector_blockwise_fp4( rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); } +#ifdef __HIP_PLATFORM_AMD__ + // Macro to instantiate and launch the kernel for a given tile dimension. + // TILE_DIM must be a compile-time constant (64 or 128). +#define LAUNCH_FP4_CT_KERNEL(TILE_DIM) \ + do { \ + constexpr int kTD = TILE_DIM; \ + size_t smem_bytes = smem_size_for_tile(kTD) * sizeof(InputType); \ + auto kernel = block_scaled_1d_cast_transpose_kernel< \ + kTD, kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, \ + float, InputType, OutputType, ScaleType, kSwizzledScale, \ + kApplyStochasticRounding, kIs2DBlockScaling>; \ + if (smem_bytes >= 48 * 1024) { \ + cudaError_t err = cudaFuncSetAttribute( \ + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); \ + NVTE_CHECK(err == cudaSuccess, \ + "Failed to set dynamic shared memory size."); \ + } \ + kernel<<>>( \ + reinterpret_cast(input.dptr), \ + reinterpret_cast(global_amax.dptr), \ + reinterpret_cast(output.dptr), \ + reinterpret_cast(output_t.dptr), \ + reinterpret_cast(scale_inv.dptr), \ + reinterpret_cast(scale_inv_t.dptr), row_length, \ + num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, \ + scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, \ + noop_ptr); \ + } while (0) +#endif // __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -890,7 +947,11 @@ void quantize_transpose_vector_blockwise_fp4( using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16; constexpr bool kPow2Scale = false; +#ifdef __HIP_PLATFORM_AMD__ + const bool full_tile = row_length % tile_dim == 0 && num_rows % tile_dim == 0; +#else const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; +#endif TRANSFORMER_ENGINE_SWITCH_CONDITION( return_identity, kReturnIdentity, @@ -910,6 +971,14 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( use_2d_quantization, kIs2DBlockScaling, +#ifdef __HIP_PLATFORM_AMD__ + if (tile_dim == 64) { + LAUNCH_FP4_CT_KERNEL(64); + } else { + LAUNCH_FP4_CT_KERNEL(128); + } + ) // kIs2DBlockScaling +#else size_t smem_bytes = kSMemSize * sizeof(InputType); auto kernel = block_scaled_1d_cast_transpose_kernel< kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, @@ -932,6 +1001,7 @@ void quantize_transpose_vector_blockwise_fp4( num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, noop_ptr);) // kIs2DBlockScaling +#endif //__HIP_PLATFORM_AMD__ ) // kApplyStochasticRounding ) // kSwizzledScale ) // kAligned @@ -940,6 +1010,10 @@ void quantize_transpose_vector_blockwise_fp4( ) // OutputType ) // InputType +#ifdef __HIP_PLATFORM_AMD__ +#undef LAUNCH_FP4_CT_KERNEL +#endif + NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); From da093bd03d357a11ed26d4202af0b8b074a18283 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 15:14:20 -0500 Subject: [PATCH 62/69] more cleanups --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 4 ++-- tests/cpp/operator/test_dequantize_nvfp4.cu | 4 ++-- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 12 +++++++---- tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 20 ++++++++++++++----- tests/pytorch/test_cpu_offloading.py | 6 +++--- tests/pytorch/test_fusible_ops.py | 5 ++++- .../hadamard_transform/hadamard_transform.cu | 11 ++-------- 7 files changed, 36 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index a24a9c6f4..9f7c21d91 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -68,9 +68,9 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax) { #ifdef __HIP_PLATFORM_AMD__ const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; #else - constexpr float fp8_max = 448.0f; + constexpr float fp8_max = 448.0f; // 448.0f; #endif - constexpr float fp4_max = 6.0f; + constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 9fd551d6a..39ea2eed3 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -110,10 +110,10 @@ void compute_ref(const fp4e2m1* input, const size_t scale_stride) { #ifdef __HIP_PLATFORM_AMD__ const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float factor_inv = 1.0f / (6.0f * fp8_max); #else - constexpr float fp8_max = 448.0f; + constexpr float factor_inv = 1.0f / (6.0f * 448.0f); #endif - const float factor_inv = 1.0f / (6.0f * fp8_max); const size_t blocks_per_row = cols / kFP4BlockSize1D; diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index d40e44423..22b5cda40 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -111,11 +111,15 @@ def check_nvfp4_gemm_versus_reference( sx_trimmed = sx_native[:M, :expected_sx_cols] sw_trimmed = sw_native[:N, :expected_sw_cols] - # Native scales are stored as uint8 but need to be interpreted as float8_e4m3 + # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn # for the reference GEMM to work correctly - fp8_dtype = get_torch_float8_e4m3_type() - sx_trimmed = sx_trimmed.view(fp8_dtype) - sw_trimmed = sw_trimmed.view(fp8_dtype) + if IS_HIP_EXTENSION: + fp8_dtype = get_torch_float8_e4m3_type() + sx_trimmed = sx_trimmed.view(fp8_dtype) + sw_trimmed = sw_trimmed.view(fp8_dtype) + else: + sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) + sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM ref_quantizer = NVFP4QuantizerRef( diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index 98c0474a5..11777a715 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,7 +13,9 @@ import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, is_fp8_fnuz recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) @@ -59,12 +63,18 @@ def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: - fp8_dtype = get_torch_float8_e4m3_type() - fp8_max = 240.0 if is_fp8_fnuz() else 448.0 - sf = sx.repeat_interleave(16, dim=1).view(fp8_dtype).to(torch.float32) + if IS_HIP_EXTENSION: + fp8_dtype = get_torch_float8_e4m3_type() + fp8_max = 240.0 if is_fp8_fnuz() else 448.0 + sf = sx.repeat_interleave(16, dim=1).view(fp8_dtype).to(torch.float32) + else: + sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) dqx = fp4_to_fp32(unpack_fp4(qx)) sf = sf[: dqx.shape[0], : dqx.shape[1]] - dequant = dqx * sf * (amax / (6.0 * fp8_max)) + if IS_HIP_EXTENSION: + dequant = dqx * sf * (amax / (6.0 * fp8_max)) + else: + dequant = dqx * sf * (amax / (6.0 * 448)) return dequant diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index f9feee514..76c53f428 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -208,10 +208,10 @@ def memory_leak_check(): from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes # workspaces are larger for AMDGPU addl_space = get_cublas_workspace_size_bytes() / (1024**2) * 4 - else: - addl_space = 0 - if Utils.get_cuda_memory_mb() > 1000 + addl_space: + if Utils.get_cuda_memory_mb() > 1000: + if IS_HIP_EXTENSION and Utils.get_cuda_memory_mb() <= 1000 + addl_space: + return memory_num = Utils.get_cuda_memory_mb() import gc diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3b4e4f2d3..159826d11 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1773,7 +1773,10 @@ def test_clamped_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if IS_HIP_EXTENSION: + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + else: + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index d1453ed59..0631f326b 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -980,15 +980,8 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran auto* in_ptr = reinterpret_cast(input.dptr); - if (pre_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); - } - if (id_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); - } - if (tr_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); - } + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); #else constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, From 75a4738bbb9f9989035557efcd68346f977d6a17 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Apr 2026 21:34:29 +0000 Subject: [PATCH 63/69] re-fix null tensor_amax --- transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 8901f2351..dd98feb66 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -61,7 +61,11 @@ __global__ void __launch_bounds__(512) fp8e4m3 scale = scales[my_scale_index]; // NVFP4 may reach this path with scale present but no separate amax buffer. // Use 1.0f as the neutral fallback when tensor_amax is not provided on HIP. +#ifndef __HIP_PLATFORM_AMD__ float amax = *tensor_amax; +#else + float amax = (tensor_amax != nullptr) ? *tensor_amax : 1.0f; +#endif #if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__) // On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) const float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); From 531550679079c28bf9c4b9b8989da0c88aa407fc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Apr 2026 13:53:02 -0500 Subject: [PATCH 64/69] minor cleanups --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 ++-- transformer_engine/common/CMakeLists.txt | 1 - transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 8 ++++++-- .../hadamard_transform/hadamard_transform_utils.cuh | 3 --- transformer_engine/common/recipe/nvfp4.cu | 2 ++ transformer_engine/pytorch/csrc/extensions/cast.cpp | 4 ++-- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 22b5cda40..d813ecf9c 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -12,10 +12,10 @@ from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils -from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type - from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2f4873692..54df188d2 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -243,7 +243,6 @@ if(USE_CUDA) recipe/nvfp4.cu) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu - transpose/quantize_transpose_square_blockwise.cu hadamard_transform/group_hadamard_transform.cu transpose/quantize_transpose_square_blockwise.cu hadamard_transform/hadamard_transform_cast_fusion.cu diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index dd98feb66..aba4f8789 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -66,12 +66,16 @@ __global__ void __launch_bounds__(512) #else float amax = (tensor_amax != nullptr) ? *tensor_amax : 1.0f; #endif -#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__) +#if defined(__HIP_PLATFORM_AMD__) +#if !defined(__HIP_DEVICE_COMPILE__) // On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) const float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); #else constexpr float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); -#endif +#endif // !__HIP_DEVICE_COMPILE__ +#else + constexpr float factor_inv = 1.0 / (6.0 * 448.0); +#endif // __HIP_PLATFORM_AMD__ float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh b/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh index 7d96152b0..f86061abb 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_utils.cuh @@ -7,8 +7,6 @@ #ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_ #define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_ -#ifndef __HIP_PLATFORM_AMD__ - #include #include #include @@ -197,5 +195,4 @@ __device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, } // namespace transformer_engine -#endif // __HIP_PLATFORM_AMD__ #endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_UTILS_CUH_ diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index af4847e4b..8f6124461 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b35e0700..acbe4753b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1102,8 +1102,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Fuse the rowwise and colwise into one when the kernel is ready split_quantize_nvfp4_impl_with_rht_helper(input, input_list, output_list, split_sections, quantizers, stream); - } else { - // NVFP4 quantize without RHT + } else { // NVFP4 quantize + // Fuse the rowwise and colwise into one when the kernel is ready split_quantize_nvfp4_impl_helper(input, input_list, output_list, split_sections, quantizers, stream); } From a91eaf06e3b7d48a002e0a1f3912d93d99019a36 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 23 Apr 2026 15:01:34 -0500 Subject: [PATCH 65/69] address review comments --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 +- tests/pytorch/test_cpu_offloading.py | 7 - tests/pytorch/test_fusible_ops.py | 5 +- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 9 +- transformer_engine/common/gemm/rocm_gemm.cu | 193 ++++++++++-------- .../common/hadamard_transform/wht16.cuh | 2 + transformer_engine/common/recipe/nvfp4.cu | 6 +- .../pytorch/cpp_extensions/gemm.py | 24 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 5 +- transformer_engine/pytorch/module/base.py | 6 +- .../pytorch/module/layernorm_linear.py | 2 + transformer_engine/pytorch/module/linear.py | 2 + 12 files changed, 154 insertions(+), 111 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index d813ecf9c..47d61743e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -162,7 +162,9 @@ def check_nvfp4_gemm_versus_reference( # Allocate cuBLAS workspace if IS_HIP_EXTENSION: # On ROCm, FP4 is dequantized to BF16 in workspace before GEMM, so allocate enough space. - ws_bytes = M * K * 2 + K * N * 2 + 32 * 1024 * 1024 + # Extra 32 MiB for hipBLASLt internal workspace + alpha vector + bf16_size = torch.bfloat16.itemsize + ws_bytes = M * K * bf16_size + K * N * bf16_size + 32 * 1024 * 1024 workspace = torch.empty(ws_bytes, dtype=torch.uint8, device=device) else: workspace = torch.empty(4, dtype=torch.uint8, device=device) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 76c53f428..7d1b8c716 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -204,14 +204,7 @@ def memory_leak_check(): # Only cublas workspaces and some global tensors are allowed to be allocated. # All other allocations should be released. # This is a simple check to catch memory leaks. - if IS_HIP_EXTENSION: - from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes - # workspaces are larger for AMDGPU - addl_space = get_cublas_workspace_size_bytes() / (1024**2) * 4 - if Utils.get_cuda_memory_mb() > 1000: - if IS_HIP_EXTENSION and Utils.get_cuda_memory_mb() <= 1000 + addl_space: - return memory_num = Utils.get_cuda_memory_mb() import gc diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 159826d11..cc3c6f341 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2939,10 +2939,9 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return out # Check values + tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking if IS_HIP_EXTENSION: - tols = {"rtol": 0.25, "atol": 0.54} - else: - tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking + tols["atol"] = 0.54 torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index aba4f8789..bd01acefe 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -66,16 +66,11 @@ __global__ void __launch_bounds__(512) #else float amax = (tensor_amax != nullptr) ? *tensor_amax : 1.0f; #endif -#if defined(__HIP_PLATFORM_AMD__) -#if !defined(__HIP_DEVICE_COMPILE__) - // On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) - const float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); -#else +#if defined(__HIP_DEVICE_COMPILE__) constexpr float factor_inv = 1.0f / (detail::TypeExtrema::max * detail::TypeExtrema::max); -#endif // !__HIP_DEVICE_COMPILE__ #else constexpr float factor_inv = 1.0 / (6.0 * 448.0); -#endif // __HIP_PLATFORM_AMD__ +#endif float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 2e9f197f4..810a2557b 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -319,7 +319,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; ret.lda = is_A_transposed ? k : m; } else if (is_nvfp_scaling(A.scaling_mode)) { - // NVFP4 + // NVFP4: dequant path always produces TN layout for the BF16 GEMM, + // but the source data may come from either rowwise or columnwise buffers. ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; ret.transA = CUBLAS_OP_T; // NVFP4 gemm is always TN layout ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; @@ -364,7 +365,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; } else if (is_nvfp_scaling(B.scaling_mode)) { - // NVFP4 + // NVFP4: dequant path always produces TN layout for the BF16 GEMM, + // but the source data may come from either rowwise or columnwise buffers. ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; ret.transB = CUBLAS_OP_N; // NVFP4 gemm is always TN layout ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; @@ -377,6 +379,97 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla return ret; } +// Dequantize FP4 inputs to BF16 in-place within the workspace and set up +// the alpha device vector for the subsequent hipBLASLt GEMM. +// After this call, param.A/B point to BF16 buffers within workspace, +// param.Atype/Btype are kBFloat16, and *alpha_ptr_out points to a device vector. +static void dequant_fp4_gemm_inputs( + GemmParam& param, + const transformer_engine::Tensor& inputA, cublasOperation_t transa, + const transformer_engine::Tensor& inputB, cublasOperation_t transb, + int m, int n, int k, float alpha, + void* workspace, size_t& workspaceSize, + const void** alpha_ptr_out, hipStream_t stream) { + + const float fp4_max = 6.0f; + const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); + + const float* amax_A = (transa == CUBLAS_OP_T) + ? reinterpret_cast(inputA.amax.dptr) + : reinterpret_cast(inputA.columnwise_amax.dptr); + const float* amax_B = (transb == CUBLAS_OP_N) + ? reinterpret_cast(inputB.amax.dptr) + : reinterpret_cast(inputB.columnwise_amax.dptr); + + // Compute total extra bytes needed from the workspace: + // alpha vector: m * sizeof(float) + // dequant A: m * k * sizeof(bf16) (if A is FP4) + // dequant B: k * n * sizeof(bf16) (if B is FP4) + const size_t alpha_vec_bytes = static_cast(m) * sizeof(float); + const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) + ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; + const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) + ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; + const size_t fp4_total_bytes = alpha_vec_bytes + a_bf16_bytes + b_bf16_bytes; + NVTE_CHECK(workspaceSize >= fp4_total_bytes, + "NVFP4 GEMM requires at least ", fp4_total_bytes, " bytes workspace (", + fp4_total_bytes / (1024 * 1024), " MiB) for alpha vector + BF16 dequant buffers, " + "but only ", workspaceSize, " bytes (", workspaceSize / (1024 * 1024), + " MiB) available. Increase the cuBLAS workspace size."); + + // Carve regions from the end of the workspace. + // Layout: [cublas workspace ... | alpha_vec | dequant_a | dequant_b] + workspaceSize = (workspaceSize / sizeof(float)) * sizeof(float) - fp4_total_bytes; + uint8_t* ws_ptr = reinterpret_cast(workspace) + workspaceSize; + + float* device_alpha_vec = reinterpret_cast(ws_ptr); + ws_ptr += alpha_vec_bytes; + + NVTE_CHECK(amax_A != nullptr, "FP4 GEMM requires amax_A"); + NVTE_CHECK(amax_B != nullptr, "FP4 GEMM requires amax_B"); + constexpr int kBlockSize = 256; + const int num_blocks = (m + kBlockSize - 1) / kBlockSize; + compute_fp4_alpha_vector_kernel<<>>( + alpha, amax_A, amax_B, factor_inv, device_alpha_vec, m); + *alpha_ptr_out = static_cast(device_alpha_vec); + + // Dequantize FP4 -> BF16 (block scales only, no amax folded in) + if (is_fp4_dtype(param.Atype)) { + hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr); + ws_ptr += a_bf16_bytes; + const int64_t total_a = static_cast(m) * k; + const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA.scale_inv + : inputA.columnwise_scale_inv; + const int64_t a_num_cols = (transa == CUBLAS_OP_T) + ? inputA.data.shape.back() + : inputA.columnwise_data.shape.back(); + const int64_t a_scale_stride = (a_sinv.shape.size() >= 2) ? a_sinv.shape[1] : (a_num_cols / 16); + launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, + a_num_cols, a_scale_stride, stream); + param.A = a_bf16; + param.Atype = DType::kBFloat16; + param.A_scale_inv = nullptr; + } + + if (is_fp4_dtype(param.Btype)) { + hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr); + ws_ptr += b_bf16_bytes; + const int64_t total_b = static_cast(k) * n; + const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB.scale_inv + : inputB.columnwise_scale_inv; + const int64_t b_num_cols = (transb == CUBLAS_OP_N) + ? inputB.data.shape.back() + : inputB.columnwise_data.shape.back(); + const int64_t b_scale_stride = (b_sinv.shape.size() >= 2) ? b_sinv.shape[1] : (b_num_cols / 16); + launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, + b_num_cols, b_scale_stride, stream); + param.B = b_bf16; + param.Btype = DType::kBFloat16; + param.B_scale_inv = nullptr; + } +} + static class HandlePool { public: @@ -608,19 +701,22 @@ public: //Make it int instead of hipblasLtMatmulMatrixScale_t for compatibility with old hipblasLt int scaling_mode; hipblasLtEpilogue_t epilogue; + bool fp4_alpha_device_vector; // FP4 uses ALPHA_DEVICE_VECTOR pointer mode Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_, hipDataType bias_type_, hipDataType aux_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_, hipblasOperation_t transa_, hipblasOperation_t transb_, - int scaling_mode_, hipblasLtEpilogue_t epilogue_): + int scaling_mode_, hipblasLtEpilogue_t epilogue_, + bool fp4_alpha_device_vector_ = false): deviceCap(deviceCap_), a_type(a_type_), b_type(b_type_), d_type(d_type_), bias_type(bias_type_), aux_type(aux_type_), m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_), transa(transa_), transb(transb_), - scaling_mode(scaling_mode_), epilogue(epilogue_) {} + scaling_mode(scaling_mode_), epilogue(epilogue_), + fp4_alpha_device_vector(fp4_alpha_device_vector_) {} Key() {} @@ -633,7 +729,8 @@ public: && (m == val.m) && (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) - && (scaling_mode == val.scaling_mode) && (epilogue == val.epilogue) ); + && (scaling_mode == val.scaling_mode) && (epilogue == val.epilogue) + && (fp4_alpha_device_vector == val.fp4_alpha_device_vector) ); } struct Comp @@ -1051,86 +1148,9 @@ void hipblaslt_gemm(const Tensor *inputA, const void* alpha_ptr = static_cast(&alpha); const void* beta_ptr = static_cast(&beta); if (use_fp4) { - const float fp4_max = 6.0f; - const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; - const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); - - const float* amax_A = (transa == CUBLAS_OP_T) - ? reinterpret_cast(inputA->amax.dptr) - : reinterpret_cast(inputA->columnwise_amax.dptr); - const float* amax_B = (transb == CUBLAS_OP_N) - ? reinterpret_cast(inputB->amax.dptr) - : reinterpret_cast(inputB->columnwise_amax.dptr); - - // Compute total extra bytes needed from the workspace: - // alpha vector: m * sizeof(float) - // dequant A: m * k * sizeof(bf16) (if A is FP4) - // dequant B: k * n * sizeof(bf16) (if B is FP4) - const size_t alpha_vec_bytes = static_cast(m) * sizeof(float); - const size_t a_bf16_bytes = is_fp4_dtype(param.Atype) - ? static_cast(m) * k * sizeof(hip_bfloat16) : 0; - const size_t b_bf16_bytes = is_fp4_dtype(param.Btype) - ? static_cast(k) * n * sizeof(hip_bfloat16) : 0; - const size_t fp4_total_bytes = alpha_vec_bytes + a_bf16_bytes + b_bf16_bytes; - NVTE_CHECK(workspaceSize >= fp4_total_bytes, - "NVFP4 GEMM requires at least ", fp4_total_bytes, " bytes workspace (", - fp4_total_bytes / (1024 * 1024), " MiB) for alpha vector + BF16 dequant buffers, " - "but only ", workspaceSize, " bytes (", workspaceSize / (1024 * 1024), - " MiB) available. Increase the cuBLAS workspace size."); - - // Carve regions from the end of the workspace. - // Layout: [cublas workspace ... | alpha_vec | dequant_a | dequant_b] - workspaceSize = (workspaceSize / sizeof(float)) * sizeof(float) - fp4_total_bytes; - uint8_t* ws_ptr = reinterpret_cast(workspace) + workspaceSize; - - float* device_alpha_vec = reinterpret_cast(ws_ptr); - ws_ptr += alpha_vec_bytes; - - NVTE_CHECK(amax_A != nullptr, "FP4 GEMM requires amax_A"); - NVTE_CHECK(amax_B != nullptr, "FP4 GEMM requires amax_B"); - constexpr int kBlockSize = 256; - const int num_blocks = (m + kBlockSize - 1) / kBlockSize; - compute_fp4_alpha_vector_kernel<<>>( - alpha, amax_A, amax_B, factor_inv, device_alpha_vec, m); - alpha_ptr = static_cast(device_alpha_vec); - // beta_ptr stays as host &beta - - // Dequantize FP4 -> BF16 (block scales only, no amax folded in) - if (is_fp4_dtype(param.Atype)) { - hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr); - ws_ptr += a_bf16_bytes; - const int64_t total_a = static_cast(m) * k; - // Determine scale stride from scale tensor shape - const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA->scale_inv - : inputA->columnwise_scale_inv; - const int64_t a_num_cols = (transa == CUBLAS_OP_T) - ? inputA->data.shape.back() - : inputA->columnwise_data.shape.back(); - const int64_t a_scale_stride = (a_sinv.shape.size() >= 2) ? a_sinv.shape[1] : (a_num_cols / 16); - launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, - a_num_cols, a_scale_stride, stream); - param.A = a_bf16; - param.Atype = DType::kBFloat16; - param.A_scale_inv = nullptr; - } - - if (is_fp4_dtype(param.Btype)) { - hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr); - ws_ptr += b_bf16_bytes; - const int64_t total_b = static_cast(k) * n; - // Determine scale stride from scale tensor shape - const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB->scale_inv - : inputB->columnwise_scale_inv; - const int64_t b_num_cols = (transb == CUBLAS_OP_N) - ? inputB->data.shape.back() - : inputB->columnwise_data.shape.back(); - const int64_t b_scale_stride = (b_sinv.shape.size() >= 2) ? b_sinv.shape[1] : (b_num_cols / 16); - launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, - b_num_cols, b_scale_stride, stream); - param.B = b_bf16; - param.Btype = DType::kBFloat16; - param.B_scale_inv = nullptr; - } + dequant_fp4_gemm_inputs(param, *inputA, transa, *inputB, transb, + m, n, k, alpha, workspace, workspaceSize, + &alpha_ptr, stream); } bool nvte_log_gemm_config = false; @@ -1362,7 +1382,8 @@ void hipblaslt_gemm(const Tensor *inputA, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, use_fp8 ? bias_type : (hipDataType)-1, (use_fp8 && gelu) ? aux_type : (hipDataType)-1, - m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue ); + m, n, k, param.lda, param.ldb, ldd, param.transA, param.transB, scaling_mode, epilogue, + use_fp4); GemmAlgoCache::Algo cached_algo; if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) { diff --git a/transformer_engine/common/hadamard_transform/wht16.cuh b/transformer_engine/common/hadamard_transform/wht16.cuh index 751ecba68..490ebbb6d 100644 --- a/transformer_engine/common/hadamard_transform/wht16.cuh +++ b/transformer_engine/common/hadamard_transform/wht16.cuh @@ -9,6 +9,8 @@ #ifndef TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ #define TRANSFORMER_ENGINE_COMMON_HADAMARD_TRANSFORM_WHT16_CUH_ +#include "hip/hip_runtime.h" + #ifdef __HIP_PLATFORM_AMD__ static constexpr int kHadamardDim = 16; diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 8f6124461..cfe95d92a 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -26,7 +26,11 @@ __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const floa const float *amax_B, float *alpha_out) { #ifdef __HIP_PLATFORM_AMD__ constexpr float fp4_max = detail::TypeExtrema::max; - const float fp8_max = detail::TypeExtrema::max; +#if defined(__HIP_DEVICE_COMPILE__) + constexpr float fp8_max = detail::TypeExtrema::max; +#else + constexpr float fp8_max = 240.0f; // host placeholder; only device path executes +#endif const float fi = 1.0f / (fp4_max * fp4_max * fp8_max * fp8_max); *alpha_out = alpha_in * (*amax_A) * (*amax_B) * fi; #else diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index c61d1b887..f66026529 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -19,6 +19,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -36,8 +37,10 @@ def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture.""" if IS_HIP_EXTENSION: - """Return 512 MiB (FP4 dequant buffers).""" - return 512 * 1024 * 1024 + """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" + if get_device_compute_capability() == (9, 5): + return 67_108_864 + return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales @@ -128,6 +131,23 @@ def general_gemm( beta = validate_gemm_scale(beta, accumulate) workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False) + # On ROCm, FP4 is dequantized to BF16 in the workspace before GEMM. + # Compute the required extra space and extend the workspace if needed. + if IS_HIP_EXTENSION and ( + isinstance(A, NVFP4TensorStorage) or isinstance(B, NVFP4TensorStorage) + ): + import math + bf16_size = 2 # sizeof(bfloat16) + fp4_extra = 0 + if isinstance(A, NVFP4TensorStorage): + fp4_extra += math.prod(A.size()) * bf16_size + fp4_extra += A.size(0) * 4 # alpha vector (m floats) + if isinstance(B, NVFP4TensorStorage): + fp4_extra += math.prod(B.size()) * bf16_size + total_needed = fp4_extra + get_cublas_workspace_size_bytes() + if workspace.numel() < total_needed: + workspace = torch.empty(total_needed, dtype=torch.uint8, device=workspace.device) + if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 88f6c1f79..10a86726c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1673,9 +1673,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // If rht_output_t was already produced by the fused amax+transform kernel above, // skip the separate hadamard_transform call. #ifdef USE_ROCM - if (!rht_output_t.defined()) -#endif + if (!rht_output_t.defined()) { +#else { +#endif rht_output_t = allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); TensorWrapper rht_output_t_cpp; diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7954c3e72..af8910525 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -86,8 +86,10 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if IS_HIP_EXTENSION: - """Return 512 MiB (FP4 dequant buffers).""" - return 512 * 1024 * 1024 + """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" + if get_device_compute_capability() == (9, 5): + return 67_108_864 + return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2a4f74d17..a2817b79b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -374,6 +374,8 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ + # NVFP4TensorStorage doesn't have _transpose (no lazy transpose like Float8Tensor), + # so guard with hasattr. if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache and hasattr(weightmat, '_transpose'): assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." nvtx_range_push(f"{nvtx_label}.gemm") diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dd17c7023..27e2648e7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -330,6 +330,8 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ + # NVFP4TensorStorage doesn't have _transpose (no lazy transpose like Float8Tensor), + # so guard with hasattr. if IS_HIP_EXTENSION and fp8 and not keep_fp8_weight_transpose_cache and hasattr(weightmat, '_transpose'): assert weightmat._transpose is None or weightmat._transpose.numel() == 0, "Expected _transpose to be None or an empty tensor when transpose cache is disabled." From 0f53a9ddd7b59a639fc3e9f69f14bca2d4549127 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 23 Apr 2026 22:34:37 -0500 Subject: [PATCH 66/69] address review comments --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 ++-- transformer_engine/common/gemm/rocm_gemm.cu | 12 +++++++++--- transformer_engine/pytorch/cpp_extensions/gemm.py | 3 ++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 47d61743e..2bdfbed19 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -162,9 +162,9 @@ def check_nvfp4_gemm_versus_reference( # Allocate cuBLAS workspace if IS_HIP_EXTENSION: # On ROCm, FP4 is dequantized to BF16 in workspace before GEMM, so allocate enough space. - # Extra 32 MiB for hipBLASLt internal workspace + alpha vector + from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes bf16_size = torch.bfloat16.itemsize - ws_bytes = M * K * bf16_size + K * N * bf16_size + 32 * 1024 * 1024 + ws_bytes = M * K * bf16_size + K * N * bf16_size + get_cublas_workspace_size_bytes() workspace = torch.empty(ws_bytes, dtype=torch.uint8, device=device) else: workspace = torch.empty(4, dtype=torch.uint8, device=device) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 810a2557b..e17fa7ae9 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -709,7 +709,7 @@ public: int m_, int n_, int k_, int lda_, int ldb_, int ldd_, hipblasOperation_t transa_, hipblasOperation_t transb_, int scaling_mode_, hipblasLtEpilogue_t epilogue_, - bool fp4_alpha_device_vector_ = false): + bool fp4_alpha_device_vector_): deviceCap(deviceCap_), a_type(a_type_), b_type(b_type_), d_type(d_type_), bias_type(bias_type_), aux_type(aux_type_), @@ -860,7 +860,7 @@ protected: fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" << "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type" << "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type" - << "ws_min" << "ws_max" << "algo_id" << "aidx"; + << "ws_min" << "ws_max" << "algo_id" << "aidx" << "fp4_alpha"; } void load_() @@ -932,6 +932,11 @@ protected: std::getline(is, comp, csv_sep); std::getline(is, scale, csv_sep); is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; + int fp4_alpha = 0; + if (is.peek() == csv_sep) { + is >> c >> fp4_alpha; + } + cfg.fp4_alpha_device_vector = (fp4_alpha != 0); if (is.bad()) { @@ -1066,7 +1071,8 @@ protected: << ((cfg.aux_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.aux_type)) << cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName(cfg.epilogue) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) - << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; + << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index + << (cfg.fp4_alpha_device_vector ? 1 : 0) << csv_helper::end() << "\n"; } private: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index f66026529..27d4a78b6 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -136,8 +136,9 @@ def general_gemm( if IS_HIP_EXTENSION and ( isinstance(A, NVFP4TensorStorage) or isinstance(B, NVFP4TensorStorage) ): + assert ub is None, "User buffers (comm overlap) are not supported with NVFP4" import math - bf16_size = 2 # sizeof(bfloat16) + bf16_size = torch.bfloat16.itemsize fp4_extra = 0 if isinstance(A, NVFP4TensorStorage): fp4_extra += math.prod(A.size()) * bf16_size From a08e8c5e285d4c02eacd0961c4fe600aecb4fdcd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 27 Apr 2026 08:03:02 -0500 Subject: [PATCH 67/69] use maxNorm --- tests/cpp/operator/test_cast_nvfp4_transpose.cu | 2 +- tests/cpp/operator/test_dequantize_nvfp4.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 9f7c21d91..6391c9058 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -66,7 +66,7 @@ std::vector create_transpose(const InputType* const input, const size // Compute the global encode scale factor for a given global amax float compute_global_encode_scaling_factor_FP4(const float global_amax) { #ifdef __HIP_PLATFORM_AMD__ - const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float fp8_max = Numeric_Traits::maxNorm; #else constexpr float fp8_max = 448.0f; // 448.0f; #endif diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 39ea2eed3..71ebed027 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -109,7 +109,7 @@ void compute_ref(const fp4e2m1* input, const size_t cols, const size_t scale_stride) { #ifdef __HIP_PLATFORM_AMD__ - const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; + const float fp8_max = Numeric_Traits::maxNorm; const float factor_inv = 1.0f / (6.0f * fp8_max); #else constexpr float factor_inv = 1.0f / (6.0f * 448.0f); From fae76d3b7f4362af53621440f50af4678c5c2d2f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 27 Apr 2026 09:11:00 -0500 Subject: [PATCH 68/69] factor out FP4 staging --- transformer_engine/common/gemm/rocm_gemm.cu | 59 +++++++++------------ 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index e17fa7ae9..a6a2f42af 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -434,40 +434,33 @@ static void dequant_fp4_gemm_inputs( alpha, amax_A, amax_B, factor_inv, device_alpha_vec, m); *alpha_ptr_out = static_cast(device_alpha_vec); - // Dequantize FP4 -> BF16 (block scales only, no amax folded in) - if (is_fp4_dtype(param.Atype)) { - hip_bfloat16* a_bf16 = reinterpret_cast(ws_ptr); - ws_ptr += a_bf16_bytes; - const int64_t total_a = static_cast(m) * k; - const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA.scale_inv - : inputA.columnwise_scale_inv; - const int64_t a_num_cols = (transa == CUBLAS_OP_T) - ? inputA.data.shape.back() - : inputA.columnwise_data.shape.back(); - const int64_t a_scale_stride = (a_sinv.shape.size() >= 2) ? a_sinv.shape[1] : (a_num_cols / 16); - launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, - a_num_cols, a_scale_stride, stream); - param.A = a_bf16; - param.Atype = DType::kBFloat16; - param.A_scale_inv = nullptr; - } + // Stage FP4 operand: dequantize to BF16 in workspace and update GEMM param. + auto stage_fp4_operand = [&](DType& op_type, void*& op_data, + void*& op_scale_inv, + const transformer_engine::Tensor& input, + bool use_rowwise, int64_t rows, int64_t cols, + size_t bf16_bytes) { + if (!is_fp4_dtype(op_type)) + return; - if (is_fp4_dtype(param.Btype)) { - hip_bfloat16* b_bf16 = reinterpret_cast(ws_ptr); - ws_ptr += b_bf16_bytes; - const int64_t total_b = static_cast(k) * n; - const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB.scale_inv - : inputB.columnwise_scale_inv; - const int64_t b_num_cols = (transb == CUBLAS_OP_N) - ? inputB.data.shape.back() - : inputB.columnwise_data.shape.back(); - const int64_t b_scale_stride = (b_sinv.shape.size() >= 2) ? b_sinv.shape[1] : (b_num_cols / 16); - launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, - b_num_cols, b_scale_stride, stream); - param.B = b_bf16; - param.Btype = DType::kBFloat16; - param.B_scale_inv = nullptr; - } + hip_bfloat16* bf16_buf = reinterpret_cast(ws_ptr); + ws_ptr += bf16_bytes; + const auto& sinv = use_rowwise ? input.scale_inv : input.columnwise_scale_inv; + const int64_t num_cols = use_rowwise ? input.data.shape.back() + : input.columnwise_data.shape.back(); + const int64_t scale_stride = (sinv.shape.size() >= 2) ? sinv.shape[1] : (num_cols / 16); + launch_dequant_fp4_to_bf16(op_data, op_scale_inv, bf16_buf, + rows * cols, num_cols, scale_stride, stream); + op_data = bf16_buf; + op_type = DType::kBFloat16; + op_scale_inv = nullptr; + }; + + // Dequantize FP4 -> BF16 (block scales only, no amax folded in) + stage_fp4_operand(param.Atype, param.A, param.A_scale_inv, + inputA, transa == CUBLAS_OP_T, m, k, a_bf16_bytes); + stage_fp4_operand(param.Btype, param.B, param.B_scale_inv, + inputB, transb == CUBLAS_OP_N, k, n, b_bf16_bytes); } From a6f478779df5df14e89ffa8b075e509ab8535d06 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Apr 2026 13:16:30 -0500 Subject: [PATCH 69/69] address review comments --- transformer_engine/common/gemm/rocm_gemm.cu | 10 +- ...quantize_transpose_vector_blockwise_fp4.cu | 128 ++++++------------ 2 files changed, 46 insertions(+), 92 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index a6a2f42af..c0e82b8ff 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -853,7 +853,7 @@ protected: fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" << "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type" << "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type" - << "ws_min" << "ws_max" << "algo_id" << "aidx" << "fp4_alpha"; + << "fp4_alpha" << "ws_min" << "ws_max" << "algo_id" << "aidx"; } void load_() @@ -924,11 +924,8 @@ protected: std::getline(is, epi, csv_sep); std::getline(is, comp, csv_sep); std::getline(is, scale, csv_sep); - is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; int fp4_alpha = 0; - if (is.peek() == csv_sep) { - is >> c >> fp4_alpha; - } + is >> fp4_alpha >> c >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; cfg.fp4_alpha_device_vector = (fp4_alpha != 0); if (is.bad()) @@ -1064,8 +1061,9 @@ protected: << ((cfg.aux_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.aux_type)) << cfg.lda << cfg.ldb << cfg.ldd << cfg.scaling_mode << epilogueNameMapper.getName(cfg.epilogue) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) + << (cfg.fp4_alpha_device_vector ? 1 : 0) << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index - << (cfg.fp4_alpha_device_vector ? 1 : 0) << csv_helper::end() << "\n"; + << csv_helper::end() << "\n"; } private: diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 173e9b18c..3f0c0fe84 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -140,27 +140,36 @@ constexpr int kThreadsPerWarp = 32; constexpr int kNFP4PerContainer = 2; // Hyperparameters for performance tuning +#ifndef __HIP_PLATFORM_AMD__ constexpr int kTileDim = 128; +#endif // constexpr int kScaleDim = 32; constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +#ifndef __HIP_PLATFORM_AMD__ constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total +#endif + +// Tile dimension and thread block size: +// gfx942: kTileDim=64 (64 KB LDS, kThreadsPerBlock=128, 4 warps) +// gfx950 / NVIDIA: kTileDim=128 (128 KB LDS, kThreadsPerBlock=256, 8 warps) +// On AMD, __gfx950__ is only defined during device compilation, so the host +// must select tile_dim at runtime via cuda::sm_arch() using the constants below. +#ifdef __HIP_PLATFORM_AMD__ +constexpr int kTileDimGfx950 = 128; +constexpr int kTileDimGfx942 = 64; +#if !defined(__gfx950__) +constexpr int kTileDim = kTileDimGfx942; +#else +constexpr int kTileDim = kTileDimGfx950; +#endif +constexpr int kThreadsPerBlock = 2 * kTileDim; +#endif // Auto-calculated constants, do not modify directly) static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); - -#ifdef __HIP_PLATFORM_AMD__ -// On AMD, kTileDim_ is a template parameter of the kernel for runtime dispatch: -// gfx942: kTileDim_=64 (64 KB LDS, kThreadsPerBlock=128, 4 warps) -// gfx950: kTileDim_=128 (128 KB LDS, kThreadsPerBlock=256, 8 warps) -constexpr int smem_size_for_tile(int tile_dim) { - return tile_dim * ((tile_dim / kNVecSMem) + 1) * kNVecSMem; -} -#else -constexpr int kTileDim = 128; -constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total constexpr int kSMemRow = kTileDim; constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; @@ -169,7 +178,14 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 // constexpr int kNumThreadsReduce = kScaleDim / kNVecOut; static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); -#endif //__HIP_PLATFORM_AMD__ + +#ifdef __HIP_PLATFORM_AMD__ +// Host-side helper: computes shared memory size for a runtime tile dimension. +// Needed because the host determines tile_dim at runtime via cuda::sm_arch(). +constexpr int smem_size_for_tile(int tile_dim) { + return tile_dim * ((tile_dim / kNVecSMem) + 1) * kNVecSMem; +} +#endif // for 2D block scaling, we need to reduce amax in warp static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { @@ -414,38 +430,16 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, } } -template < -#ifdef __HIP_PLATFORM_AMD__ - int kTileDim_, -#endif - bool kReturnIdentity, bool kReturnTranspose, bool kIsE8Scaling, - bool kAligned, typename CType, typename IType, typename OType, typename ScaleType, - bool kSwizzledScale, bool kApplyStochasticRounding, bool kIs2DBlockScaling> -__global__ void __launch_bounds__( -#ifdef __HIP_PLATFORM_AMD__ - kTileDim_ * 2 -#else - kThreadsPerBlock -#endif -) block_scaled_1d_cast_transpose_kernel( +template +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state, const float* noop_ptr) { - // Tile-dependent constants -#ifdef __HIP_PLATFORM_AMD__ - // Redirect kTileDim to the template parameter for runtime gfx942/gfx950 dispatch -#define kTileDim kTileDim_ - constexpr int kThreadsPerBlock = kTileDim * 2; - constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; - constexpr int kNumThreadsLoad = kTileDim / kNVecIn; - constexpr int kNumThreadsStore = kTileDim / kNVecOut; - static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); - static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); -#endif - constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer; using SMemVec = Vec; using OVec = Vec; @@ -819,9 +813,6 @@ __global__ void __launch_bounds__( } } } -#ifdef __HIP_PLATFORM_AMD__ -#undef kTileDim -#endif } } // namespace @@ -882,10 +873,9 @@ void quantize_transpose_vector_blockwise_fp4( using namespace transformer_engine::quantize_transpose_nvfp4; #ifdef __HIP_PLATFORM_AMD__ - // Runtime tile dimension selection: - // gfx942 (64 KB LDS): tile_dim=64, threads=128 - // gfx950 (128 KB LDS): tile_dim=128, threads=256 - const int tile_dim = (cuda::sm_arch() >= 95) ? 128 : 64; + // Tile dimension is selected at compile time based on the target architecture. + // The host still needs the runtime value for grid/smem computation. + const int tile_dim = (cuda::sm_arch() >= 95) ? kTileDimGfx950 : kTileDimGfx942; const size_t num_blocks_x = DIVUP(row_length, static_cast(tile_dim)); const size_t num_blocks_y = DIVUP(num_rows, static_cast(tile_dim)); #else @@ -906,36 +896,6 @@ void quantize_transpose_vector_blockwise_fp4( rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); } -#ifdef __HIP_PLATFORM_AMD__ - // Macro to instantiate and launch the kernel for a given tile dimension. - // TILE_DIM must be a compile-time constant (64 or 128). -#define LAUNCH_FP4_CT_KERNEL(TILE_DIM) \ - do { \ - constexpr int kTD = TILE_DIM; \ - size_t smem_bytes = smem_size_for_tile(kTD) * sizeof(InputType); \ - auto kernel = block_scaled_1d_cast_transpose_kernel< \ - kTD, kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, \ - float, InputType, OutputType, ScaleType, kSwizzledScale, \ - kApplyStochasticRounding, kIs2DBlockScaling>; \ - if (smem_bytes >= 48 * 1024) { \ - cudaError_t err = cudaFuncSetAttribute( \ - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); \ - NVTE_CHECK(err == cudaSuccess, \ - "Failed to set dynamic shared memory size."); \ - } \ - kernel<<>>( \ - reinterpret_cast(input.dptr), \ - reinterpret_cast(global_amax.dptr), \ - reinterpret_cast(output.dptr), \ - reinterpret_cast(output_t.dptr), \ - reinterpret_cast(scale_inv.dptr), \ - reinterpret_cast(scale_inv_t.dptr), row_length, \ - num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, \ - scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, \ - noop_ptr); \ - } while (0) -#endif // __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -972,14 +932,10 @@ void quantize_transpose_vector_blockwise_fp4( use_2d_quantization, kIs2DBlockScaling, #ifdef __HIP_PLATFORM_AMD__ - if (tile_dim == 64) { - LAUNCH_FP4_CT_KERNEL(64); - } else { - LAUNCH_FP4_CT_KERNEL(128); - } - ) // kIs2DBlockScaling + size_t smem_bytes = smem_size_for_tile(tile_dim) * sizeof(InputType); #else size_t smem_bytes = kSMemSize * sizeof(InputType); +#endif auto kernel = block_scaled_1d_cast_transpose_kernel< kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, float, InputType, OutputType, ScaleType, kSwizzledScale, @@ -990,8 +946,13 @@ void quantize_transpose_vector_blockwise_fp4( smem_bytes); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); +#ifdef __HIP_PLATFORM_AMD__ + } kernel<<>>( +#else } kernel<<>>( +#endif reinterpret_cast(input.dptr), reinterpret_cast(global_amax.dptr), reinterpret_cast(output.dptr), @@ -1001,7 +962,6 @@ void quantize_transpose_vector_blockwise_fp4( num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, noop_ptr);) // kIs2DBlockScaling -#endif //__HIP_PLATFORM_AMD__ ) // kApplyStochasticRounding ) // kSwizzledScale ) // kAligned @@ -1010,10 +970,6 @@ void quantize_transpose_vector_blockwise_fp4( ) // OutputType ) // InputType -#ifdef __HIP_PLATFORM_AMD__ -#undef LAUNCH_FP4_CT_KERNEL -#endif - NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);