-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathdata_loader.py
102 lines (82 loc) · 3.08 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 10 14:12:10 2019
@author: chxy
"""
import numpy as np
import torch
from torchvision import datasets
from torchvision import transforms
def get_train_loader(data_dir,
batch_size,
random_seed,
shuffle=True,
num_workers=4,
pin_memory=True):
"""
Utility function for loading and returning a multi-process
train iterator over the CIFAR100 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Args
----
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: train set iterator.
"""
# define transforms
trans = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 将图像转化为32 * 32
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(degrees=15), # 随机旋转
transforms.ToTensor(), # 将numpy数据类型转化为Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化
])
# load dataset
dataset = datasets.CIFAR100(root=data_dir,
transform=trans,
download=False,
train=True)
if shuffle:
np.random.seed(random_seed)
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory,
)
return train_loader
def get_test_loader(data_dir,
batch_size,
num_workers=4,
pin_memory=True):
"""
Utility function for loading and returning a multi-process
test iterator over the CIFAR100 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Args
----
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
# define transforms
trans = transforms.Compose([
transforms.ToTensor(), # 将numpy数据类型转化为Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化
])
# load dataset
dataset = datasets.CIFAR100(
data_dir, train=False, download=False, transform=trans
)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory,
)
return data_loader