-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata.py
More file actions
executable file
·53 lines (40 loc) · 1.93 KB
/
data.py
File metadata and controls
executable file
·53 lines (40 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import colorful
from torch.utils.data import ConcatDataset, DataLoader
from dataloaders import DATASET
# =========
# Scheduler
# =========
class DataScheduler():
def __init__(self, config, writer):
self.config = config
self.datasets = {}
self.total_step = 0
# Prepare datasets
dataset_name = config['data_name']
self.datasets[dataset_name] = DATASET[dataset_name](self.config, writer, split='train', noise=True)
self.total_step += len(self.datasets[dataset_name]) // self.config['batch_size']
self.task_datasets = []
dataset = ConcatDataset([self.datasets[dataset_name]])
self.task_datasets.append((1, dataset))
def __iter__(self):
for t_i, (epoch, task) in enumerate(self.task_datasets):
print(colorful.bold_green('\nProgress to Task %d' % t_i).styled_string)
collate_fn = task.datasets[0].dataset.collate_fn
# for data in DataLoader(task, batch_size=self.config['batch_size'],
# num_workers=self.config['num_workers'],
# collate_fn=collate_fn, drop_last=True): # shuffle=True
# yield data, t_i
task_loader = DataLoader(task, batch_size=self.config['batch_size'],
num_workers=self.config['num_workers'],
collate_fn=collate_fn,
drop_last=False,
pin_memory=True, # better when training on GPU.
shuffle=True)
yield task_loader, epoch, t_i
def get_task(self, t):
return self.task_datasets[t][1]
def get_dataloader(self, dataset):
collate_fn = dataset.dataset.datasets[0].dataset.collate_fn
return DataLoader(dataset, self.config['batch_size'], shuffle=True, collate_fn=collate_fn)
def __len__(self):
return self.total_step