This repository has been archived by the owner on Mar 30, 2019. It is now read-only.
forked from samarthbhargav/hackathon4good
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
54 lines (40 loc) · 1.83 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
class CaladriusDataset(Dataset):
def __init__(self,
directory,
split='train',
inputSize=(32, 32),
transforms=None):
self.directory = directory
with open(os.path.join(directory, 'labels.txt')) as labels_file:
self.datapoints = [x.strip() for x in tqdm(labels_file.readlines())]
self.transforms = transforms
def __len__(self):
return len(self.datapoints)
def __getitem__(self, idx):
filename, before_image, after_image, damage = self.loadDatapoint(idx)
if self.transforms:
before_image = self.transforms(before_image)
after_image = self.transforms(after_image)
return (filename, before_image, after_image, damage)
def loadDatapoint(self, idx):
line = self.datapoints[idx]
filename, damage = line.split(' ')
before_image = Image.open(os.path.join(self.directory, 'before', filename))
after_image = Image.open(os.path.join(self.directory, 'after', filename))
return filename, before_image, after_image, float(damage)
class Datasets(object):
def __init__(self, args, transforms):
self.args = args
self.dataPath = args.dataPath
self.batchSize = args.batchSize
self.transforms = transforms
self.numberOfWorkers = args.numberOfWorkers
def load(self, set_name):
assert set_name in {'train', 'validation', 'test'}
dataset = CaladriusDataset(os.path.join(self.dataPath, set_name), transforms=self.transforms[set_name])
dataLoader = DataLoader(dataset, batch_size=self.batchSize, shuffle=(set_name == 'train'), num_workers=self.numberOfWorkers)
return dataset, dataLoader