-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathread_fashion_mnist_win.py
88 lines (71 loc) · 3.15 KB
/
read_fashion_mnist_win.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import urllib.request
import gzip
import numpy as np
from sklearn.model_selection import train_test_split
from mlxtend.data import loadlocal_mnist
class FashionMnistLoader:
dir_name = "data/fashion"
url_train_imgs = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz"
url_train_labels = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz"
url_test_imgs = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz"
url_test_labels = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz"
def __init__(self):
self.train_imgs_fn = None
self.train_labels_fn = None
self.test_imgs_fn = None
self.test_labels_fn = None
def get_data(self, url):
gz_file_name = url.split("/")[-1]
gz_file_path = os.path.join(self.dir_name, gz_file_name)
file_name = gz_file_name.split(".")[0]
file_path = os.path.join(self.dir_name, file_name)
os.makedirs(self.dir_name, exist_ok=True)
if not os.path.exists(file_path):
urllib.request.urlretrieve(url, gz_file_path)
with gzip.open(gz_file_path) as data:
with open(file_path, 'wb') as out:
out.write(data.read())
return file_path
def get_all_data(self):
self.train_imgs_fn = self.get_data(self.url_train_imgs)
self.train_labels_fn = self.get_data(self.url_train_labels)
self.test_imgs_fn = self.get_data(self.url_test_imgs)
self.test_labels_fn = self.get_data(self.url_test_labels)
return self
def load_train(self):
X, y = loadlocal_mnist(
images_path=self.train_imgs_fn,
labels_path=self.train_labels_fn)
return X, y
def load_test(self):
X, y = loadlocal_mnist(
images_path=self.test_imgs_fn,
labels_path=self.test_labels_fn)
return X, y
def _split(self, X, y, test_size):
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=666)
return X_train, X_test, y_train, y_test
def train_split(self, test_size):
X, y = self.load_train()
X_train, X_test, y_train, y_test = self._split(X, y, test_size)
return X_train, X_test, y_train, y_test
def standard_split(self):
X_train, y_train = self.load_train()
X_test, y_test = self.load_test()
return X_train, X_test, y_train, y_test
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class FashionMnist(Dataset):
def __init__(self, X, y, transform=None):
self.data = (torch.from_numpy(X).float()/255).reshape(-1, 1, 28, 28).squeeze()
self.target = torch.from_numpy(y).long()
self.transform = transform
def __len__(self):
return len(self.target)
def __getitem__(self, index):
img, tar = self.data[index], self.target[index]
if self.transform:
img = self.transform(img)
return img, tar