Skip to content

abeenoch/DistributedTraining

Repository files navigation

Distributed Training System

Distributed training framework with:

  1. Baseline single-node trainer
  2. Parameter Server architecture (sync/async)
  3. Ring-AllReduce / DDP architecture
  4. Gradient compression (quantization, top-k)
  5. Fault-tolerance coordinator with checkpoint/recovery
  6. Metrics + performance reporting

Setup

Prerequisites

  1. Python 3.11+
  2. pip
  3. (Optional) Docker + Docker Compose

Install dependencies

pip install -r requirements.txt

Quick start (Windows PowerShell)

python -m venv .venv
.\.venv\Scripts\Activate.ps1
pip install -r requirements.txt
pytest -q -m "not slow"

Verify environment

pytest -q -m "not slow"

Project Layout

  1. src/config.py: config schema + validation
  2. src/trainer.py: baseline trainer
  3. src/parameter_server.py: parameter server RPC service
  4. src/worker.py: worker training/heartbeat/metric reporting
  5. src/coordinator.py: lifecycle, failure detection, shard reassignment, recovery
  6. src/ddp_trainer.py: DDP trainer
  7. src/compression.py: quantization + top-k compressors
  8. src/metrics.py: TensorBoard + performance report generation
  9. generate_performance_report.py: consolidated report CLI
  10. run_baseline_training.py: baseline runner
  11. run_ddp_training.py: DDP runner
  12. run_parameter_server_training.py: parameter-server runner
  13. docker/: container entrypoint + Dockerfile

Usage

Baseline training

python run_baseline_training.py

Outputs:

  1. TensorBoard logs under logs/test_baseline
  2. Checkpoints under checkpoints/test_baseline

Parameter Server architecture

Use run_parameter_server_training.py for direct PS-only training on local machine.

Single-machine CLI run:

python run_parameter_server_training.py --workers 2 --dataset mnist --num-epochs 2

Four-worker CLI run:

python run_parameter_server_training.py --workers 4 --dataset mnist --num-epochs 2

Async mode with compression:

python run_parameter_server_training.py --workers 2 --dataset mnist --num-epochs 2 --aggregation-mode async --compression-enabled --compression-type topk --compression-ratio 0.1

Common flags:

  1. --workers: number of PS workers (default 2)
  2. --dataset: mnist, fashion_mnist, cifar10
  3. --num-epochs, --batch-size, --learning-rate
  4. --aggregation-mode: sync or async
  5. --compression-enabled + --compression-type {quantization,topk} + --compression-ratio
  6. --max-train-samples / --max-test-samples for fast smoke runs
  7. --checkpoint-dir / --log-dir for output paths

Example smoke test (fast):

python run_parameter_server_training.py --workers 2 --dataset mnist --num-epochs 1 --max-train-samples 512 --max-test-samples 128

Four-worker smoke test (fast):

python run_parameter_server_training.py --workers 4 --dataset mnist --num-epochs 1 --max-train-samples 512 --max-test-samples 128

Expected outputs:

  1. Final summary in terminal (Final Accuracy, Final Loss, Total Time, Throughput)
  2. Checkpoint at checkpoints/ps_cli/parameter_server_final.pt (unless overridden)
  3. Worker logs under logs/ps_cli (unless overridden)

Containerized run:

docker compose up --build

Services:

  1. parameter_server
  2. worker_0
  3. worker_1

Ring-AllReduce / DDP

Single-process MNIST smoke run:

python run_ddp_training.py --world-size 1 --dataset mnist --num-epochs 1

Two-process MNIST run:

python run_ddp_training.py --world-size 2 --dataset mnist --num-epochs 2

Synthetic-data debug run:

python run_ddp_training.py --dataset synthetic --world-size 2 --num-epochs 1 --num-samples 2048 --optimizer sgd

Final Validation

Run full final checkpoint validation:

python run_final_checkpoint_validation.py --workers 4 --epochs 2 --max-train-samples 5000 --max-test-samples 1000

Notes:

  1. speedup_ok_property_28 only evaluates true when --workers 4.
  2. First CIFAR-10 run may download data to ./data; later runs use cached files.

Configuration Options

System accepts YAML/JSON config structure with:

training

  1. batch_size (int > 0)
  2. learning_rate (float > 0)
  3. num_epochs (int > 0)
  4. dataset (mnist, fashion_mnist, cifar10)
  5. model_architecture (string, non-empty)
  6. checkpoint_interval (int > 0)

system

  1. num_workers (int > 0)
  2. architecture (parameter_server, ddp)
  3. aggregation_mode (sync, async)
  4. compression_enabled (bool)
  5. compression_type (quantization, topk)
  6. compression_ratio (float in (0,1])
  7. heartbeat_interval (float > 0)
  8. heartbeat_timeout (float > 0, must be > interval)
  9. checkpoint_dir (string path)
  10. log_dir (string path)

See:

  1. configs/example_config.yaml
  2. configs/ps_sync_mnist.yaml
  3. configs/ps_async_quantization.yaml
  4. configs/ps_sync_topk.yaml
  5. configs/ps_sync_mnist_4workers.yaml
  6. configs/ps_async_quantization_4workers.yaml
  7. configs/ddp_mnist.yaml

Metrics and Reporting

TensorBoard logging

Worker and trainer logs include:

  1. loss
  2. accuracy
  3. throughput
  4. gradient time
  5. communication time
  6. samples processed
  7. compression ratio (when enabled)

Consolidated performance report

python generate_performance_report.py \
  --output reports/perf_report.json \
  --baseline-json artifacts/baseline.json \
  --ps-json artifacts/ps.json \
  --ddp-rank-json artifacts/ddp_rank_0.json artifacts/ddp_rank_1.json \
  --worker-metrics-json artifacts/worker_metrics.json

Docker

Build and run

docker compose up --build

Volumes

  1. ./checkpoints:/app/checkpoints
  2. ./logs:/app/logs

Networking

  1. All services on training_net
  2. Workers reach parameter server at parameter_server:50051

Testing

Fast suite

pytest -q -m "not slow"

Slow integration/property tests

pytest -q -m "slow"

Troubleshooting

  1. TensorBoard/TensorFlow import issues on Windows:
    • code uses a safe fallback/no-TF path automatically.
  2. DDP multiprocessing permission errors:
    • use run_ddp_training.py subprocess launcher.
  3. gRPC stub generation errors:
    • install grpcio-tools, then regenerate protobuf stubs.
  4. Slow test timeouts:
    • run specific files first (pytest tests/test_<file>.py -q).
  5. Docker connectivity issues:
    • ensure compose network is created and workers use parameter_server:50051.

About

A PyTorch distributed-training playground that implements and compares single-node baseline, Parameter Server (sync/async), and DDP training, with optional gradient compression, coordinator-based fault tolerance, checkpoints/logging, and reproducible validation/reporting scripts.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors