From edff1624b21db67e0b70bfd914f114bad0d4d784 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 7 Dec 2023 14:04:50 -0800 Subject: [PATCH] test_relaxation parametrize ase_filter to test FrechetCellFilter and ExpCellFilter --- chgnet/model/dynamics.py | 2 +- tests/test_relaxation.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 8e299d8d..2729d1ab 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -89,7 +89,7 @@ def __init__( self.device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" # Move the model to the specified device - self.model = (model or CHGNet.load()).to(self.device).float() + self.model = (model or CHGNet.load()).to(self.device) self.model.graph_converter.set_isolated_atom_response(on_isolated_atoms) self.stress_weight = stress_weight print(f"CHGNet will run on {self.device}") diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index 477b3b44..534df862 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -5,6 +5,7 @@ import pytest import torch +from ase.filters import ExpCellFilter, Filter, FrechetCellFilter from pymatgen.core import Structure from pytest import approx, mark, param @@ -15,8 +16,10 @@ structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif") -@pytest.mark.parametrize("algorithm", ["legacy", "fast"]) -def test_relaxation(algorithm: Literal["legacy", "fast"]): +@pytest.mark.parametrize( + "algorithm, ase_filter", [("legacy", FrechetCellFilter), ("fast", ExpCellFilter)] +) +def test_relaxation(algorithm: Literal["legacy", "fast"], ase_filter: Filter) -> None: chgnet = CHGNet.load() converter = CrystalGraphConverter( atom_graph_cutoff=6, bond_graph_cutoff=3, algorithm=algorithm @@ -25,7 +28,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]): chgnet.graph_converter = converter relaxer = StructOptimizer(model=chgnet) - result = relaxer.relax(structure, verbose=True) + result = relaxer.relax(structure, verbose=True, ase_filter=ase_filter) assert list(result) == ["final_structure", "trajectory"] traj = result["trajectory"]