-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
234 lines (190 loc) · 7.25 KB
/
train.py
File metadata and controls
234 lines (190 loc) · 7.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""
BioLab YOLO Training Script for Backend.AI GPU Cluster
======================================================
This script is designed to run on Backend.AI (SKKU Supercomputing Center).
It trains a YOLOv8 model for pad detection.
Usage:
python train.py --data /home/work/data/data.yaml --epochs 100
Arguments:
--data: Path to data.yaml file (YOLO format)
--epochs: Number of training epochs (default: 100)
--batch: Batch size (default: 16)
--imgsz: Image size (default: 640)
--model: Base model to use (default: yolov8n.pt)
--project: Output directory (default: /home/work/output)
--name: Run name (default: train)
--augment: Enable data augmentation (default: True)
"""
import argparse
import os
import sys
import time
import json
from datetime import datetime
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(description="BioLab YOLO Training")
# Required
parser.add_argument("--data", type=str, required=True,
help="Path to data.yaml file")
# Training parameters
parser.add_argument("--epochs", type=int, default=100,
help="Number of epochs")
parser.add_argument("--batch", type=int, default=16,
help="Batch size")
parser.add_argument("--imgsz", type=int, default=640,
help="Image size")
parser.add_argument("--model", type=str, default="yolov8n.pt",
help="Base model (yolov8n.pt, yolov8s.pt, etc.)")
# Output
parser.add_argument("--project", type=str, default="/home/work/output",
help="Output directory")
parser.add_argument("--name", type=str, default="train",
help="Run name")
# Augmentation
parser.add_argument("--augment", type=bool, default=True,
help="Enable data augmentation")
parser.add_argument("--no-augment", dest="augment", action="store_false",
help="Disable data augmentation")
return parser.parse_args()
def print_system_info():
"""Print system and GPU information."""
print("=" * 60)
print("BioLab YOLO Training Script")
print("=" * 60)
print(f"Start Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Python: {sys.version}")
# Check GPU
try:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
except ImportError:
print("PyTorch not installed!")
sys.exit(1)
print("=" * 60)
def validate_data_yaml(data_path: str) -> bool:
"""Validate that data.yaml exists and has required fields."""
import yaml
if not os.path.exists(data_path):
print(f"ERROR: data.yaml not found at {data_path}")
return False
with open(data_path, 'r') as f:
data = yaml.safe_load(f)
required_fields = ['path', 'train', 'val', 'names', 'nc']
for field in required_fields:
if field not in data:
print(f"ERROR: Missing required field '{field}' in data.yaml")
return False
print(f"Dataset path: {data['path']}")
print(f"Train: {data['train']}")
print(f"Val: {data['val']}")
print(f"Classes: {data['nc']} - {data['names']}")
return True
def train(args):
"""Run YOLO training."""
from ultralytics import YOLO
print("\n[1/3] Loading model...")
model = YOLO(args.model)
print(f"\n[2/3] Starting training...")
print(f" - Epochs: {args.epochs}")
print(f" - Batch size: {args.batch}")
print(f" - Image size: {args.imgsz}")
print(f" - Augmentation: {args.augment}")
# Training arguments
train_args = {
"data": args.data,
"epochs": args.epochs,
"imgsz": args.imgsz,
"batch": args.batch,
"project": args.project,
"name": args.name,
"exist_ok": True,
"device": 0, # Use first GPU
"verbose": True,
"save": True,
"amp": True, # Mixed precision for speed
}
# Augmentation settings (optimized for colorchip/pad detection)
if args.augment:
train_args.update({
"degrees": 0.0, # No rotation (pads are axis-aligned)
"translate": 0.0, # No translation
"scale": 0.05, # Slight scale variation
"shear": 0.0, # No shear
"perspective": 0.0, # No perspective
"hsv_h": 0.01, # Slight hue variation
"hsv_s": 0.3, # Saturation variation (important for color)
"hsv_v": 0.4, # Value/brightness variation
"flipud": 0.0, # No vertical flip
"fliplr": 0.5, # Horizontal flip OK
"mosaic": 0.0, # No mosaic (preserve pad structure)
})
else:
train_args.update({
"hsv_h": 0.0, "hsv_s": 0.0, "hsv_v": 0.0,
"degrees": 0.0, "translate": 0.0, "scale": 0.0,
"shear": 0.0, "perspective": 0.0,
"flipud": 0.0, "fliplr": 0.0, "mosaic": 0.0
})
# Train
start_time = time.time()
results = model.train(**train_args)
duration = time.time() - start_time
print(f"\n[3/3] Training completed!")
print(f" - Duration: {duration/60:.1f} minutes")
# Check results
output_dir = Path(args.project) / args.name
best_pt = output_dir / "weights" / "best.pt"
last_pt = output_dir / "weights" / "last.pt"
if best_pt.exists():
print(f" - Best model: {best_pt}")
print(f" - Model size: {best_pt.stat().st_size / 1e6:.1f} MB")
# Extract metrics
try:
val_map = results.results_dict.get("metrics/mAP50-95(B)", 0)
val_map50 = results.results_dict.get("metrics/mAP50(B)", 0)
print(f" - Val mAP50-95: {float(val_map):.4f}")
print(f" - Val mAP50: {float(val_map50):.4f}")
except:
pass
# Save training summary
summary = {
"model": args.model,
"epochs": args.epochs,
"batch_size": args.batch,
"image_size": args.imgsz,
"augmentation": args.augment,
"duration_minutes": round(duration / 60, 1),
"best_model_path": str(best_pt),
"completed_at": datetime.now().isoformat(),
}
summary_path = output_dir / "training_summary.json"
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f" - Summary saved: {summary_path}")
print("\n" + "=" * 60)
print("Training Complete!")
print("=" * 60)
return results
def main():
args = parse_args()
# Print system info
print_system_info()
# Validate data.yaml
print("\nValidating dataset...")
if not validate_data_yaml(args.data):
sys.exit(1)
# Run training
try:
results = train(args)
except Exception as e:
print(f"\nERROR: Training failed!")
print(f" {type(e).__name__}: {e}")
sys.exit(1)
if __name__ == "__main__":
main()