Skip to content

Commit

Permalink
merge compute_output_stat (#3310)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 28, 2024
1 parent 16cd26c commit 3ad57da
Show file tree
Hide file tree
Showing 14 changed files with 312 additions and 89 deletions.
8 changes: 0 additions & 8 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ def output_def(self) -> FittingOutputDef:
]
)

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the fitting.
Return a list of statistic names needed, such as "bias_atom_e".
"""
return []

def forward(
self,
descriptor: torch.Tensor,
Expand Down
34 changes: 22 additions & 12 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.pt.utils.stat import (
compute_output_bias,
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
Expand Down Expand Up @@ -135,16 +138,8 @@ def serialize(self) -> dict:
data["atom_ener"] = self.atom_ener
return data

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the fitting.
Return a list of statistic names needed, such as "bias_atom_e".
"""
return ["bias_atom_e"]

def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None):
energy = [item["energy"] for item in merged]
energy = [item[self.var_name] for item in merged]
data_mixed_type = "real_natoms_vec" in merged[0]
if data_mixed_type:
input_natoms = [item["real_natoms_vec"] for item in merged]
Expand All @@ -155,7 +150,22 @@ def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None):
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()
else:
bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond)
# shape: (nframes, ndim)
merged_energy = to_numpy_array(torch.cat(energy))
# shape: (nframes, ntypes)
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
if self.atom_ener is not None and len(self.atom_ener) > 0:
assigned_atom_ener = np.array(
[ee if ee is not None else np.nan for ee in self.atom_ener]
)
else:
assigned_atom_ener = None
bias_atom_e, _ = compute_stats_from_redu(
merged_energy,
merged_natoms,
assigned_bias=assigned_atom_ener,
rcond=self.rcond,
)
if stat_file_path is not None:
stat_file_path.save_numpy(bias_atom_e)
assert all(x is not None for x in [bias_atom_e])
Expand Down
8 changes: 0 additions & 8 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,6 @@ def share_params(self, base_class, shared_level, resume=False):
else:
raise NotImplementedError

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the fitting.
Return a list of statistic names needed, such as "bias_atom_e".
"""
raise NotImplementedError("data_stat_key is not implemented!")

def change_energy_bias(
self, config, model, old_type_map, new_type_map, bias_shift="delta", ntest=10
):
Expand Down
8 changes: 0 additions & 8 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,6 @@ def output_def(self) -> FittingOutputDef:
]
)

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the fitting.
Return a list of statistic names needed, such as "bias_atom_e".
"""
return []

def forward(
self,
descriptor: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def iter(
Parameters
----------
data : List[Dict[str, torch.Tensor]]
The environment matrix.
The data.
Yields
------
Expand Down
21 changes: 0 additions & 21 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging

import numpy as np
import torch

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,23 +56,3 @@ def make_stat_input(datasets, dataloaders, nbatches):
sys_stat[key] = sys_stat_list
lst.append(sys_stat)
return lst


def compute_output_bias(energy, natoms, rcond=None):
"""Update output bias for fitting net.
Args:
- energy: Batched energy with shape [nframes, 1].
- natoms: Batched atom statisics with shape [self.ntypes+2].
Returns
-------
- energy_coef: Average enery per atom for each element.
"""
for i in range(len(energy)):
energy[i] = energy[i].mean(dim=0, keepdim=True)
natoms[i] = natoms[i].double().mean(dim=0, keepdim=True)
sys_ener = torch.cat(energy).cpu()
sys_tynatom = torch.cat(natoms)[:, 2:].cpu()
energy_coef, _, _, _ = np.linalg.lstsq(sys_tynatom, sys_ener, rcond)
return energy_coef
9 changes: 7 additions & 2 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
from deepmd.tf.utils.network import (
one_layer_rand_seed_shift,
)
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -225,8 +228,10 @@ def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
sys_tynatom = sys_tynatom[:, 2:]

dos_shift, resd, rank, s_value = np.linalg.lstsq(
sys_tynatom, sys_dos, rcond=rcond
dos_shift, _ = compute_stats_from_redu(
sys_dos,
sys_tynatom,
rcond=rcond,
)

return dos_shift
Expand Down
25 changes: 12 additions & 13 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from deepmd.tf.utils.spin import (
Spin,
)
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -295,21 +298,17 @@ def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
# In this situation, we directly use these assigned energies instead of computing stats.
# This will make the loss decrease quickly
assigned_atom_ener = np.array(
[ee for ee in self.atom_ener_v if ee is not None]
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
)
assigned_ener_idx = [
ii for ii, ee in enumerate(self.atom_ener_v) if ee is not None
]
# np.dot out size: nframe
sys_ener -= np.dot(sys_tynatom[:, assigned_ener_idx], assigned_atom_ener)
sys_tynatom[:, assigned_ener_idx] = 0.0
energy_shift, resd, rank, s_value = np.linalg.lstsq(
sys_tynatom, sys_ener, rcond=rcond
else:
assigned_atom_ener = None
energy_shift, _ = compute_stats_from_redu(
sys_ener.reshape(-1, 1),
sys_tynatom,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
if len(self.atom_ener) > 0:
for ii in assigned_ener_idx:
energy_shift[ii] = self.atom_ener_v[ii]
return energy_shift
return energy_shift.ravel()

def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None:
"""Compute the input statistics.
Expand Down
6 changes: 2 additions & 4 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,14 @@ def get_out_size(self) -> int:
"""Get the output size. Should be 9."""
return 9

def compute_input_stats(self, all_stat, protection=1e-2):
"""Compute the input statistics.
def compute_output_stats(self, all_stat):
"""Compute the output statistics.
Parameters
----------
all_stat
Dictionary of inputs.
can be prepared by model.make_stat_input
protection
Divided-by-zero protection
"""
if "polarizability" not in all_stat.keys():
self.avgeig = np.zeros([9])
Expand Down
11 changes: 8 additions & 3 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from deepmd.utils.data import (
DeepmdData,
)
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -248,10 +251,12 @@ def compute_energy_shift(self, rcond=None, key="energy"):
sys_tynatom = np.array(self.natoms_vec, dtype=GLOBAL_NP_FLOAT_PRECISION)
sys_tynatom = np.reshape(sys_tynatom, [self.nsystems, -1])
sys_tynatom = sys_tynatom[:, 2:]
energy_shift, resd, rank, s_value = np.linalg.lstsq(
sys_tynatom, sys_ener, rcond=rcond
energy_shift, _ = compute_stats_from_redu(
sys_ener.reshape(-1, 1),
sys_tynatom,
rcond=rcond,
)
return energy_shift
return energy_shift.ravel()

def add_dict(self, adict: dict) -> None:
"""Add items to the data system by a `dict`.
Expand Down
117 changes: 117 additions & 0 deletions deepmd/utils/out_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Output statistics."""
from typing import (
Optional,
Tuple,
)

import numpy as np


def compute_stats_from_redu(
output_redu: np.ndarray,
natoms: np.ndarray,
assigned_bias: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the output statistics.
Given the reduced output value and the number of atoms for each atom,
compute the least-squares solution as the atomic output bais and std.
Parameters
----------
output_redu
The reduced output value, shape is [nframes, ndim].
natoms
The number of atoms for each atom, shape is [nframes, ntypes].
assigned_bias
The assigned output bias, shape is [ntypes, ndim]. Set to nan
if not assigned.
rcond
Cut-off ratio for small singular values of a.
Returns
-------
np.ndarray
The computed output bias, shape is [ntypes, ndim].
np.ndarray
The computed output std, shape is [ntypes, ndim].
"""
output_redu = np.array(output_redu)
natoms = np.array(natoms)
# check shape
assert output_redu.ndim == 2
assert natoms.ndim == 2
assert output_redu.shape[0] == natoms.shape[0] # nframes
if assigned_bias is not None:
assigned_bias = np.array(assigned_bias).reshape(
natoms.shape[1], output_redu.shape[1]
)
# compute output bias
if assigned_bias is not None:
# Atomic energies stats are incorrect if atomic energies are assigned.
# In this situation, we directly use these assigned energies instead of computing stats.
# This will make the loss decrease quickly
assigned_bias_atom_mask = ~np.isnan(assigned_bias).any(axis=1)
# assigned_bias_masked: nmask, ndim
assigned_bias_masked = assigned_bias[assigned_bias_atom_mask]
# assigned_bias_natoms: nframes, nmask
assigned_bias_natoms = natoms[:, assigned_bias_atom_mask]
# output_redu: nframes, ndim
output_redu -= np.einsum(
"ij,jk->ik", assigned_bias_natoms, assigned_bias_masked
)
# remove assigned atom
natoms[:, assigned_bias_atom_mask] = 0

# computed_output_bias: ntypes, ndim
computed_output_bias, _, _, _ = np.linalg.lstsq(natoms, output_redu, rcond=rcond)
if assigned_bias is not None:
# add back assigned atom; this might not be required
computed_output_bias[assigned_bias_atom_mask] = assigned_bias_masked
# rest_redu: nframes, ndim
rest_redu = output_redu - np.einsum("ij,jk->ik", natoms, computed_output_bias)
output_std = rest_redu.std(axis=0)
return computed_output_bias, output_std


def compute_stats_from_atomic(
output: np.ndarray,
atype: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the output statistics.
Given the output value and the type of atoms,
compute the atomic output bais and std.
Parameters
----------
output
The output value, shape is [nframes, nloc, ndim].
atype
The type of atoms, shape is [nframes, nloc].
Returns
-------
np.ndarray
The computed output bias, shape is [ntypes, ndim].
np.ndarray
The computed output std, shape is [ntypes, ndim].
"""
output = np.array(output)
atype = np.array(atype)
# check shape
assert output.ndim == 3
assert atype.ndim == 2
assert output.shape[:2] == atype.shape
# compute output bias
nframes, nloc, ndim = output.shape
ntypes = atype.max() + 1
output_bias = np.zeros((ntypes, ndim))
output_std = np.zeros((ntypes, ndim))
for type_i in range(ntypes):
mask = atype == type_i
output_bias[type_i] = output[mask].mean(axis=0)
output_std[type_i] = output[mask].std(axis=0)
return output_bias, output_std
Loading

0 comments on commit 3ad57da

Please sign in to comment.