Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 107 additions & 26 deletions src/lmflow/pipeline/utils/lisa_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,111 @@
import logging
from typing import Optional

import numpy as np
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.trainer_callback import TrainerCallback

logger = logging.getLogger(__name__)

# Mapping from model class name to the dot-separated path of transformer layers.
# Add new entries here as new model families are released.
CLASS_TO_LAYERS_MAP = {
# LLaMA family (1, 2, 3, 3.1, 3.2, 3.3)
"LlamaForCausalLM": "model.model.layers",
# Qwen family
"Qwen2ForCausalLM": "model.model.layers",
"Qwen2MoeForCausalLM": "model.model.layers",
# Mistral / Mixtral
"MistralForCausalLM": "model.model.layers",
"MixtralForCausalLM": "model.model.layers",
# Gemma family
"GemmaForCausalLM": "model.model.layers",
"Gemma2ForCausalLM": "model.model.layers",
"Gemma3ForCausalLM": "model.model.layers",
# Phi family (Microsoft)
"Phi3ForCausalLM": "model.model.layers",
"PhiForCausalLM": "model.model.layers",
# DeepSeek
"DeepseekV2ForCausalLM": "model.model.layers",
"DeepseekV3ForCausalLM": "model.model.layers",
# Cohere (Command R)
"CohereForCausalLM": "model.model.layers",
# OLMo (Allen AI)
"OlmoForCausalLM": "model.model.layers",
"Olmo2ForCausalLM": "model.model.layers",
# Falcon
"FalconForCausalLM": "model.transformer.h",
# GPT-2
"GPT2LMHeadModel": "model.transformer.h",
# GPT-NeoX / Pythia
"GPTNeoXForCausalLM": "model.gpt_neox.layers",
# Hymba
"HymbaForCausalLM": "model.model.layers",
}

# Common layer paths tried in order during dynamic fallback.
_FALLBACK_LAYER_PATHS = [
"model.model.layers",
"model.transformer.h",
"model.gpt_neox.layers",
"model.layers",
]


def _resolve_layers(model: PreTrainedModel, layers_attribute: str):
"""Walk the dot-separated layers_attribute path on model and return the layer list."""
obj = model
for attr in layers_attribute.split(".")[1:]: # skip leading "model"
obj = getattr(obj, attr)
return obj


def _get_layers_attribute(model: PreTrainedModel, lisa_layers_attribute: Optional[str] = None) -> str:
"""Resolve the dot-separated path to the model's transformer layers.

Resolution order:
1. User-supplied lisa_layers_attribute override (highest priority).
2. CLASS_TO_LAYERS_MAP lookup by model class name.
3. Dynamic introspection across known common paths.

Raises ValueError if no path can be found.
"""
unwrapped = model.module if hasattr(model, "module") else model
model_class_name = type(unwrapped).__name__

# 1. User override takes highest priority
if lisa_layers_attribute is not None:
return lisa_layers_attribute

# 2. Known architecture map
if model_class_name in CLASS_TO_LAYERS_MAP:
return CLASS_TO_LAYERS_MAP[model_class_name]

# 3. Dynamic fallback — inspect the actual model object
for path in _FALLBACK_LAYER_PATHS:
try:
obj = unwrapped
for attr in path.split("."):
obj = getattr(obj, attr)
if isinstance(obj, (list, nn.ModuleList)):
logger.warning(
"Model class '%s' not in CLASS_TO_LAYERS_MAP. "
"Dynamically detected layers at '%s'. "
"Consider adding '%s' to CLASS_TO_LAYERS_MAP in lisa_trainer.py.",
model_class_name, path, model_class_name,
)
return path
except AttributeError:
continue

raise ValueError(
f"Cannot locate transformer layers for model class '{model_class_name}'. "
f"Set lisa_layers_attribute in FinetunerArguments to the dot-separated "
f"path (e.g. 'model.model.layers'), or add '{model_class_name}' to "
f"CLASS_TO_LAYERS_MAP in src/lmflow/pipeline/utils/lisa_trainer.py."
)


class DynamicLayerActivationCallback(TrainerCallback):
def __init__(
Expand All @@ -18,49 +120,28 @@ def __init__(
self.interval_steps = interval_steps
self.model = model

# Determine the way to access layers based on the model type
class_to_layers_map = {
"LlamaForCausalLM": "model.model.layers",
"Qwen2ForCausalLM": "model.model.layers",
"MistralForCausalLM": "model.model.layers",
"MixtralForCausalLM": "model.model.layers",
"GemmaForCausalLM": "model.model.layers",
"GPT2LMHeadModel": "model.transformer.h",
"HymbaForCausalLM": "model.model.layers",
}
model_class_name = self.model.__class__.__name__
if model_class_name in class_to_layers_map:
self.layers_attribute = class_to_layers_map[model_class_name]
else:
assert lisa_layers_attribute is not None, "Please provide the attribute to access the layers of the model."
self.layers_attribute = lisa_layers_attribute
self.total_layers = len(
eval("self." + self.layers_attribute)
) # Dynamically execute to get the number of layers
self.layers_attribute = _get_layers_attribute(model, lisa_layers_attribute)
self.total_layers = len(_resolve_layers(self.model, self.layers_attribute))

self.active_layers_indices = []

def freeze_all_layers(self):
layers = eval("self." + self.layers_attribute) # Dynamically execute to get layers
layers = _resolve_layers(self.model, self.layers_attribute)
for layer in layers:
for param in layer.parameters():
param.requires_grad = False

def on_step_begin(self, args, state, control, **kwargs):
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.interval_steps == 0:
self.switch_active_layers()

def switch_active_layers(self):
# First, disable gradients for all layers
self.freeze_all_layers()

# Randomly select n_layers to activate
layers = eval("self." + self.layers_attribute) # Re-fetch layer references
layers = _resolve_layers(self.model, self.layers_attribute)
self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
print(f"Activating layers at indices: {self.active_layers_indices} for the next steps.", flush=True)
logger.info("Activating layers at indices: %s for the next steps.", self.active_layers_indices)

# Enable gradients only for the selected layers
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True
89 changes: 89 additions & 0 deletions tests/pipeline/test_lisa_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import types
import pytest
import torch.nn as nn

from lmflow.pipeline.utils.lisa_trainer import (
CLASS_TO_LAYERS_MAP,
_get_layers_attribute,
)


def make_mock_model(class_name: str, layers_path: str = "model.model.layers", num_layers: int = 4):
"""Build a minimal mock model with layers at the given dot-separated path.

Uses SimpleNamespace for nested attributes so we avoid nn.Module.__init__
complexity. The top-level object is given the requested class name so that
type(model).__name__ returns it correctly.

For example, layers_path="model.model.layers" creates:
mock.model.model.layers = ModuleList([...])
"""
layers = nn.ModuleList([nn.Linear(8, 8) for _ in range(num_layers)])

current = layers
for part in reversed(layers_path.split(".")):
parent = types.SimpleNamespace()
setattr(parent, part, current)
current = parent

MockClass = type(class_name, (object,), {})
instance = object.__new__(MockClass)
instance.__dict__.update(vars(current))
return instance


class TestGetLayersAttribute:

def test_known_architecture_uses_map(self):
"""LLaMA is in CLASS_TO_LAYERS_MAP — should return the mapped path directly."""
model = make_mock_model("LlamaForCausalLM", "model.model.layers")
result = _get_layers_attribute(model)
assert result == "model.model.layers"

def test_newly_added_architecture_gemma2(self):
"""Gemma2 was added to the expanded map — should resolve without fallback."""
model = make_mock_model("Gemma2ForCausalLM", "model.model.layers")
result = _get_layers_attribute(model)
assert result == "model.model.layers"
assert "Gemma2ForCausalLM" in CLASS_TO_LAYERS_MAP

def test_falcon_maps_to_transformer_h(self):
"""FalconForCausalLM maps to model.transformer.h — verifies non-default path entries."""
model = make_mock_model("FalconForCausalLM", "model.transformer.h")
result = _get_layers_attribute(model)
assert result == "model.transformer.h"

def test_user_override_takes_precedence_over_map(self):
"""User-supplied lisa_layers_attribute must win even for known architectures.

Uses a custom path that differs from both the map entry and all fallback
paths, so the only way the test passes is if the override is truly used.
"""
model = make_mock_model("LlamaForCausalLM", "model.model.layers")
result = _get_layers_attribute(model, lisa_layers_attribute="model.custom.blocks")
assert result == "model.custom.blocks"

def test_dynamic_fallback_finds_transformer_h(self):
"""Unknown model with layers at model.transformer.h — fallback iterates past first entry."""
model = make_mock_model("BrandNewGPTModel", "model.transformer.h")
result = _get_layers_attribute(model)
assert result == "model.transformer.h"

def test_completely_unknown_model_raises_valueerror(self):
"""Unknown model with no recognizable layer path should raise a clear ValueError."""
model = make_mock_model("WeirdModelWithNoLayers", "model.model.layers")
model.__dict__.clear()
with pytest.raises(ValueError, match="Cannot locate transformer layers"):
_get_layers_attribute(model)

def test_dataparallel_wrapped_model_unwrapped(self):
"""Model wrapped in DataParallel (.module) should be unwrapped before class lookup."""
inner = make_mock_model("LlamaForCausalLM", "model.model.layers")

# Simulate DataParallel wrapping: outer object has a .module attribute
WrapperClass = type("DataParallel", (object,), {})
wrapper = object.__new__(WrapperClass)
wrapper.module = inner

result = _get_layers_attribute(wrapper)
assert result == "model.model.layers"