-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathabstract_dataset.py
41 lines (35 loc) · 1.35 KB
/
abstract_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import cv2
import torch
import numpy as np
from torchvision.datasets import VisionDataset
import albumentations
from albumentations import Compose
from albumentations.pytorch.transforms import ToTensorV2
class AbstractDataset(VisionDataset):
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
super(AbstractDataset, self).__init__(cfg['root'], transforms=transforms,
transform=transform, target_transform=target_transform)
# fix for re-production
np.random.seed(seed)
self.images = list()
self.targets = list()
self.split = cfg['split']
if self.transforms is None:
self.transforms = Compose(
[getattr(albumentations, _['name'])(**_['params']) for _ in cfg['transforms']] +
[ToTensorV2()]
)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
path = self.images[index]
tgt = self.targets[index]
return path, tgt
def load_item(self, items):
images = list()
for item in items:
img = cv2.imread(item)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = self.transforms(image=img)['image']
images.append(image)
return torch.stack(images, dim=0)