From d8790c33ec327e61241f073e7978e030000fe3a4 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Sun, 20 Oct 2024 12:38:04 -0500 Subject: [PATCH] Fix batched charges --- examples/issues/179/run.py | 105 ++++++++++++++++++++ src/dxtb/_src/calculators/types/energy.py | 3 + src/dxtb/_src/calculators/utils.py | 73 ++++++++++++++ src/dxtb/_src/components/classicals/list.py | 2 +- src/dxtb/_src/scf/iterator.py | 5 +- test/test_calculator/test_general.py | 48 ++++++++- 6 files changed, 232 insertions(+), 4 deletions(-) create mode 100644 examples/issues/179/run.py create mode 100644 src/dxtb/_src/calculators/utils.py diff --git a/examples/issues/179/run.py b/examples/issues/179/run.py new file mode 100644 index 00000000..ddbac09a --- /dev/null +++ b/examples/issues/179/run.py @@ -0,0 +1,105 @@ +import torch + +import dxtb +from dxtb.typing import DD + +dd: DD = {"device": torch.device("cpu"), "dtype": torch.double} + +num1 = torch.tensor( + [8, 1, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1], + device=dd["device"], +) +num2 = torch.tensor( + [8, 1, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1], + device=dd["device"], +) + +pos1 = torch.tensor( + [ + [-7.68281384, 1.3350934, 0.74846383], + [-5.7428588, 1.31513411, 0.36896714], + [-8.23756184, -0.19765779, 1.67193897], + [-8.13313558, 2.93710683, 1.6453921], + [-2.95915993, 1.40005084, 0.24966306], + [-2.1362031, 1.4795743, -1.38758999], + [-2.40235213, 2.84218589, 1.24419946], + [-8.2640369, 5.79677268, 2.54733192], + [-8.68767571, 7.18194193, 1.3350556], + [-9.27787497, 6.09327071, 4.03498102], + [-9.34575393, -2.54164384, 3.28062124], + [-8.59029812, -3.46388688, 4.6567765], + [-10.71898011, -3.58163572, 2.65211723], + [-9.5591796, 9.66793334, -0.53212042], + [-8.70438089, 11.29169941, -0.5990394], + [-11.12723654, 9.8483266, -1.43755624], + [-2.69970054, 5.55135395, 2.96084179], + [-1.59244386, 6.50972855, 4.06699298], + [-4.38439138, 6.18065165, 3.1939773], + ], + **dd +) + +pos2 = torch.tensor( + [ + [-7.67436676, 1.33433562, 0.74512468], + [-5.75285545, 1.30220838, 0.37189432], + [-8.23155251, -0.20308887, 1.67397231], + [-8.15184386, 2.94589406, 1.6474141], + [-2.96380866, 1.39739578, 0.24572676], + [-2.14413995, 1.48993378, -1.37321106], + [-2.39808135, 2.86614761, 1.25247646], + [-8.26855335, 5.79452391, 2.54948621], + [-8.69277797, 7.18061912, 1.33247046], + [-9.28819287, 6.08797948, 4.03809906], + [-9.3377226, -2.54245643, 3.27861813], + [-8.59693106, -3.48501402, 4.65503795], + [-10.72627446, -3.59514726, 2.66139579], + [-9.55955755, 9.6716561, -0.53106973], + [-8.7077635, 11.28708848, -0.59527696], + [-11.12540351, 9.87000175, -1.44181568], + [-2.70194931, 5.55490663, 2.9641866], + [-1.60305656, 6.49854138, 4.07984311], + [-4.39083534, 6.17898869, 3.18702311], + ], + **dd +) + +charge1 = torch.tensor(1, **dd) +charge2 = torch.tensor(1, **dd) + + +############################################################################## + + +numbers = torch.stack([num1, num2]) +positions = torch.stack([pos1, pos2]) +charge = torch.tensor([charge1, charge2]) + +# no conformers -> batched mode 1 +opts = {"verbosity": 0, "batch_mode": 1} + +calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd) +result = calc.energy(positions, chrg=charge) + + +############################################################################## + + +calc = dxtb.Calculator(num1, dxtb.GFN1_XTB, opts={"verbosity": 0}, **dd) +result1 = calc.energy(pos1, chrg=charge1) + + +############################################################################## + + +calc = dxtb.Calculator(num2, dxtb.GFN1_XTB, opts={"verbosity": 0}, **dd) +result2 = calc.energy(pos2, chrg=charge2) + + +############################################################################## + + +assert torch.allclose(result[0], result1) +assert torch.allclose(result[1], result2) + +print("Issue 179 is fixed!") diff --git a/src/dxtb/_src/calculators/types/energy.py b/src/dxtb/_src/calculators/types/energy.py index d096eff1..602c0be5 100644 --- a/src/dxtb/_src/calculators/types/energy.py +++ b/src/dxtb/_src/calculators/types/energy.py @@ -36,6 +36,7 @@ from dxtb._src.utils.tensors import tensor_id from ..result import Result +from ..utils import shape_checks_chrg from . import decorators as cdec from .base import BaseCalculator @@ -116,6 +117,8 @@ def singlepoint( if spin is not None: spin = any_to_tensor(spin, **self.dd) + assert shape_checks_chrg(chrg, self.numbers.ndim, name="Charge") + result = Result(positions, **self.dd) ########################### diff --git a/src/dxtb/_src/calculators/utils.py b/src/dxtb/_src/calculators/utils.py new file mode 100644 index 00000000..39076d85 --- /dev/null +++ b/src/dxtb/_src/calculators/utils.py @@ -0,0 +1,73 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Calculators: Utility +==================== + +Collection of utility functions for the calculator. +""" + +from __future__ import annotations + +from dxtb._src.typing import Literal, NoReturn, Tensor + +__all__ = ["shape_checks_chrg"] + + +def shape_checks_chrg( + t: Tensor, ndims: int, name: str = "Charge" +) -> Literal[True] | NoReturn: + """ + Check the shape of a tensor. + + Parameters + ---------- + t : Tensor + The tensor to check. + ndims : int + The number of dimensions indicating single or batched calculations. + + Raises + ------ + ValueError + If the tensor has not the expected number of dimensions. + """ + + if t.ndim > 1: + raise ValueError( + f"{name.title()} tensor has more than 1 dimension. " + "Please use a 1D tensor for batched calculations " + "(e.g., `torch.tensor([1.0, 0.0])`), instead of " + "a 2D tensor (e.g., NOT `torch.tensor([[1.0], [0.0]])`)." + ) + + if t.ndim == 1 and t.numel() == 1: + raise ValueError( + f"{name.title()} tensor has only one element. Please use a " + "scalar for single structures (e.g., `torch.tensor(1.0)`) and " + "a 1D tensor for batched calculations (e.g., " + ) + + if ndims != t.ndim + 1: + raise ValueError( + f"{name.title()} tensor has invalid shape: {t.shape}.\n " + "Please use a scalar for single structures (e.g., " + "`torch.tensor(1.0)`) and a 1D tensor for batched " + "calculations (e.g., `torch.tensor([1.0, 0.0])`)." + ) + + return True diff --git a/src/dxtb/_src/components/classicals/list.py b/src/dxtb/_src/components/classicals/list.py index a7002afc..7030c80b 100644 --- a/src/dxtb/_src/components/classicals/list.py +++ b/src/dxtb/_src/components/classicals/list.py @@ -112,7 +112,7 @@ def get_energy( if len(self.components) <= 0: return {"none": positions.new_zeros(positions.shape[:-1])} - energies = {} + energies: dict[str, Tensor] = {} for classical in self.components: timer.start(classical.label, parent_uid="Classicals") energies[classical.label] = classical.get_energy( diff --git a/src/dxtb/_src/scf/iterator.py b/src/dxtb/_src/scf/iterator.py index 1321d675..2c2a33fd 100644 --- a/src/dxtb/_src/scf/iterator.py +++ b/src/dxtb/_src/scf/iterator.py @@ -32,6 +32,7 @@ from tad_mctc import storch from dxtb import IndexHelper +from dxtb._src.calculators.utils import shape_checks_chrg from dxtb._src.components.interactions import ( InteractionList, InteractionListCache, @@ -176,8 +177,10 @@ def get_refocc( torch.tensor(0, device=refs.device, dtype=refs.dtype), ) + assert shape_checks_chrg(chrg, n0.ndim, name="Charge") + # Obtain the reference occupation and total number of electrons - nel = torch.sum(n0, -1) - torch.sum(chrg, -1) + nel = torch.sum(n0, -1) - chrg # get alpha and beta electrons and occupation nab = filling.get_alpha_beta_occupation(nel, spin) diff --git a/test/test_calculator/test_general.py b/test/test_calculator/test_general.py index 81adbc19..e86a45f3 100644 --- a/test/test_calculator/test_general.py +++ b/test/test_calculator/test_general.py @@ -27,10 +27,16 @@ from dxtb import GFN1_XTB as par from dxtb import Calculator, labels from dxtb._src.timing import timer +from dxtb.typing import DD +from ..conftest import DEVICE -def test_fail() -> None: - numbers = torch.tensor([6, 1, 1, 1, 1], dtype=torch.double) + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_fail(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([6, 1, 1, 1, 1], **dd) with pytest.raises(DtypeError): Calculator(numbers, par, opts={"verbosity": 0}) @@ -39,6 +45,44 @@ def test_fail() -> None: timer.reset() +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_fail_charge_single(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([3, 1], device=DEVICE) + positions = torch.zeros(2, 3, **dd) + + calc = Calculator(numbers, par, opts={"verbosity": 0}) + + # charge must be a scalar for single structure + with pytest.raises(ValueError) as excinfo: + charge = torch.tensor([0.0], **dd) + calc.singlepoint(positions, chrg=charge) + + assert "Charge tensor has only one element" in str(excinfo) + + +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_fail_charge_batch(dtype: torch.dtype) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + numbers = torch.tensor([[3, 1], [3, 1]], device=DEVICE) + positions = torch.zeros(2, 2, 3, **dd) + + calc = Calculator(numbers, par, opts={"verbosity": 0}) + with pytest.raises(ValueError) as excinfo: + charge = torch.tensor([[0.0], [0.0]], **dd) + calc.singlepoint(positions, chrg=charge) + + assert "Charge tensor has more than 1 dimension" in str(excinfo) + + with pytest.raises(ValueError) as excinfo: + charge = torch.tensor(0.0, **dd) + calc.singlepoint(positions, chrg=charge) + + assert "Charge tensor has invalid shape" in str(excinfo) + + def run_asserts(c: Calculator, dtype: torch.dtype) -> None: assert c.dtype == dtype assert c.classicals.dtype == dtype