Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent unexpected unpacking error when calling lr_finder.plot() with suggest_lr=True #98

Merged
merged 2 commits into from
Aug 26, 2024

Conversation

NaleRaphael
Copy link
Contributor

@NaleRaphael NaleRaphael commented Aug 3, 2024

@davidtvs, this PR should fix #88. And if it is resolved, maybe we can go on #97 to make a new release to PyPI.
@chAwater, I've rewritten part of the test case you made, please feel free to advise me if there is anything can be improved.

Problem

When calling lr_finder.plot(..., suggest_lr=True), we usually expect the returned value is actually a tuple containing ax and suggested_lr, and the API is also described as it.

But it would return only ax if there is no sufficient data points to calculate gradient of lr-loss curve:

min_grad_idx = None
try:
min_grad_idx = (np.gradient(np.array(losses))).argmin()
except ValueError:
print(
"Failed to compute the gradients, there might not be enough points."
)

if suggest_lr and min_grad_idx is not None:
return ax, lrs[min_grad_idx]
else:
return ax

Therefore, if users prefer unpacking the returned value directly as below, they would sometimes ran into the error TypeError: cannot unpack non-iterable AxesSubplot object.

ax, suggested_lr = lr_finder.plot(ax=ax, suggest_lr=True)

Solution

Always return 2 values if suggest_lr=True. This makes sure it work with the 2 kinds of syntax as follows:

# 1. unpack returned value directly
ax, suggested_lr = lr_finder.plot(ax=ax, suggest_lr=True)

# 2. use a single variable to catch the returned value, then unpack them manually (user can check it before unpacking)
retval = lr_finder.plot(ax=ax, suggest_lr=True)
assert isinstance(retval, tuple) and len(retval) == 2
ax, suggested_lr = retval

The responsibility of "check whether suggested_lr is available/none" is left back to users now. But it should be fine since the warning message would show up and it's easy to check. Also, the warning message is now more verbose to help user figure out the problem.

Note

Though I think this issue could be resolved better by separating the feature of "suggest learning rate" into a new function, it should be safer to keep the API unchanged before we decide to support more different methods to find a suggested learning rate in the future.

As it's mentioned in davidtvs#88, suggested lr would not be returned along
with `ax` (`matplotlib.Axes`) if there is no sufficient data points
to calculate gradient of lr-loss curve.

Though it would warn user about this problem [1], but users might be
confused by another error caused by unpacking returned value. This is
because users would usually expect it works as below:
```python
ax, lr = lr_finder.plot(..., suggest_lr=True)
```

But the second returned value `lr` might not exist when it failed
to find a suggested lr, then the returned value would be a single
value instead. Therefore, the unpacking syntax `ax, lr = ...` would
fail and result in the error reported in davidtvs#88.

So we fix it by always returning both `ax` and `suggested_lr` when
the flag `suggest_lr` is True to meet the expectation, and leave
the responsibility of "check whether `lr` is null" back to user.

[1]: https://github.com/davidtvs/pytorch-lr-finder/blob/fd9e949/torch_lr_finder/lr_finder.py#L539-L542
Copy link
Owner

@davidtvs davidtvs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I agree that suggest_lr should be split from the plot function. It will make the API of plot simpler and allow users to get the suggested LR via a function call - at their own risk of course.

torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
@davidtvs davidtvs mentioned this pull request Aug 25, 2024
…uggest LR

Now LR finder will raise a RuntimeError if there is no sufficient data
points to calculate gradient for suggested LR when
`lr_finder.plot(..., suggest_lr=True)` is called.

The error message will clarify the details of failure, so users can fix
the issue earlier as well.
Copy link
Owner

@davidtvs davidtvs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks @NaleRaphael!

@davidtvs davidtvs merged commit 5b9d92c into davidtvs:master Aug 26, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TypeError: cannot unpack non-iterable AxesSubplot object
2 participants