-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnotMNIST_p2_vizualization.py
executable file
·78 lines (64 loc) · 2.18 KB
/
notMNIST_p2_vizualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tarfile
from IPython.display import display, Image
from scipy import ndimage
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle
import re
import IPython
ip = IPython.get_ipython()
ip.enable_pylab()
root = r'C:\Users\YoelS\Desktop\Udacity'
#folders = ['notMNIST_small', 'notMNIST_large']
#
#for fld in folders:
#
# candidates = os.listdir(os.path.join(root, fld))
#
# for cnd in candidates:
# if re.search('.pickle', cnd) is not None:
# try:
# with open(os.path.join(root, fld, cnd), 'rb') as f:
# dataset = pickle.load(f)
# except Exception as e:
# print('Unable to read ', cnd, ':', e)
#
# nimages = dataset.shape[0]
# nrows = 5
# ncols = 8
# fig, axes = plt.subplots(
# num=cnd[0] + ' ' + fld[-5:] + ' : ' +
# str(nimages) + ' samples',
# nrows=nrows, ncols=ncols)
#
# for r in range(nrows):
# for c in range(ncols):
#
# index = np.int64(nimages * np.random.rand(1)[0])
# axes[r, c].imshow(dataset[index, :, :], cmap='Greys')
files = ['notMNIST.pickle', 'notMNIST_sanitized.pickle']
for fl in files:
try:
with open(os.path.join(root, fl), 'rb') as f:
dataset = pickle.load(f)
except Exception as e:
print('Unable to read ', fl, ':', e)
for ds in ['train', 'valid', 'test']:
data = dataset[ds + '_dataset']
nimages = data.shape[0]
nrows = 5
ncols = 8
fig, axes = plt.subplots(
num=fl + ' ' + ds + ' : ' +
str(nimages) + ' samples',
nrows=nrows, ncols=ncols)
for r in range(nrows):
for c in range(ncols):
index = np.int64(nimages * np.random.rand(1)[0])
axes[r, c].imshow(data[index, :, :], cmap='Greys')