diff --git a/tapeagents/finetune/rl/__init__.py b/tapeagents/finetune/rl/__init__.py index 941c18ef..1597d610 100644 --- a/tapeagents/finetune/rl/__init__.py +++ b/tapeagents/finetune/rl/__init__.py @@ -136,15 +136,20 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc assert approx_kl.shape == masks_.shape assert approx_kl.shape == surrogate_loss.shape - loss = -masked_sum(surrogate_loss - config.kl_coef * approx_kl, masks_) + loss = surrogate_loss - config.kl_coef * approx_kl + # loss = -masked_sum(surrogate_loss - config.kl_coef * approx_kl, masks_) case "reinforce": surr1 = torch.zeros_like(ratio_new_old) surr2 = torch.zeros_like(ratio_new_old) - loss = -masked_sum(new_log_probs * log_p_weights - config.kl_coef * approx_kl, masks_) + loss = new_log_probs * log_p_weights - config.kl_coef * approx_kl + # loss = -masked_sum(new_log_probs * log_p_weights - config.kl_coef * approx_kl, masks_) case _: raise ValueError(f"Unknown algorithm {config.algo}") + num_nans = torch.isnan(loss).sum() + loss = -masked_sum(loss, masks_) assert torch.isfinite(loss).all(), f"Loss is not finite: {loss}" + # normalize the loss by the micro batch size loss = loss / masks.shape[0] stats = { @@ -181,6 +186,7 @@ def rl_step(model: PreTrainedModel, batch: dict, config: RLConfig) -> tuple[torc "ratio_ref_new": masked_mean(torch.exp(log_ratio_ref_new), masks_).item(), "ratio_ref_old": masked_mean(torch.exp(ref_logprobs - old_logprobs), masks_).item(), "clamp_log_ratio_ref_new_indicators": masked_mean(clamp_log_ratio_ref_new_indicators, masks_).item(), + "num_nans": num_nans.item(), } return loss, stats diff --git a/tapeagents/finetune/rl/utils.py b/tapeagents/finetune/rl/utils.py index f2da96ef..fae36519 100644 --- a/tapeagents/finetune/rl/utils.py +++ b/tapeagents/finetune/rl/utils.py @@ -33,17 +33,17 @@ def get_avg_rl_stats(rl_stats): def masked_sum(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: """Compute sum of tensor with a masked values.""" if axis is not None: - return (values * mask).sum(axis=axis) # type: ignore + return (values * mask).nan_to_num(0).sum(axis=axis) # type: ignore else: - return (values * mask).sum() + return (values * mask).nan_to_num(0).sum() def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: """Compute mean of tensor with a masked values.""" if axis is not None: - return (values * mask).sum(axis=axis) / mask.sum(axis=axis) # type: ignore + return (values * mask).nan_to_num(0).sum(axis=axis) / mask.sum(axis=axis) # type: ignore else: - return (values * mask).sum() / mask.sum() + return (values * mask).nan_to_num(0).sum() / mask.sum() def calculate_rewards_with_implicit_kl(row, reward_minus_kl_coef):