Skip to content

Commit

Permalink
Add option to pass Atoms structure directly
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Mar 8, 2024
1 parent b5f83b5 commit e2745ec
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 46 deletions.
11 changes: 6 additions & 5 deletions janus_core/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Set up commandline interface."""

import ast
from pathlib import Path
from typing import Annotated

import typer
Expand Down Expand Up @@ -62,8 +63,8 @@ def parse_dict_class(value: str):

@app.command()
def singlepoint(
structure: Annotated[
str, typer.Option(help="Path to structure to perform calculations")
struct_path: Annotated[
Path, typer.Option("--struct", help="Path of structure to simulate")
],
architecture: Annotated[
str, typer.Option("--arch", help="MLIP architecture to use for calculations")
Expand Down Expand Up @@ -112,8 +113,8 @@ def singlepoint(
Parameters
----------
structure : str
Structure to simulate.
struct_path : Path
Path of structure to simulate.
architecture : Optional[str]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
Expand All @@ -140,7 +141,7 @@ def singlepoint(
raise ValueError("write_kwargs must be a dictionary")

s_point = SinglePoint(
structure=structure,
struct_path=struct_path,
architecture=architecture,
device=device,
read_kwargs=read_kwargs,
Expand Down
73 changes: 54 additions & 19 deletions janus_core/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ class SinglePoint:
Parameters
----------
structure : str
Structure to simulate.
struct : Optional[MaybeSequence[Atoms]]
Structure or list of structures to simulate. Required if `struct_path` is None.
Default is None.
struct_path : Optional[str]
Path of structure to simulate. Required if `struct` is None.
Default is None.
struct_name : Optional[str]
Name of structure. Default is "struct" if `struct` is specified, else
inferred from `struct_path`.
architecture : Literal[architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
Expand All @@ -42,14 +49,14 @@ class SinglePoint:
----------
architecture : Architectures
MLIP architecture to use for single point calculations.
structure : str
Path of structure to simulate.
device : Devices
Device to run MLIP model on.
struct : MaybeList[Atoms]
ASE Atoms or list of Atoms structures to simulate.
structname : str
Name of structure from its filename.
device : Devices
Device to run MLIP model on.
struct_path : Optional[str]
Path of structure to simulate.
struct_name : Optional[str]
Name of structure.
Methods
-------
Expand All @@ -63,7 +70,9 @@ class SinglePoint:

def __init__(
self,
structure: str,
struct: Optional[MaybeList[Atoms]] = None,
struct_path: Optional[str] = None,
struct_name: Optional[str] = None,
architecture: Architectures = "mace_mp",
device: Devices = "cpu",
read_kwargs: Optional[ASEReadArgs] = None,
Expand All @@ -74,8 +83,15 @@ def __init__(
Parameters
----------
structure : str
Path of structure to simulate.
struct : Optional[MaybeSequence[Atoms]]
Structure or list of structures to simulate.
Required if `struct_path` is None. Default is None.
struct_path : Optional[str]
Path of structure to simulate. Required if `struct` is None.
Default is None.
struct_name : Optional[str]
Name of structure. Default is "struct" if `struct` is specified, else
inferred from `struct_path`.
architecture : Architectures
MLIP architecture to use for single point calculations.
Default is "mace_mp".
Expand All @@ -86,14 +102,29 @@ def __init__(
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
"""
self.architecture = architecture
self.device = device
self.structure = structure
if struct and struct_path:
raise ValueError("Only one of `struct` and `struct_path` may be specified")

if not struct and not struct_path:
raise ValueError("Please specify either `struct` or `struct_path`")

# Read structure and get calculator
read_kwargs = read_kwargs if read_kwargs else {}
calc_kwargs = calc_kwargs if calc_kwargs else {}
self.read_structure(**read_kwargs)

self.architecture = architecture
self.device = device
self.struct_path = struct_path
self.struct_name = struct_name

# Read structure if given as path
if self.struct_path:
self.read_structure(**read_kwargs)
else:
self.struct = struct
if not self.struct_name:
self.struct_name = "struct"

# Configure calculator
self.set_calculator(**calc_kwargs)

def read_structure(self, **kwargs) -> None:
Expand All @@ -108,8 +139,12 @@ def read_structure(self, **kwargs) -> None:
**kwargs
Keyword arguments passed to ase.io.read.
"""
self.struct = read(self.structure, **kwargs)
self.structname = Path(self.structure).stem
if self.struct_path:
self.struct = read(self.struct_path, **kwargs)
if not self.struct_name:
self.struct_name = Path(self.struct_path).stem
else:
raise ValueError("`struct_path` must be defined")

def set_calculator(
self, read_kwargs: Optional[ASEReadArgs] = None, **kwargs
Expand Down Expand Up @@ -278,7 +313,7 @@ def run_single_point(

if write_results:
if "filename" not in write_kwargs:
filename = f"{self.structname}-results.xyz"
filename = f"{self.struct_name}-results.xyz"
write_kwargs["filename"] = Path(".").absolute() / filename
write(images=self.struct, **write_kwargs)

Expand Down
12 changes: 6 additions & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_singlepoint(tmp_path):
app,
[
"singlepoint",
"--structure",
"--struct",
DATA_PATH / "NaCl.cif",
"--write-kwargs",
f"{{'filename': '{str(results_path)}'}}",
Expand All @@ -58,7 +58,7 @@ def test_singlepoint_properties(tmp_path):
app,
[
"singlepoint",
"--structure",
"--struct",
DATA_PATH / "H2O.cif",
"--property",
"energy",
Expand All @@ -76,7 +76,7 @@ def test_singlepoint_properties(tmp_path):
app,
[
"singlepoint",
"--structure",
"--struct",
DATA_PATH / "H2O.cif",
"--property",
"stress",
Expand All @@ -97,7 +97,7 @@ def test_singlepoint_read_kwargs(tmp_path):
app,
[
"singlepoint",
"--structure",
"--struct",
DATA_PATH / "benzene-traj.xyz",
"--read-kwargs",
"{'index': ':'}",
Expand All @@ -121,7 +121,7 @@ def test_singlepoint_calc_kwargs(tmp_path):
app,
[
"singlepoint",
"--structure",
"--struct",
DATA_PATH / "NaCl.cif",
"--calc-kwargs",
"{'default_dtype': 'float32'}",
Expand All @@ -144,7 +144,7 @@ def test_singlepoint_default_write():
app,
[
"singlepoint",
"--structure",
"--struct",
DATA_PATH / "NaCl.cif",
"--property",
"energy",
Expand Down
14 changes: 7 additions & 7 deletions tests/test_geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
]


@pytest.mark.parametrize("architecture, structure, expected, kwargs", test_data)
def test_optimize(architecture, structure, expected, kwargs):
@pytest.mark.parametrize("architecture, struct_path, expected, kwargs", test_data)
def test_optimize(architecture, struct_path, expected, kwargs):
"""Test optimizing geometry using MACE."""
single_point = SinglePoint(
structure=DATA_PATH / structure,
struct_path=DATA_PATH / struct_path,
architecture=architecture,
calc_kwargs={"model_paths": MODEL_PATH},
)
Expand All @@ -58,7 +58,7 @@ def test_saving_struct(tmp_path):
struct_path = tmp_path / "NaCl.xyz"

single_point = SinglePoint(
structure=DATA_PATH / "NaCl.cif",
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
)
Expand All @@ -77,7 +77,7 @@ def test_saving_struct(tmp_path):
def test_saving_traj(tmp_path):
"""Test saving optimization trajectory output."""
single_point = SinglePoint(
structure=DATA_PATH / "NaCl.cif",
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
)
Expand All @@ -91,7 +91,7 @@ def test_saving_traj(tmp_path):
def test_traj_reformat(tmp_path):
"""Test saving optimization trajectory in different format."""
single_point = SinglePoint(
structure=DATA_PATH / "NaCl.cif",
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
)
Expand All @@ -112,7 +112,7 @@ def test_traj_reformat(tmp_path):
def test_missing_traj_kwarg(tmp_path):
"""Test saving optimization trajectory in different format."""
single_point = SinglePoint(
structure=DATA_PATH / "NaCl.cif",
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
)
Expand Down
Loading

0 comments on commit e2745ec

Please sign in to comment.