-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathwnet.py
143 lines (126 loc) · 6.86 KB
/
wnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from keras.applications.vgg16 import VGG16
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, multiply, BatchNormalization, ReLU, Activation
from keras.models import Model
from keras.initializers import RandomNormal
def wnet(input_shape=(None, None, 3), BN=False):
# Difference with original paper: padding 'valid vs same'
conv_kernel_initializer = RandomNormal(stddev=0.01)
input_flow = Input(input_shape)
# Encoder
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(input_flow)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x_1 = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x_1 = BatchNormalization()(x_1) if BN else x_1
x_1 = Activation('relu')(x_1)
x = MaxPooling2D((2, 2))(x_1)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x_2 = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x_2 = BatchNormalization()(x_2) if BN else x_2
x_2 = Activation('relu')(x_2)
x = MaxPooling2D((2, 2))(x_2)
x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x_3 = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x_3 = BatchNormalization()(x_3) if BN else x_3
x_3 = Activation('relu')(x_3)
x = MaxPooling2D((2, 2))(x_3)
x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x_4 = Conv2D(512, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x_4 = BatchNormalization()(x_4) if BN else x_4
x_4 = Activation('relu')(x_4)
# Decoder 1
x = UpSampling2D((2, 2))(x_4)
x = concatenate([x_3, x])
x = Conv2D(256, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = UpSampling2D((2, 2))(x)
x = concatenate([x_2, x])
x = Conv2D(128, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = UpSampling2D((2, 2))(x)
x = concatenate([x_1, x])
x = Conv2D(64, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
x = Conv2D(32, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x)
x = BatchNormalization()(x) if BN else x
x = Activation('relu')(x)
# Decoder 2
x_rb = UpSampling2D((2, 2))(x_4)
x_rb = concatenate([x_3, x_rb])
x_rb = Conv2D(256, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = Conv2D(256, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = UpSampling2D((2, 2))(x_rb)
x_rb = concatenate([x_2, x_rb])
x_rb = Conv2D(128, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = Conv2D(128, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = UpSampling2D((2, 2))(x_rb)
x_rb = concatenate([x_1, x_rb])
x_rb = Conv2D(64, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = Conv2D(64, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = Conv2D(32, (3, 3), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer)(x_rb)
x_rb = BatchNormalization()(x_rb) if BN else x_rb
x_rb = Activation('relu')(x_rb)
x_rb = Conv2D(1, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer, activation='sigmoid')(x_rb) # Sigmoid activation
# Multiplication
x = multiply([x, x_rb])
x = Conv2D(1, (1, 1), strides=(1, 1), padding='same', kernel_initializer=conv_kernel_initializer, activation='relu')(x)
model = Model(inputs=input_flow, outputs=x)
front_end = VGG16(weights='imagenet', include_top=False)
weights_front_end = []
for layer in front_end.layers:
if 'conv' in layer.name:
weights_front_end.append(layer.get_weights())
counter_conv = 0
for i in range(len(model.layers)):
if counter_conv >= 13:
break
if 'conv' in model.layers[i].name:
model.layers[i].set_weights(weights_front_end[counter_conv])
counter_conv += 1
return model