Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

With yaml config file for LightningCLI, self.save_hyperparameters() behavior abnormal #19977

Closed
t4rf9 opened this issue Jun 15, 2024 · 10 comments · Fixed by #20068
Closed

With yaml config file for LightningCLI, self.save_hyperparameters() behavior abnormal #19977

t4rf9 opened this issue Jun 15, 2024 · 10 comments · Fixed by #20068
Labels
bug Something isn't working lightningcli pl.cli.LightningCLI ver: 2.2.x
Milestone

Comments

@t4rf9
Copy link

t4rf9 commented Jun 15, 2024

Bug description

With yaml config file for LightningCLI, self.save_hyperparameters() in __init__ of the model and datamodule mistakenly saves a dict containing keys like class_path and init_args.

This problems appears in version 2.3.0, but version 2.2.5 works correctly.

What version are you seeing the problem on?

2.3.0

How to reproduce the bug

config.yaml

ckpt_path: null
seed_everything: 0
model:
  class_path: model.Model
  init_args:
    learning_rate: 1e-2
data:
  class_path: datamodule.DataModule
  init_args:
    data_dir: data
trainer:
  accelerator: auto
  strategy: auto
  devices: auto
  num_nodes: 1
  precision: null
  fast_dev_run: false
  max_epochs: 100
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: 10
  limit_test_batches: null
  limit_predict_batches: null
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: lightning_logs
      name: normalized
  callbacks:
    class_path: lightning.pytorch.callbacks.ModelCheckpoint
    init_args:
      save_top_k: 5
      monitor: valid_loss
      filename: "{epoch}-{step}-{valid_loss:.8f}"
  overfit_batches: 0.0
  val_check_interval: 50
  check_val_every_n_epoch: 1
  num_sanity_val_steps: null
  log_every_n_steps: 50
  enable_checkpointing: null
  enable_progress_bar: null
  enable_model_summary: null
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: true
  benchmark: null
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: true
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null

model.py

import torch
from torch import nn
import torch.nn.functional as F
import lightning as pl


class Model(pl.LightningModule):
    def __init__(self, learning_rate: float):
        super().__init__()

        print()
        print("Model:")

        print(f"learning_rate: {learning_rate}")
        ## This outputs correctly.

        self.save_hyperparameters()
        
        print(self.hparams)
        ## This outputs:
        # "_instantiator": lightning.pytorch.cli.instantiate_module
        # "class_path":    model.Model
        # "init_args":     {'learning_rate': 0.01}

datamodule.py

from lightning import LightningDataModule
from torch.utils.data import DataLoader

from dataset import KaptchaDataset
from transform import Transform


class DataModule(LightningDataModule):
    def __init__(self, data_dir: str):
        super().__init__()
        self.save_hyperparameters()

        print()
        print("DataModule:")

        print(self.hparams)
        ## This outputs
        # "_instantiator": lightning.pytorch.cli.instantiate_module
        # "class_path":    datamodule.DataModule
        # "init_args":     {'data_dir': 'data'}

main.py

from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from model import Model
from datamodule import DataModule


def cli_main():
    cli = LightningCLI()


if __name__ == "__main__":
    cli_main()

Run python main.py fit --config config.yaml


### Environment

<details>
  <summary>Current environment</summary>

#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):
#- Running environment of LightningApp (e.g. local, cloud):


</details>


cc @carmocca @mauvilsa
@t4rf9 t4rf9 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 15, 2024
@adamjstewart
Copy link
Contributor

We're seeing this too, this broke all of TorchGeo's tests: https://github.com/microsoft/torchgeo/actions/runs/9522133755/job/26251028463?pr=2119

@EthanMarx
Copy link

+1

@CarlosGomes98
Copy link

CarlosGomes98 commented Jun 28, 2024

We are seeing this as well IBM/terratorch#26

As far as I can tell it stems from #19771 which (inadvertedly?) affects the LightningCLI parser

@adamjstewart
Copy link
Contributor

Still broken in 2.3.1, still preventing TorchGeo from supporting newer versions of Lightning.

@adamjstewart
Copy link
Contributor

Still broken in 2.3.2.

@adamjstewart
Copy link
Contributor

As far as I can tell it stems from #19771 which (inadvertedly?) affects the LightningCLI parser

I checked the commits before and after, but it's broken before and after, so I think this PR is unrelated.

@adamjstewart
Copy link
Contributor

Thanks for the steps to reproduce @t4rf9! Note that datamodule.py contains a couple extra imports that need to be removed.

I would also like to add one other file test.py:

from lightning.pytorch import Trainer

from model import Model
from datamodule import DataModule


if __name__ == "__main__":
    model = Model(learning_rate=1e-2)
    datamodule = DataModule(data_dir='data')
    trainer = Trainer()
    trainer.fit(model=model, datamodule=datamodule)

By running this, you'll see that the extra class_path and init_args do not exist for non-CLI usage. Therefore, this isn't a simple "backwards incompatible change" that requires changing how users access hparams. It is a true "bug" such that if you use save_hyperparameters, the same code cannot be used for both CLI and non-CLI usage.

@adamjstewart
Copy link
Contributor

I tracked down the source of the bug as follows.

config.yaml

model:
  class_path: model.Model
  init_args:
    learning_rate: 1e-2
data:
  class_path: lightning.pytorch.demos.boring_classes.BoringDataModule
trainer:
  fast_dev_run: true

main.py

from lightning.pytorch.cli import LightningCLI


def cli_main():
    cli = LightningCLI()


if __name__ == '__main__':
    cli_main()

model.py

from lightning.pytorch import LightningModule


class Model(LightningModule):
    def __init__(self, learning_rate):
        super().__init__()
        self.save_hyperparameters()
        assert 'learning_rate' in self.hparams

    def training_step(*args):
        pass

    def configure_optimizers(*args):
        pass

Then run:

> git bisect start
> git checkout 2.2.0
> git bisect good
> git checkout 2.3.0
> git bisect bad
> git bisect run python3 main.py fit --config config.yaml

This reported that the bug first appears in #18105. Pinging everyone involved in that PR: @mauvilsa @awaelchli @carmocca @Borda

@mauvilsa
Copy link
Contributor

mauvilsa commented Jul 9, 2024

I suppose in #18105 we didn't notice how this could potentially break things. But I think it can be easily fixed. I will create a pull request for it tomorrow.

@mauvilsa
Copy link
Contributor

Created a pull request #20068.

@t4rf9, @adamjstewart, please test it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working lightningcli pl.cli.LightningCLI ver: 2.2.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants