NVFP4 recipe with GEMM via BF16 dequant#518
Conversation
Remove TODO regarding userbuffers
Userbuffer Enablement for ROCm
* Update Dockerfile to use ROCm TheRock * Update wheels building script to work with ROCm TheRock and the latest Manylinux image * Support default ROCm location /opt/rocm/core * Fix UB code build on TheRock * Support comma separated list of target GPU architectures * Guess ROCm build from HIP_PLATFORM
ada428f to
a91eaf0
Compare
| const float fp8_max = te_fp8_fnuz() ? 240.0f : 448.0f; | ||
| const float factor_inv = 1.0f / (6.0f * fp8_max); |
There was a problem hiding this comment.
Same comment as above regarding using Numeric_Traits_fp8e4m3 here
| if (is_fp4_dtype(param.Atype)) { | ||
| hip_bfloat16* a_bf16 = reinterpret_cast<hip_bfloat16*>(ws_ptr); | ||
| ws_ptr += a_bf16_bytes; | ||
| const int64_t total_a = static_cast<int64_t>(m) * k; | ||
| const auto& a_sinv = (transa == CUBLAS_OP_T) ? inputA.scale_inv | ||
| : inputA.columnwise_scale_inv; | ||
| const int64_t a_num_cols = (transa == CUBLAS_OP_T) | ||
| ? inputA.data.shape.back() | ||
| : inputA.columnwise_data.shape.back(); | ||
| const int64_t a_scale_stride = (a_sinv.shape.size() >= 2) ? a_sinv.shape[1] : (a_num_cols / 16); | ||
| launch_dequant_fp4_to_bf16(param.A, param.A_scale_inv, a_bf16, total_a, | ||
| a_num_cols, a_scale_stride, stream); | ||
| param.A = a_bf16; | ||
| param.Atype = DType::kBFloat16; | ||
| param.A_scale_inv = nullptr; | ||
| } | ||
|
|
||
| if (is_fp4_dtype(param.Btype)) { | ||
| hip_bfloat16* b_bf16 = reinterpret_cast<hip_bfloat16*>(ws_ptr); | ||
| ws_ptr += b_bf16_bytes; | ||
| const int64_t total_b = static_cast<int64_t>(k) * n; | ||
| const auto& b_sinv = (transb == CUBLAS_OP_N) ? inputB.scale_inv | ||
| : inputB.columnwise_scale_inv; | ||
| const int64_t b_num_cols = (transb == CUBLAS_OP_N) | ||
| ? inputB.data.shape.back() | ||
| : inputB.columnwise_data.shape.back(); | ||
| const int64_t b_scale_stride = (b_sinv.shape.size() >= 2) ? b_sinv.shape[1] : (b_num_cols / 16); | ||
| launch_dequant_fp4_to_bf16(param.B, param.B_scale_inv, b_bf16, total_b, | ||
| b_num_cols, b_scale_stride, stream); | ||
| param.B = b_bf16; | ||
| param.Btype = DType::kBFloat16; | ||
| param.B_scale_inv = nullptr; | ||
| } |
There was a problem hiding this comment.
Minor comment: would it make sense to factor the repeated FP4→BF16 staging logic for A/B into a small helper? The two blocks look structurally similar, aside from the operand-specific shape/layout details.
There was a problem hiding this comment.
Thanks, I factored this out into a lambda function in fae76d3
aris134
left a comment
There was a problem hiding this comment.
LGTM. I left one minor non-blocking suggestion, but this looks good to me overall.
| << "type_a" << "type_b" << "type_d" << "bias_type" << "aux_type" | ||
| << "lda" << "ldb" << "ldd" << "scale_mode" << "epi" << "comp" << "scale_type" | ||
| << "ws_min" << "ws_max" << "algo_id" << "aidx"; | ||
| << "ws_min" << "ws_max" << "algo_id" << "aidx" << "fp4_alpha"; |
There was a problem hiding this comment.
nit: please move it before ws_min. Those last 4 parameters do not participate in key
| std::getline(is, scale, csv_sep); | ||
| is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; | ||
| int fp4_alpha = 0; | ||
| if (is.peek() == csv_sep) { |
There was a problem hiding this comment.
Not needed, by contract the cache should be rebuilt with new TE so no backward compatibility
| return tile_dim * ((tile_dim / kNVecSMem) + 1) * kNVecSMem; | ||
| } | ||
| #else | ||
| constexpr int kTileDim = 128; |
There was a problem hiding this comment.
kTileDim is declared at line 143, kThreadsPerBlock at 148
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // On AMD, kTileDim_ is a template parameter of the kernel for runtime dispatch: | ||
| // gfx942: kTileDim_=64 (64 KB LDS, kThreadsPerBlock=128, 4 warps) | ||
| // gfx950: kTileDim_=128 (128 KB LDS, kThreadsPerBlock=256, 8 warps) |
There was a problem hiding this comment.
If the values are hardcoded depending on the platform they do not need to be template parameter but constexpr guarded with platform-specific ifdefs
There was a problem hiding this comment.
Replace templates with constexpr parameters in a6f4787
0f240ad to
9ed88ff
Compare
9ed88ff to
a6f4787
Compare
Description
Part of https://github.com/ROCm/frameworks-internal/issues/15682
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: