diff --git a/pymatgen/analysis/cost.py b/pymatgen/analysis/cost.py index a9b0698eabb..bf372d5d813 100644 --- a/pymatgen/analysis/cost.py +++ b/pymatgen/analysis/cost.py @@ -16,7 +16,6 @@ import scipy.constants as const from monty.design_patterns import singleton -from monty.string import unicode2str from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram from pymatgen.core.composition import Composition @@ -94,7 +93,7 @@ def __init__(self, filename): self._chemsys_entries = defaultdict(list) filename = os.path.join(os.path.dirname(__file__), filename) with open(filename) as f: - reader = csv.reader(f, quotechar=unicode2str("|")) + reader = csv.reader(f, quotechar="|") for row in reader: comp = Composition(row[0]) cost_per_mol = float(row[1]) * comp.weight.to("kg") * const.N_A diff --git a/pymatgen/core/xcfunc.py b/pymatgen/core/xcfunc.py index 61f03ecb77c..76e742d4725 100644 --- a/pymatgen/core/xcfunc.py +++ b/pymatgen/core/xcfunc.py @@ -6,7 +6,6 @@ from monty.functools import lazy_property from monty.json import MSONable -from monty.string import is_string from pymatgen.core.libxcfunc import LibxcFunc @@ -122,7 +121,7 @@ def asxc(cls, obj): """Convert object into Xcfunc.""" if isinstance(obj, cls): return obj - if is_string(obj): + if isinstance(obj, str): return cls.from_name(obj) raise TypeError(f"Don't know how to convert <{type(obj)}:{obj}> to Xcfunc") diff --git a/pymatgen/entries/entry_tools.py b/pymatgen/entries/entry_tools.py index 1b928ff773b..aca414d3453 100644 --- a/pymatgen/entries/entry_tools.py +++ b/pymatgen/entries/entry_tools.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Literal from monty.json import MontyDecoder, MontyEncoder, MSONable -from monty.string import unicode2str from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.structure_matcher import SpeciesComparator, StructureMatcher @@ -302,8 +301,8 @@ def to_csv(self, filename: str, latexify_names: bool = False) -> None: with open(filename, "w") as f: writer = csv.writer( f, - delimiter=unicode2str(","), - quotechar=unicode2str('"'), + delimiter=",", + quotechar='"', quoting=csv.QUOTE_MINIMAL, ) writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"]) @@ -326,8 +325,8 @@ def from_csv(cls, filename: str): with open(filename, encoding="utf-8") as f: reader = csv.reader( f, - delimiter=unicode2str(","), - quotechar=unicode2str('"'), + delimiter=",", + quotechar='"', quoting=csv.QUOTE_MINIMAL, ) entries = [] diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index 3bb9c05c2f7..708e7a6e67d 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -12,7 +12,6 @@ import matplotlib.pyplot as plt import numpy as np -from monty.string import is_string, list_strings from pymatgen.io.core import ParseError from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig @@ -107,7 +106,8 @@ def parse(self, filenames): Return: list of successfully read files. """ - filenames = list_strings(filenames) + if isinstance(filenames, str): + filenames = [filenames] read_ok = [] for fname in filenames: @@ -667,16 +667,16 @@ def get_section(self, section_name): def to_csv(self, fileobj=sys.stdout): """Write data on file fileobj using CSV format.""" - openclose = is_string(fileobj) + is_str = isinstance(fileobj, str) - if openclose: + if is_str: fileobj = open(fileobj, "w") # noqa: SIM115 for idx, section in enumerate(self.sections): fileobj.write(section.to_csvline(with_header=(idx == 0))) fileobj.flush() - if openclose: + if is_str: fileobj.close() def to_table(self, sort_key="wall_time", stop=None): @@ -718,7 +718,7 @@ def get_dataframe(self, sort_key="wall_time", **kwargs): def get_values(self, keys): """Return a list of values associated to a particular list of keys.""" - if is_string(keys): + if isinstance(keys, str): return [s.__dict__[keys] for s in self.sections] values = [] for k in keys: diff --git a/pymatgen/io/abinit/inputs.py b/pymatgen/io/abinit/inputs.py index ee836ecc91f..f0306095f38 100644 --- a/pymatgen/io/abinit/inputs.py +++ b/pymatgen/io/abinit/inputs.py @@ -18,7 +18,6 @@ import numpy as np from monty.collections import AttrDict from monty.json import MSONable -from monty.string import is_string, list_strings from pymatgen.core.structure import Structure from pymatgen.io.abinit import abiobjects as aobj @@ -99,9 +98,7 @@ # Default values used if user does not specify them -_DEFAULTS = { - "kppa": 1000, -} +_DEFAULTS = {"kppa": 1000} def as_structure(obj): @@ -115,7 +112,7 @@ def as_structure(obj): if isinstance(obj, Structure): return obj - if is_string(obj): + if isinstance(obj, str): return Structure.from_file(obj) if isinstance(obj, Mapping): @@ -151,7 +148,7 @@ def from_object(cls, obj): """ if isinstance(obj, cls): return obj - if is_string(obj): + if isinstance(obj, str): return cls(obj[0].upper()) raise TypeError(f"The object provided is not handled: type {type(obj).__name__}") @@ -676,8 +673,10 @@ def remove_vars(self, keys, strict=True): keys: string or list of strings with variable names. strict: If True, KeyError is raised if at least one variable is not present. """ + if isinstance(keys, str): + keys = [keys] removed = {} - for key in list_strings(keys): + for key in keys: if strict and key not in self: raise KeyError(f"{key=} not in self:\n {list(self)}") if key in self: @@ -710,7 +709,7 @@ class BasicAbinitInput(AbstractInput, MSONable): def __init__( self, structure, - pseudos, + pseudos: str | list[str] | list[Pseudo] | PseudoTable, pseudo_dir=None, comment=None, abi_args=None, @@ -745,11 +744,14 @@ def __init__( self._vars = dict(args) self.set_structure(structure) + if isinstance(pseudos, str): + pseudos = [pseudos] + if pseudo_dir is not None: pseudo_dir = os.path.abspath(pseudo_dir) if not os.path.exists(pseudo_dir): raise self.Error(f"Directory {pseudo_dir} does not exist") - pseudos = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)] + pseudos = [os.path.join(pseudo_dir, p) for p in pseudos] try: self._pseudos = PseudoTable.as_table(pseudos).get_pseudos_for_structure(self.structure) @@ -1092,8 +1094,10 @@ def __init__(self, structure: Structure, pseudos, pseudo_dir="", ndtset=1): else: # String(s) + if isinstance(pseudos, str): + pseudos = [pseudos] pseudo_dir = os.path.abspath(pseudo_dir) - pseudo_paths = [os.path.join(pseudo_dir, p) for p in list_strings(pseudos)] + pseudo_paths = [os.path.join(pseudo_dir, p) for p in pseudos] missing = [p for p in pseudo_paths if not os.path.exists(p)] if missing: diff --git a/pymatgen/io/abinit/pseudos.py b/pymatgen/io/abinit/pseudos.py index c0605c2f0d8..649603928ad 100644 --- a/pymatgen/io/abinit/pseudos.py +++ b/pymatgen/io/abinit/pseudos.py @@ -19,7 +19,6 @@ from monty.itertools import iterator_from_slice from monty.json import MontyDecoder, MSONable from monty.os.path import find_exts -from monty.string import is_string, list_strings from tabulate import tabulate from pymatgen.core.periodic_table import Element @@ -612,7 +611,7 @@ def _dict_from_lines(lines, key_nums, sep=None): Raises: ValueError if parsing fails. """ - if is_string(lines): + if isinstance(lines, str): lines = [lines] if not isinstance(key_nums, collections.abc.Iterable): @@ -1611,8 +1610,8 @@ def __init__(self, pseudos: Sequence[Pseudo]) -> None: if not isinstance(pseudos, collections.abc.Iterable): pseudos = [pseudos] - if len(pseudos) and is_string(pseudos[0]): - pseudos = list_strings(pseudos) + if isinstance(pseudos, str): + pseudos = [pseudos] self._pseudos_with_z = defaultdict(list) @@ -1772,7 +1771,9 @@ def select_symbols(self, symbols, ret_list=False): Prepend the symbol string with "-", to exclude pseudos. ret_list: if True a list of pseudos is returned instead of a PseudoTable """ - symbols = list_strings(symbols) + if isinstance(symbols, str): + symbols = [symbols] + exclude = symbols[0].startswith("-") if exclude: diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index 40a903664dd..55208a33630 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -18,7 +18,6 @@ import numpy as np from monty.io import zopen from monty.serialization import loadfn -from monty.string import remove_non_ascii from pymatgen.core.composition import Composition from pymatgen.core.lattice import Lattice @@ -137,7 +136,7 @@ def _process_string(cls, string): # remove empty lines string = re.sub(r"^\s*\n", "", string, flags=re.MULTILINE) # remove non_ascii - string = remove_non_ascii(string) + string = string.encode("ascii", "ignore").decode("ascii") # since line breaks in .cif files are mostly meaningless, # break up into a stream of tokens to parse, rejoining multiline # strings (between semicolons) diff --git a/pymatgen/io/vasp/sets.py b/pymatgen/io/vasp/sets.py index 4f060d40b5e..e9620b4ad28 100644 --- a/pymatgen/io/vasp/sets.py +++ b/pymatgen/io/vasp/sets.py @@ -974,7 +974,7 @@ def __init__( updates: dict[str, float] = {} # select the KSPACING and smearing parameters based on the bandgap - if self.bandgap < 1e-4: + if self.bandgap < bandgap_tol: updates.update(KSPACING=0.22, SIGMA=0.2, ISMEAR=2) else: rmin = max(1.5, 25.22 - 2.87 * bandgap) # Eq. 25 diff --git a/pymatgen/util/provenance.py b/pymatgen/util/provenance.py index 75a3edb95d8..fb4e457d74f 100644 --- a/pymatgen/util/provenance.py +++ b/pymatgen/util/provenance.py @@ -10,7 +10,6 @@ from io import StringIO from monty.json import MontyDecoder, MontyEncoder -from monty.string import remove_non_ascii try: from pybtex import errors @@ -29,10 +28,10 @@ __credits__ = "Dan Gunter" -MAX_HNODE_SIZE = 64000 # maximum size (bytes) of SNL HistoryNode -MAX_DATA_SIZE = 256000 # maximum size (bytes) of SNL data field +MAX_HNODE_SIZE = 64_000 # maximum size (bytes) of SNL HistoryNode +MAX_DATA_SIZE = 256_000 # maximum size (bytes) of SNL data field MAX_HNODES = 100 # maximum number of HistoryNodes in SNL file -MAX_BIBTEX_CHARS = 20000 # maximum number of characters for BibTeX reference +MAX_BIBTEX_CHARS = 20_000 # maximum number of characters for BibTeX reference def is_valid_bibtex(reference: str) -> bool: @@ -46,7 +45,7 @@ def is_valid_bibtex(reference: str) -> bool: """ # str is necessary since pybtex seems to have an issue with unicode. The # filter expression removes all non-ASCII characters. - sio = StringIO(remove_non_ascii(reference)) + sio = StringIO(reference.encode("ascii", "ignore").decode("ascii")) parser = bibtex.Parser() errors.set_strict_mode(enable=False) bib_data = parser.parse_stream(sio) diff --git a/tests/core/test_periodic_table.py b/tests/core/test_periodic_table.py index ceeaf8b9ec3..6aba53b1e99 100644 --- a/tests/core/test_periodic_table.py +++ b/tests/core/test_periodic_table.py @@ -1,7 +1,6 @@ from __future__ import annotations import math -import os import pickle import unittest from copy import deepcopy @@ -405,13 +404,12 @@ def test_pickle(self): cs = Species("Cs1+") cl = Species("Cl1+") - with open("cscl.pickle", "wb") as file: + with open(f"{self.tmp_path}/cscl.pickle", "wb") as file: pickle.dump((cs, cl), file) - with open("cscl.pickle", "rb") as file: + with open(f"{self.tmp_path}/cscl.pickle", "rb") as file: tup = pickle.load(file) assert tup == (cs, cl) - os.remove("cscl.pickle") def test_get_crystal_field_spin(self): assert Species("Fe", 2).get_crystal_field_spin() == 4 diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index 0b84737e1cc..f81df3a6280 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -2085,9 +2085,8 @@ def test_to_from_file_string(self): assert m == self.mol assert isinstance(m, Molecule) - self.mol.to(filename="CH4_testing.xyz") - assert os.path.isfile("CH4_testing.xyz") - os.remove("CH4_testing.xyz") + self.mol.to(filename=f"{self.tmp_path}/CH4_testing.xyz") + assert os.path.isfile(f"{self.tmp_path}/CH4_testing.xyz") def test_extract_cluster(self): species = self.mol.species * 2 diff --git a/tests/core/test_trajectory.py b/tests/core/test_trajectory.py index 2f7291a25ab..ed793f99366 100644 --- a/tests/core/test_trajectory.py +++ b/tests/core/test_trajectory.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import os import numpy as np from numpy.testing import assert_allclose @@ -448,12 +447,11 @@ def test_variable_lattice(self): assert all(np.allclose(struct.lattice.matrix, structures[i].lattice.matrix) for i, struct in enumerate(traj)) # Check if the file is written correctly when lattice is not constant. - traj.write_Xdatcar(filename="traj_test_XDATCAR") + traj.write_Xdatcar(filename=f"{self.tmp_path}/traj_test_XDATCAR") - # Load trajectory from written xdatcar and compare to original - written_traj = Trajectory.from_file("traj_test_XDATCAR", constant_lattice=False) + # Load trajectory from written XDATCAR and compare to original + written_traj = Trajectory.from_file(f"{self.tmp_path}/traj_test_XDATCAR", constant_lattice=False) self._check_traj_equality(traj, written_traj) - os.remove("traj_test_XDATCAR") def test_as_from_dict(self): d = self.traj.as_dict() @@ -465,9 +463,8 @@ def test_as_from_dict(self): assert isinstance(traj, Trajectory) def test_xdatcar_write(self): - self.traj.write_Xdatcar(filename="traj_test_XDATCAR") + self.traj.write_Xdatcar(filename=f"{self.tmp_path}/traj_test_XDATCAR") - # Load trajectory from written xdatcar and compare to original - written_traj = Trajectory.from_file("traj_test_XDATCAR") + # Load trajectory from written XDATCAR and compare to original + written_traj = Trajectory.from_file(f"{self.tmp_path}/traj_test_XDATCAR") self._check_traj_equality(self.traj, written_traj) - os.remove("traj_test_XDATCAR") diff --git a/tests/electronic_structure/test_plotter.py b/tests/electronic_structure/test_plotter.py index 60921b31d6a..ae512f123f0 100644 --- a/tests/electronic_structure/test_plotter.py +++ b/tests/electronic_structure/test_plotter.py @@ -89,7 +89,7 @@ def get_plot_attributes(ax: plt.Axes): return {"xaxis_limits": list(ax.get_xlim()), "yaxis_limits": list(ax.get_ylim())} -class TestBSPlotter(unittest.TestCase): +class TestBSPlotter(PymatgenTest): def setUp(self): with open(f"{TEST_FILES_DIR}/CaO_2605_bandstructure.json") as f: d = json.loads(f.read()) @@ -171,9 +171,8 @@ def test_get_plot(self): assert ax.get_ylim() == (-4.0, 7.6348), "wrong ylim" ax = self.plotter.get_plot(smooth=True) ax = self.plotter.get_plot(vbm_cbm_marker=True) - self.plotter.save_plot("bsplot.png") - assert os.path.isfile("bsplot.png") - os.remove("bsplot.png") + self.plotter.save_plot(f"{self.tmp_path}/bsplot.png") + assert os.path.isfile(f"{self.tmp_path}/bsplot.png") plt.close("all") # test plotter with 2 bandstructures @@ -183,9 +182,8 @@ def test_get_plot(self): ax = self.plotter_multi.get_plot(zero_to_efermi=False) assert ax.get_ylim() == (-15.2379, 12.67141266), "wrong ylim" ax = self.plotter_multi.get_plot(smooth=True) - self.plotter_multi.save_plot("bsplot.png") - assert os.path.isfile("bsplot.png") - os.remove("bsplot.png") + self.plotter_multi.save_plot(f"{self.tmp_path}/bsplot.png") + assert os.path.isfile(f"{self.tmp_path}/bsplot.png") plt.close("all") @@ -489,14 +487,14 @@ def test_get_plot(self): assert ax_cohp.lines[1].get_linestyle() == "--" for label in legend_labels: assert label in self.cohp_plot._cohps - linesindex = legend_labels.index("1") - linestyles = {Spin.up: "-", Spin.down: "--"} + lines_index = legend_labels.index("1") + line_styles = {Spin.up: "-", Spin.down: "--"} cohp_fe_fe = self.cohp.all_cohps["1"] for s, spin in enumerate([Spin.up, Spin.down]): - lines = ax_cohp.lines[2 * linesindex + s] + lines = ax_cohp.lines[2 * lines_index + s] assert_allclose(lines.get_xdata(), -cohp_fe_fe.cohp[spin]) assert_allclose(lines.get_ydata(), self.cohp.energies) - assert lines.get_linestyle() == linestyles[spin] + assert lines.get_linestyle() == line_styles[spin] plt.close() ax_cohp = self.cohp_plot.get_plot(invert_axes=False, plot_negative=False) @@ -504,7 +502,7 @@ def test_get_plot(self): assert ax_cohp.get_xlabel() == "$E$ (eV)" assert ax_cohp.get_ylabel() == "COHP" for s, spin in enumerate([Spin.up, Spin.down]): - lines = ax_cohp.lines[2 * linesindex + s] + lines = ax_cohp.lines[2 * lines_index + s] assert_allclose(lines.get_xdata(), self.cohp.energies) assert_allclose(lines.get_ydata(), cohp_fe_fe.cohp[spin]) plt.close() @@ -513,7 +511,7 @@ def test_get_plot(self): assert ax_cohp.get_xlabel() == "-ICOHP (eV)" for s, spin in enumerate([Spin.up, Spin.down]): - lines = ax_cohp.lines[2 * linesindex + s] + lines = ax_cohp.lines[2 * lines_index + s] assert_allclose(lines.get_xdata(), -cohp_fe_fe.icohp[spin]) coop_dict = {"Bi5-Bi6": self.coop.all_cohps["10"]} @@ -526,14 +524,13 @@ def test_get_plot(self): coop_bi_bi = self.coop.all_cohps["10"].cohp[Spin.up] assert_allclose(lines_coop.get_xdata(), coop_bi_bi) - # Cleanup. + # cleanup plt.close("all") def test_save_plot(self): self.cohp_plot.add_cohp_dict(self.cohp.all_cohps) ax = self.cohp_plot.get_plot() assert isinstance(ax, plt.Axes) - self.cohp_plot.save_plot("cohpplot.png") - assert os.path.isfile("cohpplot.png") - os.remove("cohpplot.png") + self.cohp_plot.save_plot(f"{self.tmp_path}/cohpplot.png") + assert os.path.isfile(f"{self.tmp_path}/cohpplot.png") plt.close("all") diff --git a/tests/entries/test_entry_tools.py b/tests/entries/test_entry_tools.py index d12a3ae505c..457f0d09b14 100644 --- a/tests/entries/test_entry_tools.py +++ b/tests/entries/test_entry_tools.py @@ -1,18 +1,15 @@ from __future__ import annotations -import os -import unittest - import pytest from monty.serialization import dumpfn, loadfn from pymatgen.core.periodic_table import Element from pymatgen.entries.computed_entries import ComputedEntry from pymatgen.entries.entry_tools import EntrySet, group_entries_by_composition, group_entries_by_structure -from pymatgen.util.testing import TEST_FILES_DIR +from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -class TestFunc(unittest.TestCase): +class TestFunc(PymatgenTest): def test_group_entries_by_structure(self): entries = loadfn(f"{TEST_FILES_DIR}/TiO2_entries.json") groups = group_entries_by_structure(entries) @@ -42,7 +39,7 @@ def test_group_entries_by_composition(self): assert g == sorted(g, key=lambda e: e.energy_per_atom) -class TestEntrySet(unittest.TestCase): +class TestEntrySet(PymatgenTest): def setUp(self): entries = loadfn(f"{TEST_FILES_DIR}/Li-Fe-P-O_entries.json") self.entry_set = EntrySet(entries) @@ -64,7 +61,6 @@ def test_remove_non_ground_states(self): assert len(self.entry_set) < length def test_as_dict(self): - dumpfn(self.entry_set, "temp_entry_set.json") - entry_set = loadfn("temp_entry_set.json") + dumpfn(self.entry_set, f"{self.tmp_path}/temp_entry_set.json") + entry_set = loadfn(f"{self.tmp_path}/temp_entry_set.json") assert len(entry_set) == len(self.entry_set) - os.remove("temp_entry_set.json") diff --git a/tests/io/feff/test_inputs.py b/tests/io/feff/test_inputs.py index dc48d00ca51..205a8307144 100644 --- a/tests/io/feff/test_inputs.py +++ b/tests/io/feff/test_inputs.py @@ -116,8 +116,8 @@ def test_get_str(self): header = Header.from_str(header_string) struct = header.struct central_atom = "O" - a = Atoms(struct, central_atom, radius=10.0) - atoms = str(a) + atoms = Atoms(struct, central_atom, radius=10.0) + atoms = str(atoms) assert atoms.splitlines()[3].split()[4] == central_atom, "failed to create ATOMS string" def test_as_dict_and_from_dict(self): diff --git a/tests/io/feff/test_sets.py b/tests/io/feff/test_sets.py index 980dd333ece..8e437e0eca6 100644 --- a/tests/io/feff/test_sets.py +++ b/tests/io/feff/test_sets.py @@ -141,13 +141,10 @@ def test_reciprocal_tags_and_input(self): all_input = elnes.all_input() assert "ATOMS" not in all_input assert "POTENTIALS" not in all_input - elnes.write_input() + elnes.write_input(output_dir=self.tmp_path) structure = Structure.from_file("Co2O2.cif") assert self.structure.matches(structure) - os.remove("HEADER") - os.remove("PARAMETERS") - os.remove("feff.inp") - os.remove("Co2O2.cif") + assert {*os.listdir()} == {"Co2O2.cif", "HEADER", "PARAMETERS", "feff.inp"} def test_small_system_exafs(self): exafs_settings = MPEXAFSSet(self.absorbing_atom, self.structure) diff --git a/tests/io/lobster/test_inputs.py b/tests/io/lobster/test_inputs.py index 20c51cbbd99..db9063e2762 100644 --- a/tests/io/lobster/test_inputs.py +++ b/tests/io/lobster/test_inputs.py @@ -2302,15 +2302,13 @@ def test_write_file(self): filename=f"{TEST_FILES_DIR}/cohp/LCAOWaveFunctionAfterLSO1PlotOfSpin1Kpoint1band1.gz", structure=Structure.from_file(f"{TEST_FILES_DIR}/cohp/POSCAR_O.gz"), ) - wave1.write_file(filename="wavecar_test.vasp", part="real") + wave1.write_file(filename=f"{self.tmp_path}/wavecar_test.vasp", part="real") assert os.path.isfile("wavecar_test.vasp") - wave1.write_file(filename="wavecar_test.vasp", part="imaginary") + wave1.write_file(filename=f"{self.tmp_path}/wavecar_test.vasp", part="imaginary") assert os.path.isfile("wavecar_test.vasp") - os.remove("wavecar_test.vasp") - wave1.write_file(filename="density.vasp", part="density") + wave1.write_file(filename=f"{self.tmp_path}/density.vasp", part="density") assert os.path.isfile("density.vasp") - os.remove("density.vasp") class TestSitePotentials(PymatgenTest): diff --git a/tests/io/qchem/test_sets.py b/tests/io/qchem/test_sets.py index 0b3a17b7985..cfcdf4221e7 100644 --- a/tests/io/qchem/test_sets.py +++ b/tests/io/qchem/test_sets.py @@ -188,8 +188,8 @@ def test_pcm_write(self): pcm_dielectric=10.0, max_scf_cycles=35, ) - dict_set.write("mol.qin") - test_dict = QCInput.from_file("mol.qin").as_dict() + dict_set.write(f"{self.tmp_path}/mol.qin") + test_dict = QCInput.from_file(f"{self.tmp_path}/mol.qin").as_dict() rem = { "job_type": "opt", "basis": "6-31G*", @@ -216,7 +216,6 @@ def test_pcm_write(self): qc_input = QCInput(molecule=test_molecule, rem=rem, pcm=pcm, solvent={"dielectric": 10.0}) for k, v in qc_input.as_dict().items(): assert v == test_dict[k] - os.remove("mol.qin") def test_isosvp_write(self): """Also tests overwrite_inputs with a RHOISO value.""" @@ -231,8 +230,8 @@ def test_isosvp_write(self): max_scf_cycles=35, overwrite_inputs={"svp": {"RHOISO": 0.0009}}, ) - dict_set.write("mol.qin") - test_dict = QCInput.from_file("mol.qin").as_dict() + dict_set.write(f"{self.tmp_path}/mol.qin") + test_dict = QCInput.from_file(f"{self.tmp_path}/mol.qin").as_dict() rem = { "job_type": "opt", "basis": "def2-SVPD", @@ -258,7 +257,6 @@ def test_isosvp_write(self): ) for k, v in qc_input.as_dict().items(): assert v == test_dict[k] - os.remove("mol.qin") def test_smd_write(self): test_molecule = QCInput.from_file(f"{TEST_DIR}/pcm.qin").molecule @@ -271,8 +269,8 @@ def test_smd_write(self): smd_solvent="water", max_scf_cycles=35, ) - dict_set.write("mol.qin") - test_dict = QCInput.from_file("mol.qin").as_dict() + dict_set.write(f"{self.tmp_path}/mol.qin") + test_dict = QCInput.from_file(f"{self.tmp_path}/mol.qin").as_dict() rem = { "job_type": "opt", "basis": "6-31G*", @@ -293,7 +291,6 @@ def test_smd_write(self): qc_input = QCInput(molecule=test_molecule, rem=rem, smx={"solvent": "water"}) for k, v in qc_input.as_dict().items(): assert v == test_dict[k] - os.remove("mol.qin") def test_cmirs_write(self): """Also tests overwrite_inputs with a RHOISO value.""" @@ -308,8 +305,8 @@ def test_cmirs_write(self): max_scf_cycles=35, overwrite_inputs={"svp": {"RHOISO": 0.0005}}, ) - dict_set.write("mol.qin") - test_dict = QCInput.from_file("mol.qin").as_dict() + dict_set.write(f"{self.tmp_path}/mol.qin") + test_dict = QCInput.from_file(f"{self.tmp_path}/mol.qin").as_dict() rem = { "job_type": "opt", "basis": "def2-SVPD", @@ -345,7 +342,6 @@ def test_cmirs_write(self): ) for k, v in qc_input.as_dict().items(): assert v == test_dict[k] - os.remove("mol.qin") def test_custom_smd_write(self): test_molecule = QCInput.from_file(f"{TEST_DIR}/pcm.qin").molecule @@ -359,8 +355,8 @@ def test_custom_smd_write(self): custom_smd="90.00,1.415,0.00,0.735,20.2,0.00,0.00", max_scf_cycles=35, ) - dict_set.write("mol.qin") - test_dict = QCInput.from_file("mol.qin").as_dict() + dict_set.write(f"{self.tmp_path}/mol.qin") + test_dict = QCInput.from_file(f"{self.tmp_path}/mol.qin").as_dict() rem = { "job_type": "opt", "basis": "6-31G*", @@ -381,7 +377,6 @@ def test_custom_smd_write(self): qc_input = QCInput(molecule=test_molecule, rem=rem, smx={"solvent": "other"}) for k, v in qc_input.as_dict().items(): assert v == test_dict[k] - os.remove("mol.qin") with open("solvent_data") as sd: lines = sd.readlines() assert lines[0] == "90.00,1.415,0.00,0.735,20.2,0.00,0.00" diff --git a/tests/io/test_shengbte.py b/tests/io/test_shengbte.py index 209165c1773..da7109736c2 100644 --- a/tests/io/test_shengbte.py +++ b/tests/io/test_shengbte.py @@ -1,26 +1,22 @@ from __future__ import annotations import os -import unittest +import pytest from numpy.testing import assert_array_equal from pymatgen.io.shengbte import Control from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -try: - import f90nml -except ImportError: - f90nml = None - -test_dir = f"{TEST_FILES_DIR}/shengbte" +f90nml = pytest.importorskip("f90nml") +TEST_DIR = f"{TEST_FILES_DIR}/shengbte" module_dir = os.path.dirname(os.path.abspath(__file__)) class TestShengBTE(PymatgenTest): def setUp(self): - self.filename = f"{test_dir}/CONTROL-CSLD_Si" + self.filename = f"{TEST_DIR}/CONTROL-CSLD_Si" self.test_dict = { "nelements": 1, "natoms": 2, @@ -44,7 +40,6 @@ def setUp(self): "nanowires": False, } - @unittest.skipIf(f90nml is None, "No f90nml") def test_from_file(self): io = Control.from_file(self.filename) assert io["nelements"] == 1 @@ -72,31 +67,23 @@ def test_from_file(self): assert not io["nonanalytic"] assert not io["nanowires"] - if os.path.exists(f"{test_dir}/test_control"): - os.remove(f"{test_dir}/test_control") - io.to_file(filename=f"{test_dir}/test_control") + io.to_file(filename=f"{self.tmp_path}/test_control") - with open(f"{test_dir}/test_control") as file: + with open(f"{self.tmp_path}/test_control") as file: test_string = file.read() - with open(f"{test_dir}/CONTROL-CSLD_Si") as reference_file: + with open(f"{TEST_DIR}/CONTROL-CSLD_Si") as reference_file: reference_string = reference_file.read() assert test_string == reference_string - os.remove(f"{test_dir}/test_control") - @unittest.skipIf(f90nml is None, "No f90nml") def test_from_dict(self): io = Control.from_dict(self.test_dict) - if os.path.exists(f"{test_dir}/test_control"): - os.remove(f"{test_dir}/test_control") - io.to_file(filename=f"{test_dir}/test_control") - with open(f"{test_dir}/test_control") as file: + io.to_file(filename=f"{self.tmp_path}/test_control") + with open(f"{self.tmp_path}/test_control") as file: test_string = file.read() - with open(f"{test_dir}/CONTROL-CSLD_Si") as reference_file: + with open(f"{TEST_DIR}/CONTROL-CSLD_Si") as reference_file: reference_string = reference_file.read() assert test_string == reference_string - os.remove(f"{test_dir}/test_control") - @unittest.skipIf(f90nml is None, "No f90nml") def test_msonable_implementation(self): # tests as dict and from dict methods ctrl_from_file = Control.from_file(self.filename) diff --git a/tests/io/vasp/test_outputs.py b/tests/io/vasp/test_outputs.py index 5db25cc0d85..a372ec94078 100644 --- a/tests/io/vasp/test_outputs.py +++ b/tests/io/vasp/test_outputs.py @@ -1345,19 +1345,16 @@ def test_init(self): chgcar = self.chgcar_spin - self.chgcar_spin assert chgcar.get_integrated_diff(0, 1)[0, 1] == approx(0) - ans = [1.56472768, 3.25985108, 3.49205728, 3.66275028, 3.8045896, 5.10813352] + expected = [1.56472768, 3.25985108, 3.49205728, 3.66275028, 3.8045896, 5.10813352] actual = self.chgcar_fe3o4.get_integrated_diff(0, 3, 6) - assert_allclose(actual[:, 1], ans) + assert_allclose(actual[:, 1], expected) def test_write(self): - self.chgcar_spin.write_file("CHGCAR_pmg") - with open("CHGCAR_pmg") as f: - for i, line in enumerate(f): - if i == 22130: + self.chgcar_spin.write_file(out_path := f"{self.tmp_path}/CHGCAR_pmg") + with open(out_path) as file: + for idx, line in enumerate(file): + if idx in (22130, 44255): assert line == "augmentation occupancies 1 15\n" - if i == 44255: - assert line == "augmentation occupancies 1 15\n" - os.remove("CHGCAR_pmg") def test_soc_chgcar(self): assert set(self.chgcar_NiO_SOC.data) == {"total", "diff_x", "diff_y", "diff_z", "diff"} @@ -1367,11 +1364,7 @@ def test_soc_chgcar(self): # check our construction of chg.data['diff'] makes sense # this has been checked visually too and seems reasonable assert abs(self.chgcar_NiO_SOC.data["diff"][0][0][0]) == np.linalg.norm( - [ - self.chgcar_NiO_SOC.data["diff_x"][0][0][0], - self.chgcar_NiO_SOC.data["diff_y"][0][0][0], - self.chgcar_NiO_SOC.data["diff_z"][0][0][0], - ] + [self.chgcar_NiO_SOC.data[f"diff_{key}"][0][0][0] for key in "xyz"] ) # and that the net magnetization is about zero @@ -1379,18 +1372,17 @@ def test_soc_chgcar(self): # vasp output, but might be due to chgcar limitations? assert self.chgcar_NiO_SOC.net_magnetization == approx(0.0, abs=1e-0) - self.chgcar_NiO_SOC.write_file("CHGCAR_pmg_soc") - chg_from_file = Chgcar.from_file("CHGCAR_pmg_soc") + self.chgcar_NiO_SOC.write_file(out_path := f"{self.tmp_path}/CHGCAR_pmg_soc") + chg_from_file = Chgcar.from_file(out_path) assert chg_from_file.is_soc - os.remove("CHGCAR_pmg_soc") @unittest.skipIf(h5py is None, "h5py required for HDF5 support.") def test_hdf5(self): chgcar = Chgcar.from_file(f"{TEST_FILES_DIR}/CHGCAR.NiO_SOC.gz") - chgcar.to_hdf5("chgcar_test.hdf5") + chgcar.to_hdf5(out_path := f"{self.tmp_path}/chgcar_test.hdf5") import h5py - with h5py.File("chgcar_test.hdf5", "r") as f: + with h5py.File(out_path, "r") as f: assert_allclose(f["vdata"]["total"], chgcar.data["total"]) assert_allclose(f["vdata"]["diff"], chgcar.data["diff"]) assert_allclose(f["lattice"], chgcar.structure.lattice.matrix) @@ -1401,9 +1393,8 @@ def test_hdf5(self): for sp in f["species"]: assert sp in [b"Ni", b"O"] - chgcar2 = Chgcar.from_hdf5("chgcar_test.hdf5") + chgcar2 = Chgcar.from_hdf5(out_path) assert_allclose(chgcar2.data["total"], chgcar.data["total"]) - os.remove("chgcar_test.hdf5") def test_spin_data(self): for v in self.chgcar_spin.spin_data.values(): diff --git a/tests/io/vasp/test_sets.py b/tests/io/vasp/test_sets.py index c1f7ac51380..5d4fb26be4c 100644 --- a/tests/io/vasp/test_sets.py +++ b/tests/io/vasp/test_sets.py @@ -911,7 +911,7 @@ def test_init(self): assert vis.incar["ISMEAR"] == 0 vis.write_input(self.tmp_path) assert os.path.isfile(f"{self.tmp_path}/CHGCAR") - os.remove(f"{self.tmp_path}/CHGCAR") + os.remove(f"{self.tmp_path}/CHGCAR") # needed for next assert vis = self.set.from_prev_calc(prev_calc_dir=prev_run, standardize=True, mode="Line", copy_chgcar=True) vis.write_input(self.tmp_path) @@ -954,7 +954,7 @@ def test_override_from_prev(self): assert vis.incar["ISMEAR"] == 0 vis.write_input(self.tmp_path) assert os.path.isfile(f"{self.tmp_path}/CHGCAR") - os.remove(f"{self.tmp_path}/CHGCAR") + os.remove(f"{self.tmp_path}/CHGCAR") # needed for next assert vis = self.set(dummy_structure, standardize=True, mode="Line", copy_chgcar=True) vis.override_from_prev_calc(prev_calc_dir=prev_run) @@ -1538,11 +1538,11 @@ def test_scan_substitute(self): def test_bandgap_tol(self): # Test that the bandgap tolerance is applied correctly bandgap = 0.01 - for bandgap_tol, expected_kspacing in ((0.001, 0.2668137888), (0.02, 0.26681378884)): + for bandgap_tol, expected_kspacing in ((0.001, 0.2668137888), (0.02, 0.22)): incar = MPScanRelaxSet(self.struct, bandgap=0.01, bandgap_tol=bandgap_tol).incar assert incar["KSPACING"] == approx(expected_kspacing, abs=1e-5), f"{bandgap_tol=}, {bandgap=}" - assert incar["ISMEAR"] == -5 - assert incar["SIGMA"] == 0.05 + assert incar["ISMEAR"] == -5 if bandgap > bandgap_tol else 2 + assert incar["SIGMA"] == 0.05 if bandgap > bandgap_tol else 0.2 def test_kspacing(self): # Test that KSPACING is capped at 0.44 for insulators diff --git a/tests/phonon/test_bandstructure.py b/tests/phonon/test_bandstructure.py index 3b94a643d5f..ca600a73421 100644 --- a/tests/phonon/test_bandstructure.py +++ b/tests/phonon/test_bandstructure.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os from numpy.testing import assert_allclose, assert_array_equal from pytest import approx @@ -76,8 +75,4 @@ def test_dict_methods(self): self.assert_msonable(self.bs2) def test_write_methods(self): - self.bs2.write_phononwebsite("test.json") - - def tearDown(self): - if os.path.isfile("test.json"): - os.remove("test.json") + self.bs2.write_phononwebsite(f"{self.tmp_path}/test.json") diff --git a/tests/phonon/test_ir_spectra.py b/tests/phonon/test_ir_spectra.py index ab167ed5d13..648c3426ee6 100644 --- a/tests/phonon/test_ir_spectra.py +++ b/tests/phonon/test_ir_spectra.py @@ -1,7 +1,5 @@ from __future__ import annotations -import os - from monty.serialization import loadfn from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest @@ -12,12 +10,8 @@ def setUp(self): self.ir_spectra = loadfn(f"{TEST_FILES_DIR}/ir_spectra_mp-991652_DDB.json") def test_basic(self): - self.ir_spectra.write_json("test.json") - ir_spectra = loadfn("test.json") + self.ir_spectra.write_json(f"{self.tmp_path}/test.json") + ir_spectra = loadfn(f"{self.tmp_path}/test.json") irdict = ir_spectra.as_dict() ir_spectra.from_dict(irdict) ir_spectra.plot(show=False) - - def tearDown(self): - if os.path.isfile("test.json"): - os.remove("test.json")