Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/release-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ steps:
- "bash tools/vllm-rocm/generate-rocm-wheels-root-index.sh"
env:
S3_BUCKET: "vllm-wheels"
VARIANT: "rocm700"
VARIANT: "rocm721"

# ROCm Job 6: Build ROCm Release Docker Image
- label: ":docker: Build release image - x86_64 - ROCm"
Expand Down Expand Up @@ -681,6 +681,7 @@ steps:
- label: "Publish nightly ROCm image to DockerHub"
depends_on:
- build-rocm-release-image
if: build.env("NIGHTLY") == "1"
agents:
queue: small_cpu_queue_release
commands:
Expand Down
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}")
# marlin arches for other files
cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")

Expand Down Expand Up @@ -523,12 +523,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()


# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS
Expand Down Expand Up @@ -616,12 +616,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
Expand Down Expand Up @@ -1050,7 +1050,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# - sm80 doesn't support fp8 computation
# - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction
# so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0)
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}")
# moe marlin arches for other files
cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_OTHER_ARCHS)
Expand Down
8 changes: 4 additions & 4 deletions cmake/external_projects/qutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ endif()
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")

if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0f" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(QUTLASS_ARCHS "10.0f;12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a;10.3a" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;12.1a;10.0a;10.3a" "${CUDA_ARCHS}")
endif()

if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS)

if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)")
set(QUTLASS_TARGET_CC 100)
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
elseif(QUTLASS_ARCHS MATCHES "12\\.[01][af]?")
set(QUTLASS_TARGET_CC 120)
else()
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
Expand Down Expand Up @@ -96,7 +96,7 @@ else()
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
else()
message(STATUS
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
"[QUTLASS] Skipping build: no supported arch (12.0f / 10.0f) found in "
"CUDA_ARCHS='${CUDA_ARCHS}'.")
endif()
endif()
39 changes: 37 additions & 2 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,11 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)

# If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS
# Handle architecture-specific suffixes (a/f) for SRC entries.
# First try exact base match (x.y), then cross-suffix match (x.ya / x.yf).
# For 'f' (family) suffix: if no exact/cross match, fall back to major-version
# match — e.g. SRC="12.0f" matches TGT="12.1a" since SM121 is in the SM12x
# family. The output uses TGT's value to preserve the user's compilation flags.
set(_CUDA_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "[af]$")
Expand All @@ -365,6 +368,38 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
list(APPEND _CUDA_ARCHS "${_arch}")
elseif("${_base}a" IN_LIST _TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}a")
list(APPEND _CUDA_ARCHS "${_base}a")
elseif("${_base}f" IN_LIST _TGT_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}f")
list(APPEND _CUDA_ARCHS "${_base}f")
elseif(_arch MATCHES "f$")
# Family suffix: match any TGT entry in the same major version family.
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _src_major "${_base}")
foreach(_tgt ${_TGT_CUDA_ARCHS})
string(REGEX REPLACE "[af]$" "" _tgt_base "${_tgt}")
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _tgt_major "${_tgt_base}")
if(_tgt_major STREQUAL _src_major)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_tgt}")
list(APPEND _CUDA_ARCHS "${_tgt}")
break()
endif()
endforeach()
endif()
endif()
endforeach()

# Symmetric handling: if TGT has x.ya/f and SRC has x.y (without suffix),
# preserve TGT's suffix in the output.
set(_tgt_copy ${_TGT_CUDA_ARCHS})
foreach(_arch ${_tgt_copy})
if(_arch MATCHES "[af]$")
string(REGEX REPLACE "[af]$" "" _base "${_arch}")
if ("${_base}" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_arch}")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_base}")
list(APPEND _CUDA_ARCHS "${_arch}")
endif()
endif()
endforeach()
Expand Down
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_wna16/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const float *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
Expand Down
35 changes: 22 additions & 13 deletions csrc/moe/marlin_moe_wna16/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ __global__ void Marlin(
// fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ scales_ptr,
// fp16 global scale (for nvfp4// only)
const uint16_t* __restrict__ global_scale_ptr,
const float* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
Expand Down Expand Up @@ -308,7 +308,14 @@ __global__ void Marlin(
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
static constexpr auto num_bits =
vllm::ScalarType::from_id(b_type_id).size_bits();
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
// num_bits == 4
constexpr bool use_fp16_accum =
a_type_id == vllm::kFloat16.id() &&
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
!(group_blocks == -1 && num_bits == 4));
#else
constexpr bool use_fp16_accum = false;
#endif
Expand Down Expand Up @@ -357,7 +364,7 @@ __global__ void Marlin(
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);

c_scalar_t2 global_scale;
float global_scale_f32 = 1.0f;

constexpr bool has_act_order = group_blocks == 0;

Expand Down Expand Up @@ -507,11 +514,12 @@ __global__ void Marlin(

if (mul_topk_weights) {
idx = idx < prob_m_top_k ? idx : 0;
c_scalar_t2 topk_weight_val =
Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx]));
float topk_weight_tmp = topk_weights_ptr[idx];
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
topk_weight_val = __hmul2(topk_weight_val, global_scale);
topk_weight_tmp *= global_scale_f32;
}
c_scalar_t2 topk_weight_val =
Cdtype::num2num2(Cdtype::float2num(topk_weight_tmp));
sh_block_topk_weights[threadIdx.x] = topk_weight_val;
}
}
Expand All @@ -532,8 +540,7 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id];

if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[expert_id];
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
global_scale_f32 = global_scale_ptr[expert_id];
}

B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
Expand Down Expand Up @@ -1784,6 +1791,13 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) {
c0 *= global_scale_f32;
c1 *= global_scale_f32;
}
}

c_scalar_t2 res =
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));

Expand All @@ -1800,11 +1814,6 @@ __global__ void Marlin(
res = __hmul2(res, tmp_scale);
}

if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if (has_bias && last) {
c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
Expand Down
8 changes: 4 additions & 4 deletions csrc/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
const int4* bias_ptr = (const int4*)b_bias;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const float* g_s_ptr = (const float*)g_s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
Expand Down Expand Up @@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
global_scale = torch::empty({0}, options_fp32);
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
Expand Down Expand Up @@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm(

TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
"scalar type of global_scale must be float");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),
Expand Down
3 changes: 0 additions & 3 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() {
}

__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return fminf(mmax, fmaxf(v, mmin));
#else
#endif
}

__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/marlin/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const float *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
Expand Down
10 changes: 5 additions & 5 deletions csrc/quantization/marlin/marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ torch::Tensor marlin_gemm(
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
"marlin_gemm(..) requires CUDA_ARCH >= 7.5");
return torch::empty({1, 1});
}

Expand Down Expand Up @@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
const int4* bias_ptr = (const int4*)b_bias;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const float* g_s_ptr = (const float*)g_s;

const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
Expand Down Expand Up @@ -751,7 +751,7 @@ torch::Tensor marlin_gemm(
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
global_scale = torch::empty({0}, options_fp32);
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
Expand Down Expand Up @@ -832,8 +832,8 @@ torch::Tensor marlin_gemm(

TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float,
"scalar type of global_scale must be float");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),
Expand Down
25 changes: 15 additions & 10 deletions csrc/quantization/marlin/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ __global__ void Marlin(
const float* __restrict__ a_scales_ptr,
// fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ scales_ptr,
// fp16 global scale (for nvfp4// only)
const uint16_t* __restrict__ global_scale_ptr,
// float global scale (for nvfp4// only)
const float* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
Expand Down Expand Up @@ -292,7 +292,13 @@ __global__ void Marlin(
#endif

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits();
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
// num_bits == 4
constexpr bool use_fp16_accum =
a_type_id == vllm::kFloat16.id() &&
(!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) &&
!(group_blocks == -1 && num_bits == 4));
#else
constexpr bool use_fp16_accum = false;
#endif
Expand Down Expand Up @@ -342,11 +348,10 @@ __global__ void Marlin(
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);

c_scalar_t2 global_scale;
float global_scale_f32 = 1.0f;

if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[0];
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
global_scale_f32 = global_scale_ptr[0];
}

constexpr bool has_act_order = group_blocks == 0;
Expand Down Expand Up @@ -1644,6 +1649,10 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
c0 *= global_scale_f32;
c1 *= global_scale_f32;
}
c_scalar_t2 res =
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));

Expand All @@ -1659,10 +1668,6 @@ __global__ void Marlin(
}
res = __hmul2(res, tmp_scale);
}

if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale);
}
if (has_bias && last) {
c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
Expand Down
Loading
Loading