Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 130 additions & 76 deletions csrc/rocm/skinny_gemms_int4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -263,54 +263,81 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
*(const half2*)&BIAS_HI);
}
} else {
// ExLlama shuffle: 8 unsigned int4 values packed per uint32
// Bit layout: [v0,v2,v4,v6] in low 16 bits,
// [v1,v3,v5,v7] in high 16 bits
// Symmetric uses a fixed zero point of -8.
constexpr int ZP_BIAS = HAS_ZERO_POINTS ? 0 : 8;
// BF16 path: dequant int4 → fp16 via bit manipulation
// (same as fp16 path), convert bf16 activations → fp16,
// then use native fdot2 for the dot product.
// This is ~2.5× fewer ALU ops than scalar bf16→f32 path.
constexpr uint32_t BF16_FP16_MAGIC = 0x64006400u;
constexpr uint32_t BF16_BIAS_LO =
HAS_ZERO_POINTS ? 0x64006400u : 0x64086408u;
constexpr uint32_t BF16_SCALE16 = 0x2C002C00u;
constexpr uint32_t BF16_BIAS_HI =
HAS_ZERO_POINTS ? 0xD400D400u : 0xD480D480u;

// Temporary fp16 storage for dequantized weights
half cvtW_h[A_CHUNK];

#pragma unroll
for (uint32_t w = 0; w < A_CHUNK / 8; w++) {
uint32_t qa = bigB[y][k2].u32[w];
cvtB.h[w * 8 + 0] = (scalar_t)((int)(qa & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 1] =
(scalar_t)((int)((qa >> 16) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 2] =
(scalar_t)((int)((qa >> 4) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 3] =
(scalar_t)((int)((qa >> 20) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 4] =
(scalar_t)((int)((qa >> 8) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 5] =
(scalar_t)((int)((qa >> 24) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 6] =
(scalar_t)((int)((qa >> 12) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 7] =
(scalar_t)((int)((qa >> 28) & 0xF) - ZP_BIAS);
uint32_t lo0 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC;
uint32_t hi0 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC;
qa >>= 8;
uint32_t lo1 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC;
uint32_t hi1 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC;

*(half2*)&cvtW_h[w * 8 + 0] =
__hsub2(*(half2*)&lo0, *(const half2*)&BF16_BIAS_LO);
*(half2*)&cvtW_h[w * 8 + 2] =
__hfma2(*(half2*)&hi0, *(const half2*)&BF16_SCALE16,
*(const half2*)&BF16_BIAS_HI);
*(half2*)&cvtW_h[w * 8 + 4] =
__hsub2(*(half2*)&lo1, *(const half2*)&BF16_BIAS_LO);
*(half2*)&cvtW_h[w * 8 + 6] =
__hfma2(*(half2*)&hi1, *(const half2*)&BF16_SCALE16,
*(const half2*)&BF16_BIAS_HI);
}
}

if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) {
uint32_t group_idx = k_ / GROUP_SIZE;
scalar_t zp = zero_points[(m + y) * num_groups + group_idx];
if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) {
uint32_t group_idx = k_ / GROUP_SIZE;
float zp_f =
__s2float(zero_points[(m + y) * num_groups + group_idx]);
half zp_h = __float2half(zp_f);
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK; b++) {
cvtB.h[b] = cvtB.h[b] - zp;
for (uint32_t b = 0; b < A_CHUNK; b++) {
cvtW_h[b] = cvtW_h[b] - zp_h;
}
}
}

if constexpr (GROUP_SIZE > 0) {
float partial = 0;
// Convert bf16 activations to fp16 and dot product via fdot2
if constexpr (GROUP_SIZE > 0) {
float partial = 0;
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(partial, bigA[n][k2].f[b], cvtB.f[b])
}
uint32_t group_idx = k_ / GROUP_SIZE;
sum[n][y] +=
partial * __s2float(scale[(m + y) * num_groups + group_idx]);
} else {
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
half a_h0 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0]));
half a_h1 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1]));
half2 a_pair = {a_h0, a_h1};
half2 w_pair = *(half2*)&cvtW_h[b * 2];
partial =
__builtin_amdgcn_fdot2(a_pair, w_pair, partial, false);
}
uint32_t group_idx = k_ / GROUP_SIZE;
sum[n][y] += partial *
__s2float(scale[(m + y) * num_groups + group_idx]);
} else {
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], cvtB.f[b])
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
half a_h0 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0]));
half a_h1 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1]));
half2 a_pair = {a_h0, a_h1};
half2 w_pair = *(half2*)&cvtW_h[b * 2];
sum[n][y] =
__builtin_amdgcn_fdot2(a_pair, w_pair, sum[n][y], false);
}
}
}
}
Expand Down Expand Up @@ -527,54 +554,81 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
*(const half2*)&BIAS_HI);
}
} else {
// ExLlama shuffle: 8 unsigned int4 values packed per uint32
// Bit layout: [v0,v2,v4,v6] in low 16 bits,
// [v1,v3,v5,v7] in high 16 bits
// Symmetric uses a fixed zero point of -8.
constexpr int ZP_BIAS = HAS_ZERO_POINTS ? 0 : 8;
// BF16 path: dequant int4 → fp16 via bit manipulation
// (same as fp16 path), convert bf16 activations → fp16,
// then use native fdot2 for the dot product.
// This is ~2.5× fewer ALU ops than scalar bf16→f32 path.
constexpr uint32_t BF16_FP16_MAGIC = 0x64006400u;
constexpr uint32_t BF16_BIAS_LO =
HAS_ZERO_POINTS ? 0x64006400u : 0x64086408u;
constexpr uint32_t BF16_SCALE16 = 0x2C002C00u;
constexpr uint32_t BF16_BIAS_HI =
HAS_ZERO_POINTS ? 0xD400D400u : 0xD480D480u;

// Temporary fp16 storage for dequantized weights
half cvtW_h[A_CHUNK];

#pragma unroll
for (uint32_t w = 0; w < A_CHUNK / 8; w++) {
uint32_t qa = bigB[y][k2].u32[w];
cvtB.h[w * 8 + 0] = (scalar_t)((int)(qa & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 1] =
(scalar_t)((int)((qa >> 16) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 2] =
(scalar_t)((int)((qa >> 4) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 3] =
(scalar_t)((int)((qa >> 20) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 4] =
(scalar_t)((int)((qa >> 8) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 5] =
(scalar_t)((int)((qa >> 24) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 6] =
(scalar_t)((int)((qa >> 12) & 0xF) - ZP_BIAS);
cvtB.h[w * 8 + 7] =
(scalar_t)((int)((qa >> 28) & 0xF) - ZP_BIAS);
uint32_t lo0 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC;
uint32_t hi0 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC;
qa >>= 8;
uint32_t lo1 = (qa & 0x000F000Fu) | BF16_FP16_MAGIC;
uint32_t hi1 = (qa & 0x00F000F0u) | BF16_FP16_MAGIC;

*(half2*)&cvtW_h[w * 8 + 0] =
__hsub2(*(half2*)&lo0, *(const half2*)&BF16_BIAS_LO);
*(half2*)&cvtW_h[w * 8 + 2] =
__hfma2(*(half2*)&hi0, *(const half2*)&BF16_SCALE16,
*(const half2*)&BF16_BIAS_HI);
*(half2*)&cvtW_h[w * 8 + 4] =
__hsub2(*(half2*)&lo1, *(const half2*)&BF16_BIAS_LO);
*(half2*)&cvtW_h[w * 8 + 6] =
__hfma2(*(half2*)&hi1, *(const half2*)&BF16_SCALE16,
*(const half2*)&BF16_BIAS_HI);
}
}

if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) {
uint32_t group_idx = k_ / GROUP_SIZE;
scalar_t zp = zero_points[(m + y) * num_groups + group_idx];
if constexpr (HAS_ZERO_POINTS && GROUP_SIZE > 0) {
uint32_t group_idx = k_ / GROUP_SIZE;
float zp_f =
__s2float(zero_points[(m + y) * num_groups + group_idx]);
half zp_h = __float2half(zp_f);
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK; b++) {
cvtB.h[b] = cvtB.h[b] - zp;
for (uint32_t b = 0; b < A_CHUNK; b++) {
cvtW_h[b] = cvtW_h[b] - zp_h;
}
}
}

if constexpr (GROUP_SIZE > 0) {
float partial = 0;
// Convert bf16 activations to fp16 and dot product via fdot2
if constexpr (GROUP_SIZE > 0) {
float partial = 0;
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(partial, bigA[n][k2].f[b], cvtB.f[b])
}
uint32_t group_idx = k_ / GROUP_SIZE;
sum[n][y] +=
partial * __s2float(scale[(m + y) * num_groups + group_idx]);
} else {
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
half a_h0 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0]));
half a_h1 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1]));
half2 a_pair = {a_h0, a_h1};
half2 w_pair = *(half2*)&cvtW_h[b * 2];
partial =
__builtin_amdgcn_fdot2(a_pair, w_pair, partial, false);
}
uint32_t group_idx = k_ / GROUP_SIZE;
sum[n][y] += partial *
__s2float(scale[(m + y) * num_groups + group_idx]);
} else {
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], cvtB.f[b])
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
half a_h0 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 0]));
half a_h1 =
__float2half(__bfloat162float(bigA[n][k2].h[b * 2 + 1]));
half2 a_pair = {a_h0, a_h1};
half2 w_pair = *(half2*)&cvtW_h[b * 2];
sum[n][y] =
__builtin_amdgcn_fdot2(a_pair, w_pair, sum[n][y], false);
}
}
}
}
Expand Down
Loading