From 510d88aefcc4bcc9db435ac87286d519e1688230 Mon Sep 17 00:00:00 2001 From: Ben Rich Date: Tue, 14 Jan 2025 18:13:20 -0700 Subject: [PATCH] Storing a list of the orbital labels instead of a dictionary of orbitals available per ion-type, adding a value check for the bandProjections, improving docustring for JDFTXOutputs --- src/pymatgen/io/jdftx/_output_utils.py | 98 +++++++++++++++++++++++++- src/pymatgen/io/jdftx/outputs.py | 20 ++++-- tests/io/jdftx/outputs_test_utils.py | 8 +-- tests/io/jdftx/test_jdftxoutput.py | 3 +- 4 files changed, 118 insertions(+), 11 deletions(-) diff --git a/src/pymatgen/io/jdftx/_output_utils.py b/src/pymatgen/io/jdftx/_output_utils.py index b0a392c15cc..85c117b6da3 100644 --- a/src/pymatgen/io/jdftx/_output_utils.py +++ b/src/pymatgen/io/jdftx/_output_utils.py @@ -501,7 +501,7 @@ def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool: ] -def _get_atom_orb_labels_dict(bandfile_filepath: Path) -> dict[str, list[str]]: +def _get_atom_orb_labels_ref_dict(bandfile_filepath: Path) -> dict[str, list[str]]: """ Return a dictionary mapping each atom symbol to all atomic orbital projection string representations. @@ -546,3 +546,99 @@ def _get_atom_orb_labels_dict(bandfile_filepath: Path) -> dict[str, list[str]]: else: labels_dict[sym] += mls return labels_dict + + +def _get_atom_count_list(bandfile_filepath: Path) -> list[tuple[str, int]]: + """ + Return a list of tuples of atom symbols and counts. + + Return a list of tuples of atom symbols and counts. This is superior to a dictionary as it maintains the order of + the atoms in the bandfile. + + Args: + bandfile_filepath (str | Path): The path to the bandfile. + + Returns: + list[tuple[str, int]]: A list of tuples of atom symbols and counts. + """ + bandfile = read_file(bandfile_filepath) + atom_count_list = [] + + for i, line in enumerate(bandfile): + if i > 1: + if "#" in line: + break + lsplit = line.strip().split() + sym = lsplit[0].strip() + count = int(lsplit[1].strip()) + atom_count_list.append((sym, count)) + return atom_count_list + + +def _get_orb_label_list_expected_len(labels_dict: dict[str, list[str]], atom_count_list: list[tuple[str, int]]) -> int: + """ + Return the expected length of the atomic orbital projection string representation list. + + Return the expected length of the atomic orbital projection string representation list. + + Args: + labels_dict (dict[str, list[str]]): A dictionary mapping each atom symbol to all atomic orbital projection + string representations. + atom_count_list (list[tuple[str, int]]): A list of tuples of atom symbols and counts. + + Returns: + int: The expected length of the atomic orbital projection string representation list. + """ + expected_len = 0 + for ion_tuple in atom_count_list: + ion = ion_tuple[0] + count = ion_tuple[1] + orbs = labels_dict[ion] + expected_len += count * len(orbs) + return expected_len + + +def _get_orb_label(ion: str, idx: int, orb: str) -> str: + """ + Return the string representation for an orbital projection. + + Return the string representation for an orbital projection. + + Args: + ion (str): The symbol of the atom. + idx (int): The index of the atom. + orb (str): The atomic orbital projection string representation. + + Returns: + str: The atomic orbital projection string representation for the atom. + """ + return f"{ion}#{idx + 1}({orb})" + + +def _get_orb_label_list(bandfile_filepath: Path) -> tuple[str, ...]: + """ + Return a tuple of all atomic orbital projection string representations. + + Return a tuple of all atomic orbital projection string representations. + + Args: + bandfile_filepath (str | Path): The path to the bandfile. + + Returns: + tuple[str]: A list of all atomic orbital projection string representations. + """ + labels_dict = _get_atom_orb_labels_ref_dict(bandfile_filepath) + atom_count_list = _get_atom_count_list(bandfile_filepath) + read_file(bandfile_filepath) + labels_list: list[str] = [] + for ion_tuple in atom_count_list: + ion = ion_tuple[0] + orbs = labels_dict[ion] + count = ion_tuple[1] + for i in range(count): + for orb in orbs: + labels_list.append(_get_orb_label(ion, i, orb)) + # This is most likely unnecessary, but it is a good check to have. + if len(labels_list) != _get_orb_label_list_expected_len(labels_dict, atom_count_list): + raise RuntimeError("Number of atomic orbital projections does not match expected length.") + return tuple(labels_list) diff --git a/src/pymatgen/io/jdftx/outputs.py b/src/pymatgen/io/jdftx/outputs.py index 486c73d122a..14663ebc69f 100644 --- a/src/pymatgen/io/jdftx/outputs.py +++ b/src/pymatgen/io/jdftx/outputs.py @@ -16,8 +16,8 @@ class is written. import numpy as np from pymatgen.io.jdftx._output_utils import ( - _get_atom_orb_labels_dict, _get_nbands_from_bandfile_filepath, + _get_orb_label_list, get_proj_tju_from_file, read_outfile_slices, ) @@ -90,6 +90,18 @@ class JDFTXOutputs: (nspin, nkpt, nbands, nion, nionproj) to save on memory as nonionproj is different depending on the ion type. This array may also be complex if specified in 'band-projections-params' in the JDFTx input, allowing for pCOHP analysis. + eigenvals (np.ndarray): The eigenvalues. Stored in shape (nstates, nbands) where nstates is nspin*nkpts (nkpts + may not equal prod(kfolding) if symmetry reduction occurred) and nbands is the number of bands. + orb_label_list (tuple[str, ...]): A tuple of the orbital labels for the bandProjections file, where the i'th + element describes the i'th orbital. Orbital labels are formatted as "#()", + where is the element symbol of the ion, is the 1-based index of the ion-type in the + structure (ie C#2 would be the second carbon atom, but not necessarily the second ion in the structure), + and is a string describing "l" and "ml" quantum numbers (ie "p_x" or "d_yz"). Note that while "z" + corresponds to the "z" axis, "x" and "y" are arbitrary and may not correspond to the actual x and y axes of + the structure. In the case where multiple shells of a given "l" are available within the projections, a + 0-based index will appear mimicking a principle quantum number (ie "0px" for first shell and "1px" for + second shell). The actual principal quantum number is not stored in the JDFTx output files and must be + inferred by the user. """ calc_dir: str | Path = field(init=True) @@ -99,7 +111,7 @@ class JDFTXOutputs: bandProjections: np.ndarray | None = field(init=False) eigenvals: np.ndarray | None = field(init=False) # Misc metadata for interacting with the data - atom_orb_labels_dict: dict[int, str] | None = field(init=False) + orb_label_list: tuple[str, ...] | None = field(init=False) @classmethod def from_calc_dir(cls, calc_dir: str | Path, store_vars: list[str] | None = None) -> JDFTXOutputs: @@ -167,15 +179,13 @@ def _check_bandProjections(self): def _store_bandProjections(self): if "bandProjections" in self.paths: self.bandProjections = get_proj_tju_from_file(self.paths["bandProjections"]) - self.atom_orb_labels_dict = _get_atom_orb_labels_dict(self.paths["bandProjections"]) + self.orb_label_list = _get_orb_label_list(self.paths["bandProjections"]) def _check_eigenvals(self): """Check for misaligned data within eigenvals file.""" if "eigenvals" in self.paths: if not self.paths["eigenvals"].exists(): raise RuntimeError("Allocated path for eigenvals does not exist.") - # TODO: We should not have to load the entire file to find its length - replace with something more - # efficient once Claude lets me create an account. tj = len(np.fromfile(self.paths["eigenvals"])) nstates_float = tj / self.outfile.nbands if not np.isclose(nstates_float, int(nstates_float)): diff --git a/tests/io/jdftx/outputs_test_utils.py b/tests/io/jdftx/outputs_test_utils.py index 005729e8e49..081ffc7052f 100644 --- a/tests/io/jdftx/outputs_test_utils.py +++ b/tests/io/jdftx/outputs_test_utils.py @@ -137,10 +137,9 @@ def jdftxoutfile_matches_known(joutfile: JDFTXOutfile, known: dict): "eigenvals": n2_ex_calc_dir / Path("eigenvals"), } n2_ex_calc_dir_bandprojections_metadata = { - "atom_orb_labels_dict": { - "N": ["s", "px", "py", "pz"], - }, + "orb_label_list": ["N#1(s)", "N#1(px)", "N#1(py)", "N#1(pz)", "N#2(s)", "N#2(px)", "N#2(py)", "N#2(pz)"], "shape": (54, 15, 8), + "first val": -0.1331527 + 0.5655596j, } @@ -150,8 +149,9 @@ def jdftxoutfile_matches_known(joutfile: JDFTXOutfile, known: dict): "eigenvals": nh3_ex_calc_dir / Path("eigenvals"), } nh3_ex_calc_dir_bandprojections_metadata = { - "atom_orb_labels_dict": {"N": ["s", "px", "py", "pz"], "H": ["s"]}, + "orb_label_list": ["N#1(s)", "N#1(px)", "N#1(py)", "N#1(pz)", "H#1(s)", "H#2(s)", "H#3(s)"], "shape": (16, 14, 7), + "first val": -0.0688767 + 0.9503786j, } example_sp_outfile_path = ex_out_files_dir / Path("example_sp.out") diff --git a/tests/io/jdftx/test_jdftxoutput.py b/tests/io/jdftx/test_jdftxoutput.py index 0da5b41e5e9..5fe308c1a05 100644 --- a/tests/io/jdftx/test_jdftxoutput.py +++ b/tests/io/jdftx/test_jdftxoutput.py @@ -60,10 +60,11 @@ def test_store_vars(calc_dir: Path, store_vars: list[str]): def test_store_bandprojections(calc_dir: Path, known_metadata: dict): """Test that the stored band projections are correct.""" jo = JDFTXOutputs.from_calc_dir(calc_dir, store_vars=["bandProjections"]) - for var in ["atom_orb_labels_dict"]: + for var in ["orb_label_list"]: assert hasattr(jo, var) assert_same_value(getattr(jo, var), known_metadata[var]) assert_same_value(jo.bandProjections.shape, known_metadata["shape"]) + assert pytest.approx(jo.bandProjections[0, 0, 0]) == known_metadata["first val"] @pytest.mark.parametrize("calc_dir", [n2_ex_calc_dir, nh3_ex_calc_dir])