You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using a trainer with accelerator='ddp' and plugins='ddp_sharded' to train a SimCLR model with lars_wrapper=True causes the following error:
Traceback (most recent call last):
File "/share/ctn/users/hc3190/issa/disentangle/bug.py", line 15, in <module>
trainer.fit(model, datamodule=dm)
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
results = self.accelerator_backend.train()
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 158, in train
results = self.ddp_train(process_idx=self.task_idx, model=model)
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 301, in ddp_train
model = self.configure_ddp(model, device_ids)
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 318, in configure_ddp
model = self.ddp_plugin.configure_ddp(model, device_ids)
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 38, in configure_ddp
self._wrap_optimizers(model)
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 60, in _wrap_optimizers
self._reinit_with_fairscale_oss(trainer)
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 69, in _reinit_with_fairscale_oss
zero_optimizer = OSS(
File "/home/hc3190/.conda/envs/pytorch_env/lib/python3.8/site-packages/fairscale/optim/oss.py", line 89, in __init__
self.optim = optim(self.partition_parameters()[self.rank], **default)
TypeError: __init__() got an unexpected keyword argument 'lr'
The error disappears when either lars_wrapper = False or plugins = None. I suspect this is because the LARSWrapper does not belong to the torch Optimizer class and does not accept the keyword argument 'lr', unlike the usual torch optimizers, but fairscale treats the LARSWrapper as the usual torch Optimizer and passes in the 'lr' keyword argument anyway.
🐛 Bug
Using a trainer with accelerator='ddp' and plugins='ddp_sharded' to train a SimCLR model with lars_wrapper=True causes the following error:
The error disappears when either lars_wrapper = False or plugins = None. I suspect this is because the LARSWrapper does not belong to the torch Optimizer class and does not accept the keyword argument 'lr', unlike the usual torch optimizers, but fairscale treats the LARSWrapper as the usual torch Optimizer and passes in the 'lr' keyword argument anyway.
To Reproduce
Run the following code sample.
Code sample
Expected behavior
Expect no errors.
Environment
The text was updated successfully, but these errors were encountered: