diff --git a/mlspm/image/_visualization.py b/mlspm/image/_visualization.py index f479cd3..85c9c4f 100644 --- a/mlspm/image/_visualization.py +++ b/mlspm/image/_visualization.py @@ -85,6 +85,7 @@ def make_prediction_plots( if vmax == vmin == 0: vmin = 0 vmax = 0.1 + if rows == 2: im1 = top_ax.imshow(p.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower") im2 = bottom_ax.imshow(t.T, vmax=vmax, vmin=vmin, cmap=cmap, origin="lower") diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 3a60189..4669e85 100755 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,28 +1,48 @@ import shutil from pathlib import Path +import pytest import numpy as np def test_make_input_plots(): - from mlspm.visualization import make_input_plots - save_dir = Path('test_input_plots') + save_dir = Path("test_input_plots") X = [np.random.rand(2, 20, 20, 2), np.random.rand(2, 20, 20, 2)] make_input_plots(X, outdir=save_dir) - assert len(list(save_dir.glob('*.png'))) == 4 + assert len(list(save_dir.glob("*.png"))) == 4 shutil.rmtree(save_dir) -def test_plot_graphs(): +def test_make_prediction_plots(): + from mlspm.visualization import make_prediction_plots + + save_dir = Path("test_prediction_plots") + + preds = [np.random.rand(2, 20, 20), np.random.rand(2, 20, 20)] + true = [np.random.rand(2, 20, 20), np.random.rand(2, 20, 20)] + losses = np.random.rand(2, 2) + descriptors = ['a', 'ES'] + + make_prediction_plots(preds, true, losses=losses, descriptors=descriptors, outdir=save_dir) + + assert len(list(save_dir.glob("*.png"))) == 2 + + with pytest.raises(ValueError): + make_prediction_plots() + + shutil.rmtree(save_dir) + + +def test_plot_graphs(): from mlspm.graph import MoleculeGraph from mlspm.visualization import plot_graphs - save_dir = Path('test_graph_plots') + save_dir = Path("test_graph_plots") # fmt:off atoms = [ @@ -44,28 +64,23 @@ def test_plot_graphs(): ]), ] # fmt:on - bonds = [ - [(0,1), (0,2), (1,3), (2,3)], - [(0,1)], - [], - [(0,1), (1,2)] - ] + bonds = [[(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1)], [], [(0, 1), (1, 2)]] classes = [[1], [6]] mols = [MoleculeGraph(a, b, classes) for a, b in zip(atoms, bonds)] box_borders = np.array([[0.0, 0.0, 0.0], [4.0, 4.0, 4.0]]) plot_graphs(mols, mols, box_borders=box_borders, classes=classes, outdir=save_dir) - assert len(list(save_dir.glob('*.png'))) == 4 + assert len(list(save_dir.glob("*.png"))) == 4 shutil.rmtree(save_dir) -def test_plot_distribution_grid(): +def test_plot_distribution_grid(): from mlspm.graph import make_position_distribution, MoleculeGraph from mlspm.visualization import plot_distribution_grid - save_dir = Path('test_distribution_plots') + save_dir = Path("test_distribution_plots") # fmt:off atoms = np.array([ @@ -83,6 +98,6 @@ def test_plot_distribution_grid(): plot_distribution_grid(dist, dist, box_borders=box_borders, outdir=save_dir) - assert len(list(save_dir.glob('*.png'))) == 2 + assert len(list(save_dir.glob("*.png"))) == 2 shutil.rmtree(save_dir)