diff --git a/references/recognition/train_pytorch_ddp.py b/references/recognition/train_pytorch_ddp.py index fd1906dcb4..ec18ce0298 100644 --- a/references/recognition/train_pytorch_ddp.py +++ b/references/recognition/train_pytorch_ddp.py @@ -312,7 +312,7 @@ def main(rank: int, world_size: int, args): # random parameters and gradients are synchronized in backward passes. # Therefore, saving it in one process is sufficient. print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), f"./{exp_name}.pt") + torch.save(model.module.state_dict(), f"./{exp_name}.pt") min_loss = val_loss mb.write( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} "