-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathcustom_dataset.py
50 lines (38 loc) · 1.62 KB
/
custom_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
class MultiViewDataSet(Dataset):
def find_classes(self, dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __init__(self, root, data_type, transform=None, target_transform=None):
self.x = []
self.y = []
self.root = root
self.classes, self.class_to_idx = self.find_classes(root)
self.transform = transform
self.target_transform = target_transform
# root / <label> / <train/test> / <item> / <view>.png
for label in os.listdir(root): # Label
for item in os.listdir(root + '/' + label + '/' + data_type):
views = []
for view in os.listdir(root + '/' + label + '/' + data_type + '/' + item):
views.append(root + '/' + label + '/' + data_type + '/' + item + '/' + view)
self.x.append(views)
self.y.append(self.class_to_idx[label])
# Override to give PyTorch access to any image on the dataset
def __getitem__(self, index):
orginal_views = self.x[index]
views = []
for view in orginal_views:
im = Image.open(view)
im = im.convert('RGB')
if self.transform is not None:
im = self.transform(im)
views.append(im)
return views, self.y[index]
# Override to give PyTorch size of dataset
def __len__(self):
return len(self.x)