diff --git a/pyro/optim/optim.py b/pyro/optim/optim.py index 97db2c2632..b123d26bcb 100644 --- a/pyro/optim/optim.py +++ b/pyro/optim/optim.py @@ -138,10 +138,16 @@ def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None: if self.grad_clip[p] is not None: self.grad_clip[p](p) - if isinstance( - self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler - ) or isinstance( - self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau + if ( + hasattr(torch.optim.lr_scheduler, "_LRScheduler") + and isinstance( + self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler + ) + or hasattr(torch.optim.lr_scheduler, "LRScheduler") + and isinstance(self.optim_objs[p], torch.optim.lr_scheduler.LRScheduler) + or isinstance( + self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau + ) ): # if optim object was a scheduler, perform an optimizer step self.optim_objs[p].optimizer.step(*args, **kwargs)