Skip to content

Add BioReason-Pro contrib model#120

Open
jimburtoft wants to merge 6 commits intoaws-neuron:mainfrom
jimburtoft:contrib/bioreason-pro
Open

Add BioReason-Pro contrib model#120
jimburtoft wants to merge 6 commits intoaws-neuron:mainfrom
jimburtoft:contrib/bioreason-pro

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

@jimburtoft jimburtoft commented Apr 9, 2026

Summary

  • Adds BioReason-Pro multimodal protein function prediction model as a contrib model
  • Pipeline: ESM3 protein encoder (CPU) → projection layers → Qwen3-4B (NxDI on NeuronCore)
  • 8 runtime patches enable inputs_embeds passthrough through NxDI compiled model
  • Process-level data parallelism with DataParallelRunner for multi-core inference

Benchmark Results (trn2.3xlarge, SDK 2.28)

Single-Core Performance

Batch Size tok/s Speedup
BS=1 44.0 1.0x
BS=4 98.6 2.2x
BS=8 128.8 2.9x
BS=16 134.7 3.1x

Data Parallelism + Batching

Config Workers Aggregate tok/s Wall Time (40 proteins)
DP=4, BS=1 (LNC=2) 4 129.8 271.0s
DP=4, BS=4 (LNC=2) 4 299.9 215.3s
  • 2.1x GPU A10G peak throughput (299.9 vs 142.3 tok/s)
  • Zero contention between NeuronCores (per-worker tok/s identical across all 4 workers)
  • 0 errors across 40 proteins (64,568 tokens total)

GPU Comparison (A10G)

Batch Size GPU tok/s Neuron tok/s Neuron vs GPU
BS=1 42.9 44.0 1.03x
BS=4 142.3 98.6 0.69x
DP=4 BS=4 142.3 299.9 2.1x

Files

  • src/modeling_bioreason.py — BioReasonPipeline with ESM3 encoder, embedding injection, predict/predict_batch
  • src/patch_nxdi_embeds.py — 8 V3b patches for inputs_embeds passthrough
  • src/dp_launcher.py — DataParallelRunner for multi-core inference with batching
  • test/integration/test_model.py — 5 integration tests
  • test/integration/benchmark_dp.py — Data-parallel benchmark script

Key Technical Details

  • ESM3 tokenizer compatibility patch for transformers >= 4.47 (mask_token fix)
  • Process-level DP via NEURON_RT_VISIBLE_CORES (NxDI built-in DP requires TP>1)
  • Compiled model shared across all DP workers (no per-core recompilation)
  • Supports BS=1,4,8,16 (compiled per batch size)

…on on Neuron

BioReason-Pro combines ESM3-small (~1.4B) protein encoder with Qwen3-4B
backbone via NxDI for protein function prediction. Includes 8 runtime
patches enabling inputs_embeds passthrough in NxDI 0.8.0, plus full
integration test suite validated on trn2.3xlarge (SDK 2.28).
When sweeping multiple batch sizes, compiled artifacts (model.pt, weights/,
neuron_config.json) were saved to the model weights directory, causing
subsequent batch sizes to load the first compiled model instead of
recompiling. The new compiled_model_path parameter directs compiled
artifacts to a separate directory per batch size while still reading
HF weights from model_path.
…tok/s at DP=4 BS=4

- Add predict_batch() method for batched inputs_embeds generation
- Fix ESM3 tokenizer mask_token incompatibility with transformers >= 4.47
  (mask_token returns None in ESM3 >= 3.0.4, breaking tokenize_sequence)
- Update dp_launcher worker to collect batches when batch_size > 1
- DP=4 BS=4 benchmark: 299.9 tok/s aggregate (2.1x GPU A10G peak)
  40 proteins, 64,568 tokens, 215.3s wall time, 0 errors
- Update README with DP+batching results and compatibility matrix
- Add measured DP sweep: LNC=2 DP=4 BS=4/8/16 and LNC=1 DP=8 BS=1/4/8
- Peak: 537.3 tok/s (LNC=2 DP=4 BS=16) vs H100 best 522.6 tok/s = 1.03x
- LNC=2 DP=4 confirmed optimal (1.52x faster than LNC=1 DP=8)
- Update compatibility matrix with all validated configs
- All numbers measured, no projections
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant