Skip to content

Commit

Permalink
fix learning rate schedulers in PyTorch 2.0 (closes #3202) (#3207)
Browse files Browse the repository at this point in the history
_LRScheduler got renamed to LRScheduler
  • Loading branch information
ilia-kats authored May 17, 2023
1 parent 7e4cd1d commit 62651dc
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions pyro/optim/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None:
if self.grad_clip[p] is not None:
self.grad_clip[p](p)

if isinstance(
self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler
) or isinstance(
self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau
if (
hasattr(torch.optim.lr_scheduler, "_LRScheduler")
and isinstance(
self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler
)
or hasattr(torch.optim.lr_scheduler, "LRScheduler")
and isinstance(self.optim_objs[p], torch.optim.lr_scheduler.LRScheduler)
or isinstance(
self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau
)
):
# if optim object was a scheduler, perform an optimizer step
self.optim_objs[p].optimizer.step(*args, **kwargs)
Expand Down

0 comments on commit 62651dc

Please sign in to comment.