diff --git a/pixyz/models/model.py b/pixyz/models/model.py index c447c4a4..debbba46 100644 --- a/pixyz/models/model.py +++ b/pixyz/models/model.py @@ -179,3 +179,29 @@ def test(self, test_x_dict={}, **kwargs): loss = self.test_loss_cls.eval(test_x_dict, **kwargs) return loss + + def save(self, path): + """Save the model. The only parameters that are saved are those that are included in the distribution. + Parameters such as device, optimizer, placement of clip_grad, etc. are not saved. + + Parameters + ---------- + path : str + Target file path + + """ + torch.save({ + 'distributions': self.distributions.state_dict(), + }, path) + + def load(self, path): + """Load the model. + + Parameters + ---------- + path : str + Target file path + + """ + checkpoint = torch.load(path) + self.distributions.load_state_dict(checkpoint['distributions']) diff --git a/tests/models/test_model.py b/tests/models/test_model.py new file mode 100644 index 00000000..4f4699ae --- /dev/null +++ b/tests/models/test_model.py @@ -0,0 +1,38 @@ +import os +import torch +import torch.nn as nn +from pixyz.distributions import Normal +from pixyz.losses import CrossEntropy +from pixyz.models import Model + + +class TestModel: + def _make_model(self, loc): + class Dist(Normal): + def __init__(self): + super().__init__(loc=loc, scale=1) + self.module = nn.Linear(2, 2) + + p = Dist() + + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + loss = CrossEntropy(p, p).to(device) + model = Model(loss=loss, distributions=[p]) + return model + + def test_save_load(self, tmp_path): + model = self._make_model(0) + save_path = os.path.join(tmp_path, 'model.pth') + model.save(save_path) + + model = self._make_model(1) + p: Normal = model.distributions[0] + assert p.get_params()['loc'] == 1 + + model.load(save_path) + p: Normal = model.distributions[0] + assert p.get_params()['loc'] == 0