Skip to content

Commit

Permalink
fix IonEntry doc str + type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Aug 6, 2023
1 parent 9455830 commit e734bf3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 65 deletions.
13 changes: 7 additions & 6 deletions pymatgen/analysis/pourbaix_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def __init__(self, entry_list, weights=None):
entry_list ([PourbaixEntry]): List of component PourbaixEntries
weights ([float]): Weights associated with each entry. Default is None
"""
self.weights = [1.0] * len(entry_list) if weights is None else weights
self.weights = weights or [1.0] * len(entry_list)
self.entry_list = entry_list

def __getattr__(self, attr):
Expand Down Expand Up @@ -320,16 +320,16 @@ def as_dict(self):
}

@classmethod
def from_dict(cls, d):
def from_dict(cls, dct):
"""
Args:
d (): Dict representation.
d (dict): Dict representation.
Returns:
MultiEntry
"""
entry_list = [PourbaixEntry.from_dict(e) for e in d.get("entry_list")]
return cls(entry_list, d.get("weights"))
entry_list = [PourbaixEntry.from_dict(entry) for entry in dct.get("entry_list")]
return cls(entry_list, dct.get("weights"))


# TODO: this class isn't particularly useful in its current form, could be
Expand All @@ -346,13 +346,14 @@ class IonEntry(PDEntry):
set to some other string for display purposes.
"""

def __init__(self, ion, energy, name=None, attribute=None):
def __init__(self, ion: Ion, energy: float, name: str | None = None, attribute=None):
"""
Args:
ion: Ion object
energy: Energy for composition.
name: Optional parameter to name the entry. Defaults to the
chemical formula.
attribute: Optional attribute of the entry, e.g., band gap.
"""
self.ion = ion
# Auto-assign name
Expand Down
57 changes: 29 additions & 28 deletions pymatgen/io/abinit/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,6 @@
"ird1wf",
}

# Name of the (default) tolerance used by the runlevels.
_runl2tolname = {
"scf": "tolvrs",
"nscf": "tolwfr",
"dfpt": "toldfe", # ?
"screening": "toldfe", # dummy
"sigma": "toldfe", # dummy
"bse": "toldfe", # ?
"relax": "tolrff",
}

# Tolerances for the different levels of accuracy.

Expand Down Expand Up @@ -155,7 +145,7 @@ class ShiftMode(Enum):
def from_object(cls, obj):
"""
Returns an instance of ShiftMode based on the type of object passed. Converts strings to ShiftMode depending
on the iniital letter of the string. G for GammaCenterd, M for MonkhorstPack,
on the initial letter of the string. G for GammaCentered, M for MonkhorstPack,
S for Symmetric, O for OneSymmetric.
Case insensitive.
"""
Expand All @@ -166,14 +156,25 @@ def from_object(cls, obj):
raise TypeError(f"The object provided is not handled: type {type(obj).__name__}")


def _stopping_criterion(runlevel, accuracy):
"""Return the stopping criterion for this runlevel with the given accuracy."""
tolname = _runl2tolname[runlevel]
return {tolname: getattr(_tolerances[tolname], accuracy)}
def _stopping_criterion(run_level, accuracy):
"""Return the stopping criterion for this run_level with the given accuracy."""

# Name of the (default) tolerance used by the run levels.
_run_level_tolname_map = {
"scf": "tolvrs",
"nscf": "tolwfr",
"dfpt": "toldfe", # ?
"screening": "toldfe", # dummy
"sigma": "toldfe", # dummy
"bse": "toldfe", # ?
"relax": "tolrff",
}
tol_name = _run_level_tolname_map[run_level]
return {tol_name: getattr(_tolerances[tol_name], accuracy)}


def _find_ecut_pawecutdg(ecut, pawecutdg, pseudos, accuracy):
"""Return a |AttrDict| with the value of ``ecut`` and ``pawecutdg``."""
"""Return a |AttrDict| with the value of ecut and pawecutdg."""
# Get ecut and pawecutdg from the pseudo hints.
if ecut is None or (pawecutdg is None and any(p.ispaw for p in pseudos)):
has_hints = all(p.has_hints for p in pseudos)
Expand All @@ -194,18 +195,18 @@ def _find_ecut_pawecutdg(ecut, pawecutdg, pseudos, accuracy):


def _find_scf_nband(structure, pseudos, electrons, spinat=None):
"""Find the value of ``nband``."""
"""Find the value of nband."""
if electrons.nband is not None:
return electrons.nband

nsppol, smearing = electrons.nsppol, electrons.smearing

# Number of valence electrons including possible extra charge
nval = num_valence_electrons(structure, pseudos)
nval -= electrons.charge
n_val_elec = num_valence_electrons(structure, pseudos)
n_val_elec -= electrons.charge

# First guess (semiconductors)
nband = nval // 2
nband = n_val_elec // 2

# TODO: Find better algorithm
# If nband is too small we may kill the job, increase nband and restart
Expand Down Expand Up @@ -234,7 +235,7 @@ def _get_shifts(shift_mode, structure):
centered otherwise.
Note: for some cases (e.g. body centered tetragonal), both the Symmetric and OneSymmetric may fail to satisfy the
``chksymbreak`` condition (Abinit input variable).
chksymbreak condition (Abinit input variable).
"""
if shift_mode == ShiftMode.GammaCentered:
return ((0, 0, 0),)
Expand Down Expand Up @@ -475,7 +476,7 @@ def ion_ioncell_relax_input(

def calc_shiftk(structure, symprec: float = 0.01, angle_tolerance=5):
"""
Find the values of ``shiftk`` and ``nshiftk`` appropriated for the sampling of the Brillouin zone.
Find the values of shiftk and nshiftk appropriated for the sampling of the Brillouin zone.
When the primitive vectors of the lattice do NOT form a FCC or a BCC lattice,
the usual (shifted) Monkhorst-Pack grids are formed by using nshiftk=1 and shiftk 0.5 0.5 0.5 .
Expand Down Expand Up @@ -611,7 +612,7 @@ def __str__(self):
return self.to_str()

def write(self, filepath="run.abi"):
"""Write the input file to file to ``filepath``."""
"""Write the input file to file to filepath."""
dirname = os.path.dirname(os.path.abspath(filepath))
if not os.path.exists(dirname):
os.makedirs(dirname)
Expand Down Expand Up @@ -699,7 +700,7 @@ def to_str(self):


class BasicAbinitInputError(Exception):
"""Base error class for exceptions raised by ``BasicAbinitInput``."""
"""Base error class for exceptions raised by BasicAbinitInput."""


class BasicAbinitInput(AbstractInput, MSONable):
Expand Down Expand Up @@ -790,7 +791,7 @@ def from_dict(cls, d):

def add_abiobjects(self, *abi_objects):
"""
This function receive a list of ``AbiVarable`` objects and add
This function receive a list of AbiVarable objects and add
the corresponding variables to the input.
"""
dct = {}
Expand Down Expand Up @@ -1010,7 +1011,7 @@ class BasicMultiDataset:
that provides an easy-to-use interface to apply global changes to the
the inputs stored in the objects.
Let's assume for example that multi contains two ``BasicAbinitInput`` objects and we
Let's assume for example that multi contains two BasicAbinitInput objects and we
want to set `ecut` to 1 in both dictionaries. The direct approach would be:
for inp in multi:
Expand Down Expand Up @@ -1204,7 +1205,7 @@ def extend(self, abinit_inputs):
self._inputs.extend(abinit_inputs)

def addnew_from(self, dtindex):
"""Add a new entry in the multidataset by copying the input with index ``dtindex``."""
"""Add a new entry in the multidataset by copying the input with index dtindex."""
self.append(self[dtindex].deepcopy())

def split_datasets(self):
Expand Down Expand Up @@ -1301,7 +1302,7 @@ def has_same_variable(kref, vref, other_inp):

def write(self, filepath="run.abi"):
"""
Write ``ndset`` input files to disk. The name of the file
Write ndset input files to disk. The name of the file
is constructed from the dataset index e.g. run0.abi.
"""
root, ext = os.path.splitext(filepath)
Expand Down
44 changes: 13 additions & 31 deletions pymatgen/io/vasp/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def __init__(
The default behavior of the constructor is for a Gamma centered,
1x1x1 KPOINTS with no shift.
"""
if num_kpts > 0 and (not labels) and (not kpts_weights):
if num_kpts > 0 and not labels and not kpts_weights:
raise ValueError("For explicit or line-mode kpoints, either the labels or kpts_weights must be specified.")

self.comment = comment
Expand All @@ -1100,10 +1100,9 @@ def __init__(
self.tet_connections = tet_connections

@property
def style(self):
def style(self) -> KpointsSupportedModes:
"""
:return: Style for kpoint generation. One of Kpoints_supported_modes
enum.
Style for kpoint generation. One of Kpoints_supported_modes enum.
"""
return self._style

Expand All @@ -1119,11 +1118,7 @@ def style(self, style):

if (
style
in (
Kpoints.supported_modes.Automatic,
Kpoints.supported_modes.Gamma,
Kpoints.supported_modes.Monkhorst,
)
in (Kpoints.supported_modes.Automatic, Kpoints.supported_modes.Gamma, Kpoints.supported_modes.Monkhorst)
and len(self.kpts) > 1
):
raise ValueError(
Expand All @@ -1150,10 +1145,7 @@ def automatic(subdivisions):
Kpoints object
"""
return Kpoints(
"Fully automatic kpoint scheme",
0,
style=Kpoints.supported_modes.Automatic,
kpts=[[subdivisions]],
"Fully automatic kpoint scheme", 0, style=Kpoints.supported_modes.Automatic, kpts=[[subdivisions]]
)

@staticmethod
Expand All @@ -1170,13 +1162,7 @@ def gamma_automatic(kpts: tuple[int, int, int] = (1, 1, 1), shift: Vector3D = (0
Returns:
Kpoints object
"""
return Kpoints(
"Automatic kpoint scheme",
0,
Kpoints.supported_modes.Gamma,
kpts=[kpts],
kpts_shift=shift,
)
return Kpoints("Automatic kpoint scheme", 0, Kpoints.supported_modes.Gamma, kpts=[kpts], kpts_shift=shift)

@staticmethod
def monkhorst_automatic(kpts: tuple[int, int, int] = (2, 2, 2), shift: Vector3D = (0, 0, 0)):
Expand All @@ -1192,13 +1178,7 @@ def monkhorst_automatic(kpts: tuple[int, int, int] = (2, 2, 2), shift: Vector3D
Returns:
Kpoints object
"""
return Kpoints(
"Automatic kpoint scheme",
0,
Kpoints.supported_modes.Monkhorst,
kpts=[kpts],
kpts_shift=shift,
)
return Kpoints("Automatic kpoint scheme", 0, Kpoints.supported_modes.Monkhorst, kpts=[kpts], kpts_shift=shift)

@staticmethod
def automatic_density(structure: Structure, kppa: float, force_gamma: bool = False):
Expand Down Expand Up @@ -1864,11 +1844,13 @@ def electron_configuration(self):

def write_file(self, filename: str) -> None:
"""
Writes PotcarSingle to a file.
:param filename: Filename.
Write PotcarSingle to a file.
Args:
filename (str): Filename to write to.
"""
with zopen(filename, "wt") as f:
f.write(str(self))
with zopen(filename, "wt") as file:
file.write(str(self))

@staticmethod
def from_file(filename: str) -> PotcarSingle:
Expand Down

0 comments on commit e734bf3

Please sign in to comment.