diff --git a/contrib/models/gemma-3-1b-it/README.md b/contrib/models/gemma-3-1b-it/README.md index 9e992144..08ce07f1 100644 --- a/contrib/models/gemma-3-1b-it/README.md +++ b/contrib/models/gemma-3-1b-it/README.md @@ -1,115 +1,177 @@ -# Contrib Model: gemma 3 1b it +# Contrib Model: Gemma 3 1B IT -NeuronX Distributed Inference implementation of gemma 3 1b it. +NeuronX Distributed Inference support for **google/gemma-3-1b-it** (1B parameter variant). -## Model Information - -- **HuggingFace ID:** `gemma-3-1b-it` -- **Model Type:** Decoder-only transformer -- **License:** Check HuggingFace model card - -## Architecture Details - - -## Validation Results - -**Validated:** 2026-02-06 -**Configuration:** TP=1, batch_size=1, seq_len=128, bfloat16 - -### Test Results +This contrib subclasses the official `models/gemma3/` implementation and adds +the minimal overrides needed for the 1B variant's unusual architecture. -| Test | Status | Result | -|------|--------|--------| -| Smoke Test | ✅ PASS | Model loads successfully | -| Token Matching | ✅ PASS | **100% match** (best of multiple prompts) | - -**Test Prompt:** `"def fibonacci(n):"` - -**Status:** ✅ VALIDATED - -### Device Profiling Metrics - -**Configuration:** TP=1, batch_size=1, seq_len=128, bfloat16 -**Instance:** trn1.32xlarge | **Profiled:** 2026-03-18 - -| Metric | Context Encoding | Token Generation | -|--------|-----------------|------------------| -| MFU (%) | 0.20 | 0.00 | -| MBU (%) | 0.50 | 0.59 | -| HFU (%) | 0.21 | 0.00 | -| Execution Time (us) | 0.01 | 0.01 | -| HBM Read | 2.00 GB | 2.00 GB | -| HBM Write | 4.82 MB | 1.10 MB | - -**Throughput:** 87.64 tok/s | **Compile Time:** 253.40s +## Model Information -> Metrics from `neuron-profile capture` on compiled NEFFs. MFU = Model FLOPs Utilization, -> MBU = Memory Bandwidth Utilization, HFU = Hardware FLOPs Utilization. +- **HuggingFace ID:** `google/gemma-3-1b-it` +- **Model Type:** Decoder-only transformer (causal LM) +- **Parameters:** 1B +- **License:** Gemma license (see HuggingFace model card) + +## Why a Separate Contrib? + +The official `models/gemma3/` targets the 4B/12B/27B variants (head_dim=128). +The 1B variant has several unusual architecture parameters that require +additional handling: + +| Parameter | 1B | 4B/12B/27B | +|-----------|-----|-----------| +| head_dim | **256** | 128 | +| vocab_size | **262144** | 262208 | +| num_kv_heads | **1** | 4-16 | +| num_attention_heads | **4** | 8-32 | + +### Issues Addressed + +1. **Chunked attention for head_dim=256** -- The Neuron compiler generates DGE + scatter/gather instructions that produce out-of-bounds memory accesses when + head_dim exceeds 128. All Q@K^T and scores@V matmuls are split into + 128-wide chunks along head_dim. Mathematically identical, avoids hardware + addressing limits. + +2. **vocab_size from HF config** -- The upstream `Gemma3InferenceConfig` + hardcodes `vocab_size=262208`. This contrib reads the actual value from + the HuggingFace config (262144 for 1B). + +3. **Auto-disable NKI attention kernel** -- The NKI flash attention kernel + asserts `head_dim <= 128`. This contrib auto-disables it when head_dim + exceeds that limit. + +4. **k_cache_transposed + SWA + GQA fix** -- The base class forces + `k_cache_transposed=False` for sliding window layers, but the KV cache + manager stores K in BHDS layout for ALL layers when `k_cache_transposed=True` + in the config. This creates a layout mismatch: `repeat_kv` assumes BHSD but + receives BHDS, producing incorrect GQA expansion. The fix restores the + config value and transposes K around `repeat_kv`. + +5. **query_pre_attn_scalar weight fusion** -- NxDI uses `QK^T / sqrt(head_dim)` + for attention scaling, but Gemma 3 specifies `QK^T / sqrt(query_pre_attn_scalar)`. + Rather than modifying the attention kernel (which risks breaking optimizations), + we fuse the correction factor into Q/K weight matrices at load time. Zero + runtime overhead. Pattern from Pierre Lienhart's gemma3-vision contrib. + +### Known Compiler Issue + +**CTE buckets < 512 crash at runtime** with head_dim=256 + `input_output_aliases`. +This is a Neuron compiler issue (DGE OOB), not a code issue. Workaround: +always use `context_encoding_buckets: [512]` or larger. + +| CTE Bucket | Result | +|-----------|--------| +| 128 | OOB crash | +| 256 | OOB crash | +| 384 | OOB crash | +| 512 | **PASS** | ## Usage +### Standalone (NxDI API) + ```python -from transformers import AutoTokenizer, GenerationConfig +import torch from neuronx_distributed_inference.models.config import NeuronConfig from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config -# Import model classes from src -from src.modeling_gemma_3_1b_it import Neurongemma31bitForCausalLM, gemma31bitInferenceConfig +import sys +sys.path.insert(0, "contrib/models/gemma-3-1b-it/src") +from modeling_gemma3 import NeuronGemma3_1B_ForCausalLM, Gemma3_1B_InferenceConfig -model_path = "/path/to/gemma-3-1b-it/" -compiled_model_path = "/path/to/compiled/" +model_path = "google/gemma-3-1b-it" -# Configure neuron_config = NeuronConfig( tp_degree=1, - batch_size=None, + batch_size=4, seq_len=512, torch_dtype=torch.bfloat16, + attn_kernel_enabled=False, + k_cache_transposed=True, ) -config = gemma31bitInferenceConfig( +config = Gemma3_1B_InferenceConfig( neuron_config, load_config=load_pretrained_config(model_path), ) -# Compile and load -model = Neurongemma31bitForCausalLM(model_path, config) -model.compile(compiled_model_path) -model.load(compiled_model_path) +model = NeuronGemma3_1B_ForCausalLM(model_path, config) +model.compile("/tmp/gemma3-1b-compiled") +model.load("/tmp/gemma3-1b-compiled") +``` + +### vLLM Serving + +Requires installing the NxDI fork with the `gemma3` model type registered in +`constants.py` (or using the fork's `fix/gemma3-1b-oob` branch). -# Generate -tokenizer = AutoTokenizer.from_pretrained(model_path) -# ... (see integration test for full example) +```bash +python -m vllm.entrypoints.openai.api_server \ + --model google/gemma-3-1b-it \ + --tensor-parallel-size 1 \ + --max-model-len 512 \ + --max-num-seqs 4 \ + --dtype bfloat16 \ + --no-enable-prefix-caching \ + --block-size 128 \ + --additional-config '{"override_neuron_config": { + "tp_degree": 1, + "batch_size": 4, + "seq_len": 512, + "n_active_tokens": 4, + "context_encoding_buckets": [512], + "token_generation_buckets": [512], + "on_device_sampling_config": null, + "attn_kernel_enabled": false, + "k_cache_transposed": true + }}' ``` -## Compatibility Matrix +## Required Configuration -| Instance/Version | 2.20+ | 2.19 and earlier | -|------------------|-------|------------------| -| Trn1 | ✅ Working | Not tested | -| Inf2 | Not tested | Not tested | +| Parameter | Value | Why | +|-----------|-------|-----| +| `attn_kernel_enabled` | `false` | NKI kernel asserts head_dim <= 128 | +| `k_cache_transposed` | `true` | Required for the SWA+GQA fix | +| `context_encoding_buckets` | `[512]` or larger | Compiler OOB for buckets < 512 | +| `on_device_sampling_config` | `null` | Required (not `false`) | -## Testing +## Compatibility -Run integration tests: +| Instance | Status | Notes | +|----------|--------|-------| +| trn2.3xlarge | Tested | TP=1, batch_size=4/16, CTE bucket 512 | +| inf2.8xlarge | Not tested with this contrib | OOB confirmed on raw official code | +| trn1.* | Not tested | Should work with same config | -```bash -pytest nxdi_contrib_models/models/gemma-3-1b-it/test/integration/test_model.py --capture=tee-sys -``` +## Architecture -Or run manually: +This contrib is structured as thin subclasses of the official implementation: -```bash -cd nxdi_contrib_models/models/gemma-3-1b-it -python3 test/integration/test_model.py +``` +models/gemma3/modeling_gemma3.py (upstream, unchanged) + | + +-- contrib/gemma-3-1b-it/src/modeling_gemma3.py (this file) + |-- Gemma3_1B_InferenceConfig <-- fixes vocab_size, auto-disables NKI + |-- NeuronGemma3_1B_Attention <-- chunked attn + k_cache_transposed fix + +-- NeuronGemma3_1B_ForCausalLM <-- query_pre_attn_scalar weight fusion ``` -## Example Checkpoints +No upstream files are modified. The contrib imports from the official +`models/gemma3/` package and overrides only what is necessary. -* gemma-3-1b-it +## Testing + +```bash +# On a Neuron instance (trn2 or inf2): +cd neuronx-distributed-inference +PYTHONPATH="contrib/models/gemma-3-1b-it/src:src:$PYTHONPATH" \ + pytest contrib/models/gemma-3-1b-it/test/integration/test_model.py -v --capture=tee-sys +``` ## Maintainer -Annapurna Labs +Jim Burtoft (jimburtoft) -**Last Updated:** 2026-02-06 +**Last Updated:** 2026-03-27 diff --git a/contrib/models/gemma-3-1b-it/src/__init__.py b/contrib/models/gemma-3-1b-it/src/__init__.py index 902148b0..86719897 100644 --- a/contrib/models/gemma-3-1b-it/src/__init__.py +++ b/contrib/models/gemma-3-1b-it/src/__init__.py @@ -1 +1,4 @@ -from .modeling_gemma3 import NeuronGemma3ForCausalLM, Gemma3InferenceConfig +from .modeling_gemma3 import ( + Gemma3_1B_InferenceConfig, + NeuronGemma3_1B_ForCausalLM, +) diff --git a/contrib/models/gemma-3-1b-it/src/modeling_gemma3.py b/contrib/models/gemma-3-1b-it/src/modeling_gemma3.py index 8a89b332..4c3a8141 100644 --- a/contrib/models/gemma-3-1b-it/src/modeling_gemma3.py +++ b/contrib/models/gemma-3-1b-it/src/modeling_gemma3.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2025 Google Inc. and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,640 +14,549 @@ # limitations under the License. """ -PyTorch Gemma3 model for NeuronX Distributed Inference - -This implementation ports Google's Gemma3 model to NeuronX hardware. -Key architectural features: -- Q-K normalization (similar to Qwen3) -- Scaled embeddings (embed * sqrt(hidden_size)) -- Dual RoPE implementations (global and local for sliding window) -- Four normalization layers per block -- Alternating sliding window attention pattern -- MQA (num_kv_heads=1) +Gemma 3 1B IT support for NeuronX Distributed Inference. + +The official ``models/gemma3/`` implementation targets the 4B/12B/27B variants +which all have head_dim=128. The 1B variant has several unusual architecture +parameters that require additional handling: + + * head_dim=256 – exceeds the NKI kernel limit of 128, and triggers a Neuron + compiler DGE out-of-bounds issue for CTE buckets < 512. + * vocab_size=262144 – the 4B+ variants use 262208; the upstream config class + hardcodes the larger value. + * GQA with num_kv_heads=1 – interacts with k_cache_transposed + SWA to + produce a layout mismatch in repeat_kv. + +This module subclasses the official Gemma 3 NxDI classes and adds only the +minimal overrides required for the 1B variant: + + 1. **Chunked attention** – Q@K^T and scores@V matmuls are split into + 128-wide chunks along head_dim to stay within hardware addressing limits. + 2. **vocab_size from HF config** – reads the actual value instead of + hardcoding 262208. + 3. **Auto-disable NKI kernel** – when head_dim > 128. + 4. **k_cache_transposed fix** – restores the config value for SWA layers + and transposes K around repeat_kv so GQA expansion works correctly. + 5. **query_pre_attn_scalar weight fusion** – fuses the Gemma 3 attention + scaling correction into Q/K weight matrices at load time (following the + pattern from Pierre Lienhart's gemma3-vision contrib) so NxDI's default + sqrt(head_dim) scaling produces the correct result with zero runtime cost. + +Required configuration knobs (via vLLM --additional-config): + + context_encoding_buckets: [512] # MUST be >= 512 (compiler issue) + attn_kernel_enabled: false # NKI kernel asserts head_dim <= 128 + k_cache_transposed: true # required for the repeat_kv fix + +See the README for full usage instructions. """ -import json -import os -from typing import List, Optional, Tuple, Type +import logging +import math +from typing import Optional, Tuple import torch import torch.nn.functional as F -from torch import nn +from torch import Tensor, nn + +import copy from neuronx_distributed.parallel_layers.layers import ( ColumnParallelLinear, ParallelEmbedding, - RowParallelLinear, ) -from neuronx_distributed.utils import cpu_mode from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.gemma3.modeling_gemma3 import ( + Gemma3InferenceConfig as _UpstreamGemma3InferenceConfig, + Gemma3NeuronConfig as _UpstreamGemma3NeuronConfig, + NeuronGemma3Attention as _UpstreamNeuronGemma3Attention, + NeuronGemma3DecoderLayer as _UpstreamNeuronGemma3DecoderLayer, + NeuronGemma3ForCausalLM as _UpstreamNeuronGemma3ForCausalLM, + NeuronGemma3TextModel as _UpstreamNeuronGemma3TextModel, + NeuronGemma3RMSNorm, + get_rmsnorm_cls, + get_updated_configs, +) +from neuronx_distributed_inference.models.llama.modeling_llama import NeuronLlamaMLP from neuronx_distributed_inference.models.model_base import ( NeuronBaseForCausalLM, NeuronBaseModel, ) -from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase -from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding -from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.attention.utils import repeat_kv +logger = logging.getLogger(__name__) -# ==================================================================================== -# Configuration Classes -# ==================================================================================== +# Maximum head dimension that the Neuron compiler can handle without DGE +# out-of-bounds errors in the standard matmul paths. +_MAX_UNCHUNKED_HEAD_DIM = 128 -class Gemma3NeuronConfig(NeuronConfig): - """ - NeuronConfig for Gemma3 model - Specifies the attention class to use for Gemma3 +# --------------------------------------------------------------------------- +# Chunked attention helpers +# --------------------------------------------------------------------------- + + +def _chunked_qk( + Q: Tensor, + K: Tensor, + scale: float, + chunk_size: int = _MAX_UNCHUNKED_HEAD_DIM, +) -> Tensor: + """Q @ K^T / scale, chunked along head_dim to avoid DGE OOB. + + Args: + Q: (B, H, S_q, D) + K: (B, H, S_k, D) – NOT transposed + scale: divisor (typically sqrt(head_dim)) + chunk_size: max inner-dim width per matmul """ + head_dim = Q.shape[-1] + if head_dim <= chunk_size: + return torch.matmul(Q, K.transpose(2, 3)) / scale + + QK = torch.matmul(Q[..., :chunk_size], K[..., :chunk_size].transpose(2, 3)) + for start in range(chunk_size, head_dim, chunk_size): + end = min(start + chunk_size, head_dim) + QK = QK + torch.matmul(Q[..., start:end], K[..., start:end].transpose(2, 3)) + return QK / scale + + +def _chunked_qk_transposed( + Q: Tensor, + K_t: Tensor, + scale: float, + chunk_size: int = _MAX_UNCHUNKED_HEAD_DIM, +) -> Tensor: + """Q @ K_t / scale where K_t is already (B, H, D, S_k).""" + head_dim = Q.shape[-1] + if head_dim <= chunk_size: + return torch.matmul(Q, K_t) / scale + + QK = torch.matmul(Q[..., :chunk_size], K_t[..., :chunk_size, :]) + for start in range(chunk_size, head_dim, chunk_size): + end = min(start + chunk_size, head_dim) + QK = QK + torch.matmul(Q[..., start:end], K_t[..., start:end, :]) + return QK / scale + + +def _chunked_v_matmul( + scores: Tensor, + V: Tensor, + chunk_size: int = _MAX_UNCHUNKED_HEAD_DIM, +) -> Tensor: + """scores @ V, chunked along V's head_dim.""" + head_dim = V.shape[-1] + if head_dim <= chunk_size: + return torch.matmul(scores, V) + + chunks = [] + for start in range(0, head_dim, chunk_size): + end = min(start + chunk_size, head_dim) + chunks.append(torch.matmul(scores, V[..., start:end])) + return torch.cat(chunks, dim=-1) + + +# --------------------------------------------------------------------------- +# Config overrides +# --------------------------------------------------------------------------- + + +class Gemma3_1B_NeuronConfig(_UpstreamGemma3NeuronConfig): + """NeuronConfig that points to our 1B-specific attention class.""" def __init__(self, **kwargs): super().__init__(**kwargs) - # Use Gemma3-specific attention class - self.attn_cls = NeuronGemma3Attention + self.attn_cls = NeuronGemma3_1B_Attention -class Gemma3InferenceConfig(InferenceConfig): - """ - Configuration class for Gemma3 model inference on NeuronX - - Inherits from InferenceConfig and adds Gemma3-specific parameters. - This class handles loading configuration from HuggingFace format. +class Gemma3_1B_InferenceConfig(_UpstreamGemma3InferenceConfig): + """InferenceConfig fixes for the 1B variant. + + Changes vs upstream: + - Reads vocab_size from HF config (262144 for 1B) instead of hardcoding + 262208. + - Auto-disables NKI attention kernel when head_dim > 128. """ - def add_derived_config(self): - """Add derived configuration parameters""" - self.num_cores_per_group = 1 - - # Add required attributes for HF compatibility - if not hasattr(self, "output_attentions"): - self.output_attentions = False - if not hasattr(self, "output_hidden_states"): - self.output_hidden_states = False - if not hasattr(self, "use_cache"): - self.use_cache = True - - # Add Gemma3-specific parameters with defaults - if not hasattr(self, "query_pre_attn_scalar"): - self.query_pre_attn_scalar = 256 - - # NOTE: Disabling sliding window for now as the NKI kernel doesn't support head_dim > 128 - # Gemma3 uses head_dim=256 which exceeds this limit - # TODO: Re-enable when kernel support is added or use alternative implementation - if not hasattr(self, "sliding_window"): - self.sliding_window = None # Disabled for now - - if not hasattr(self, "sliding_window_pattern"): - self.sliding_window_pattern = 6 - - if not hasattr(self, "rope_local_base_freq"): - self.rope_local_base_freq = 10000 - - if not hasattr(self, "attn_logit_softcapping"): - self.attn_logit_softcapping = None - - if not hasattr(self, "final_logit_softcapping"): - self.final_logit_softcapping = None - - if not hasattr(self, "attention_bias"): - self.attention_bias = False - - if not hasattr(self, "attention_dropout"): - self.attention_dropout = 0.0 - - # Generate layer_types based on sliding_window_pattern - # NOTE: Currently all layers use global attention due to head_dim limitation - if not hasattr(self, "layer_types"): - self.layer_types = [] - for i in range(self.num_hidden_layers): - # Disabled sliding window due to head_dim > 128 limitation - self.layer_types.append("global_attention") - - def get_required_attributes(self) -> List[str]: - """List of required attributes for the configuration""" - return [ - "hidden_size", - "num_attention_heads", - "num_hidden_layers", - "num_key_value_heads", - "head_dim", - "pad_token_id", - "vocab_size", - "max_position_embeddings", - "rope_theta", - "rms_norm_eps", - "intermediate_size", - ] + def __init__(self, neuron_config, fused_spec_config=None, load_config=None): + # Let the parent set everything up (including load_config which + # populates vocab_size from HF). + # + # The parent unconditionally sets vocab_size=262208 *after* load_config. + # We need to capture the HF value before that happens. + self._hf_vocab_size = None - @classmethod - def get_neuron_config_cls(cls) -> Type[Gemma3NeuronConfig]: - """Return the NeuronConfig class to use""" - return Gemma3NeuronConfig + # Intercept load_config to capture vocab_size before parent overwrites it. + if load_config is not None: + original_load_config = load_config - @classmethod - def from_pretrained(cls, model_path: str, **kwargs) -> "Gemma3InferenceConfig": - """ - Load configuration from a pretrained model directory - - Args: - model_path: Path to the model directory containing config.json - **kwargs: Additional arguments (including neuron_config) - - Returns: - Gemma3InferenceConfig: Configuration object - """ - # Extract neuron_config from kwargs if it exists - neuron_config = kwargs.pop("neuron_config", None) - - # Read config.json from the model directory - config_path = os.path.join(model_path, "config.json") - if not os.path.exists(config_path): - raise FileNotFoundError(f"Configuration file not found at {config_path}") - - with open(config_path, "r") as f: - config_dict = json.load(f) - - # Override with remaining kwargs - config_dict.update(kwargs) - - # Add required attributes that might not be in HF config - if "output_attentions" not in config_dict: - config_dict["output_attentions"] = False - if "output_hidden_states" not in config_dict: - config_dict["output_hidden_states"] = False - if "use_cache" not in config_dict: - config_dict["use_cache"] = True - # Gemma3 defaults to tied embeddings - if "tie_word_embeddings" not in config_dict: - config_dict["tie_word_embeddings"] = True - - # If neuron_config is None, create a default one for validation - # The actual neuron_config will be loaded from the compiled model during inference - if neuron_config is None: - from neuronx_distributed_inference.models.config import NeuronConfig - neuron_config = NeuronConfig() - - # Create config object - config = cls(neuron_config=neuron_config, **config_dict) - return config - - -# ==================================================================================== -# Model Components -# ==================================================================================== - - -class Gemma3RMSNorm(nn.Module): - """ - Gemma3-specific RMSNorm implementation - - Key difference from standard RMSNorm: - - Uses (1.0 + weight) instead of just weight for scaling - - This is specific to Gemma3 architecture - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3RMSNorm - """ + def _capturing_load_config(self_inner): + original_load_config(self_inner) + # Capture the HF vocab_size before parent overwrites it. + self._hf_vocab_size = getattr(self_inner, "vocab_size", None) - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - # Initialize weight to zeros (Gemma3-specific) - self.weight = nn.Parameter(torch.zeros(dim)) + load_config = _capturing_load_config - def _norm(self, x): - """Root mean square normalization""" - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + super().__init__(neuron_config, fused_spec_config, load_config) - def forward(self, x): - output = self._norm(x.float()) - # Gemma3-specific: use (1.0 + weight) for scaling - output = output * (1.0 + self.weight.float()) - return output.type_as(x) + # Restore the correct vocab_size. + if self._hf_vocab_size is not None: + self.vocab_size = self._hf_vocab_size + # Auto-disable NKI kernel when head_dim > 128. + head_dim = getattr( + self, "head_dim", self.hidden_size // self.num_attention_heads + ) + if ( + head_dim > _MAX_UNCHUNKED_HEAD_DIM + and self.neuron_config.attn_kernel_enabled is not False + ): + logger.warning( + "Gemma3-1B: head_dim=%d > %d, auto-disabling NKI attention kernel", + head_dim, + _MAX_UNCHUNKED_HEAD_DIM, + ) + self.neuron_config.attn_kernel_enabled = False -def get_rmsnorm_cls(): - """ - Get the appropriate RMSNorm implementation based on execution mode - - Returns: - Gemma3RMSNorm for CPU mode (CustomRMSNorm doesn't work on CPU) - CustomRMSNorm for NeuronX mode (optimized for Neuron hardware) - """ - # For Gemma3, we need to use the custom Gemma3RMSNorm which has - # the specific (1.0 + weight) scaling. However, CustomRMSNorm doesn't - # support this yet, so we'll use Gemma3RMSNorm everywhere for now. - return Gemma3RMSNorm + @classmethod + def get_neuron_config_cls(cls): + return Gemma3_1B_NeuronConfig -class Gemma3ScaledEmbedding(nn.Module): - """ - Gemma3-specific scaled embeddings - - Embeddings are multiplied by sqrt(hidden_size) as per Gemma3 architecture. - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3TextScaledWordEmbedding - """ +# --------------------------------------------------------------------------- +# Attention override +# --------------------------------------------------------------------------- - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int, - dtype: torch.dtype, - shard_across_embedding: bool = True, - pad: bool = True, - sequence_parallel_enabled: bool = False, - ): - super().__init__() - self.embed_scale = embedding_dim**0.5 - self.embedding = ParallelEmbedding( - num_embeddings, - embedding_dim, - padding_idx, - dtype=dtype, - shard_across_embedding=shard_across_embedding, - pad=pad, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - def forward(self, input_ids: torch.Tensor): - # Get embeddings and scale by sqrt(hidden_size) - embeds = self.embedding(input_ids) - return embeds * self.embed_scale +class NeuronGemma3_1B_Attention(_UpstreamNeuronGemma3Attention): + """Attention for Gemma 3 1B (head_dim=256). - -class NeuronGemma3Attention(NeuronAttentionBase): - """ - Gemma3 attention mechanism with Q-K normalization - - Key features: - - Q-K normalization after projection (similar to Qwen3) - - Support for both global and local (sliding window) attention - - query_pre_attn_scalar for attention score scaling - - Optional attention logit softcapping - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3Attention + Adds: + - Chunked Q@K^T and scores@V for head_dim > 128. + - Restores k_cache_transposed for SWA layers (base class forces False). + - Transposes K around repeat_kv so GQA expansion works for BHDS layout. """ - def __init__(self, config: Gemma3InferenceConfig, is_sliding: bool = False): - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + def __init__(self, config): + super().__init__(config) + self._needs_chunked_attn = self.head_dim > _MAX_UNCHUNKED_HEAD_DIM - # Determine which RoPE to use based on attention type - # Sliding window uses local RoPE with smaller base frequency - if is_sliding: - rope_theta = config.rope_local_base_freq + # The base class forces k_cache_transposed=False for SWA layers + # (attention_base.py line 316), but the KV cache manager uses the + # NeuronConfig value globally. Restore the config value so that + # SWA layers interpret the cache layout correctly. + self.k_cache_transposed = config.neuron_config.k_cache_transposed + + # -- CTE overrides (prefill) ---------------------------------------- + + def scaled_qk(self, Q, K, attention_mask): + """Override: chunk Q@K^T for large head_dim.""" + if self._needs_chunked_attn: + QK = _chunked_qk(Q, K, scale=math.sqrt(self.head_dim)) else: - rope_theta = config.rope_theta + QK = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + QK = torch.where( + attention_mask.to(torch.bool), QK, torch.finfo(QK.dtype).min + ) + return QK - rotary_emb = RotaryEmbedding( - dim=head_dim, - max_position_embeddings=config.max_position_embeddings, - base=rope_theta, + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask) -> Tensor: + """Override: use chunked V matmul for the flat-compiler CTE path.""" + from neuronx_distributed_inference.modules.attention.attention_base import ( + FlashAttentionStrategy, ) - # Determine sliding window size - sliding_window = config.sliding_window if is_sliding else None - - super().__init__( - config=config, - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - head_dim=head_dim, - rotary_emb=rotary_emb, - sliding_window=sliding_window, - # Q-K normalization (like Qwen3) - q_layernorm=get_rmsnorm_cls()(dim=head_dim, eps=config.rms_norm_eps), - k_layernorm=get_rmsnorm_cls()(dim=head_dim, eps=config.rms_norm_eps), + flash_attn_strategy = self.get_flash_attention_strategy( + q_len, attention_mask is not None ) + if flash_attn_strategy != FlashAttentionStrategy.NONE: + return super().perform_prefill(Q, K, V, q_len, bsz, attention_mask) - # Store Gemma3-specific parameters - self.query_pre_attn_scalar = config.query_pre_attn_scalar - self.attn_logit_softcapping = config.attn_logit_softcapping + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + active_scores = self.scaled_qk(Q, K_active, attention_mask) + learned_sinks = self.get_learned_sinks() + if learned_sinks is not None: + learned_sinks = learned_sinks.reshape(1, self.num_heads, 1, 1).expand( + bsz, -1, q_len, -1 + ) + active_scores = torch.cat((active_scores, learned_sinks), dim=-1) -class NeuronGemma3MLP(nn.Module): - """ - Gemma3 MLP (feed-forward network) - - Architecture: gate_proj, up_proj, down_proj with GELU activation - Similar to LLaMA but uses gelu_pytorch_tanh instead of SiLU - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3MLP - """ + active_scores = nn.functional.softmax( + active_scores, dim=-1, dtype=torch.float32 + ).to(Q.dtype) - def __init__(self, config: Gemma3InferenceConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - # Gate and up projections (column parallel) - self.gate_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - ) + if learned_sinks is not None: + active_scores = active_scores[..., :-1] - self.up_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=False, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, + attn_output = ( + _chunked_v_matmul(active_scores, V_active) + if self._needs_chunked_attn + else torch.matmul(active_scores, V_active) + ) + return attn_output, flash_attn_strategy + + def perform_prefill_windowed_attn( + self, Q, K, V, q_len, bsz, attention_mask, window_size + ) -> Tensor: + """Override: use chunked matmuls for windowed (SWA) CTE path.""" + from neuronx_distributed_inference.modules.attention.attention_base import ( + FlashAttentionStrategy, ) - # Down projection (row parallel) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=False, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, + flash_attn_strategy = self.get_flash_attention_strategy( + q_len, attention_mask is not None ) + if flash_attn_strategy not in ( + FlashAttentionStrategy.NONE, + FlashAttentionStrategy.SLIDING_WINDOW_KERNEL, + ): + attn_output, _ = self.perform_prefill(Q, K, V, q_len, bsz, attention_mask) + return attn_output, flash_attn_strategy + + if flash_attn_strategy == FlashAttentionStrategy.SLIDING_WINDOW_KERNEL: + return super().perform_prefill_windowed_attn( + Q, K, V, q_len, bsz, attention_mask, window_size + ) - # GELU activation (gelu_pytorch_tanh approximation) - # This is GELU with tanh approximation as used in Gemma3 - self.act_fn = nn.GELU(approximate="tanh") + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + active_scores = self.scaled_qk(Q, K_active, attention_mask) - def forward(self, x): - # Gemma3 MLP: down_proj(act(gate_proj(x)) * up_proj(x)) - gate_output = self.act_fn(self.gate_proj(x)) - up_output = self.up_proj(x) - down_output = self.down_proj(gate_output * up_output) - return down_output, None # Return None for compatibility + learned_sinks = self.get_learned_sinks() + if learned_sinks is not None: + learned_sinks = learned_sinks.reshape(1, self.num_heads, 1, 1).expand( + bsz, -1, q_len, -1 + ) + active_scores = torch.cat((active_scores, learned_sinks), dim=-1) + active_scores = nn.functional.softmax( + active_scores, dim=-1, dtype=torch.float32 + ).to(Q.dtype) -class NeuronGemma3DecoderLayer(nn.Module): - """ - Gemma3 decoder layer - - Key architectural features: - - Four normalization layers: input, post_attention, pre_feedforward, post_feedforward - - Pre-norm architecture with residual connections - - Support for both global and sliding window attention - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3DecoderLayer - """ + if learned_sinks is not None: + active_scores = active_scores[..., :-1] - def __init__(self, config: Gemma3InferenceConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx + attn_output = ( + _chunked_v_matmul(active_scores, V_active) + if self._needs_chunked_attn + else torch.matmul(active_scores, V_active) + ) + return attn_output, flash_attn_strategy - # Determine if this layer uses sliding window attention - is_sliding = config.layer_types[layer_idx] == "sliding_attention" + # -- TKG override (token generation) --------------------------------- - # Attention and MLP - self.self_attn = NeuronGemma3Attention(config, is_sliding=is_sliding) - self.mlp = NeuronGemma3MLP(config) + def compute_for_token_gen( + self, + Q, + K, + V, + position_ids, + past_key_value, + attention_mask, + active_mask, + is_prefix_caching=False, + ) -> Tensor: + """Override: chunked matmuls + k_cache_transposed repeat_kv fix.""" + if not self._needs_chunked_attn: + return super().compute_for_token_gen( + Q, + K, + V, + position_ids, + past_key_value, + attention_mask, + active_mask, + is_prefix_caching, + ) - # Four normalization layers (Gemma3-specific) - self.input_layernorm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) - self.post_feedforward_layernorm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + from neuronx_distributed_inference.modules.attention.attention_base import ( + manual_softmax, + ) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states: input tensor of shape (batch, seq_len, hidden_size) - attention_mask: attention mask tensor - position_ids: position indices tensor - past_key_value: cached key-value pairs for efficient generation - - Returns: - Tuple of (hidden_states, present_key_value, cos_cache, sin_cache, None) - """ - # Attention block with pre and post normalization - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self attention - hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - **kwargs, + is_speculation = False if position_ids is None else position_ids.shape[-1] > 1 + if self.attention_chunk_size and is_speculation: + raise NotImplementedError( + "Speculative decoding not supported with chunked attention." + ) + + K_prior = past_key_value[0] + V_prior = past_key_value[1] + + # Handle k_cache_transposed: K_prior is BHDS, repeat_kv expects BHSD. + if self.k_cache_transposed: + K_prior = K_prior.transpose(2, 3) # BHDS -> BHSD + K_prior = repeat_kv(K_prior, self.num_key_value_groups) + K_prior = K_prior.transpose(2, 3) # BHSD -> BHDS + V_prior = repeat_kv(V_prior, self.num_key_value_groups) + prior_scores = _chunked_qk_transposed( + Q, K_prior, scale=math.sqrt(self.head_dim) + ) + else: + K_prior = repeat_kv(K_prior, self.num_key_value_groups) + V_prior = repeat_kv(V_prior, self.num_key_value_groups) + prior_scores = _chunked_qk(Q, K_prior, scale=math.sqrt(self.head_dim)) + + # Pad attention mask if KV cache is padded. + if ( + prior_scores.shape[-1] > attention_mask.shape[-1] + and self.neuron_config.apply_seq_ids_mask + ): + attention_mask = F.pad( + attention_mask, + (0, prior_scores.shape[-1] - attention_mask.shape[-1]), + "constant", + 0, + ) + + prior_scores = torch.where( + attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min ) + prior_scores = prior_scores.to(torch.float32) + + # Active (current) KV. + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + active_scores = _chunked_qk(Q, K_active, scale=math.sqrt(self.head_dim)) + if is_speculation or is_prefix_caching: + active_scores = torch.where( + active_mask, active_scores, torch.finfo(active_scores.dtype).min + ) + active_scores = active_scores.to(torch.float32) - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = residual + hidden_states + learned_sinks = self.get_learned_sinks() + if learned_sinks is not None: + bsz, _, seqlen, _ = active_scores.shape + sinks = learned_sinks.reshape(1, self.num_heads, 1, 1).expand( + bsz, -1, seqlen, -1 + ) + prior_scores = torch.cat((prior_scores, sinks), dim=-1) - # MLP block with pre and post normalization - residual = hidden_states - hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states)[0] - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = residual + hidden_states + softmax_prior, softmax_active = manual_softmax( + prior_scores, active_scores, is_speculation or is_prefix_caching + ) - outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + if learned_sinks is not None: + softmax_prior = softmax_prior[..., :-1] - return outputs + softmax_prior, softmax_active = ( + softmax_prior.to(Q.dtype), + softmax_active.to(Q.dtype), + ) + attn_prior = _chunked_v_matmul(softmax_prior, V_prior) + attn_active = _chunked_v_matmul(softmax_active, V_active) + return attn_prior + attn_active -# ==================================================================================== -# Model Classes -# ==================================================================================== +# --------------------------------------------------------------------------- +# Decoder layer + text model overrides +# --------------------------------------------------------------------------- -class NeuronGemma3Model(NeuronBaseModel): - """ - Gemma3 base model for NeuronX inference - - This is the main transformer model without the language modeling head. - Includes embeddings, decoder layers, and final normalization. - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3TextModel + +class NeuronGemma3_1B_DecoderLayer(_UpstreamNeuronGemma3DecoderLayer): + """Decoder layer that uses our 1B-specific attention class. + + The upstream decoder hardcodes ``NeuronGemma3Attention(config)`` instead + of using ``config.neuron_config.attn_cls``. We override ``__init__`` + to swap in ``NeuronGemma3_1B_Attention``. """ - def setup_attr_for_model(self, config: Gemma3InferenceConfig): - """Setup attributes for model initialization""" - self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None - self.tp_degree = config.neuron_config.tp_degree - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.max_batch_size = config.neuron_config.max_batch_size - self.buckets = config.neuron_config.buckets - - def init_model(self, config: Gemma3InferenceConfig): - """Initialize the model components""" - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Scaled embeddings (Gemma3-specific) - self.embed_tokens = Gemma3ScaledEmbedding( - config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=config.neuron_config.torch_dtype, - shard_across_embedding=True, - pad=True, - sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, - ) + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + # Replace the attention module with our 1B-specific version. + self.self_attn = NeuronGemma3_1B_Attention(config) + + +class NeuronGemma3_1B_TextModel(_UpstreamNeuronGemma3TextModel): + """Text model that uses our 1B decoder layers.""" - # Decoder layers + def init_model(self, config): + super().init_model(config) + # Replace layers with our 1B-specific decoder layers. + updated_configs = get_updated_configs(config) self.layers = nn.ModuleList( - [NeuronGemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ + NeuronGemma3_1B_DecoderLayer(conf, idx) + for idx, conf in enumerate(updated_configs) + ] ) - # Final normalization - self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) - - # Language modeling head - self.lm_head = ColumnParallelLinear( - config.hidden_size, - config.vocab_size, - bias=False, - pad=True, - gather_output=not self.on_device_sampling, - dtype=config.neuron_config.torch_dtype, - ) +# --------------------------------------------------------------------------- +# CausalLM wrapper +# --------------------------------------------------------------------------- -class NeuronGemma3ForCausalLM(NeuronBaseForCausalLM): - """ - Gemma3 model for causal language modeling on NeuronX - - This class wraps NeuronGemma3Model and provides the interface for - compilation, inference, and weight loading. - - Reference: transformers/models/gemma3/modeling_gemma3.py:Gemma3ForCausalLM - """ - _model_cls = NeuronGemma3Model +class NeuronGemma3_1B_ForCausalLM(_UpstreamNeuronGemma3ForCausalLM): + """Gemma 3 1B causal LM with query_pre_attn_scalar weight fusion. - @staticmethod - def load_hf_model(model_path, **kwargs): - """ - Load the HuggingFace Gemma3 model - - Note: We import here to avoid dependency issues - """ - from transformers import AutoModelForCausalLM + Overrides convert_hf_to_neuron_state_dict to fuse the attention scaling + correction (query_pre_attn_scalar vs head_dim) into Q/K weight matrices. + This avoids any runtime change while producing mathematically identical + attention scores. - return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + Pattern credit: Pierre Lienhart (gemma3-vision contrib). + """ - @staticmethod - def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: - """ - Convert HuggingFace Gemma3 state dict to NeuronX format - - Key mappings: - - embed_tokens.weight -> embed_tokens.embedding.weight - - layers.*.self_attn.q_norm -> layers.*.self_attn.q_layernorm - - layers.*.self_attn.k_norm -> layers.*.self_attn.k_layernorm - - norm.weight -> norm.weight - - lm_head.weight -> lm_head.weight - - Note: The input state_dict already has the "model." prefix stripped by the framework. - """ - neuron_config = config.neuron_config - neuron_state_dict = {} - - # Handle embeddings with scaling - if "embed_tokens.weight" in state_dict: - neuron_state_dict["embed_tokens.embedding.weight"] = ( - state_dict["embed_tokens.weight"].detach().clone() - ) + _model_cls = NeuronGemma3_1B_TextModel - # Handle final norm - if "norm.weight" in state_dict: - neuron_state_dict["norm.weight"] = state_dict["norm.weight"].detach().clone() - - # Handle lm_head - if "lm_head.weight" in state_dict: - neuron_state_dict["lm_head.weight"] = state_dict["lm_head.weight"].detach().clone() - - # Handle decoder layers - num_layers = config.num_hidden_layers - tp_degree = neuron_config.tp_degree - - for i in range(num_layers): - prefix = f"layers.{i}" # No "model." prefix needed - - # Attention weights (Q, K, V projections) - # NOTE: Do NOT rename to qkv_proj.q_proj - the preshard_hook will handle that! - # Just copy the keys as-is - for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: - key = f"{prefix}.self_attn.{proj}.weight" - if key in state_dict: - neuron_state_dict[key] = state_dict[key].detach().clone() - - # Q-K normalization weights (Gemma3-specific) - if f"{prefix}.self_attn.q_norm.weight" in state_dict: - neuron_state_dict[f"{prefix}.self_attn.q_layernorm.weight"] = ( - state_dict[f"{prefix}.self_attn.q_norm.weight"].detach().clone() - ) - - if f"{prefix}.self_attn.k_norm.weight" in state_dict: - neuron_state_dict[f"{prefix}.self_attn.k_layernorm.weight"] = ( - state_dict[f"{prefix}.self_attn.k_norm.weight"].detach().clone() - ) - - # MLP weights - for proj in ["gate_proj", "up_proj", "down_proj"]: - key = f"{prefix}.mlp.{proj}.weight" - if key in state_dict: - neuron_state_dict[key] = state_dict[key].detach().clone() - - # Layer normalization weights (four norms per layer) - for norm_name in [ - "input_layernorm", - "post_attention_layernorm", - "pre_feedforward_layernorm", - "post_feedforward_layernorm", - ]: - key = f"{prefix}.{norm_name}.weight" - if key in state_dict: - neuron_state_dict[key] = state_dict[key].detach().clone() - - # Add rank information for tensor parallelism in attention - neuron_state_dict[f"{prefix}.self_attn.rank_util.rank"] = torch.arange( - 0, tp_degree, dtype=torch.int32 - ) + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + # Run the upstream conversion first (renames q_norm -> q_layernorm, etc). + state_dict = _UpstreamNeuronGemma3ForCausalLM.convert_hf_to_neuron_state_dict( + state_dict, config + ) - # Add rank information for vocabulary parallelism - if neuron_config.vocab_parallel: - neuron_state_dict["embed_tokens.embedding.rank_util.rank"] = torch.arange( - 0, neuron_config.local_ranks_size - ) + # Fuse query_pre_attn_scalar into Q and K weights. + # + # NxDI's attention base uses QK^T / sqrt(head_dim). + # Gemma 3 specifies QK^T / sqrt(query_pre_attn_scalar). + # + # By scaling Q and K weights by gamma, we get: + # (Q*gamma)(K*gamma)^T / sqrt(head_dim) + # = Q K^T * gamma^2 / sqrt(head_dim) + # = Q K^T / sqrt(query_pre_attn_scalar) + # + # gamma = sqrt( (1/sqrt(head_dim)) * sqrt(query_pre_attn_scalar) ) + # = (query_pre_attn_scalar / head_dim) ** 0.25 + query_pre_attn_scalar = getattr(config, "query_pre_attn_scalar", None) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) - # Add rank information for base model - neuron_state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + if query_pre_attn_scalar is not None and query_pre_attn_scalar != head_dim: + default_qk_scaling_factor_inv = math.sqrt(float(query_pre_attn_scalar)) + gemma_qk_scaling_factor = 1.0 / math.sqrt(float(head_dim)) + gamma = math.sqrt(gemma_qk_scaling_factor * default_qk_scaling_factor_inv) - return neuron_state_dict + logger.info( + "Fusing query_pre_attn_scalar=%s into Q/K weights (gamma=%.6f)", + query_pre_attn_scalar, + gamma, + ) - @staticmethod - def update_state_dict_for_tied_weights(state_dict): - """ - Handle tied weights between embeddings and lm_head - - In Gemma3, embeddings are tied by default (tie_word_embeddings=True in config) - Note: The embedding is nested as embed_tokens.embedding.weight due to scaling wrapper - """ - # Check both possible key locations for embedding weights - if "embed_tokens.embedding.weight" in state_dict: - state_dict["lm_head.weight"] = state_dict["embed_tokens.embedding.weight"].clone() - elif "embed_tokens.weight" in state_dict: - # Fallback if the embedding hasn't been wrapped yet - state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + for key in list(state_dict.keys()): + if key.endswith( + ( + ".q_proj.weight", + ".k_proj.weight", + ".qkv_proj.q_proj.weight", + ".qkv_proj.k_proj.weight", + ) + ): + orig_dtype = state_dict[key].dtype + state_dict[key] = (state_dict[key].to(torch.float32) * gamma).to( + orig_dtype + ) + + return state_dict @classmethod def get_config_cls(cls): - """Return the configuration class""" - return Gemma3InferenceConfig + return Gemma3_1B_InferenceConfig diff --git a/contrib/models/gemma-3-1b-it/test/integration/test_model.py b/contrib/models/gemma-3-1b-it/test/integration/test_model.py index d0a87e0c..c894417d 100644 --- a/contrib/models/gemma-3-1b-it/test/integration/test_model.py +++ b/contrib/models/gemma-3-1b-it/test/integration/test_model.py @@ -1,358 +1,220 @@ #!/usr/bin/env python3 """ -Integration tests for Gemma-3-1b-it NeuronX implementation. +Integration test for Gemma 3 1B IT on NeuronX. -Tests model compilation, loading, and inference accuracy/performance. +This test compiles and runs the model using the contrib's subclassed +implementation, then verifies that it generates coherent text. + +Usage (on a Neuron instance): + + cd neuronx-distributed-inference + PYTHONPATH="contrib/models/gemma-3-1b-it/src:src:$PYTHONPATH" \ + python contrib/models/gemma-3-1b-it/test/integration/test_model.py + +Or with pytest: + + PYTHONPATH="contrib/models/gemma-3-1b-it/src:src:$PYTHONPATH" \ + pytest contrib/models/gemma-3-1b-it/test/integration/test_model.py -v --capture=tee-sys """ +import os +import sys +import time +from pathlib import Path + import pytest import torch -import json -from pathlib import Path -from transformers import AutoTokenizer, GenerationConfig +from transformers import AutoTokenizer +# Ensure contrib src is on the path. +_CONTRIB_SRC = str(Path(__file__).resolve().parent.parent.parent / "src") +if _CONTRIB_SRC not in sys.path: + sys.path.insert(0, _CONTRIB_SRC) + +from modeling_gemma3 import Gemma3_1B_InferenceConfig, NeuronGemma3_1B_ForCausalLM from neuronx_distributed_inference.models.config import NeuronConfig from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config -# Import from src directory -import sys -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) -from modeling_gemma3 import NeuronGemma3ForCausalLM, Gemma3InferenceConfig - - -# Test configuration -MODEL_PATH = "/home/ubuntu/models/gemma-3-1b-it/" -COMPILED_MODEL_PATH = "/home/ubuntu/neuron_models/gemma-3-1b-it/" - - -def load_neuron_config_from_compiled(compiled_path: str): - """ - Load neuron configuration from compiled model's neuron_config.json. - - This matches the pattern from validate_model.py to ensure consistency. - """ - config_path = Path(compiled_path) / "neuron_config.json" - - if not config_path.exists(): - raise FileNotFoundError(f"neuron_config.json not found: {config_path}") - - with open(config_path) as f: - config_data = json.load(f) - - if "neuron_config" in config_data: - return config_data["neuron_config"] - else: - return config_data - - -def create_model_for_inference(compiled_path: str, model_path: str): - """ - Create model for inference using the exact pattern from validate_model.py. - - This loads neuron_config from the compiled model to ensure consistency. - """ - # Load neuron config from compiled model - neuron_config_dict = load_neuron_config_from_compiled(compiled_path) - - # Convert dtype - dtype_str = neuron_config_dict.get('torch_dtype', 'torch.bfloat16') - if isinstance(dtype_str, str): - dtype = getattr(torch, dtype_str.split('.')[1]) if dtype_str.startswith('torch.') else torch.bfloat16 - else: - dtype = dtype_str - - # Create NeuronConfig from saved values - neuron_config_kwargs = { - 'tp_degree': neuron_config_dict.get('tp_degree', 2), - 'batch_size': neuron_config_dict.get('batch_size', 1), - 'seq_len': neuron_config_dict.get('seq_len', 512), - 'torch_dtype': dtype, - 'save_sharded_checkpoint': neuron_config_dict.get('save_sharded_checkpoint', True), - 'on_cpu': neuron_config_dict.get('on_cpu', False), - } - - optional_params = ['world_size', 'max_context_length', 'enable_bucketing'] - for param in optional_params: - if param in neuron_config_dict: - neuron_config_kwargs[param] = neuron_config_dict[param] - - if 'max_context_length' not in neuron_config_kwargs: - neuron_config_kwargs['max_context_length'] = neuron_config_kwargs['seq_len'] - - neuron_config = NeuronConfig(**neuron_config_kwargs) - - # Create model config - try: - model_config = Gemma3InferenceConfig.from_pretrained( - model_path, neuron_config=neuron_config, - ) - except (TypeError, AttributeError): - model_config = Gemma3InferenceConfig( - neuron_config, load_config=load_pretrained_config(model_path), - ) - - # Create model - try: - if hasattr(NeuronGemma3ForCausalLM, 'from_pretrained'): - model = NeuronGemma3ForCausalLM.from_pretrained(compiled_path, config=model_config) - else: - raise AttributeError("No from_pretrained method") - except (TypeError, AttributeError, Exception): - model = NeuronGemma3ForCausalLM(model_path, model_config) - - return model, neuron_config - - -def generate_with_neuron_model(model, input_ids, max_new_tokens: int): - """ - Generate tokens using manual forward pass loop. - - Matches the pattern from validate_model.py. - """ - generated_ids = input_ids.clone() - - for _ in range(max_new_tokens): - seq_len = generated_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(generated_ids.shape[0], -1) - - with torch.no_grad(): - outputs = model(generated_ids, position_ids=position_ids) - - if hasattr(outputs, 'logits'): - logits = outputs.logits - elif isinstance(outputs, tuple): - logits = outputs[0] - else: - logits = outputs - - next_token_logits = logits[:, -1, :] - next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) - generated_ids = torch.cat([generated_ids, next_token], dim=-1) - - return generated_ids +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +MODEL_PATH = os.environ.get("GEMMA3_1B_MODEL_PATH", "google/gemma-3-1b-it") +COMPILED_MODEL_PATH = os.environ.get( + "GEMMA3_1B_COMPILED_PATH", "/tmp/gemma3-1b-it-compiled" +) + +# Neuron config matching the validated working configuration. +NEURON_CONFIG_KWARGS = dict( + tp_degree=1, + batch_size=1, + seq_len=512, + max_context_length=512, + torch_dtype=torch.bfloat16, + attn_kernel_enabled=False, + k_cache_transposed=True, + on_device_sampling_config=None, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def tokenizer(): + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok @pytest.fixture(scope="module") def compiled_model(): - """Compile and load model using our custom pattern.""" - # Compile if needed + """Compile (if needed) and load the model.""" + neuron_config = NeuronConfig(**NEURON_CONFIG_KWARGS) + config = Gemma3_1B_InferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + + model = NeuronGemma3_1B_ForCausalLM(MODEL_PATH, config) + compiled_path = Path(COMPILED_MODEL_PATH) - if not (compiled_path / "model.pt").exists(): - print(f"Compiling model to {COMPILED_MODEL_PATH}...") - - neuron_config = NeuronConfig( - tp_degree=2, - batch_size=1, - seq_len=512, - max_context_length=512, - torch_dtype=torch.bfloat16, - ) - - config = Gemma3InferenceConfig( - neuron_config, - load_config=load_pretrained_config(MODEL_PATH), - ) - - model = NeuronGemma3ForCausalLM(MODEL_PATH, config) + if not compiled_path.exists() or not any(compiled_path.iterdir()): + print(f"\nCompiling model to {COMPILED_MODEL_PATH} ...") model.compile(COMPILED_MODEL_PATH) - - # Load using our custom pattern - model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + print("Compilation complete.") + + print(f"Loading model from {COMPILED_MODEL_PATH} ...") model.load(COMPILED_MODEL_PATH) - + print("Model loaded.") return model -@pytest.fixture(scope="module") -def tokenizer(): - """Load tokenizer.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right", trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - return tokenizer +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- -@pytest.fixture(scope="module") -def generation_config(): - """Load generation config.""" - return GenerationConfig.from_pretrained(MODEL_PATH, do_sample=False, top_k=1, trust_remote_code=True) +def generate(model, input_ids, max_new_tokens: int = 20) -> torch.Tensor: + """Autoregressive generation via forward-loop.""" + generated = input_ids.clone() + for _ in range(max_new_tokens): + seq_len = generated.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(generated.shape[0], -1) + with torch.no_grad(): + outputs = model(generated, position_ids=position_ids) + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) + generated = torch.cat([generated, next_token], dim=-1) + return generated + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- -def test_model_loads(compiled_model): - """Test that model loads successfully (smoke test).""" + +def test_smoke(compiled_model): + """Model loads without errors.""" assert compiled_model is not None - assert hasattr(compiled_model, 'config') - assert hasattr(compiled_model.config, 'neuron_config') - print("✓ Smoke test passed - Model loaded successfully") - - -def test_model_generates(compiled_model, tokenizer): - """Test that model can generate text using our custom generation loop.""" - prompt = "def fibonacci(n):" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - - # Use our custom generation function - generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=20) - output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - - assert len(output_text) > len(prompt), "Output should be longer than prompt" - assert "return" in output_text or "if" in output_text, "Should contain Python code" - print(f"✓ Generation test passed") - print(f" Output: {output_text}") - - -def test_output_coherence(compiled_model, tokenizer): - """Test that output is coherent (not gibberish).""" - prompt = "What is 2 + 2?" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - - generated_ids = generate_with_neuron_model(compiled_model, inputs.input_ids, max_new_tokens=30) - output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - - # Coherence checks - assert len(output_text.split()) > 5, "Output should have multiple words" - assert not _is_repetitive(output_text), "Output should not be repetitive" - assert any(c in output_text for c in '.,!?'), "Output should have punctuation" - - print(f"✓ Coherence test passed") - print(f" Output: {output_text[:100]}...") - - -def test_performance_ttft(compiled_model, tokenizer): - """Test Time To First Token (TTFT) performance.""" - import time - - prompt = "Hello, how are you?" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids - - # Warmup - for _ in range(3): - seq_len = input_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) - with torch.no_grad(): - _ = compiled_model(input_ids, position_ids=position_ids) - - # Measure TTFT - times = [] - for _ in range(10): - seq_len = input_ids.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) - - start = time.perf_counter() - with torch.no_grad(): - _ = compiled_model(input_ids, position_ids=position_ids) - end = time.perf_counter() - - times.append((end - start) * 1000) # ms - - avg_ttft = sum(times) / len(times) - - # Should be under 100ms - assert avg_ttft < 100, f"TTFT {avg_ttft:.2f}ms exceeds 100ms threshold" - print(f"✓ TTFT test passed: {avg_ttft:.2f}ms (threshold: 100ms)") - - -def test_performance_throughput(compiled_model, tokenizer): - """Test token generation throughput.""" - import time - - prompt = "Hello" - inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids - num_tokens = 50 - - # Warmup - _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=5) - - # Measure throughput - start = time.perf_counter() - _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=num_tokens) - end = time.perf_counter() - - total_time = end - start - throughput = num_tokens / total_time - - # Should be above 10 tokens/s - assert throughput > 10, f"Throughput {throughput:.2f} tok/s below 10 tok/s threshold" - print(f"✓ Throughput test passed: {throughput:.2f} tok/s (threshold: 10 tok/s)") - - -def _is_repetitive(text: str, max_repeat: int = 5) -> bool: - """Check if text has excessive repetition.""" + assert hasattr(compiled_model, "config") + print("PASS: smoke test") + + +def test_generates_text(compiled_model, tokenizer): + """Model generates non-empty, non-trivial output.""" + prompt = "The capital of France is" + inputs = tokenizer(prompt, return_tensors="pt") + generated = generate(compiled_model, inputs.input_ids, max_new_tokens=20) + text = tokenizer.decode(generated[0], skip_special_tokens=True) + + assert len(text) > len(prompt), f"Output not longer than prompt: {text!r}" + print(f"PASS: generates text\n Output: {text}") + + +def test_coherence(compiled_model, tokenizer): + """Output is coherent (not gibberish or degenerate repetition).""" + prompt = "Explain what a neural network is in one sentence." + inputs = tokenizer(prompt, return_tensors="pt") + generated = generate(compiled_model, inputs.input_ids, max_new_tokens=40) + text = tokenizer.decode(generated[0], skip_special_tokens=True) + words = text.split() - if len(words) < 10: - return False - - for i in range(len(words) - max_repeat): - word = words[i] - if all(words[i+j] == word for j in range(max_repeat)): - return True - - return False + assert len(words) > 8, f"Too few words: {text!r}" + + # Check for degenerate repetition (same word 6+ times in a row). + for i in range(len(words) - 5): + assert not all(words[i + j] == words[i] for j in range(6)), ( + f"Degenerate repetition: {text!r}" + ) + + print(f"PASS: coherence\n Output: {text[:120]}...") + + +def test_vocab_size(compiled_model): + """Config has the correct 1B vocab_size (262144, not 262208).""" + assert compiled_model.config.vocab_size == 262144, ( + f"Expected 262144, got {compiled_model.config.vocab_size}" + ) + print(f"PASS: vocab_size = {compiled_model.config.vocab_size}") +def test_head_dim(compiled_model): + """Config has head_dim=256.""" + head_dim = getattr(compiled_model.config, "head_dim", None) + assert head_dim == 256, f"Expected head_dim=256, got {head_dim}" + print(f"PASS: head_dim = {head_dim}") + + +# --------------------------------------------------------------------------- +# Standalone runner +# --------------------------------------------------------------------------- + if __name__ == "__main__": - # Run tests manually (without pytest) - print("="*80) - print("Gemma-3-1b-it Integration Tests") - print("="*80) - - # Setup - compile if needed + print("=" * 70) + print("Gemma 3 1B IT -- Integration Test") + print("=" * 70) + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + neuron_config = NeuronConfig(**NEURON_CONFIG_KWARGS) + config = Gemma3_1B_InferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + model = NeuronGemma3_1B_ForCausalLM(MODEL_PATH, config) + compiled_path = Path(COMPILED_MODEL_PATH) - if not (compiled_path / "model.pt").exists(): - print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") - - neuron_config = NeuronConfig( - tp_degree=2, - batch_size=1, - seq_len=512, - max_context_length=512, - torch_dtype=torch.bfloat16, - ) - - config = Gemma3InferenceConfig( - neuron_config, - load_config=load_pretrained_config(MODEL_PATH), - ) - - model = NeuronGemma3ForCausalLM(MODEL_PATH, config) + if not compiled_path.exists() or not any(compiled_path.iterdir()): + print(f"\nCompiling to {COMPILED_MODEL_PATH} ...") + t0 = time.time() model.compile(COMPILED_MODEL_PATH) - print("✓ Compilation complete") - - # Load model using our custom pattern - print(f"\nLoading compiled model from {COMPILED_MODEL_PATH}...") - model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + print(f"Compiled in {time.time() - t0:.1f}s") + + print(f"\nLoading from {COMPILED_MODEL_PATH} ...") model.load(COMPILED_MODEL_PATH) - print("✓ Model loaded") - - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right", trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - generation_config = GenerationConfig.from_pretrained(MODEL_PATH, do_sample=False, top_k=1, trust_remote_code=True) - - # Run tests - print("\n" + "="*80) - print("Running Tests") - print("="*80) - - print("\n1. Smoke Test (Model Loading)...") - test_model_loads(model) - - print("\n2. Generation Test...") - test_model_generates(model, tokenizer) - - print("\n3. Coherence Test...") - test_output_coherence(model, tokenizer) - - print("\n4. TTFT Performance Test...") - test_performance_ttft(model, tokenizer) - - print("\n5. Throughput Performance Test...") - test_performance_throughput(model, tokenizer) - - print("\n" + "="*80) - print("✓ All tests passed!") - print("="*80) + print("Loaded.\n") + + print("1. Smoke test ...") + test_smoke(model) + + print("\n2. Vocab size ...") + test_vocab_size(model) + + print("\n3. Head dim ...") + test_head_dim(model) + + print("\n4. Generation ...") + test_generates_text(model, tok) + + print("\n5. Coherence ...") + test_coherence(model, tok) + + print("\n" + "=" * 70) + print("ALL TESTS PASSED") + print("=" * 70)