Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storing a list of the orbital labels instead of a dictionary of orbit… #43

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion src/pymatgen/io/jdftx/_output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _is_complex_bandfile_filepath(bandfile_filepath: str | Path) -> bool:
]


def _get_atom_orb_labels_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
def _get_atom_orb_labels_ref_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
"""
Return a dictionary mapping each atom symbol to all atomic orbital projection string representations.

Expand Down Expand Up @@ -546,3 +546,99 @@ def _get_atom_orb_labels_dict(bandfile_filepath: Path) -> dict[str, list[str]]:
else:
labels_dict[sym] += mls
return labels_dict


def _get_atom_count_list(bandfile_filepath: Path) -> list[tuple[str, int]]:
"""
Return a list of tuples of atom symbols and counts.

Return a list of tuples of atom symbols and counts. This is superior to a dictionary as it maintains the order of
the atoms in the bandfile.

Args:
bandfile_filepath (str | Path): The path to the bandfile.

Returns:
list[tuple[str, int]]: A list of tuples of atom symbols and counts.
"""
bandfile = read_file(bandfile_filepath)
atom_count_list = []

for i, line in enumerate(bandfile):
if i > 1:
if "#" in line:
break
lsplit = line.strip().split()
sym = lsplit[0].strip()
count = int(lsplit[1].strip())
atom_count_list.append((sym, count))
return atom_count_list


def _get_orb_label_list_expected_len(labels_dict: dict[str, list[str]], atom_count_list: list[tuple[str, int]]) -> int:
"""
Return the expected length of the atomic orbital projection string representation list.

Return the expected length of the atomic orbital projection string representation list.

Args:
labels_dict (dict[str, list[str]]): A dictionary mapping each atom symbol to all atomic orbital projection
string representations.
atom_count_list (list[tuple[str, int]]): A list of tuples of atom symbols and counts.

Returns:
int: The expected length of the atomic orbital projection string representation list.
"""
expected_len = 0
for ion_tuple in atom_count_list:
ion = ion_tuple[0]
count = ion_tuple[1]
orbs = labels_dict[ion]
expected_len += count * len(orbs)
return expected_len


def _get_orb_label(ion: str, idx: int, orb: str) -> str:
"""
Return the string representation for an orbital projection.

Return the string representation for an orbital projection.

Args:
ion (str): The symbol of the atom.
idx (int): The index of the atom.
orb (str): The atomic orbital projection string representation.

Returns:
str: The atomic orbital projection string representation for the atom.
"""
return f"{ion}#{idx + 1}({orb})"


def _get_orb_label_list(bandfile_filepath: Path) -> tuple[str, ...]:
"""
Return a tuple of all atomic orbital projection string representations.

Return a tuple of all atomic orbital projection string representations.

Args:
bandfile_filepath (str | Path): The path to the bandfile.

Returns:
tuple[str]: A list of all atomic orbital projection string representations.
"""
labels_dict = _get_atom_orb_labels_ref_dict(bandfile_filepath)
atom_count_list = _get_atom_count_list(bandfile_filepath)
read_file(bandfile_filepath)
labels_list: list[str] = []
for ion_tuple in atom_count_list:
ion = ion_tuple[0]
orbs = labels_dict[ion]
count = ion_tuple[1]
for i in range(count):
for orb in orbs:
labels_list.append(_get_orb_label(ion, i, orb))
# This is most likely unnecessary, but it is a good check to have.
if len(labels_list) != _get_orb_label_list_expected_len(labels_dict, atom_count_list):
raise RuntimeError("Number of atomic orbital projections does not match expected length.")
return tuple(labels_list)
20 changes: 15 additions & 5 deletions src/pymatgen/io/jdftx/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class is written.
import numpy as np

from pymatgen.io.jdftx._output_utils import (
_get_atom_orb_labels_dict,
_get_nbands_from_bandfile_filepath,
_get_orb_label_list,
get_proj_tju_from_file,
read_outfile_slices,
)
Expand Down Expand Up @@ -90,6 +90,18 @@ class JDFTXOutputs:
(nspin, nkpt, nbands, nion, nionproj) to save on memory as nonionproj is different depending on the ion
type. This array may also be complex if specified in 'band-projections-params' in the JDFTx input, allowing
for pCOHP analysis.
eigenvals (np.ndarray): The eigenvalues. Stored in shape (nstates, nbands) where nstates is nspin*nkpts (nkpts
may not equal prod(kfolding) if symmetry reduction occurred) and nbands is the number of bands.
orb_label_list (tuple[str, ...]): A tuple of the orbital labels for the bandProjections file, where the i'th
element describes the i'th orbital. Orbital labels are formatted as "<ion>#<ion-number>(<orbital>)",
where <ion> is the element symbol of the ion, <ion-number> is the 1-based index of the ion-type in the
structure (ie C#2 would be the second carbon atom, but not necessarily the second ion in the structure),
and <orbital> is a string describing "l" and "ml" quantum numbers (ie "p_x" or "d_yz"). Note that while "z"
corresponds to the "z" axis, "x" and "y" are arbitrary and may not correspond to the actual x and y axes of
the structure. In the case where multiple shells of a given "l" are available within the projections, a
0-based index will appear mimicking a principle quantum number (ie "0px" for first shell and "1px" for
second shell). The actual principal quantum number is not stored in the JDFTx output files and must be
inferred by the user.
"""

calc_dir: str | Path = field(init=True)
Expand All @@ -99,7 +111,7 @@ class JDFTXOutputs:
bandProjections: np.ndarray | None = field(init=False)
eigenvals: np.ndarray | None = field(init=False)
# Misc metadata for interacting with the data
atom_orb_labels_dict: dict[int, str] | None = field(init=False)
orb_label_list: tuple[str, ...] | None = field(init=False)

@classmethod
def from_calc_dir(cls, calc_dir: str | Path, store_vars: list[str] | None = None) -> JDFTXOutputs:
Expand Down Expand Up @@ -167,15 +179,13 @@ def _check_bandProjections(self):
def _store_bandProjections(self):
if "bandProjections" in self.paths:
self.bandProjections = get_proj_tju_from_file(self.paths["bandProjections"])
self.atom_orb_labels_dict = _get_atom_orb_labels_dict(self.paths["bandProjections"])
self.orb_label_list = _get_orb_label_list(self.paths["bandProjections"])

def _check_eigenvals(self):
"""Check for misaligned data within eigenvals file."""
if "eigenvals" in self.paths:
if not self.paths["eigenvals"].exists():
raise RuntimeError("Allocated path for eigenvals does not exist.")
# TODO: We should not have to load the entire file to find its length - replace with something more
# efficient once Claude lets me create an account.
tj = len(np.fromfile(self.paths["eigenvals"]))
nstates_float = tj / self.outfile.nbands
if not np.isclose(nstates_float, int(nstates_float)):
Expand Down
8 changes: 4 additions & 4 deletions tests/io/jdftx/outputs_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ def jdftxoutfile_matches_known(joutfile: JDFTXOutfile, known: dict):
"eigenvals": n2_ex_calc_dir / Path("eigenvals"),
}
n2_ex_calc_dir_bandprojections_metadata = {
"atom_orb_labels_dict": {
"N": ["s", "px", "py", "pz"],
},
"orb_label_list": ["N#1(s)", "N#1(px)", "N#1(py)", "N#1(pz)", "N#2(s)", "N#2(px)", "N#2(py)", "N#2(pz)"],
"shape": (54, 15, 8),
"first val": -0.1331527 + 0.5655596j,
}


Expand All @@ -150,8 +149,9 @@ def jdftxoutfile_matches_known(joutfile: JDFTXOutfile, known: dict):
"eigenvals": nh3_ex_calc_dir / Path("eigenvals"),
}
nh3_ex_calc_dir_bandprojections_metadata = {
"atom_orb_labels_dict": {"N": ["s", "px", "py", "pz"], "H": ["s"]},
"orb_label_list": ["N#1(s)", "N#1(px)", "N#1(py)", "N#1(pz)", "H#1(s)", "H#2(s)", "H#3(s)"],
"shape": (16, 14, 7),
"first val": -0.0688767 + 0.9503786j,
}

example_sp_outfile_path = ex_out_files_dir / Path("example_sp.out")
Expand Down
3 changes: 2 additions & 1 deletion tests/io/jdftx/test_jdftxoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def test_store_vars(calc_dir: Path, store_vars: list[str]):
def test_store_bandprojections(calc_dir: Path, known_metadata: dict):
"""Test that the stored band projections are correct."""
jo = JDFTXOutputs.from_calc_dir(calc_dir, store_vars=["bandProjections"])
for var in ["atom_orb_labels_dict"]:
for var in ["orb_label_list"]:
assert hasattr(jo, var)
assert_same_value(getattr(jo, var), known_metadata[var])
assert_same_value(jo.bandProjections.shape, known_metadata["shape"])
assert pytest.approx(jo.bandProjections[0, 0, 0]) == known_metadata["first val"]


@pytest.mark.parametrize("calc_dir", [n2_ex_calc_dir, nh3_ex_calc_dir])
Expand Down
Loading