-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
107 lines (85 loc) · 3.55 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from sklearn.model_selection import train_test_split
import os
import cv2 as cv
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from loguru import logger
class VideoDataset(torch.utils.data.Dataset):
'''
Custom Dataset for loading videos and their class labels
'''
def __init__(self, data_dir, num_classes=10, num_frames=20, transform=None, target_transform=None):
super().__init__()
self.data_dir = data_dir
self.transform = transform
self.target_transform = target_transform
self.num_classes = num_classes
self.num_frames = num_frames
self.video_filename_list = []
self.classesIdx_list = []
self.class_dict = {class_label: idx for idx, class_label in enumerate(
sorted(os.listdir(self.data_dir)))}
for class_label, class_idx in self.class_dict.items():
class_dir = os.path.join(self.data_dir, class_label)
for video_filename in sorted(os.listdir(class_dir)):
self.video_filename_list.append(
os.path.join(class_label, video_filename))
self.classesIdx_list.append(class_idx)
def __len__(self):
return len(self.video_filename_list)
def read_video(self, video_path):
frames = []
cap = cv.VideoCapture(video_path)
count_frames = 0
while True:
ret, frame = cap.read()
if ret:
if self.transform:
transformed = self.transform(image=frame)
frame = transformed['image']
frames.append(frame)
count_frames += 1
else:
break
stride = count_frames // self.num_frames
new_frames = []
count = 0
for i in range(0, count_frames, stride):
if count >= self.num_frames:
break
new_frames.append(frames[i])
count += 1
cap.release()
return torch.stack(new_frames, dim=0)
def __getitem__(self, idx):
classIdx = self.classesIdx_list[idx]
video_filename = self.video_filename_list[idx]
video_path = os.path.join(self.data_dir, video_filename)
frames = self.read_video(video_path)
return frames, classIdx
if __name__ == '__main__':
num_classes = 10
batch_size = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_frames = 20 # You can adjust this to balance speed and accuracy
img_size = (128, 128) # You can adjust this to balance speed and accuracy
num_workers = 0
transform = A.Compose(
[
A.Resize(height=img_size[0], width=img_size[1]),
A.Normalize(),
ToTensorV2()
]
)
logger.info('Loading dataset')
full_dataset = VideoDataset(data_dir="data", num_frames=num_frames, num_classes=num_classes, transform=transform)
train_dataset, test_dataset = train_test_split(full_dataset, test_size=0.2, random_state=42)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
logger.info('Dataset loaded')
for batch_idx, (data, target) in enumerate(train_loader):
print(data.shape)
print(target.shape)
break
# Hint: refer to https://www.kaggle.com/code/nguyenmanhcuongg/pytorch-video-classification-with-conv2d-lstm to implement your model and other functions