-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
55 lines (44 loc) · 1.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import time
import math
import torch
from trainer import count_parameters,epoch_time,train,evaluate,init_weights
from data import SRC,TGT,iterator
from model import Encoder,Decoder,Seq2Seq,Attention
if __name__ == '__main__':
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TGT.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = 128
n_epochs = 10
CLIP = 1
dropout = 0.5
train_iter,valid_iter,_ = iterator(batch,device)
src_pad_id = SRC.vocab.stoi[SRC.pad_token]
atten = Attention(HID_DIM,HID_DIM)
enc = Encoder(INPUT_DIM,ENC_EMB_DIM,HID_DIM,HID_DIM,dropout)
dec = Decoder(OUTPUT_DIM,atten,DEC_EMB_DIM,HID_DIM,HID_DIM,dropout)
model = Seq2Seq(enc,dec,src_pad_id,device).to(device)
model.apply(init_weights)
if os.path.exists("tut4-model.pt"):
model.load_state_dict(torch.load("tut4-model.pt"))
count_parameters(model)
pad_idx = TGT.vocab.stoi[TGT.pad_token]
criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = torch.optim.Adam(model.parameters())
best_valid_loss = float("inf")
for epoch in range(n_epochs):
start_time = time.time()
train_loss = train(model,train_iter,optimizer,criterion,CLIP)
valid_loss = evaluate(model,valid_iter,criterion)
end_time = time.time()
epoch_min,epoch_sec = epoch_time(start_time,end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(),"tut4-model.pt")
print(f"Epoch:{epoch+1:02} | Time: {epoch_min}m {epoch_sec}s")
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')