-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathmodel.py
43 lines (31 loc) · 1.18 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sugartensor as tf
#
# hyper parameters
#
z_dim = 50 # noise dimension
margin = 1 # max-margin for hinge loss
pt_weight = 0.1 # PT regularizer's weight
#
# create generator
#
def generator(x):
reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0
with tf.sg_context(name='generator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse):
# generator network
res = (x.sg_dense(dim=1024, name='fc_1')
.sg_dense(dim=7*7*128, name='fc_2')
.sg_reshape(shape=(-1, 7, 7, 128))
.sg_upconv(dim=64, name='conv_1')
.sg_upconv(dim=1, act='sigmoid', bn=False, name='conv_2'))
return res
#
# create discriminator
#
def discriminator(x):
reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse):
res = (x.sg_conv(dim=64, name='conv_1')
.sg_conv(dim=128, name='conv_2')
.sg_upconv(dim=64, name='conv_3')
.sg_upconv(dim=1, act='linear', name='conv_4'))
return res