Skip to content

Update gemma-3-1b-it contrib: fix head_dim=256 issues, add chunked attention and SWA support#129

Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/gemma-3-1b-it
Open

Update gemma-3-1b-it contrib: fix head_dim=256 issues, add chunked attention and SWA support#129
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/gemma-3-1b-it

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Replaces the standalone gemma-3-1b-it implementation with a thin subclass of the official models/gemma3/ that fixes 5 issues specific to the 1B variant's unusual head_dim=256 architecture
  • Future NxDI improvements to models/gemma3/ flow through automatically

Why This Update?

The 1B variant has unusual architecture parameters vs the 4B/12B/27B variants:

Parameter 1B 4B/12B/27B
head_dim 256 128
vocab_size 262144 262208
num_kv_heads 1 4-16

The previous implementation disabled sliding window attention entirely and reimplemented all components from scratch. This update subclasses the official code and fixes only what's needed.

Issues Fixed

  1. Chunked attention for head_dim=256 -- Q@K^T and scores@V split into 128-wide chunks to avoid Neuron compiler DGE out-of-bounds errors
  2. vocab_size from HF config -- Reads actual value (262144) instead of hardcoded 262208
  3. Auto-disable NKI attention kernel -- NKI kernel asserts head_dim <= 128
  4. k_cache_transposed + SWA + GQA fix -- Fixes layout mismatch where repeat_kv assumes BHSD but receives BHDS
  5. query_pre_attn_scalar weight fusion -- Fuses Gemma3's attention scaling correction into Q/K weights at load time (zero runtime overhead)

Architecture

Thin subclasses of official implementation (no upstream files modified):

models/gemma3/modeling_gemma3.py  (upstream, unchanged)
  +-- contrib/gemma-3-1b-it/src/modeling_gemma3.py
        |-- Gemma3_1B_InferenceConfig
        |-- NeuronGemma3_1B_Attention (chunked attn + k_cache fix)
        +-- NeuronGemma3_1B_ForCausalLM (weight fusion)

Required Configuration

Parameter Value Why
attn_kernel_enabled false NKI kernel asserts head_dim <= 128
k_cache_transposed true Required for SWA+GQA fix
context_encoding_buckets [512]+ Compiler OOB for buckets < 512

Validation

  • SDK 2.28: Tested on trn2.3xlarge, TP=1, batch=4/16
  • SDK 2.29: Tested PASS with attn_kernel_enabled=False workaround
  • Multi-prompt accuracy validation (not just code completion)
  • vLLM serving validated with near-linear throughput scaling

Compatibility

Instance SDK 2.28 SDK 2.29
trn2.3xlarge Validated Validated
inf2 Expected compatible NxDI drops inf2 in 2.29

…emma3/

Replaces the standalone Annapurna Labs implementation with subclasses
of the official NeuronGemma3* classes, adding only the overrides needed
for the 1B variant (head_dim=256, vocab_size=262144, GQA 4:1):

- Chunked Q@K^T and scores@V for head_dim>128 (compiler DGE OOB fix)
- k_cache_transposed restored for SWA layers (GQA repeat_kv layout fix)
- vocab_size read from HF config instead of hardcoded 262208
- NKI attention kernel auto-disabled when head_dim>128
- query_pre_attn_scalar fused into Q/K weights at load time (zero cost)
- DecoderLayer and TextModel overrides to swap in correct attention class

Tested on trn2.3xlarge with upstream NxDI main (0.8.0+26b1fcf5.dev):
compile 27s, load 12.5s, all 5 integration tests pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant