diff --git a/csrc/rocm/skinny_gemms_int4.cu b/csrc/rocm/skinny_gemms_int4.cu index 823d2822035c..3fb1ced9661f 100644 --- a/csrc/rocm/skinny_gemms_int4.cu +++ b/csrc/rocm/skinny_gemms_int4.cu @@ -263,54 +263,81 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) *(const half2*)&BIAS_HI); } } else { - // ExLlama shuffle: 8 unsigned int4 values packed per uint32 - // Bit layout: [v0,v2,v4,v6] in low 16 bits, - // [v1,v3,v5,v7] in high 16 bits - // Symmetric uses a fixed zero point of -8. - constexpr int ZP_BIAS = HAS_ZERO_POINTS ? 0 : 8; + // BF16 path: dequant int4 → fp16 via bit manipulation + // (same as fp16 path), convert bf16 activations → fp16, + // then use native fdot2 for the dot product. + // This is ~2.5× fewer ALU ops than scalar bf16→f32 path. + constexpr uint32_t BF16_FP16_MAGIC = 0x64006400u; + constexpr uint32_t BF16_BIAS_LO = + HAS_ZERO_POINTS ? 0x64006400u : 0x64086408u; + constexpr uint32_t BF16_SCALE16 = 0x2C002C00u; + constexpr uint32_t BF16_BIAS_HI = + HAS_ZERO_POINTS ? 0xD400D400u : 0xD480D480u; + + // Temporary fp16 storage for dequantized weights + half cvtW_h[A_CHUNK]; + #pragma unroll for (uint32_t w = 0; w < A_CHUNK / 8; w++) { uint32_t qa = bigB[y][k2].u32[w]; - cvtB.h[w * 8 + 0] = (scalar_t)((int)(qa & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 1] = - (scalar_t)((int)((qa >> 16) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 2] = - (scalar_t)((int)((qa >> 4) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 3] = - (scalar_t)((int)((qa >> 20) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 4] = - (scalar_t)((int)((qa >> 8) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 5] = - (scalar_t)((int)((qa >> 24) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 6] = - (scalar_t)((int)((qa >> 12) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 7] = - (scalar_t)((int)((qa >> 28) & 0xF) - ZP_BIAS); + uint32_t lo0 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC; + uint32_t hi0 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC; + qa >>= 8; + uint32_t lo1 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC; + uint32_t hi1 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC; + + *(half2*)&cvtW_h[w * 8 + 0] = + __hsub2(*(half2*)&lo0, *(const half2*)&BF16_BIAS_LO); + *(half2*)&cvtW_h[w * 8 + 2] = + __hfma2(*(half2*)&hi0, *(const half2*)&BF16_SCALE16, + *(const half2*)&BF16_BIAS_HI); + *(half2*)&cvtW_h[w * 8 + 4] = + __hsub2(*(half2*)&lo1, *(const half2*)&BF16_BIAS_LO); + *(half2*)&cvtW_h[w * 8 + 6] = + __hfma2(*(half2*)&hi1, *(const half2*)&BF16_SCALE16, + *(const half2*)&BF16_BIAS_HI); } - } - if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) { - uint32_t group_idx = k_ / GROUP_SIZE; - scalar_t zp = zero_points[(m + y) * num_groups + group_idx]; + if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) { + uint32_t group_idx = k_ / GROUP_SIZE; + float zp_f = + __s2float(zero_points[(m + y) * num_groups + group_idx]); + half zp_h = __float2half(zp_f); #pragma unroll - for (uint32_t b = 0; b < A_CHUNK; b++) { - cvtB.h[b] = cvtB.h[b] - zp; + for (uint32_t b = 0; b < A_CHUNK; b++) { + cvtW_h[b] = cvtW_h[b] - zp_h; + } } - } - if constexpr (GROUP_SIZE > 0) { - float partial = 0; + // Convert bf16 activations to fp16 and dot product via fdot2 + if constexpr (GROUP_SIZE > 0) { + float partial = 0; #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(partial, bigA[n][k2].f[b], cvtB.f[b]) - } - uint32_t group_idx = k_ / GROUP_SIZE; - sum[n][y] += - partial * __s2float(scale[(m + y) * num_groups + group_idx]); - } else { + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + half a_h0 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0])); + half a_h1 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1])); + half2 a_pair = {a_h0, a_h1}; + half2 w_pair = *(half2*)&cvtW_h[b * 2]; + partial = + __builtin_amdgcn_fdot2(a_pair, w_pair, partial, false); + } + uint32_t group_idx = k_ / GROUP_SIZE; + sum[n][y] += partial * + __s2float(scale[(m + y) * num_groups + group_idx]); + } else { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][y], bigA[n][k2].f[b], cvtB.f[b]) + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + half a_h0 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0])); + half a_h1 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1])); + half2 a_pair = {a_h0, a_h1}; + half2 w_pair = *(half2*)&cvtW_h[b * 2]; + sum[n][y] = + __builtin_amdgcn_fdot2(a_pair, w_pair, sum[n][y], false); + } } } } @@ -527,54 +554,81 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) *(const half2*)&BIAS_HI); } } else { - // ExLlama shuffle: 8 unsigned int4 values packed per uint32 - // Bit layout: [v0,v2,v4,v6] in low 16 bits, - // [v1,v3,v5,v7] in high 16 bits - // Symmetric uses a fixed zero point of -8. - constexpr int ZP_BIAS = HAS_ZERO_POINTS ? 0 : 8; + // BF16 path: dequant int4 → fp16 via bit manipulation + // (same as fp16 path), convert bf16 activations → fp16, + // then use native fdot2 for the dot product. + // This is ~2.5× fewer ALU ops than scalar bf16→f32 path. + constexpr uint32_t BF16_FP16_MAGIC = 0x64006400u; + constexpr uint32_t BF16_BIAS_LO = + HAS_ZERO_POINTS ? 0x64006400u : 0x64086408u; + constexpr uint32_t BF16_SCALE16 = 0x2C002C00u; + constexpr uint32_t BF16_BIAS_HI = + HAS_ZERO_POINTS ? 0xD400D400u : 0xD480D480u; + + // Temporary fp16 storage for dequantized weights + half cvtW_h[A_CHUNK]; + #pragma unroll for (uint32_t w = 0; w < A_CHUNK / 8; w++) { uint32_t qa = bigB[y][k2].u32[w]; - cvtB.h[w * 8 + 0] = (scalar_t)((int)(qa & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 1] = - (scalar_t)((int)((qa >> 16) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 2] = - (scalar_t)((int)((qa >> 4) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 3] = - (scalar_t)((int)((qa >> 20) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 4] = - (scalar_t)((int)((qa >> 8) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 5] = - (scalar_t)((int)((qa >> 24) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 6] = - (scalar_t)((int)((qa >> 12) & 0xF) - ZP_BIAS); - cvtB.h[w * 8 + 7] = - (scalar_t)((int)((qa >> 28) & 0xF) - ZP_BIAS); + uint32_t lo0 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC; + uint32_t hi0 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC; + qa >>= 8; + uint32_t lo1 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC; + uint32_t hi1 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC; + + *(half2*)&cvtW_h[w * 8 + 0] = + __hsub2(*(half2*)&lo0, *(const half2*)&BF16_BIAS_LO); + *(half2*)&cvtW_h[w * 8 + 2] = + __hfma2(*(half2*)&hi0, *(const half2*)&BF16_SCALE16, + *(const half2*)&BF16_BIAS_HI); + *(half2*)&cvtW_h[w * 8 + 4] = + __hsub2(*(half2*)&lo1, *(const half2*)&BF16_BIAS_LO); + *(half2*)&cvtW_h[w * 8 + 6] = + __hfma2(*(half2*)&hi1, *(const half2*)&BF16_SCALE16, + *(const half2*)&BF16_BIAS_HI); } - } - if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) { - uint32_t group_idx = k_ / GROUP_SIZE; - scalar_t zp = zero_points[(m + y) * num_groups + group_idx]; + if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) { + uint32_t group_idx = k_ / GROUP_SIZE; + float zp_f = + __s2float(zero_points[(m + y) * num_groups + group_idx]); + half zp_h = __float2half(zp_f); #pragma unroll - for (uint32_t b = 0; b < A_CHUNK; b++) { - cvtB.h[b] = cvtB.h[b] - zp; + for (uint32_t b = 0; b < A_CHUNK; b++) { + cvtW_h[b] = cvtW_h[b] - zp_h; + } } - } - if constexpr (GROUP_SIZE > 0) { - float partial = 0; + // Convert bf16 activations to fp16 and dot product via fdot2 + if constexpr (GROUP_SIZE > 0) { + float partial = 0; #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(partial, bigA[n][k2].f[b], cvtB.f[b]) - } - uint32_t group_idx = k_ / GROUP_SIZE; - sum[n][y] += - partial * __s2float(scale[(m + y) * num_groups + group_idx]); - } else { + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + half a_h0 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0])); + half a_h1 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1])); + half2 a_pair = {a_h0, a_h1}; + half2 w_pair = *(half2*)&cvtW_h[b * 2]; + partial = + __builtin_amdgcn_fdot2(a_pair, w_pair, partial, false); + } + uint32_t group_idx = k_ / GROUP_SIZE; + sum[n][y] += partial * + __s2float(scale[(m + y) * num_groups + group_idx]); + } else { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][y], bigA[n][k2].f[b], cvtB.f[b]) + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + half a_h0 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0])); + half a_h1 = + __float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1])); + half2 a_pair = {a_h0, a_h1}; + half2 w_pair = *(half2*)&cvtW_h[b * 2]; + sum[n][y] = + __builtin_amdgcn_fdot2(a_pair, w_pair, sum[n][y], false); + } } } }