From a4732b5894ab347b0bd829d19432df64a305458a Mon Sep 17 00:00:00 2001 From: ktaaaki Date: Tue, 15 Dec 2020 18:55:21 +0900 Subject: [PATCH 1/2] add save & load methods --- pixyz/models/model.py | 13 ++++++++ tests/models/test_model.py | 66 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 tests/models/test_model.py diff --git a/pixyz/models/model.py b/pixyz/models/model.py index c447c4a4..834d234a 100644 --- a/pixyz/models/model.py +++ b/pixyz/models/model.py @@ -179,3 +179,16 @@ def test(self, test_x_dict={}, **kwargs): loss = self.test_loss_cls.eval(test_x_dict, **kwargs) return loss + + def save(self, path): + torch.save({ + 'loss_cls': self.loss_cls.state_dict(), + 'test_loss_cls': self.test_loss_cls.state_dict(), + 'distributions': self.distributions.state_dict(), + }, path) + + def load(self, path): + checkpoint = torch.load(path) + self.loss_cls.load_state_dict(checkpoint['loss_cls']) + self.test_loss_cls.load_state_dict(checkpoint['test_loss_cls']) + self.distributions.load_state_dict(checkpoint['distributions']) diff --git a/tests/models/test_model.py b/tests/models/test_model.py new file mode 100644 index 00000000..77a64ada --- /dev/null +++ b/tests/models/test_model.py @@ -0,0 +1,66 @@ +import os +import torch +import torch.nn as nn +from pixyz.distributions import Normal +from pixyz.losses import CrossEntropy +from pixyz.models import Model + + +class TestModel: + def _make_model(self, loc): + class Dist(Normal): + def __init__(self): + super().__init__(loc=loc, scale=1) + self.module = nn.Linear(2, 2) + + p = Dist() + + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + loss = CrossEntropy(p, p).to(device) + model = Model(loss=loss, distributions=[p]) + return model + + def test_save_load(self, tmp_path): + model = self._make_model(0) + save_path = os.path.join(tmp_path, 'model.pth') + model.save(save_path) + + model = self._make_model(1) + p: Normal = model.distributions[0] + assert p.get_params()['loc'] == 1 + + model.load(save_path) + p: Normal = model.distributions[0] + assert p.get_params()['loc'] == 0 + + +class TestGraph: + def test_rename_atomdist(self): + normal = Normal(var=['x'], name='p') + graph = normal.graph + assert graph.name == 'p' + normal.name = 'q' + assert graph.name == 'q' + + def test_print(self): + normal = Normal(var=['x'], name='p') + print(normal.graph) + + +class TestDistributionBase: + def test_init_with_scalar_params(self): + normal = Normal(loc=0, scale=1, features_shape=[2]) + assert normal.sample()['x'].shape == torch.Size([1, 2]) + assert normal.features_shape == torch.Size([2]) + + normal = Normal(loc=0, scale=1) + assert normal.sample()['x'].shape == torch.Size([1]) + assert normal.features_shape == torch.Size([]) + + def test_batch_n(self): + normal = Normal(loc=0, scale=1) + assert normal.sample(batch_n=3)['x'].shape == torch.Size([3]) From 44321f76b714bf97d54aa973748be368d56ea4d4 Mon Sep 17 00:00:00 2001 From: ktaaaki Date: Mon, 21 Dec 2020 13:32:01 +0900 Subject: [PATCH 2/2] model save only support distribution's parameters --- pixyz/models/model.py | 21 +++++++++++++++++---- tests/models/test_model.py | 28 ---------------------------- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/pixyz/models/model.py b/pixyz/models/model.py index 834d234a..debbba46 100644 --- a/pixyz/models/model.py +++ b/pixyz/models/model.py @@ -181,14 +181,27 @@ def test(self, test_x_dict={}, **kwargs): return loss def save(self, path): + """Save the model. The only parameters that are saved are those that are included in the distribution. + Parameters such as device, optimizer, placement of clip_grad, etc. are not saved. + + Parameters + ---------- + path : str + Target file path + + """ torch.save({ - 'loss_cls': self.loss_cls.state_dict(), - 'test_loss_cls': self.test_loss_cls.state_dict(), 'distributions': self.distributions.state_dict(), }, path) def load(self, path): + """Load the model. + + Parameters + ---------- + path : str + Target file path + + """ checkpoint = torch.load(path) - self.loss_cls.load_state_dict(checkpoint['loss_cls']) - self.test_loss_cls.load_state_dict(checkpoint['test_loss_cls']) self.distributions.load_state_dict(checkpoint['distributions']) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 77a64ada..4f4699ae 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -36,31 +36,3 @@ def test_save_load(self, tmp_path): model.load(save_path) p: Normal = model.distributions[0] assert p.get_params()['loc'] == 0 - - -class TestGraph: - def test_rename_atomdist(self): - normal = Normal(var=['x'], name='p') - graph = normal.graph - assert graph.name == 'p' - normal.name = 'q' - assert graph.name == 'q' - - def test_print(self): - normal = Normal(var=['x'], name='p') - print(normal.graph) - - -class TestDistributionBase: - def test_init_with_scalar_params(self): - normal = Normal(loc=0, scale=1, features_shape=[2]) - assert normal.sample()['x'].shape == torch.Size([1, 2]) - assert normal.features_shape == torch.Size([2]) - - normal = Normal(loc=0, scale=1) - assert normal.sample()['x'].shape == torch.Size([1]) - assert normal.features_shape == torch.Size([]) - - def test_batch_n(self): - normal = Normal(loc=0, scale=1) - assert normal.sample(batch_n=3)['x'].shape == torch.Size([3])