Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Bandoverlaps parser #3689

Merged
58 changes: 32 additions & 26 deletions pymatgen/io/lobster/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,10 +1326,12 @@ def __init__(
"""
Args:
filename: filename of the "bandOverlaps.lobster" file.
band_overlaps_dict: A dictionary containing the band overlap data of the form: {spin: {"kpoint as string":
{"max_deviation":float that describes the max deviation, "matrix": 2D array of the size number of bands
times number of bands including the overlap matrices with}}}.
max_deviation (list[float]): A list of floats describing the maximal deviation for each problematic kpoint.
band_overlaps_dict: A dictionary containing the band overlap data of the form: {spin: {
"k_points" : list of k-point array,
"max_deviations": list of max deviations associated with each k-point,
"matrices": list of the overlap matrices associated with each k-point
}}.
max_deviation (list[float]): A list of floats describing the maximal deviation for each problematic k-point.
"""
self._filename = filename
self.band_overlaps_dict = {} if band_overlaps_dict is None else band_overlaps_dict
Expand Down Expand Up @@ -1363,24 +1365,31 @@ def _read(self, contents: list, spin_numbers: list):
kpoint_array = []
for kpointel in kpoint:
if kpointel not in ["at", "k-point", ""]:
kpoint_array.append(str(kpointel))
kpoint_array += [float(kpointel)]

elif "maxDeviation" in line:
if spin not in self.band_overlaps_dict:
self.band_overlaps_dict[spin] = {}
if " ".join(kpoint_array) not in self.band_overlaps_dict[spin]:
self.band_overlaps_dict[spin][" ".join(kpoint_array)] = {}
if "k_points" not in self.band_overlaps_dict[spin]:
self.band_overlaps_dict[spin]["k_points"] = []
if "max_deviations" not in self.band_overlaps_dict[spin]:
self.band_overlaps_dict[spin]["max_deviations"] = []
if "matrices" not in self.band_overlaps_dict[spin]:
self.band_overlaps_dict[spin]["matrices"] = []
maxdev = line.split(" ")[2]
self.band_overlaps_dict[spin][" ".join(kpoint_array)]["maxDeviation"] = float(maxdev)
self.max_deviation.append(float(maxdev))
self.band_overlaps_dict[spin][" ".join(kpoint_array)]["matrix"] = []
self.band_overlaps_dict[spin]["max_deviations"] += [float(maxdev)]
self.band_overlaps_dict[spin]["k_points"] += [kpoint_array]
self.max_deviation += [float(maxdev)]
overlaps = []

else:
overlaps = []
rows = []
for el in line.split(" "):
if el != "":
overlaps.append(float(el))
self.band_overlaps_dict[spin][" ".join(kpoint_array)]["matrix"].append(overlaps)
rows += [float(el)]
overlaps += [rows]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overlaps is possibly unbound now from renaming overlaps = [] to rows above.

we don't check for this in CI yet due to way too many legacy violations but really should. related discussion in #3646

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @janosh, I did not get what you meant here. Can you please elaborate a bit more? I need the overlaps variable initialized in the previous if clause to properly store matrices for each k-point.

Is there anything I need to do ? I could not think of any other way to do this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is, if the code enters the else case overlaps won't be declared, so Python will throw a NameError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, such case won't happen as per the current file format, it will always pass through the elif clause .

But I Will check for a few more examples that I have and see If I encounter any errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @janosh , I checked for about 20 different example files which I have, does not result in NameError due to file format , logic seems to work without breaking it 😃

But if you have any better idea to get same outcome, I am happy to implement it. Just at this point I seem to have blanked out 😅 and cannot think of any alternatives

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could pip install pyright and run it on this file. it'll show you the error i mean. you just need to ensure overlaps is declared by all code paths

Copy link
Contributor Author

@naik-aakash naik-aakash Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @janosh , I have now addressed the errors which I could using pyright. If you think these are okay it could be merged.

if len(overlaps) == len(rows):
self.band_overlaps_dict[spin]["matrices"] += [np.matrix(overlaps)]

def has_good_quality_maxDeviation(self, limit_maxDeviation: float = 0.1) -> bool:
"""
Expand Down Expand Up @@ -1414,38 +1423,35 @@ def has_good_quality_check_occupied_bands(
Returns:
Boolean that will give you information about the quality of the projection
"""
for matrix in self.band_overlaps_dict[Spin.up].values():
for iband1, band1 in enumerate(matrix["matrix"]):
for matrix in self.band_overlaps_dict[Spin.up]["matrices"]:
for iband1, band1 in enumerate(matrix):
for iband2, band2 in enumerate(band1):
if iband1 < number_occ_bands_spin_up and iband2 < number_occ_bands_spin_up:
if iband1 == iband2:
if abs(band2 - 1.0) > limit_deviation:
if abs(band2 - 1.0).all() > limit_deviation:
return False
elif band2 > limit_deviation:
elif band2.all() > limit_deviation:
return False

if spin_polarized:
for matrix in self.band_overlaps_dict[Spin.down].values():
for iband1, band1 in enumerate(matrix["matrix"]):
for matrix in self.band_overlaps_dict[Spin.down]["matrices"]:
for iband1, band1 in enumerate(matrix):
for iband2, band2 in enumerate(band1):
if number_occ_bands_spin_down is not None:
if iband1 < number_occ_bands_spin_down and iband2 < number_occ_bands_spin_down:
if iband1 == iband2:
if abs(band2 - 1.0) > limit_deviation:
if abs(band2 - 1.0).all() > limit_deviation:
return False
elif band2 > limit_deviation:
elif band2.all() > limit_deviation:
return False
else:
ValueError("number_occ_bands_spin_down has to be specified")
return True

@property
def bandoverlapsdict(self):
warnings.warn(
"`bandoverlapsdict` attribute is deprecated. Use `band_overlaps_dict` instead.",
DeprecationWarning,
stacklevel=2,
)
msg = "`bandoverlapsdict` attribute is deprecated. Use `band_overlaps_dict` instead."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self.band_overlaps_dict


Expand Down
Loading
Loading