Skip to content

Commit

Permalink
Partially prevent ase_calculator from failing silently.
Browse files Browse the repository at this point in the history
  • Loading branch information
orionarcher committed Aug 28, 2024
1 parent 902823d commit e0bf592
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
"""
calculator = None

if isinstance(calculator_meta, str) and calculator_meta in map(str, MLFF):
calculator_name = MLFF(calculator_meta.split("MLFF.")[-1])
if isinstance(calculator_meta, str):
valid_calculators = [name.split("MLFF.")[-1] for name in map(str, MLFF)]

if "MLFF." in calculator_meta:
calculator_name = MLFF(calculator_meta.split("MLFF.")[-1])
elif calculator_meta in valid_calculators:
calculator_name = MLFF(calculator_meta)
else:
raise ValueError(f"Could not create calculator from {calculator_meta}.")

if calculator_name == MLFF.CHGNet:
from chgnet.model.dynamics import CHGNetCalculator
Expand Down
12 changes: 12 additions & 0 deletions tests/forcefields/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def test_ext_load(force_field: str):
assert calc_from_decode.parameters == calc_from_preset.parameters == {}


def test_raises_error():
with pytest.raises(ValueError, match="Could not create"):
ase_calculator("not_a_calculator")


@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"])
def test_accepts_stubs(force_field: str):
calculator1 = ase_calculator("MACE")
calculator2 = ase_calculator(str(MLFF.MACE))
assert calculator1.name == calculator2.name


def test_m3gnet_pot():
import matgl
from matgl.ext.ase import PESCalculator
Expand Down

0 comments on commit e0bf592

Please sign in to comment.