-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdnn_utils.py
97 lines (71 loc) · 1.88 KB
/
dnn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import matplotlib.pyplot as plt
def sigmoid(Z):
"""sigmoid
Arguments:
Z {[np.array]} -- [Wx + b]
Returns:
A - [np.array] -- [1 / 1+exp(- Wx + b)]
cache - Z
"""
A = 1/(1+np.exp(-Z))
cache = Z
return A, cache
def relu(Z):
"""rectified linear unit
Arguments:
Z {[np.array]} -- [Wx + b]
Returns:
A - [np.array] -- [max(0,Z)]
cache - Z
"""
A = np.maximum(0,Z)
assert(A.shape == Z.shape)
cache = Z
return A, cache
def relu_backward(dA, cache):
"""
Implement the backward propagation for a single RELU unit.
Arguments:
dA - the activated gradient
cache - Z
Returns:
dZ - Gradient of the cost with respect to Z
"""
Z = cache
dZ = np.array(dA, copy=True) # just converting dz to a correct object.
# for z <= 0, set dz to 0
dZ[Z <= 0] = 0
assert (dZ.shape == Z.shape)
return dZ
def sigmoid_backward(dA, cache):
"""
Implement the backward propagation for a single SIGMOID unit.
Arguments:
dA - the acitvated gradient
cache - Z
Returns:
dZ - Gradient of the cost with respect to Z
"""
Z = cache
s = 1/(1+np.exp(-Z))
dZ = dA * s * (1-s)
assert (dZ.shape == Z.shape)
return dZ
def print_mislabeled_images(classes, X, y, p):
"""
Plots images where predictions and truth were different.
X -- dataset
y -- true labels
p -- predictions
"""
a = p + y
mislabeled_indices = np.asarray(np.where(a == 1))
print(mislabeled_indices)
num_images = len(mislabeled_indices[0])
for i in range(num_images):
index = mislabeled_indices[1][i]
plt.subplot(2, num_images, i + 1)
plt.imshow(X[:,index].reshape(64,64,3))
plt.title("Prediction: " + classes[int(p[0,index])].decode("utf-8") + " \n Class: " + classes[y[0,index]].decode("utf-8"))
plt.show()