-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
30 lines (20 loc) · 901 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from two_stage_model import *
from data_transform import *
import torch
from utils import train_en_de_C
from config import base_lr,epoches,lr_step,latent_variable_dim
net1 = encoder_C(1,latent_variable_dim)
net2 = decoder_C(latent_variable_dim)
optimizer_encoder = torch.optim.SGD(net1.parameters(),lr=base_lr,momentum=0.9,weight_decay=0.0005)
optimizer_decoder = torch.optim.SGD(net2.parameters(),lr=base_lr,momentum=0.9,weight_decay=0.0005)
criterion = nn.MSELoss()
def adjust_lr(optimizer,epoch):
lr = base_lr*(0.1**(epoch//lr_step))
for parameter in optimizer.param_groups:
parameter['lr'] = lr
print(" #### Start training ####")
for epoch in range(1,epoches+1):
adjust_lr(optimizer_encoder,epoch)
adjust_lr(optimizer_decoder, epoch)
train_en_de_C(net1,net2,MNIST_train_data,MNIST_test_data,epoch,optimizer_encoder,optimizer_decoder,criterion)
print("Done!")