From cb7cbacaac0863939b3a15ece6ab88130d4de48c Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 1 Apr 2026 15:35:38 +0800 Subject: [PATCH 1/2] added glm4_moe_lite supports Signed-off-by: ZX-ModelCloud --- README.md | 1 + defuser/model_registry.py | 6 ++ defuser/modeling/unfused_moe/glm4_moe_lite.py | 100 ++++++++++++++++++ tests/test_convert_model.py | 33 +++++- tests/test_meta_model_defusion.py | 6 +- 5 files changed, 141 insertions(+), 5 deletions(-) create mode 100644 defuser/modeling/unfused_moe/glm4_moe_lite.py diff --git a/README.md b/README.md index 0851a76..dd529b7 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Defuser currently supports the following `transformers==5.3.0` `model_type` valu | Model type | Defused op performed | | --- | --- | | `glm4_moe` | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. | +| `glm4_moe_lite` | Replaces `Glm4MoeLiteMoE` with a defused per-expert linear MoE block.| | `glm4v` | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. | | `mixtral` | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. | | `qwen2_moe` | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. | diff --git a/defuser/model_registry.py b/defuser/model_registry.py index 0cfa5bb..697e7bf 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -135,6 +135,12 @@ class PATCH(str, Enum): }, "glm4_moe_lite": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, + PATCH.REPLACE_MODULE: [ + ( + "transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteMoE", + "defuser.modeling.unfused_moe.glm4_moe_lite.LinearGlm4MoeLiteMoE", + ) + ], }, "glm4v": { "min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION, diff --git a/defuser/modeling/unfused_moe/glm4_moe_lite.py b/defuser/modeling/unfused_moe/glm4_moe_lite.py new file mode 100644 index 0000000..91f287a --- /dev/null +++ b/defuser/modeling/unfused_moe/glm4_moe_lite.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# Adapted from intel/auto-round +# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/unfused_moe/glm_moe_light.py + +import torch +import torch.nn as nn + +class LinearGlm4MoeLiteMoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts = config.num_local_experts + from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import Glm4MoeLiteMLP, Glm4MoeLiteTopkRouter + + self.experts = nn.ModuleList( + [Glm4MoeLiteMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + self.gate = Glm4MoeLiteTopkRouter(config) + self.shared_experts = Glm4MoeLiteMLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def experts_forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + """ """ + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + expert_layer = self.experts[expert_idx] + current_hidden_states = expert_layer(current_state) + # gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + # current_hidden_states = self.act_fn(gate) * up + # current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + def route_tokens_to_experts(self, router_logits): + router_logits = router_logits.sigmoid() + router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + group_scores = ( + router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts_forward(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states \ No newline at end of file diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index e7920db..f2daa39 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -10,6 +10,11 @@ from torch import nn from transformers.core_model_loading import WeightConverter from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM, Glm4MoeMoE +from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import Glm4MoeLiteConfig +from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( + Glm4MoeLiteForCausalLM, + Glm4MoeLiteMoE, +) from transformers.models.glm4v.configuration_glm4v import Glm4vConfig from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration from transformers.models.mixtral.configuration_mixtral import MixtralConfig @@ -48,6 +53,7 @@ from defuser.model_registry import MODEL_CONFIG, PATCH from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model from defuser.modeling.unfused_moe.glm4_moe import LinearGlm4MoeMoE +from defuser.modeling.unfused_moe.glm4_moe_lite import LinearGlm4MoeLiteMoE from defuser.modeling.unfused_moe.mixtral import LinearMixtralSparseMoeBlock from defuser.modeling.unfused_moe.qwen2_moe import LinearQwen2MoeSparseMoeBlock from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock @@ -61,9 +67,9 @@ -def _tiny_moe_config(config_cls): +def _tiny_moe_config(config_cls, num_hidden_layers: int=1): return config_cls( - num_hidden_layers=1, + num_hidden_layers=num_hidden_layers, hidden_size=64, intermediate_size=128, moe_intermediate_size=32, @@ -936,6 +942,29 @@ def test_defused_models_preserve_output_router_logits_capture(): assert len(outputs.router_logits) == 1 assert outputs.router_logits[0].shape == (3, model.config.num_experts) +def test_glm4_moe_lite(): + model_type = "glm4_moe_lite" + replace_fused_blocks(model_type) + + model = Glm4MoeLiteForCausalLM(_tiny_moe_config(Glm4MoeLiteConfig, num_hidden_layers=2)) + assert model.config.model_type == model_type + + # In GLM4-MoE-Lite, the `mlp.experts` module is present only starting from the second layer. + converted = convert_model(model, max_layers=2) + assert not converted + + _assert_unfused_expert_module(model.model.layers[1].mlp.experts) + + +def test_glm4_moe_lite_defused_forward_matches_fused_math(): + config = _tiny_moe_config(Glm4MoeLiteConfig) + hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) + + _assert_sparse_moe_defused_matches_fused_math( + Glm4MoeLiteMoE(config), + LinearGlm4MoeLiteMoE(config), + hidden_states, + ) def test_glm4v_checkpoint_mapping_splits_gate_up_proj(): from defuser.defuser import get_checkpoint_conversion_mapping diff --git a/tests/test_meta_model_defusion.py b/tests/test_meta_model_defusion.py index ccd0a6a..2db5d2e 100644 --- a/tests/test_meta_model_defusion.py +++ b/tests/test_meta_model_defusion.py @@ -382,13 +382,13 @@ def _validate_defused_module(case: dict, module) -> None: }, { "model_type": "glm4_moe_lite", - "mode": "convert", + "mode": "replace", "model_module": "transformers.models.glm4_moe_lite.modeling_glm4_moe_lite", "model_class": "Glm4MoeLiteForCausalLM", "config_module": "transformers.models.glm4_moe_lite.configuration_glm4_moe_lite", "config_class": "Glm4MoeLiteConfig", - "target_class_paths": ("transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteNaiveMoe",), - "validator": "experts", + "target_class_paths": ("defuser.modeling.unfused_moe.glm4_moe_lite.LinearGlm4MoeLiteMoE",), + "validator": "sparse_block", }, { "model_type": "glm4v", From e4d05bffc4d6bed27d5769190ce73b8cc2f1f75e Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 1 Apr 2026 15:37:20 +0800 Subject: [PATCH 2/2] fix test_qwen3_omni Signed-off-by: ZX-ModelCloud --- tests/test_convert_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index f2daa39..475c779 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -84,6 +84,7 @@ def _tiny_moe_config(config_cls, num_hidden_layers: int=1): def _tiny_qwen3_omni_config(): return Qwen3OmniMoeConfig( + initializer_range=0.02, enable_audio_output=False, thinker_config={ "text_config": {