Skip to content

Commit

Permalink
Added tests for checkpointing.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Nov 27, 2023
1 parent 6896179 commit 7345ad3
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
from pathlib import Path

import numpy as np
from torch import nn, optim
import torch


def test_xyz_read_write():
Expand All @@ -22,33 +26,36 @@ def test_xyz_read_write():
assert np.allclose(xyz, xyz_read)


def test_loss_log_plot():
from mlspm.logging import LossLogPlot

loss_log_path = "loss_log.csv"
plot_path = "test_plot.png"
log_path = "test_log.txt"
def test_checkpoints():
from mlspm.utils import load_checkpoint, save_checkpoint

info_log = open(log_path, 'w')
save_dir = Path("test_checkpoints")

if os.path.exists(loss_log_path):
os.remove(loss_log_path)
loss_log = LossLogPlot(loss_log_path, plot_path, loss_labels=["1", "2"], loss_weights=["", ""], stream=info_log)
model = nn.Linear(10, 10)
optimizer = optim.Adam(model.parameters())
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + 1e-5 * b))
additional_data = {"test_data": 3, "lr_scheduler": lr_scheduler}

losses = [[[0, 1, 2], [1, 2, 3]], [[1.5, 2.0, 3.0], [0.1, 0.4, 0.7]]]
for loss in losses:
loss_log.add_train_loss(loss[0])
loss_log.add_val_loss(loss[1])
loss_log.next_epoch()
x, y = np.random.rand(2, 1, 10, 10)
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()
pred = model(x)
loss = ((y - pred) ** 2).mean()
loss.backward()
optimizer.step()
lr_scheduler.step()
print(loss)

new_log = LossLogPlot(loss_log_path, "plot.png", loss_labels=["1", "2"], loss_weights=["", ""], stream=info_log)
save_checkpoint(model, optimizer, epoch=1, save_dir=save_dir, additional_data=additional_data)

info_log.close()
model_new = nn.Linear(10, 10)
optimizer_new = optim.Adam(model.parameters())
lr_scheduler_new = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + 1e-5 * b))
additional_data = {"test_data": 0, "lr_scheduler": lr_scheduler_new}

os.remove(loss_log_path)
os.remove(plot_path)
os.remove(log_path)
load_checkpoint(model_new, optimizer_new, file_name=save_dir / "model_1.pth", additional_data=additional_data)

assert new_log.epoch == 2
assert np.allclose(new_log.train_losses, np.array([[0.75, 1.5, 2.5]])), new_log.train_losses
assert np.allclose(new_log.val_losses, np.array([[0.55, 1.2, 1.85]])), new_log.val_losses
assert np.allclose(model.state_dict()["weight"], model_new.state_dict()["weight"])
assert np.allclose(optimizer.state_dict()["state"][0]["exp_avg"], optimizer_new.state_dict()["state"][0]["exp_avg"])
assert np.allclose(lr_scheduler.state_dict()["_last_lr"], lr_scheduler_new.state_dict()["_last_lr"])
assert np.allclose(additional_data["test_data"], 3)

0 comments on commit 7345ad3

Please sign in to comment.