Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimCLR: using LARS with ddp_sharded causes TypeError: unexpected keyword argument 'lr' #562

Closed
hchau630 opened this issue Feb 17, 2021 · 1 comment · Fixed by #613
Closed
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@hchau630
Copy link

hchau630 commented Feb 17, 2021

🐛 Bug

Using a trainer with accelerator='ddp' and plugins='ddp_sharded' to train a SimCLR model with lars_wrapper=True causes the following error:

Traceback (most recent call last):
  File "/share/ctn/users/hc3190/issa/disentangle/bug.py", line 15, in <module>
    trainer.fit(model, datamodule=dm)
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
    results = self.accelerator_backend.train()
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 158, in train
    results = self.ddp_train(process_idx=self.task_idx, model=model)
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 301, in ddp_train
    model = self.configure_ddp(model, device_ids)
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 318, in configure_ddp
    model = self.ddp_plugin.configure_ddp(model, device_ids)
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 38, in configure_ddp
    self._wrap_optimizers(model)
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 60, in _wrap_optimizers
    self._reinit_with_fairscale_oss(trainer)
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 69, in _reinit_with_fairscale_oss
    zero_optimizer = OSS(
  File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/fairscale/optim/oss.py", line 89, in __init__
    self.optim = optim(self.partition_parameters()[self.rank], **default)
TypeError: __init__() got an unexpected keyword argument 'lr'

The error disappears when either lars_wrapper = False or plugins = None. I suspect this is because the LARSWrapper does not belong to the torch Optimizer class and does not accept the keyword argument 'lr', unlike the usual torch optimizers, but fairscale treats the LARSWrapper as the usual torch Optimizer and passes in the 'lr' keyword argument anyway.

To Reproduce

Run the following code sample.

Code sample

import pytorch_lightning as pl
from pl_bolts.datamodules import ImagenetDataModule
from pl_bolts.models.self_supervised import SimCLR

IMAGENET_DIR_PATH = "/path/to/imagenet"
gpus = 4
batch_size = 32

dm = ImagenetDataModule(data_dir=IMAGENET_DIR_PATH, batch_size=batch_size)
trainer = pl.Trainer(gpus=gpus, accelerator='ddp', plugins='ddp_sharded', fast_dev_run=True)
model = SimCLR(gpus, dm.num_samples, batch_size, 'imagenet', lars_wrapper=True)
trainer.fit(model, datamodule=dm)

Expected behavior

Expect no errors.

Environment

  • PyTorch: 1.7.0
  • Lightning version: 1.1.8
  • Lightning bolts version: 0.3.0
  • Fairscale version: 0.1.6
  • Python version: 3.8.5
  • CUDA version: 11.1
  • GPU models and configuration: GeForce RTX 2080 Ti (x4)
@hchau630 hchau630 added fix fixing issues... help wanted Extra attention is needed labels Feb 17, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda Borda added bug Something isn't working and removed fix fixing issues... labels Jun 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants