diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 8fc66793ae61..45b2996f7ead 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -611,7 +611,7 @@ steps: - "bash tools/vllm-rocm/generate-rocm-wheels-root-index.sh" env: S3_BUCKET: "vllm-wheels" - VARIANT: "rocm700" + VARIANT: "rocm721" # ROCm Job 6: Build ROCm Release Docker Image - label: ":docker: Build release image - x86_64 - ROCm" @@ -681,6 +681,7 @@ steps: - label: "Publish nightly ROCm image to DockerHub" depends_on: - build-rocm-release-image + if: build.env("NIGHTLY") == "1" agents: queue: small_cpu_queue_release commands: diff --git a/CMakeLists.txt b/CMakeLists.txt index afc02f7fbbbe..cf59f18eb7e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -363,7 +363,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}") # marlin arches for other files cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") @@ -523,12 +523,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() - # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS @@ -616,12 +616,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require + # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS @@ -1050,7 +1050,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}") # moe marlin arches for other files cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_OTHER_ARCHS) diff --git a/cmake/external_projects/qutlass.cmake b/cmake/external_projects/qutlass.cmake index 84bb1b00c1bb..273fe754bed1 100644 --- a/cmake/external_projects/qutlass.cmake +++ b/cmake/external_projects/qutlass.cmake @@ -32,16 +32,16 @@ endif() message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0f" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(QUTLASS_ARCHS "10.0f;12.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a;10.3a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;12.1a;10.0a;10.3a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND QUTLASS_ARCHS) if(QUTLASS_ARCHS MATCHES "10\\.(0a|3a|0f)") set(QUTLASS_TARGET_CC 100) - elseif(QUTLASS_ARCHS MATCHES "12\\.0a") + elseif(QUTLASS_ARCHS MATCHES "12\\.[01][af]?") set(QUTLASS_TARGET_CC 120) else() message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") @@ -96,7 +96,7 @@ else() "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") else() message(STATUS - "[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " + "[QUTLASS] Skipping build: no supported arch (12.0f / 10.0f) found in " "CUDA_ARCHS='${CUDA_ARCHS}'.") endif() endif() diff --git a/cmake/utils.cmake b/cmake/utils.cmake index fd3d7e0ae8b0..e95333457b57 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -355,8 +355,11 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR list(REMOVE_DUPLICATES _PTX_ARCHS) list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) - # If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should - # remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS + # Handle architecture-specific suffixes (a/f) for SRC entries. + # First try exact base match (x.y), then cross-suffix match (x.ya / x.yf). + # For 'f' (family) suffix: if no exact/cross match, fall back to major-version + # match — e.g. SRC="12.0f" matches TGT="12.1a" since SM121 is in the SM12x + # family. The output uses TGT's value to preserve the user's compilation flags. set(_CUDA_ARCHS) foreach(_arch ${_SRC_CUDA_ARCHS}) if(_arch MATCHES "[af]$") @@ -365,6 +368,38 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR if ("${_base}" IN_LIST TGT_CUDA_ARCHS) list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") list(APPEND _CUDA_ARCHS "${_arch}") + elseif("${_base}a" IN_LIST _TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}a") + list(APPEND _CUDA_ARCHS "${_base}a") + elseif("${_base}f" IN_LIST _TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}f") + list(APPEND _CUDA_ARCHS "${_base}f") + elseif(_arch MATCHES "f$") + # Family suffix: match any TGT entry in the same major version family. + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _src_major "${_base}") + foreach(_tgt ${_TGT_CUDA_ARCHS}) + string(REGEX REPLACE "[af]$" "" _tgt_base "${_tgt}") + string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" _tgt_major "${_tgt_base}") + if(_tgt_major STREQUAL _src_major) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_tgt}") + list(APPEND _CUDA_ARCHS "${_tgt}") + break() + endif() + endforeach() + endif() + endif() + endforeach() + + # Symmetric handling: if TGT has x.ya/f and SRC has x.y (without suffix), + # preserve TGT's suffix in the output. + set(_tgt_copy ${_TGT_CUDA_ARCHS}) + foreach(_arch ${_tgt_copy}) + if(_arch MATCHES "[af]$") + string(REGEX REPLACE "[af]$" "" _base "${_arch}") + if ("${_base}" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_arch}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_base}") + list(APPEND _CUDA_ARCHS "${_arch}") endif() endif() endforeach() diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index e5a3a0b9c945..09ed1a470bd6 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -13,7 +13,7 @@ const int4 *__restrict__ b_bias_ptr, \ const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ global_scale_ptr, \ + const float *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ const int32_t *__restrict__ sorted_token_ids_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \ diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index cddc42643c4c..f5685b898036 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -260,7 +260,7 @@ __global__ void Marlin( // fp16 quantization scales. shape (k/groupsize, n) const int4* __restrict__ scales_ptr, // fp16 global scale (for nvfp4// only) - const uint16_t* __restrict__ global_scale_ptr, + const float* __restrict__ global_scale_ptr, // 4bit packed zero-points of shape // (k/groupsize, n/pack_factor) const int4* __restrict__ zp_ptr, @@ -308,7 +308,14 @@ __global__ void Marlin( constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 - constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id(); + static constexpr auto num_bits = + vllm::ScalarType::from_id(b_type_id).size_bits(); + // Disable use_fp16_accum for NVFP4 and cases when group_size == -1 && + // num_bits == 4 + constexpr bool use_fp16_accum = + a_type_id == vllm::kFloat16.id() && + (!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) && + !(group_blocks == -1 && num_bits == 4)); #else constexpr bool use_fp16_accum = false; #endif @@ -357,7 +364,7 @@ __global__ void Marlin( has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); - c_scalar_t2 global_scale; + float global_scale_f32 = 1.0f; constexpr bool has_act_order = group_blocks == 0; @@ -507,11 +514,12 @@ __global__ void Marlin( if (mul_topk_weights) { idx = idx < prob_m_top_k ? idx : 0; - c_scalar_t2 topk_weight_val = - Cdtype::num2num2(Cdtype::float2num(topk_weights_ptr[idx])); + float topk_weight_tmp = topk_weights_ptr[idx]; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - topk_weight_val = __hmul2(topk_weight_val, global_scale); + topk_weight_tmp *= global_scale_f32; } + c_scalar_t2 topk_weight_val = + Cdtype::num2num2(Cdtype::float2num(topk_weight_tmp)); sh_block_topk_weights[threadIdx.x] = topk_weight_val; } } @@ -532,8 +540,7 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = global_scale_ptr[expert_id]; - global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); + global_scale_f32 = global_scale_ptr[expert_id]; } B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); @@ -1784,6 +1791,13 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + if (!mul_topk_weights) { + c0 *= global_scale_f32; + c1 *= global_scale_f32; + } + } + c_scalar_t2 res = Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); @@ -1800,11 +1814,6 @@ __global__ void Marlin( res = __hmul2(res, tmp_scale); } - if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - if (!mul_topk_weights) { - res = __hmul2(res, global_scale); - } - } if (has_bias && last) { c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index e3f3b4175b92..60681ad930ff 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, const int4* bias_ptr = (const int4*)b_bias; const float* a_s_ptr = (const float*)a_s; const int4* b_s_ptr = (const int4*)b_s; - const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const float* g_s_ptr = (const float*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { - global_scale = torch::empty({0}, options); + global_scale = torch::empty({0}, options_fp32); TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, "scalar type of a_scales must be float"); - TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), - "scalar type of global_scale must be the same with c"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "scalar type of global_scale must be float"); if (a_type.size_bits() == 16) { TORCH_CHECK( a.scalar_type() == c.scalar_type(), diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index c0153bb41b4d..8cc645c33e2f 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() { } __device__ __forceinline__ float clip(float v, float mmin, float mmax) { -#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 return fminf(mmax, fmaxf(v, mmin)); -#else -#endif } __device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, diff --git a/csrc/quantization/marlin/kernel.h b/csrc/quantization/marlin/kernel.h index b3b79c8aec45..8c9cec88b6ad 100644 --- a/csrc/quantization/marlin/kernel.h +++ b/csrc/quantization/marlin/kernel.h @@ -13,7 +13,7 @@ const int4 *__restrict__ b_bias_ptr, \ const float *__restrict__ a_scales_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ global_scale_ptr, \ + const float *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ diff --git a/csrc/quantization/marlin/marlin.cu b/csrc/quantization/marlin/marlin.cu index 62826128c394..fbdb619c27f0 100644 --- a/csrc/quantization/marlin/marlin.cu +++ b/csrc/quantization/marlin/marlin.cu @@ -57,7 +57,7 @@ torch::Tensor marlin_gemm( int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + "marlin_gemm(..) requires CUDA_ARCH >= 7.5"); return torch::empty({1, 1}); } @@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, const int4* bias_ptr = (const int4*)b_bias; const float* a_s_ptr = (const float*)a_s; const int4* b_s_ptr = (const int4*)b_s; - const uint16_t* g_s_ptr = (const uint16_t*)g_s; + const float* g_s_ptr = (const float*)g_s; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; @@ -751,7 +751,7 @@ torch::Tensor marlin_gemm( TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, "global_scale can only be used for nvfp4 format."); } else { - global_scale = torch::empty({0}, options); + global_scale = torch::empty({0}, options_fp32); TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), "the global_scale parameter must be passed for nvfp4 format."); } @@ -832,8 +832,8 @@ torch::Tensor marlin_gemm( TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float, "scalar type of a_scales must be float"); - TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(), - "scalar type of global_scale must be the same with c"); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "scalar type of global_scale must be float"); if (a_type.size_bits() == 16) { TORCH_CHECK( a.scalar_type() == c.scalar_type(), diff --git a/csrc/quantization/marlin/marlin_template.h b/csrc/quantization/marlin/marlin_template.h index c7b53696c122..9e625b645ee6 100644 --- a/csrc/quantization/marlin/marlin_template.h +++ b/csrc/quantization/marlin/marlin_template.h @@ -251,8 +251,8 @@ __global__ void Marlin( const float* __restrict__ a_scales_ptr, // fp16 quantization scales. shape (k/groupsize, n) const int4* __restrict__ scales_ptr, - // fp16 global scale (for nvfp4// only) - const uint16_t* __restrict__ global_scale_ptr, + // float global scale (for nvfp4// only) + const float* __restrict__ global_scale_ptr, // 4bit packed zero-points of shape // (k/groupsize, n/pack_factor) const int4* __restrict__ zp_ptr, @@ -292,7 +292,13 @@ __global__ void Marlin( #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 - constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id(); + constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits(); + // Disable use_fp16_accum for NVFP4 and cases when group_size == -1 && + // num_bits == 4 + constexpr bool use_fp16_accum = + a_type_id == vllm::kFloat16.id() && + (!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) && + !(group_blocks == -1 && num_bits == 4)); #else constexpr bool use_fp16_accum = false; #endif @@ -342,11 +348,10 @@ __global__ void Marlin( has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); - c_scalar_t2 global_scale; + float global_scale_f32 = 1.0f; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - uint16_t val = global_scale_ptr[0]; - global_scale = Cdtype::num2num2(*reinterpret_cast(&val)); + global_scale_f32 = global_scale_ptr[0]; } constexpr bool has_act_order = group_blocks == 0; @@ -1644,6 +1649,10 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + c0 *= global_scale_f32; + c1 *= global_scale_f32; + } c_scalar_t2 res = Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); @@ -1659,10 +1668,6 @@ __global__ void Marlin( } res = __hmul2(res, tmp_scale); } - - if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - res = __hmul2(res, global_scale); - } if (has_bias && last) { c_scalar_t2 tmp_bias = b_bias[0]; if constexpr (m_block_size_8) { diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 60bf8c314e61..fe2b1882da0e 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -386,6 +386,9 @@ ENV MIOPEN_DEBUG_CONV_GEMM=0 # will not be imported by other tests RUN mkdir src && mv vllm src/vllm +# This is a workaround to ensure pytest exits with the correct status code in CI tests. +RUN echo "import os\n\ndef pytest_sessionfinish(session, exitstatus):\n os._exit(int(exitstatus))" > /vllm-workspace/conftest.py + # ----------------------- # Final vLLM image FROM base AS final diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index e5a216c77ba6..e77406728cb4 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,7 +1,7 @@ -ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete -ARG TRITON_BRANCH="57c693b6" +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.2.1-complete +ARG TRITON_BRANCH="ba5c1517" ARG TRITON_REPO="https://github.com/ROCm/triton.git" -ARG PYTORCH_BRANCH="89075173" +ARG PYTORCH_BRANCH="8514f051" # release/2.10 as of 3/17 ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_BRANCH="v0.24.1" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" @@ -114,6 +114,8 @@ ARG TRITON_REPO RUN git clone ${TRITON_REPO} RUN cd triton \ && git checkout ${TRITON_BRANCH} \ + && git config --global user.email "you@example.com" && git config --global user.name "Your Name" \ + && git cherry-pick 555d04f \ && if [ ! -f setup.py ]; then cd python; fi \ && python3 setup.py bdist_wheel --dist-dir=dist \ && mkdir -p /app/install && cp dist/*.whl /app/install @@ -142,10 +144,14 @@ ARG PYTORCH_VISION_REPO ARG PYTORCH_AUDIO_REPO ARG USE_SCCACHE +RUN apt-get update && apt-get install -y pkg-config liblzma-dev RUN git clone ${PYTORCH_REPO} pytorch -RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \ - && pip install -r requirements.txt && git submodule update --init --recursive \ - && python3 tools/amd_build/build_amd.py \ +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} +RUN cd pytorch \ + && pip install -r requirements.txt && git submodule update --init --recursive +RUN cd pytorch/third_party/kineto \ + && git remote add rocm https://github.com/ROCm/kineto && git fetch rocm && git checkout 2d73be3 +RUN cd pytorch && python3 tools/amd_build/build_amd.py \ && if [ "$USE_SCCACHE" = "1" ]; then \ export HIP_CLANG_PATH=/opt/sccache-wrappers \ && export CMAKE_C_COMPILER_LAUNCHER=sccache \ @@ -239,7 +245,7 @@ RUN pip install pyyaml && cd aiter \ export HIP_CLANG_PATH=/opt/sccache-wrappers \ && sccache --show-stats; \ fi \ - && GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \ + && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \ && if [ "$USE_SCCACHE" = "1" ]; then sccache --show-stats; fi \ && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index 1f36ceba617a..101ab9d56119 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -172,8 +172,11 @@ uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700 --8<-- [end:build-wheel-from-source] --8<-- [start:pre-built-images] -vLLM offers an official Docker image for deployment. -The image can be used to run OpenAI compatible server and is available on Docker Hub as [vllm/vllm-openai-rocm](https://hub.docker.com/r/vllm/vllm-openai-rocm/tags). +vLLM offers official Docker images for deployment. +The images can be used to run OpenAI compatible server and are available on Docker Hub as [vllm/vllm-openai-rocm](https://hub.docker.com/r/vllm/vllm-openai-rocm/tags). + +- `vllm/vllm-openai-rocm:latest` — stable release +- `vllm/vllm-openai-rocm:nightly` — preview build from the latest development branch, use this if you want the latest features and fixes ```bash docker run --rm \ @@ -186,30 +189,18 @@ docker run --rm \ --env "HF_TOKEN=$HF_TOKEN" \ -p 8000:8000 \ --ipc=host \ - vllm/vllm-openai-rocm:latest \ + vllm/vllm-openai-rocm: \ --model Qwen/Qwen3-0.6B ``` -#### Use AMD's Docker Images +#### Use AMD's Docker Images (Deprecated) -Prior to January 20th, 2026 when the official docker images are available on [upstream vLLM docker hub](https://hub.docker.com/v2/repositories/vllm/vllm-openai-rocm/tags/), the [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized -docker image designed for validating inference performance on the AMD Instinct MI300X™ accelerator. -AMD also offers nightly prebuilt docker image from [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev), which has vLLM and all its dependencies installed. The entrypoint of this docker image is `/bin/bash` (different from the vLLM's Official Docker Image). +!!! warning "Deprecated" + AMD's Docker images (`rocm/vllm` and `rocm/vllm-dev`) are deprecated in favor of the official vLLM Docker images above (`vllm/vllm-openai-rocm`). Please migrate to the official images. -```bash -docker pull rocm/vllm-dev:nightly # to get the latest image -docker run -it --rm \ ---network=host \ ---group-add=video \ ---ipc=host \ ---cap-add=SYS_PTRACE \ ---security-opt seccomp=unconfined \ ---device /dev/kfd \ ---device /dev/dri \ --v :/app/models \ --e HF_HOME="/app/models" \ -rocm/vllm-dev:nightly -``` +Prior to January 20th, 2026 when the official docker images became available on [upstream vLLM docker hub](https://hub.docker.com/v2/repositories/vllm/vllm-openai-rocm/tags/), the [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offered a prebuilt, optimized +docker image designed for validating inference performance on the AMD Instinct MI300X™ accelerator. +AMD also offered nightly prebuilt docker image from [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev), which has vLLM and all its dependencies installed. The entrypoint of this docker image is `/bin/bash` (different from the vLLM's Official Docker Image). !!! tip Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html) diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index dff86b7d91bc..015514def33f 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -56,9 +56,12 @@ This guide will help you quickly get started with vLLM to perform: !!! note It currently supports Python 3.12, ROCm 7.0 and `glibc >= 2.35`. - !!! note + !!! note Note that, previously, docker images were published using AMD's docker release pipeline and were located `rocm/vllm-dev`. This is being deprecated by using vLLM's docker release pipeline. + !!! tip + A nightly Docker image is also available as [vllm/vllm-openai-rocm:nightly](https://hub.docker.com/r/vllm/vllm-openai-rocm/tags) for testing the latest development builds. + === "Google TPU" To run vLLM on Google TPUs, you need to install the `vllm-tpu` package. diff --git a/tests/model_executor/model_loader/test_reload.py b/tests/model_executor/model_loader/test_reload.py index 6fcb077c1c73..d031eafe8087 100644 --- a/tests/model_executor/model_loader/test_reload.py +++ b/tests/model_executor/model_loader/test_reload.py @@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner): mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] assert add_perp < mul_perp + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize( + "base_model,mul_model,add_model,quantization", + [ + ( + "Qwen/Qwen3-0.6B", + "inference-optimization/Qwen3-0.6B-debug-multiply", + "inference-optimization/Qwen3-0.6B-debug-add", + "fp8", + ), + ( + "inference-optimization/DeepSeek-V3-debug-empty", + "inference-optimization/DeepSeek-V3-debug-multiply", + "inference-optimization/DeepSeek-V3-debug-add", + "fp8", + ), + ( + "Qwen/Qwen3-0.6B", + "inference-optimization/Qwen3-0.6B-debug-multiply", + "inference-optimization/Qwen3-0.6B-debug-add", + "mxfp8", + ), + # ( TODO: support mxfp4 & mla + # "inference-optimization/DeepSeek-V3-debug-empty", + # "inference-optimization/DeepSeek-V3-debug-multiply", + # "inference-optimization/DeepSeek-V3-debug-add", + # "mxfp8", + # ), + ], +) +def test_online_quantize_reload( + base_model, mul_model, add_model, quantization, tp_size, vllm_runner +): + if cuda_device_count_stateless() < tp_size: + pytest.skip(reason="Not enough CUDA devices") + + if quantization == "fp8" and not current_platform.supports_fp8(): + pytest.skip(reason="Requires FP8 support") + + with vllm_runner( + model_name=base_model, + quantization=quantization, + tensor_parallel_size=tp_size, + enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model), + enable_prefix_caching=False, + ) as llm: + llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model}) + mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] + add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] + assert mul_perp < add_perp + + llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model}) + mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0] + add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0] + assert add_perp < mul_perp diff --git a/tests/models/registry.py b/tests/models/registry.py index feb074f117dd..829f559d2ae6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -791,6 +791,7 @@ def check_available_online( "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo( "baidu/ERNIE-4.5-VL-28B-A3B-PT", trust_remote_code=True, + revision="refs/pr/14", ), "FireRedASR2ForConditionalGeneration": _HfExamplesInfo( "allendou/FireRedASR2-LLM-vllm", diff --git a/tools/vllm-rocm/generate-rocm-wheels-root-index.sh b/tools/vllm-rocm/generate-rocm-wheels-root-index.sh index 87b5c3228f7f..650a71937899 100755 --- a/tools/vllm-rocm/generate-rocm-wheels-root-index.sh +++ b/tools/vllm-rocm/generate-rocm-wheels-root-index.sh @@ -17,14 +17,14 @@ # # Environment variables: # S3_BUCKET - Bucket name (default: vllm-wheels) -# VARIANT - ROCm variant (default: rocm700) +# VARIANT - ROCm variant (default: rocm721) # DRY_RUN - Set to 1 for preview mode (same as --dry-run) set -euo pipefail # ======== Configuration ======== BUCKET="${S3_BUCKET:-vllm-wheels}" -VARIANT="${VARIANT:-rocm700}" +VARIANT="${VARIANT:-rocm721}" DRY_RUN="${DRY_RUN:-0}" FORCE_VERSION="" diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 50fe82eb1d4e..dcc93d987eda 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -66,22 +66,21 @@ class CacheConfig: enable_prefix_caching: bool = True """Whether to enable prefix caching.""" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" - """Set the hash algorithm for prefix caching:\n - - "sha256" uses Pickle for object serialization before hashing. This is the - current default, as SHA256 is the most secure choice to avoid potential - hash collisions.\n + """Set the hash algorithm for prefix caching: + + - "sha256" uses Pickle for object serialization before hashing. This is the current + default, as SHA256 is the most secure choice to avoid potential hash collisions. - "sha256_cbor" provides a reproducible, cross-language compatible hash. It - serializes objects using canonical CBOR and hashes them with SHA-256.\n + serializes objects using canonical CBOR and hashes them with SHA-256. - "xxhash" uses Pickle serialization with xxHash (128-bit) for faster, - non-cryptographic hashing. Requires the optional ``xxhash`` package. - IMPORTANT: Use of a hashing algorithm that is not considered - cryptographically secure theoretically increases the risk of hash collisions, - which can cause undefined behavior or even leak private information in - multi-tenant environments. Even if collisions are still very unlikely, it is - important to consider your security risk tolerance against the performance - benefits before turning this on.\n + non-cryptographic hashing. Requires the optional ``xxhash`` package. + IMPORTANT: Use of a hashing algorithm that is not considered cryptographically + secure theoretically increases the risk of hash collisions, which can cause + undefined behavior or even leak private information in multi-tenant environments. + Even if collisions are still very unlikely, it is important to consider your + security risk tolerance against the performance benefits before turning this on. - "xxhash_cbor" combines canonical CBOR serialization with xxHash for - reproducible hashing. Requires the optional ``xxhash`` package.""" + reproducible hashing. Requires the optional ``xxhash`` package.""" calculate_kv_scales: bool = False """Deprecated: This option is deprecated and will be removed in v0.19. It enables dynamic calculation of `k_scale` and `v_scale` when diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 2ec18289d68b..4476cd125265 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -32,14 +32,14 @@ class KernelConfig: moe_backend: MoEBackend = "auto" """Backend for MoE expert computation kernels. Available options: - - "auto": Automatically select the best backend based on model and hardware\n - - "triton": Use Triton-based fused MoE kernels\n - - "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n - - "cutlass": Use vLLM CUTLASS kernels\n - - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n - - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n - - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n - - "marlin": Use Marlin kernels (weight-only quantization)\n + - "auto": Automatically select the best backend based on model and hardware + - "triton": Use Triton-based fused MoE kernels + - "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only) + - "cutlass": Use vLLM CUTLASS kernels + - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels + - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels + - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) + - "marlin": Use Marlin kernels (weight-only quantization) - "aiter": Use AMD AITer kernels (ROCm only)""" @field_validator("moe_backend", mode="before") diff --git a/vllm/config/load.py b/vllm/config/load.py index e77d9b37830e..93240ec5fc0f 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -51,7 +51,7 @@ class LoadConfig: - "gguf" will load weights from GGUF format files (details specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md). - "mistral" will load weights from consolidated safetensors files used by - Mistral models.\n + Mistral models. - Other custom values can be supported via plugins. """ download_dir: str | None = None diff --git a/vllm/config/model.py b/vllm/config/model.py index 225ee119a6c9..acb43a04b157 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -125,26 +125,28 @@ class ModelConfig: """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode | str = "auto" - """Tokenizer mode:\n + """Tokenizer mode: + - "auto" will use the tokenizer from `mistral_common` for Mistral models - if available, otherwise it will use the "hf" tokenizer.\n - - "hf" will use the fast tokenizer if available.\n - - "slow" will always use the slow tokenizer.\n - - "mistral" will always use the tokenizer from `mistral_common`.\n - - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n - - "qwen_vl" will always use the tokenizer from `qwen_vl`.\n + if available, otherwise it will use the "hf" tokenizer. + - "hf" will use the fast tokenizer if available. + - "slow" will always use the slow tokenizer. + - "mistral" will always use the tokenizer from `mistral_common`. + - "deepseek_v32" will always use the tokenizer from `deepseek_v32`. + - "qwen_vl" will always use the tokenizer from `qwen_vl`. - Other custom values can be supported via plugins.""" trust_remote_code: bool = False """Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.""" dtype: ModelDType | torch.dtype = "auto" - """Data type for model weights and activations:\n + """Data type for model weights and activations: + - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 - precision for BF16 models.\n - - "half" for FP16. Recommended for AWQ quantization.\n - - "float16" is the same as "half".\n - - "bfloat16" for a balance between precision and range.\n - - "float" is shorthand for FP32 precision.\n + precision for BF16 models. + - "half" for FP16. Recommended for AWQ quantization. + - "float16" is the same as "half". + - "bfloat16" for a balance between precision and range. + - "float" is shorthand for FP32 precision. - "float32" for FP32 precision.""" seed: int = 0 """Random seed for reproducibility. @@ -182,13 +184,14 @@ class ModelConfig: automatically derived from the model config. When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable - format. Examples:\n - - 1k -> 1000\n - - 1K -> 1024\n - - 25.6k -> 25,600\n + format. Examples: + + - 1k -> 1000 + - 1K -> 1024 + - 25.6k -> 25,600 - -1 or 'auto' -> Automatically choose the maximum model length that fits in - GPU memory. This will use the model's maximum context length if it fits, - otherwise it will find the largest length that can be accommodated.""" + GPU memory. This will use the model's maximum context length if it fits, + otherwise it will find the largest length that can be accommodated.""" spec_target_max_model_len: int | None = None """Specify the maximum length for spec decoding draft models.""" quantization: QuantizationMethods | str | None = None @@ -248,10 +251,11 @@ class ModelConfig: prometheus metrics, if multiple names provided, metrics tag will take the first one.""" config_format: str | ConfigFormat = "auto" - """The format of the model config to load:\n + """The format of the model config to load: + - "auto" will try to load the config in hf format if available after trying - to load in mistral format.\n - - "hf" will load the config in hf format.\n + to load in mistral format. + - "hf" will load the config in hf format. - "mistral" will load the config in mistral format.""" hf_token: bool | str | None = None """The token to use as HTTP bearer authorization for remote files . If @@ -276,12 +280,12 @@ class ModelConfig: """Enable sleep mode for the engine (only cuda and hip platforms are supported).""" model_impl: str | ModelImpl = "auto" - """Which implementation of the model to use:\n - - "auto" will try to use the vLLM implementation, if it exists, and fall - back to the Transformers implementation if no vLLM implementation is - available.\n - - "vllm" will use the vLLM model implementation.\n - - "transformers" will use the Transformers model implementation.\n + """Which implementation of the model to use: + + - "auto" will try to use the vLLM implementation, if it exists, and fall back to the + Transformers implementation if no vLLM implementation is available. + - "vllm" will use the vLLM model implementation. + - "transformers" will use the Transformers model implementation. - "terratorch" will use the TerraTorch model implementation. """ override_attention_dtype: str | None = None @@ -1512,10 +1516,11 @@ def requires_raw_input_tokens(self) -> bool: @property def score_type(self) -> ScoreType: """ - Scoring API handles score/rerank for:\n - - "classify" task (score_type: cross-encoder models)\n - - "embed" task (score_type: bi-encoder models)\n - - "token_embed" task (score_type: late interaction models)\n + Scoring API handles score/rerank for: + + - "classify" task (score_type: cross-encoder models) + - "embed" task (score_type: bi-encoder models) + - "token_embed" task (score_type: late interaction models) """ # fixme: self._model_info.score_type is the score type before # as_seq_cls_model, which is "bi-encoder", rather than the @@ -1593,9 +1598,10 @@ def head_dtype(self) -> torch.dtype: such as the lm_head in a generation model, or the score or classifier in a classification model. - `head_dtype` currently only supports pooling models.\n - - The pooling model defaults to using fp32 head, - you can use --hf-overrides '{"head_dtype": "model"}' to disable it. + `head_dtype` currently only supports pooling models. + + - The pooling model defaults to using fp32 head, you can use + --hf-overrides '{"head_dtype": "model"}' to disable it. """ head_dtype = _get_head_dtype( diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 1c9bc43b01ca..e66511c92ab2 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -146,14 +146,14 @@ class MultiModalConfig: parallelism (TP). - `"weights"`: Within the same vLLM engine, split the weights of - each layer across TP ranks. (default TP behavior)\n + each layer across TP ranks. (default TP behavior) - `"data"`: Within the same vLLM engine, split the batched input data - across TP ranks to process the data in parallel, while hosting - the full weights on each TP rank. - This batch-level DP is not to be confused with API request-level - DP (which is controlled by `--data-parallel-size`). - This is only supported on a per-model basis and falls back to - `"weights"` if the encoder does not support DP.""" + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" mm_encoder_attn_backend: AttentionBackendEnum | None = None """Optional override for the multi-modal encoder attention backend when using vision transformers. Accepts any value from diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 7dd9c5bb516c..8afff3af258e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -148,10 +148,11 @@ class ParallelConfig: eplb_config: EPLBConfig = Field(default_factory=EPLBConfig) """Expert parallelism configuration.""" expert_placement_strategy: ExpertPlacementStrategy = "linear" - """The expert placement strategy for MoE layers:\n + """The expert placement strategy for MoE layers: + - "linear": Experts are placed in a contiguous manner. For example, with 4 experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have - experts [2, 3].\n + experts [2, 3]. - "round_robin": Experts are placed in a round-robin manner. For example, with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 will have experts [1, 3]. This strategy can help improve load balancing @@ -159,11 +160,11 @@ class ParallelConfig: all2all_backend: All2AllBackend = "allgather_reducescatter" """All2All backend for MoE expert parallel communication. Available options: - - "allgather_reducescatter": All2all based on allgather and reducescatter\n - - "deepep_high_throughput": Use deepep high-throughput kernels\n - - "deepep_low_latency": Use deepep low-latency kernels\n - - "mori": Use mori kernels\n - - "nixl_ep": Use nixl-ep kernels\n + - "allgather_reducescatter": All2all based on allgather and reducescatter + - "deepep_high_throughput": Use deepep high-throughput kernels + - "deepep_low_latency": Use deepep low-latency kernels + - "mori": Use mori kernels + - "nixl_ep": Use nixl-ep kernels - "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl - "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels""" diff --git a/vllm/config/profiler.py b/vllm/config/profiler.py index e79e213106db..68fa78854b45 100644 --- a/vllm/config/profiler.py +++ b/vllm/config/profiler.py @@ -37,7 +37,7 @@ class ProfilerConfig: profiler: ProfilerKind | None = None """Which profiler to use. Defaults to None. Options are: - - 'torch': Use PyTorch profiler.\n + - 'torch': Use PyTorch profiler. - 'cuda': Use CUDA profiler.""" torch_profiler_dir: str = "" diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index f988c1086abb..3cd99bb082eb 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -106,11 +106,12 @@ class SchedulerConfig: max_num_batched_tokens in case max multimodal embedding size is larger.""" policy: SchedulerPolicy = "fcfs" - """The scheduling policy to use:\n - - "fcfs" means first come first served, i.e. requests are handled in order - of arrival.\n + """The scheduling policy to use: + + - "fcfs" means first come first served, i.e. requests are handled in order + of arrival. - "priority" means requests are handled based on given priority (lower - value means earlier handling) and time of arrival deciding any ties).""" + value means earlier handling) and time of arrival deciding any ties).""" disable_chunked_mm_input: bool = False """If set to true and chunked prefill is enabled, we do not want to diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 73abd7865642..a953fcb46e42 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -37,10 +37,12 @@ @overload +@dataclass_transform(field_specifiers=(PydanticField,)) def config(cls: type[ConfigT]) -> type[ConfigT]: ... @overload +@dataclass_transform(field_specifiers=(PydanticField,)) def config( *, config: ConfigDict | None = None, **kwargs: Any ) -> Callable[[type[ConfigT]], type[ConfigT]]: ... diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 382b66a70111..b6be7f10bdb0 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -152,13 +152,11 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: - """Enable if using AITER RMSNorm and AITER Triton GEMMs - and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion.""" + """Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss.""" from vllm._aiter_ops import rocm_aiter_ops return ( rocm_aiter_ops.is_rmsnorm_enabled() - and not rocm_aiter_ops.is_triton_gemm_enabled() and cfg.model_config is not None and cfg.model_config.get_hidden_size() == 2880 ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3ba00225fd4e..e1772ab1d427 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1301,7 +1301,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # delay the Pydantic validation that comes with SpeculativeConfig. vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads) vllm_group.add_argument( - "--speculative-config", **vllm_kwargs["speculative_config"] + "--speculative-config", "-sc", **vllm_kwargs["speculative_config"] ) vllm_group.add_argument( "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H800,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H800,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..e1ac98ea5d82 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H800,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..f8d56b7ee2da --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.1", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 671435a88c06..9a6f67b421f9 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( activation_to_flashinfer_int, ) @@ -152,11 +153,8 @@ def apply( import flashinfer from flashinfer.fused_moe import Fp8QuantizationType - # Pack topk_ids and topk_weights into single tensor - # Format: (expert_id << 16) | (weight_bf16.view(int16)) - packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view( - torch.int16 - ) + # Pack topk ids and weights into format expected by the kernel. + packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) # trtllm_fp8_block_scale_routed_moe does not support autotuning # so skip this kernel during dummy run for autotuning. diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 7960bdf44792..84beb6abb553 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( activation_to_flashinfer_int, ) @@ -183,9 +184,7 @@ def apply( assert self.quant_config.w2_scale is not None # Pack topk ids and weights into format expected by the kernel. - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) + packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) # trtllm_fp4_block_scale_routed_moe does not support autotuning # so skip this kernel during dummy run for autotuning. diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 77df6edf9e94..9008bdeeca7e 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -20,10 +20,7 @@ mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - _swizzle_mxfp4, - get_padding_alignment, -) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kMxfp4Static, @@ -396,9 +393,8 @@ def mxfp4_round_up_hidden_size_and_intermediate_size( intermediate_size = round_up(intermediate_size, 128) hidden_size = round_up(hidden_size, 128) elif current_platform.is_rocm(): - pad_align = get_padding_alignment() - intermediate_size = round_up(intermediate_size, pad_align) - hidden_size = round_up(hidden_size, pad_align) + intermediate_size = round_up(intermediate_size, 256) + hidden_size = round_up(hidden_size, 256) else: intermediate_size = round_up(intermediate_size, 64) return hidden_size, intermediate_size diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ba4494f6cdc3..c576b0a25c28 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -323,3 +323,16 @@ def normalize_batched_scales_shape( @functools.cache def disable_inplace() -> bool: return is_torch_equal_or_newer("2.9") + + +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def trtllm_moe_pack_topk_ids_weights( + topk_ids: torch.Tensor, topk_weights: torch.Tensor +) -> torch.Tensor: + """ + Pack topk_ids and topk_weights into a single int32 tensor. + Format: (expert_id << 16) | weight_bf16.view(int16) + """ + return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view( + torch.int16 + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e01148313eb7..9e717da43fbd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -73,7 +73,9 @@ cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) -from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight +from vllm.model_executor.model_loader.reload.layerwise import ( + initialize_online_processing, +) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, @@ -496,8 +498,8 @@ def apply( class Fp8OnlineLinearMethod(Fp8LinearMethod): - """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint - and quantized the weights during loading.""" + """Online version of Fp8LinearMethod which loads a full precision checkpoint + and quantizes weights during loading.""" uses_meta_device: bool = True @@ -519,84 +521,25 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # WEIGHT - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=patched_weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - param = layer.weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel` around just in case - # there is logic added to vllm in the future which calls a - # weight loader twice - we do not want to re-initialize in - # that case. - - return res - weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, - # materialized just-in-time in `patched_weight_loader` - device="meta", + device="meta", # materialized and processed during loading dtype=params_dtype, ), input_dim=1, output_dim=0, - weight_loader=patched_weight_loader, + weight_loader=weight_loader, ) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) + initialize_online_processing(layer) + def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return - # deferred initialization of randomly initialized weights for the - # `--load_format dummy` feature - if layer.weight.device == torch.device("meta"): - weight = ModelWeightParameter( - data=torch.empty_like(layer.weight, device=layer._load_device), - input_dim=1, - output_dim=0, - weight_loader=layer.weight.weight_loader, - ) - _copy_missing_attrs(layer.weight, weight) - layer.register_parameter("weight", weight) - initialize_single_dummy_weight(layer.weight) - # TODO(future): support block_quant in online quant path assert not self.block_quant @@ -845,9 +788,6 @@ def _setup_kernel( ) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight w2 = layer.w2_weight @@ -892,9 +832,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) - # Prevent duplicate processing (e.g., during weight reload) - layer._already_called_process_weights_after_loading = True - def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -1013,86 +950,12 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # We are doing online quantization, patch the weight loaded - # to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded. - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # save the ids of original w13 and w2 so that we can - # distinguish which one `param` should map to further - # down in this file - layer._w13_weight_orig_id = id(layer.w13_weight) - layer._w2_weight_orig_id = id(layer.w2_weight) - - # when the first `loaded_weight` is about to be - # loaded to `param`, materialize `param` just-in-time - - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w13_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs(w2_weight, extra_weight_attrs) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - del layer._load_device - - # refresh the reference to `param` to reflect just-in-time - # materialization - if id(param) == layer._w13_weight_orig_id: - param = layer.w13_weight - elif id(param) == layer._w2_weight_orig_id: - param = layer.w2_weight - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - # Note that we keep `layer._loaded_numel`, - # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` - # around because if EP is on, weight loaders for non-local - # experts will run but not actually copy any elements, and we - # need to not re-initialize in that case. - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, - # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), @@ -1106,91 +969,53 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): num_experts, hidden_size, intermediate_size_per_partition, - # materialized just-in-time in `patched_weight_loader` - device="meta", + device="meta", # materialized and processed during loading dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - # stash the correct device for `patched_weight_loader` - layer._load_device = torch.get_default_device() # BIASES (for models like GPT-OSS that have biased MoE) if self.moe.has_bias: - # Use the original weight_loader (not patched) for biases - orig_extra_weight_attrs = dict(extra_weight_attrs) - orig_extra_weight_attrs["weight_loader"] = weight_loader w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, + device="meta", # materialized and processed during loading dtype=layer.orig_dtype, ), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, orig_extra_weight_attrs) + set_weight_attrs(w13_bias, extra_weight_attrs) + w2_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype), + torch.zeros( + num_experts, + hidden_size, + device="meta", # materialized and processed during loading + dtype=layer.orig_dtype, + ), requires_grad=False, ) layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, orig_extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_bias, extra_weight_attrs) - layer.w13_input_scale = None - layer.w2_input_scale = None + initialize_online_processing(layer) def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return - # deferred initialization of randomly initialized weights for the - # `--load_format dummy` feature - if layer.w13_weight.device == torch.device("meta"): - w13_weight = torch.nn.Parameter( - torch.empty_like(layer.w13_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w13_weight, {"weight_loader": layer.w13_weight.weight_loader} - ) - _copy_missing_attrs(layer.w13_weight, w13_weight) - layer.register_parameter("w13_weight", w13_weight) - initialize_single_dummy_weight(layer.w13_weight) - if layer.w2_weight.device == torch.device("meta"): - w2_weight = torch.nn.Parameter( - torch.empty_like(layer.w2_weight, device=layer._load_device), - requires_grad=False, - ) - set_weight_attrs( - w2_weight, {"weight_loader": layer.w2_weight.weight_loader} - ) - _copy_missing_attrs(layer.w2_weight, w2_weight) - layer.register_parameter("w2_weight", w2_weight) - initialize_single_dummy_weight(layer.w2_weight) - - # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) - w13_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale + w13_scale = torch.ones(layer.num_experts, dtype=torch.float32) + w2_scale = torch.ones(layer.num_experts, dtype=torch.float32) + layer.w13_input_scale = None + layer.w2_input_scale = None for expert in range(layer.local_num_experts): w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( @@ -1207,8 +1032,8 @@ def process_weights_after_loading(self, layer: Module) -> None: w2, w13_scale, w2_scale, - layer.w13_input_scale, - layer.w2_input_scale, + w13_input_scale=layer.w13_input_scale, + w2_input_scale=layer.w2_input_scale, ) # Prevent duplicate processing (e.g., during weight reload) diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py index 5b4564bea31c..bd29f272bd10 100644 --- a/vllm/model_executor/layers/quantization/mxfp8.py +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -337,6 +337,8 @@ def process_weights_after_loading(self, layer: Module) -> None: w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) w13_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale + layer.w13_input_scale = None + layer.w2_input_scale = None w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight) w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 9bc58d2f302d..4fd484edeb30 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -27,10 +27,19 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(75) -def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float: +def _nvfp4_compute_scale_factor( + marlin_scales: torch.Tensor, + a_dtype: torch.dtype | None = None, +) -> float: """Compute the power-of-2 scale_factor needed so that all non-zero values in marlin_scales * 2^7 are >= 2 after rescaling. Returns a Python float (power of 2, >= 1.0).""" + + # Since half has a smaller dynamic range compared to bfloat16, + # no rescaling is applied here if active dtype is half. + if a_dtype is not None and a_dtype == torch.half: + return 1.0 + ws_float = marlin_scales.float() * (2**7) nonzero_mask = ws_float > 0 if nonzero_mask.any(): @@ -44,6 +53,7 @@ def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float: def nvfp4_marlin_process_scales( marlin_scales: torch.Tensor, scale_factor: float | None = None, + a_dtype: torch.dtype | None = None, ) -> tuple[torch.Tensor, float]: """Process NVFP4 weight scales into the special S0E5M3 format for Marlin. @@ -91,7 +101,7 @@ def nvfp4_marlin_process_scales( # to fully utilize the E4M3 dynamic range (e.g., global_scale=1). # The caller must compensate by dividing global_scale by scale_factor. if scale_factor is None: - scale_factor = _nvfp4_compute_scale_factor(marlin_scales) + scale_factor = _nvfp4_compute_scale_factor(marlin_scales, a_dtype) if scale_factor > 1.0: marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half) @@ -119,12 +129,14 @@ def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None): return marlin_scales -def nvfp4_marlin_process_global_scale(global_scale): - assert global_scale.dtype in [torch.half, torch.bfloat16] +def nvfp4_marlin_process_global_scale(global_scale, a_dtype: torch.dtype | None = None): + if a_dtype is None: + a_dtype = global_scale.dtype + assert a_dtype in [torch.half, torch.bfloat16] fp4_exponent = 2 - if global_scale.dtype == torch.half: + if a_dtype == torch.half: target_exponent = 5 - elif global_scale.dtype == torch.bfloat16: + elif a_dtype == torch.bfloat16: target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 @@ -244,11 +256,15 @@ def prepare_fp4_layer_for_marlin( ) if is_nvfp4: - weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale) + weight_scale, scale_factor = nvfp4_marlin_process_scales( + weight_scale, a_dtype=param_dtype + ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - weight_global_scale = layer.weight_global_scale.to(param_dtype) - weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + weight_global_scale = layer.weight_global_scale.to(torch.float32) + weight_global_scale = nvfp4_marlin_process_global_scale( + weight_global_scale, param_dtype + ) weight_global_scale = weight_global_scale / scale_factor layer.weight_global_scale = torch.nn.Parameter( weight_global_scale, requires_grad=False @@ -339,7 +355,6 @@ def premute_scales( scales: torch.Tensor, g_scales: torch.Tensor, name: str ) -> tuple[torch.Tensor, torch.Tensor]: scales = scales.to(param_dtype) - g_scales = g_scales.to(param_dtype) tensor_list = [] num_shards = 2 if is_act_and_mul else 1 @@ -350,7 +365,7 @@ def premute_scales( # All experts share one global_scale, so compute the max # scale_factor across all experts first, then apply uniformly. - combined_scale_factor = _nvfp4_compute_scale_factor(scales) + combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype) for i in range(E): scale = scales[i].T @@ -362,12 +377,12 @@ def premute_scales( is_a_8bit=is_a_8bit, ) marlin_scales, _ = nvfp4_marlin_process_scales( - marlin_scales, scale_factor=combined_scale_factor + marlin_scales, scale_factor=combined_scale_factor, a_dtype=param_dtype ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) - g_scales = nvfp4_marlin_process_global_scale(g_scales) + g_scales = nvfp4_marlin_process_global_scale(g_scales, param_dtype) g_scales = g_scales / combined_scale_factor return scales, g_scales @@ -438,7 +453,7 @@ def prepare_moe_fp4_layer_for_marlin( scales = scales.view(torch.float8_e8m0fnu) scales = scales.to(param_dtype) if is_nvfp4: - global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2") tensor_list = [] if "w13" in name: @@ -449,7 +464,7 @@ def prepare_moe_fp4_layer_for_marlin( # For NVFP4: compute unified scale_factor across all experts combined_scale_factor = None if is_nvfp4: - combined_scale_factor = _nvfp4_compute_scale_factor(scales) + combined_scale_factor = _nvfp4_compute_scale_factor(scales, param_dtype) for i in range(e): scale = scales[i].T @@ -463,7 +478,9 @@ def prepare_moe_fp4_layer_for_marlin( ) if is_nvfp4: marlin_scales, _ = nvfp4_marlin_process_scales( - marlin_scales, scale_factor=combined_scale_factor + marlin_scales, + scale_factor=combined_scale_factor, + a_dtype=param_dtype, ) else: marlin_scales = mxfp4_marlin_process_scales( @@ -477,7 +494,7 @@ def prepare_moe_fp4_layer_for_marlin( if is_nvfp4: assert combined_scale_factor is not None - global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale, param_dtype) global_scale = global_scale / combined_scale_factor global_scale = torch.nn.Parameter(global_scale, requires_grad=False) setattr(layer, name + "_weight_scale_2", global_scale) @@ -665,7 +682,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None): ) marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales) - global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = nvfp4_marlin_process_global_scale(global_scale).to(torch.float32) global_scale = global_scale / scale_factor return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 49ddc8accc29..21c8aba1d56c 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -6,7 +6,6 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.triton_utils import triton from vllm.utils.import_utils import has_triton_kernels from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer @@ -49,9 +48,16 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8): value_layout = StridedLayout if on_gfx950(): - from triton_kernels.tensor_details.layout import GFX950MXScaleLayout + try: + # triton < 3.6 + from triton_kernels.tensor_details.layout import GFX950MXScaleLayout - scale_layout = GFX950MXScaleLayout + scale_layout = GFX950MXScaleLayout + except ImportError: + # triton >= 3.6 + from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout + + scale_layout = CDNA4MXScaleLayout else: scale_layout = StridedLayout else: @@ -85,14 +91,6 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8): return quant_tensor, InFlexData(), scale -def get_padding_alignment(): - return ( - 256 - if triton.runtime.driver.active.get_current_target().arch in ("gfx950",) - else 128 - ) - - def _dequant_mxfp4( x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype ) -> torch.Tensor: diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index e3b965db8aaf..f68405d05f87 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.model_loader.reload import finalize_layerwise_processing from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, @@ -49,16 +50,13 @@ def load_model( device_config.device if load_config.device is None else load_config.device ) target_device = torch.device(load_device) - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model( - vllm_config=vllm_config, model_config=model_config, prefix=prefix - ) - + with set_default_torch_dtype(model_config.dtype), target_device: + model = initialize_model( + vllm_config=vllm_config, model_config=model_config, prefix=prefix + ) log_model_inspection(model) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) # Log peak GPU memory after loading weights. This is needed @@ -71,6 +69,11 @@ def load_model( scope="local", ) + # Process weights into kernel format. Note that when using online + # quantization, weights are (typically) quantized as they are loaded. + if _has_online_quant(model): + finalize_layerwise_processing(model, model_config) + process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None: from vllm.model_inspection import format_model_inspection logger.info("vLLM model structure:\n%s", format_model_inspection(model)) + + +def _has_online_quant(model: nn.Module): + for module in model.modules(): + quant_method = getattr(module, "quant_method", None) + if getattr(quant_method, "uses_meta_device", False): + return True + + return False diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 156071f1dae3..5a8b5de6f553 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch import torch.nn as nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor +from vllm.model_executor.model_loader.reload.utils import get_layer_tensors from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights @@ -23,6 +26,12 @@ def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + # materialize meta tensors as part of online quantization lifecycle + for layer in model.modules(): + for name, param in get_layer_tensors(layer).items(): + if param.device == torch.device("meta"): + setattr(layer, name, materialize_meta_tensor(param)) + # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model, model_config) diff --git a/vllm/model_executor/model_loader/reload/__init__.py b/vllm/model_executor/model_loader/reload/__init__.py index ea0b0bc06ad9..56a9d88ac4e4 100644 --- a/vllm/model_executor/model_loader/reload/__init__.py +++ b/vllm/model_executor/model_loader/reload/__init__.py @@ -21,12 +21,14 @@ __all__ = [ "record_metadata_for_reloading", "initialize_layerwise_reload", + "finalize_layerwise_processing", "finalize_layerwise_reload", "set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights", ] from .layerwise import ( + finalize_layerwise_processing, finalize_layerwise_reload, initialize_layerwise_reload, record_metadata_for_reloading, diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 21795e63995e..2a174673b91b 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -28,6 +28,7 @@ "get_layerwise_info", "record_metadata_for_reloading", "initialize_layerwise_reload", + "finalize_layerwise_processing", "finalize_layerwise_reload", ] @@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module): info = get_layerwise_info(layer) # Skip if the layer has already been initialized - if info.can_process(): + if info.can_load(): continue # Save current tensors for later copying @@ -98,15 +99,21 @@ def initialize_layerwise_reload(model: torch.nn.Module): # Restore layer parameters/buffers onto meta device restore_layer_on_meta(layer, info) - # Track loading progress to determine when to process/copy - info.load_numel = 0 - info.load_numel_total = get_layer_size(layer) + initialize_online_processing(layer) - # Wrap each parameter's weight loader - # Note that nested wrapping will occur for shared tensors - for name, tensor in get_layer_tensors(layer).items(): - if _get_weight_loader(tensor).__name__ != "online_process_loader": - tensor.weight_loader = make_online_process_loader(layer, name) + +def initialize_online_processing(layer: torch.nn.Module): + info = get_layerwise_info(layer) + + # Track loading progress to determine when to process/copy + info.load_numel = 0 + info.load_numel_total = get_layer_size(layer) + + # Wrap each parameter's weight loader + # Note that nested wrapping will occur for shared tensors + for name, tensor in get_layer_tensors(layer).items(): + if _get_weight_loader(tensor).__name__ != "online_process_loader": + tensor.weight_loader = make_online_process_loader(layer, name) def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable: @@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla @wraps(original_loader, assigned=("__doc__", "__annotations__")) def online_process_loader(*args, **kwargs): - if not info.can_process(): + if not info.can_load(): # Unfortunately, some qconfigs are set up to load the same weight # multiple times. For example, CT_WNA16 loads `weight_shape` for # each of the qkv partitions. This results in layers loading extra @@ -140,7 +147,7 @@ def online_process_loader(*args, **kwargs): bound_args = loader_signature.bind(*args, **kwargs) bound_args.apply_defaults() - # Cache loaded weights, track loading progress + # Buffer loaded weights, track loading progress info.loaded_weights.append((param_name, bound_args)) num_loaded, ret = get_numel_loaded(original_loader, bound_args) info.load_numel += num_loaded @@ -163,19 +170,26 @@ def online_process_loader(*args, **kwargs): return online_process_loader -def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig): +def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelConfig): """ - Remove the outermost layer of weight loading wrappers. + Apply processing to any layers which were not layerwise processed during loading. + This includes attention layers and layers which have weight elements which are not + loaded (due to padding). This function should be applied after `initialize_layerwise_reload` is applied unwrap the layerwise weight loaders. - Also processes Attention/MLA layers, which must be processed after all other layers + :param model: model to finalize processing for + :param model_config: config needed for applying processing to attention layers """ - model._do_torchao_reload = model._original_do_torchao_reload + if hasattr(model, "_original_do_torchao_reload"): + model._do_torchao_reload = model._original_do_torchao_reload for layer in model.modules(): info = get_layerwise_info(layer) + if not info.can_load(): + info.reset() + continue # Attention/MLA layers are processed after all other layers if isinstance(layer, (Attention, MLAAttention)): @@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) "Layerwise reloading of Q/K/V scale weights is not implemented yet" ) + elif info.kernel_tensors is None: + raise NotImplementedError( + "Layerwise loading of Q/K/V scale weights is not implemented yet" + ) + else: _place_kernel_tensors(layer, info) layer.process_weights_after_loading(model_config.dtype) - # No weights were loaded, place kernel tensors back - elif info.can_process() and info.load_numel <= 0: - _place_kernel_tensors(layer, info) + # No weights were loaded + elif info.load_numel <= 0: + # first load but received no weights. This happens on dummy load + if info.kernel_tensors is None: + materialize_layer(layer) + + # reloading: place kernel tensors back as a fallback + else: + logger.warning("%s: Failed to load weights", layer.__class__.__name__) + _place_kernel_tensors(layer, info) # Process non-attention layers which did not load all elements. This can happen # if the created weight has extra padding elements which are not loaded - # Having too many of these delayed layers can lead to execess memory usage + # Having too many of these delayed layers can lead to excess memory usage # see Limitations(4) elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator] logger.debug("%s: Delayed processing", layer.__class__.__name__) @@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig) info.reset() +def finalize_layerwise_reload(*args, **kwargs): + finalize_layerwise_processing(*args, **kwargs) + + def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): """ - Finalize layer loading after all weights have been cached. + Finalize layer loading after all weights have been buffered. This function: 1. Materializes the layer onto the target device - 2. Loads all cached weights + 2. Loads all buffered weights 3. Runs quantization processing if applicable 4. Copies processed values back to original tensor storage """ # Materialize layer tensors onto device materialize_layer(layer) - # Reset FP8 online quantization flag so process_weights_after_loading + # Reset online quantization flag so process_weights_after_loading # will run again during reload if hasattr(layer, "_already_called_process_weights_after_loading"): delattr(layer, "_already_called_process_weights_after_loading") @@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): for param in get_layer_tensors(layer).values(): param.weight_loader = _get_original_loader(param) - # Load all cached weights into materialized layer (using original loaders) + # Load all buffered weights into materialized layer (using original loaders) for name, args in info.loaded_weights: param = getattr(layer, name) args.arguments["param"] = param @@ -239,13 +269,14 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Copy processed values into original tensor storage (preserves cudagraph refs) # this code is a no-op if not reloading (because kernel tensors is empty) - parameters, buffers = info.kernel_tensors - for name, param in parameters.items(): - param.data.copy_(getattr(layer, name)) - for name, buffer in buffers.items(): - buffer.data.copy_(getattr(layer, name)) + if info.kernel_tensors is not None: + parameters, buffers = info.kernel_tensors + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + buffer.data.copy_(getattr(layer, name)) - _place_kernel_tensors(layer, info) + _place_kernel_tensors(layer, info) info.reset() logger.debug("%s: Processed", layer.__class__.__name__) @@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo): for name in get_layer_tensors(layer): delattr(layer, name) + assert info.kernel_tensors is not None parameters, buffers = info.kernel_tensors for name, param in parameters.items(): layer.register_parameter(name, param) diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index af20236d1c9d..138b9f01d69b 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None: setattr(layer, name, materialize_meta_tensor(tensor)) -class MetaCopyCounter(TorchDispatchMode): +class CopyCounter(TorchDispatchMode): """ Tracks total number of elements modified with `copy_`. @@ -122,7 +122,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func is torch.ops.aten.copy_.default and args[0].device.type == "meta": + if func is torch.ops.aten.copy_.default: assert args[0].numel() == args[1].numel() self.copied_numel += args[0].numel() @@ -140,7 +140,6 @@ def get_numel_loaded( :return: number of elements loaded by the weight loader, the return value of the weight loader """ - assert args.arguments["param"].device.type == "meta" - with MetaCopyCounter() as counter: + with CopyCounter() as counter: return_value = weight_loader(*args.args, **args.kwargs) return counter.copied_numel, return_value diff --git a/vllm/model_executor/model_loader/reload/types.py b/vllm/model_executor/model_loader/reload/types.py index a7edbe79a75e..b1506fadcc71 100644 --- a/vllm/model_executor/model_loader/reload/types.py +++ b/vllm/model_executor/model_loader/reload/types.py @@ -16,8 +16,8 @@ class LayerReloadingInfo: # model format (meta), populated by `record_metadata_for_reloading` restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {})) - # kernel format (device) - kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {})) + # kernel format (device), used to copy into when reloading only + kernel_tensors: LayerTensors | None = None # track how many restored elements are ready for loading load_numel: int = 0 @@ -29,5 +29,5 @@ class LayerReloadingInfo: def reset(self): self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc] - def can_process(self) -> bool: + def can_load(self) -> bool: return self.load_numel_total is not None diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index fbaaef59de0b..37023d3f1f5c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1323,25 +1323,11 @@ def initialize_dummy_weights( is fixed, the random values generated by this function only depends on the parameter's number of elements and its data type. """ - - # Check if any module uses online quantization with meta device weights. - # If so, we'll skip initializing params on meta device since they'll be - # handled in `process_weights_after_loading`. - def uses_meta_device(module: torch.nn.Module) -> bool: - quant_method = getattr(module, "quant_method", None) - return getattr(quant_method, "uses_meta_device", False) - - has_online_quant = any(uses_meta_device(m) for m in model.modules()) - for param in model.state_dict().values(): - if has_online_quant and param.device == torch.device("meta"): - # For online quantization, weights are created on meta device and - # dummy weight init will happen in `process_weights_after_loading`. - continue - initialize_single_dummy_weight(param, low, high, seed) +@torch.no_grad() def initialize_single_dummy_weight( param: torch.Tensor, low: float = -1e-3, diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py index e4482d4fb63f..c48edb68f20a 100644 --- a/vllm/utils/argparse_utils.py +++ b/vllm/utils/argparse_utils.py @@ -31,14 +31,12 @@ class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpForma def _split_lines(self, text, width): """ 1. Sentences split across lines have their single newlines removed. - 2. Paragraphs and explicit newlines are split into separate lines. + 2. Paragraphs and lists are split into separate lines. 3. Each line is wrapped to the specified width (width of terminal). """ - # The patterns also include whitespace after the newline - single_newline = re.compile(r"(? tuple[BatchExecutionDescriptor, torch.Tensor | None]: + batch_desc = cudagraph_manager.dispatch( + num_reqs, num_tokens, uniform_token_count + ) + num_tokens_across_dp = None + if self.dp_size > 1: + batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( + cudagraph_manager, + batch_desc, + num_tokens, + num_reqs, + uniform_token_count, + self.dp_size, + self.dp_rank, + ) + return batch_desc, num_tokens_across_dp + + def _build_draft_attn_metadata( + self, + num_reqs: int, + num_reqs_padded: int, + num_tokens_padded: int, + max_query_len: int, + ) -> dict[str, Any] | None: + if not self.draft_attn_layer_names: + return None + + query_start_loc_cpu = ( + torch.arange(num_reqs_padded + 1, dtype=torch.int32, device="cpu").clamp_( + max=num_reqs + ) + * max_query_len + ) + block_tables = [ + x[:num_reqs_padded] for x in self.block_tables.input_block_tables + ] + slot_mappings = self.block_tables.slot_mappings[:, :num_tokens_padded] + attn_metadata = build_attn_metadata( + attn_groups=self.attn_groups, + num_reqs=num_reqs_padded, + num_tokens=num_tokens_padded, + query_start_loc_gpu=self.input_buffers.query_start_loc[ + : num_reqs_padded + 1 + ], + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=max_query_len, + seq_lens=self.input_buffers.seq_lens[:num_reqs_padded], + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + return attn_metadata + def capture_model(self) -> None: if self.num_speculative_steps == 1: return @@ -319,7 +382,6 @@ def propose( logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs - num_reqs_padded = input_batch.num_reqs_after_padding # NOTE(woosuk): For draft sampling, we only consider the temperature # and ignore the other sampling parameters such as top_k and top_p, # for simplicity and performance. @@ -366,69 +428,49 @@ def propose( self.max_num_reqs, ) - # Get batch descriptor and sync across DP ranks. - # Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode - - batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1) - num_tokens_across_dp = None - - if self.dp_size > 1: - batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( - self.cudagraph_manager, - batch_desc, - num_reqs, - num_reqs, - 1, # uniform_token_count - self.dp_size, - self.dp_rank, - ) - - if not (dummy_run and skip_attn_for_dummy_run): - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos, batch_desc.num_tokens - ) - - if batch_desc.cg_mode == CUDAGraphMode.FULL: - return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs] + # Each request produces exactly 1 token per draft decode step, + # enabling FULL cudagraph. + decode_batch_desc, num_tokens_across_dp = self._dispatch_and_sync_dp( + self.cudagraph_manager, + num_reqs, + num_reqs, + uniform_token_count=1, + ) - # Run eager or piecewise CUDA graph. attn_metadata_updated = None slot_mappings_updated = None if not (dummy_run and skip_attn_for_dummy_run): - query_start_loc_cpu = torch.arange( - num_reqs_padded + 1, dtype=torch.int32, device="cpu" - ) - block_tables = [ - x[:num_reqs_padded] for x in self.block_tables.input_block_tables - ] - - # FIXME(woosuk): This is UNSAFE!! - attn_metadata_updated = build_attn_metadata( - attn_groups=self.attn_groups, - num_reqs=num_reqs_padded, - num_tokens=num_reqs_padded, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=1, - seq_lens=self.input_buffers.seq_lens[:num_reqs_padded], - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, + # Build attention metadata and slot mappings for the draft + # decode steps. It is necessary to rebuild the attention + # metadata even when replaying the FULL cudagraph so that + # any attention metadata builder state is updated. + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, + self.input_buffers.query_start_loc[: num_reqs + 1], + pos, + decode_batch_desc.num_tokens, ) slot_mappings_updated = build_slot_mappings_by_layer( slot_mappings, self.kv_cache_config ) + attn_metadata_updated = self._build_draft_attn_metadata( + num_reqs=num_reqs, + num_reqs_padded=decode_batch_desc.num_reqs or num_reqs, + num_tokens_padded=decode_batch_desc.num_tokens, + max_query_len=1, + ) - self.generate_draft( - num_reqs, - batch_desc.num_tokens, - attn_metadata_updated, - slot_mappings_updated, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=batch_desc.cg_mode, - ) + if decode_batch_desc.cg_mode == CUDAGraphMode.FULL: + self.cudagraph_manager.run_fullgraph(decode_batch_desc) + else: + self.generate_draft( + num_reqs, + decode_batch_desc.num_tokens, + attn_metadata_updated, + slot_mappings_updated, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=decode_batch_desc.cg_mode, + ) return self.draft_tokens[:num_reqs]