Skip to content

Commit

Permalink
Added test for make_prediction_plots
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Jan 30, 2024
1 parent 4d4b67a commit 4a088a8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
1 change: 1 addition & 0 deletions mlspm/image/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def make_prediction_plots(
if vmax == vmin == 0:
vmin = 0
vmax = 0.1

if rows == 2:
im1 = top_ax.imshow(p.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower")
im2 = bottom_ax.imshow(t.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower")
Expand Down
45 changes: 30 additions & 15 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,48 @@
import shutil
from pathlib import Path
import pytest

import numpy as np


def test_make_input_plots():

from mlspm.visualization import make_input_plots

save_dir = Path('test_input_plots')
save_dir = Path("test_input_plots")

X = [np.random.rand(2, 20, 20, 2), np.random.rand(2, 20, 20, 2)]
make_input_plots(X, outdir=save_dir)

assert len(list(save_dir.glob('*.png'))) == 4
assert len(list(save_dir.glob("*.png"))) == 4

shutil.rmtree(save_dir)

def test_plot_graphs():

def test_make_prediction_plots():
from mlspm.visualization import make_prediction_plots

save_dir = Path("test_prediction_plots")

preds = [np.random.rand(2, 20, 20), np.random.rand(2, 20, 20)]
true = [np.random.rand(2, 20, 20), np.random.rand(2, 20, 20)]
losses = np.random.rand(2, 2)
descriptors = ['a', 'ES']

make_prediction_plots(preds, true, losses=losses, descriptors=descriptors, outdir=save_dir)

assert len(list(save_dir.glob("*.png"))) == 2

with pytest.raises(ValueError):
make_prediction_plots()

shutil.rmtree(save_dir)


def test_plot_graphs():
from mlspm.graph import MoleculeGraph
from mlspm.visualization import plot_graphs

save_dir = Path('test_graph_plots')
save_dir = Path("test_graph_plots")

# fmt:off
atoms = [
Expand All @@ -44,28 +64,23 @@ def test_plot_graphs():
]),
]
# fmt:on
bonds = [
[(0,1), (0,2), (1,3), (2,3)],
[(0,1)],
[],
[(0,1), (1,2)]
]
bonds = [[(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1)], [], [(0, 1), (1, 2)]]
classes = [[1], [6]]
mols = [MoleculeGraph(a, b, classes) for a, b in zip(atoms, bonds)]
box_borders = np.array([[0.0, 0.0, 0.0], [4.0, 4.0, 4.0]])

plot_graphs(mols, mols, box_borders=box_borders, classes=classes, outdir=save_dir)

assert len(list(save_dir.glob('*.png'))) == 4
assert len(list(save_dir.glob("*.png"))) == 4

shutil.rmtree(save_dir)

def test_plot_distribution_grid():

def test_plot_distribution_grid():
from mlspm.graph import make_position_distribution, MoleculeGraph
from mlspm.visualization import plot_distribution_grid

save_dir = Path('test_distribution_plots')
save_dir = Path("test_distribution_plots")

# fmt:off
atoms = np.array([
Expand All @@ -83,6 +98,6 @@ def test_plot_distribution_grid():

plot_distribution_grid(dist, dist, box_borders=box_borders, outdir=save_dir)

assert len(list(save_dir.glob('*.png'))) == 2
assert len(list(save_dir.glob("*.png"))) == 2

shutil.rmtree(save_dir)

0 comments on commit 4a088a8

Please sign in to comment.