Distributed training framework with:
- Baseline single-node trainer
- Parameter Server architecture (sync/async)
- Ring-AllReduce / DDP architecture
- Gradient compression (quantization, top-k)
- Fault-tolerance coordinator with checkpoint/recovery
- Metrics + performance reporting
- Python 3.11+
pip- (Optional) Docker + Docker Compose
pip install -r requirements.txtpython -m venv .venv
.\.venv\Scripts\Activate.ps1
pip install -r requirements.txt
pytest -q -m "not slow"pytest -q -m "not slow"src/config.py: config schema + validationsrc/trainer.py: baseline trainersrc/parameter_server.py: parameter server RPC servicesrc/worker.py: worker training/heartbeat/metric reportingsrc/coordinator.py: lifecycle, failure detection, shard reassignment, recoverysrc/ddp_trainer.py: DDP trainersrc/compression.py: quantization + top-k compressorssrc/metrics.py: TensorBoard + performance report generationgenerate_performance_report.py: consolidated report CLIrun_baseline_training.py: baseline runnerrun_ddp_training.py: DDP runnerrun_parameter_server_training.py: parameter-server runnerdocker/: container entrypoint + Dockerfile
python run_baseline_training.pyOutputs:
- TensorBoard logs under
logs/test_baseline - Checkpoints under
checkpoints/test_baseline
Use run_parameter_server_training.py for direct PS-only training on local machine.
Single-machine CLI run:
python run_parameter_server_training.py --workers 2 --dataset mnist --num-epochs 2Four-worker CLI run:
python run_parameter_server_training.py --workers 4 --dataset mnist --num-epochs 2Async mode with compression:
python run_parameter_server_training.py --workers 2 --dataset mnist --num-epochs 2 --aggregation-mode async --compression-enabled --compression-type topk --compression-ratio 0.1Common flags:
--workers: number of PS workers (default2)--dataset:mnist,fashion_mnist,cifar10--num-epochs,--batch-size,--learning-rate--aggregation-mode:syncorasync--compression-enabled+--compression-type {quantization,topk}+--compression-ratio--max-train-samples/--max-test-samplesfor fast smoke runs--checkpoint-dir/--log-dirfor output paths
Example smoke test (fast):
python run_parameter_server_training.py --workers 2 --dataset mnist --num-epochs 1 --max-train-samples 512 --max-test-samples 128Four-worker smoke test (fast):
python run_parameter_server_training.py --workers 4 --dataset mnist --num-epochs 1 --max-train-samples 512 --max-test-samples 128Expected outputs:
- Final summary in terminal (
Final Accuracy,Final Loss,Total Time,Throughput) - Checkpoint at
checkpoints/ps_cli/parameter_server_final.pt(unless overridden) - Worker logs under
logs/ps_cli(unless overridden)
Containerized run:
docker compose up --buildServices:
parameter_serverworker_0worker_1
Single-process MNIST smoke run:
python run_ddp_training.py --world-size 1 --dataset mnist --num-epochs 1Two-process MNIST run:
python run_ddp_training.py --world-size 2 --dataset mnist --num-epochs 2Synthetic-data debug run:
python run_ddp_training.py --dataset synthetic --world-size 2 --num-epochs 1 --num-samples 2048 --optimizer sgdRun full final checkpoint validation:
python run_final_checkpoint_validation.py --workers 4 --epochs 2 --max-train-samples 5000 --max-test-samples 1000Notes:
speedup_ok_property_28only evaluatestruewhen--workers 4.- First CIFAR-10 run may download data to
./data; later runs use cached files.
System accepts YAML/JSON config structure with:
batch_size(int > 0)learning_rate(float > 0)num_epochs(int > 0)dataset(mnist,fashion_mnist,cifar10)model_architecture(string, non-empty)checkpoint_interval(int > 0)
num_workers(int > 0)architecture(parameter_server,ddp)aggregation_mode(sync,async)compression_enabled(bool)compression_type(quantization,topk)compression_ratio(float in(0,1])heartbeat_interval(float > 0)heartbeat_timeout(float > 0, must be > interval)checkpoint_dir(string path)log_dir(string path)
See:
configs/example_config.yamlconfigs/ps_sync_mnist.yamlconfigs/ps_async_quantization.yamlconfigs/ps_sync_topk.yamlconfigs/ps_sync_mnist_4workers.yamlconfigs/ps_async_quantization_4workers.yamlconfigs/ddp_mnist.yaml
Worker and trainer logs include:
- loss
- accuracy
- throughput
- gradient time
- communication time
- samples processed
- compression ratio (when enabled)
python generate_performance_report.py \
--output reports/perf_report.json \
--baseline-json artifacts/baseline.json \
--ps-json artifacts/ps.json \
--ddp-rank-json artifacts/ddp_rank_0.json artifacts/ddp_rank_1.json \
--worker-metrics-json artifacts/worker_metrics.jsondocker compose up --build./checkpoints:/app/checkpoints./logs:/app/logs
- All services on
training_net - Workers reach parameter server at
parameter_server:50051
pytest -q -m "not slow"pytest -q -m "slow"- TensorBoard/TensorFlow import issues on Windows:
- code uses a safe fallback/no-TF path automatically.
- DDP multiprocessing permission errors:
- use
run_ddp_training.pysubprocess launcher.
- use
- gRPC stub generation errors:
- install
grpcio-tools, then regenerate protobuf stubs.
- install
- Slow test timeouts:
- run specific files first (
pytest tests/test_<file>.py -q).
- run specific files first (
- Docker connectivity issues:
- ensure compose network is created and workers use
parameter_server:50051.
- ensure compose network is created and workers use