Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jan 8, 2025
1 parent 6c234c2 commit 350cc54
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/dxtb/_src/calculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Calculator(AnalyticalCalculator, AutogradCalculator, NumericalCalculator):
)
"""Names of implemented methods of the Calculator."""

# The main implementation can be found in calculator base classes:
# dxtb/_src/calculators/types/base.py
def calculate(
self,
properties: list[str],
Expand Down
9 changes: 9 additions & 0 deletions src/dxtb/_src/calculators/config/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,12 @@ def __init__(
"The driver must be of type 'int' or 'str', but "
f"'{type(driver)}' was given."
)

def __str__(self) -> str: # pragma: no cover
return (
f"ConfigIntegrals(level={self.level}, cutoff={self.cutoff}, "
f"driver={self.driver}, uplo={self.uplo})"
)

def __repr__(self) -> str: # pragma: no cover
return str(self)
2 changes: 1 addition & 1 deletion src/dxtb/_src/calculators/config/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
force_convergence: bool = False,
fermi_etemp: float = defaults.FERMI_ETEMP,
fermi_maxiter: int = defaults.FERMI_MAXITER,
fermi_thresh: dict = defaults.FERMI_THRESH,
fermi_thresh: float | int | None = defaults.FERMI_THRESH,
fermi_partition: str | int = defaults.FERMI_PARTITION,
# cache
cache_enabled: bool = defaults.CACHE_ENABLED,
Expand Down
18 changes: 11 additions & 7 deletions src/dxtb/_src/calculators/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,17 @@ def __init__(
opts = Config(**opts, **dd)
self.opts = opts

# set integral level based on parametrization
if par.meta is not None:
if par.meta.name is not None:
if "gfn1" in par.meta.name.casefold():
self.opts.ints.level = labels.INTLEVEL_HCORE
elif "gfn2" in par.meta.name.casefold():
self.opts.ints.level = labels.INTLEVEL_QUADRUPOLE
# Set integral level based on parametrization. For the tests, we want
# to turn this off. Otherwise, all GFN2-xTB tests will fail without
# `libcint`, even when integrals are not tested (e.g. D4SC). This is
# caused by the integral factories in this very constructor.
if kwargs.pop("auto_int_level", True):
if par.meta is not None:
if par.meta.name is not None:
if "gfn1" in par.meta.name.casefold():
self.opts.ints.level = labels.INTLEVEL_HCORE
elif "gfn2" in par.meta.name.casefold():
self.opts.ints.level = labels.INTLEVEL_QUADRUPOLE

# create cache
self.cache = CalculatorCache(**dd) if cache is None else cache
Expand Down
2 changes: 1 addition & 1 deletion src/dxtb/_src/scf/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def get_refocc(

n0 = torch.where(
orb_per_shell != 0,
storch.divide(refocc, orb_per_shell),
storch.divide(refocc, orb_per_shell.type(refocc.dtype)),
torch.tensor(0, device=refs.device, dtype=refs.dtype),
)

Expand Down
4 changes: 2 additions & 2 deletions test/test_classical/test_dispersion/test_d4sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_single(dtype: torch.dtype, name: str):
ref = sample["edisp_d4sc"].to(**dd)
charges = torch.tensor(0.0, **dd)

calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd)
calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd, auto_int_level=False)

result = calc.singlepoint(positions, charges)
d4sc = calc.interactions.get_interaction("DispersionD4SC")
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str):
)
)

calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd)
calc = Calculator(numbers, GFN2_XTB, opts=opts, **dd, auto_int_level=False)

result = calc.singlepoint(positions)
d4sc = calc.interactions.get_interaction("DispersionD4SC")
Expand Down
2 changes: 1 addition & 1 deletion test/test_coulomb/test_es3_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_change_device_fail() -> None:


def test_device_fail_numbers() -> None:
n = torch.tensor([3, 1], device="cpu")
n = torch.tensor([3, 1], dtype=torch.float, device="cpu")
numbers = MockTensor(n)
numbers.device = "cuda"

Expand Down

0 comments on commit 350cc54

Please sign in to comment.