Skip to content

Commit

Permalink
test_relaxation parametrize ase_filter to test FrechetCellFilter and …
Browse files Browse the repository at this point in the history
…ExpCellFilter
  • Loading branch information
janosh committed Dec 7, 2023
1 parent 0d5f3c3 commit edff162
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"

# Move the model to the specified device
self.model = (model or CHGNet.load()).to(self.device).float()
self.model = (model or CHGNet.load()).to(self.device)
self.model.graph_converter.set_isolated_atom_response(on_isolated_atoms)
self.stress_weight = stress_weight
print(f"CHGNet will run on {self.device}")
Expand Down
9 changes: 6 additions & 3 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from ase.filters import ExpCellFilter, Filter, FrechetCellFilter
from pymatgen.core import Structure
from pytest import approx, mark, param

Expand All @@ -15,8 +16,10 @@
structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")


@pytest.mark.parametrize("algorithm", ["legacy", "fast"])
def test_relaxation(algorithm: Literal["legacy", "fast"]):
@pytest.mark.parametrize(
"algorithm, ase_filter", [("legacy", FrechetCellFilter), ("fast", ExpCellFilter)]
)
def test_relaxation(algorithm: Literal["legacy", "fast"], ase_filter: Filter) -> None:
chgnet = CHGNet.load()
converter = CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3, algorithm=algorithm
Expand All @@ -25,7 +28,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]):

chgnet.graph_converter = converter
relaxer = StructOptimizer(model=chgnet)
result = relaxer.relax(structure, verbose=True)
result = relaxer.relax(structure, verbose=True, ase_filter=ase_filter)
assert list(result) == ["final_structure", "trajectory"]

traj = result["trajectory"]
Expand Down

0 comments on commit edff162

Please sign in to comment.