Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint every 10 #86

Merged
merged 7 commits into from
Jul 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import os
from datetime import datetime
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -56,6 +56,7 @@ def train_model(
train_loader: DataLoader | InMemoryDataLoader,
test_loader: DataLoader | InMemoryDataLoader,
checkpoint: Literal["local", "wandb"] | None = None,
checkpoint_frequency: int = 10,
learning_rate: float = 1e-4,
model_params: dict[str, Any] | None = None,
run_params: dict[str, Any] | None = None,
Expand Down Expand Up @@ -90,6 +91,7 @@ def train_model(
checkpoint = wandb.restore("checkpoint.pth", run_path)
torch.load(checkpoint.name)
```
checkpoint_frequency (int): How often to save a checkpoint. Defaults to 10.
learning_rate (float): The optimizer's learning rate. Defaults to 1e-4.
model_params (dict): Arguments passed to model class. E.g. dict(n_attn_layers=6,
embedding_aggregation=("mean", "std")) for Wrenformer.
Expand Down Expand Up @@ -227,7 +229,7 @@ def train_model(
**wandb_kwargs or {},
)

for epoch in tqdm(range(epochs), disable=None, desc="Training epoch"):
for epoch in tqdm(range(1, epochs + 1), disable=None, desc="Training epoch"):
train_metrics = model.evaluate(
train_loader,
loss_dict,
Expand Down Expand Up @@ -265,6 +267,25 @@ def train_model(
if wandb_path:
wandb.log({"training": train_metrics, "validation": val_metrics})

if epoch % checkpoint_frequency == 0 and epoch < epochs:
inference_model = swa_model if swa_start else model
inference_model.eval()
checkpoint_model(
checkpoint_endpoint=checkpoint,
model_params=model_params,
inference_model=inference_model,
optimizer_instance=optimizer_instance,
lr_scheduler=lr_scheduler,
loss_dict=loss_dict,
epochs=epoch,
test_metrics=val_metrics,
timestamp=timestamp,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params,
scheduler_name=scheduler_name,
)

# get test set predictions
if swa_start is not None:
n_swa_epochs = int((1 - swa_start) * epochs)
Expand Down Expand Up @@ -327,7 +348,7 @@ def train_model(
loss_dict=loss_dict,
epochs=epochs,
test_metrics=test_metrics,
timestamp=timestamp or datetime.now().astimezone().strftime("%Y%m%d-%H%M%S"),
timestamp=timestamp,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params,
Expand Down Expand Up @@ -365,21 +386,24 @@ def train_model(


def checkpoint_model(
checkpoint_endpoint: str,
checkpoint_endpoint: Literal["local", "wandb"] | None,
model_params: dict | None,
inference_model: nn.Module,
optimizer_instance: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
loss_dict: dict,
epochs: int,
test_metrics: dict,
timestamp: str,
timestamp: str | None,
run_name: str,
normalizer_dict: dict,
run_params: dict,
scheduler_name: str,
):
"""Save model checkpoint to different endpoints."""
if checkpoint_endpoint is None:
return

if model_params is None:
raise ValueError("Must provide model_params to save checkpoint, got None")

Expand All @@ -393,22 +417,29 @@ def checkpoint_model(
metrics=test_metrics,
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params.copy(),
run_params=deepcopy(run_params),
)
if scheduler_name == "LambdaLR":
# exclude lr_lambda from pickled checkpoint since it causes errors when
# torch.load()-ing a checkpoint and the file defining lr_lambda() was
# renamed
checkpoint_dict["run_params"]["lr_scheduler"].pop("params")

if checkpoint_endpoint == "local":
os.makedirs(f"{ROOT}/models", exist_ok=True)
checkpoint_path = f"{ROOT}/models/{timestamp}-{run_name}.pth"
checkpoint_path = (
f"{ROOT}/models/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth"
)
torch.save(checkpoint_dict, checkpoint_path)

if checkpoint_endpoint == "wandb":
assert (
wandb.run is not None
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
torch.save(checkpoint_dict, f"{wandb.run.dir}/checkpoint.pth")
torch.save(
checkpoint_dict,
f"{wandb.run.dir}/{timestamp+'-' if timestamp else ''}{run_name}-{epochs}.pth",
)


def train_wrenformer(
Expand Down