-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrainer.py
50 lines (41 loc) · 1.86 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from dataset import Dataset
from SimplE import SimplE
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
class Trainer:
def __init__(self, dataset, args):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.model = SimplE(dataset.num_ent(), dataset.num_rel(), args.emb_dim, self.device)
self.dataset = dataset
self.args = args
def train(self):
self.model.train()
optimizer = torch.optim.Adagrad(
self.model.parameters(),
lr=self.args.lr,
weight_decay= 0,
initial_accumulator_value= 0.1 #this is added because of the consistency to the original tensorflow code
)
for epoch in range(1, self.args.ne + 1):
last_batch = False
total_loss = 0.0
while not last_batch:
h, r, t, l = self.dataset.next_batch(self.args.batch_size, neg_ratio=self.args.neg_ratio, device = self.device)
last_batch = self.dataset.was_last_batch()
optimizer.zero_grad()
scores = self.model(h, r, t)
loss = torch.sum(F.softplus(-l * scores))+ (self.args.reg_lambda * self.model.l2_loss() / self.dataset.num_batch(self.args.batch_size))
loss.backward()
optimizer.step()
total_loss += loss.cpu().item()
print("Loss in iteration " + str(epoch) + ": " + str(total_loss) + "(" + self.dataset.name + ")")
if epoch % self.args.save_each == 0:
self.save_model(epoch)
def save_model(self, chkpnt):
print("Saving the model")
directory = "models/" + self.dataset.name + "/"
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(self.model, directory + str(chkpnt) + ".chkpnt")