-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
37 lines (32 loc) · 1.32 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from torch.utils.data import Dataset
import os
from path import Path
from .utils import *
class PointCloudData(Dataset):
def __init__(self, root_dir, valid=False, folder="train", transforms=None):
self.root_dir = root_dir
folders = [dir for dir in sorted(os.listdir(root_dir)) if os.path.isdir(root_dir / dir)]
self.classes = {folder: i for i, folder in enumerate(folders)}
self.transforms = transforms
self.valid = valid
self.files = []
for category in self.classes.keys():
new_dir = root_dir / Path(category) / folder
for file in os.listdir(new_dir):
if file.endswith('.off'):
sample = {'pcd_path': new_dir / file, 'category': category}
self.files.append(sample)
def __len__(self):
return len(self.files)
def __preproc__(self, file):
verts, faces = read_off(file)
if self.transforms:
pointcloud = self.transforms((verts, faces))
return pointcloud
def __getitem__(self, idx):
pcd_path = self.files[idx]['pcd_path']
category = self.files[idx]['category']
with open(pcd_path, 'r') as f:
pointcloud = self.__preproc__(f)
return {'pointcloud': pointcloud,
'category': self.classes[category]}