Skip to content

Commit

Permalink
Test displacement kwargs for phonons
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Oct 14, 2024
1 parent 0a339c1 commit 6f491ca
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tests/test_phonons_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,54 @@ def test_invalid_traj_input(tmp_path):
)
assert result.exit_code == 1
assert isinstance(result.exception, ValueError)


def test_displacement_kwargs(tmp_path):
"""Test displacment_kwargs can be set."""
file_prefix_1 = tmp_path / "NaCl_1"
file_prefix_2 = tmp_path / "NaCl_2"
displacement_file_1 = tmp_path / "NaCl_1-phonopy.yml"
displacement_file_2 = tmp_path / "NaCl_2-phonopy.yml"

result = runner.invoke(
app,
[
"phonons",
"--struct",
DATA_PATH / "NaCl.cif",
"--no-hdf5",
"--displacement-kwargs",
"{'is_plusminus': True}",
"--file-prefix",
file_prefix_1,
],
)
assert result.exit_code == 0

result = runner.invoke(
app,
[
"phonons",
"--struct",
DATA_PATH / "NaCl.cif",
"--no-hdf5",
"--displacement-kwargs",
"{'is_plusminus': False}",
"--file-prefix",
file_prefix_2,
],
)
assert result.exit_code == 0

# Check parameters
with open(displacement_file_1, encoding="utf8") as file:
params = yaml.safe_load(file)
n_displacments_1 = len(params["displacements"])

assert n_displacments_1 == 4

with open(displacement_file_2, encoding="utf8") as file:
params = yaml.safe_load(file)
n_displacments_2 = len(params["displacements"])

assert n_displacments_2 == 2

0 comments on commit 6f491ca

Please sign in to comment.