Skip to content

FEAT: ChargE3Net Fine-Tuning Pipeline on LeMatRho#8

Open
speckhard wants to merge 5 commits intoLeMaterial:mainfrom
speckhard:feat/charge3net
Open

FEAT: ChargE3Net Fine-Tuning Pipeline on LeMatRho#8
speckhard wants to merge 5 commits intoLeMaterial:mainfrom
speckhard:feat/charge3net

Conversation

@speckhard
Copy link
Copy Markdown
Collaborator

feat/charge3net — ChargE3Net Fine-Tuning Pipeline

Overview

Adds a complete PyTorch training pipeline to fine-tune ChargE3Net ("Higher-Order Equivariant Neural Networks for Charge Density Prediction in Materials", Koker et al., npj Computational Materials 2024) on LeMatRho charge density data stored in Parquet format. The model is initialized from a pre-trained Materials Project checkpoint (1.9M parameters) and fine-tuned on 65,239 materials with 10x10x10 charge density grids.

New Files

File Lines Purpose
data.py 281 Lazy-loading PyTorch Dataset. Reads LeMatRho Parquet chunks, converts to ase.Atoms + density grids, and builds charge3net-compatible graph dicts via KdTreeGraphConstructor. Only a lightweight index (~2 MB) is held in memory; rows are read on-the-fly.
model.py 132 Thin wrapper around charge3net's E3DensityModel. Handles instantiation with MP checkpoint hyperparameters and loading from 3 checkpoint formats (legacy PyTorch Lightning, new charge3net, raw state_dict).
train.py 392 Full training script with CLI. Supports normal training, smoke test, single-batch overfit test, W&B logging (online/offline), and checkpoint resumption across SLURM jobs. Metrics: L1, NMAPE, RMSE, NRMSE.
submit_charge3net.sh 47 SLURM GPU submission script for Jean Zay HPC (A100, 20h wall time). Auto-resumes from latest.pt if present.
.gitignore 6 Excludes .env, .venv/, __pycache__/, checkpoints/, wandb/.

Modified Files

File Change
pyproject.toml Added dependencies: e3nn, scipy, lz4, pyarrow, pandas, wandb, python-dotenv

Verification Tests

1. Smoke test (--smoke-test): Loads the pre-trained MP checkpoint, runs a single forward pass on one batch, and verifies output shape and loss computation. Confirms the data pipeline produces tensors compatible with the model.

2. Single-batch overfit test (--overfit-single-batch): Fetches one batch, trains on it repeatedly for N epochs without validation. Proves the model can learn from this data pipeline — L1 dropped from 34.2 to 7.7 (77% reduction) over 150 epochs on CPU without pre-trained weights.

3. Full training run (Jean Zay A100 GPU): Fine-tuning from the MP checkpoint on 61,978 training samples, validating on 3,261 samples. Results after 8 epochs:

Epoch train L1 (e/ų) val L1 (e/ų) val NMAPE (%) val RMSE (e/ų) val NRMSE (%)
1 5.27 3.94 14.15 8.99 31.89
2 3.56 3.50 12.75 7.79 27.96
3 3.21 3.15 11.35 6.84 24.35
4 3.00 3.06 10.84 6.73 23.35
5 2.85 3.03 10.66 6.48 22.45
6 2.75 2.84 10.08 6.15 21.53
7 2.66 2.81 9.91 6.05 21.04
8 2.59 2.71 9.57 5.84 20.29

Both train and val L1 are monotonically decreasing. Training is ongoing (50 epochs target, ~2h/epoch on A100).

Git Log

f3ffde4 Add --resume-from for checkpoint resumption across SLURM jobs
252acac Fix wandb on air-gapped clusters: add offline mode and init timeout
99d391a Add W&B logging, RMSE/NRMSE metrics, and Jean Zay SLURM script
f72fa52 Training pipeline to train Charge3net on LeMat-Rho

Pre-trained Weights

The Materials Project checkpoint (charge3net_mp.pt, 23 MB, 1.9M params) is checked into the AIforGreatGood/charge3net repo under models/. It was trained for 245 epochs / 407,868 steps on Materials Project charge densities. We cloned the repo and used the checkpoint directly — no separate download required.

External Dependency

Requires AIforGreatGood/charge3net cloned as a sibling directory (for E3DensityModel, KdTreeGraphConstructor, collate_list_of_dicts, PowerDecayScheduler).

- Add wandb integration with --wandb-project/--wandb-entity/--no-wandb flags
- Add compute_rmse() and compute_nrmse() validation metrics
- Log per-step train loss and per-epoch train/val metrics to W&B
- Load WANDB_API_KEY from .env via python-dotenv
- Add submit_charge3net.sh for Jean Zay A100 GPU jobs
- Add .gitignore (excludes .env, checkpoints, wandb, etc.)
- save_checkpoint now includes global_step
- load_checkpoint restores model, optimizer, scheduler, epoch, best_nmape, global_step
- SLURM script auto-detects latest.pt and passes --resume-from
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant