diff --git a/.gitignore b/.gitignore index 896b38a12..056853376 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,7 @@ save* .log *.pid *.ipynb* +model/ +output_* +HiFloat4/ +datasets/ \ No newline at end of file diff --git a/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_fp8_wikitext.yml b/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_fp8_wikitext.yml new file mode 100644 index 000000000..d7a02d3c9 --- /dev/null +++ b/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_fp8_wikitext.yml @@ -0,0 +1,33 @@ +base: + seed: &seed 42 +model: + type: IndustrialCoder + path: model/IndustrialCoder-32B + tokenizer_mode: slow + torch_dtype: auto + # Reduce peak memory in catcher stage for large models. + use_cpu_to_save_cuda_mem_for_catcher: False +eval: + eval_pos: [fake_quant] + name: wikitext2 + download: True + seq_len: 2048 + bs: 1 + inference_per_block: False +quant: + method: RTN + weight: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 + symmetric: True + granularity: per_token + use_qtorch: True +save: + save_vllm: True + save_path: ./save_for_vllm/industrialcoder_rtn_fp8_wikitext/ diff --git a/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_int_awq_wikitext.yml b/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_int_awq_wikitext.yml new file mode 100644 index 000000000..7e19446af --- /dev/null +++ b/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_int_awq_wikitext.yml @@ -0,0 +1,41 @@ +base: + seed: &seed 42 +model: + type: IndustrialCoder + path: model/IndustrialCoder-32B + tokenizer_mode: slow + torch_dtype: auto + # Reduce peak memory in catcher stage for large models. + use_cpu_to_save_cuda_mem_for_catcher: False +calib: + name: pileval + download: True + # path: calib data path + n_samples: 128 + bs: -1 + seq_len: 512 + preproc: txt_general_preproc + seed: *seed +eval: + eval_pos: [fake_quant] + name: wikitext2 + download: True + seq_len: 2048 + bs: 20 + inference_per_block: True +quant: + method: Awq + weight: + bit: 4 + symmetric: True + granularity: per_group + group_size: 128 + need_pack: True + special: + trans: True + trans_version: v2 + weight_clip: True + quant_out: True +save: + save_vllm: True + save_path: ./save_for_vllm/industrialcoder_rtn_int_awq_wikitext/ diff --git a/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_int_gptq_wikitext.yml b/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_int_gptq_wikitext.yml new file mode 100644 index 000000000..fec3c68fe --- /dev/null +++ b/configs/quantization/backend/vllm/fp8/industrialcoder_rtn_int_gptq_wikitext.yml @@ -0,0 +1,43 @@ +base: + seed: &seed 42 +model: + type: IndustrialCoder + path: model/IndustrialCoder-32B + tokenizer_mode: slow + torch_dtype: auto + # Reduce peak memory in catcher stage for large models. + use_cpu_to_save_cuda_mem_for_catcher: False +calib: + name: wikitext2 + download: True + n_samples: 128 + # path: calib data path + bs: 1 + seq_len: 2048 + preproc: wikitext2_gptq + seed: *seed +eval: + eval_pos: [fake_quant] + name: wikitext2 + download: True + seq_len: 2048 + bs: 20 + inference_per_block: True +quant: + method: GPTQ + weight: + bit: 4 + symmetric: True + granularity: per_group + group_size: 128 + need_pack: True + special: + actorder: True + static_groups: True + percdamp: 0.01 + blocksize: 128 + true_sequential: True + quant_out: True +save: + save_vllm: True + save_path: ./save_for_vllm/industrialcoder_rtn_int_gptq_wikitext/ diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml new file mode 100644 index 000000000..bf66b40bc --- /dev/null +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -0,0 +1,54 @@ +base: + seed: &seed 42 +model: + type: Wan2T2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/model/Wan2.2-T2V-A14B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ + inference_per_block: True +quant: + video_gen: + method: Awq + weight: + # quant_type: int-quant + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + # quant_type: int-qu + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ./save_for_lightx2v/wan2_2_t2v/awq_w_a/original/ diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml index 680fab43b..1b1097ad7 100755 --- a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -2,7 +2,7 @@ base: seed: &seed 42 model: type: WanI2V - path: /path/to/model + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B/ torch_dtype: auto calib: name: i2v @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 8 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 8 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_i2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml new file mode 100644 index 000000000..adba728d0 --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml @@ -0,0 +1,57 @@ +# Wan2.1 I2V FP8 量化配置示例 +# 这是一个快速开始的配置文件,请根据实际情况修改路径 + +base: + seed: &seed 42 + +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的 Wan2.1 I2V 模型路径 + torch_dtype: auto + +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为你的校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed + +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: /path/to/eval/data # 修改为你的评估数据路径 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_fp8/ + +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数,范围 0.5-1.0 + +save: + save_lightx2v: True # 保存为 lightx2v 兼容格式 + save_path: /path/to/save/quantized/model # 修改为你的保存路径 diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml index 14d05479d..ec6d8714e 100755 --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,7 +20,7 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml new file mode 100755 index 000000000..f140839e3 --- /dev/null +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: WanT2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-1.3B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ../lightx2v/wan_t2v_awq_w_a_s/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml deleted file mode 100755 index b6a53b0e0..000000000 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ /dev/null @@ -1,32 +0,0 @@ -base: - seed: &seed 42 -model: - type: WanT2V - path: /path/to/wan_t2v - torch_dtype: auto -eval: - eval_pos: [transformed, fake_quant] - type: video_gen - name: t2v - download: False - path: ../assets/wan_t2v/eval/ - bs: 1 - target_height: 480 - target_width: 832 - num_frames: 81 - guidance_scale: 5.0 - output_video_path: ./output_videos_rtn/ -quant: - video_gen: - method: RTN - weight: - bit: 6 - symmetric: True - granularity: per_channel - act: - bit: 6 - symmetric: True - granularity: per_token -save: - save_lightx2v: True - save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml index 7d65f31fc..f76edd294 100755 --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,26 +20,30 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 num_frames: 81 guidance_scale: 5.0 - output_video_path: ./output_videos_sq/ + output_video_path: ./output_videos_awq/ quant: video_gen: - method: SmoothQuant + method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel + group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: - alpha: 0.7 + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/docs/wan2.1_quantization_guide.md b/docs/wan2.1_quantization_guide.md new file mode 100644 index 000000000..eeef5ac63 --- /dev/null +++ b/docs/wan2.1_quantization_guide.md @@ -0,0 +1,288 @@ +# Wan2.1 视频生成模型量化指南 + +## 概述 + +llmc 框架现已全面支持 Wan2.1 系列视频生成模型的量化,并提供真正量化的 INT8/FP8 权重导出,与 lightx2v 推理框架兼容。 + +## 支持的模型类型 + +- **WanI2V**: Image-to-Video (图像到视频) +- **WanT2V**: Text-to-Video (文本到视频) + +## 支持的量化方法 + +### FP8 量化 (推荐) + +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml` + +**特点**: +- 使用 E4M3 FP8 格式 (8-bit 浮点数,4位指数,3位尾数) +- SmoothQuant 算法,平衡权重和激活的量化难度 +- 适合 GPU 推理,性能损失小 + +**量化配置**: +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数 +``` + +### INT8 量化 + +#### 1. RTN (Round-to-Nearest) +**配置文件**: `configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml` + +**特点**: +- 最简单的量化方法 +- 直接四舍五入到最近的量化级别 +- 速度快,精度略低 + +#### 2. AWQ (Activation-aware Weight Quantization) +**配置文件**: `configs/quantization/video_gen/wan_i2v/awq_w_a.yaml` + +**特点**: +- 基于激活分布优化权重量化 +- 保护重要通道,减少精度损失 +- 需要校准数据 + +#### 3. SmoothQuant +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml` + +**特点**: +- 平衡权重和激活的量化难度 +- 数学上等价于平滑激活异常值 +- 通常提供最佳精度 + +### LoRA 模型量化 + +支持对 LoRA 适配器模型的量化: +- `smoothquant_w_a_int8_lora.yaml` +- `rtn_w_a_lora.yaml` + +## 运行步骤 + +### 1. 准备环境 + +```bash +# 设置 llmc 路径 +export llmc=/path/to/llmc +export PYTHONPATH=$llmc:$PYTHONPATH + +# 设置 GPU +export CUDA_VISIBLE_DEVICES=0 +``` + +### 2. 准备校准数据 + +为 I2V 模型准备校准数据: +``` +assets/wan_i2v/calib/ +├── image_1.jpg +├── image_2.jpg +└── ... +``` + +为 T2V 模型准备校准数据: +``` +assets/wan_t2v/calib/ +├── prompt_1.txt +├── prompt_2.txt +└── ... +``` + +### 3. 修改配置文件 + +编辑对应的 YAML 配置文件,设置: +- `model.path`: Wan2.1 模型路径 +- `calib.path`: 校准数据路径 +- `save.save_path`: 量化模型保存路径 + +**示例 (FP8 量化)**: +```yaml +base: + seed: 42 +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的模型路径 + torch_dtype: auto +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 +save: + save_lightx2v: True + save_path: /path/to/save/quantized/model # 修改为保存路径 +``` + +### 4. 运行量化 + +#### 使用脚本运行 (推荐) + +```bash +# 运行 FP8 量化 (I2V) +./run_llmc.sh wan_i2v_fp8 + +# 运行 INT8 RTN 量化 (I2V) +./run_llmc.sh wan_i2v_int8_rtn + +# 运行 INT8 AWQ 量化 (I2V) +./run_llmc.sh wan_i2v_int8_awq + +# 运行 INT8 SmoothQuant 量化 (I2V) +./run_llmc.sh wan_i2v_int8_smoothquant + +# 运行 T2V 模型量化 +./run_llmc.sh wan_t2v_int8_rtn +./run_llmc.sh wan_t2v_int8_awq +./run_llmc.sh wan_t2v_int8_smoothquant +``` + +#### 直接运行命令 + +```bash +torchrun \ +--nnodes 1 \ +--nproc_per_node 1 \ +--rdzv_id $RANDOM \ +--rdzv_backend c10d \ +--rdzv_endpoint 127.0.0.1:29500 \ +${llmc}/llmc/__main__.py \ +--config configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml \ +--task_id my_quant_task +``` + +### 5. 监控进度 + +```bash +# 查看日志 +tail -f wan_i2v_fp8.log + +# 查看进程 +ps aux | grep __main__.py +``` + +### 6. 停止任务 + +```bash +# 使用保存的 PID 文件 +xargs kill -9 < wan_i2v_fp8.pid +``` + +## 配置参数说明 + +### 模型配置 +- `type`: 模型类型 (`WanI2V` 或 `WanT2V`) +- `path`: 模型权重路径 +- `torch_dtype`: 数据类型 (`auto`, `bfloat16`, `float32`) + +### 校准配置 +- `sample_steps`: 采样步数 (通常 20-40) +- `bs`: 批大小 (通常 1,视频生成显存占用大) +- `target_height`: 目标视频高度 (默认 480) +- `target_width`: 目标视频宽度 (默认 832) +- `num_frames`: 视频帧数 (默认 81) +- `guidance_scale`: CFG 引导强度 (默认 5.0) + +### 量化配置 +- `method`: 量化方法 (`RTN`, `Awq`, `SmoothQuant`) +- `weight.bit`: 权重位宽 (8, e4m3) +- `act.bit`: 激活位宽 (8, e4m3) +- `granularity`: 量化粒度 (`per_channel`, `per_token`) +- `special.alpha`: SmoothQuant 平衡参数 (0.5-1.0) + +## 在 lightx2v 中使用量化模型 + +### 1. 配置 lightx2v + +编辑 `lightx2v/configs/quantization/wan_i2v.json`: +```json +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "dit_quantized_ckpt": "/path/to/quantized/model", + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm" +} +``` + +对于 FP8 模型,设置 `"dit_quant_scheme": "fp8"`。 + +### 2. 运行推理 + +```bash +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path /path/to/original/model \ +--config_json configs/quantization/wan_i2v.json \ +--prompt "Your prompt here" \ +--image_path /path/to/input/image.jpg \ +--save_result_path output.mp4 +``` + +## 性能建议 + +1. **FP8 vs INT8**: + - FP8: 精度更高,适合对质量要求高的场景 + - INT8: 压缩率更高,适合对速度要求高的场景 + +2. **量化方法选择**: + - 快速原型: RTN + - 平衡精度和速度: SmoothQuant + - 最高精度: AWQ + +3. **校准数据**: + - 使用 10-50 个样本 + - 覆盖典型使用场景 + - I2V: 使用多样化图像 + - T2V: 使用多样化文本描述 + +4. **资源需求**: + - GPU: 建议 24GB+ 显存 + - 校准时间: 30分钟 - 2小时 (取决于数据量) + - 存储空间: 量化后模型约原模型 25-50% 大小 + +## 故障排除 + +### 显存不足 +- 减小 `bs` 到 1 +- 减小 `num_frames` +- 减小 `target_height` 和 `target_width` + +### 量化精度损失过大 +- 尝试 SmoothQuant 方法 +- 增加校准数据数量 +- 调整 `alpha` 参数 (0.5-1.0) + +### lightx2v 兼容性问题 +- 确保使用 `save_lightx2v: True` +- 检查 `dit_quant_scheme` 设置 +- 确认量化模型路径正确 + +## 参考 + +- lightx2v 文档: [lightx2v 项目地址] +- llmc 框架: [llmc 项目地址] +- Wan2.1 模型: [模型地址] diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index 72823d1bd..380e8f42c 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -31,11 +31,15 @@ def __init__(self, model, compress_config, input, padding_mask, config): def run_block_loop(self): for i in range(len(self.blocks)): self.block_idx = i + if self.input and hasattr(self.model, 'get_blockwise_input'): + self.input = self.model.get_blockwise_input(self.block_idx, self.input) logger.info( f'\nblock index: {self.block_idx}/{len(self.blocks)} ' f'\nblock: {self.blocks[self.block_idx]}' ) self.block_opt(self.blocks[self.block_idx]) + if self.input and hasattr(self.model, 'set_blockwise_input'): + self.model.set_blockwise_input(self.block_idx, self.input) if hasattr(self, 'save_scale') and self.save_scale: os.makedirs(self.scale_path, exist_ok=True) diff --git a/llmc/compression/quantization/__init__.py b/llmc/compression/quantization/__init__.py index 2c08343e2..07b4f5967 100644 --- a/llmc/compression/quantization/__init__.py +++ b/llmc/compression/quantization/__init__.py @@ -10,7 +10,7 @@ from .ntweak import NormTweaking from .omniq import OmniQuant from .osplus import OsPlus -from .quant import FloatQuantizer, IntegerQuantizer +from .quant import FloatQuantizer, HiFloat4Quantizer, IntegerQuantizer from .quarot import Quarot from .quik import QUIK from .rtn import RTN diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 5a2232699..0c3d5474f 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -35,7 +35,12 @@ _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer +from .quant import ( + FloatQuantizer, + HiFloat4Quantizer, + IntegerQuantizer, + Weight48IntegerQuantizer, +) class BaseBlockwiseQuantization(BlockwiseOpt): @@ -157,6 +162,8 @@ def set_quant_config(self): self.weight_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.weight_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.weight_quant_module = HiFloat4Quantizer logger.info(f'The used Weight Quant Module is {self.weight_quant_module}') self.wquantizer = self.weight_quant_module(**self.quant_config['weight']) @@ -175,6 +182,13 @@ def set_quant_config(self): self.act_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.act_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.act_quant_module = HiFloat4Quantizer + else: + raise ValueError( + f"Unsupported act quant_type: {quant_type}. " + "Supported: int-quant, float-quant, hif4." + ) self.quant_config['act']['tp'] = self.tp self.aquantizer = self.act_quant_module(**self.quant_config['act']) self.act_static = self.quant_config['act'].get('static', False) diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 1c0e6e455..fdfe47bcd 100755 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -445,6 +445,13 @@ def __init__(self, weight, eps=1e-6): def __repr__(self): return 'LlmcQwen2RMSNorm()' +class LlmcIndustrialCoderRMSNorm(LlmcLlamaRMSNorm): + def __init__(self, weight, eps=1e-6): + super().__init__(weight, eps) + + def __repr__(self): + return 'LlmcIndustrialCoderRMSNorm()' + class LlmcMixtralRMSNorm(LlmcLlamaRMSNorm): def __init__(self, weight, eps=1e-6): @@ -892,7 +899,6 @@ def new(cls, module, w_q, quant_config): bias = None need_pack = quant_config['weight'].get('need_pack', False) - if quant_config['weight']['granularity'] == 'per_block': scales_name = 'weight_scale_inv' else: @@ -1187,6 +1193,7 @@ def __repr__(self): 'Mixtral': LlmcMixtralRMSNorm, 'Interlm2': LlmcInternLM2RMSNorm, 'Qwen2': LlmcQwen2RMSNorm, + 'IndustrialCoder': LlmcIndustrialCoderRMSNorm, 'Gemma2': LlmcGemma2RMSNorm, 'MiniCPM': LlmcMiniCPMRMSNorm, 'Starcoder': LlmcLayerNorm, diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 2c24c03a8..55cd791a1 100755 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,4 +1,6 @@ import gc +import os +import sys import torch from loguru import logger @@ -1229,6 +1231,102 @@ def __repr__(self): ) +def _get_hif4_quant_cy(): + """Lazy import HiFloat4 quant_cy (QType, quant_dequant_float) from HiFloat4/hif4_gpu.""" + _repo_root = os.path.dirname( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + ) + _hif4_gpu = os.path.join(_repo_root, 'HiFloat4', 'hif4_gpu') + if _hif4_gpu not in sys.path: + sys.path.insert(0, _hif4_gpu) + try: + from quant_cy import QType, quant_dequant_float + return QType, quant_dequant_float + except Exception as e: + raise ImportError( + 'HiFloat4 4-bit quantization requires the HiFloat4/hif4_gpu package. ' + 'Ensure HiFloat4 is available at repo_root/HiFloat4/hif4_gpu and built.' + ) from e + + +class HiFloat4Quantizer(BaseQuantizer): + """4-bit HiFloat (hif4) simulation quantizer using HiFloat4 quant_dequant_float. + + Uses the HiFloat4 library's quant_dequant_float for block-wise float 4-bit + quantization. No scales/zeros; quantization is done per block along the last dim. + Only supports fake (simulation) quantization; real weight packing is not implemented. + """ + + def __init__(self, bit=4, symmetric=None, granularity=None, **kwargs): + super().__init__(bit, symmetric, granularity, **kwargs) + self.quant_type = 'hif4' + self.q_dim = kwargs.get('hif4_qdim', -1) + self.force_py = kwargs.get('force_py', False) + self.force_fp32 = kwargs.get('force_fp32', True) + self._QType = None + self._quant_dequant_float = None + + def _ensure_hif4(self): + if self._quant_dequant_float is None: + self._QType, self._quant_dequant_float = _get_hif4_quant_cy() + + def fake_quant_act_static(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_act_dynamic(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_static(self, weight, args): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_dynamic(self, weight, args={}): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def real_quant_weight_static(self, weight, args): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def real_quant_weight_dynamic(self, weight, args={}): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def __repr__(self): + return ( + f'HiFloat4Quantizer(quant_type=hif4, q_dim={self.q_dim}, ' + f'force_py={self.force_py}, force_fp32={self.force_fp32})' + ) + + class Weight48IntegerQuantizer(BaseQuantizer): # flake8: noqa def __init__(self, bit, bit4, bit8, **kwargs): diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 7af3de73a..c2295e6b3 100755 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -167,9 +167,10 @@ def get_batch_process(self, samples): return calib_model_inputs def get_calib_dataset(self): - samples = self.calib_dataset[ - int(os.environ['RANK'])::int(os.environ['WORLD_SIZE']) - ] + samples = self.calib_dataset.shard( + num_shards=int(os.environ['WORLD_SIZE']), + index=int(os.environ['RANK']) + ) logger.info(f'len(samples) rank : {len(samples)}') calib_model_inputs = self.get_calib_model_inputs(samples) diff --git a/llmc/eval/eval_video_generate.py b/llmc/eval/eval_video_generate.py index 0f99ff6c9..726187c0b 100755 --- a/llmc/eval/eval_video_generate.py +++ b/llmc/eval/eval_video_generate.py @@ -23,6 +23,7 @@ def __init__(self, model, config): self.target_width = self.eval_cfg.get('target_width', 832) self.num_frames = self.eval_cfg.get('num_frames', 81) self.guidance_scale = self.eval_cfg.get('guidance_scale', 5.0) + self.guidance_scale_2 = self.eval_cfg.get('guidance_scale_2', None) self.fps = self.eval_cfg.get('fps', 15) @torch.no_grad() @@ -56,14 +57,17 @@ def t2v_eval(self, model, testenc, bs, eval_pos): assert bs == 1, 'Only support eval bs=1' for i, data in enumerate(testenc): - output = model.Pipeline( - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=self.target_height, - width=self.target_width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'), @@ -77,15 +81,18 @@ def i2v_eval(self, model, testenc, bs, eval_pos): for i, data in enumerate(testenc): image, width, height = self.pre_process(model, data['image']) - output = model.Pipeline( - image=image, - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=height, - width=width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'image': image, + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': height, + 'width': width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, @@ -98,9 +105,9 @@ def i2v_eval(self, model, testenc, bs, eval_pos): @torch.no_grad() def eval_func(self, model, testenc, bs, eval_pos): assert bs == 1, 'Evaluation only supports batch size = 1.' - assert self.model_type in ['WanT2V', 'WanI2V'], ( + assert self.model_type in ['WanT2V', 'WanI2V', 'Wan2T2V'], ( f"Unsupported model type '{self.model_type}'.\n" - 'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.' + 'Only Wan video generation models (WanT2V, WanI2V, Wan2T2V) are supported.' ) if self.eval_dataset_name == 't2v': return self.t2v_eval(model, testenc, bs, eval_pos) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 83d746254..d100c9755 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -9,6 +9,7 @@ from .internomni import InternOmni from .internvl2 import InternVL2 from .internvl3_5 import InternVL3_5 +from .industrialcoder import IndustrialCoder from .llama import Llama from .llava import Llava from .llava_hf import LlavaHf @@ -37,3 +38,4 @@ from .vit import Vit from .wan_i2v import WanI2V from .wan_t2v import WanT2V +from .wan2_2_t2v import Wan2T2V diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 4d7dda2ae..25393a871 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -119,7 +119,7 @@ def has_bias(self): pass def build_tokenizer(self): - if self.model_type not in ['Vit', 'WanT2V', 'WanI2V']: + if self.model_type not in ['Vit', 'WanT2V', 'WanI2V', 'Wan2T2V']: assert self.tokenizer_mode in ['fast', 'slow'] self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, use_fast=self.tokenizer_mode, trust_remote_code=True @@ -129,7 +129,7 @@ def build_tokenizer(self): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token else: - self.tokenizer = None + self.tokenizer = None def get_tokenizer(self): return self.tokenizer diff --git a/llmc/models/industrialcoder.py b/llmc/models/industrialcoder.py new file mode 100644 index 000000000..b4a77ee36 --- /dev/null +++ b/llmc/models/industrialcoder.py @@ -0,0 +1,122 @@ +""" +IndustrialCoder (IQuestCoder) model adapter for LLMC quantization. + +Model structure follows IQuestCoderForCausalLM / IQuestCoderModel: + - model.model.embed_tokens, model.model.layers, model.model.norm, model.model.rotary_emb + - model.lm_head + - Each layer: input_layernorm, self_attn (q_proj, k_proj, v_proj, o_proj), + post_attention_layernorm, mlp (gate_proj, up_proj, down_proj) + +Layout is the same as Qwen2-style decoders; this module provides a dedicated +adapter so IndustrialCoder is supported as its own model type, not as Qwen2. +""" + +from importlib.metadata import version + +import packaging + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class IndustrialCoder(BaseModel): + """IndustrialCoder (IQuestCoder) – standalone adapter for blockwise quantization.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def find_blocks(self): + # IQuestCoderForCausalLM.model -> IQuestCoderModel with .layers + self.blocks = self.model.model.layers + + def find_embed_layers(self): + base = self.model.model + self.embed_tokens = base.embed_tokens + if hasattr(base, 'rotary_emb') and ( + packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0') + ): + self.rotary_emb = base.rotary_emb + + def find_block_name(self): + self.block_name_prefix = 'model.layers' + + def get_embed_layers(self): + return [self.embed_tokens] + + def get_attn_in_block(self, block): + return {'self_attn': block.self_attn} + + def get_attention_rotary_layers(self): + if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'): + return [self.rotary_emb] if hasattr(self, 'rotary_emb') and self.rotary_emb is not None else [] + return [] + + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + + def get_layers_except_blocks(self): + if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'): + rotary = [self.rotary_emb] if hasattr(self, 'rotary_emb') and self.rotary_emb is not None else [] + return [self.embed_tokens] + rotary + [self.model.model.norm, self.model.lm_head] + return [self.embed_tokens, self.model.model.norm, self.model.lm_head] + + def skip_layer_name(self): + return ['lm_head'] + + def has_bias(self): + # IQuestCoder config: attention_bias, mlp_bias (often False) + cfg = self.model_config + return getattr(cfg, 'attention_bias', False) or getattr(cfg, 'mlp_bias', False) + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.post_attention_layernorm, + } + + def get_subsets_in_block(self, block): + # Same layout as Qwen2 / IQuestCoderDecoderLayer + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.o_proj': block.self_attn.o_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.o_proj'], + 'inspect': block.self_attn.o_proj, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.gate_proj': block.mlp.gate_proj, + 'mlp.up_proj': block.mlp.up_proj, + }, + 'prev_op': [block.post_attention_layernorm], + 'input': ['mlp.gate_proj'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.down_proj}, + 'prev_op': [block.mlp.up_proj], + 'input': ['mlp.down_proj'], + 'inspect': block.mlp.down_proj, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py new file mode 100755 index 000000000..d799a4cec --- /dev/null +++ b/llmc/models/wan2_2_t2v.py @@ -0,0 +1,264 @@ +import gc +import inspect +from collections import defaultdict + +import torch +import torch.nn as nn +from diffusers import AutoencoderKLWan, WanPipeline +from loguru import logger + +from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class Wan2T2V(BaseModel): + """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + if 'calib' in config: + self.calib_bs = config.calib.bs + self.sample_steps = config.calib.sample_steps + self.target_height = config.calib.get('target_height', 480) + self.target_width = config.calib.get('target_width', 832) + self.num_frames = config.calib.get('num_frames', 81) + self.guidance_scale = config.calib.get('guidance_scale', 5.0) + self.guidance_scale_2 = config.calib.get('guidance_scale_2', 3.0) + else: + self.sample_steps = None + + def build_model(self): + vae = AutoencoderKLWan.from_pretrained( + self.model_path, + subfolder='vae', + torch_dtype=torch.float32, + use_safetensors=True, + ) + # Wan2.2: one pipeline, two transformer experts (transformer + transformer_2). + # Pipeline switches by SNR; both use WanTransformer3DModel with same block layout as Wan2.1. + self.Pipeline = WanPipeline.from_pretrained( + self.model_path, + vae=vae, + torch_dtype=torch.bfloat16, + use_safetensors=True, + ) + self.find_llmc_model() + # Wrap both experts with LlmcWanTransformerBlock (same as Wan2.1 per-block layout). + for block_idx, block in enumerate(self.Pipeline.transformer.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer.blocks[block_idx] = new_block + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer_2.blocks[block_idx] = new_block + self.num_transformer_blocks = len(self.Pipeline.transformer.blocks) + self.blocks = list(self.Pipeline.transformer.blocks) + list(self.Pipeline.transformer_2.blocks) + logger.info( + 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' + ) + else: + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + logger.info('Wan2.2: single transformer wrapped (40 blocks).') + logger.info('Model: %s', self.model) + + def find_llmc_model(self): + self.model = self.Pipeline.transformer + + def find_blocks(self): + self.blocks = list(self.Pipeline.transformer.blocks) + self.num_transformer_blocks = len(self.blocks) + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + self.blocks += list(self.Pipeline.transformer_2.blocks) + + def _expert_name_from_block_idx(self, block_idx): + if block_idx < self.num_transformer_blocks: + return 'transformer' + return 'transformer_2' + + def get_blockwise_input(self, block_idx, fallback_input): + if not hasattr(self, 'blockwise_inputs'): + return fallback_input + return self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] + + def set_blockwise_input(self, block_idx, block_input): + if not hasattr(self, 'blockwise_inputs'): + return + self.blockwise_inputs[self._expert_name_from_block_idx(block_idx)] = block_input + + def get_catcher(self, first_block_input): + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.step = 0 + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + first_block_input['kwargs'].append(kwargs) + self.step += 1 + if self.step == sample_steps: + raise ValueError + else: + return self.module(*args) + + return Catcher + + @torch.no_grad() + def collect_first_block_input(self, calib_data, padding_mask=None): + first_block_input = { + 'transformer': defaultdict(list), + 'transformer_2': defaultdict(list), + } + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module, expert_name): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.expert_name = expert_name + + def _to_cpu(self, x): + if torch.is_tensor(x): + return x.detach().cpu() + if isinstance(x, tuple): + return tuple(self._to_cpu(t) for t in x) + return x + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + cur_num = len(first_block_input[self.expert_name]['data']) + if cur_num < sample_steps: + first_block_input[self.expert_name]['data'].append( + args[0].detach().cpu() if torch.is_tensor(args[0]) else args[0] + ) + first_block_input[self.expert_name]['kwargs'].append( + {k: self._to_cpu(v) for k, v in kwargs.items()} + ) + if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input): + raise ValueError + return self.module(*args) + + first_block = self.Pipeline.transformer.blocks[0] + self.Pipeline.transformer.blocks[0] = Catcher(first_block, 'transformer') + first_block_2 = None + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + first_block_2 = self.Pipeline.transformer_2.blocks[0] + self.Pipeline.transformer_2.blocks[0] = Catcher(first_block_2, 'transformer_2') + + self.Pipeline.to('cuda') + for data in calib_data: + try: + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if hasattr(self, 'guidance_scale_2'): + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + self.Pipeline(**pipe_kw) + except ValueError: + pass + gc.collect() + torch.cuda.empty_cache() + + self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module + if first_block_2 is not None: + self.Pipeline.transformer_2.blocks[0] = self.Pipeline.transformer_2.blocks[0].module + self.Pipeline.to('cpu') + + assert len(first_block_input['transformer']['data']) > 0, 'Catch transformer input data failed.' + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + assert len(first_block_input['transformer_2']['data']) > 0, \ + 'Catch transformer_2 input data failed.' + + self.blockwise_inputs = first_block_input + self.first_block_input = self.blockwise_inputs['transformer'] + self.n_samples = sum(len(v['data']) for v in self.blockwise_inputs.values()) + logger.info( + 'Retrieved Wan2.2 calibration samples: transformer=%s, transformer_2=%s.', + len(self.blockwise_inputs['transformer']['data']), + len(self.blockwise_inputs['transformer_2']['data']), + ) + + def get_padding_mask(self): + return None + + def has_bias(self): + return True + + def __str__(self): + return '\nWan2.2 MoE Model:\n%s\nTotal params: ~27B (14B active per step)' % ( + str(self.model), + ) + + def get_layernorms_in_block(self, block): + return { + 'affine_norm1': block.affine_norm1, + 'norm2': block.norm2, + 'affine_norm3': block.affine_norm3, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'attn1.to_q': block.attn1.to_q, + 'attn1.to_k': block.attn1.to_k, + 'attn1.to_v': block.attn1.to_v, + }, + 'prev_op': [block.affine_norm1], + 'input': ['attn1.to_q'], + 'inspect': block.attn1, + 'has_kwargs': True, + 'sub_keys': {'rotary_emb': 'rotary_emb'}, + }, + { + 'layers': { + 'attn2.to_q': block.attn2.to_q, + }, + 'prev_op': [block.norm2], + 'input': ['attn2.to_q'], + 'inspect': block.attn2, + 'has_kwargs': True, + 'sub_keys': {'encoder_hidden_states': 'encoder_hidden_states'}, + }, + { + 'layers': { + 'ffn.net.0.proj': block.ffn.net[0].proj, + }, + 'prev_op': [block.affine_norm3], + 'input': ['ffn.net.0.proj'], + 'inspect': block.ffn, + 'has_kwargs': True, + }, + ] + + def find_embed_layers(self): + pass + + def get_embed_layers(self): + pass + + def get_layers_except_blocks(self): + pass + + def skip_layer_name(self): + pass diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py index 885bccda3..ec1f0650c 100755 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -31,10 +31,13 @@ def __init__(self, config, device_map=None, use_cache=False): def build_model(self): vae = AutoencoderKLWan.from_pretrained( - self.model_path, subfolder='vae', torch_dtype=torch.float32 + self.model_path, subfolder='vae', torch_dtype=torch.float32, use_safetensors=True ) + # self.Pipeline = WanPipeline.from_pretrained( + # self.model_path, vae=vae, torch_dtype=torch.bfloat16 + # ) self.Pipeline = WanPipeline.from_pretrained( - self.model_path, vae=vae, torch_dtype=torch.bfloat16 + self.model_path, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True ) self.find_llmc_model() self.find_blocks() diff --git a/llmc/utils/export_vllm.py b/llmc/utils/export_vllm.py index 1128c3df9..87271daf4 100755 --- a/llmc/utils/export_vllm.py +++ b/llmc/utils/export_vllm.py @@ -31,7 +31,8 @@ def update_vllm_quant_config( with open(config_file, 'w') as file: json.dump(config_vllm, file, indent=4) return - elif config.quant.weight.get('granularity', 'per_block'): + # elif config.quant.weight.get('granularity', 'per_block'): + elif config.quant.weight.get('granularity') == 'per_block': quant_config = { 'activation_scheme': 'dynamic', 'fmt': 'e4m3', diff --git a/scripts/run_llmc.sh b/scripts/run_llmc.sh index d90877f69..0bd994f9e 100755 --- a/scripts/run_llmc.sh +++ b/scripts/run_llmc.sh @@ -1,17 +1,20 @@ -#!/bin/bash - -# export CUDA_VISIBLE_DEVICES=0,1 - -llmc=/path/to/llmc +export PATH=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin:$PATH +export PYTHON=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/python +export PIP=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/pip +export HF_ENDPOINT=https://hf-mirror.com +cd /mnt/lm_data_afs/wangzining/charles/lab/llmc +# model_name=wan_t2v +model_name=wan2_2_t2v +task_name=awq_w_a +# task_name=awq_w_a_s +log_name=${model_name}_${task_name} +rm -rf ./save_for_lightx2v/${model_name}/${task_name}/original +llmc=. export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=awq_w_only -config=${llmc}/configs/quantization/methods/Awq/awq_w_only.yml - +config=${llmc}/configs/quantization/video_gen/${model_name}/${task_name}.yaml nnodes=1 nproc_per_node=1 - find_unused_port() { while true; do port=$(shuf -i 10000-60000 -n 1) @@ -22,25 +25,15 @@ find_unused_port() { done } UNUSED_PORT=$(find_unused_port) - - MASTER_ADDR=127.0.0.1 MASTER_PORT=$UNUSED_PORT task_id=$UNUSED_PORT -nohup \ + torchrun \ --nnodes $nnodes \ --nproc_per_node $nproc_per_node \ --rdzv_id $task_id \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ -${llmc}/llmc/__main__.py --config $config --task_id $task_id \ -> ${task_name}.log 2>&1 & - -sleep 2 -ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pid - -# You can kill this program by -# xargs kill -9 < xxx.pid -# xxx.pid is ${task_name}.pid file \ No newline at end of file +${llmc}/llmc/__main__.py --config $config --task_id $task_id |tee ${log_name}.log \ No newline at end of file diff --git a/scripts/run_llmc_industrialcoder_fp8.sh b/scripts/run_llmc_industrialcoder_fp8.sh new file mode 100755 index 000000000..6dad8c006 --- /dev/null +++ b/scripts/run_llmc_industrialcoder_fp8.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +model_name=industrialcoder +method_name=rtn_fp8 +dataset_name=wikitext + +log_name=${model_name}_${method_name}_${dataset_name} +rm -rf ./save_for_vllm/${log_name}/ +llmc=. +export PYTHONPATH=$llmc:$PYTHONPATH +config=${llmc}/configs/quantization/backend/vllm/fp8/${log_name}.yml +nnodes=1 +nproc_per_node=8 + +find_unused_port() { + while true; do + port=$(shuf -i 10000-60000 -n 1) + if ! ss -tuln | grep -q ":$port "; then + echo "$port" + return 0 + fi + done +} +UNUSED_PORT=$(find_unused_port) +MASTER_ADDR=127.0.0.1 +MASTER_PORT=$UNUSED_PORT +task_id=$UNUSED_PORT + + +torchrun \ +--nnodes $nnodes \ +--nproc_per_node $nproc_per_node \ +--rdzv_id $task_id \ +--rdzv_backend c10d \ +--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ +${llmc}/llmc/__main__.py --config $config --task_id $task_id |tee ${log_name}.log diff --git a/scripts/test_load_vllm_quant_state_dict.py b/scripts/test_load_vllm_quant_state_dict.py new file mode 100644 index 000000000..9cd40bb67 --- /dev/null +++ b/scripts/test_load_vllm_quant_state_dict.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +""" +Load the vLLM quant model from save_for_vllm/industrialcoder_rtn_fp8_wikitext/vllm_quant_model +and print state_dict keys (and optionally full state_dict). +""" +import argparse +import os +import sys + +# allow running from repo root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +from transformers import AutoConfig, AutoModelForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="Load vLLM quant model and print state_dict") + parser.add_argument( + "model_dir", + nargs="?", + default="save_for_vllm/industrialcoder_rtn_int_awq_wikitext/", + help="Path to vllm_quant_model directory", + ) + parser.add_argument( + "--list-keys", + action="store_true", + help="Print all state_dict keys (default: only summary and weight_scale keys)", + ) + parser.add_argument( + "--no-load-weights", + action="store_true", + help="Only load config and print expected keys from index (no full model load)", + ) + parser.add_argument( + "--cpu", + action="store_true", + help="Load model on CPU (default: load on GPU)", + ) + args = parser.parse_args() + + model_dir = os.path.abspath(args.model_dir) + if not os.path.isdir(model_dir): + print(f"Error: not a directory: {model_dir}") + sys.exit(1) + + config_path = os.path.join(model_dir, "config.json") + if not os.path.isfile(config_path): + print(f"Error: config.json not found in {model_dir}") + sys.exit(1) + + print(f"Loading from: {model_dir}\n") + + if args.no_load_weights: + # Only inspect index / config without loading full model + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + print("Config model_type:", getattr(config, "model_type", "?")) + index_path = os.path.join(model_dir, "model.safetensors.index.json") + if os.path.isfile(index_path): + import json + with open(index_path) as f: + index = json.load(f) + meta = index.get("metadata", {}) + weight_map = index.get("weight_map", {}) + print(f"Total tensors in index: {len(weight_map)}") + print("\nFirst 20 keys in weight_map:") + for i, k in enumerate(sorted(weight_map.keys())): + if i >= 20: + print(" ...") + break + print(f" {k}") + weight_scale_keys = [k for k in weight_map if "weight_scale" in k] + print(f"\nKeys containing 'weight_scale': {len(weight_scale_keys)}") + for k in sorted(weight_scale_keys)[:30]: + print(f" {k}") + if len(weight_scale_keys) > 30: + print(f" ... and {len(weight_scale_keys) - 30} more") + return + + device_map = "cpu" if args.cpu else "cuda:0" + print(f"Loading full model on {device_map} (may take a while and use significant memory)...") + model = AutoModelForCausalLM.from_pretrained( + model_dir, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map=device_map, + ) + + state_dict = model.state_dict() + keys = list(state_dict.keys()) + print(f"Total keys in state_dict: {len(keys)}\n") + + if args.list_keys: + print("All state_dict keys:") + for k in sorted(keys): + t = state_dict[k] + print(f" {k} shape={tuple(t.shape)} dtype={t.dtype}") + else: + print("Sample keys (first 30):") + for k in sorted(keys)[:30]: + t = state_dict[k] + print(f" {k} shape={tuple(t.shape)} dtype={t.dtype}") + if len(keys) > 30: + print(" ...") + + weight_scale_keys = [k for k in keys if "weight_scale" in k] + print(f"\nKeys containing 'weight_scale': {len(weight_scale_keys)}") + for k in sorted(weight_scale_keys): + t = state_dict[k] + print(f" {k} shape={tuple(t.shape)} dtype={t.dtype}") + + # Check for the key that was missing in the error + target = "layers.0.mlp.down_proj.weight_scale" + if target in keys: + print(f"\nKey '{target}' present in state_dict.") + else: + print(f"\nKey '{target}' NOT in state_dict.") + # Show similar keys + similar = [k for k in keys if "down_proj" in k and "weight_scale" in k] + if similar: + print("Similar keys (down_proj + weight_scale):") + for k in sorted(similar)[:10]: + print(f" {k}") + + +if __name__ == "__main__": + main()