Skip to content

Commit

Permalink
pt: refactor loss (#3569)
Browse files Browse the repository at this point in the history
This PR updates the loss interface to allow for a more flexible design. 
It enables processing input tensors before feeding them into the model,
such as denoising operations (fyi @Chengqian-Zhang ). Previously, this
was done in the data loader, which was less intuitive and more
confusing. Now, users can easily handle these tasks within the loss
function itself, as demonstrated in similar implementations in uni-mol:
https://github.com/dptech-corp/Uni-Mol/blob/main/unimol/unimol/losses/unimol.py#L20.
  • Loading branch information
iProzd authored Mar 20, 2024
1 parent 9c861c2 commit 71ec631
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 53 deletions.
30 changes: 20 additions & 10 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,30 @@ def __init__(
self.use_l1_all = use_l1_all
self.inference = inference

def forward(self, model_pred, label, natoms, learning_rate, mae=False):
"""Return loss on loss and force.
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force.
Args:
- natoms: Tell atom count.
- p_energy: Predicted energy of all atoms.
- p_force: Predicted force per atom.
- l_energy: Actual energy of all atoms.
- l_force: Actual force per atom.
Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
- loss: Loss to minimize.
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)
coef = learning_rate / self.starter_learning_rate
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
Expand Down Expand Up @@ -200,7 +210,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
more_loss["mae_v"] = mae_v.detach()
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss
return model_pred, loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
Expand Down
13 changes: 9 additions & 4 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,30 @@ def __init__(
self.use_l1_all = use_l1_all
self.inference = inference

def forward(self, model_pred, label, natoms, learning_rate, mae=False):
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return energy loss with magnetic labels.
Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)
coef = learning_rate / self.starter_learning_rate
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef
Expand Down Expand Up @@ -175,7 +180,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss
return model_pred, loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, **kwargs):
"""Construct loss."""
super().__init__()

def forward(self, model_pred, label, natoms, learning_rate):
def forward(self, input_dict, model, label, natoms, learning_rate):
"""Return loss ."""
raise NotImplementedError

Expand Down
13 changes: 9 additions & 4 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,30 @@ def __init__(
"Can not assian zero weight both to `pref` and `pref_atomic`"
)

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.
Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)
del learning_rate, mae
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
Expand Down Expand Up @@ -133,7 +138,7 @@ def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
loss += self.global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
return loss, more_loss
return model_pred, loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,13 @@ def step(_step_id, task_key="Default"):
module = (
self.wrapper.module if dist.is_initialized() else self.wrapper
)
loss, more_loss = module.loss[task_key](
model_pred,

def fake_model():
return model_pred

_, loss, more_loss = module.loss[task_key](
{},
fake_model,
label_dict,
int(input_dict["atype"].shape[-1]),
learning_rate=pref_lr,
Expand Down
19 changes: 12 additions & 7 deletions deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,20 @@ def forward(
has_spin = has_spin()
if has_spin:
input_dict["spin"] = spin
model_pred = self.model[task_key](**input_dict)
natoms = atype.shape[-1]
if not self.inference_only and not inference_only:
loss, more_loss = self.loss[task_key](
model_pred, label, natoms=natoms, learning_rate=cur_lr

if self.inference_only or inference_only:
model_pred = self.model[task_key](**input_dict)
return model_pred, None, None
else:
natoms = atype.shape[-1]
model_pred, loss, more_loss = self.loss[task_key](
input_dict,
self.model[task_key],
label,
natoms=natoms,
learning_rate=cur_lr,
)
return model_pred, loss, more_loss
else:
return model_pred, None, None

def set_extra_state(self, state: Dict):
self.model_params = state["model_params"]
Expand Down
41 changes: 20 additions & 21 deletions source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,34 +338,33 @@ def test_consistency(self):
batch["natoms"] = torch.tensor(
batch["natoms_vec"], device=batch["coord"].device
).unsqueeze(0)
model_predict = my_model(
batch["coord"].to(env.DEVICE),
batch["atype"].to(env.DEVICE),
batch["box"].to(env.DEVICE),
do_atomic_virial=True,
)
model_predict_1 = my_model(
batch["coord"].to(env.DEVICE),
batch["atype"].to(env.DEVICE),
batch["box"].to(env.DEVICE),
do_atomic_virial=False,
model_input = {
"coord": batch["coord"].to(env.DEVICE),
"atype": batch["atype"].to(env.DEVICE),
"box": batch["box"].to(env.DEVICE),
"do_atomic_virial": True,
}
model_input_1 = {
"coord": batch["coord"].to(env.DEVICE),
"atype": batch["atype"].to(env.DEVICE),
"box": batch["box"].to(env.DEVICE),
"do_atomic_virial": False,
}
label = {
"energy": batch["energy"].to(env.DEVICE),
"force": batch["force"].to(env.DEVICE),
}
cur_lr = my_lr.value(self.wanted_step)
model_predict, loss, _ = my_loss(
model_input, my_model, label, int(batch["natoms"][0, 0]), cur_lr
)
model_predict_1 = my_model(**model_input_1)
p_energy, p_force, p_virial, p_atomic_virial = (
model_predict["energy"],
model_predict["force"],
model_predict["virial"],
model_predict["atom_virial"],
)
cur_lr = my_lr.value(self.wanted_step)
model_pred = {
"energy": p_energy,
"force": p_force,
}
label = {
"energy": batch["energy"].to(env.DEVICE),
"force": batch["force"].to(env.DEVICE),
}
loss, _ = my_loss(model_pred, label, int(batch["natoms"][0, 0]), cur_lr)
np.testing.assert_allclose(
head_dict["energy"], p_energy.view(-1).cpu().detach().numpy()
)
Expand Down
18 changes: 14 additions & 4 deletions source/tests/pt/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,13 @@ def test_consistency(self):
self.start_pref_v,
self.limit_pref_v,
)
my_loss, my_more_loss = mine(
self.model_pred,

def fake_model():
return self.model_pred

_, my_loss, my_more_loss = mine(
{},
fake_model,
self.label,
self.nloc,
self.cur_lr,
Expand Down Expand Up @@ -345,8 +350,13 @@ def test_consistency(self):
self.start_pref_fm,
self.limit_pref_fm,
)
my_loss, my_more_loss = mine(
self.model_pred,

def fake_model():
return self.model_pred

_, my_loss, my_more_loss = mine(
{},
fake_model,
self.label,
self.nloc_tf, # use tf natoms pref
self.cur_lr,
Expand Down

0 comments on commit 71ec631

Please sign in to comment.