Skip to content

Commit

Permalink
Improve code quality (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Aug 14, 2024
1 parent 9e6c7b4 commit b0f948f
Show file tree
Hide file tree
Showing 29 changed files with 190 additions and 127 deletions.
100 changes: 100 additions & 0 deletions .github/workflows/ubuntu-pytorch-1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: Tests (Ubuntu, PyTorch V1)

on:
push:
branches:
- main
- master
paths-ignore:
- "doc*/**"
- "./*.ya?ml"
- "**/*.md"
- "**/*.rst"

pull_request:
paths-ignore:
- "doc*/**"
- "./*.ya?ml"
- "**/*.md"
- "**/*.rst"

workflow_dispatch:

jobs:
main:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
torch-version: ["1.11.0", "1.12.1", "1.13.1"]
exclude:
# Check latest versions here: https://download.pytorch.org/whl/torch/
#
# PyTorch now fully supports Python=<3.11
# see: https://github.com/pytorch/pytorch/issues/86566
#
# PyTorch does now support Python 3.12 (Linux) for 2.2.0 and newer
# see: https://github.com/pytorch/pytorch/issues/110436
#
# PyTorch<1.13.0 does only support Python=<3.10
- python-version: "3.11"
torch-version: "1.11.0"
- python-version: "3.11"
torch-version: "1.12.1"

runs-on: ${{ matrix.os }}

defaults:
run:
shell: bash {0}

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python3 -m pip install --upgrade pip
python3 -m pip install tox
- name: Determine TOXENV
run: echo "TOXENV=py$(echo ${{ matrix.python-version }} | tr -d '.')-torch$(echo ${{ matrix.torch-version }} | tr -d '.')" >> $GITHUB_ENV

- name: Print TOXENV
run: echo "TOXENV is set to '${{ env.TOXENV }}'."

- name: Unittests with tox
run: tox -e ${{ env.TOXENV }}

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
if: >
matrix.python-version == '3.11' &&
matrix.torch-version == '2.2.2' &&
matrix.os == 'ubuntu-latest'
with:
files: ./coverage.xml # optional
token: ${{ secrets.CODECOV_TOKEN }} # required
verbose: true # optional (default = false)
18 changes: 3 additions & 15 deletions .github/workflows/ubuntu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,20 @@ jobs:
matrix:
os: [ubuntu-latest]
# python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
# torch-version: ["1.11.0", "1.12.1", "1.13.1", "2.0.1", "2.1.2", "2.2.2", "2.3.1"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
torch-version: ["1.11.0", "1.12.1", "1.13.1", "2.0.1", "2.1.2", "2.2.2"]
torch-version: ["2.0.1", "2.1.2", "2.2.2", "2.3.1"]
exclude:
# Check latest versions here: https://download.pytorch.org/whl/torch/
#
# PyTorch now fully supports Python=<3.11
# PyTorch fully supports Python=<3.11
# see: https://github.com/pytorch/pytorch/issues/86566
#
# PyTorch does now support Python 3.12 (Linux) for 2.2.0 and newer
# PyTorch supports Python 3.12 (Linux) for 2.2.0 and newer
# see: https://github.com/pytorch/pytorch/issues/110436
- python-version: "3.12"
torch-version: "1.11.0"
- python-version: "3.12"
torch-version: "1.12.1"
- python-version: "3.12"
torch-version: "1.13.1"
- python-version: "3.12"
torch-version: "2.0.1"
- python-version: "3.12"
torch-version: "2.1.2"
# PyTorch<1.13.0 does only support Python=<3.10
- python-version: "3.11"
torch-version: "1.11.0"
- python-version: "3.11"
torch-version: "1.12.1"

runs-on: ${{ matrix.os }}

Expand Down
11 changes: 6 additions & 5 deletions examples/profiling/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
dxtb.timer.start("Setup")

dxtb.timer.start("Ihelp", parent_uid="Setup")
ihelp = dxtb.IndexHelper.from_numbers(numbers, dxtb.GFN1_XTB, batch_mode=batch_mode)
ihelp_cuda = dxtb.IndexHelper.from_numbers(
numbers, dxtb.GFN1_XTB, batch_mode=batch_mode
)
dxtb.timer.stop("Ihelp")

dxtb.timer.start("Class", parent_uid="Setup")
Expand All @@ -56,7 +58,7 @@
dxtb.timer.start("Cache")

torch.cuda.synchronize()
cache = obj.get_cache(numbers, ihelp=ihelp)
cache = obj.get_cache(numbers, ihelp=ihelp_cuda)

torch.cuda.synchronize()
dxtb.timer.stop("Cache")
Expand All @@ -74,14 +76,13 @@

numbers = numbers.cpu()
positions = positions.cpu()
ihelp = ihelp.cpu()
dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}

dxtb.timer.reset()
dxtb.timer.start("Setup")

dxtb.timer.start("Ihelp", parent_uid="Setup")
ihelp = dxtb.IndexHelper.from_numbers(numbers, dxtb.GFN1_XTB, batch_mode=batch_mode)
ihelp_cpu = dxtb.IndexHelper.from_numbers(numbers, dxtb.GFN1_XTB, batch_mode=batch_mode)
dxtb.timer.stop("Ihelp")

dxtb.timer.start("Class", parent_uid="Setup")
Expand All @@ -92,7 +93,7 @@
dxtb.timer.stop("Setup")
dxtb.timer.start("Cache")

cache = obj.get_cache(numbers, ihelp=ihelp)
cache = obj.get_cache(numbers, ihelp=ihelp_cpu)

dxtb.timer.stop("Cache")
dxtb.timer.start("Energy")
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ install_requires =
tad-multicharge
tomli
tomli-w
torch>=1.11.0,<=2.2.2
torch>=1.11.0,<2.4
typing-extensions
python_requires = >=3.8, <3.12
package_dir =
Expand Down
5 changes: 2 additions & 3 deletions src/dxtb/_src/basis/bas.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def to_bse(
f"Available options are: {allowed_formats}."
)

header = ""
if with_header is True:
l = 70 * "-"
header = (
Expand All @@ -317,9 +318,7 @@ def to_bse(
s = 0
fulltxt = ""
for i, number in enumerate(self.unique.tolist()):
txt = ""
if with_header is True:
txt += header # type: ignore
txt = header

if qcformat == "gaussian94":
txt += f"{pse.Z2S[number]}\n"
Expand Down
4 changes: 2 additions & 2 deletions src/dxtb/_src/calculators/properties/vibration/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def _get_rotational_modes(mass: Tensor, positions: Tensor):

# Eigendecomposition yields the principal moments of inertia (w)
# and the principal axes of rotation (paxes) of a molecule.
w, paxes = storch.eighb(im)
_, paxes = storch.eighb(im)

# make z-axis rotation vector with smallest moment of inertia
w = torch.flip(w, [-1])
# w = torch.flip(w, [-1])
paxes = torch.flip(paxes, [-1])
ex, ey, ez = paxes.mT

Expand Down
12 changes: 3 additions & 9 deletions src/dxtb/_src/calculators/types/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def forces_analytical(
Tensor
Atomic forces of shape ``(..., nat, 3)``.
"""
total_grad = torch.zeros(positions.shape, **self.dd)

# DEVNOTE: We need to save certain properties from the energy
# calculation for the analytical derivative. So, we check in the
# options if those quantities were cached. If not, we correct the
Expand All @@ -111,14 +113,6 @@ def forces_analytical(

self.energy(positions, chrg, spin, **kwargs)

# Setup

chrg = any_to_tensor(chrg, **self.dd)
if spin is not None:
spin = any_to_tensor(spin, **self.dd)

total_grad = torch.zeros(positions.shape, **self.dd)

# CLASSICAL CONTRIBUTIONS

if len(self.classicals.components) > 0:
Expand All @@ -139,7 +133,7 @@ def forces_analytical(
timer.stop("Classicals Gradient")
OutputHandler.write_stdout("done", v=3)

if any(x in ["all", "scf"] for x in self.opts.exclude):
if {"all", "scf"} & set(self.opts.exclude):
return -total_grad

# SELF-CONSISTENT FIELD PROCEDURE
Expand Down
3 changes: 0 additions & 3 deletions src/dxtb/_src/calculators/types/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@
__all__ = ["EnergyCalculator"]


logger = logging.getLogger(__name__)


class EnergyCalculator(BaseCalculator):
"""
Parametrized calculator defining the extended tight-binding model.
Expand Down
4 changes: 1 addition & 3 deletions src/dxtb/_src/cli/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def singlepoint(self) -> Result | None:
numbers = pack(_n)
positions = pack(_p)
else:
_n, _p = read.read_from_path(args.file[0], args.filetype)
numbers = torch.tensor(_n, dtype=torch.long, device=dd["device"])
positions = torch.tensor(_p, **dd)
numbers, positions = read.read_from_path(args.file[0], args.filetype, **dd)

timer.stop("Read Files")

Expand Down
5 changes: 4 additions & 1 deletion src/dxtb/_src/components/interactions/solvation/alpb.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from tad_mctc.math import einsum

from dxtb._src.param import Param
from dxtb._src.typing import DD, Any, Tensor, TensorLike, get_default_dtype
from dxtb._src.typing import DD, Any, Tensor, TensorLike, get_default_dtype, override
from dxtb._src.typing.exceptions import DeviceError

from ..base import Interaction, InteractionCache
Expand Down Expand Up @@ -234,6 +234,7 @@ def __init__(
kwargs["rvdw"] = VDW_D3.to(**self.dd)[numbers]
self.born_kwargs = kwargs

@override
def get_cache(
self, numbers: Tensor, positions: Tensor, **_
) -> GeneralizedBornCache:
Expand Down Expand Up @@ -289,9 +290,11 @@ def get_cache(
self.cache = GeneralizedBornCache(mat)
return self.cache

@override
def get_atom_energy(self, charges: Tensor, cache: GeneralizedBornCache) -> Tensor:
return 0.5 * charges * self.get_atom_potential(charges, cache)

@override
def get_atom_potential(
self, charges: Tensor, cache: GeneralizedBornCache
) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions src/dxtb/_src/integral/driver/libcint/base_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def setup(self, positions: Tensor, **kwargs) -> None:
IndexHelper.from_numbers(number, self.par)
for number in self.numbers
]
else:
raise ValueError(f"Unknown batch mode '{self.ihelp.batch_mode}'.")

assert isinstance(atombases, list)
self.drv = [
Expand Down
20 changes: 9 additions & 11 deletions src/dxtb/_src/integral/driver/pytorch/impls/md/explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,6 @@ def de_s(
e030 = rpj * e020 + e021
e020 = rpj * e010 + e011

e021 = xij * e010 + rpj * e011
e030 = rpj * e020 + e021
e022 = xij * e011
e031 = xij * e020 + rpj * e021 + 2 * e022
e040 = rpj * e030 + e031
Expand Down Expand Up @@ -789,14 +787,14 @@ def de_p(
e210 = rpi * e110 + e111
f110 = a * e210 - e010

e021 = xij * e010 + rpj * e011
e030 = rpj * e020 + e021
# e021 = xij * e010 + rpj * e011
# e030 = rpj * e020 + e021
f020 = 2 * e010 - b * e030

e111 = xij * e100 + rpj * e101
# e111 = xij * e100 + rpj * e101
e112 = xij * e011
e211 = xij * e110 + rpi * e111 + 2 * e112
e120 = rpj * e110 + e111
# e120 = rpj * e110 + e111
e220 = rpj * e210 + e211
f120 = a * e220 - e020

Expand Down Expand Up @@ -1612,10 +1610,10 @@ def de_f(
f110 = a * e210 - e010
f210 = a * e310 - 2 * e110

e302 = xij * e201 + rpi * e202
e021 = xij * e010 + rpj * e011
e022 = xij * e011
e032 = xij * e021 + rpj * e022
# e302 = xij * e201 + rpi * e202
# e021 = xij * e010 + rpj * e011
# e022 = xij * e011
# e032 = xij * e021 + rpj * e022
e311 = xij * e300 + rpj * e301 + 2 * e302

e410 = rpi * e310 + e311
Expand All @@ -1630,7 +1628,7 @@ def de_f(
f120 = a * e220 - e020
f130 = a * e230 - e030

e320 = rpj * e310 + e311
# e320 = rpj * e310 + e311
f220 = a * e320 - 2 * e120
f230 = a * e330 - 2 * e130

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def md_recursion_gradient(
# for single gaussians (e.g. in tests)
if len(vec.shape) == 1:
vec = torch.unsqueeze(vec, 0)
s3d = torch.unsqueeze(s3d, 0)
# s3d = torch.unsqueeze(s3d, 0)
ds3d = torch.unsqueeze(ds3d, 0)

# calc E function for all (ai, aj)-combis for all vecs in batch
Expand Down
Loading

0 comments on commit b0f948f

Please sign in to comment.