Skip to content

Commit

Permalink
Merge pull request #161 from masa-su/feature/model_save2
Browse files Browse the repository at this point in the history
Feature/model save2
  • Loading branch information
masa-su authored Jan 6, 2021
2 parents dbbfb35 + 44321f7 commit 203ec29
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
26 changes: 26 additions & 0 deletions pixyz/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,29 @@ def test(self, test_x_dict={}, **kwargs):
loss = self.test_loss_cls.eval(test_x_dict, **kwargs)

return loss

def save(self, path):
"""Save the model. The only parameters that are saved are those that are included in the distribution.
Parameters such as device, optimizer, placement of clip_grad, etc. are not saved.
Parameters
----------
path : str
Target file path
"""
torch.save({
'distributions': self.distributions.state_dict(),
}, path)

def load(self, path):
"""Load the model.
Parameters
----------
path : str
Target file path
"""
checkpoint = torch.load(path)
self.distributions.load_state_dict(checkpoint['distributions'])
38 changes: 38 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import torch
import torch.nn as nn
from pixyz.distributions import Normal
from pixyz.losses import CrossEntropy
from pixyz.models import Model


class TestModel:
def _make_model(self, loc):
class Dist(Normal):
def __init__(self):
super().__init__(loc=loc, scale=1)
self.module = nn.Linear(2, 2)

p = Dist()

if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

loss = CrossEntropy(p, p).to(device)
model = Model(loss=loss, distributions=[p])
return model

def test_save_load(self, tmp_path):
model = self._make_model(0)
save_path = os.path.join(tmp_path, 'model.pth')
model.save(save_path)

model = self._make_model(1)
p: Normal = model.distributions[0]
assert p.get_params()['loc'] == 1

model.load(save_path)
p: Normal = model.distributions[0]
assert p.get_params()['loc'] == 0

0 comments on commit 203ec29

Please sign in to comment.