Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): support use_aparam_as_mask for pt backend #4246

Merged
merged 13 commits into from
Oct 26, 2024
8 changes: 6 additions & 2 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None
# init networks
in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
)
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
self.ntypes,
Expand Down Expand Up @@ -401,7 +405,7 @@ def _call_common(
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def __init__(
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")

Expand Down
15 changes: 11 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class GeneralFitting(Fitting):
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -147,6 +149,7 @@ def __init__(
trainable: Union[bool, list[bool]] = True,
remove_vaccum_contribution: Optional[list[bool]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -164,6 +167,7 @@ def __init__(
self.rcond = rcond
self.seed = seed
self.type_map = type_map
self.use_aparam_as_mask = use_aparam_as_mask
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
Expand Down Expand Up @@ -208,7 +212,11 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None

in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
)

self.filter_layers = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -293,13 +301,12 @@ def serialize(self) -> dict:
# "trainable": self.trainable ,
# "atom_ener": self.atom_ener ,
# "layer_name": self.layer_name ,
# "use_aparam_as_mask": self.use_aparam_as_mask ,
# "spin": self.spin ,
## NOTICE: not supported by far
"tot_ener_zero": False,
"trainable": [self.trainable] * (len(self.neuron) + 1),
"layer_name": None,
"use_aparam_as_mask": False,
"use_aparam_as_mask": self.use_aparam_as_mask,
"spin": None,
}

Expand Down Expand Up @@ -441,7 +448,7 @@ def _forward_common(
dim=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class InvarFitting(GeneralFitting):
The `set_davg_zero` key in the descrptor should be set.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.

use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -99,6 +100,7 @@ def __init__(
exclude_types: list[int] = [],
atom_ener: Optional[list[Optional[torch.Tensor]]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
self.dim_out = dim_out
Expand All @@ -122,6 +124,7 @@ def __init__(
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
type_map=type_map,
use_aparam_as_mask=use_aparam_as_mask,
**kwargs,
)

Expand Down
32 changes: 21 additions & 11 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _build_lower(
ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
ext_fparam = tf.cast(ext_fparam, self.fitting_precision)
layer = tf.concat([layer, ext_fparam], axis=1)
if aparam is not None:
if aparam is not None and not self.use_aparam_as_mask:
ext_aparam = tf.slice(
aparam,
[0, start_index * self.numb_aparam],
Expand Down Expand Up @@ -561,7 +561,7 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.fparam_inv_std),
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
t_aparam_avg = tf.get_variable(
"t_aparam_avg",
self.numb_aparam,
Expand All @@ -576,6 +576,13 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.aparam_inv_std),
)
else:
t_aparam_avg = tf.zeros(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)
t_aparam_istd = tf.ones(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)

inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt])
if len(self.atom_ener):
Expand All @@ -602,12 +609,11 @@ def build(
fparam = (fparam - t_fparam_avg) * t_fparam_istd

aparam = None
if not self.use_aparam_as_mask:
if self.numb_aparam > 0:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])

atype_nall = tf.reshape(atype, [-1, natoms[1]])
self.atype_nloc = tf.slice(
Expand Down Expand Up @@ -783,7 +789,7 @@ def init_variables(
self.fparam_inv_std = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_fparam_istd"
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
self.aparam_avg = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_aparam_avg"
)
Expand Down Expand Up @@ -883,7 +889,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if fitting.numb_fparam > 0:
fitting.fparam_avg = data["@variables"]["fparam_avg"]
fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"]
if fitting.numb_aparam > 0:
if fitting.numb_aparam > 0 and not fitting.use_aparam_as_mask:
fitting.aparam_avg = data["@variables"]["aparam_avg"]
fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"]
return fitting
Expand Down Expand Up @@ -922,7 +928,11 @@ def serialize(self, suffix: str = "") -> dict:
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
in_dim=(
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
),
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
Expand Down
8 changes: 7 additions & 1 deletion source/tests/consistent/fitting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class FittingTest:
"""Useful utilities for descriptor tests."""

def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, aparam, suffix):
t_inputs = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_inputs")
t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms")
t_atype = tf.placeholder(tf.int32, [None], name="i_atype")
Expand All @@ -30,6 +30,12 @@ def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
)
extras["fparam"] = t_fparam
feed_dict[t_fparam] = fparam
if aparam is not None:
t_aparam = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None, None], name="i_aparam"
)
extras["aparam"] = t_aparam
feed_dict[t_aparam] = aparam
t_out = obj.build(
t_inputs,
t_natoms,
Expand Down
22 changes: 22 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
("float64", "float32"), # precision
(True, False), # mixed_types
(0, 1), # numb_fparam
(0, 1), # numb_aparam
(10, 20), # numb_dos
)
class TestDOS(CommonTest, FittingTest, unittest.TestCase):
Expand All @@ -68,13 +69,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"numb_fparam": numb_fparam,
"numb_aparam": numb_aparam,
"seed": 20240217,
"numb_dos": numb_dos,
}
Expand All @@ -86,6 +89,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable numb_aparam is not used.
numb_dos,
) = self.param
return CommonTest.skip_pt
Expand Down Expand Up @@ -115,6 +119,9 @@
# inconsistent if not sorted
self.atype.sort()
self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION)
self.aparam = np.zeros_like(
self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
Expand All @@ -123,6 +130,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
Expand All @@ -137,6 +145,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return self.build_tf_fitting(
Expand All @@ -145,6 +154,7 @@
self.natoms,
self.atype,
self.fparam if numb_fparam else None,
self.aparam if numb_aparam else None,
suffix,
)

Expand All @@ -154,6 +164,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return (
Expand All @@ -163,6 +174,9 @@
fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE)
if numb_fparam
else None,
aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE)
if numb_aparam
else None,
)["dos"]
.detach()
.cpu()
Expand All @@ -175,12 +189,14 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return dp_obj(
self.inputs,
self.atype.reshape(1, -1),
fparam=self.fparam if numb_fparam else None,
aparam=self.aparam if numb_aparam else None,
)["dos"]

def eval_jax(self, jax_obj: Any) -> Any:
Expand All @@ -189,13 +205,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
aparam=jnp.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -206,13 +224,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -230,6 +250,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand All @@ -247,6 +268,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand Down
Loading
Loading