Skip to content

Commit

Permalink
Fix code scanning issues (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Sep 16, 2024
1 parent 0a9e4c2 commit 4562481
Show file tree
Hide file tree
Showing 196 changed files with 2,385 additions and 762 deletions.
7 changes: 5 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# adapt path to include the source code
sys.path.insert(0, op.join(op.dirname(__file__), "../../", "src"))

import dxtb
import dxtb # pylint: disable=unused-import

project = "Fully Differentiable Extended Tight-Binding"
author = "Grimme Group"
Expand Down Expand Up @@ -83,7 +83,10 @@
"tad_dftd4": ("https://tad-dftd4.readthedocs.io/en/latest/", None),
"tad_libcint": ("https://tad-libcint.readthedocs.io/en/latest/", None),
"tad_mctc": ("https://tad-mctc.readthedocs.io/en/latest/", None),
"tad_multicharge": ("https://tad-multicharge.readthedocs.io/en/latest/", None),
"tad_multicharge": (
"https://tad-multicharge.readthedocs.io/en/latest/",
None,
),
"torch": ("https://pytorch.org/docs/stable/", None),
}

Expand Down
4 changes: 3 additions & 1 deletion examples/profiling/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@
dxtb.timer.start("Setup")

dxtb.timer.start("Ihelp", parent_uid="Setup")
ihelp_cpu = dxtb.IndexHelper.from_numbers(numbers, dxtb.GFN1_XTB, batch_mode=batch_mode)
ihelp_cpu = dxtb.IndexHelper.from_numbers(
numbers, dxtb.GFN1_XTB, batch_mode=batch_mode
)
dxtb.timer.stop("Ihelp")

dxtb.timer.start("Class", parent_uid="Setup")
Expand Down
4 changes: 2 additions & 2 deletions examples/profiling/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

############################

import torch
import torch # noqa

t1 = time.perf_counter()

Expand All @@ -37,7 +37,7 @@

############################

import scipy
import scipy # noqa

t4 = time.perf_counter()

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,11 @@ omit = [
fail_under = 75


[tool.black]
line-length = 80


[tool.isort]
line_length = 80
profile = "black"
skip = ["./src/dxtb/__init__.py",]
30 changes: 24 additions & 6 deletions src/dxtb/_src/basis/bas.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from tad_mctc.data import pse
from tad_mctc.exceptions import DtypeError

from dxtb._src.param import Param, get_elem_param, get_elem_pqn, get_elem_valence
from dxtb._src.param import (
Param,
get_elem_param,
get_elem_pqn,
get_elem_valence,
)
from dxtb._src.typing import Literal, Self, Tensor, TensorLike, override

from .indexhelper import IndexHelper
Expand Down Expand Up @@ -104,13 +109,21 @@ def __init__(
self.ihelp = ihelp

self.ngauss = get_elem_param(
self.unique, par.element, "ngauss", device=self.device, dtype=torch.uint8
self.unique,
par.element,
"ngauss",
device=self.device,
dtype=torch.uint8,
)
self.slater = get_elem_param(
self.unique, par.element, "slater", **self.dd
)
self.slater = get_elem_param(self.unique, par.element, "slater", **self.dd)
self.pqn = get_elem_pqn(
self.unique, par.element, device=self.device, dtype=torch.uint8
)
self.valence = get_elem_valence(self.unique, par.element, device=self.device)
self.valence = get_elem_valence(
self.unique, par.element, device=self.device
)

def create_cgtos(self) -> tuple[list[Tensor], list[Tensor]]:
"""
Expand Down Expand Up @@ -194,7 +207,9 @@ def unique_shell_pairs(
# products upon multiplication (fundamental theorem of arithmetic).
orbs = primes.to(self.device)[sh2ush]
orbs = orbs.unsqueeze(-2) * orbs.unsqueeze(-1)
sh2orb = self.ihelp.spread_shell_to_orbital(self.ihelp.orbitals_per_shell)
sh2orb = self.ihelp.spread_shell_to_orbital(
self.ihelp.orbitals_per_shell
)

# extra offset along only one dimension to distinguish (n, m) and
# (m, n) of the same orbital block (e.g. 1x3 sp and 3x1 ps block)
Expand Down Expand Up @@ -326,7 +341,10 @@ def to_bse(
shells = self.ihelp.shells_per_atom[i]
for _ in range(shells):
alpha, coeff = slater_to_gauss(
self.ngauss[s], self.pqn[s], self.ihelp.angular[s], self.slater[s]
self.ngauss[s],
self.pqn[s],
self.ihelp.angular[s],
self.slater[s],
)
if self.valence[s].item() is False:
alpha, coeff = orthogonalize(
Expand Down
31 changes: 23 additions & 8 deletions src/dxtb/_src/basis/indexhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,9 @@ def from_numbers_angular(
device=cpu,
)

ushell_index = torch.cumsum(ushells_per_unique, dim=-1) - ushells_per_unique
ushell_index = (
torch.cumsum(ushells_per_unique, dim=-1) - ushells_per_unique
)
ushells_to_unique = _fill(ushell_index, ushells_per_unique)

if batch_mode > 0:
Expand Down Expand Up @@ -431,13 +433,17 @@ def from_numbers_angular(
lsh >= 0, 2 * lsh + 1, torch.tensor(0, device=cpu)
)

orbital_index = torch.cumsum(orbitals_per_shell, -1) - orbitals_per_shell
orbital_index = (
torch.cumsum(orbitals_per_shell, -1) - orbitals_per_shell
)
orbital_index[orbitals_per_shell == 0] = PAD

if batch_mode > 0:
orbitals_to_shell = pack(
[
_fill(orbital_index[_batch, :], orbitals_per_shell[_batch, :])
_fill(
orbital_index[_batch, :], orbitals_per_shell[_batch, :]
)
for _batch in range(numbers.shape[0])
],
value=PAD,
Expand Down Expand Up @@ -490,7 +496,9 @@ def reduce_orbital_to_shell(
Shell-resolved tensor.
"""

return wrap_scatter_reduce(x, dim, self.orbitals_to_shell, reduce, extra=extra)
return wrap_scatter_reduce(
x, dim, self.orbitals_to_shell, reduce, extra=extra
)

def reduce_shell_to_atom(
self,
Expand Down Expand Up @@ -520,7 +528,9 @@ def reduce_shell_to_atom(
Atom-resolved tensor.
"""

return wrap_scatter_reduce(x, dim, self.shells_to_atom, reduce, extra=extra)
return wrap_scatter_reduce(
x, dim, self.shells_to_atom, reduce, extra=extra
)

def reduce_orbital_to_atom(
self,
Expand Down Expand Up @@ -551,7 +561,9 @@ def reduce_orbital_to_atom(
"""

return self.reduce_shell_to_atom(
self.reduce_orbital_to_shell(x, dim=dim, reduce=reduce, extra=extra),
self.reduce_orbital_to_shell(
x, dim=dim, reduce=reduce, extra=extra
),
dim=dim,
reduce=reduce,
extra=extra,
Expand Down Expand Up @@ -973,7 +985,9 @@ def _orbitals_to_shell_cart(self) -> Tensor:
if self.batch_mode > 0:
orbitals_to_shell = pack(
[
_fill(orbital_index[_batch, :], orbitals_per_shell[_batch, :])
_fill(
orbital_index[_batch, :], orbitals_per_shell[_batch, :]
)
for _batch in range(self.angular.shape[0])
],
value=PAD,
Expand Down Expand Up @@ -1080,7 +1094,8 @@ def orbitals_per_atom(self) -> Tensor:
try:
# batch mode
pad = torch.nn.utils.rnn.pad_sequence(
[self.shells_to_atom.mT, self.orbitals_to_shell.T], padding_value=PAD
[self.shells_to_atom.mT, self.orbitals_to_shell.T],
padding_value=PAD,
)
pad = einsum("ijk->kji", pad) # [2, bs, norb_max]
except RuntimeError:
Expand Down
5 changes: 4 additions & 1 deletion src/dxtb/_src/basis/ortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def orthogonalize(

# Create new basis function from the pair which is orthogonal to the first
# basis function
alpha_new[: alpha_j.shape[-1]], alpha_new[alpha_j.shape[-1] :] = alpha_j, alpha_i
alpha_new[: alpha_j.shape[-1]], alpha_new[alpha_j.shape[-1] :] = (
alpha_j,
alpha_i,
)
coeff_new[: coeff_j.shape[-1]], coeff_new[coeff_j.shape[-1] :] = (
coeff_j,
-overlap * coeff_i,
Expand Down
4 changes: 3 additions & 1 deletion src/dxtb/_src/basis/slater.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
# Two over pi
top = 2.0 / math.pi

dfactorial = torch.tensor([1.0, 1.0, 3.0, 15.0, 105.0, 945.0, 10395.0, 135135.0])
dfactorial = torch.tensor(
[1.0, 1.0, 3.0, 15.0, 105.0, 945.0, 10395.0, 135135.0]
)
"""
Double factorial up to 7!! for normalization of the Gaussian basis functions.
Expand Down
4 changes: 3 additions & 1 deletion src/dxtb/_src/calculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ def calculate(
spin: Tensor | float | int | None = defaults.SPIN,
**kwargs: Any,
):
AutogradCalculator.calculate(self, properties, positions, chrg, spin, **kwargs)
AutogradCalculator.calculate(
self, properties, positions, chrg, spin, **kwargs
)
16 changes: 13 additions & 3 deletions src/dxtb/_src/calculators/config/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
import torch

from dxtb._src.constants import defaults, labels
from dxtb._src.typing import Any, PathLike, Self, get_default_device, get_default_dtype
from dxtb._src.typing import (
Any,
PathLike,
Self,
get_default_device,
get_default_dtype,
)

from .cache import ConfigCache
from .integral import ConfigIntegrals
Expand Down Expand Up @@ -359,7 +365,9 @@ def batch_mode(self, value: int) -> None:
If the batch mode is invalid.
"""
if value not in (0, 1, 2):
raise ValueError(f"Invalid batch mode '{value}'. Must be one of [0, 1, 2].")
raise ValueError(
f"Invalid batch mode '{value}'. Must be one of [0, 1, 2]."
)

self._batch_mode = value
self.scf.batch_mode = value
Expand Down Expand Up @@ -400,7 +408,9 @@ def to_json(self, path: PathLike | None = None) -> str:
config_info = self.info()

def serialize(value):
if isinstance(value, torch.device) or isinstance(value, torch.dtype):
if isinstance(value, torch.device) or isinstance(
value, torch.dtype
):
return str(value)
elif isinstance(value, list):
# Recursively serialize lists
Expand Down
4 changes: 3 additions & 1 deletion src/dxtb/_src/calculators/gfn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(
numbers: Tensor,
*,
classical: list[Classical] | tuple[Classical] | Classical | None = None,
interaction: list[Interaction] | tuple[Interaction] | Interaction | None = None,
interaction: (
list[Interaction] | tuple[Interaction] | Interaction | None
) = None,
opts: dict[str, Any] | Config | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
Expand Down
4 changes: 3 additions & 1 deletion src/dxtb/_src/calculators/gfn2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def __init__(
numbers: Tensor,
*,
classical: list[Classical] | tuple[Classical] | Classical | None = None,
interaction: list[Interaction] | tuple[Interaction] | Interaction | None = None,
interaction: (
list[Interaction] | tuple[Interaction] | Interaction | None
) = None,
opts: dict[str, Any] | Config | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
Expand Down
8 changes: 6 additions & 2 deletions src/dxtb/_src/calculators/properties/moments/quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
__all__ = ["quadrupole"]


def quadrupole(qat: Tensor, dpat: Tensor, qpat: Tensor, positions: Tensor) -> Tensor:
def quadrupole(
qat: Tensor, dpat: Tensor, qpat: Tensor, positions: Tensor
) -> Tensor:
"""
Analytical calculation of traceless electric quadrupole moment.
Expand Down Expand Up @@ -84,7 +86,9 @@ def quadrupole(qat: Tensor, dpat: Tensor, qpat: Tensor, positions: Tensor) -> Te

# Compute the atomic contributions to molecular quadrupole moment
cart = torch.empty(
(*positions.shape[:-1], 6), device=positions.device, dtype=positions.dtype
(*positions.shape[:-1], 6),
device=positions.device,
dtype=positions.dtype,
)
cart[..., 0] = pv2d[..., 0]
cart[..., 1] = (
Expand Down
4 changes: 3 additions & 1 deletion src/dxtb/_src/calculators/properties/vibration/raman.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def depol_unit(self, *_: Any) -> NoReturn:

# conversion

def to_unit(self, value: Literal["freqs", "ints", "depol"], unit: str) -> Tensor:
def to_unit(
self, value: Literal["freqs", "ints", "depol"], unit: str
) -> Tensor:
"""
Convert a value from one unit to another based on the converter
dictionary.
Expand Down
16 changes: 12 additions & 4 deletions src/dxtb/_src/calculators/properties/vibration/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def freqs_unit(self, value: str) -> None:
raise ValueError(f"Unsupported frequency unit: {value}")
self._freqs_unit = value

def _convert(self, value: Tensor, unit: str, converter: dict[str, float]) -> Tensor:
def _convert(
self, value: Tensor, unit: str, converter: dict[str, float]
) -> Tensor:
"""
Convert a tensor from one unit to another based on the converter
dictionary.
Expand Down Expand Up @@ -113,7 +115,9 @@ def _convert(self, value: Tensor, unit: str, converter: dict[str, float]) -> Ten

return value * converter[unit]

def save_prop_to_pt(self, prop: str, filepath: PathLike | None = None) -> None:
def save_prop_to_pt(
self, prop: str, filepath: PathLike | None = None
) -> None:
"""
Save the results to a PyTorch file.
Expand Down Expand Up @@ -157,7 +161,9 @@ def save_all_to_pt(self, filepaths: list[PathLike] | None = None) -> None:
self.save_prop_to_pt(prop, path)

def __iter__(self) -> Generator[Tensor, None, None]:
return iter(getattr(self, s) for s in get_all_slots(self) if "unit" not in s)
return iter(
getattr(self, s) for s in get_all_slots(self) if "unit" not in s
)

def __getitem__(self, key: str) -> Tensor:
s = get_all_slots(self)
Expand All @@ -173,7 +179,9 @@ def __getitem__(self, key: str) -> Tensor:
if key in s:
return getattr(self, key)

raise KeyError(f"Invalid key: '{key}'. Possible keys are '{', '.join(s)}'.")
raise KeyError(
f"Invalid key: '{key}'. Possible keys are '{', '.join(s)}'."
)

def __str__(self) -> str:
text = ""
Expand Down
Loading

0 comments on commit 4562481

Please sign in to comment.