From fc48a90c3370219afa6abe2da5301522d3f79869 Mon Sep 17 00:00:00 2001 From: Brian Ebiyau <38182764+ebyau@users.noreply.github.com> Date: Tue, 14 Feb 2023 11:46:07 +0300 Subject: [PATCH] tensorflow.roll (#10354) * tensorflow.roll implementation * code formating * added required changes --- .../frontends/tensorflow/general_functions.py | 4 ++ .../test_tensorflow/test_general_functions.py | 46 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/ivy/functional/frontends/tensorflow/general_functions.py b/ivy/functional/frontends/tensorflow/general_functions.py index c6a32edbf3c7d..16e15c95dc23b 100644 --- a/ivy/functional/frontends/tensorflow/general_functions.py +++ b/ivy/functional/frontends/tensorflow/general_functions.py @@ -339,3 +339,7 @@ def where(condition: ivy.array, x=None, y=None, name=None): return ivy.argwhere(condition) else: return ivy.where(condition, x, y) + + +def roll(input, shift, axis, name=None): + return ivy.roll(input, shift, axis=axis) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py index d5d349f2b3350..feb2f8718c0b3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py @@ -1445,3 +1445,49 @@ def test_tensorflow_where_with_xy( x=x, y=y, ) + + +# roll +@handle_frontend_test( + fn_tree="tensorflow.roll", + dtype_and_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + ), + shift=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"), + force_tuple=True, + ), +) +def test_tensorflow_roll( + *, + dtype_and_values, + shift, + axis, + on_device, + fn_tree, + frontend, + test_flags, +): + input_dtype, value = dtype_and_values + if isinstance(shift, int) and isinstance(axis, tuple): + axis = axis[0] + if isinstance(shift, tuple) and isinstance(axis, tuple): + if len(shift) != len(axis): + mn = min(len(shift), len(axis)) + shift = shift[:mn] + axis = axis[:mn] + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=value[0], + shift=shift, + axis=axis, + )