Skip to content

[WIP] TDM porting#558

Draft
wangye805 wants to merge 32 commits intonpi_gfx1250from
yewang12/tdm_port_npi
Draft

[WIP] TDM porting#558
wangye805 wants to merge 32 commits intonpi_gfx1250from
yewang12/tdm_port_npi

Conversation

@wangye805
Copy link
Copy Markdown
Collaborator

Description

Do not review, I'm just trying to play with claude for the iterations. Tracking https://github.com/ROCm/frameworks-internal/issues/16226

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Initial TDM porting

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@wangye805 wangye805 marked this pull request as draft April 22, 2026 21:49
Copy link
Copy Markdown
Collaborator Author

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, also port TDM to NV upstream's flow which were guarded previously because we don't have TMA.

Comment thread transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip
Comment thread transformer_engine/common/util/rocm_cast_gated_kernels.cuh Outdated
Copy link
Copy Markdown
Collaborator Author

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please forget about previous rocm specific cast transpose kernel logics. Here I wanted you to closely follow NV upstream's behavior and do a TDM port as their TMA equivalent

Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
wangye805 and others added 2 commits April 23, 2026 14:50
…ory comments

- Remove input_act_stride/output_stride as kernel params in gated kernels;
  compute them inside the kernel from cols and IS_DGATED template param
- Add comments explaining why TDM does not need in_transaction_size
  (uses s_wait_tensorcnt counting ops, not mbarrier counting bytes)

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
…ests

- Fix `) {` placement to minimize diff in gated kernel signatures
- Fix MXFP8 gated kernel: remove unnecessary pre-loop wait, make
  in-loop wait conditional to preserve double-buffering prefetch
- Add comments explaining TDM does not need mbarrier destroy
- Add NVTE_USE_NV_UPSTREAM_FLOW=1 ctest run in ci/core.sh to exercise
  TDM kernel paths for MXFP8 quantize, gated, and dequantize

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@wenchenvincent wenchenvincent changed the title [WIP] TDM porting using claude [WIP] TDM porting Apr 23, 2026
wangye805 and others added 4 commits April 24, 2026 12:16
When a double-buffered prefetch tile origin falls past the tensor boundary
(non-tile-aligned rows/cols), tensor_h - tile_row and tensor_w - tile_col
would underflow as uint32_t to ~4 billion, causing the TDM hardware to
attempt a DMA of billions of rows and trigger a GPU page fault.

Clamp the remaining extent to 0 when tile_row >= tensor_h or
tile_col >= tensor_w. Unlike NV TMA (which encodes full tensor shape in a
host-side CUtensorMap and clamps automatically), TDM computes the remaining
extent per-call, so the caller must guard against out-of-bounds origins.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Introduce HIPTensorMap/HIPTensorMapOut structs in tdm.cuh as the AMD
analog of CUtensorMap. Callers in cast_kernels.cuh and
cast_gated_kernels.cuh now construct one descriptor per tensor at kernel
entry and pass it to TDM helper calls instead of repeating 6+ raw
scalars at every call site.

Revert TDM usage in rocm_cast_kernels.cuh, rocm_cast_gated_kernels.cuh,
and rocm_dequantize_kernels.cuh back to the original HIP vectorized
copy_2d_to_shared / bulk_tensor_2d_shared_to_global path. The rocm_*
kernels are the legacy non-TDM path; TDM is used only in the NV-upstream
ported kernels (cast_kernels.cuh / cast_gated_kernels.cuh) behind
NVTE_USE_NV_UPSTREAM_FLOW.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Remove the tdm.cuh include, TDM_SHMEM_ALIGNMENT usage, and any
whitespace changes introduced in the previous commit, so rocm_cast_kernels.cuh,
rocm_cast_gated_kernels.cuh, and rocm_dequantize_kernels.cuh are
byte-for-byte identical to 5e8d61e (Ilya's branch point).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…ain kernel functions

On AMD, each mxfp8 quantize/dequantize/gated function previously
dispatched between TDM and ROCm kernels via an inline env-var check.
This refactor separates the two flows cleanly:

- cast_gated_kernels.cuh / rocm_cast_gated_kernels.cuh:
  rocm_cast_mxfp8_gated() hosts the ROCm HIP gated kernel dispatch.
  cast_mxfp8_gated() is now TDM-only on AMD.
  quantize_gated() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var.

- cast_kernels.cuh / rocm_cast_kernels.cuh:
  rocm_mxfp8_quantize() hosts the ROCm HIP cast kernel dispatch.
  mxfp8_quantize() is now TDM-only on AMD.
  fp8_quantize_rocm() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var.

- dequantize_kernels.cuh / rocm_dequantize_kernels.cuh:
  rocm_mxfp8_dequantize() hosts the ROCm HIP dequantize dispatch.
  mxfp8_dequantize() is now TDM-only on AMD.
  dequantize_helper() dispatches via NVTE_USE_NV_UPSTREAM_FLOW env var.

NV upstream path (no AMD) is unchanged throughout.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh
wangye805 and others added 2 commits April 26, 2026 09:41
- Rename namespace nv_flow -> tma_flow (more accurate: both TMA on NV
  and TDM on AMD use this path)
- Rename env-var NVTE_USE_NV_UPSTREAM_FLOW -> NVTE_USE_TDM_FLOW with
  inverted default: 0 = ROCm flow (default), 1 = TDM flow
- Apply same env-var dispatch to fp8 gated path (was missing)
- Remove dead AMD-specific guards around ScalingType, BUFF_DIM,
  blocks, THREADS_PER_CHUNK, grid, block_size in cast_mxfp8_gated
- Remove AMD-specific {} wrapper and duplicate shmem computation block;
  TMA_SHMEM_ALIGNMENT == TDM_SHMEM_ALIGNMENT == 128 so NV upstream
  formula works on both platforms

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Move swizzled_group_idx/swizzled_idx/shmem_offset_rowwise back to just
before out_act.store_to(), matching NV upstream cast_gated_kernels.cuh
lines 831-834, to minimize diff.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_gated_kernels.cuh Outdated
wangye805 and others added 5 commits April 26, 2026 10:00
Two differences found vs NV upstream (line 968):

1. out_gate_mem: AMD TDM kernel always needs a gate shmem buffer
   regardless of IS_DGATED (kernel signature always includes gate
   output pointers), so restore AMD-specific:
     out_gate_mem = buff_size_aligned_out  (always)
   vs NV:
     out_gate_mem = IS_DGATED ? buff_size_aligned_out : 0

2. in_mem: split into in_act_mem + in_gate_mem intermediate vars
   to match NV upstream style exactly.

3. AMD TDM dispatch: restore TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH
   dispatch (was accidentally dropped when removing the {} wrapper),
   guarded under #ifdef __HIP_PLATFORM_AMD__. NV uses switch(scaling_type).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Replace TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH with the same
switch(scaling_type) structure as NV upstream. The TDM kernel shares
the same ROWWISE_SCALING/COLWISE_SCALING/THREADS_PER_CHUNK template
params as the NV kernel — SCALE_DIM_Y/X/IS_ALIGNED were ROCm-flow
params that don't apply here.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…ated_kernel

next_buff, next_stage_offset_Y, global_offset_Y, global_offset_X,
next_buff_offset are identical in both the TMA and TDM branches —
declare them once above the #ifndef __HIP_PLATFORM_AMD__ guard.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
The shmem size calculation is identical for TDM and TMA paths
(TDM_SHMEM_ALIGNMENT == TMA_SHMEM_ALIGNMENT == 128), so declare it
once above the #ifdef __HIP_PLATFORM_AMD__ guard. Only the pointer
setup and kernel launch remain platform-specific.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Both AMD and NV call mxfp8_kernel::cast_mxfp8_gated_kernel — the TDM
kernel at line 435 is also inside namespace mxfp8_kernel. The two
switch blocks were identical except for the namespace qualifier, so
remove the #ifdef and keep one unified switch block.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
wangye805 and others added 2 commits April 26, 2026 10:51
…OCM_

- Remove namespace tma_flow from cast_gated_kernels.cuh; constants and
  cast_fp8_gated_kernel now live directly in gated_kernels namespace,
  consistent with mxfp8_kernel::cast_mxfp8_gated_kernel organization
- In rocm_cast_gated_kernels.cuh, prefix all constants that conflict
  with the now-unnamespaced tma_flow constants with ROCM_:
  ROCM_CHUNK_DIM_Y/X, ROCM_THREADS_PER_CHUNK, ROCM_THREADS_PER_CHUNK_X/Y,
  ROCM_BUFFERS_NUM, ROCM_BUFFER_DIM_Y/X, ROCM_SHMEM_DIM_Y/X,
  ROCM_BUFFER_STAGES_NUM, ROCM_ITERATIONS
- Remove duplicate sigmoidf definition from rocm_cast_gated_kernels.cuh
  (already defined in cast_gated_kernels.cuh which includes it)

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…e kernels

Same cleanup as cast_gated_kernels.cuh: prefix ROCm-flow constants with
ROCM_ in rocm_cast_kernels.cuh and rocm_dequantize_kernels.cuh to
disambiguate from TDM-flow constants sharing the same namespace.

- rocm_cast_kernels.cuh: rename ELEMS_PER_THREAD, THREADS_PER_CHUNK_X_ROWWISE,
  THREADS_PER_CHUNK_Y_ROWWISE, THREADS_PER_CHUNK_X_COLWISE, TILE_DIM → ROCM_*
- rocm_dequantize_kernels.cuh: rename all constants in dequantization namespace
  (CHUNK_DIM_Y/X, THREADS_PER_CHUNK, BUFFERS_NUM, ELEMS_PER_THREAD,
  BUFFER_DIM_Y/X, SHMEM_DIM_Y/X, THREADS_PER_CHUNK_X_*, ITERATIONS) → ROCM_*
- dequantize_kernels.cuh: add TDM-flow constants directly into dequantization
  namespace (CHUNK_DIM_Y/X, THREADS_PER_CHUNK, BUFFERS_NUM, ELEMS_PER_THREAD,
  BUFFER_DIM_Y, SHMEM_DIM_Y/X, THREADS_PER_CHUNK_X_*, ITERATIONS) so the
  TDM/NV kernel is self-contained and no longer depends on the ROCm include

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/dequantize_kernels.cuh Outdated
…ernels.cuh

- Remove AMD-only template params (SCALE_DIM_Y_TMPL, SCALE_DIM_X_TMPL, IS_ALIGNED)
  from cast_mxfp8_2D_kernel signature; pass raw pointers from launcher instead
- Revert f16 amax computation to __hmax/__habs with thread_amax_f16 (NV upstream pattern)
- Hoist next-stage offset variables above #ifndef to avoid duplicate declarations
- Remove #ifndef guard around ptx::floatx2 block_scale_inverse_2x (works on HIP)
- Fix __shared__ alignas(...) order for AMD shared memory declarations
- Replace AMD TDM launcher with cast_gated_kernels.cuh raw-pointer pattern
- Remove unnecessary TRANSFORMER_ENGINE_SWITCH_CONDITION IS_ALIGNED wrapper
  from AMD TDM dequantize launcher to match NV upstream structure

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Comment thread transformer_engine/common/util/cast_kernels.cuh Outdated
Comment thread transformer_engine/common/util/cast_kernels.cuh
Comment thread transformer_engine/common/util/dequantize_kernels.cuh Outdated
Comment thread transformer_engine/common/util/dequantize_kernels.cuh Outdated
- Remove #ifdef guard from shmem declarations in cast_mxfp8_2D_kernel:
  TDM_SHMEM_ALIGNMENT == TMA_SHMEM_ALIGNMENT == 128, use TMA_SHMEM_ALIGNMENT
  throughout to minimize diff with NV upstream
- Update fp8_quantize AMD path to check NVTE_USE_TDM_FLOW: if set, call
  fp8_quantize_arch_ge_100 (TDM kernel); otherwise fall back to fp8_quantize_rocm
- Remove #ifdef and alignas swap in dequantize_mxfp8_kernel shmem declarations:
  use __shared__ alignas(TMA_SHMEM_ALIGNMENT) for both platforms
- Replace NVTE_USE_NV_UPSTREAM_FLOW with NVTE_USE_TDM_FLOW in dequantize_helper
  AMD path to match cast_gated_kernels.cuh pattern

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Comment on lines +797 to +799
const size_t next_buff = next_iter % FP8_BUFFERS_NUM;
const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const size_t chunk_it_offset_x = chunk_offset_X;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidate with line 783 to 785?

wangye805 and others added 11 commits April 26, 2026 16:08
…2D_kernel

next_buff, chunk_it_offset_y, chunk_it_offset_x were duplicated in both
the TMA and TDM prefetch branches. Hoist above #ifndef to declare once,
matching the pattern from cast_gated_kernels.cuh.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…ls.cuh

ROCM_ rename was applied twice to BUFFER_DIM_Y, SHMEM_DIM_Y, SHMEM_DIM_X,
creating ROCM_ROCM_* definitions while usages only had single ROCM_ prefix.
Also BUFFERS_NUM and BUFFER_DIM_X at the shmem size calculation were never
renamed to their ROCM_ equivalents.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…tion

- common.h: add TMA_SHMEM_ALIGNMENT as alias for TDM_SHMEM_ALIGNMENT in AMD
  block so cast_gated_kernels.cuh launcher code compiles without ifdefs
- rocm_cast_gated_kernels.cuh: define sigmoidf device inline since HIP
  runtime does not provide it (mirrors the CUDA definition in cast_gated_kernels.cuh)

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…e_arch_ge_100 call

fp8_quantize_arch_ge_100 is guarded by #ifndef __HIP_PLATFORM_AMD__ (NV TMA only).
AMD branch should delegate entirely to fp8_quantize_rocm, which internally
dispatches to mxfp8_quantize (TDM path) or rocm_mxfp8_quantize based on
NVTE_USE_TDM_FLOW. Also rename NVTE_USE_NV_UPSTREAM_FLOW to NVTE_USE_TDM_FLOW
in rocm_cast_kernels.cuh to match the unified env var.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
The AMD section of mxfp8_quantize only sets up raw pointers and never
launches a kernel, so calling it directly from quantize_helper left the
scale buffer zero-initialized. fp8_quantize_rocm already has the correct
TDM/plain-ROCm dispatch logic; route AMD through it instead.

Fixes 1110 FusedCastMXFP8TestSuite failures on gfx950 (NVTE_USE_TDM_FLOW=0).
The rowwise scale tensor is allocated with stride padded to
scale_tensor_alignment_X_rowwise (4), but rocm_mxfp8_dequantize was
computing scales_stride = DIVUP(cols, 32) (unpadded). From row 1
onward the kernel reads the wrong scale, producing inf/garbage output.

Fix: use DIVUP_TO_MULTIPLE(..., scale_tensor_alignment_X_rowwise),
matching the allocation in the test harness and the NV dequantize path.

Fixes 6 DequantizeMXFP8TestSuite failures (65x96, block_size=(1,32))
on gfx950.
The NVTE_USE_TDM_FLOW=1 branches in rocm_cast_kernels.cuh,
cast_gated_kernels.cuh, and dequantize_kernels.cuh called TDM/TMA
kernel paths (mxfp8_quantize, cast_mxfp8_gated, mxfp8_dequantize) that
are no-ops on non-gfx1250 AMD — their device code is wrapped in
#if defined(__gfx1250__) so nothing executes, leaving scales at zero.

Wrap the TDM flow selection in #if defined(__HIP_PLATFORM_AMD__) &&
defined(__gfx1250__), falling back to the plain ROCm kernels
(rocm_mxfp8_quantize, rocm_cast_mxfp8_gated, rocm_mxfp8_dequantize)
on all other AMD architectures.

Fixes all 2748 tests passing with NVTE_USE_TDM_FLOW=1 on gfx950.
The AMD section of mxfp8_quantize set up raw pointers but never launched
the kernel — all TDM quantize calls were silent no-ops. Add the kernel
launch switch (ROWWISE/COLWISE/BIDIMENSIONAL) mirroring the NV path, using
raw pointers and TDM shared-memory sizing.

Also fix host-side TDM dispatch guards: replace device-only __gfx1250__
with CMake-injected NVTE_ARCH_HAS_TDM (visible to host compilation) plus
a runtime cuda::sm_arch_name() check, matching the ARCH_HAS_STOCHASTIC_ROUNDING
pattern from PR #472. This ensures gfx942/950-only builds compile cleanly
and multi-arch builds running on non-gfx1250 hardware fall back to the
ROCm path even when NVTE_USE_TDM_FLOW=1.

Add debug fprintf/printf traces across all dispatch and kernel entry points
to confirm which code path executes at runtime.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
The switch-case for cast_mxfp8_2D_kernel is identical on AMD (TDM) and NV
(TMA) — only the first four args differ (raw pointers vs CUtensorMap). Move
the shared dshmem sizing and switch-case after the #ifdef block so there is
a single launch path. The #ifdef now only covers platform-specific setup:
raw pointer casts on AMD, create_2D_tensor_map descriptors on NV.

TMA_SHMEM_ALIGNMENT is aliased to TDM_SHMEM_ALIGNMENT (both 128) so the
shmem calculation is correct on both platforms without a separate formula.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…or TDM path

- Move `using namespace mxfp8_kernel` outside `#ifndef __HIP_PLATFORM_AMD__` so
  tiling constants (CHUNK_DIM_Y/X, SCALE_DIM_X, BUFFS_NUM) are in scope on AMD
- Guard all three `cudaFuncSetAttribute` calls with `#ifndef __HIP_PLATFORM_AMD__`
  since HIP cannot take the address of a templated kernel function the same way;
  dynamic shmem size is still correctly passed via <<<grid, block, dshmem, stream>>>
- Add `__device__ __forceinline__` overloads of `__habs` and `__hmax` for
  `hip_bfloat16` (TE's bf16 alias) because ROCm only defines them for
  `__hip_bfloat16`, a distinct type on this ROCm version

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant