Skip to content

Commit

Permalink
Merge pull request #42 from benrich37/setup_teardown_inputs
Browse files Browse the repository at this point in the history
Changing all "zopen" to just "open" due to an error raised when monty…
  • Loading branch information
benrich37 authored Jan 14, 2025
2 parents 67a44cd + 1185d95 commit 7a6b0ad
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 74 deletions.
3 changes: 1 addition & 2 deletions src/pymatgen/io/jdftx/_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import TYPE_CHECKING, Any

import numpy as np
from monty.io import zopen

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -52,7 +51,7 @@ def read_file(file_name: str) -> list[str]:
text: list[str]
list of strings from file
"""
with zopen(file_name, "r") as f:
with open(file_name) as f:
text = f.readlines()
f.close()
return text
Expand Down
7 changes: 3 additions & 4 deletions src/pymatgen/io/jdftx/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class is written.

import numpy as np
import scipy.constants as const
from monty.io import zopen
from monty.json import MSONable

from pymatgen.core import Structure
Expand Down Expand Up @@ -185,7 +184,7 @@ def write_file(self, filename: PathLike) -> None:
Args:
filename (PathLike): Filename to write to.
"""
with zopen(filename, mode="wt") as file:
with open(filename, mode="w") as file:
file.write(str(self))

@classmethod
Expand All @@ -211,7 +210,7 @@ def from_file(
path_parent = None
if assign_path_parent:
path_parent = Path(filename).parents[0]
with zopen(filename, mode="rt") as file:
with open(filename) as file:
return cls.from_str(
file.read(),
dont_require_structure=dont_require_structure,
Expand Down Expand Up @@ -835,7 +834,7 @@ def write_file(self, filename: PathLike, **kwargs) -> None:
filename (PathLike): Filename to write to.
**kwargs: Kwargs to pass to JDFTXStructure.get_str.
"""
with zopen(filename, mode="wt") as file:
with open(filename, mode="w") as file:
file.write(self.get_str(**kwargs))

def as_dict(self) -> dict:
Expand Down
1 change: 0 additions & 1 deletion tests/files/io/jdftx/tmp/empty.txt

This file was deleted.

13 changes: 6 additions & 7 deletions tests/io/jdftx/outputs_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from collections.abc import Callable


from .shared_test_utils import assert_same_value, dump_files_dir
from .shared_test_utils import assert_same_value


def write_mt_file(fname: str, write_dir: Path = dump_files_dir):
filepath = write_dir / fname
with open(filepath, "w") as f:
f.write("if you're reading this yell at ben")
f.close()
# def write_mt_file(fname: str, write_dir: Path = dump_files_dir):
# filepath = write_dir / fname
# with open(filepath, "w") as f:
# f.write("if you're reading this yell at ben")
# f.close()


def object_hasall_known_simple(obj: Any, knowndict: dict):
Expand Down
16 changes: 16 additions & 0 deletions tests/io/jdftx/shared_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

import os
import shutil
from pathlib import Path

import pytest
Expand All @@ -31,3 +33,17 @@ def assert_same_value(testval, knownval):
assert len(testval) == len(knownval)
for i in range(len(testval)):
assert_same_value(testval[i], knownval[i])


@pytest.fixture(scope="module")
def tmp_path():
os.mkdir(dump_files_dir)
yield dump_files_dir
shutil.rmtree(dump_files_dir)


def write_mt_file(tmp_path: Path, fname: str):
filepath = tmp_path / fname
with open(filepath, "w") as f:
f.write("if you're reading this yell at ben")
f.close()
85 changes: 41 additions & 44 deletions tests/io/jdftx/test_jdftxinfile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any
Expand All @@ -21,7 +20,7 @@
ex_infile2_fname,
ex_infile3_fname,
)
from .shared_test_utils import assert_same_value, dump_files_dir
from .shared_test_utils import assert_same_value

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -47,6 +46,42 @@ def test_JDFTXInfile_known_lambda(infile_fname: str, bool_func: Callable[[JDFTXI
assert bool_func(jif)


def JDFTXInfile_self_consistency_tester(jif: JDFTXInfile, tmp_path: PathLike):
"""Create an assortment of JDFTXinfile created from the same data but through different methods, and test that
they are all equivalent through "assert_idential_jif" """
dict_jif = jif.as_dict()
# # Commenting out tests with jif2 due to the list representation asserted
jif2 = JDFTXInfile.get_dict_representation(JDFTXInfile._from_dict(dict_jif))
str_list_jif = jif.get_text_list()
str_jif = "\n".join(str_list_jif)
jif3 = JDFTXInfile.from_str(str_jif)
tmp_fname = tmp_path / "tmp.in"
jif.write_file(tmp_fname)
jif4 = JDFTXInfile.from_file(tmp_fname)
jifs = [jif, jif2, jif3, jif4]
for i in range(len(jifs)):
for j in range(i + 1, len(jifs)):
print(f"{i}, {j}")
assert_idential_jif(jifs[i], jifs[j])


def test_JDFTXInfile_from_dict(tmp_path) -> None:
jif = JDFTXInfile.from_file(ex_infile1_fname)
jif_dict = jif.as_dict()
# Test that dictionary can be modified and that _from_dict will fix set values
jif_dict["elec-cutoff"] = 20
jif2 = JDFTXInfile.from_dict(jif_dict)
JDFTXInfile_self_consistency_tester(jif2, tmp_path)


@pytest.mark.parametrize("infile_fname", [ex_infile3_fname, ex_infile1_fname, ex_infile2_fname])
def test_JDFTXInfile_self_consistency_fromfile(infile_fname: PathLike, tmp_path) -> None:
"""Test that JDFTXInfile objects with different assortments of tags survive inter-conversion done within
"JDFTXInfile_self_consistency_tester"""
jif = JDFTXInfile.from_file(infile_fname)
JDFTXInfile_self_consistency_tester(jif, tmp_path)


@pytest.mark.parametrize(
("val_key", "val"),
[
Expand All @@ -63,22 +98,12 @@ def test_JDFTXInfile_known_lambda(infile_fname: str, bool_func: Callable[[JDFTXI
("elec-cutoff", 20),
],
)
def test_JDFTXInfile_set_values(val_key: str, val: Any):
def test_JDFTXInfile_set_values(val_key: str, val: Any, tmp_path) -> None:
"""Test value setting for various tags"""
jif = JDFTXInfile.from_file(ex_infile1_fname)
jif[val_key] = val
# Test that the JDFTXInfile object is still consistent
JDFTXInfile_self_consistency_tester(jif)


def test_JDFTXInfile_from_dict():
"""Test the from_dict method"""
jif = JDFTXInfile.from_file(ex_infile1_fname)
jif_dict = jif.as_dict()
# Test that dictionary can be modified and that _from_dict will fix set values
jif_dict["elec-cutoff"] = 20
jif2 = JDFTXInfile.from_dict(jif_dict)
JDFTXInfile_self_consistency_tester(jif2)
JDFTXInfile_self_consistency_tester(jif, tmp_path)


@pytest.mark.parametrize(
Expand All @@ -91,15 +116,15 @@ def test_JDFTXInfile_from_dict():
("ion", "Fe 1 1 1 0"),
],
)
def test_JDFTXInfile_append_values(val_key: str, val: Any):
def test_JDFTXInfile_append_values(val_key: str, val: Any, tmp_path) -> None:
"""Test the append_tag method"""
jif = JDFTXInfile.from_file(ex_infile1_fname)
val_old = None if val_key not in jif else deepcopy(jif[val_key])
jif.append_tag(val_key, val)
val_new = jif[val_key]
assert val_old != val_new
# Test that the append_tag does not break the JDFTXInfile object
JDFTXInfile_self_consistency_tester(jif)
JDFTXInfile_self_consistency_tester(jif, tmp_path)


def test_JDFTXInfile_expected_exceptions():
Expand Down Expand Up @@ -217,34 +242,6 @@ def test_JDFTXInfile_knowns_simple(infile_fname: PathLike, knowns: dict):
assert_same_value(jif[key], val)


@pytest.mark.parametrize("infile_fname", [ex_infile3_fname, ex_infile1_fname, ex_infile2_fname])
def test_JDFTXInfile_self_consistency(infile_fname: PathLike):
"""Test that JDFTXInfile objects with different assortments of tags survive inter-conversion done within
"JDFTXInfile_self_consistency_tester"""
jif = JDFTXInfile.from_file(infile_fname)
JDFTXInfile_self_consistency_tester(jif)


def JDFTXInfile_self_consistency_tester(jif: JDFTXInfile):
"""Create an assortment of JDFTXinfile created from the same data but through different methods, and test that
they are all equivalent through "assert_idential_jif" """
dict_jif = jif.as_dict()
# # Commenting out tests with jif2 due to the list representation asserted
jif2 = JDFTXInfile.get_dict_representation(JDFTXInfile._from_dict(dict_jif))
str_list_jif = jif.get_text_list()
str_jif = "\n".join(str_list_jif)
jif3 = JDFTXInfile.from_str(str_jif)
tmp_fname = dump_files_dir / "tmp.in"
jif.write_file(tmp_fname)
jif4 = JDFTXInfile.from_file(tmp_fname)
jifs = [jif, jif2, jif3, jif4]
for i in range(len(jifs)):
for j in range(i + 1, len(jifs)):
print(f"{i}, {j}")
assert_idential_jif(jifs[i], jifs[j])
os.remove(tmp_fname)


def test_jdftxstructure():
"""Test the JDFTXStructure object associated with the JDFTXInfile object"""
jif = JDFTXInfile.from_file(ex_infile2_fname)
Expand Down
29 changes: 13 additions & 16 deletions tests/io/jdftx/test_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from pymatgen.io.jdftx.joutstructures import _get_joutstructures_start_idx
from pymatgen.io.jdftx.outputs import _find_jdftx_out_file

from .outputs_test_utils import noeigstats_outfile_path, write_mt_file
from .shared_test_utils import dump_files_dir
from .outputs_test_utils import noeigstats_outfile_path
from .shared_test_utils import write_mt_file


def test_get_start_lines():
Expand Down Expand Up @@ -66,26 +66,23 @@ def test_get_joutstructures_start_idx():
assert _get_joutstructures_start_idx(["ken", "ken"], out_slice_start_flag=start_flag) is None


def test_find_jdftx_out_file():
def test_find_jdftx_out_file(tmp_path):
"""Test the _find_jdftx_out_file function.
This function is used to find the JDFTx out file in a directory.
It tests the behavior to make sure the correct errors are raised on directories without and out file
and directories with multiple out files. And out file must match "*.out" or "out" exactly.
"""
with pytest.raises(FileNotFoundError, match="No JDFTx out file found in directory."):
_find_jdftx_out_file(dump_files_dir)
write_mt_file("test.out")
assert _find_jdftx_out_file(dump_files_dir) == dump_files_dir / "test.out"
_find_jdftx_out_file(tmp_path)
write_mt_file(tmp_path, "test.out")
assert _find_jdftx_out_file(tmp_path) == tmp_path / "test.out"
# out file has to match "*.out" or "out" exactly
write_mt_file("tinyout")
assert _find_jdftx_out_file(dump_files_dir) == dump_files_dir / "test.out"
remove(_find_jdftx_out_file(dump_files_dir))
write_mt_file("out")
assert _find_jdftx_out_file(dump_files_dir) == dump_files_dir / "out"
write_mt_file("tinyout.out")
write_mt_file(tmp_path, "tinyout")
assert _find_jdftx_out_file(tmp_path) == tmp_path / "test.out"
remove(_find_jdftx_out_file(tmp_path))
write_mt_file(tmp_path, "out")
assert _find_jdftx_out_file(tmp_path) == tmp_path / "out"
write_mt_file(tmp_path, "tinyout.out")
with pytest.raises(FileNotFoundError, match="Multiple JDFTx out files found in directory."):
_find_jdftx_out_file(dump_files_dir)
# remove tmp files
for remaining in dump_files_dir.glob("*out"):
remove(remaining)
_find_jdftx_out_file(tmp_path)

0 comments on commit 7a6b0ad

Please sign in to comment.