diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index d554769b7e..66c3a3a113 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -106,17 +106,15 @@ def step(self, batch, batch_idx): def training_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.TrainResult(minimize=loss) - result.log_dict( + self.log_dict( {f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False ) - return result + return loss def validation_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.EvalResult(checkpoint_on=loss) - result.log_dict({f"val_{k}": v for k, v in logs.items()}) - return result + self.log_dict({f"val_{k}": v for k, v in logs.items()}) + return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 40ba1be428..ab6671d000 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -139,17 +139,15 @@ def step(self, batch, batch_idx): def training_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.TrainResult(minimize=loss) - result.log_dict( + self.log_dict( {f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False ) - return result + return loss def validation_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.EvalResult(checkpoint_on=loss) - result.log_dict({f"val_{k}": v for k, v in logs.items()}) - return result + self.log_dict({f"val_{k}": v for k, v in logs.items()}) + return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index 1459327daf..7311cb260b 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -136,18 +136,16 @@ def generator_step(self, x): # log to prog bar on each step AND for the full epoch # use the generator loss for checkpointing - result = pl.TrainResult(minimize=g_loss, checkpoint_on=g_loss) - result.log('g_loss', g_loss, on_epoch=True, prog_bar=True) - return result + self.log('g_loss', g_loss, on_epoch=True, prog_bar=True) + return g_loss def discriminator_step(self, x): # Measure discriminator's ability to classify real from generated samples d_loss = self.discriminator_loss(x) # log to prog bar on each step AND for the full epoch - result = pl.TrainResult(minimize=d_loss) - result.log('d_loss', d_loss, on_epoch=True, prog_bar=True) - return result + self.log('d_loss', d_loss, on_epoch=True, prog_bar=True) + return d_loss def configure_optimizers(self): lr = self.hparams.learning_rate diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index c9b2a2ff4f..79626cd143 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -2,7 +2,7 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.metrics.classification import accuracy +from pytorch_lightning.metrics.functional import accuracy from torch import nn from torch.nn import functional as F from torch.optim import Adam diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 04e7f81ae2..95c68bbee7 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -136,19 +136,17 @@ def training_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - result = pl.TrainResult(minimize=total_loss) - result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) - return result + return total_loss def validation_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss) - result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) - return result + return total_loss def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index b94d4611c0..e6aa3b5b0c 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -35,7 +35,7 @@ class CPCV2(pl.LightningModule): def __init__( self, datamodule: pl.LightningDataModule = None, - encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', + encoder_name: str = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, @@ -50,7 +50,7 @@ def __init__( """ Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly - encoder: A string for any of the resnets in torchvision, or the original CPC encoder, + encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. @@ -66,28 +66,20 @@ def __init__( super().__init__() self.save_hyperparameters() + # HACK - datamodule not pickleable so we remove it from hparams. + # TODO - remove datamodule from init. data should be decoupled from models. + del self.hparams['datamodule'] + self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True - # link data - # if datamodule is None: - # datamodule = CIFAR10DataModule( - # self.hparams.data_dir, - # num_workers=self.hparams.num_workers, - # batch_size=batch_size - # ) - # datamodule.train_transforms = CPCTrainTransformsCIFAR10() - # datamodule.val_transforms = CPCEvalTransformsCIFAR10() assert datamodule self.datamodule = datamodule - # init encoder - self.encoder = encoder - if isinstance(encoder, str): - self.encoder = self.init_encoder() + self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) @@ -97,20 +89,22 @@ def __init__( self.num_classes = self.datamodule.num_classes if pretrained: - self.load_pretrained(encoder) + self.load_pretrained(self.hparams.encoder_name) + + print(self.hparams) - def load_pretrained(self, encoder): + def load_pretrained(self, encoder_name): available_weights = {'resnet18'} - if encoder in available_weights: - load_pretrained(self, f'CPCV2-{encoder}') - elif available_weights not in available_weights: - rank_zero_warn(f'{encoder} not yet available') + if encoder_name in available_weights: + load_pretrained(self, f'CPCV2-{encoder_name}') + elif encoder_name not in available_weights: + rank_zero_warn(f'{encoder_name} not yet available') def init_encoder(self): dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size)) - encoder_name = self.hparams.encoder + encoder_name = self.hparams.encoder_name if encoder_name == 'cpc_encoder': return cpc_resnet101(dummy_batch) else: @@ -160,18 +154,16 @@ def training_step(self, batch, batch_nb): nce_loss = self.shared_step(batch) # result - result = pl.TrainResult(nce_loss) - result.log('train_nce_loss', nce_loss) - return result + self.log('train_nce_loss', nce_loss) + return nce_loss def validation_step(self, batch, batch_nb): # calculate loss nce_loss = self.shared_step(batch) # result - result = pl.EvalResult(checkpoint_on=nce_loss) - result.log('val_nce', nce_loss, prog_bar=True) - return result + self.log('val_nce', nce_loss, prog_bar=True) + return nce_loss def shared_step(self, batch): try: diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 7fbe562827..582883991a 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -157,16 +157,14 @@ def forward(self, x): def training_step(self, batch, batch_idx): loss = self.shared_step(batch, batch_idx) - result = pl.TrainResult(minimize=loss) - result.log('train_loss', loss, on_epoch=True) - return result + self.log('train_loss', loss, on_epoch=True) + return loss def validation_step(self, batch, batch_idx): loss = self.shared_step(batch, batch_idx) - result = pl.EvalResult(checkpoint_on=loss) - result.log('avg_val_loss', loss) - return result + self.log('avg_val_loss', loss) + return loss def shared_step(self, batch, batch_idx): (img1, img2), y = batch diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index d3a3e95377..f07e697a42 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -59,21 +59,18 @@ def on_train_epoch_start(self) -> None: def training_step(self, batch, batch_idx): loss, acc = self.shared_step(batch) - result = pl.TrainResult(loss) - result.log('train_acc', acc, prog_bar=True) - return result + self.log('train_acc', acc, prog_bar=True) + return loss def validation_step(self, batch, batch_idx): loss, acc = self.shared_step(batch) - result = pl.EvalResult(checkpoint_on=loss, early_stop_on=loss) - result.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True) - return result + self.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True) + return loss def test_step(self, batch, batch_idx): loss, acc = self.shared_step(batch) - result = pl.EvalResult() - result.log_dict({'test_acc': acc, 'test_loss': loss}) - return result + self.log_dict({'test_acc': acc, 'test_loss': loss}) + return loss def shared_step(self, batch): x, y = batch diff --git a/requirements/base.txt b/requirements/base.txt index 62766e3de2..c434a7c377 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,2 +1,2 @@ -pytorch-lightning>=0.9.1rc3 +pytorch-lightning>=0.10.0 torch>=1.6 \ No newline at end of file