Skip to content

Commit

Permalink
Udpate multicharge API call (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Jan 17, 2024
1 parent f6988f5 commit e04d9ee
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/tad_dftd4/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.ncoord import cn_d4, erf_count
from tad_multicharge.eeq import get_charges
from tad_multicharge import get_eeq_charges

from . import data, defaults
from .cutoff import Cutoff
Expand Down Expand Up @@ -113,7 +113,7 @@ def dftd4(
if r4r2 is None:
r4r2 = data.R4R2.to(**dd)[numbers]
if q is None:
q = get_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq)
q = get_eeq_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq)

if numbers.shape != positions.shape[:-1]:
raise ValueError(
Expand Down
38 changes: 19 additions & 19 deletions test/test_disp/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,22 @@ class Record(Molecule, Refs):
{
"q": torch.tensor(
[
7.733478452374956e-01,
1.076268996435148e-01,
+7.733478452374956e-01,
+1.076268996435148e-01,
-3.669996418388237e-01,
4.928336699714377e-02,
+4.928336699714377e-02,
-1.833320732359188e-01,
2.333021537750765e-01,
6.618377120945702e-02,
+2.333021537750765e-01,
+6.618377120945702e-02,
-5.439442982394790e-01,
-2.702644018249256e-01,
2.666190421598861e-01,
2.627250807290775e-01,
+2.666190421598861e-01,
+2.627250807290775e-01,
-7.153145902661326e-02,
-3.733008547230057e-01,
3.845854622327315e-02,
+3.845854622327315e-02,
-5.058512350299781e-01,
5.176772579438197e-01,
+5.176772579438197e-01,
],
dtype=torch.float64,
),
Expand Down Expand Up @@ -211,21 +211,21 @@ class Record(Molecule, Refs):
{
"q": torch.tensor(
[
7.383947733600110e-02,
+7.383947733600110e-02,
-1.683548174961888e-01,
-3.476428218085238e-01,
-7.054893587245280e-01,
7.735482364313262e-01,
2.302076155262403e-01,
1.027485077260907e-01,
9.478181684909791e-02,
2.442577506263613e-02,
2.349849530340006e-01,
+7.735482364313262e-01,
+2.302076155262403e-01,
+1.027485077260907e-01,
+9.478181684909791e-02,
+2.442577506263613e-02,
+2.349849530340006e-01,
-3.178399496308427e-01,
6.671128897373533e-01,
+6.671128897373533e-01,
-4.781198581208199e-01,
6.575365749452318e-02,
1.082591170860242e-01,
+6.575365749452318e-02,
+1.082591170860242e-01,
-3.582152405023902e-01,
],
dtype=torch.float64,
Expand Down
1 change: 0 additions & 1 deletion test/test_grad/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_single(dtype: torch.dtype, name: str) -> None:
sample["hessian"].to(**dd),
torch.Size(2 * (numbers.shape[-1], 3)),
)
print(ref)

# variable to be differentiated
positions.requires_grad_(True)
Expand Down
91 changes: 87 additions & 4 deletions test/test_model/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
class Refs(TypedDict):
"""Format of reference values."""

q: Tensor
"""EEQ charges."""

gw: Tensor
"""
Gaussian weights. Shape must be `(nrefs, natoms)`, which we have to
Expand All @@ -48,7 +51,14 @@ class Record(Molecule, Refs):

refs: dict[str, Refs] = {
"LiH": Refs(
{ # CN, q
{ # CN
"q": torch.tensor(
[
3.708714958301688e-01,
-3.708714958301688e-01,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand All @@ -75,7 +85,17 @@ class Record(Molecule, Refs):
}
),
"SiH4": Refs(
{ # CN, q
{ # CN
"q": torch.tensor(
[
-8.412842390895063e-02,
2.103210597723753e-02,
2.103210597723774e-02,
2.103210597723764e-02,
2.103210597723773e-02,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down Expand Up @@ -143,6 +163,27 @@ class Record(Molecule, Refs):
),
"MB16_43_01": Refs(
{ # CN
"q": torch.tensor(
[
+7.733478452374956e-01,
+1.076268996435148e-01,
-3.669996418388237e-01,
+4.928336699714377e-02,
-1.833320732359188e-01,
+2.333021537750765e-01,
+6.618377120945702e-02,
-5.439442982394790e-01,
-2.702644018249256e-01,
+2.666190421598861e-01,
+2.627250807290775e-01,
-7.153145902661326e-02,
-3.733008547230057e-01,
+3.845854622327315e-02,
-5.058512350299781e-01,
+5.176772579438197e-01,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down Expand Up @@ -495,7 +536,28 @@ class Record(Molecule, Refs):
}
),
"MB16_43_02": Refs(
{ # q
{
"q": torch.tensor(
[
+7.383947733600110e-02,
-1.683548174961888e-01,
-3.476428218085238e-01,
-7.054893587245280e-01,
+7.735482364313262e-01,
+2.302076155262403e-01,
+1.027485077260907e-01,
+9.478181684909791e-02,
+2.442577506263613e-02,
+2.349849530340006e-01,
-3.178399496308427e-01,
+6.671128897373533e-01,
-4.781198581208199e-01,
+6.575365749452318e-02,
+1.082591170860242e-01,
-3.582152405023902e-01,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down Expand Up @@ -848,7 +910,28 @@ class Record(Molecule, Refs):
}
),
"MB16_43_03": Refs(
{ # CN, q
{ # CN
"q": torch.tensor(
[
-1.7778832703574010e-01,
-8.2294323973571670e-01,
4.0457879113787724e-02,
5.7971038082866722e-01,
6.9960183636529338e-01,
6.8430976075776473e-02,
-3.4297147449169296e-01,
4.6495478328605205e-02,
6.7701246205863264e-02,
8.4993144140514468e-02,
-5.2228521752048518e-01,
-2.9251488187370783e-01,
-3.9837556749973635e-01,
2.0976964648102694e-01,
7.2314045922878123e-01,
3.6577661388763623e-02,
],
dtype=torch.float64,
),
"gw": reshape_fortran(
torch.tensor(
[
Expand Down
14 changes: 7 additions & 7 deletions test/test_model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
from tad_mctc.batch import pack
from tad_mctc.ncoord import cn_d4
from tad_multicharge.eeq import get_charges # get rid!

from tad_dftd4.model import D4Model
from tad_dftd4.typing import DD
Expand All @@ -44,14 +43,12 @@ def test_single(name: str, dtype: torch.dtype) -> None:
sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)
q = sample["q"].to(**dd)
ref = sample["c6"].to(**dd)

d4 = D4Model(numbers, **dd)

cn = cn_d4(numbers, positions)
total_charge = torch.tensor(0.0, **dd)
q = get_charges(numbers, positions, total_charge)

gw = d4.weight_references(cn=cn, q=q)
c6 = d4.get_atomic_c6(gw)
assert pytest.approx(ref.cpu(), abs=tol, rel=tol) == c6.cpu()
Expand All @@ -77,6 +74,12 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
sample2["positions"].to(**dd),
]
)
q = pack(
[
sample1["q"].to(**dd),
sample2["q"].to(**dd),
]
)
refs = pack(
[
sample1["c6"].to(**dd),
Expand All @@ -87,9 +90,6 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
d4 = D4Model(numbers, **dd)

cn = cn_d4(numbers, positions)
total_charge = torch.zeros(numbers.shape[0], **dd)
q = get_charges(numbers, positions, total_charge)

gw = d4.weight_references(cn=cn, q=q)
c6 = d4.get_atomic_c6(gw)
assert pytest.approx(refs.cpu(), abs=tol, rel=tol) == c6.cpu()
12 changes: 7 additions & 5 deletions test/test_model/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch.nn.functional as F
from tad_mctc.batch import pack
from tad_mctc.ncoord import cn_d4
from tad_multicharge.eeq import get_charges

from tad_dftd4.model import D4Model
from tad_dftd4.typing import DD
Expand Down Expand Up @@ -55,7 +54,7 @@ def single(
cn = None # positions.new_zeros(numbers.shape)

if with_q is True:
q = get_charges(numbers, positions, torch.tensor(0.0, **dd))
q = sample["q"].to(**dd)
else:
q = None # positions.new_zeros(numbers.shape)

Expand Down Expand Up @@ -120,13 +119,16 @@ def test_batch(name1: str, name2: str, dtype: torch.dtype) -> None:
sample2["positions"].to(**dd),
]
)
q = pack(
[
sample1["q"].to(**dd),
sample2["q"].to(**dd),
]
)

d4 = D4Model(numbers, **dd)

cn = cn_d4(numbers, positions)
total_charge = positions.new_zeros(numbers.shape[0])
q = get_charges(numbers, positions, total_charge)

gwvec = d4.weight_references(cn, q)

# pad reference tensor to always be of shape `(natoms, 7)`
Expand Down

0 comments on commit e04d9ee

Please sign in to comment.