-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch frontend lerp #10690
Torch frontend lerp #10690
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! I left a few comments please go over them and ask for another review once you're done.
|
||
@to_ivy_arrays_and_back | ||
def lerp(start, end, weight): | ||
return ivy.lerp(start, end, weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ivy.lerp doesn't exist in the ivy API, you can suggest its addition by following the steps here: https://lets-unify.ai/ivy/deep_dive/ivy_frontends.html#missing-ivy-functions.
That being said the function can simply composed by a few arithmetic operations as described here: https://pytorch.org/docs/stable/generated/torch.lerp.html.
Also please add the out
argument
import torch | ||
from torch.testing import assert_allclose | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove these imports
|
||
|
||
@to_ivy_arrays_and_back | ||
def lerp(start, end, weight): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first argument should be input
as stated in the torch docs: https://pytorch.org/docs/stable/generated/torch.lerp.html
expected_output = torch.lerp(start, end, weight) | ||
output = frontend.torch.lerp(start, end, weight) | ||
|
||
assert_allclose(expected_output, output, rtol=1e-03, atol=1e-05) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to do this here, test_frontend_fuction
handles this
input=(start, end, weight), | ||
expected_output=expected_output, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't pack all input arguments into one, make them seperate keyword args, i.e. input=start
, end=end
, weight=weight
.
Remove expected_output
No need to include the out
argument when you add it
Thank you so much for your detailed review, will help me a lot. |
Hi @fspyridakos, as far as the addition of ivy.lerp function is concerned, I've raised a new issue #10762. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! Also your ivy.lerp issue mentions the .interp
function in jax and numpy, and ivy.interp
already exists, if you think you can use that function here go ahead, but I think lerp
and interp
do different things, just chaning some elementwise functions is fine for now
if out is None: | ||
out = ivy.zeros_like(input) | ||
out = ivy.multiply_add(ivy.subtract(end, input), weight, ivy.subtract(input, out), out=out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most lines are redundant, you can just do return <operation>
as with the above functions (out
in the docs practically means the return
) also look over your implementation, ivy.multiply_add
doesn't exist
fn_tree="torch.lerp", | ||
dtype_and_input=helpers.dtype_and_values( | ||
available_dtypes=helpers.get_dtypes("float"), | ||
default=0.5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what default
is here, but you should probably use num_arrays=3
input_dtype, inputs = dtype_and_input | ||
start, end, weight = inputs | ||
helpers.test_frontend_function( | ||
input_dtypes=input_dtype, | ||
frontend=frontend, | ||
test_flags=test_flags, | ||
fn_tree=fn_tree, | ||
on_device=on_device, | ||
input_start=start, | ||
input_end=end, | ||
input_weight=weight, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix indentation here to match functions above
Hi @fspyridakos, once again thanks a lot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few more chnages and you're golden
@@ -450,3 +450,9 @@ def hypot(input, other, *, out=None): | |||
@to_ivy_arrays_and_back | |||
def sigmoid(input, *, out=None): | |||
return ivy.sigmoid(input, out=out) | |||
|
|||
|
|||
@to_ivy_arrays_and_back |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add float16 and bfloat16 as unsupprted dtypes for torch (like other functions above)
fn_tree="torch.lerp", | ||
dtype_and_input=helpers.dtype_and_values( | ||
available_dtypes=helpers.get_dtypes("float"), | ||
num_arrays=3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the following:
large_abs_safety_factor=2.5, small_abs_safety_factor=2.5, safety_factor_scale="log",
Some large values will not interpolate properly due to float percision limitations, this prevents those from being generated
input_start=start, | ||
input_end=end, | ||
input_weight=weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just have this be input=start, end=end
and so on, the kwargs must have the same names as the function args
Hi @fspyridakos, thanks for all your help. |
LGTM, had to make 2 small changes but it's good now and the test passes. Thanks again for your contribution! |
Thanks a lot for merging @fspyridakos. |
Co-authored-by: Fotios Spyridakos <[email protected]>
Co-authored-by: Fotios Spyridakos <[email protected]>
Co-authored-by: Fotios Spyridakos <[email protected]>
Co-authored-by: Fotios Spyridakos <[email protected]>
Co-authored-by: Fotios Spyridakos <[email protected]>
Closes #10683