Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: refactor loss #3569

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,30 @@
self.use_l1_all = use_l1_all
self.inference = inference

def forward(self, model_pred, label, natoms, learning_rate, mae=False):
"""Return loss on loss and force.
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):

Check warning on line 93 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L93

Added line #L93 was not covered by tests
"""Return loss on energy and force.

Args:
- natoms: Tell atom count.
- p_energy: Predicted energy of all atoms.
- p_force: Predicted force per atom.
- l_energy: Actual energy of all atoms.
- l_force: Actual force per atom.
Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
- loss: Loss to minimize.
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
iProzd marked this conversation as resolved.
Show resolved Hide resolved
"""
model_pred = model(**input_dict)

Check warning on line 116 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L116

Added line #L116 was not covered by tests
coef = learning_rate / self.starter_learning_rate
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
Expand Down Expand Up @@ -200,7 +210,7 @@
more_loss["mae_v"] = mae_v.detach()
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss
return model_pred, loss, more_loss

Check warning on line 213 in deepmd/pt/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener.py#L213

Added line #L213 was not covered by tests

@property
def label_requirement(self) -> List[DataRequirementItem]:
Expand Down
13 changes: 9 additions & 4 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,30 @@
self.use_l1_all = use_l1_all
self.inference = inference

def forward(self, model_pred, label, natoms, learning_rate, mae=False):
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):

Check warning on line 66 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L66

Added line #L66 was not covered by tests
"""Return energy loss with magnetic labels.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)

Check warning on line 89 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L89

Added line #L89 was not covered by tests
coef = learning_rate / self.starter_learning_rate
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef
Expand Down Expand Up @@ -175,7 +180,7 @@

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return loss, more_loss
return model_pred, loss, more_loss

Check warning on line 183 in deepmd/pt/loss/ener_spin.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/ener_spin.py#L183

Added line #L183 was not covered by tests

@property
def label_requirement(self) -> List[DataRequirementItem]:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""Construct loss."""
super().__init__()

def forward(self, model_pred, label, natoms, learning_rate):
def forward(self, input_dict, model, label, natoms, learning_rate):

Check warning on line 22 in deepmd/pt/loss/loss.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/loss.py#L22

Added line #L22 was not covered by tests
"""Return loss ."""
raise NotImplementedError

Expand Down
13 changes: 9 additions & 4 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,30 @@
"Can not assian zero weight both to `pref` and `pref_atomic`"
)

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False):

Check warning on line 66 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L66

Added line #L66 was not covered by tests
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)

Check warning on line 89 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L89

Added line #L89 was not covered by tests
del learning_rate, mae
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
Expand Down Expand Up @@ -133,7 +138,7 @@
loss += self.global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
return loss, more_loss
return model_pred, loss, more_loss

Check warning on line 141 in deepmd/pt/loss/tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/tensor.py#L141

Added line #L141 was not covered by tests

@property
def label_requirement(self) -> List[DataRequirementItem]:
Expand Down
9 changes: 7 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,13 @@
module = (
self.wrapper.module if dist.is_initialized() else self.wrapper
)
loss, more_loss = module.loss[task_key](
model_pred,

def fake_model():
return model_pred

Check warning on line 701 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L700-L701

Added lines #L700 - L701 were not covered by tests

_, loss, more_loss = module.loss[task_key](

Check warning on line 703 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L703

Added line #L703 was not covered by tests
{},
fake_model,
label_dict,
int(input_dict["atype"].shape[-1]),
learning_rate=pref_lr,
Expand Down
19 changes: 12 additions & 7 deletions deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,20 @@
has_spin = has_spin()
if has_spin:
input_dict["spin"] = spin
model_pred = self.model[task_key](**input_dict)
natoms = atype.shape[-1]
if not self.inference_only and not inference_only:
loss, more_loss = self.loss[task_key](
model_pred, label, natoms=natoms, learning_rate=cur_lr

if self.inference_only or inference_only:
model_pred = self.model[task_key](**input_dict)
return model_pred, None, None

Check warning on line 174 in deepmd/pt/train/wrapper.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/wrapper.py#L172-L174

Added lines #L172 - L174 were not covered by tests
else:
natoms = atype.shape[-1]
model_pred, loss, more_loss = self.loss[task_key](

Check warning on line 177 in deepmd/pt/train/wrapper.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/wrapper.py#L176-L177

Added lines #L176 - L177 were not covered by tests
input_dict,
self.model[task_key],
label,
natoms=natoms,
learning_rate=cur_lr,
)
return model_pred, loss, more_loss
else:
return model_pred, None, None

def set_extra_state(self, state: Dict):
self.model_params = state["model_params"]
Expand Down
41 changes: 20 additions & 21 deletions source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,34 +338,33 @@ def test_consistency(self):
batch["natoms"] = torch.tensor(
batch["natoms_vec"], device=batch["coord"].device
).unsqueeze(0)
model_predict = my_model(
batch["coord"].to(env.DEVICE),
batch["atype"].to(env.DEVICE),
batch["box"].to(env.DEVICE),
do_atomic_virial=True,
)
model_predict_1 = my_model(
batch["coord"].to(env.DEVICE),
batch["atype"].to(env.DEVICE),
batch["box"].to(env.DEVICE),
do_atomic_virial=False,
model_input = {
"coord": batch["coord"].to(env.DEVICE),
"atype": batch["atype"].to(env.DEVICE),
"box": batch["box"].to(env.DEVICE),
"do_atomic_virial": True,
}
model_input_1 = {
"coord": batch["coord"].to(env.DEVICE),
"atype": batch["atype"].to(env.DEVICE),
"box": batch["box"].to(env.DEVICE),
"do_atomic_virial": False,
}
label = {
"energy": batch["energy"].to(env.DEVICE),
"force": batch["force"].to(env.DEVICE),
}
cur_lr = my_lr.value(self.wanted_step)
model_predict, loss, _ = my_loss(
model_input, my_model, label, int(batch["natoms"][0, 0]), cur_lr
)
model_predict_1 = my_model(**model_input_1)
p_energy, p_force, p_virial, p_atomic_virial = (
model_predict["energy"],
model_predict["force"],
model_predict["virial"],
model_predict["atom_virial"],
)
cur_lr = my_lr.value(self.wanted_step)
model_pred = {
"energy": p_energy,
"force": p_force,
}
label = {
"energy": batch["energy"].to(env.DEVICE),
"force": batch["force"].to(env.DEVICE),
}
loss, _ = my_loss(model_pred, label, int(batch["natoms"][0, 0]), cur_lr)
np.testing.assert_allclose(
head_dict["energy"], p_energy.view(-1).cpu().detach().numpy()
)
Expand Down
18 changes: 14 additions & 4 deletions source/tests/pt/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,13 @@ def test_consistency(self):
self.start_pref_v,
self.limit_pref_v,
)
my_loss, my_more_loss = mine(
self.model_pred,

def fake_model():
return self.model_pred

_, my_loss, my_more_loss = mine(
{},
fake_model,
self.label,
self.nloc,
self.cur_lr,
Expand Down Expand Up @@ -345,8 +350,13 @@ def test_consistency(self):
self.start_pref_fm,
self.limit_pref_fm,
)
my_loss, my_more_loss = mine(
self.model_pred,

def fake_model():
return self.model_pred

_, my_loss, my_more_loss = mine(
{},
fake_model,
self.label,
self.nloc_tf, # use tf natoms pref
self.cur_lr,
Expand Down