Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 12, 2021
1 parent 3aef4e4 commit 3cc54b8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 120 deletions.
125 changes: 43 additions & 82 deletions pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,67 @@
import logging
import os
from typing import Any, Dict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer


class ToyModel(nn.Module):
class RandomDataset(Dataset):

def __init__(self):
super().__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.len


class ToyTask(pl.LightningModule):
class BoringModel(LightningModule):

def __init__(self):
super().__init__()
self.loss_fn = nn.MSELoss()

def setup(self, stage: str):
if stage == "test":
return
self.setup_model_and_optimizer()
print("setup called")

def setup_model_and_optimizer(self):
self.model = ToyModel()
self.optimizer = AdamW(
self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False
)
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
return self.model(x)
return self.layer(x)

def training_step(self, batch, batch_idx):
targets = self.forward(batch["model_input"])
loss = self.loss_fn(targets, batch["label"])
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}

# Log loss results per train step and per epoch
self.log("loss", loss)
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)

# Tell Lightning to minimize loss
return loss
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)

def configure_optimizers(self):
return self.optimizer

# def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# self.setup_model_and_optimizer()


if __name__ == "__main__":
task = ToyTask()

dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)]

train_dataloader = DataLoader(dataset, batch_size=None)
val_dataloader = DataLoader(dataset, batch_size=None)

model_checkpoint = ModelCheckpoint(
save_last=True,
every_n_val_epochs=1,
return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
trainer.test(model, test_dataloaders=test_data)

trainer = pl.Trainer(
gpus=2,
precision=16,
max_epochs=3,
progress_bar_refresh_rate=100,
log_gpu_memory=None,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
)

results = trainer.fit(task, train_dataloader)

print(model_checkpoint.last_model_path)

trainer = pl.Trainer(
gpus=2,
precision=16,
max_epochs=4,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
resume_from_checkpoint=model_checkpoint.last_model_path,
)
trainer.fit(task, train_dataloader)
if __name__ == '__main__':
run()
33 changes: 0 additions & 33 deletions pl_examples/model_resume.py

This file was deleted.

2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _load_config(self, config):
config = json.load(f)
return config

def pre_dispatch(self) -> None:
def pre_dispatch(self):
self.init_deepspeed()
self.barrier()

Expand Down
5 changes: 1 addition & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ def test_model_checkpoint_only_weights(tmpdir):

# assert restoring train state fails
with pytest.raises(KeyError, match="checkpoint contains only the model"):
trainer.checkpoint_connector.resume_from_checkpoint(new_weights_path)
trainer.checkpoint_connector.resume_start()
trainer.checkpoint_connector.restore_training_state()
trainer.checkpoint_connector.resume_end()
trainer.checkpoint_connector.restore(new_weights_path)


def test_model_freeze_unfreeze():
Expand Down

0 comments on commit 3cc54b8

Please sign in to comment.