Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Defuser currently supports the following `transformers==5.3.0` `model_type` valu
| Model type | Defused op performed |
| --- | --- |
| `glm4_moe` | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
| `glm4_moe_lite` | Replaces `Glm4MoeLiteMoE` with a defused per-expert linear MoE block.|
| `glm4v` | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
| `mixtral` | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
| `qwen2_moe` | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
Expand Down
6 changes: 6 additions & 0 deletions defuser/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ class PATCH(str, Enum):
},
"glm4_moe_lite": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
(
"transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteMoE",
"defuser.modeling.unfused_moe.glm4_moe_lite.LinearGlm4MoeLiteMoE",
)
],
},
"glm4v": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
Expand Down
100 changes: 100 additions & 0 deletions defuser/modeling/unfused_moe/glm4_moe_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

# Adapted from intel/auto-round
# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/unfused_moe/glm_moe_light.py

import torch
import torch.nn as nn

class LinearGlm4MoeLiteMoE(nn.Module):
"""
A mixed expert module containing shared experts.
"""

def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_local_experts
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import Glm4MoeLiteMLP, Glm4MoeLiteTopkRouter

self.experts = nn.ModuleList(
[Glm4MoeLiteMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
)

self.gate = Glm4MoeLiteTopkRouter(config)
self.shared_experts = Glm4MoeLiteMLP(
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)
self.n_routed_experts = config.n_routed_experts
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
self.top_k = config.num_experts_per_tok

def experts_forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
""" """
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
expert_layer = self.experts[expert_idx]
current_hidden_states = expert_layer(current_state)
# gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
# current_hidden_states = self.act_fn(gate) * up
# current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states

def route_tokens_to_experts(self, router_logits):
router_logits = router_logits.sigmoid()
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
group_scores = (
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
topk_weights = router_logits.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights

def forward(self, hidden_states):
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts_forward(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states
34 changes: 32 additions & 2 deletions tests/test_convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from torch import nn
from transformers.core_model_loading import WeightConverter
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM, Glm4MoeMoE
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import Glm4MoeLiteConfig
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
Glm4MoeLiteForCausalLM,
Glm4MoeLiteMoE,
)
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
Expand Down Expand Up @@ -48,6 +53,7 @@
from defuser.model_registry import MODEL_CONFIG, PATCH
from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model
from defuser.modeling.unfused_moe.glm4_moe import LinearGlm4MoeMoE
from defuser.modeling.unfused_moe.glm4_moe_lite import LinearGlm4MoeLiteMoE
from defuser.modeling.unfused_moe.mixtral import LinearMixtralSparseMoeBlock
from defuser.modeling.unfused_moe.qwen2_moe import LinearQwen2MoeSparseMoeBlock
from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock
Expand All @@ -61,9 +67,9 @@



def _tiny_moe_config(config_cls):
def _tiny_moe_config(config_cls, num_hidden_layers: int=1):
return config_cls(
num_hidden_layers=1,
num_hidden_layers=num_hidden_layers,
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=32,
Expand All @@ -78,6 +84,7 @@ def _tiny_moe_config(config_cls):

def _tiny_qwen3_omni_config():
return Qwen3OmniMoeConfig(
initializer_range=0.02,
enable_audio_output=False,
thinker_config={
"text_config": {
Expand Down Expand Up @@ -936,6 +943,29 @@ def test_defused_models_preserve_output_router_logits_capture():
assert len(outputs.router_logits) == 1
assert outputs.router_logits[0].shape == (3, model.config.num_experts)

def test_glm4_moe_lite():
model_type = "glm4_moe_lite"
replace_fused_blocks(model_type)

model = Glm4MoeLiteForCausalLM(_tiny_moe_config(Glm4MoeLiteConfig, num_hidden_layers=2))
assert model.config.model_type == model_type

# In GLM4-MoE-Lite, the `mlp.experts` module is present only starting from the second layer.
converted = convert_model(model, max_layers=2)
assert not converted

_assert_unfused_expert_module(model.model.layers[1].mlp.experts)


def test_glm4_moe_lite_defused_forward_matches_fused_math():
config = _tiny_moe_config(Glm4MoeLiteConfig)
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)

_assert_sparse_moe_defused_matches_fused_math(
Glm4MoeLiteMoE(config),
LinearGlm4MoeLiteMoE(config),
hidden_states,
)

def test_glm4v_checkpoint_mapping_splits_gate_up_proj():
from defuser.defuser import get_checkpoint_conversion_mapping
Expand Down
6 changes: 3 additions & 3 deletions tests/test_meta_model_defusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ def _validate_defused_module(case: dict, module) -> None:
},
{
"model_type": "glm4_moe_lite",
"mode": "convert",
"mode": "replace",
"model_module": "transformers.models.glm4_moe_lite.modeling_glm4_moe_lite",
"model_class": "Glm4MoeLiteForCausalLM",
"config_module": "transformers.models.glm4_moe_lite.configuration_glm4_moe_lite",
"config_class": "Glm4MoeLiteConfig",
"target_class_paths": ("transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteNaiveMoe",),
"validator": "experts",
"target_class_paths": ("defuser.modeling.unfused_moe.glm4_moe_lite.LinearGlm4MoeLiteMoE",),
"validator": "sparse_block",
},
{
"model_type": "glm4v",
Expand Down