diff --git a/tests/test_phonons_cli.py b/tests/test_phonons_cli.py index c3cbe4df..06a870cf 100644 --- a/tests/test_phonons_cli.py +++ b/tests/test_phonons_cli.py @@ -421,3 +421,54 @@ def test_invalid_traj_input(tmp_path): ) assert result.exit_code == 1 assert isinstance(result.exception, ValueError) + + +def test_displacement_kwargs(tmp_path): + """Test displacment_kwargs can be set.""" + file_prefix_1 = tmp_path / "NaCl_1" + file_prefix_2 = tmp_path / "NaCl_2" + displacement_file_1 = tmp_path / "NaCl_1-phonopy.yml" + displacement_file_2 = tmp_path / "NaCl_2-phonopy.yml" + + result = runner.invoke( + app, + [ + "phonons", + "--struct", + DATA_PATH / "NaCl.cif", + "--no-hdf5", + "--displacement-kwargs", + "{'is_plusminus': True}", + "--file-prefix", + file_prefix_1, + ], + ) + assert result.exit_code == 0 + + result = runner.invoke( + app, + [ + "phonons", + "--struct", + DATA_PATH / "NaCl.cif", + "--no-hdf5", + "--displacement-kwargs", + "{'is_plusminus': False}", + "--file-prefix", + file_prefix_2, + ], + ) + assert result.exit_code == 0 + + # Check parameters + with open(displacement_file_1, encoding="utf8") as file: + params = yaml.safe_load(file) + n_displacments_1 = len(params["displacements"]) + + assert n_displacments_1 == 4 + + with open(displacement_file_2, encoding="utf8") as file: + params = yaml.safe_load(file) + n_displacments_2 = len(params["displacements"]) + + assert n_displacments_2 == 2