From 7345ad3121d4fe81fc858cb588784241ea5a47f7 Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Mon, 27 Nov 2023 12:21:18 +0200 Subject: [PATCH] Added tests for checkpointing. --- tests/test_utils.py | 53 +++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3cbd30d..bf201cc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,9 @@ import os +from pathlib import Path + import numpy as np +from torch import nn, optim +import torch def test_xyz_read_write(): @@ -22,33 +26,36 @@ def test_xyz_read_write(): assert np.allclose(xyz, xyz_read) -def test_loss_log_plot(): - from mlspm.logging import LossLogPlot - - loss_log_path = "loss_log.csv" - plot_path = "test_plot.png" - log_path = "test_log.txt" +def test_checkpoints(): + from mlspm.utils import load_checkpoint, save_checkpoint - info_log = open(log_path, 'w') + save_dir = Path("test_checkpoints") - if os.path.exists(loss_log_path): - os.remove(loss_log_path) - loss_log = LossLogPlot(loss_log_path, plot_path, loss_labels=["1", "2"], loss_weights=["", ""], stream=info_log) + model = nn.Linear(10, 10) + optimizer = optim.Adam(model.parameters()) + lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + 1e-5 * b)) + additional_data = {"test_data": 3, "lr_scheduler": lr_scheduler} - losses = [[[0, 1, 2], [1, 2, 3]], [[1.5, 2.0, 3.0], [0.1, 0.4, 0.7]]] - for loss in losses: - loss_log.add_train_loss(loss[0]) - loss_log.add_val_loss(loss[1]) - loss_log.next_epoch() + x, y = np.random.rand(2, 1, 10, 10) + x = torch.from_numpy(x).float() + y = torch.from_numpy(y).float() + pred = model(x) + loss = ((y - pred) ** 2).mean() + loss.backward() + optimizer.step() + lr_scheduler.step() + print(loss) - new_log = LossLogPlot(loss_log_path, "plot.png", loss_labels=["1", "2"], loss_weights=["", ""], stream=info_log) + save_checkpoint(model, optimizer, epoch=1, save_dir=save_dir, additional_data=additional_data) - info_log.close() + model_new = nn.Linear(10, 10) + optimizer_new = optim.Adam(model.parameters()) + lr_scheduler_new = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + 1e-5 * b)) + additional_data = {"test_data": 0, "lr_scheduler": lr_scheduler_new} - os.remove(loss_log_path) - os.remove(plot_path) - os.remove(log_path) + load_checkpoint(model_new, optimizer_new, file_name=save_dir / "model_1.pth", additional_data=additional_data) - assert new_log.epoch == 2 - assert np.allclose(new_log.train_losses, np.array([[0.75, 1.5, 2.5]])), new_log.train_losses - assert np.allclose(new_log.val_losses, np.array([[0.55, 1.2, 1.85]])), new_log.val_losses + assert np.allclose(model.state_dict()["weight"], model_new.state_dict()["weight"]) + assert np.allclose(optimizer.state_dict()["state"][0]["exp_avg"], optimizer_new.state_dict()["state"][0]["exp_avg"]) + assert np.allclose(lr_scheduler.state_dict()["_last_lr"], lr_scheduler_new.state_dict()["_last_lr"]) + assert np.allclose(additional_data["test_data"], 3)