Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lobsterin dict inheritance and treat \t in lobsterins correctly #3439

Merged
merged 16 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions pymatgen/io/lobster/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import itertools
import os
import re
import warnings
from collections import UserDict
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -49,7 +51,7 @@
)


class Lobsterin(dict, MSONable):
class Lobsterin(UserDict, MSONable):
"""
This class can handle and generate lobsterin files
Furthermore, it can also modify INCAR files for lobster, generate KPOINT files for fatband calculations in Lobster,
Expand Down Expand Up @@ -159,7 +161,10 @@ def __getitem__(self, item):
if not found:
new_key = item

return dict.__getitem__(self, new_key)
return super().__getitem__(new_key)

def __delitem__(self, key):
del self.data[key.lower()]

def diff(self, other):
"""
Expand Down Expand Up @@ -580,31 +585,31 @@ def from_file(cls, lobsterin: str):
Lobsterindict: dict[str, Any] = {}

for datum in data:
# will remove all comments to avoid complications
raw_datum = datum.split("!")[0]
raw_datum = raw_datum.split("//")[0]
raw_datum = raw_datum.split("#")[0]
raw_datum = raw_datum.split(" ")
while "" in raw_datum:
raw_datum.remove("")
if len(raw_datum) > 1:
# check which type of keyword this is, handle accordingly
if raw_datum[0].lower() not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]:
if raw_datum[0].lower() not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]:
if raw_datum[0].lower() not in Lobsterindict:
Lobsterindict[raw_datum[0].lower()] = " ".join(raw_datum[1:])
# Remove all comments
if not datum.startswith(("!", "#", "//")):
pattern = r"\b[^!#//]+" # exclude comments after commands
matched_pattern = re.findall(pattern, datum)
if matched_pattern:
raw_datum = matched_pattern[0].replace("\t", " ") # handle tab in between and end of command
key_word = raw_datum.strip().split(" ") # extract keyword
if len(key_word) > 1:
# check which type of keyword this is, handle accordingly
if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]:
if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]:
if key_word[0].lower() not in Lobsterindict:
Lobsterindict[key_word[0].lower()] = " ".join(key_word[1:])
else:
raise ValueError(f"Same keyword {key_word[0].lower()} twice!")
elif key_word[0].lower() not in Lobsterindict:
Lobsterindict[key_word[0].lower()] = float(key_word[1])
else:
raise ValueError(f"Same keyword {key_word[0].lower()} twice!")
elif key_word[0].lower() not in Lobsterindict:
Lobsterindict[key_word[0].lower()] = [" ".join(key_word[1:])]
else:
raise ValueError(f"Same keyword {raw_datum[0].lower()} twice!")
elif raw_datum[0].lower() not in Lobsterindict:
Lobsterindict[raw_datum[0].lower()] = float(raw_datum[1])
else:
raise ValueError(f"Same keyword {raw_datum[0].lower()} twice!")
elif raw_datum[0].lower() not in Lobsterindict:
Lobsterindict[raw_datum[0].lower()] = [" ".join(raw_datum[1:])]
else:
Lobsterindict[raw_datum[0].lower()].append(" ".join(raw_datum[1:]))
elif len(raw_datum) > 0:
Lobsterindict[raw_datum[0].lower()] = True
Lobsterindict[key_word[0].lower()].append(" ".join(key_word[1:]))
elif len(key_word) > 0:
Lobsterindict[key_word[0].lower()] = True

return cls(Lobsterindict)

Expand Down
5 changes: 3 additions & 2 deletions tests/files/cohp/lobsterin.2
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
COHPstartEnergy -15.0
COHPendEnergy 5.0
basisSet pbeVaspFit2015
basisSet pbeVaspFit2015
gaussianSmearingWidth 0.1
basisfunctions Fe 3d 4p 4s ! This is a comment
basisfunctions Fe 3d 4p 4s ! This is a comment
basisfunctions Co 3d 4p 4s # This is another comment
skipdos // Here, we comment again
skipcohp
skipcoop
skipPopulationAnalysis
skipGrossPopulation
! cohpsteps
26 changes: 26 additions & 0 deletions tests/io/lobster/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,32 @@ def test_diff(self):
== self.Lobsterinfromfile3.diff(self.Lobsterinfromfile)["Different"]["SKIPCOHP"]["lobsterin2"]
)

def test_dict_functionality(self):
assert self.Lobsterinfromfile.get("COHPstartEnergy") == -15.0
assert self.Lobsterinfromfile.get("COHPstartEnergy") == -15.0
assert self.Lobsterinfromfile.get("COhPstartenergy") == -15.0
lobsterincopy = self.Lobsterinfromfile.copy()
lobsterincopy.update({"cohpstarteNergy": -10.00})
assert lobsterincopy["cohpstartenergy"] == -10.0
lobsterincopy.pop("cohpstarteNergy")
assert "cohpstartenergy" not in lobsterincopy
lobsterincopy.pop("cohpendenergY")
lobsterincopy["cohpsteps"] = 100
assert lobsterincopy["cohpsteps"] == 100
before = len(lobsterincopy.items())
lobsterincopy.popitem()
after = len(lobsterincopy.items())
assert before != after

def test_read_write_lobsterin(self):
outfile_path = tempfile.mkstemp()[1]
lobsterin1 = Lobsterin.from_file(f"{TEST_FILES_DIR}/cohp/lobsterin.1")
lobsterin1.write_lobsterin(outfile_path)
lobsterin2 = Lobsterin.from_file(outfile_path)
assert lobsterin1.diff(lobsterin2)["Different"] == {}

# TODO: will integer vs float break cohpsteps?

def test_get_basis(self):
# get basis functions
lobsterin1 = Lobsterin({})
Expand Down