Skip to content

Commit

Permalink
Use Self type in Method Signatures (#3705)
Browse files Browse the repository at this point in the history
* add some Self type annotations

* add more Self and type annotations

* add more Self and type annotations

* pre-commit auto-fixes

* add more Self and type annotations

* format docstring to google style

* add more Self types

* TEMP FIX for normalized repr

* add more Self types

* remove encoding for binary

* use `Attributes` over `Parameters`

* add Self return type to __new__ and __enter__ methods

* use `Self` for multi-line def

* `mypy` fixes in core

* `mypy` fixes in core

* move `mypy` fixes

* add Note tag

* remove unused arg properties from core.ion

* fix mypy error

* fix unit test for reaction energy calc

* fix unit test for core.units

* remove ERROR = UnitError

* google-style doc string return type format

* replace deprecated utcnow

* fix missing datetime import

* replace `app = lines.append`

* slab type hints

* relocate Slab import

* fix vector in adsorption

* use time for time count

* revert datetime utc

* switch to `time`

* switch to `datetime` for utc time for now

* take some  coderabbitai suggestions

* ignore some mypy override errors

* fix type ufloat

* fix BaseLammpsGenerator missing fields inputfile and data (did this actually work before) plus fix mypy errors

* `mypy` fixes

* `mypy` fixes

* more `mypy` fixes

* mypy fixes

down to 69 errors

* fix unit test

* reapply fix for voronoi

* fix mypy and voronoi unit test

* fix unit test for voronoi

* more `mypy` fixes

* one more `mypy` fix

* pre-commit auto-fixes

* fix number import name

* fix lammps generator

* avoid self type hard-coding

* mypy fix

* `sourcery` fix

* `sourcery` fix

* `sourcery` fix

* `mypy` fix

* suppress `mypy` error

* `mypy` and `sourcery` fix

* `mypy` fixes

* fix unit tests

* `mypy` fixes

* `mypy` fixes

* suppress `mypy` error in xcfunc

* fix `mypy` error

* fix `mypy` error

* refactor date parsing in AirssProvider to use std lib datetime instead of pypi pkg dateutil

* fix TestKpoints check kpts_shift

* fix mypy error in sets.py: cast(Sequence[Sequence[float]], kpoints)

* refactor

* fix mypy confusing keyword and positional args in Lattice.from_parameters

* remove now unused mypy ignore in Cssr.from_str

* fix bad refactor in 8bdb484

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
DanielYang59 and janosh authored Mar 29, 2024
1 parent 2500664 commit 6a06f3c
Show file tree
Hide file tree
Showing 128 changed files with 1,324 additions and 1,202 deletions.
26 changes: 13 additions & 13 deletions dev_scripts/potcar_scrambler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import warnings
from glob import glob
from typing import TYPE_CHECKING

import numpy as np
from monty.os.path import zpath
Expand All @@ -14,6 +15,9 @@
from pymatgen.io.vasp.sets import _load_yaml_config
from pymatgen.util.testing import VASP_IN_DIR

if TYPE_CHECKING:
from typing_extensions import Self


class PotcarScrambler:
"""
Expand All @@ -34,26 +38,22 @@ class PotcarScrambler:
from existing POTCAR `input_filename`
"""

def __init__(self, potcars: Potcar | PotcarSingle):
if isinstance(potcars, PotcarSingle):
self.PSP_list = [potcars]
else:
self.PSP_list = potcars
def __init__(self, potcars: Potcar | PotcarSingle) -> None:
self.PSP_list = [potcars] if isinstance(potcars, PotcarSingle) else potcars
self.scrambled_potcars_str = ""
for psp in self.PSP_list:
scrambled_potcar_str = self.scramble_single_potcar(psp)
self.scrambled_potcars_str += scrambled_potcar_str
return

def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5):
def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5) -> float:
n_prec = len(input_str.split(".")[1])
bd = max(1, bloat * abs(float(input_str)))
return round(bd * np.random.rand(1)[0], n_prec)

def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5):
input_str = input_str.strip()

if input_str.lower() in ("t", "f", "true", "false"):
if input_str.lower() in {"t", "f", "true", "false"}:
return bool(np.random.randint(2))

if input_str.upper() == input_str.lower() and input_str[0].isnumeric():
Expand All @@ -68,7 +68,7 @@ def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5):
except ValueError:
return input_str

def scramble_single_potcar(self, potcar: PotcarSingle):
def scramble_single_potcar(self, potcar: PotcarSingle) -> str:
"""
Scramble the body of a POTCAR, retain the PSCTR header information.
Expand Down Expand Up @@ -124,20 +124,20 @@ def scramble_single_potcar(self, potcar: PotcarSingle):
)
return scrambled_potcar_str

def to_file(self, filename: str):
def to_file(self, filename: str) -> None:
with zopen(filename, mode="wt") as file:
file.write(self.scrambled_potcars_str)

@classmethod
def from_file(cls, input_filename: str, output_filename: str | None = None):
def from_file(cls, input_filename: str, output_filename: str | None = None) -> Self:
psp = Potcar.from_file(input_filename)
psp_scrambled = cls(psp)
if output_filename:
psp_scrambled.to_file(output_filename)
return psp_scrambled


def generate_fake_potcar_libraries():
def generate_fake_potcar_libraries() -> None:
"""
To test the `_gen_potcar_summary_stats` function in `pymatgen.io.vasp.inputs`,
need a library of fake POTCARs which do not violate copyright
Expand Down Expand Up @@ -173,7 +173,7 @@ def generate_fake_potcar_libraries():
break


def potcar_cleanser():
def potcar_cleanser() -> None:
"""
Function to replace copyrighted POTCARs used in io.vasp.sets testing
with dummy POTCARs that have scrambled PSP and kinetic energy values
Expand Down
12 changes: 6 additions & 6 deletions dev_scripts/regen_libxcfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def write_libxc_docs_json(xc_funcs, json_path):
xc_funcs = deepcopy(xc_funcs)

# Remove XC_FAMILY from Family and XC_ from Kind to make strings more human-readable.
for d in xc_funcs.values():
d["Family"] = d["Family"].replace("XC_FAMILY_", "", 1)
d["Kind"] = d["Kind"].replace("XC_", "", 1)
for dct in xc_funcs.values():
dct["Family"] = dct["Family"].replace("XC_FAMILY_", "", 1)
dct["Kind"] = dct["Kind"].replace("XC_", "", 1)

# Build lightweight version with a subset of keys.
for num, d in xc_funcs.items():
xc_funcs[num] = {key: d[key] for key in ("Family", "Kind", "References")}
for num, dct in xc_funcs.items():
xc_funcs[num] = {key: dct[key] for key in ("Family", "Kind", "References")}
# Descriptions are optional
for opt in ("Description 1", "Description 2"):
desc = d.get(opt)
desc = dct.get(opt)
if desc is not None:
xc_funcs[num][opt] = desc

Expand Down
2 changes: 1 addition & 1 deletion pymatgen/alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __repr__(self):
]
)

def as_dict(self):
def as_dict(self) -> dict:
"""Returns: MSONable dict."""
return {
"@module": type(self).__module__,
Expand Down
14 changes: 8 additions & 6 deletions pymatgen/alchemy/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from typing_extensions import Self

from pymatgen.alchemy.filters import AbstractStructureFilter


Expand Down Expand Up @@ -212,7 +214,7 @@ def write_vasp_input(
**kwargs: All keyword args supported by the VASP input set.
"""
vasp_input_set(self.final_structure, **kwargs).write_input(output_dir, make_dir_if_not_present=create_directory)
with open(f"{output_dir}/transformations.json", mode="w") as file:
with open(f"{output_dir}/transformations.json", mode="w", encoding="utf-8") as file:
json.dump(self.as_dict(), file)

def __str__(self) -> str:
Expand Down Expand Up @@ -267,7 +269,7 @@ def from_cif_str(
transformations: list[AbstractTransformation] | None = None,
primitive: bool = True,
occupancy_tolerance: float = 1.0,
) -> TransformedStructure:
) -> Self:
"""Generates TransformedStructure from a cif string.
Args:
Expand Down Expand Up @@ -311,7 +313,7 @@ def from_poscar_str(
cls,
poscar_string: str,
transformations: list[AbstractTransformation] | None = None,
) -> TransformedStructure:
) -> Self:
"""Generates TransformedStructure from a poscar string.
Args:
Expand Down Expand Up @@ -339,12 +341,12 @@ def as_dict(self) -> dict[str, Any]:
dct["@module"] = type(self).__module__
dct["@class"] = type(self).__name__
dct["history"] = jsanitize(self.history)
dct["last_modified"] = str(datetime.datetime.utcnow())
dct["last_modified"] = str(datetime.datetime.now(datetime.timezone.utc))
dct["other_parameters"] = jsanitize(self.other_parameters)
return dct

@classmethod
def from_dict(cls, dct: dict) -> TransformedStructure:
def from_dict(cls, dct: dict) -> Self:
"""Creates a TransformedStructure from a dict."""
struct = Structure.from_dict(dct)
return cls(struct, history=dct["history"], other_parameters=dct.get("other_parameters"))
Expand Down Expand Up @@ -376,7 +378,7 @@ def to_snl(self, authors: list[str], **kwargs) -> StructureNL:
return StructureNL(self.final_structure, authors, history=history, **kwargs)

@classmethod
def from_snl(cls, snl: StructureNL) -> TransformedStructure:
def from_snl(cls, snl: StructureNL) -> Self:
"""Create TransformedStructure from SNL.
Args:
Expand Down
34 changes: 18 additions & 16 deletions pymatgen/alchemy/transmuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from typing_extensions import Self

__author__ = "Shyue Ping Ong, Will Richards"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "0.1"
Expand All @@ -42,7 +44,7 @@ def __init__(
transformations=None,
extend_collection: int = 0,
ncores: int | None = None,
):
) -> None:
"""Initializes a transmuter from an initial list of
pymatgen.alchemy.materials.TransformedStructure.
Expand Down Expand Up @@ -71,7 +73,16 @@ def __getitem__(self, index):
def __getattr__(self, name):
return [getattr(x, name) for x in self.transformed_structures]

def undo_last_change(self):
def __len__(self):
return len(self.transformed_structures)

def __str__(self):
output = ["Current structures", "------------"]
for x in self.transformed_structures:
output.append(str(x.final_structure))
return "\n".join(output)

def undo_last_change(self) -> None:
"""Undo the last transformation in the TransformedStructure.
Raises:
Expand All @@ -80,7 +91,7 @@ def undo_last_change(self):
for x in self.transformed_structures:
x.undo_last_change()

def redo_next_change(self):
def redo_next_change(self) -> None:
"""Redo the last undone transformation in the TransformedStructure.
Raises:
Expand All @@ -89,9 +100,6 @@ def redo_next_change(self):
for x in self.transformed_structures:
x.redo_next_change()

def __len__(self):
return len(self.transformed_structures)

def append_transformation(self, transformation, extend_collection=False, clear_redo=True):
"""Appends a transformation to all TransformedStructures.
Expand Down Expand Up @@ -178,12 +186,6 @@ def add_tags(self, tags):
"""
self.set_parameter("tags", tags)

def __str__(self):
output = ["Current structures", "------------"]
for x in self.transformed_structures:
output.append(str(x.final_structure))
return "\n".join(output)

def append_transformed_structures(self, trafo_structs_or_transmuter):
"""Method is overloaded to accept either a list of transformed structures
or transmuter, it which case it appends the second transmuter"s
Expand All @@ -201,7 +203,7 @@ def append_transformed_structures(self, trafo_structs_or_transmuter):
self.transformed_structures.extend(trafo_structs_or_transmuter)

@classmethod
def from_structures(cls, structures, transformations=None, extend_collection=0):
def from_structures(cls, structures, transformations=None, extend_collection=0) -> Self:
"""Alternative constructor from structures rather than
TransformedStructures.
Expand Down Expand Up @@ -256,7 +258,7 @@ def __init__(self, cif_string, transformations=None, primitive=True, extend_coll
super().__init__(transformed_structures, transformations, extend_collection)

@classmethod
def from_filenames(cls, filenames, transformations=None, primitive=True, extend_collection=False):
def from_filenames(cls, filenames, transformations=None, primitive=True, extend_collection=False) -> Self:
"""Generates a TransformedStructureCollection from a cif, possibly
containing multiple structures.
Expand All @@ -269,7 +271,7 @@ def from_filenames(cls, filenames, transformations=None, primitive=True, extend_
"""
cif_files = []
for filename in filenames:
with open(filename) as file:
with open(filename, encoding="utf-8") as file:
cif_files.append(file.read())
return cls(
"\n".join(cif_files),
Expand Down Expand Up @@ -308,7 +310,7 @@ def from_filenames(cls, poscar_filenames, transformations=None, extend_collectio
"""
trafo_structs = []
for filename in poscar_filenames:
with open(filename) as file:
with open(filename, encoding="utf-8") as file:
trafo_structs.append(TransformedStructure.from_poscar_str(file.read(), []))
return StandardTransmuter(trafo_structs, transformations, extend_collection=extend_collection)

Expand Down
Loading

0 comments on commit 6a06f3c

Please sign in to comment.