diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 51e34e9b08..82184dec02 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -681,7 +681,17 @@ def _pass_filter( tf.shape(inputs_i)[0], self.nei_type_vec, # extra input for atten ) - inputs_i *= mask + if self.smooth: + inputs_i = tf.where( + tf.cast(mask, tf.bool), + inputs_i, + # (nframes * nloc, 1) -> (nframes * nloc, ndescrpt) + tf.tile( + tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt] + ), + ) + else: + inputs_i *= mask if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: inputs_i = descrpt2r4(inputs_i, atype) layer, qmat = self._filter(