diff --git a/.github/workflows/intelligent-tests.yml b/.github/workflows/intelligent-tests.yml
index 0cd2c83ece094..9059f274e1005 100644
--- a/.github/workflows/intelligent-tests.yml
+++ b/.github/workflows/intelligent-tests.yml
@@ -46,7 +46,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -93,7 +93,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -139,7 +139,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -186,7 +186,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -232,7 +232,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -278,7 +278,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -324,7 +324,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -370,7 +370,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -416,7 +416,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -462,7 +462,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -508,7 +508,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -554,7 +554,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -600,7 +600,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -646,7 +646,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -692,7 +692,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -738,7 +738,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -784,7 +784,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -830,7 +830,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -876,7 +876,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -922,7 +922,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -968,7 +968,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1014,7 +1014,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1060,7 +1060,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1106,7 +1106,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1152,7 +1152,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1198,7 +1198,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1244,7 +1244,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1290,7 +1290,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1336,7 +1336,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1382,7 +1382,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1428,7 +1428,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
@@ -1474,7 +1474,7 @@ jobs:
id: tests
run: |
cd ivy
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} 'false' ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
diff --git a/.github/workflows/manual-tests.yml b/.github/workflows/manual-tests.yml
index 5d4a9d13fbcf5..3a3746bf4960c 100644
--- a/.github/workflows/manual-tests.yml
+++ b/.github/workflows/manual-tests.yml
@@ -38,7 +38,7 @@ jobs:
pip install pymongo
cd ivy
python setup_tests.py ${{ github.event.inputs.test }}
- python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} ${{ github.event.inputs.version}} ${{ steps.jobs.outputs.html_url }}
+ python run_tests.py ${{ secrets.REDIS_CONNECTION_URL }} ${{ secrets.REDIS_PASSWORD }} ${{ secrets.MONGODB_PASSWORD }} ${{ github.event.inputs.version}} ${{ github.run_id }} ${{ steps.jobs.outputs.html_url }}
continue-on-error: true
- name: Check on failures
diff --git a/.idea/ivy.iml b/.idea/ivy.iml
index f4b5a229e00b1..533092f4ecf7f 100644
--- a/.idea/ivy.iml
+++ b/.idea/ivy.iml
@@ -2,14 +2,14 @@
-
+
+
-
-
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 8cc1b33864f84..1dc76391487a0 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,8 +3,8 @@
-
+
-
+
\ No newline at end of file
diff --git a/docker/rebuild_all_dockerfiles.sh b/docker/rebuild_all_dockerfiles.sh
index 792fa432bd510..386faff74fae5 100755
--- a/docker/rebuild_all_dockerfiles.sh
+++ b/docker/rebuild_all_dockerfiles.sh
@@ -1,5 +1,4 @@
#!/bin/bash
docker build -t unifyai/ivy:latest --no-cache -f Dockerfile ..
-docker build -t unifyai/ivy:latest-gpu --no-cache DockerfileGPU ..
-docker build -t unifyai/ivy:latest-copsim --no-cache DockerfileCopsim ..
+docker build -t unifyai/ivy:latest-gpu --no-cache -f DockerfileGPU ..
diff --git a/docs/partial_source/deep_dive/arrays.rst b/docs/partial_source/deep_dive/arrays.rst
index 0744f3ce0795b..5ed5fa9c178e0 100644
--- a/docs/partial_source/deep_dive/arrays.rst
+++ b/docs/partial_source/deep_dive/arrays.rst
@@ -24,7 +24,8 @@ Arrays
.. _`arrays channel`: https://discord.com/channels/799879767196958751/933380487353872454
.. _`arrays forum`: https://discord.com/channels/799879767196958751/1028296936203235359
.. _`wrapped logic`: https://github.com/unifyai/ivy/blob/6a729004c5e0db966412b00aa2fce174482da7dd/ivy/func_wrapper.py#L95
-
+.. _`NumPy's`: https://numpy.org/doc/stable/user/basics.dispatch.html#basics-dispatch
+.. _`PyTorch's`: https://pytorch.org/docs/stable/notes/extending.html#extending-torch
There are two types of array in Ivy, there is the :class:`ivy.NativeArray` and also the :class:`ivy.Array`.
Native Array
@@ -105,6 +106,68 @@ For the reasons explained above, this would be a problem.
Therefore, all compositional functions have a separate piece of `wrapped logic`_ to ensure that all :class:`ivy.NativeArray` instances are converted to :class:`ivy.Array` instances before entering into the compositional function.
+Integrating custom classes with Ivy
+-----------------------------------
+
+Ivy's functional API and its functions can easily be integrated with non-Ivy classes. Whether these classes are ones that inherit from Ivy or completely standalone custom classes, using Ivy's :code:`__array_function__`, Ivy's functions can handle inputs of those types.
+
+To make use of that feature, the class must contain an implementation for these functions and it must contain an implementation for the function :code:`__array_function__`. If a non-Ivy class is passed to an Ivy function, a call to this class's :code:`__array_function__` is made which directs Ivy's function to handle that input type correctly. This allows users to define custome implementations for any of the functions that can be found in Ivy's functional API which would further make it easy to integrate those classes with other Ivy projects.
+
+**Note**
+This functionality is inspired by `NumPy's`_ :code:`__array_function__` and `PyTorch's`_ :code:`__torch_function__`.
+
+As an example, consider the following class :code:`MyArray` with the following definition:
+
+.. code-block:: python
+>>> class MyArray:
+>>> def __init__(self, data=None):
+>>> self.data = data
+
+Running any of Ivy’s functions using a :code:`MyArray` object as input will throw an :code:`IvyBackendException` since Ivy’s functions do not support this class type as input. This is where :code:`__array_function__` comes into play. Let’s add the method to our :code:`MyArray` class to see how it works.
+
+There are different ways to do so. One way is to use a global dict :code:`HANDLED_FUNCTIONS` which will map Ivy’s functions to the custom variant functions:
+
+.. code-block:: python
+>>> HANDLED_FUNCTIONS = {}
+>>> class MyArray:
+>>> def __init__(self, data=None):
+>>> self.data = data
+>>> def __array_function__(self, func, types, args, kwargs):
+>>> if func not in HANDLED_FUNCTIONS:
+>>> return NotImplemented
+>>> if not all((t, (MyArray, ivy.Array, ivy.NativeArray)) for t in types):
+>>> return NotImplemented
+>>> return HANDLED_FUNCTIONS[func](*args, **kwargs)
+
+:code:`__array_function__` accepts four parameters: :code:`func` representing a reference to the array API function being
+overridden, :code:`types` a list of the types of objects implementing :code:`__array_function__`, :code:`args` a tuple of arguments supplied to the function, and :code:`kwargs` being a dictionary of keyword arguments passed to the function.
+While this class contains an implementation for :code:`__array_function__`, it is still not enough as it is necessary to implement any needed Ivy functions with the new :code:`MyArray` class as input(s) for the code to run successfully.
+We will define a decorator function :code:`implements` that can be used to add functions to :code:`HANDLED_FUNCTIONS`:
+
+.. code-block:: python
+>>> def implements(ivy_function):
+>>> def decorator(func):
+>>> HANDLED_FUNCTIONS[ivy_function] = func
+>>> return func
+>>> return decorator
+
+Lastly, we need to apply that decorator to the override function. Let’s consider for example a function that overrides :code:`ivy.abs`:
+
+.. code-block:: python
+>>> @implements(ivy.abs)
+>>> def my_abs(my_array, ivy_array):
+>>> my_array.data = abs(my_array.data)
+
+Now that we have added the function to :code:`HANDLED_FUNCTIONS`, we can now use :code:`ivy.abs` with :code:`MyArray` objects:
+.. code-block:: python
+
+>>> X = MyArray(-3)
+>>> X = ivy.abs(X)
+
+Of course :code:`ivy.abs` is an example of a function that is easy to override since it only requires one operand. The same approach can be used to override functions with multiple operands, including arrays or array-like objects that define :code:`__array_function__`.
+
+It is relevant to mention again that any function not stored inside the dict :code:`HANDLED_FUNCTIONS` will not work and it is also important to notice that the operands passed to the function must match that of the function stored in the dict. For instance :code:`my_abs` takes only one parameter which is a :code:`MyArray` object. So, passing any other operands to the function will result in an exception :code:`IvyBackendException` being thrown. Lastly, for a custom class to be covered completely with Ivy's functional API, it is necessary to create an implementation for all the relevant functions within the API that will be used by this custom class. That can be all the functions in the API or only a subset of them.
+
**Round Up**
This should have hopefully given you a good feel for the different types of arrays, and how these are handled in Ivy.
@@ -118,4 +181,4 @@ If you have any questions, please feel free to reach out on `discord`_ in the `a
\ No newline at end of file
+
diff --git a/ivy/__init__.py b/ivy/__init__.py
index 97ad0fa6776ec..7c02fa7318b95 100644
--- a/ivy/__init__.py
+++ b/ivy/__init__.py
@@ -24,9 +24,13 @@
import jaxlib
except ImportError:
jax = SimpleNamespace()
+ jax.interpreters = SimpleNamespace()
+ jax.interpreters.xla = SimpleNamespace()
jax.interpreters.xla._DeviceArray = SimpleNamespace()
+ jaxlib = SimpleNamespace()
+ jaxlib.xla_extension = SimpleNamespace()
jaxlib.xla_extension.DeviceArray = SimpleNamespace()
- jax.Buffer = SimpleNamespace()
+ jaxlib.xla_extension.Buffer = SimpleNamespace()
warnings.filterwarnings("ignore", module="^(?!.*ivy).*$")
@@ -203,7 +207,7 @@ def __new__(cls, shape_tup):
torch.Size,
jax.interpreters.xla._DeviceArray,
jaxlib.xla_extension.DeviceArray,
- jax.Buffer,
+ jax.xla_extension.Buffer,
np.ndarray,
tf.Tensor,
)
@@ -302,6 +306,7 @@ class Node(str):
array_decimal_values_stack = list()
warning_level_stack = list()
nan_policy_stack = list()
+dynamic_backend_stack = list()
warn_to_regex = {"all": "!.*", "ivy_only": "^(?!.*ivy).*$", "none": ".*"}
@@ -838,7 +843,6 @@ class GlobalsDict(dict):
"valid_dtypes": valid_dtypes,
"valid_numeric_dtypes": valid_numeric_dtypes,
"valid_int_dtypes": valid_int_dtypes,
- "valid_int_dtypes": valid_int_dtypes,
"valid_uint_dtypes": valid_uint_dtypes,
"valid_complex_dtypes": valid_complex_dtypes,
"valid_devices": valid_devices,
@@ -862,6 +866,7 @@ class GlobalsDict(dict):
"default_int_dtype_stack": data_type.default_int_dtype_stack,
"default_uint_dtype_stack": data_type.default_uint_dtype_stack,
"nan_policy_stack": nan_policy_stack,
+ "dynamic_backend_stack": dynamic_backend_stack,
}
)
@@ -1094,3 +1099,55 @@ def unset_nan_policy():
global nan_policy_stack
if nan_policy_stack:
nan_policy_stack.pop(-1)
+
+
+# Dynamic Backend
+
+
+def get_dynamic_backend():
+ """Returns the current dynamic backend setting, with the default being True"""
+ global dynamic_backend_stack
+ if not dynamic_backend_stack:
+ return True
+ else:
+ return dynamic_backend_stack[-1]
+
+
+def set_dynamic_backend(flag):
+ """Sets the global dynamic backend setting to the provided flag (True or False)"""
+ global dynamic_backend_stack
+ if flag not in [True, False]:
+ raise ValueError("dynamic_backend must be a boolean value (True or False)")
+ dynamic_backend_stack.append(flag)
+
+
+def unset_dynamic_backend():
+ """
+ Removes the current dynamic backend setting,
+ restoring the previous setting (if any)
+ """
+ global dynamic_backend_stack
+ if dynamic_backend_stack:
+ dynamic_backend_stack.pop()
+
+
+# Context Managers
+
+
+class DynamicBackendContext:
+ def __init__(self, value):
+ self.value = value
+ self.original = None
+
+ def __enter__(self):
+ self.original = get_dynamic_backend()
+ set_dynamic_backend(self.value)
+
+ def __exit__(self, type, value, traceback):
+ unset_dynamic_backend()
+ if self.original is not None:
+ set_dynamic_backend(self.original)
+
+
+def dynamic_backend_as(value):
+ return DynamicBackendContext(value)
diff --git a/ivy/array/array.py b/ivy/array/array.py
index e95e61e1e9023..25f71a8c54a7a 100644
--- a/ivy/array/array.py
+++ b/ivy/array/array.py
@@ -72,7 +72,7 @@ class Array(
ArrayWithStatisticalExperimental,
ArrayWithUtilityExperimental,
):
- def __init__(self, data):
+ def __init__(self, data, dynamic_backend=None):
ArrayWithActivations.__init__(self)
ArrayWithCreation.__init__(self)
ArrayWithDataTypes.__init__(self)
@@ -112,9 +112,9 @@ def __init__(self, data):
ArrayWithSortingExperimental.__init__(self),
ArrayWithStatisticalExperimental.__init__(self),
ArrayWithUtilityExperimental.__init__(self),
- self._init(data)
+ self._init(data, dynamic_backend)
- def _init(self, data):
+ def _init(self, data, dynamic_backend=None):
if ivy.is_ivy_array(data):
self._data = data.data
else:
@@ -135,10 +135,46 @@ def _init(self, data):
else:
self._post_repr = ")"
self.backend = ivy.current_backend_str()
+ if dynamic_backend is not None:
+ self._dynamic_backend = dynamic_backend
+ else:
+ self._dynamic_backend = ivy.get_dynamic_backend()
# Properties #
# ---------- #
+ @property
+ def dynamic_backend(self):
+ return self._dynamic_backend
+
+ @dynamic_backend.setter
+ def dynamic_backend(self, value):
+ from ivy.functional.ivy.gradients import _variable
+ from ivy.backend_handler import _determine_backend_from_args
+
+ if value == False:
+ self._backend = _determine_backend_from_args(self)
+
+ else:
+ is_variable = self._backend.is_variable
+ to_numpy = self._backend.to_numpy
+ variable_data = self._backend.variable_data
+
+ if is_variable(self.data) and not (
+ str(self._backend).__contains__("jax")
+ or str(self._backend).__contains__("numpy")
+ ):
+ native_data = variable_data(self.data)
+ np_data = to_numpy(native_data)
+ new_arr = ivy.array(np_data)
+ self._data = _variable(new_arr).data
+
+ else:
+ np_data = to_numpy(self.data)
+ self._data = ivy.array(np_data).data
+
+ self._dynamic_backend = value
+
@property
def data(self) -> ivy.NativeArray:
"""The native array being wrapped in self."""
@@ -217,6 +253,24 @@ def __torch_function__(cls, func, types, args=(), kwargs={}):
args, kwargs = args_to_native(*args, **kwargs)
return func(*args, **kwargs)
+ def __array_function__(self, func, types, args, kwargs):
+ # Cannot handle items that have __array_function__ other than those of
+ # ivy arrays or native arrays.
+ for t in types:
+ if (
+ hasattr(t, "__array_function__")
+ and (t.__array_function__ is not ivy.Array.__array_function__)
+ or (
+ hasattr(ivy.NativeArray, "__array_function__")
+ and (t.__array_function__ is not ivy.NativeArray.__array_function__)
+ )
+ ):
+ return NotImplemented
+
+ # Arguments contain no overrides, so we can safely call the
+ # overloaded function again.
+ return func(*args, **kwargs)
+
def __array__(self, *args, **kwargs):
args, kwargs = args_to_native(*args, **kwargs)
return self._data.__array__(*args, **kwargs)
diff --git a/ivy/array/experimental/activations.py b/ivy/array/experimental/activations.py
index 1d39a81b02ff2..527fb0af9a1ab 100644
--- a/ivy/array/experimental/activations.py
+++ b/ivy/array/experimental/activations.py
@@ -80,3 +80,32 @@ def thresholded_relu(
ivy.array([0., 0., 1.])
"""
return ivy.thresholded_relu(self._data, threshold=threshold, out=out)
+
+ def prelu(
+ self,
+ slope: Union[float, ivy.NativeArray, ivy.Array],
+ /,
+ *,
+ out: Optional["ivy.Array"] = None,
+ ) -> ivy.Array:
+ """
+ Prelu takes input data (Array) and slope array as input,
+ and produces one output data (array) where the function
+ f(x) = slope * x for x < 0, f(x) = x for x >= 0., is applied
+ to the data array elementwise. This operator supports unidirectional
+ broadcasting (array slope should be unidirectional broadcastable to
+ input tensor X);
+
+ Parameters
+ ----------
+ self
+ input array.
+ slope
+ Slope Array. The shape of slope can be smaller then first input X;
+ if so, its shape must be unidirectional broadcastable to X.
+ out
+ Optional output array.
+ Returns
+ -------
+ """
+ return ivy.prelu(self._data, slope, out=out)
diff --git a/ivy/array/experimental/elementwise.py b/ivy/array/experimental/elementwise.py
index 26337b16ee9d6..8df937aa86393 100644
--- a/ivy/array/experimental/elementwise.py
+++ b/ivy/array/experimental/elementwise.py
@@ -849,9 +849,13 @@ def allclose(
)
def diff(
- self: Union[ivy.Array, int, float, list, tuple],
+ self: ivy.Array,
/,
*,
+ n: int = 1,
+ axis: int = -1,
+ prepend: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
+ append: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.diff. This method simply
@@ -862,6 +866,16 @@ def diff(
----------
self
array-like input.
+ n
+ The number of times values are differenced. If zero, the input is returned
+ as-is.
+ axis
+ The axis along which the difference is taken, default is the last axis.
+ prepend,append
+ Values to prepend/append to x along given axis prior to performing the
+ difference. Scalar values are expanded to arrays with length 1 in the
+ direction of axis and the shape of the input array in along all other
+ axes. Otherwise the dimension and shape must match x except along axis.
out
optional output array, for writing the result to.
@@ -872,15 +886,13 @@ def diff(
Examples
--------
- >>> x = ivy.Container(a=ivy.array([1, 2, 4, 7, 0]),\
- b=ivy.array([1, 2, 4, 7, 0]))
- >>> ivy.Container.static_diff(x)
- {
- a: ivy.array([ 1, 2, 3, -7])
- b: ivy.array([ 1, 2, 3, -7])
- }
+ >>> x = ivy.array([1, 2, 4, 7, 0])
+ >>> x.diff()
+ ivy.array([ 1, 2, 3, -7])
"""
- return ivy.diff(self._data, out=out)
+ return ivy.diff(
+ self._data, n=n, axis=axis, prepend=prepend, append=append, out=out
+ )
def fix(
self: ivy.Array,
diff --git a/ivy/array/experimental/layers.py b/ivy/array/experimental/layers.py
index 5067a136d4cdf..870345697a513 100644
--- a/ivy/array/experimental/layers.py
+++ b/ivy/array/experimental/layers.py
@@ -1,6 +1,6 @@
# global
import abc
-from typing import Optional, Union, Tuple, Literal
+from typing import Optional, Union, Tuple, Literal, Sequence
# local
import ivy
@@ -76,6 +76,8 @@ def max_pool2d(
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
@@ -135,6 +137,8 @@ def max_pool2d(
strides,
padding,
data_format=data_format,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
out=out,
)
@@ -630,3 +634,57 @@ def dft(
norm=norm,
out=out,
)
+
+ def interpolate(
+ self,
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Union[Literal["linear", "bilinear", "trilinear", "nearest"]] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+ out: Optional[ivy.Array] = None,
+ ) -> ivy.Array:
+ """
+ Down/up samples the input to the given size.
+ The algorithm used for interpolation is determined by mode.
+
+ Parameters
+ ----------
+ self
+ Input array, Must have the shape
+ [batch x channels x [optional depth] x [optional height] x width].
+ size
+ Output size.
+ mode
+ Interpolation mode. Can be one of the following:
+ - linear
+ - bilinear
+ - trilinear
+ - nearest
+ align_corners
+ If True, the corner pixels of the input and output tensors are aligned,
+ and thus preserving the values at the corner pixels. If False, the corner
+ pixels are not aligned, and the interpolation uses edge value padding for
+ out-of-boundary values.
+ only has an effect when mode is 'linear', 'bilinear',
+ 'bicubic' or 'trilinear'. Default: False
+ antialias
+ If True, antialiasing is applied when downsampling an image.
+ Supported modes: 'bilinear', 'bicubic'.
+ out
+ Optional output array, for writing the result to. It must
+ have a shape that the inputs broadcast to.
+
+ Returns
+ -------
+ resized array
+ """
+ return ivy.interpolate(
+ self._data,
+ size,
+ mode=mode,
+ align_corners=align_corners,
+ antialias=antialias,
+ out=out,
+ )
diff --git a/ivy/array/experimental/manipulation.py b/ivy/array/experimental/manipulation.py
index c99e5abd6aa85..b3b7afa9b3ec0 100644
--- a/ivy/array/experimental/manipulation.py
+++ b/ivy/array/experimental/manipulation.py
@@ -750,11 +750,9 @@ def take_along_axis(
def hsplit(
self: ivy.Array,
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[ivy.Array] = None,
- ) -> ivy.Array:
+ ) -> List[ivy.Array]:
"""
ivy.Array instance method variant of ivy.hsplit. This method simply
wraps the function, and so the docstring for ivy.hsplit also applies
@@ -765,20 +763,15 @@ def hsplit(
self
Input array.
indices_or_sections
- If indices_or_sections is an integer n, the array is split into n sections.
- If the array is divisible by n horizontally, each section will be of
- equal size. If input is not divisible by n, the sizes of the first
- int(ary.size(0) % n) sections will have size int(ary.size(0) / n) + 1, and
- the rest will have size int(ary.size(0) / n).
+ If indices_or_sections is an integer n, the array is split into n
+ equal sections, provided that n must be a divisor of the split axis.
If indices_or_sections is a tuple of ints, then input is split at each of
the indices in the tuple.
- out
- Optional output, for writing the result to.
Returns
-------
ret
- input array split horizontally.
+ list of arrays split horizontally from input array.
Examples
--------
@@ -798,7 +791,7 @@ def hsplit(
[10., 11.],
[14., 15.]]))
"""
- return ivy.hsplit(self._data, indices_or_sections, out=out)
+ return ivy.hsplit(self._data, indices_or_sections)
def expand(
self: ivy.Array,
diff --git a/ivy/array/experimental/norms.py b/ivy/array/experimental/norms.py
index 651b3419035cc..358ab86bfaee3 100644
--- a/ivy/array/experimental/norms.py
+++ b/ivy/array/experimental/norms.py
@@ -122,3 +122,33 @@ def instance_norm(
track_running_stats=track_running_stats,
out=out,
)
+
+ def lp_normalize(self, /, *, p: float = 2, axis=None, out=None):
+ """Normalizes the array to have Lp norm.
+
+ Parameters
+ ----------
+ self
+ Input array.
+ p
+ p-norm to use for normalization.
+ axis
+ Axis along which to normalize. If ``None``, the whole array
+ is normalized.
+ out
+ optional output array, for writing the result to. It must have a
+ shape that the inputs broadcast to.
+
+ Returns
+ -------
+ ret
+ The normalized array.
+
+ Examples
+ --------
+ >>> x = ivy.array([[1., 2.], [3., 4.]])
+ >>> x.lp_normalize(p=2, axis=1)
+ ivy.array([[0.4472, 0.8944],
+ [0.6, 0.8]])
+ """
+ return ivy.lp_normalize(self, p=p, axis=axis, out=out)
diff --git a/ivy/array/experimental/statistical.py b/ivy/array/experimental/statistical.py
index 051a1d7aabf98..dea12138e7ce1 100644
--- a/ivy/array/experimental/statistical.py
+++ b/ivy/array/experimental/statistical.py
@@ -256,3 +256,61 @@ def corrcoef(
[-1., nan, 1.]])
"""
return ivy.corrcoef(self._data, y=y, rowvar=rowvar, out=out)
+
+ def nanmedian(
+ self: ivy.Array,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ overwrite_input: Optional[bool] = False,
+ out: Optional[ivy.Array] = None,
+ ) -> ivy.Array:
+ """ivy.Array instance method variant of ivy.nanmedian. This method simply
+ wraps the function, and so the docstring for ivy.nanmedian also applies to
+ this method with minimal changes.
+
+ Parameters
+ ----------
+ self
+ Input array.
+ axis
+ Axis or axes along which the means are computed.
+ The default is to compute the mean of the flattened array.
+ keepdims
+ If this is set to True, the axes which are reduced are left in the result
+ as dimensions with size one. With this option, the result will broadcast
+ correctly against the original a. If the value is anything but the default,
+ then keepdims will be passed through to the mean or sum methods of
+ sub-classes of ndarray. If the sub-classes methods does not implement
+ keepdims any exceptions will be raised.
+ overwrite_input
+ If True, then allow use of memory of input array a for calculations.
+ The input array will be modified by the call to median. This will
+ save memory when you do not need to preserve the contents of the input array.
+ Treat the input as undefined, but it will probably be fully or partially sorted.
+ Default is False. If overwrite_input is True and a is not already an ndarray,
+ an error will be raised.
+ out
+ optional output array, for writing the result to.
+
+ Returns
+ -------
+ ret
+ A new array holding the result. If the input contains integers
+
+ Examples
+ >>> a = ivy.Array([[10.0, ivy.nan, 4], [3, 2, 1]])
+ >>> a.nanmedian(a)
+ 3.0
+ >>> a.nanmedian(a, axis=0)
+ array([6.5, 2. , 2.5])
+ """
+
+ return ivy.nanmedian(
+ self._data,
+ axis=axis,
+ keepdims=keepdims,
+ overwrite_input=overwrite_input,
+ out=out,
+ )
diff --git a/ivy/array/layers.py b/ivy/array/layers.py
index 69b6eb4db84f9..e30608deb920a 100644
--- a/ivy/array/layers.py
+++ b/ivy/array/layers.py
@@ -453,12 +453,12 @@ def depthwise_conv2d(
def conv2d(
self: ivy.Array,
filters: Union[ivy.Array, ivy.NativeArray],
- strides: Union[int, Tuple[int], Tuple[int, int]],
+ strides: Union[int, Tuple[int, int]],
padding: str,
/,
*,
data_format: str = "NHWC",
- dilations: Optional[Union[int, Tuple[int], Tuple[int, int]]] = 1,
+ dilations: Union[int, Tuple[int, int]] = 1,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
@@ -468,8 +468,8 @@ def conv2d(
Parameters
----------
- x
- Input image *[batch_size,h,w,d_in]*.
+ self
+ Input image *[batch_size,h,w,d_in]* or *[batch_size,d_in,h,w]*.
filters
Convolution filters *[fh,fw,d_in,d_out]*.
strides
diff --git a/ivy/array/linear_algebra.py b/ivy/array/linear_algebra.py
index 6699a3b936876..ee5ec75577098 100644
--- a/ivy/array/linear_algebra.py
+++ b/ivy/array/linear_algebra.py
@@ -16,6 +16,8 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
@@ -57,7 +59,8 @@ def matmul(
ivy.array(11.)
"""
return ivy.matmul(
- self._data, x2, transpose_a=transpose_a, transpose_b=transpose_b, out=out
+ self._data, x2, transpose_a=transpose_a, transpose_b=transpose_b,
+ adjoint_a=adjoint_a, adjoint_b=adjoint_b, out=out
)
def cholesky(
diff --git a/ivy/assertions.py b/ivy/assertions.py
index 9911849e497d8..c78b1cd37844d 100644
--- a/ivy/assertions.py
+++ b/ivy/assertions.py
@@ -247,3 +247,17 @@ def check_dimensions(x):
"input must have greater than one dimension; "
+ " {} has {} dimensions".format(x, len(x.shape))
)
+
+
+def check_kernel_padding_size(kernel_size, padding_size):
+ for i in range(len(kernel_size)):
+ if (
+ padding_size[i][0] > kernel_size[i] // 2
+ or padding_size[i][1] > kernel_size[i] // 2
+ ):
+ raise ValueError(
+ "Padding size should be less than or equal to half of the kernel size. "
+ "Got kernel_size: {} and padding_size: {}".format(
+ kernel_size, padding_size
+ )
+ )
diff --git a/ivy/backend_handler.py b/ivy/backend_handler.py
index 4566b9d2d50f6..b720e9aeecead 100644
--- a/ivy/backend_handler.py
+++ b/ivy/backend_handler.py
@@ -4,6 +4,7 @@
import numpy as np
from ivy import verbosity
from typing import Optional
+import gc
# local
from ivy.func_wrapper import _wrap_function
@@ -30,6 +31,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_array_types["jax.interpreters.xla"] = "ivy.functional.backends.jax"
_array_types["jaxlib.xla_extension"] = "ivy.functional.backends.jax"
_array_types["tensorflow.python.framework.ops"] = "ivy.functional.backends.tensorflow"
+_array_types[
+ "tensorflow.python.ops.resource_variable_ops"
+] = "ivy.functional.backends.tensorflow"
_array_types["torch"] = "ivy.functional.backends.torch"
_array_types["torch.nn.parameter"] = "ivy.functional.backends.torch"
@@ -45,7 +49,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_backend_reverse_dict["ivy.functional.backends.tensorflow"] = "tensorflow"
_backend_reverse_dict["ivy.functional.backends.torch"] = "torch"
-
# Backend Getting/Setting #
# ----------------------- #
@@ -74,24 +77,28 @@ def _determine_backend_from_args(args):
# noqa
"""
- for arg in args:
- arg_type = type(arg)
- # function is called recursively if arg is a list/tuple
- if arg_type in [list, tuple]:
- lib = _determine_backend_from_args(arg)
+ arg_type = type(args)
+ if isinstance(args, ivy.Array):
+ args = args.data
+
+ if isinstance(args, dict):
+ for key, value in args.items():
+ # recursively call the function for each value in the dictionary
+ lib = _determine_backend_from_args(value)
if lib:
return lib
- # function is called recursively if arg is a dict
- elif arg_type is dict:
- lib = _determine_backend_from_args(list(arg.values()))
+ # check if args is a list or tuple
+ elif arg_type in [list, tuple]:
+ for arg in args:
+ # recursively call the function for each element in the list/tuple
+ lib = _determine_backend_from_args(arg)
if lib:
return lib
- else:
- # use the _array_types dict to map the module where arg comes from, to the
- # corresponding Ivy backend
- if arg.__class__.__module__ in _array_types:
- module_name = _array_types[arg.__class__.__module__]
- return importlib.import_module(module_name)
+ else:
+ # check if the class module of the arg is in _array_types
+ if args.__class__.__module__ in _array_types:
+ module_name = _array_types[args.__class__.__module__]
+ return importlib.import_module(module_name)
def fn_name_from_version_specific_fn_name(name, version):
@@ -200,14 +207,16 @@ def current_backend(*args, **kwargs):
# noqa
"""
global implicit_backend
- # if a global backend has been set with set_backend then this will be returned
+ # if a global backend has been set with
+ # set_backend then this will be returned
if backend_stack:
f = backend_stack[-1]
if verbosity.level > 0:
verbosity.cprint("Using backend from stack: {}".format(f))
return f
- # if no global backend exists, we try to infer the backend from the arguments
+ # if no global backend exists, we try to infer
+ # the backend from the arguments
f = _determine_backend_from_args(list(args) + list(kwargs.values()))
if f is not None:
implicit_backend = f.current_backend_str()
@@ -217,8 +226,118 @@ def current_backend(*args, **kwargs):
return importlib.import_module(_backend_dict[implicit_backend])
-def set_backend(backend: str):
+def convert_from_source_backend_to_numpy(variable_ids, numpy_objs):
+ # Dynamic Backend
+ from ivy.functional.ivy.gradients import _is_variable, _variable_data
+
+ def _is_var(obj):
+ if isinstance(obj, ivy.Container):
+
+ def _map_fn(x):
+ x = x.data if isinstance(x, ivy.Array) else x
+ if x.__class__.__module__ in (
+ "numpy",
+ "jax.interpreters.xla",
+ "jaxlib.xla_extension",
+ ):
+ return False
+
+ return _is_variable(x)
+
+ return obj.cont_map(lambda x, kc: _map_fn(x)).cont_all_true()
+
+ else:
+ obj = obj.data if isinstance(obj, ivy.Array) else obj
+ if obj.__class__.__module__ in (
+ "numpy",
+ "jax.interpreters.xla",
+ "jaxlib.xla_extension",
+ ):
+ return False
+ return _is_variable(obj)
+
+ def _remove_intermediate_arrays(arr_list, cont_list):
+ cont_list = [cont.cont_to_flat_list() for cont in cont_list]
+
+ cont_ids = [
+ id(item.data) if isinstance(item, ivy.Array) else id(item)
+ for cont in cont_list
+ for item in cont
+ ]
+ arr_ids = [
+ id(item.data) if isinstance(item, ivy.Array) else id(item)
+ for item in arr_list
+ ]
+
+ new_objs = {k: v for k, v in zip(arr_ids, arr_list) if k not in cont_ids}
+
+ return list(new_objs.values())
+
+ # get all ivy array and container instances in the project scope
+ array_list, container_list = [
+ [obj for obj in gc.get_objects() if isinstance(obj, obj_type)]
+ for obj_type in (ivy.Array, ivy.Container)
+ ]
+
+ # filter uninitialized arrays
+ array_list = [arr for arr in array_list if arr.__dict__]
+
+ # remove numpy intermediate objects
+ new_objs = _remove_intermediate_arrays(array_list, container_list)
+ new_objs += container_list
+
+ # now convert all ivy.Array and ivy.Container instances
+ # to numpy using the current backend
+ for obj in new_objs:
+ if obj.dynamic_backend:
+ numpy_objs.append(obj)
+ if _is_var(obj):
+ # add variable object id to set
+ variable_ids.add(id(obj))
+ native_var = _variable_data(obj)
+ np_data = ivy.to_numpy(native_var)
+
+ else:
+ np_data = obj.to_numpy()
+
+ if isinstance(obj, ivy.Container):
+ obj.cont_inplace_update(np_data)
+ else:
+ obj._data = np_data
+
+ return variable_ids, numpy_objs
+
+
+def convert_from_numpy_to_target_backend(variable_ids, numpy_objs):
+ # Dynamic Backend
+ from ivy.functional.ivy.gradients import _variable
+
+ # convert all ivy.Array and ivy.Container instances from numpy
+ # to native arrays using the newly set backend
+ for obj in numpy_objs:
+ np_arr = obj.data if isinstance(obj, ivy.Array) else obj
+ # check if object was originally a variable
+ if id(obj) in variable_ids:
+ native_arr = ivy.nested_map(
+ np_arr, current_backend().asarray, include_derived=True, shallow=False
+ )
+ new_data = _variable(native_arr)
+
+ else:
+ new_data = ivy.nested_map(
+ np_arr, current_backend().asarray, include_derived=True, shallow=False
+ )
+
+ if isinstance(obj, ivy.Container):
+ obj.cont_inplace_update(new_data)
+ else:
+ obj._data = new_data.data
+
+
+def set_backend(backend: str, dynamic: bool = False):
"""Sets `backend` to be the global backend.
+ Will also convert all Array and Container objects \
+ to the new backend if `dynamic` = True
Examples
--------
@@ -236,11 +355,22 @@ def set_backend(backend: str):
>>> native = ivy.native_array([1])
>>> print(type(native))
- """
+ """ # noqa
ivy.assertions.check_false(
isinstance(backend, str) and backend not in _backend_dict,
"backend must be one from {}".format(list(_backend_dict.keys())),
)
+
+ variable_ids = set() # create an empty set to store variable object ids
+ numpy_objs = [] # create an empty list to store numpy objects
+ # created during 1st conversion step
+
+ if dynamic:
+ variable_ids, numpy_objs = convert_from_source_backend_to_numpy(
+ variable_ids, numpy_objs
+ )
+
+ # update the global dict with the new backend
ivy.locks["backend_setter"].acquire()
global ivy_original_dict
if not backend_stack:
@@ -269,18 +399,21 @@ def set_backend(backend: str):
key=k, to_wrap=backend.__dict__[k], original=v, compositional=compositional
)
+ if dynamic:
+ convert_from_numpy_to_target_backend(variable_ids, numpy_objs)
+
if verbosity.level > 0:
verbosity.cprint("backend stack: {}".format(backend_stack))
ivy.locks["backend_setter"].release()
def set_numpy_backend():
- """Sets NumPy to be the global backend. equivalent to `ivy.set_backend("numpy")`."""
+ """Sets NumPy to be the global backend. equivalent to `ivy.set_backend("numpy")`.""" # noqa
set_backend("numpy")
def set_jax_backend():
- """Sets JAX to be the global backend. equivalent to `ivy.set_backend("jax")`."""
+ """Sets JAX to be the global backend. equivalent to `ivy.set_backend("jax")`.""" # noqa
set_backend("jax")
@@ -293,7 +426,7 @@ def set_tensorflow_backend():
def set_torch_backend():
- """Sets torch to be the global backend. equivalent to `ivy.set_backend("torch")`."""
+ """Sets torch to be the global backend. equivalent to `ivy.set_backend("torch")`.""" # noqa
set_backend("torch")
@@ -326,11 +459,11 @@ def get_backend(backend: Optional[str] = None):
>>> ivy.set_backend("jax")
>>> ivy_jax = ivy.get_backend()
>>> print(ivy_jax)
- # noqa
- """
- # ToDo: change this so that it doesn't depend at all on the global ivy. Currently
- # all backend-agnostic implementations returned in this module will still
- # use the global ivy backend.
+
+ """ # noqa
+ # ToDo: change this so that it doesn't depend at all on the global ivy.
+ # Currently all backend-agnostic implementations returned in this
+ # module will still use the global ivy backend.
global ivy_original_dict
if not backend_stack:
ivy_original_dict = ivy.__dict__.copy()
@@ -378,17 +511,18 @@ def unset_backend():
>>> x = ivy.native_array([1])
>>> print(type(x))
- """
+ """ # noqa
backend = None
- # if the backend stack is empty, nothing is done and we just return `None`
+ # if the backend stack is empty, nothing is done then we just return `None`
if backend_stack:
backend = backend_stack.pop(-1) # remove last backend from the stack
if backend.current_backend_str() == "numpy":
ivy.unset_default_device()
elif backend.current_backend_str() == "jax":
ivy.del_global_attr("RNG")
- # the new backend is the backend that was set before the one we just removed
- # from the stack, or Ivy if there was no previously set backend
+ # the new backend is the backend that was set before the one
+ # we just removed from the stack, or Ivy if there was no
+ # previously set backend
if backend_stack:
new_backend = backend_stack[-1]
if new_backend.current_backend_str() == "numpy":
diff --git a/ivy/container/base.py b/ivy/container/base.py
index c077701738999..203759e0d30cc 100644
--- a/ivy/container/base.py
+++ b/ivy/container/base.py
@@ -48,6 +48,8 @@ def _repr(x):
# noinspection PyMissingConstructor
+
+
class ContainerBase(dict, abc.ABC):
def __init__(
self,
@@ -66,6 +68,7 @@ def __init__(
rebuild_child_containers=False,
types_to_iteratively_nest=None,
alphabetical_keys=True,
+ dynamic_backend=None,
**kwargs,
):
"""Initialize container object from input dict representation.
@@ -130,6 +133,10 @@ def __init__(
self._loaded_containers_from_queues = dict()
self._queue_load_sizes_cum = np.cumsum(queue_load_sizes)
self._queue_timeout = ivy.default(queue_timeout, ivy.get_queue_timeout())
+ if dynamic_backend is not None:
+ self._dynamic_backend = dynamic_backend
+ else:
+ self._dynamic_backend = ivy.get_dynamic_backend()
if dict_in is None:
if kwargs:
dict_in = dict(**kwargs)
@@ -157,7 +164,6 @@ def __init__(
# Class Methods #
# --------------#
-
@staticmethod
def cont_multi_map_in_function(
fn,
@@ -1586,7 +1592,7 @@ def cont_inplace_update(
**config
"""
- # update config
+ # # update config
self.cont_update_config(**config)
# update container values inplace
@@ -4025,6 +4031,58 @@ def __setitem__(self, query, val):
New container after updating.
"""
+
+ def _map_fn(fn, x):
+ x = x.data if isinstance(x, ivy.Array) else x
+ return fn(x)
+
+ if query == "_backend":
+ self._backend = val
+ return
+
+ if query == "dynamic_backend":
+ from ivy.functional.ivy.gradients import _variable
+ from ivy.backend_handler import _determine_backend_from_args
+
+ if not val:
+ self._backend = _determine_backend_from_args(self)
+ else:
+ is_variable = self._backend.is_variable
+ to_numpy = self._backend.to_numpy
+ variable_data = self._backend.variable_data
+
+ def _is_var(x):
+ x = x.data if isinstance(x, ivy.Array) else x
+ return is_variable(x)
+
+ is_var = self.cont_map(lambda x, kc: _is_var(x)).cont_all_true()
+ if is_var and not (
+ str(self._backend).__contains__("jax")
+ or str(self._backend).__contains__("numpy")
+ ):
+ self.cont_map(lambda x, kc: _map_fn(variable_data, x), inplace=True)
+ self.cont_map(lambda x, kc: _map_fn(to_numpy, x), inplace=True)
+ self.cont_map(lambda x, kc: _map_fn(ivy.array, x), inplace=True)
+ self.cont_map(lambda x, kc: _map_fn(_variable, x), inplace=True)
+
+ else:
+ self.cont_map(lambda x, kc: _map_fn(to_numpy, x), inplace=True)
+ self.cont_map(lambda x, kc: _map_fn(ivy.array, x), inplace=True)
+
+ def _set_dyn_backend(obj, val):
+ if isinstance(obj, ivy.Array):
+ obj._dynamic_backend = val
+ return
+
+ if isinstance(obj, ivy.Container):
+ for item in obj.values():
+ _set_dyn_backend(item, val)
+
+ obj._dynamic_backend = val
+
+ _set_dyn_backend(self, val)
+ return
+
if isinstance(query, str) and ("/" in query or "." in query):
return self.cont_set_at_key_chain(query, val, inplace=True)
else:
@@ -4132,3 +4190,11 @@ def cont_max_depth(self):
if not kcs:
return 0
return max([len(kc.split("/")) for kc in kcs])
+
+ @property
+ def dynamic_backend(self):
+ return self._dynamic_backend
+
+ @dynamic_backend.setter
+ def dynamic_backend(self, value):
+ self._dynamic_backend = value
diff --git a/ivy/container/container.py b/ivy/container/container.py
index ef6cab7c39503..ff143532771bc 100644
--- a/ivy/container/container.py
+++ b/ivy/container/container.py
@@ -108,6 +108,7 @@ def __init__(
rebuild_child_containers=False,
types_to_iteratively_nest=None,
alphabetical_keys=True,
+ dynamic_backend=None,
**kwargs
):
ContainerBase.__init__(
@@ -127,6 +128,7 @@ def __init__(
rebuild_child_containers,
types_to_iteratively_nest,
alphabetical_keys,
+ dynamic_backend,
**kwargs
)
diff --git a/ivy/container/experimental/activations.py b/ivy/container/experimental/activations.py
index d4103f4546765..d6f7ce2230431 100644
--- a/ivy/container/experimental/activations.py
+++ b/ivy/container/experimental/activations.py
@@ -249,3 +249,77 @@ def thresholded_relu(
map_sequences=map_sequences,
out=out,
)
+
+ @staticmethod
+ def static_prelu(
+ x: Union[ivy.NativeArray, ivy.Array, ivy.Container],
+ slope: Union[float, ivy.NativeArray, ivy.Array, ivy.Container],
+ /,
+ *,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out: Optional["ivy.Array"] = None,
+ ) -> ivy.Container:
+ """
+
+ Parameters
+ ----------
+ x
+ slope
+ key_chains
+ to_apply
+ prune_unapplied
+ map_sequences
+ out
+
+ Returns
+ -------
+
+ """
+ return ContainerBase.cont_multi_map_in_function(
+ "prelu",
+ x,
+ slope,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
+
+ def prelu(
+ self: ivy.Container,
+ slope: Union[float, ivy.NativeArray, ivy.Array, ivy.Container],
+ *,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out: Optional[ivy.Container] = None,
+ ) -> ivy.Container:
+ """
+
+ Parameters
+ ----------
+ slope
+ key_chains
+ to_apply
+ prune_unapplied
+ map_sequences
+ out
+
+ Returns
+ -------
+
+ """
+ return self.static_prelu(
+ self,
+ slope,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
diff --git a/ivy/container/experimental/elementwise.py b/ivy/container/experimental/elementwise.py
index 12376fe5f6221..5a0824fa8a467 100644
--- a/ivy/container/experimental/elementwise.py
+++ b/ivy/container/experimental/elementwise.py
@@ -2232,9 +2232,13 @@ def allclose(
@staticmethod
def static_diff(
- x: Union[ivy.Array, ivy.NativeArray, ivy.Container, int, list, tuple],
+ x: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
+ n: int = 1,
+ axis: int = -1,
+ prepend: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
+ append: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -2250,6 +2254,27 @@ def static_diff(
----------
x
input container with array-like items.
+ n
+ The number of times values are differenced. If zero, the input is returned
+ as-is.
+ axis
+ The axis along which the difference is taken, default is the last axis.
+ prepend,append
+ Values to prepend/append to x along given axis prior to performing the
+ difference. Scalar values are expanded to arrays with length 1 in the
+ direction of axis and the shape of the input array in along all other
+ axes. Otherwise the dimension and shape must match x except along axis.
+ key_chains
+ The key-chains to apply or not apply the method to. Default is ``None``.
+ to_apply
+ If True, the method will be applied to key_chains, otherwise key_chains
+ will be skipped. Default is ``True``.
+ prune_unapplied
+ Whether to prune key_chains for which the function was not applied.
+ Default is ``False``.
+ map_sequences
+ Whether to also map method to sequences (lists, tuples).
+ Default is ``False``.
out
optional output container, for writing the result to.
@@ -2261,17 +2286,21 @@ def static_diff(
Examples
--------
- >>> x = ivy.Container(a=ivy.array([1, 2, 4, 7, 0]),\
- b=ivy.array([1, 2, 4, 7, 0]))
+ >>> x = ivy.Container(a=ivy.array([1, 2, 4, 7, 0]),
+ b=ivy.array([1, 2, 4, 7, 0]))
>>> ivy.Container.static_diff(x)
{
- a: ivy.array([ 1, 2, 3, -7])
+ a: ivy.array([ 1, 2, 3, -7]),
b: ivy.array([ 1, 2, 3, -7])
}
"""
return ContainerBase.cont_multi_map_in_function(
"diff",
x,
+ n=n,
+ axis=axis,
+ prepend=prepend,
+ append=append,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
@@ -2283,6 +2312,10 @@ def diff(
self: ivy.Container,
/,
*,
+ n: int = 1,
+ axis: int = -1,
+ prepend: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
+ append: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.diff. This method simply
@@ -2293,6 +2326,16 @@ def diff(
----------
self
input container with array-like items.
+ n
+ The number of times values are differenced. If zero, the input is returned
+ as-is.
+ axis
+ The axis along which the difference is taken, default is the last axis.
+ prepend,append
+ Values to prepend/append to x along given axis prior to performing the
+ difference. Scalar values are expanded to arrays with length 1 in the
+ direction of axis and the shape of the input array in along all other
+ axes. Otherwise the dimension and shape must match x except along axis.
out
optional output container, for writing the result to.
@@ -2304,15 +2347,17 @@ def diff(
Examples
--------
- >>> x = ivy.Container(a=ivy.array([1, 2, 4, 7, 0]),\
- b=ivy.array([1, 2, 4, 7, 0]))
- >>> ivy.Container.static_diff(x)
+ >>> x = ivy.Container(a=ivy.array([1, 2, 4, 7, 0]),
+ b=ivy.array([1, 2, 4, 7, 0]))
+ >>> x.diff()
{
- a: ivy.array([ 1, 2, 3, -7])
- b: ivy.array([ 1, 2, 3, -7])
+ a: ivy.array([1, 2, 3, -7]),
+ b: ivy.array([1, 2, 3, -7])
}
"""
- return self.static_diff(self, out=out)
+ return self.static_diff(
+ self, n=n, axis=axis, prepend=prepend, append=append, out=out
+ )
@staticmethod
def static_fix(
diff --git a/ivy/container/experimental/layers.py b/ivy/container/experimental/layers.py
index d3f141c021dd5..326b3fa965a7b 100644
--- a/ivy/container/experimental/layers.py
+++ b/ivy/container/experimental/layers.py
@@ -1,5 +1,5 @@
# global
-from typing import Optional, Union, List, Dict, Tuple, Literal
+from typing import Optional, Union, List, Dict, Tuple, Literal, Sequence
# local
import ivy
@@ -150,6 +150,8 @@ def static_max_pool2d(
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -200,6 +202,8 @@ def static_max_pool2d(
strides,
padding,
data_format=data_format,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
@@ -215,6 +219,8 @@ def max_pool2d(
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -266,6 +272,8 @@ def max_pool2d(
strides,
padding,
data_format=data_format,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
@@ -1351,3 +1359,129 @@ def dft(
map_sequences=map_sequences,
out=out,
)
+
+ @staticmethod
+ def static_interpolate(
+ x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Union[Literal["linear", "bilinear", "trilinear", "nearest"]] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out: Optional[ivy.Container] = None,
+ ) -> ivy.Container:
+ """
+ Down/up samples the input to the given size.
+ The algorithm used for interpolation is determined by mode.
+
+ Parameters
+ ----------
+ x
+ Input array, Must have the shape
+ [batch x channels x [optional depth] x [optional height] x width].
+ size
+ Output size.
+ mode
+ Interpolation mode. Can be one of the following:
+ - linear
+ - bilinear
+ - trilinear
+ - nearest
+ align_corners
+ If True, the corner pixels of the input and output tensors are aligned,
+ and thus preserving the values at the corner pixels. If False, the corner
+ pixels are not aligned, and the interpolation uses edge value padding for
+ out-of-boundary values.
+ only has an effect when mode is 'linear', 'bilinear',
+ 'bicubic' or 'trilinear'. Default: False
+ antialias
+ If True, antialiasing is applied when downsampling an image.
+ Supported modes: 'bilinear', 'bicubic'.
+ out
+ Optional output array, for writing the result to. It must
+ have a shape that the inputs broadcast to.
+
+ Returns
+ -------
+ resized array
+ """
+ return ContainerBase.cont_multi_map_in_function(
+ "interpolate",
+ x,
+ size,
+ mode=mode,
+ align_corners=align_corners,
+ antialias=antialias,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
+
+ def interpolate(
+ self: ivy.Container,
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Union[Literal["linear", "bilinear", "trilinear", "nearest"]] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out: Optional[ivy.Container] = None,
+ ) -> ivy.Container:
+ """
+ Down/up samples the input to the given size.
+ The algorithm used for interpolation is determined by mode.
+
+ Parameters
+ ----------
+ x
+ Input array, Must have the shape
+ [batch x channels x [optional depth] x [optional height] x width].
+ size
+ Output size.
+ mode
+ Interpolation mode. Can be one of the following:
+ - linear
+ - bilinear
+ - trilinear
+ - nearest
+ align_corners
+ If True, the corner pixels of the input and output tensors are aligned,
+ and thus preserving the values at the corner pixels. If False, the corner
+ pixels are not aligned, and the interpolation uses edge value padding for
+ out-of-boundary values.
+ only has an effect when mode is 'linear', 'bilinear',
+ 'bicubic' or 'trilinear'. Default: False
+ antialias
+ If True, antialiasing is applied when downsampling an image.
+ Supported modes: 'bilinear', 'bicubic'.
+ out
+ Optional output array, for writing the result to. It must
+ have a shape that the inputs broadcast to.
+
+ Returns
+ -------
+ resized array
+ """
+ return self.static_interpolate(
+ self,
+ size,
+ mode=mode,
+ align_corners=align_corners,
+ antialias=antialias,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
diff --git a/ivy/container/experimental/manipulation.py b/ivy/container/experimental/manipulation.py
index 7c7e5295264b9..7bbe3df584005 100644
--- a/ivy/container/experimental/manipulation.py
+++ b/ivy/container/experimental/manipulation.py
@@ -2046,15 +2046,14 @@ def take_along_axis(
@staticmethod
def static_hsplit(
ary: Union[ivy.Array, ivy.NativeArray, ivy.Container],
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
- out: Optional[ivy.Container] = None,
- ) -> ivy.Container:
+ ) -> List[ivy.Container]:
"""
ivy.Container static method variant of ivy.hsplit. This method simply wraps
the function, and so the docstring for ivy.hsplit also applies to this method
@@ -2065,54 +2064,53 @@ def static_hsplit(
ary
the container with array inputs.
indices_or_sections
- If indices_or_sections is an integer n, the array is split into n sections.
- If the array is divisible by n horizontally, each section will be of equal
- size. If input is not divisible by n, the sizes of the first
- int(ary.size(0) % n) sections will have size int(ary.size(0) / n) + 1, and
- the rest will have size int(ary.size(0) / n).
+ If indices_or_sections is an integer n, the array is split into n
+ equal sections, provided that n must be a divisor of the split axis.
If indices_or_sections is a tuple of ints, then input is split at each of
the indices in the tuple.
- out
- optional output container, for writing the result to.
+ key_chains
+ The keychains to apply or not apply the method to. Default is ``None``.
+ to_apply
+ If True, the method will be applied to key_chains, otherwise key_chains
+ will be skipped. Default is ``True``.
+ prune_unapplied
+ Whether to prune key_chains for which the function was not applied.
+ Default is ``False``.
+ map_sequences
+ Whether to also map method to sequences (lists, tuples).
+ Default is ``False``.
Returns
-------
ret
- container including input arrays split horizontally.
+ list of containers split horizontally from input array.
Examples
--------
>>> ary = ivy.Container(
- a = ivy.ivy.array(
+ a = ivy.array(
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]]
),
b=ivy.array(
- [[ 0., 1., 2., 3.],
- [ 4., 5., 6., 7.],
- [ 8., 9., 10., 11.],
- [12., 13., 14., 15.]])
+ [0., 1., 2., 3.,
+ 4., 5., 6., 7.,
+ 8., 9., 10., 11.,
+ 12., 13., 14., 15.]
)
)
>>> ivy.Container.static_hsplit(ary, 2)
- {
- a: ivy.ivy.array(
- [[[0., 1.],
- [2., 3.]],
- [[4., 5.],
- [6., 7.]]]
- ),
- b: [ivy.array([[ 0., 1.],
- [ 4., 5.],
- [ 8., 9.],
- [12., 13.]]),
- ivy.array([[ 2., 3.],
- [ 6., 7.],
- [10., 11.],
- [14., 15.]])
- }
+ [{
+ a: ivy.array([[[0., 1.]],
+ [[4., 5.]]]),
+ b: ivy.array([0., 1., 2., 3., 4., 5., 6., 7.])
+ }, {
+ a: ivy.array([[[2., 3.]],
+ [[6., 7.]]]),
+ b: ivy.array([8., 9., 10., 11., 12., 13., 14., 15.])
+ }]
"""
return ContainerBase.cont_multi_map_in_function(
"hsplit",
@@ -2122,16 +2120,13 @@ def static_hsplit(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
- out=out,
)
def hsplit(
self: ivy.Container,
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[ivy.Container] = None,
- ) -> ivy.Container:
+ ) -> List[ivy.Container]:
"""ivy.Container instance method variant of ivy.hsplit. This method simply
wraps the function, and so the docstring for ivy.hsplit also applies to this
method with minimal changes.
@@ -2141,57 +2136,44 @@ def hsplit(
self
the container with array inputs.
indices_or_sections
- If indices_or_sections is an integer n, the array is split into n sections.
- If the array is divisible by n horizontally, each section will be of equal
- size. If input is not divisible by n, the sizes of the first
- int(ary.size(0) % n) sections will have size int(ary.size(0) / n) + 1, and
- the rest will have size int(ary.size(0) / n).
+ If indices_or_sections is an integer n, the array is split into n
+ equal sections, provided that n must be a divisor of the split axis.
If indices_or_sections is a tuple of ints, then input is split at each of
the indices in the tuple.
- out
- optional output container, for writing the result to.
Returns
-------
ret
- container including arrays with the modified Bessel
- function evaluated at each of the elements of x.
+ list of containers split horizontally from input container
Examples
--------
>>> ary = ivy.Container(
- a = ivy.ivy.array(
+ a = ivy.array(
[[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]]
),
b=ivy.array(
- [[ 0., 1., 2., 3.],
- [ 4., 5., 6., 7.],
- [ 8., 9., 10., 11.],
- [12., 13., 14., 15.]])
+ [0., 1., 2., 3.,
+ 4., 5., 6., 7.,
+ 8., 9., 10., 11.,
+ 12., 13., 14., 15.]
)
)
>>> ary.hsplit(2)
- {
- a: ivy.ivy.array(
- [[[0., 1.],
- [2., 3.]],
- [[4., 5.],
- [6., 7.]]]
- ),
- b: [ivy.array([[ 0., 1.],
- [ 4., 5.],
- [ 8., 9.],
- [12., 13.]]),
- ivy.array([[ 2., 3.],
- [ 6., 7.],
- [10., 11.],
- [14., 15.]])
- }
+ [{
+ a: ivy.array([[[0., 1.]],
+ [[4., 5.]]]),
+ b: ivy.array([0., 1., 2., 3., 4., 5., 6., 7.])
+ }, {
+ a: ivy.array([[[2., 3.]],
+ [[6., 7.]]]),
+ b: ivy.array([8., 9., 10., 11., 12., 13., 14., 15.])
+ }]
"""
- return self.static_hsplit(self, indices_or_sections, out=out)
+ return self.static_hsplit(self, indices_or_sections)
@staticmethod
def static_broadcast_shapes(
diff --git a/ivy/container/experimental/norms.py b/ivy/container/experimental/norms.py
index f92704c33c8cd..0b17dda667b67 100644
--- a/ivy/container/experimental/norms.py
+++ b/ivy/container/experimental/norms.py
@@ -398,3 +398,138 @@ def instance_norm(
track_running_stats=track_running_stats,
out=out,
)
+
+ @staticmethod
+ def static_lp_normalize(
+ x: Union[ivy.Container, ivy.Array, ivy.NativeArray],
+ p: float = 2,
+ axis: int = None,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out=None,
+ ) -> ivy.Container:
+ """
+ ivy.Container static method variant of ivy.lp_normalize.
+ This method simply wraps the function, and so the
+ docstring for ivy.lp_normalize also applies to this method
+ with minimal changes.
+
+ Parameters
+ ----------
+ x
+ The input container with leaves to be normalized.
+ p
+ The order of the norm.
+ axis
+ The axis along which to normalize.
+ key_chains
+ The key-chains to apply or not apply the method to. Default is ``None``.
+ to_apply
+ If True, the method will be applied to key_chains, otherwise key_chains
+ will be skipped. Default is ``True``.
+ prune_unapplied
+ Whether to prune key_chains for which the function was not applied.
+ Default is ``False``.
+ map_sequences
+ Whether to also map method to sequences (lists, tuples).
+ Default is ``False``.
+ out
+ optional output container, for writing the result to. It must have a shape
+ that the inputs broadcast to.
+
+ Returns
+ -------
+ ret
+ a container containing the normalized leaves.
+
+ Examples
+ --------
+ >>> x = ivy.Container(a=ivy.array([[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]])))
+ ... b=ivy.array([[-1., -1.], [-1., -0.5]]]))
+ >>> y = ivy.Container.static_lp_normalize(x, p=1, axis=1)
+ >>> print(y)
+ {
+ a: ivy.array([[0.12500000, 0.37500000, 0.62500000],
+ [0.27500000, 0.35000000, 0.42500000]]),
+ b: ivy.array([[-1.0000000, -1.0000000],
+ [-0.5000000, -0.2500000]])
+ }
+ """
+ return ContainerBase.cont_multi_map_in_function(
+ "lp_normalize",
+ x,
+ p=p,
+ axis=axis,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
+
+ def lp_normalize(
+ self,
+ p: float = 2,
+ axis: int = None,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out=None,
+ ) -> ivy.Container:
+ """ivy.Container instance method variant of ivy.l2_normalize.
+ This method simply wraps the function, and so the
+ docstring for ivy.l2_normalize also applies to this method
+ with minimal changes.
+
+ Parameters
+ ----------
+ self
+ The input container with leaves to be normalized.
+ axis
+ The axis along which to normalize.
+ key_chains
+ The key-chains to apply or not apply the method to. Default is ``None``.
+ to_apply
+ If True, the method will be applied to key_chains, otherwise key_chains
+ will be skipped. Default is ``True``.
+ prune_unapplied
+ Whether to prune key_chains for which the function was not applied.
+ Default is ``False``.
+ map_sequences
+ Whether to also map method to sequences (lists, tuples).
+ Default is ``False``.
+ out
+ optional output container, for writing the result to. It must have a shape
+ that the inputs broadcast to.
+
+ Returns
+ -------
+ ret
+ a container containing the normalized leaves.
+
+ Examples
+ --------
+ >>> x = ivy.Container(a=ivy.array([[0.5, 1.5, 2.5], [3.5, 4.5, 5.5]])))
+ ... b=ivy.array([[-1., -1.], [-1., -0.5]]]))
+ >>> y = x.static_lp_normalize(axis=1)
+ >>> print(y)
+ {
+ a: ivy.array([[0.16903085, 0.50709254, 0.84515423],
+ [0.44183609, 0.56807494, 0.69431382]]),
+ b: ivy.array([[-0.70710677, -0.70710677],
+ [-0.89442718, -0.44721359]])
+ }
+ """
+ return self.static_lp_normalize(
+ self,
+ p=p,
+ axis=axis,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
diff --git a/ivy/container/experimental/statistical.py b/ivy/container/experimental/statistical.py
index bcd52db471007..6953e80c03207 100644
--- a/ivy/container/experimental/statistical.py
+++ b/ivy/container/experimental/statistical.py
@@ -668,3 +668,115 @@ def corrcoef(
}
"""
return self.static_corrcoef(self, y=y, rowvar=rowvar, out=out)
+
+ @staticmethod
+ def static_nanmedian(
+ input: ivy.Container,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
+ to_apply: bool = True,
+ prune_unapplied: bool = False,
+ map_sequences: bool = False,
+ out: Optional[ivy.Array] = None,
+ ) -> ivy.Container:
+ """
+ ivy.Container static method variant of ivy.median. This method simply wraps
+ the function, and so the docstring for ivy.median also applies to this method
+ with minimal changes.
+
+ Parameters
+ ----------
+ input
+ Input container including arrays.
+ axis
+ Axis or axes along which the medians are computed. The default is to compute
+ the median along a flattened version of the array.
+ keepdims
+ If this is set to True, the axes which are reduced are left in the result
+ as dimensions with size one.
+ out
+ optional output array, for writing the result to.
+
+ Returns
+ -------
+ ret
+ The median of the array elements.
+
+ Examples
+ --------
+ With one :class:`ivy.Container` input:
+ >>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6)))
+ >>> ivy.Container.static_nanmedian(x, 0, -1).shape
+ {
+ a: (4, 5, 3)
+ b: (7, 6, 2)
+ }
+ """
+ return ContainerBase.cont_multi_map_in_function(
+ "nanmedian",
+ input,
+ axis=axis,
+ keepdims=keepdims,
+ key_chains=key_chains,
+ to_apply=to_apply,
+ prune_unapplied=prune_unapplied,
+ map_sequences=map_sequences,
+ out=out,
+ )
+
+ def nanmedian(
+ self: ivy.Container,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ overwrite_input: Optional[bool] = False,
+ out: Optional[ivy.Container] = None,
+ ) -> ivy.Container:
+ """ivy.Array instance method variant of ivy.nanmedian. This method simply
+ wraps the function, and so the docstring for ivy.nanmedian also applies to
+ this method with minimal changes.
+
+ Parameters
+ ----------
+ self
+ Input array.
+ axis
+ Axis or axes along which the means are computed.
+ The default is to compute the mean of the flattened array.
+ keepdims
+ If this is set to True, the axes which are reduced are left in the result
+ as dimensions with size one. With this option, the result will broadcast
+ correctly against the original a. If the value is anything but the default,
+ then keepdims will be passed through to the mean or sum methods of
+ sub-classes of ndarray. If the sub-classes methods does not implement
+ keepdims any exceptions will be raised.
+ overwrite_input
+ If True, then allow use of memory of input array a for calculations.
+ The input array will be modified by the call to median. This will
+ save memory when you do not need to preserve the contents of the input array.
+ Treat the input as undefined, but it will probably be fully or partially sorted.
+ Default is False. If overwrite_input is True and a is not already an ndarray,
+ an error will be raised.
+ out
+ optional output array, for writing the result to.
+
+ Returns
+ -------
+ ret
+ A new array holding the result. If the input contains integers
+
+ Examples
+ >>> a = ivy.Container([[10.0, ivy.nan, 4], [3, 2, 1]])
+ >>> a.nanmedian(a)
+ 3.0
+ >>> a.nanmedian(a, axis=0)
+ array([6.5, 2. , 2.5])
+ """
+
+ return self.static_nanmedian(
+ self, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, out=out
+ )
diff --git a/ivy/container/layers.py b/ivy/container/layers.py
index 27405341bd50b..10eff40781b6f 100644
--- a/ivy/container/layers.py
+++ b/ivy/container/layers.py
@@ -979,12 +979,12 @@ def conv1d(
def static_conv2d(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
filters: Union[ivy.Array, ivy.NativeArray, ivy.Container],
- strides: Union[int, Tuple[int], Tuple[int, int]],
+ strides: Union[int, Tuple[int, int]],
padding: str,
/,
*,
data_format: str = "NHWC",
- dilations: Optional[Union[int, Tuple[int], Tuple[int, int]]] = 1,
+ dilations: Union[int, Tuple[int, int]] = 1,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -1011,6 +1011,17 @@ def static_conv2d(
"NHWC" or "NCHW". Defaults to "NHWC".
dilations
The dilation factor for each dimension of input. (Default value = 1)
+ key_chains
+ The key-chains to apply or not apply the method to. Default is ``None``.
+ to_apply
+ If True, the method will be applied to key_chains, otherwise key_chains
+ will be skipped. Default is ``True``.
+ prune_unapplied
+ Whether to prune key_chains for which the function was not applied.
+ Default is ``False``.
+ map_sequences
+ Whether to also map method to sequences (lists, tuples).
+ Default is ``False``.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
@@ -1052,12 +1063,12 @@ def static_conv2d(
def conv2d(
self: ivy.Container,
filters: Union[ivy.Array, ivy.NativeArray, ivy.Container],
- strides: Union[int, Tuple[int], Tuple[int, int]],
+ strides: Union[int, Tuple[int, int]],
padding: str,
/,
*,
data_format: str = "NHWC",
- dilations: Optional[Union[int, Tuple[int], Tuple[int, int]]] = 1,
+ dilations: Union[int, Tuple[int, int]] = 1,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -1071,7 +1082,7 @@ def conv2d(
Parameters
----------
- x
+ self
Input image *[batch_size,h,w,d_in]*.
filters
Convolution filters *[fh,fw,d_in,d_out]*.
@@ -1084,6 +1095,17 @@ def conv2d(
"NHWC" or "NCHW". Defaults to "NHWC".
dilations
The dilation factor for each dimension of input. (Default value = 1)
+ key_chains
+ The key-chains to apply or not apply the method to. Default is ``None``.
+ to_apply
+ If True, the method will be applied to key_chains, otherwise key_chains
+ will be skipped. Default is ``True``.
+ prune_unapplied
+ Whether to prune key_chains for which the function was not applied.
+ Default is ``False``.
+ map_sequences
+ Whether to also map method to sequences (lists, tuples).
+ Default is ``False``.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
diff --git a/ivy/container/linear_algebra.py b/ivy/container/linear_algebra.py
index c359bb14616e1..862af6372ddda 100644
--- a/ivy/container/linear_algebra.py
+++ b/ivy/container/linear_algebra.py
@@ -21,6 +21,8 @@ def static_matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -77,6 +79,8 @@ def static_matmul(
x2,
transpose_a=transpose_a,
transpose_b=transpose_b,
+ adjoint_a=adjoint_a,
+ adjoint_b=adjoint_b,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
@@ -91,6 +95,8 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
@@ -146,6 +152,8 @@ def matmul(
x2,
transpose_a=transpose_a,
transpose_b=transpose_b,
+ adjoint_a=adjoint_a,
+ adjoint_b=adjoint_b,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py
index 7360742d44481..e8ac4dd7f23e8 100644
--- a/ivy/func_wrapper.py
+++ b/ivy/func_wrapper.py
@@ -8,6 +8,7 @@
# for wrapping (sequence matters)
FN_DECORATORS = [
+ "array_function_wrapper",
"infer_device",
"infer_dtype",
"integer_arrays_to_float",
@@ -28,6 +29,30 @@
# --------#
+def try_array_function_override(func, overloaded_args, types, args, kwargs):
+ if not overloaded_args:
+ return False, None
+
+ for overloaded_arg in overloaded_args:
+ # Note that we're only calling __array_function__ on the *first*
+ # occurence of each argument type. This is necessary for reasonable
+ # performance with a possibly long list of overloaded arguments, for
+ # which each __array_function__ implementation might reasonably need to
+ # check all argument types.
+ try:
+ result = overloaded_arg.__array_function__(func, types, args, kwargs)
+ except Exception:
+ raise ivy.exceptions.IvyNotImplementedException
+
+ if result is not NotImplemented:
+ return True, result
+
+ raise TypeError(
+ "no implementation found for {} on types that implement "
+ "__array_function__: {}".format(func, list(map(type, overloaded_args)))
+ )
+
+
def _get_first_array(*args, **kwargs):
# ToDo: make this more efficient, with function ivy.nested_nth_index_where
arr = None
@@ -50,6 +75,65 @@ def _get_first_array(*args, **kwargs):
# ---------------#
+def handle_array_function(func):
+ """
+ Wrap a function to extract the relevant argument types to be passed to
+ array_function method.
+ """
+
+ @functools.wraps(func)
+ def new_func(*args, **kwargs):
+ overloaded_types = []
+ overloaded_args = []
+
+ for arg in args + tuple(kwargs.values()):
+ if ivy.exists(arg) and (
+ not isinstance(arg, ivy.Container)
+ and hasattr(arg, "__array_function__")
+ ):
+ if type(arg) not in overloaded_types:
+ overloaded_types.append(type(arg))
+ if (
+ arg.__array_function__ is not ivy.Array.__array_function__
+ and not isinstance(arg, (ivy.Array, ivy.NativeArray))
+ ):
+ index = len(overloaded_args)
+ for i, old_arg in enumerate(overloaded_args):
+ if issubclass(type(arg), type(old_arg)):
+ index = i
+ break
+ overloaded_args.insert(index, arg)
+ if ivy.exists(arg) and isinstance(arg, ivy.Container):
+ arg = ivy.Container.cont_flatten_key_chains(arg)
+ indices = ivy.nested_argwhere(
+ arg, lambda x: hasattr(x, "__array_function__")
+ )
+ for a in indices:
+ if type(getattr(arg, a[0])) not in overloaded_types:
+ overloaded_types.append(type(getattr(arg, a[0])))
+ if getattr(
+ arg, a[0]
+ ).__array_function__ is not ivy.Array.__array_function__ and not isinstance(
+ getattr(arg, a[0]), (ivy.Array, ivy.NativeArray)
+ ):
+ index = len(overloaded_args)
+ for i, old_arg in enumerate(overloaded_args):
+ if issubclass(type(getattr(arg, a[0])), type(old_arg)):
+ index = i
+ break
+ overloaded_args.insert(index, arg)
+
+ success, value = try_array_function_override(
+ ivy.__dict__[func.__name__], overloaded_args, overloaded_types, args, kwargs
+ )
+ if success:
+ return value
+ return func(*args, **kwargs)
+
+ new_func.array_function_wrapper = True
+ return new_func
+
+
def handle_array_like_without_promotion(fn: Callable) -> Callable:
@functools.wraps(fn)
def new_fn(*args, **kwargs):
diff --git a/ivy/functional/__init__.py b/ivy/functional/__init__.py
index 22620995859c9..01a9fd3d4b36e 100644
--- a/ivy/functional/__init__.py
+++ b/ivy/functional/__init__.py
@@ -2,4 +2,3 @@
from .ivy.experimental import *
from . import ivy
from .ivy import *
-from . import frontends
diff --git a/ivy/functional/backends/jax/experimental/elementwise.py b/ivy/functional/backends/jax/experimental/elementwise.py
index f2197cba3c3c1..2b1d7b1d65190 100644
--- a/ivy/functional/backends/jax/experimental/elementwise.py
+++ b/ivy/functional/backends/jax/experimental/elementwise.py
@@ -2,8 +2,14 @@
from typing import Optional, Union, Tuple, List
from numbers import Number
-from ivy import promote_types_of_inputs, default_float_dtype, is_float_dtype
+from ivy import (
+ promote_types_of_inputs,
+ default_float_dtype,
+ is_float_dtype,
+)
from ivy.functional.backends.jax import JaxArray
+from ivy.func_wrapper import with_unsupported_dtypes
+from . import backend_version
import jax.numpy as jnp
import jax.scipy as js
@@ -19,6 +25,7 @@ def sinc(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.sinc(x)
+@with_unsupported_dtypes({"0.3.14 and below": ("bfloat16",)}, backend_version)
def fmod(
x1: JaxArray,
x2: JaxArray,
@@ -70,7 +77,12 @@ def float_power(
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
- return jnp.float_power(x1, x2)
+ x1, x2 = promote_types_of_inputs(x1, x2)
+ if jnp.any(jnp.iscomplex(x1)) or jnp.any(jnp.iscomplex(x2)):
+ out_dtype = jnp.complex128
+ else:
+ out_dtype = jnp.float64
+ return jnp.float_power(x1, x2).astype(out_dtype)
def exp2(
@@ -210,15 +222,20 @@ def allclose(
def diff(
- x: Union[JaxArray, int, float, list, tuple],
+ x: JaxArray,
/,
*,
- n: Optional[int] = 1,
- axis: Optional[int] = -1,
+ n: int = 1,
+ axis: int = -1,
prepend: Optional[Union[JaxArray, int, float, list, tuple]] = None,
append: Optional[Union[JaxArray, int, float, list, tuple]] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
+ x = jnp.asarray(x)
+ if isinstance(prepend, (list, tuple)):
+ prepend = jnp.asarray(prepend)
+ if isinstance(append, (list, tuple)):
+ append = jnp.asarray(append)
return jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append)
@@ -274,11 +291,7 @@ def zeta(
temp = jnp.logical_and(jnp.not_equal(jnp.remainder(x, 2), 0), jnp.greater(x, 1))
temp = jnp.logical_and(temp, jnp.less_equal(q, 0))
nan_indices = jnp.logical_or(temp, jnp.less(x, 1))
- n, res = 1, 1 / q**x
- while n < 10000:
- term = 1 / (q + n) ** x
- n, res = n + 1, res + term
- ret = jnp.round(res, decimals=4)
+ ret = js.special.zeta(x, q)
ret = ret.at[nan_indices].set(jnp.nan)
ret = ret.at[inf_indices].set(jnp.inf)
return ret
@@ -509,5 +522,5 @@ def xlogy(x: JaxArray, y: JaxArray, /, *, out: Optional[JaxArray] = None) -> Jax
return js.special.xlogy(x, y)
-def real(x: Union[JaxArray], /, *, out: Optional[JaxArray] = None) -> JaxArray:
+def real(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.real(x)
diff --git a/ivy/functional/backends/jax/experimental/layers.py b/ivy/functional/backends/jax/experimental/layers.py
index c20c5c1847df9..b8386ccb10b43 100644
--- a/ivy/functional/backends/jax/experimental/layers.py
+++ b/ivy/functional/backends/jax/experimental/layers.py
@@ -1,5 +1,5 @@
# global
-from typing import Optional, Union, Tuple, Literal
+from typing import Optional, Union, Tuple, Literal, Sequence
import jax
import jax.lax as jlax
import jax.numpy as jnp
@@ -12,12 +12,29 @@
from ivy.functional.ivy.layers import _handle_padding
-def general_pool(inputs, init, reduce_fn, window_shape, strides, padding):
+def _from_int_to_tuple(arg, dim):
+ if isinstance(arg, int):
+ return (arg,) * dim
+ if isinstance(arg, tuple) and len(arg) == 1:
+ return (arg[0],) * dim
+ return arg
- if isinstance(strides, int):
- strides = (strides,) * len(window_shape)
- elif len(strides) == 1:
- strides = (strides[0],) * len(window_shape)
+
+def general_pool(
+ inputs, init, reduce_fn, window_shape, strides, padding, dim, dilation, ceil_mode
+):
+ window_shape = _from_int_to_tuple(window_shape, dim)
+ strides = _from_int_to_tuple(strides, dim)
+ dilation = _from_int_to_tuple(dilation, dim)
+ if isinstance(padding, int):
+ padding = [(padding,) * 2] * dim
+ elif isinstance(padding, tuple) and len(padding) == 1:
+ padding = [(padding[0],) * 2] * dim
+ elif isinstance(padding, tuple) and len(padding) == 2:
+ padding = [(padding[0],) * 2, (padding[1],) * 2]
+
+ if isinstance(padding, (tuple, list)):
+ ivy.assertions.check_kernel_padding_size(window_shape, padding)
assert len(window_shape) == len(
strides
@@ -26,6 +43,7 @@ def general_pool(inputs, init, reduce_fn, window_shape, strides, padding):
window_shape = tuple(window_shape)
strides = (1,) + strides + (1,)
dims = (1,) + window_shape + (1,)
+ dilation = (1,) + tuple(dilation) + (1,)
is_single_input = False
if inputs.ndim == len(dims) - 1:
@@ -36,11 +54,18 @@ def general_pool(inputs, init, reduce_fn, window_shape, strides, padding):
assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
- # doing manual padding instead of
+ # shape of window after dilation
+ new_window_shape = tuple(
+ [
+ window_shape[i - 1] + (dilation[i] - 1) * (window_shape[i - 1] - 1)
+ for i in range(1, len(dims) - 1)
+ ]
+ )
+ # manual padding
if isinstance(padding, str):
pad_int = [
_handle_padding(
- inputs.shape[i + 1], strides[i + 1], window_shape[i], padding
+ inputs.shape[i + 1], strides[i + 1], new_window_shape[i], padding
)
for i in range(len(dims) - 2)
]
@@ -49,8 +74,20 @@ def general_pool(inputs, init, reduce_fn, window_shape, strides, padding):
]
pad_list = [(0, 0)] + pad_list + [(0, 0)]
else:
- pad_list = [(0, 0)] + padding + [(0, 0)]
- y = jlax.reduce_window(inputs, init, reduce_fn, dims, strides, pad_list)
+ pad_list = [(0, 0)] + list(padding) + [(0, 0)]
+
+ if ceil_mode:
+ for i in range(len(dims) - 2):
+ pad_list[i + 1] = ivy.padding_ceil_mode(
+ inputs.shape[i + 1],
+ new_window_shape[i],
+ pad_list[i + 1],
+ strides[i + 1],
+ )
+
+ y = jlax.reduce_window(
+ inputs, init, reduce_fn, dims, strides, pad_list, window_dilation=dilation
+ )
if is_single_input:
y = jnp.squeeze(y, axis=0)
return y
@@ -79,7 +116,7 @@ def max_pool1d(
elif len(kernel) == 1:
kernel = (kernel[0],)
- res = general_pool(x, -jnp.inf, jlax.max, kernel, strides, padding)
+ res = general_pool(x, -jnp.inf, jlax.max, kernel, strides, padding, 1)
if data_format == "NCW":
res = jnp.transpose(x, (0, 2, 1))
@@ -90,16 +127,20 @@ def max_pool2d(
x: JaxArray,
kernel: Union[int, Tuple[int], Tuple[int, int]],
strides: Union[int, Tuple[int], Tuple[int, int]],
- padding: str,
+ padding: Union[str, int, Tuple[int], Tuple[int, int]],
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
out: Optional[JaxArray] = None,
) -> JaxArray:
if data_format == "NCHW":
x = jnp.transpose(x, (0, 2, 3, 1))
- res = general_pool(x, -jnp.inf, jlax.max, kernel, strides, padding)
+ res = general_pool(
+ x, -jnp.inf, jlax.max, kernel, strides, padding, 2, dilation, ceil_mode
+ )
if data_format == "NCHW":
return jnp.transpose(res, (0, 3, 1, 2))
@@ -121,7 +162,7 @@ def max_pool3d(
x = jnp.transpose(x, (0, 2, 3, 4, 1))
if isinstance(kernel, int):
kernel = (kernel,) * 3
- res = general_pool(x, -jnp.inf, jlax.max, kernel, strides, padding)
+ res = general_pool(x, -jnp.inf, jlax.max, kernel, strides, padding, 3)
if data_format == "NCDHW":
res = jnp.transpose(x, (0, 2, 3, 4, 1))
@@ -153,7 +194,7 @@ def avg_pool1d(
elif len(strides) == 1:
strides = (strides[0],)
- res = general_pool(x, 0.0, jlax.add, kernel, strides, padding)
+ res = general_pool(x, 0.0, jlax.add, kernel, strides, padding, 1)
div_shape = x.shape[:-1] + (1,)
if len(div_shape) - 2 == len(kernel):
div_shape = (1,) + div_shape[1:]
@@ -189,7 +230,7 @@ def avg_pool2d(
if data_format == "NCHW":
x = jnp.transpose(x, (0, 2, 3, 1))
- res = general_pool(x, 0.0, jlax.add, kernel, strides, padding)
+ res = general_pool(x, 0.0, jlax.add, kernel, strides, padding, 2)
div_shape = x.shape[:-1] + (1,)
if len(div_shape) - 2 == len(kernel):
div_shape = (1,) + div_shape[1:]
@@ -225,7 +266,7 @@ def avg_pool3d(
if data_format == "NCDHW":
x = jnp.transpose(x, (0, 2, 3, 4, 1))
- res = general_pool(x, 0.0, jlax.add, kernel, strides, padding)
+ res = general_pool(x, 0.0, jlax.add, kernel, strides, padding, 3)
res = res / general_pool(
jnp.ones_like(x, dtype=res.dtype), 0.0, jlax.add, kernel, strides, padding
@@ -387,3 +428,37 @@ def ifft(
if norm != "backward" and norm != "ortho" and norm != "forward":
raise ivy.exceptions.IvyError(f"Unrecognized normalization mode {norm}")
return jnp.fft.ifft(x, n, dim, norm)
+
+
+def interpolate(
+ x: JaxArray,
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Union[Literal["linear", "bilinear"]] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+):
+ # keeping the batch and channel dimension same
+ size = [*x.shape[0:1], *size]
+ if align_corners:
+ return ivy.interpolate(
+ x, size, mode=mode, align_corners=align_corners, antialias=antialias
+ )
+ elif mode == "linear":
+ x = jnp.transpose(x, (0, 2, 1))
+ return jnp.transpose(
+ jax.image.resize(x, shape=size, method=mode, antialias=antialias), (0, 2, 1)
+ )
+ elif mode == "bilinear":
+ x = jnp.transpose(x, (0, 2, 3, 1))
+ return jnp.transpose(
+ jax.image.resize(x, shape=size, method=mode, antialias=antialias),
+ (0, 3, 1, 2),
+ )
+ elif mode == "trilinear":
+ x = jnp.transpose(x, (0, 2, 3, 4, 1))
+ return jnp.transpose(
+ jax.image.resize(x, shape=size, method=mode, antialias=antialias),
+ (0, 4, 1, 2, 3),
+ )
diff --git a/ivy/functional/backends/jax/experimental/manipulation.py b/ivy/functional/backends/jax/experimental/manipulation.py
index d7541c0d90ca4..40a0343a73a48 100644
--- a/ivy/functional/backends/jax/experimental/manipulation.py
+++ b/ivy/functional/backends/jax/experimental/manipulation.py
@@ -276,11 +276,9 @@ def take_along_axis(
def hsplit(
ary: JaxArray,
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[JaxArray] = None,
-) -> JaxArray:
+) -> List[JaxArray]:
return jnp.hsplit(ary, indices_or_sections)
diff --git a/ivy/functional/backends/jax/experimental/norms.py b/ivy/functional/backends/jax/experimental/norms.py
index ecbf064a94a0b..8f18fa0670275 100644
--- a/ivy/functional/backends/jax/experimental/norms.py
+++ b/ivy/functional/backends/jax/experimental/norms.py
@@ -7,7 +7,10 @@
@with_unsupported_dtypes({"0.3.14 and below": ("float16",)}, backend_version)
def l2_normalize(x: JaxArray, /, *, axis: int = None, out=None) -> JaxArray:
- denorm = jnp.linalg.norm(x, axis=axis, ord=2, keepdims=True)
+ if axis is None:
+ denorm = jnp.linalg.norm(x.flatten(), 2, axis)
+ else:
+ denorm = jnp.linalg.norm(x, 2, axis, keepdims=True)
denorm = jnp.maximum(denorm, 1e-12)
return x / denorm
@@ -64,3 +67,16 @@ def instance_norm(
if data_format == "NHWC":
normalized = jnp.transpose(normalized, (0, 2, 3, 1))
return normalized
+
+
+@with_unsupported_dtypes({"0.3.14 and below": ("float16",)}, backend_version)
+def lp_normalize(
+ x: JaxArray, /, *, p: float = 2, axis: int = None, out=None
+) -> JaxArray:
+ if axis is None:
+ denorm = jnp.linalg.norm(x.flatten(), axis=axis, ord=p)
+ else:
+ denorm = jnp.linalg.norm(x, axis=axis, ord=p, keepdims=True)
+
+ denorm = jnp.maximum(denorm, 1e-12)
+ return jnp.divide(x, denorm)
diff --git a/ivy/functional/backends/jax/experimental/statistical.py b/ivy/functional/backends/jax/experimental/statistical.py
index cadfed8b3b8a6..17816e81453d8 100644
--- a/ivy/functional/backends/jax/experimental/statistical.py
+++ b/ivy/functional/backends/jax/experimental/statistical.py
@@ -69,3 +69,17 @@ def corrcoef(
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.corrcoef(x, y=y, rowvar=rowvar)
+
+
+def nanmedian(
+ input: JaxArray,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ overwrite_input: Optional[bool] = False,
+ out: Optional[JaxArray] = None,
+) -> JaxArray:
+ return jnp.nanmedian(
+ input, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, out=out
+ )
diff --git a/ivy/functional/backends/jax/layers.py b/ivy/functional/backends/jax/layers.py
index 9ace4ab1dc9e4..4c229bed055cf 100644
--- a/ivy/functional/backends/jax/layers.py
+++ b/ivy/functional/backends/jax/layers.py
@@ -110,8 +110,8 @@ def conv2d(
padding: Union[str, Sequence[Tuple[int, int]]],
/,
*,
- data_format: Optional[str] = "NHWC",
- dilations: Optional[Union[int, Tuple[int, int]]] = 1,
+ data_format: str = "NHWC",
+ dilations: Union[int, Tuple[int, int]] = 1,
out: Optional[JaxArray] = None,
) -> JaxArray:
strides = [strides] * 2 if isinstance(strides, int) else strides
diff --git a/ivy/functional/backends/jax/linear_algebra.py b/ivy/functional/backends/jax/linear_algebra.py
index 634f70f70e4cf..4942af660b292 100644
--- a/ivy/functional/backends/jax/linear_algebra.py
+++ b/ivy/functional/backends/jax/linear_algebra.py
@@ -168,12 +168,18 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
out: Optional[JaxArray] = None,
) -> JaxArray:
if transpose_a is True:
x1 = jnp.transpose(x1)
if transpose_b is True:
x2 = jnp.transpose(x2)
+ if adjoint_a is True:
+ x1 = jnp.transpose(jnp.conjugate(x1))
+ if adjoint_b is True:
+ x2 = jnp.transpose(jnp.conjugate(x2))
return jnp.matmul(x1, x2)
diff --git a/ivy/functional/backends/numpy/experimental/elementwise.py b/ivy/functional/backends/numpy/experimental/elementwise.py
index 393960202768c..0cecbff7b30da 100644
--- a/ivy/functional/backends/numpy/experimental/elementwise.py
+++ b/ivy/functional/backends/numpy/experimental/elementwise.py
@@ -119,6 +119,7 @@ def trapz(
trapz.support_native_out = False
+@_scalar_output_to_0d_array
def float_power(
x1: Union[np.ndarray, float, list, tuple],
x2: Union[np.ndarray, float, list, tuple],
@@ -126,7 +127,8 @@ def float_power(
*,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
- return np.asarray(np.float_power(x1, x2, out=out), dtype=x1.dtype)
+ x1, x2 = promote_types_of_inputs(x1, x2)
+ return np.float_power(x1, x2, out=out)
float_power.support_native_out = True
@@ -316,11 +318,11 @@ def hypot(
def diff(
- x: Union[np.ndarray, int, float, list, tuple],
+ x: Union[np.ndarray, list, tuple],
/,
*,
- n: Optional[int] = 1,
- axis: Optional[int] = -1,
+ n: int = 1,
+ axis: int = -1,
prepend: Optional[Union[np.ndarray, int, float, list, tuple]] = None,
append: Optional[Union[np.ndarray, int, float, list, tuple]] = None,
out: Optional[np.ndarray] = None,
@@ -425,5 +427,5 @@ def xlogy(
return x * np.log(y)
-def real(x: Union[np.ndarray], /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
+def real(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
return np.real(x)
diff --git a/ivy/functional/backends/numpy/experimental/layers.py b/ivy/functional/backends/numpy/experimental/layers.py
index 0860e978038fb..87830b990584c 100644
--- a/ivy/functional/backends/numpy/experimental/layers.py
+++ b/ivy/functional/backends/numpy/experimental/layers.py
@@ -7,6 +7,7 @@
# local
import ivy
from ivy.functional.ivy.layers import _handle_padding
+from ivy.functional.backends.numpy.layers import _add_dilations
def max_pool1d(
@@ -69,10 +70,12 @@ def max_pool2d(
x: np.ndarray,
kernel: Union[int, Tuple[int], Tuple[int, int]],
strides: Union[int, Tuple[int], Tuple[int, int]],
- padding: str,
+ padding: Union[str, int, Tuple[int], Tuple[int, int]],
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
@@ -86,15 +89,41 @@ def max_pool2d(
elif len(strides) == 1:
strides = [strides[0]] * 2
+ if isinstance(dilation, int):
+ dilation = [dilation] * 2
+ elif len(dilation) == 1:
+ dilation = [dilation[0]] * 2
+
+ if isinstance(padding, int):
+ padding = [(padding,) * 2] * 2
+ elif isinstance(padding, tuple) and len(padding) == 1:
+ padding = [(padding[0],) * 2] * 2
+ elif isinstance(padding, tuple) and len(padding) == 2:
+ padding = [(padding[0],) * 2, (padding[1],) * 2]
+
+ if isinstance(padding, (tuple, list)):
+ ivy.assertions.check_kernel_padding_size(kernel, padding)
+
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
x_shape = list(x.shape[1:3])
+ filters = np.ones((list(kernel)), dtype=x.dtype)
+ for j in range(2):
+ if dilation[j] > 1:
+ filters = _add_dilations(filters, dilation[j], axis=j, values=0)
+ kernel = list(filters.shape)
pad_list = padding
if isinstance(padding, str):
pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding)
pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding)
pad_list = [(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)]
+ pad_list = list(pad_list)
+ if ceil_mode:
+ for i in range(2):
+ pad_list[i] = ivy.padding_ceil_mode(
+ x_shape[i], kernel[i], pad_list[i], strides[i]
+ )
x = np.pad(
x,
@@ -103,7 +132,8 @@ def max_pool2d(
*pad_list,
(0, 0),
],
- "edge",
+ "constant",
+ constant_values=-math.inf,
)
x_shape = x.shape
@@ -112,8 +142,8 @@ def max_pool2d(
new_shape = [x_shape[0], new_h, new_w] + list(kernel) + [x_shape[-1]]
new_strides = (
x.strides[0],
- x.strides[1] * strides[1],
- x.strides[2] * strides[0],
+ x.strides[1] * strides[0],
+ x.strides[2] * strides[1],
x.strides[1],
x.strides[2],
x.strides[3],
@@ -123,6 +153,11 @@ def max_pool2d(
x, new_shape, new_strides, writeable=False
)
+ # B x OH x OW x KH x KW x I
+ sub_matrices = np.where(
+ filters.reshape([1] * 3 + list(kernel) + [1]), sub_matrices, -math.inf
+ )
+
# B x OH x OW x O
res = sub_matrices.max(axis=(3, 4))
if data_format == "NCHW":
diff --git a/ivy/functional/backends/numpy/experimental/manipulation.py b/ivy/functional/backends/numpy/experimental/manipulation.py
index 03e3f4e3163ad..fbde81072f915 100644
--- a/ivy/functional/backends/numpy/experimental/manipulation.py
+++ b/ivy/functional/backends/numpy/experimental/manipulation.py
@@ -272,11 +272,9 @@ def take_along_axis(
def hsplit(
ary: np.ndarray,
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[np.ndarray] = None,
-) -> np.ndarray:
+) -> List[np.ndarray]:
return np.hsplit(ary, indices_or_sections)
diff --git a/ivy/functional/backends/numpy/experimental/norms.py b/ivy/functional/backends/numpy/experimental/norms.py
index 7133549f77ac8..32b449980e4f3 100644
--- a/ivy/functional/backends/numpy/experimental/norms.py
+++ b/ivy/functional/backends/numpy/experimental/norms.py
@@ -6,7 +6,10 @@
@with_unsupported_dtypes({"1.23.0 and below": ("float16",)}, backend_version)
def l2_normalize(x: np.ndarray, /, *, axis: int = None, out=None) -> np.ndarray:
- denorm = np.linalg.norm(x, axis=axis, ord=2, keepdims=True)
+ if axis is None:
+ denorm = np.linalg.norm(x.flatten(), 2, axis)
+ else:
+ denorm = np.linalg.norm(x, 2, axis, keepdims=True)
denorm = np.maximum(denorm, 1e-12)
return x / denorm
@@ -62,3 +65,14 @@ def instance_norm(
if data_format == "NHWC":
normalized = np.transpose(normalized, (0, 2, 3, 1))
return normalized
+
+
+def lp_normalize(
+ x: np.ndarray, /, *, p: float = 2, axis: int = None, out=None
+) -> np.ndarray:
+ if axis is None:
+ denorm = np.linalg.norm(x.flatten(), axis=axis, ord=p)
+ else:
+ denorm = np.linalg.norm(x, axis=axis, ord=p, keepdims=True)
+ denorm = np.maximum(denorm, 1e-12)
+ return np.divide(x, denorm, out=out)
diff --git a/ivy/functional/backends/numpy/experimental/statistical.py b/ivy/functional/backends/numpy/experimental/statistical.py
index 6204af8e922b1..30fd25b063d4b 100644
--- a/ivy/functional/backends/numpy/experimental/statistical.py
+++ b/ivy/functional/backends/numpy/experimental/statistical.py
@@ -72,3 +72,17 @@ def corrcoef(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return np.corrcoef(x, y=y, rowvar=rowvar, dtype=x.dtype)
+
+
+def nanmedian(
+ input: np.ndarray,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ overwrite_input: Optional[bool] = False,
+ out: Optional[np.ndarray] = None,
+) -> np.ndarray:
+ return np.nanmedian(
+ input, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, out=out
+ )
diff --git a/ivy/functional/backends/numpy/layers.py b/ivy/functional/backends/numpy/layers.py
index b0cc4458ee2dc..68271291661d9 100644
--- a/ivy/functional/backends/numpy/layers.py
+++ b/ivy/functional/backends/numpy/layers.py
@@ -14,11 +14,11 @@
)
-def _add_dilations(x, dilations, axis):
+def _add_dilations(x, dilations, axis, values=0):
return np.insert(
x,
[i for i in range(1, x.shape[axis])] * (dilations - 1),
- values=0,
+ values=values,
axis=axis,
)
@@ -184,8 +184,8 @@ def conv2d(
padding: Union[str, Sequence[Tuple[int, int]]],
/,
*,
- data_format: Optional[str] = "NHWC",
- dilations: Optional[Union[int, Tuple[int, int]]] = 1,
+ data_format: str = "NHWC",
+ dilations: Union[int, Tuple[int, int]] = 1,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
strides = [strides] * 2 if isinstance(strides, int) else strides
diff --git a/ivy/functional/backends/numpy/linear_algebra.py b/ivy/functional/backends/numpy/linear_algebra.py
index 808cbd694ff5c..8b84e9e5eae1a 100644
--- a/ivy/functional/backends/numpy/linear_algebra.py
+++ b/ivy/functional/backends/numpy/linear_algebra.py
@@ -121,12 +121,18 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if transpose_a is True:
x1 = np.transpose(x1)
if transpose_b is True:
x2 = np.transpose(x2)
+ if adjoint_a is True:
+ x1 = np.transpose(np.conjugate(x1))
+ if adjoint_b is True:
+ x2 = np.transpose(np.conjugate(x2))
ret = np.matmul(x1, x2, out=out)
if len(x1.shape) == len(x2.shape) == 1:
ret = np.array(ret)
diff --git a/ivy/functional/backends/tensorflow/experimental/elementwise.py b/ivy/functional/backends/tensorflow/experimental/elementwise.py
index ffd12195e6a17..85dba99a17f6b 100644
--- a/ivy/functional/backends/tensorflow/experimental/elementwise.py
+++ b/ivy/functional/backends/tensorflow/experimental/elementwise.py
@@ -91,6 +91,9 @@ def trapz(
return tfp.math.trapz(y, x=x, dx=dx, axis=axis, name=None)
+@with_unsupported_dtypes(
+ {"2.9.1 and below": ("uint8", "uint16", "uint32", "uint64")}, backend_version
+)
def float_power(
x1: Union[tf.Tensor, tf.Variable, float, list, tuple],
x2: Union[tf.Tensor, tf.Variable, float, list, tuple],
@@ -98,7 +101,12 @@ def float_power(
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
- return tf.experimental.numpy.float_power(x1, x2)
+ x1, x2 = ivy.promote_types_of_inputs(x1, x2)
+ if ivy.any(ivy.is_complex_dtype(x1)) or ivy.any(ivy.is_complex_dtype(x2)):
+ out_dtype = tf.complex128
+ else:
+ out_dtype = tf.float64
+ return tf.cast(tf.experimental.numpy.float_power(x1, x2), out_dtype)
def exp2(
@@ -297,15 +305,17 @@ def nextafter(
{"2.9.1 and below": ("uint8", "uint16", "uint32", "uint64")}, backend_version
)
def diff(
- x: Union[tf.Tensor, tf.Variable, int, float, list, tuple],
+ x: Union[tf.Tensor, tf.Variable, list, tuple],
/,
*,
- n: Optional[int] = 1,
- axis: Optional[int] = -1,
+ n: int = 1,
+ axis: int = -1,
prepend: Optional[Union[tf.Tensor, tf.Variable, int, float, list, tuple]] = None,
append: Optional[Union[tf.Tensor, tf.Variable, int, float, list, tuple]] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
+ if n == 0:
+ return x
if prepend is not None:
x = tf.experimental.numpy.append(prepend, x, axis=axis)
if append is not None:
@@ -353,12 +363,12 @@ def angle(
backend_version,
)
def imag(
- input: Union[tf.Tensor, tf.Variable],
+ val: Union[tf.Tensor, tf.Variable],
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
- return tf.math.imag(input, name=None)
+ return tf.math.imag(val, name=None)
@with_supported_dtypes(
diff --git a/ivy/functional/backends/tensorflow/experimental/layers.py b/ivy/functional/backends/tensorflow/experimental/layers.py
index 70b6f806b379c..f567733312147 100644
--- a/ivy/functional/backends/tensorflow/experimental/layers.py
+++ b/ivy/functional/backends/tensorflow/experimental/layers.py
@@ -1,9 +1,21 @@
+# global
import math
-from typing import Union, Optional, Tuple, Literal
+from typing import Union, Optional, Tuple, Literal, Sequence
import tensorflow as tf
+
+# local
from ivy.func_wrapper import with_unsupported_dtypes
from .. import backend_version
import ivy
+from ivy.functional.ivy.layers import _handle_padding
+
+
+def _from_int_to_tuple(arg, dim):
+ if isinstance(arg, int):
+ return (arg,) * dim
+ if isinstance(arg, tuple) and len(arg) == 1:
+ return (arg[0],) * dim
+ return arg
def max_pool1d(
@@ -30,17 +42,49 @@ def max_pool2d(
x: Union[tf.Tensor, tf.Variable],
kernel: Union[int, Tuple[int], Tuple[int, int]],
strides: Union[int, Tuple[int], Tuple[int, int]],
- padding: str,
+ padding: Union[str, int, Tuple[int], Tuple[int, int]],
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if data_format == "NCHW":
x = tf.transpose(x, (0, 2, 3, 1))
- if not isinstance(padding, str):
- padding = [(0, 0)] + padding + [(0, 0)]
- res = tf.nn.max_pool2d(x, kernel, strides, padding)
+
+ dilation = _from_int_to_tuple(dilation, 2)
+ strides = _from_int_to_tuple(strides, 2)
+ kernel = _from_int_to_tuple(kernel, 2)
+ if isinstance(padding, int):
+ padding = [(padding,) * 2] * 2
+ elif isinstance(padding, tuple) and len(padding) == 1:
+ padding = [(padding[0],) * 2] * 2
+ elif isinstance(padding, tuple) and len(padding) == 2:
+ padding = [(padding[0],) * 2, (padding[1],) * 2]
+
+ if isinstance(padding, (tuple, list)):
+ ivy.assertions.check_kernel_padding_size(kernel, padding)
+ new_kernel = [kernel[i] + (kernel[i] - 1) * (dilation[i] - 1) for i in range(2)]
+ if isinstance(padding, str):
+ pad_h = _handle_padding(x.shape[1], strides[0], new_kernel[0], padding)
+ pad_w = _handle_padding(x.shape[2], strides[1], new_kernel[1], padding)
+ padding = [(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)]
+
+ x_shape = x.shape[1:-1]
+
+ if ceil_mode:
+ for i in range(2):
+ padding[i] = ivy.padding_ceil_mode(
+ x_shape[i], new_kernel[i], padding[i], strides[i]
+ )
+
+ padding = [(0, 0)] + list(padding) + [(0, 0)]
+ x = tf.pad(x, padding, constant_values=-math.inf)
+ res = tf.nn.pool(x, kernel, "MAX", strides, "VALID", dilations=dilation)
+
+ # converting minimum value to -inf because tensorflow clips -inf to minimum value
+ res = tf.where(res <= ivy.finfo(res.dtype).min, -math.inf, res)
if data_format == "NCHW":
return tf.transpose(res, (0, 3, 1, 2))
return res
@@ -319,3 +363,32 @@ def embedding(
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.nn.embedding_lookup(weights, indices, max_norm=max_norm)
+
+
+def interpolate(
+ x: Union[tf.Tensor, tf.Variable],
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Union[Literal["linear", "bilinear", "trilinear"]] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+):
+ if align_corners:
+ return ivy.functional.experimental.interpolate(
+ x, size, mode=mode, align_corners=align_corners, antialias=antialias
+ )
+ elif mode == "linear":
+ x = tf.transpose(x, (0, 2, 1))
+ return tf.transpose(
+ tf.image.resize(
+ x, size=[x.shape[0], size], method="bilinear", antialias=antialias
+ ),
+ (0, 2, 1),
+ )
+ elif mode == "bilinear":
+ x = tf.transpose(x, (0, 2, 3, 1))
+ return tf.transpose(tf.image.resize(x, size=size, method=mode), (0, 3, 1, 2))
+ elif mode == "trilinear":
+ x = tf.transpose(x, (0, 2, 3, 4, 1))
+ return tf.transpose(tf.image.resize(x, size=size, method=mode), (0, 4, 1, 2, 3))
diff --git a/ivy/functional/backends/tensorflow/experimental/manipulation.py b/ivy/functional/backends/tensorflow/experimental/manipulation.py
index dd3c0808925bd..306030a9e7020 100644
--- a/ivy/functional/backends/tensorflow/experimental/manipulation.py
+++ b/ivy/functional/backends/tensorflow/experimental/manipulation.py
@@ -180,11 +180,9 @@ def take_along_axis(
def hsplit(
ary: Union[tf.Tensor, tf.Variable],
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[Union[tf.Tensor, tf.Variable]] = None,
-) -> Union[tf.Tensor, tf.Variable]:
+) -> List[Union[tf.Tensor, tf.Variable]]:
return tf.experimental.numpy.hsplit(ary, indices_or_sections)
diff --git a/ivy/functional/backends/tensorflow/experimental/norms.py b/ivy/functional/backends/tensorflow/experimental/norms.py
index d2de5a5bd7710..b453a84c4156f 100644
--- a/ivy/functional/backends/tensorflow/experimental/norms.py
+++ b/ivy/functional/backends/tensorflow/experimental/norms.py
@@ -66,3 +66,11 @@ def instance_norm(
if data_format == "NCHW":
normalized = tf.transpose(normalized, (0, 3, 1, 2))
return normalized
+
+
+def lp_normalize(
+ x: Union[tf.Tensor, tf.Variable], /, *, p: float = 2, axis: int = None, out=None
+) -> tf.Tensor:
+ denorm = tf.norm(x, ord=p, axis=axis, keepdims=True)
+ denorm = tf.math.maximum(denorm, 1e-12)
+ return tf.math.divide(x, denorm)
diff --git a/ivy/functional/backends/tensorflow/experimental/statistical.py b/ivy/functional/backends/tensorflow/experimental/statistical.py
index a63aa44b00dd5..105030bb073f3 100644
--- a/ivy/functional/backends/tensorflow/experimental/statistical.py
+++ b/ivy/functional/backends/tensorflow/experimental/statistical.py
@@ -96,3 +96,20 @@ def corrcoef(
cov2_t = tf.linalg.diag(1 / tf.sqrt(tf.linalg.diag_part(cov_t)))
cor = cov2_t @ cov_t @ cov2_t
return cor
+
+
+def nanmedian(
+ input: Union[tf.Tensor, tf.Variable],
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ out: Optional[Union[tf.Tensor, tf.Variable]] = None,
+) -> Union[tf.Tensor, tf.Variable]:
+ return tfp.stats.percentile(
+ input,
+ 50.0,
+ axis=axis,
+ interpolation="midpoint",
+ keepdims=keepdims,
+ )
diff --git a/ivy/functional/backends/tensorflow/layers.py b/ivy/functional/backends/tensorflow/layers.py
index 459007a80ddbc..a9ab030c6292a 100644
--- a/ivy/functional/backends/tensorflow/layers.py
+++ b/ivy/functional/backends/tensorflow/layers.py
@@ -127,8 +127,8 @@ def conv2d(
padding: Union[str, Sequence[Tuple[int, int]]],
/,
*,
- data_format: Optional[str] = "NHWC",
- dilations: Optional[Union[int, Tuple[int, int]]] = 1,
+ data_format: str = "NHWC",
+ dilations: Union[int, Tuple[int, int]] = 1,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if data_format == "NCHW":
diff --git a/ivy/functional/backends/tensorflow/linear_algebra.py b/ivy/functional/backends/tensorflow/linear_algebra.py
index 2d298e222ac8e..c6a943ca6c58f 100644
--- a/ivy/functional/backends/tensorflow/linear_algebra.py
+++ b/ivy/functional/backends/tensorflow/linear_algebra.py
@@ -213,6 +213,8 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
@@ -224,6 +226,11 @@ def matmul(
if transpose_b is True:
x2 = tf.transpose(x2)
+ if adjoint_a is True:
+ x1 = tf.linalg.adjoint(x1)
+ if adjoint_b is True:
+ x2 = tf.linalg.adjoint(x2)
+
if dtype_from.is_unsigned or dtype_from == tf.int8 or dtype_from == tf.int16:
x1 = tf.cast(x1, tf.int64)
x2 = tf.cast(x2, tf.int64)
@@ -542,6 +549,7 @@ def slogdet(
def solve(
x1: Union[tf.Tensor, tf.Variable],
x2: Union[tf.Tensor, tf.Variable],
+ /,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
diff --git a/ivy/functional/backends/torch/experimental/elementwise.py b/ivy/functional/backends/torch/experimental/elementwise.py
index ac6737ed6075b..f4c4efb24668b 100644
--- a/ivy/functional/backends/torch/experimental/elementwise.py
+++ b/ivy/functional/backends/torch/experimental/elementwise.py
@@ -113,7 +113,11 @@ def float_power(
) -> torch.Tensor:
# Native out is supported but with restrictions leading
# to failures hence letting ivy handle it.
- return torch.float_power(x1, x2).to(x1.dtype)
+ x1, x2 = promote_types_of_inputs(x1, x2)
+ return torch.float_power(x1, x2, out=out)
+
+
+float_power.support_native_out = True
def exp2(
@@ -241,14 +245,15 @@ def angle(
def imag(
- input: torch.Tensor,
+ val: torch.Tensor,
/,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- if input.dtype != torch.complex64:
- input = input.to(torch.complex64)
- return torch.imag(input)
+ if val.dtype not in (torch.complex64, torch.complex128):
+ ret = torch.imag(val.to(torch.complex64))
+ return ret.to(val.dtype)
+ return torch.imag(val)
imag.support_native_out = False
@@ -290,11 +295,11 @@ def logaddexp2(
def diff(
- x: Union[torch.Tensor, int, float, list, tuple],
+ x: Union[torch.Tensor, list, tuple],
/,
*,
- n: Optional[int] = 1,
- axis: Optional[int] = -1,
+ n: int = 1,
+ axis: int = -1,
prepend: Optional[Union[torch.Tensor, int, float, list, tuple]] = None,
append: Optional[Union[torch.Tensor, int, float, list, tuple]] = None,
out: Optional[torch.Tensor] = None,
@@ -425,7 +430,5 @@ def xlogy(
return torch.xlogy(x, y, out=out)
-def real(
- x: Union[torch.Tensor], /, *, out: Optional[torch.Tensor] = None
-) -> torch.Tensor:
+def real(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.real(x)
diff --git a/ivy/functional/backends/torch/experimental/layers.py b/ivy/functional/backends/torch/experimental/layers.py
index 1840c4ff12f5f..09be45a145dd7 100644
--- a/ivy/functional/backends/torch/experimental/layers.py
+++ b/ivy/functional/backends/torch/experimental/layers.py
@@ -1,5 +1,5 @@
# global
-from typing import Optional, Union, Tuple, Literal
+from typing import Optional, Union, Tuple, Literal, Sequence
import torch
import math
@@ -59,10 +59,12 @@ def max_pool2d(
x: torch.Tensor,
kernel: Union[int, Tuple[int], Tuple[int, int]],
strides: Union[int, Tuple[int], Tuple[int, int]],
- padding: str,
+ padding: Union[str, int, Tuple[int], Tuple[int, int]],
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(strides, int):
@@ -75,13 +77,30 @@ def max_pool2d(
elif len(kernel) == 1:
kernel = (kernel[0], kernel[0])
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation)
+ elif len(dilation) == 1:
+ dilation = (dilation[0], dilation[0])
+
+ if isinstance(padding, int):
+ padding = [(padding,) * 2] * 2
+ elif isinstance(padding, tuple) and len(padding) == 1:
+ padding = [(padding[0],) * 2] * 2
+ elif isinstance(padding, tuple) and len(padding) == 2:
+ padding = [(padding[0],) * 2, (padding[1],) * 2]
+
+ if isinstance(padding, (tuple, list)):
+ ivy.assertions.check_kernel_padding_size(kernel, padding)
+
if data_format == "NHWC":
x = x.permute(0, 3, 1, 2)
x_shape = list(x.shape[2:])
+ new_kernel = [kernel[i] + (kernel[i] - 1) * (dilation[i] - 1) for i in range(2)]
+
if isinstance(padding, str):
- pad_h = _handle_padding(x_shape[0], strides[0], kernel[0], padding)
- pad_w = _handle_padding(x_shape[1], strides[1], kernel[1], padding)
+ pad_h = _handle_padding(x_shape[0], strides[0], new_kernel[0], padding)
+ pad_w = _handle_padding(x_shape[1], strides[1], new_kernel[1], padding)
pad_list = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
else:
# torch pad takes width padding first, then height padding
@@ -94,7 +113,7 @@ def max_pool2d(
value=float("-inf"),
)
- res = torch.nn.functional.max_pool2d(x, kernel, strides, 0)
+ res = torch.nn.functional.max_pool2d(x, kernel, strides, 0, dilation, ceil_mode)
if data_format == "NHWC":
return res.permute(0, 2, 3, 1)
return res
@@ -484,3 +503,21 @@ def embedding(
embedding.support_native_out = False
+
+
+def interpolate(
+ x: torch.Tensor,
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Optional[Literal["linear", "bilinear", "trilinear"]] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+):
+ return torch.nn.functional.interpolate(
+ x,
+ size,
+ mode=mode,
+ align_corners=align_corners,
+ antialias=antialias,
+ )
diff --git a/ivy/functional/backends/torch/experimental/manipulation.py b/ivy/functional/backends/torch/experimental/manipulation.py
index ad523e95b1454..3891989e412e2 100644
--- a/ivy/functional/backends/torch/experimental/manipulation.py
+++ b/ivy/functional/backends/torch/experimental/manipulation.py
@@ -220,12 +220,10 @@ def take_along_axis(
def hsplit(
ary: torch.Tensor,
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[torch.Tensor] = None,
-) -> torch.Tensor:
- return torch.hsplit(ary, indices_or_sections)
+) -> List[torch.Tensor]:
+ return list(torch.hsplit(ary, indices_or_sections))
take_along_axis.support_native_out = True
diff --git a/ivy/functional/backends/torch/experimental/norms.py b/ivy/functional/backends/torch/experimental/norms.py
index 9267a02515668..5cafde1662c7b 100644
--- a/ivy/functional/backends/torch/experimental/norms.py
+++ b/ivy/functional/backends/torch/experimental/norms.py
@@ -75,3 +75,14 @@ def instance_norm(
instance_norm.support_native_out = False
+
+
+@with_unsupported_dtypes({"1.11.0 and below": ("float16",)}, backend_version)
+def lp_normalize(
+ x: torch.Tensor, /, *, p: float = 2, axis: int = None, out: torch.Tensor = None
+) -> torch.Tensor:
+
+ return torch.nn.functional.normalize(x, p=p, dim=axis, out=out)
+
+
+lp_normalize.support_native_out = True
diff --git a/ivy/functional/backends/torch/experimental/statistical.py b/ivy/functional/backends/torch/experimental/statistical.py
index df0c0e31b714c..0ddfcf56d90f9 100644
--- a/ivy/functional/backends/torch/experimental/statistical.py
+++ b/ivy/functional/backends/torch/experimental/statistical.py
@@ -118,3 +118,17 @@ def corrcoef(
xarr = xarr.T if not rowvar else xarr
return torch.corrcoef(xarr)
+
+
+def nanmedian(
+ input: torch.tensor,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ overwrite_input: Optional[bool] = False,
+ out: Optional[torch.tensor] = None,
+) -> torch.tensor:
+ return torch.nanmedian(
+ input, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, out=out
+ )
diff --git a/ivy/functional/backends/torch/layers.py b/ivy/functional/backends/torch/layers.py
index 0c901961926ff..052024b696446 100644
--- a/ivy/functional/backends/torch/layers.py
+++ b/ivy/functional/backends/torch/layers.py
@@ -161,8 +161,8 @@ def conv2d(
padding: Union[str, Sequence[Tuple[int, int]]],
/,
*,
- data_format: Optional[str] = "NHWC",
- dilations: Optional[Union[int, Tuple[int, int]]] = 1,
+ data_format: str = "NHWC",
+ dilations: Union[int, Tuple[int, int]] = 1,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if data_format == "NHWC":
diff --git a/ivy/functional/backends/torch/linear_algebra.py b/ivy/functional/backends/torch/linear_algebra.py
index 4263b3ce6bf41..46dbdc85aa688 100644
--- a/ivy/functional/backends/torch/linear_algebra.py
+++ b/ivy/functional/backends/torch/linear_algebra.py
@@ -164,6 +164,8 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@@ -171,6 +173,10 @@ def matmul(
x1 = torch.t(x1)
if transpose_b is True:
x2 = torch.t(x2)
+ if adjoint_a is True:
+ x1 = torch.adjoint(x1)
+ if adjoint_b is True:
+ x2 = torch.adjoint(x2)
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
return torch.matmul(x1, x2, out=out)
@@ -414,7 +420,7 @@ def trace(
if len(x) == 0:
return ivy.array([])
ret = torch.diagonal(x, offset=offset, dim1=axis1, dim2=axis2)
- ret = torch.sum(ret)
+ ret = torch.sum(ret, dim=-1)
return ret
diff --git a/ivy/functional/backends/torch/random.py b/ivy/functional/backends/torch/random.py
index bcfa4d05a6114..e42fc5505b7fb 100644
--- a/ivy/functional/backends/torch/random.py
+++ b/ivy/functional/backends/torch/random.py
@@ -104,13 +104,14 @@ def randint(
) -> torch.Tensor:
if not dtype:
dtype = ivy.default_int_dtype()
+ if not shape:
+ shape = (1,)
dtype = ivy.as_native_dtype(dtype)
_randint_check_dtype_and_bound(low, high, dtype)
shape = _check_bounds_and_get_shape(low, high, shape)
- rand_range = high - low
if seed:
torch.manual_seed(seed)
- return torch.rand(shape, device=device).to(dtype) * rand_range + low
+ return torch.randint(low=low, high=high, size=shape, device=device, dtype=dtype)
def seed(*, seed_value: int = 0) -> None:
diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py
index d485fe375efb7..ef76ae286738b 100644
--- a/ivy/functional/frontends/jax/nn/non_linear_activations.py
+++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py
@@ -294,7 +294,7 @@ def soft_sign(x):
@to_ivy_arrays_and_back
-def softmax(x, axis=-1):
+def softmax(x, axis=-1, where=None, initial=None):
return ivy.softmax(x, axis=axis)
diff --git a/ivy/functional/frontends/jax/numpy/creation.py b/ivy/functional/frontends/jax/numpy/creation.py
index cd0ee76ad2427..720957553a7e2 100644
--- a/ivy/functional/frontends/jax/numpy/creation.py
+++ b/ivy/functional/frontends/jax/numpy/creation.py
@@ -104,3 +104,18 @@ def full_like(a, fill_value, dtype=None, shape=None):
@to_ivy_arrays_and_back
def ndim(a):
return ivy.astype(ivy.array(a.ndim), ivy.int64)
+
+
+@handle_jax_dtype
+@to_ivy_arrays_and_back
+def empty_like(a, dtype=None, shape=None):
+ # XLA cannot create uninitialized arrays
+ # jax.numpy.empty_like returns an array initialized with zeros.
+ if shape:
+ return ivy.zeros(shape, dtype=dtype)
+ return ivy.zeros_like(a, dtype=dtype)
+
+
+@to_ivy_arrays_and_back
+def full(shape, fill_value, dtype=None):
+ return ivy.full(shape, fill_value, dtype=dtype)
diff --git a/ivy/functional/frontends/jax/numpy/logic.py b/ivy/functional/frontends/jax/numpy/logic.py
index 5861e7577fda8..52787ba575faf 100644
--- a/ivy/functional/frontends/jax/numpy/logic.py
+++ b/ivy/functional/frontends/jax/numpy/logic.py
@@ -148,6 +148,11 @@ def invert(x, /):
return ivy.bitwise_invert(x)
+@to_ivy_arrays_and_back
+def isfinite(x, /):
+ return ivy.isfinite(x)
+
+
@to_ivy_arrays_and_back
def isinf(x, /):
return ivy.isinf(x)
@@ -157,3 +162,8 @@ def isinf(x, /):
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
a, b = promote_jax_arrays(a, b)
return ivy.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+
+@to_ivy_arrays_and_back
+def logical_not(x, /):
+ return ivy.logical_not(x)
diff --git a/ivy/functional/frontends/jax/numpy/manipulations.py b/ivy/functional/frontends/jax/numpy/manipulations.py
index dbc0d8fda8007..a3260d77559ef 100644
--- a/ivy/functional/frontends/jax/numpy/manipulations.py
+++ b/ivy/functional/frontends/jax/numpy/manipulations.py
@@ -123,6 +123,16 @@ def atleast_3d(*arys):
return ivy.atleast_3d(*arys)
+@to_ivy_arrays_and_back
+def atleast_1d(*arys):
+ return ivy.atleast_1d(*arys)
+
+
@to_ivy_arrays_and_back
def atleast_2d(*arys):
return ivy.atleast_2d(*arys)
+
+
+@to_ivy_arrays_and_back
+def squeeze(a, axis=None):
+ return ivy.squeeze(a, axis)
diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py
index 9c249d303d46c..f023542991dd8 100644
--- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py
+++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py
@@ -21,6 +21,11 @@ def add(x1, x2):
return ivy.add(x1, x2)
+@to_ivy_arrays_and_back
+def diff(a, n=1, axis=-1, prepend=None, append=None):
+ return ivy.diff(a, n=n, axis=axis, prepend=prepend, append=append, out=None)
+
+
@to_ivy_arrays_and_back
def arctan(x):
ret = ivy.atan(x)
@@ -222,6 +227,14 @@ def negative(
return ivy.negative(x)
+@to_ivy_arrays_and_back
+def positive(
+ x,
+ /,
+):
+ return ivy.positive(x)
+
+
@to_ivy_arrays_and_back
def rad2deg(
x,
@@ -409,3 +422,9 @@ def hypot(x1, x2, /):
@to_ivy_arrays_and_back
def floor_divide(x1, x2, /, out=None):
return ivy.floor_divide(x1, x2, out=out)
+
+
+@to_ivy_arrays_and_back
+def inner(a, b):
+ a, b = promote_types_of_jax_inputs(a, b)
+ return ivy.inner(a, b)
diff --git a/ivy/functional/frontends/numpy/linalg/decompositions.py b/ivy/functional/frontends/numpy/linalg/decompositions.py
index e0b0e55db1fa2..0452e9c19a271 100644
--- a/ivy/functional/frontends/numpy/linalg/decompositions.py
+++ b/ivy/functional/frontends/numpy/linalg/decompositions.py
@@ -16,4 +16,4 @@ def qr(a, mode="reduced"):
@to_ivy_arrays_and_back
def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
# Todo: conpute_uv and hermitian handling
- return ivy.svd(a, full_matrices=full_matrices)
+ return ivy.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
diff --git a/ivy/functional/frontends/numpy/linalg/solving_equations_and_inverting_matrices.py b/ivy/functional/frontends/numpy/linalg/solving_equations_and_inverting_matrices.py
index 28148b87cc6a3..f177ad4413299 100644
--- a/ivy/functional/frontends/numpy/linalg/solving_equations_and_inverting_matrices.py
+++ b/ivy/functional/frontends/numpy/linalg/solving_equations_and_inverting_matrices.py
@@ -4,7 +4,7 @@
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.frontends.numpy import promote_types_of_numpy_inputs
-
+from ivy.functional.frontends.numpy.linalg.norms_and_other_numbers import matrix_rank
# solve
@with_unsupported_dtypes({"1.23.0 and below": ("float16",)}, "numpy")
@@ -44,3 +44,15 @@ def tensorinv(a, ind=2):
ia = ivy.inv(a)
new_shape = tuple([*invshape])
return ivy.reshape(ia, shape=new_shape)
+
+
+@to_ivy_arrays_and_back
+@with_unsupported_dtypes({"1.23.0 and below": ("float16",)}, "numpy")
+def lstsq(a, b, rcond="warn"):
+ solution = ivy.matmul(
+ ivy.pinv(a, rtol=1e-15).astype(ivy.float64), b.astype(ivy.float64)
+ )
+ svd = ivy.svd(a, compute_uv=False)
+ rank = matrix_rank(a).astype(ivy.int32)
+ residuals = ivy.sum((b - ivy.matmul(a, solution)) ** 2).astype(ivy.float64)
+ return (solution, residuals, rank, svd[0])
diff --git a/ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py b/ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py
index c6bd0fbf6c451..07e766c44c55c 100644
--- a/ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py
+++ b/ivy/functional/frontends/numpy/mathematical_functions/arithmetic_operations.py
@@ -152,9 +152,8 @@ def _float_power(
dtype=None,
subok=True,
):
- x1 = ivy.astype(x1, ivy.as_ivy_dtype("float64"))
x1, x2 = promote_types_of_numpy_inputs(x1, x2)
- ret = ivy.pow(x1, x2, out=out)
+ ret = ivy.float_power(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret
diff --git a/ivy/functional/frontends/numpy/mathematical_functions/miscellaneous.py b/ivy/functional/frontends/numpy/mathematical_functions/miscellaneous.py
index b926a25713fad..41a095de03921 100644
--- a/ivy/functional/frontends/numpy/mathematical_functions/miscellaneous.py
+++ b/ivy/functional/frontends/numpy/mathematical_functions/miscellaneous.py
@@ -242,61 +242,7 @@ def real_if_close(a, tol=100):
@to_ivy_arrays_and_back
@from_zero_dim_arrays_to_scalar
def interp(x, xp, fp, left=None, right=None, period=None):
- x_arr = ivy.array(x)
- fix_later = False
- if x_arr.shape == ():
- x_arr = ivy.array([x])
- fix_later = True
- x = ivy.astype(x_arr, "float64")
- xp = ivy.astype(ivy.array(xp), "float64")
- fp = ivy.astype(ivy.array(fp), "float64")
- ivy.assertions.check_equal(xp.ndim, 1)
- ivy.assertions.check_equal(fp.ndim, 1)
- ivy.assertions.check_equal(xp.shape[0], fp.shape[0])
- if period is not None:
- ivy.assertions.check_equal(period, 0, inverse=True)
- period = ivy.abs(period)
- x = ivy.remainder(x, period)
- xp = ivy.remainder(xp, period)
- asort_xp = ivy.argsort(xp)
- xp = xp[asort_xp]
- fp = fp[asort_xp]
- xp = ivy.concat((xp[-1:] - period, xp, xp[0:1] + period))
- fp = ivy.concat((fp[-1:], fp, fp[0:1]))
-
- def interp_inner(value):
- if value < xp[0]:
- return left if left is not None else fp[0]
- elif value > xp[-1]:
- return right if right is not None else fp[-1]
- else:
- last = None
- if xp.shape[0] < 3:
- for i in range(xp.shape[0]):
- if xp[i] == value:
- return fp[i]
- elif xp[i] < value:
- last = i
- else:
- first = 0
- last = xp.shape[0]
- while first <= last:
- midpoint = (first + last) // 2
- if xp[midpoint] == value:
- return fp[midpoint]
- else:
- if value < xp[midpoint]:
- last = midpoint - 1
- else:
- first = midpoint + 1
- dist = (value - xp[last]) / (xp[last + 1] - xp[last])
- return (fp[last + 1] - fp[last]) * dist + fp[last]
-
- ret = ivy.map(interp_inner, unique={"value": x})
- if fix_later:
- return ivy.astype(ivy.array(ret[0]), "float64")
- else:
- return ivy.astype(ivy.array(ret), "float64")
+ return ivy.interp(x, xp, fp, left=left, right=right, period=period)
@handle_numpy_out
diff --git a/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py b/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py
index e757e7b0cf3a0..41a7be75612d1 100644
--- a/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py
+++ b/ivy/functional/frontends/numpy/mathematical_functions/sums_products_differences.py
@@ -116,3 +116,8 @@ def nancumprod(a, /, axis=None, dtype=None, out=None):
def nancumsum(a, /, axis=None, dtype=None, out=None):
a = ivy.where(ivy.isnan(a), ivy.zeros_like(a), a)
return ivy.cumsum(a, axis=axis, dtype=dtype, out=out)
+
+
+@to_ivy_arrays_and_back
+def diff(x, /, *, n=1, axis=-1, prepend=None, append=None):
+ return ivy.diff(x, n=n, axis=axis, prepend=prepend, append=append)
diff --git a/ivy/functional/frontends/numpy/ndarray/ndarray.py b/ivy/functional/frontends/numpy/ndarray/ndarray.py
index c3e1d6290321a..4578cb9f330c8 100644
--- a/ivy/functional/frontends/numpy/ndarray/ndarray.py
+++ b/ivy/functional/frontends/numpy/ndarray/ndarray.py
@@ -275,6 +275,9 @@ def std(
where=where,
)
+ def tobytes(self, order="C") -> bytes:
+ return np_frontend.tobytes(self.data, order=order)
+
def __add__(self, value, /):
return np_frontend.add(self._ivy_array, value)
@@ -290,6 +293,9 @@ def __mul__(self, value, /):
def __truediv__(self, value, /):
return np_frontend.true_divide(self._ivy_array, value)
+ def __floordiv__(self, value, /):
+ return np_frontend.floor_divide(self._ivy_array, value)
+
def __rtruediv__(self, value, /):
return np_frontend.true_divide(value, self._ivy_array)
@@ -385,6 +391,9 @@ def __imul__(self, value, /):
def __itruediv__(self, value, /):
return np_frontend.true_divide(self._ivy_array, value)
+ def __ifloordiv__(self, value, /):
+ return np_frontend.floor_divide(self._ivy_array, value, out=self)
+
def __ipow__(self, value, /):
return np_frontend.power(self._ivy_array, value)
@@ -403,6 +412,9 @@ def __imod__(self, value, /):
def __abs__(self):
return np_frontend.absolute(self._ivy_array)
+ def __array__(self, dtype, /):
+ return ivy.array(ivy.reshape(self._ivy_array, -1), dtype)[0]
+
def __getitem__(self, query):
ret = ivy.get_item(self._ivy_array, query)
return np_frontend.numpy_dtype_to_scalar[ivy.dtype(self._ivy_array)](
diff --git a/ivy/functional/frontends/tensorflow/compat/v1/nn.py b/ivy/functional/frontends/tensorflow/compat/v1/nn.py
index e69de29bb2d1d..a56fe2a747768 100644
--- a/ivy/functional/frontends/tensorflow/compat/v1/nn.py
+++ b/ivy/functional/frontends/tensorflow/compat/v1/nn.py
@@ -0,0 +1,74 @@
+# local
+import ivy
+from ivy.functional.frontends.tensorflow.func_wrapper import to_ivy_arrays_and_back
+from ivy.func_wrapper import with_supported_dtypes
+
+
+# should have float16 as well but sqrt doesn't support it
+@to_ivy_arrays_and_back
+@with_supported_dtypes({"2.9.0 and below": ("float32",)}, "tensorflow")
+def fused_batch_norm(
+ x,
+ scale,
+ offset,
+ mean=None,
+ variance=None,
+ epsilon=1e-3,
+ data_format="NHWC",
+ is_training=True,
+ name=None,
+ exponential_avg_factor=1.0,
+):
+ min_epsilon = 1.001e-5
+ epsilon = epsilon if epsilon > min_epsilon else min_epsilon
+
+ dims = len(x.shape)
+ if data_format[1] == "C":
+ if dims == 4:
+ x = ivy.permute_dims(x, axes=(0, 2, 3, 1))
+ elif dims == 5:
+ x = ivy.permute_dims(x, axes=(0, 2, 3, 4, 1))
+ else:
+ raise ivy.exceptions.IvyException(
+ "input tensor must be of 4 or 5 dimensions, got {}".format(dims)
+ )
+
+ scale = scale.astype(ivy.float32)
+ offset = offset.astype(ivy.float32)
+ old_mean = mean.astype(ivy.float32)
+ old_var = variance.astype(ivy.float32)
+ x = x.astype(ivy.float32)
+
+ if is_training:
+ depth = x.shape[-1]
+ rest_size = ivy.prod(x.shape) // depth
+ x_rest_by_depth = ivy.reshape(x, [rest_size, depth])
+ mean = ivy.mean(x_rest_by_depth, axis=0, keepdims=True)
+ variance = ivy.var(x_rest_by_depth, axis=0, keepdims=True)
+ y = ivy.reshape(
+ scale * (x_rest_by_depth - mean) / ivy.sqrt(variance + epsilon) + offset,
+ x.shape,
+ )
+ variance = variance * rest_size / (rest_size - 1) if rest_size > 1 else variance
+ mean = ivy.reshape(
+ mean * exponential_avg_factor + old_mean * (1 - exponential_avg_factor),
+ old_mean.shape,
+ )
+ variance = ivy.reshape(
+ variance * exponential_avg_factor + old_var * (1 - exponential_avg_factor),
+ old_var.shape,
+ )
+ else:
+ y = scale * (x - old_mean) / ivy.sqrt(old_var + epsilon) + offset
+
+ # permute dimensions back
+ if data_format[1] == "C":
+ if dims == 4:
+ y = ivy.permute_dims(y, axes=(0, 3, 1, 2))
+ elif dims == 5:
+ y = ivy.permute_dims(y, axes=(0, 4, 1, 2, 3))
+
+ if is_training:
+ return y, mean, variance
+ else:
+ return y, old_mean, old_var
diff --git a/ivy/functional/frontends/tensorflow/general_functions.py b/ivy/functional/frontends/tensorflow/general_functions.py
index 921ece2b48b06..c6a32edbf3c7d 100644
--- a/ivy/functional/frontends/tensorflow/general_functions.py
+++ b/ivy/functional/frontends/tensorflow/general_functions.py
@@ -34,6 +34,23 @@ def clip_by_value(t, clip_value_min, clip_value_max):
return ivy.clip(t, clip_value_min, clip_value_max)
+@to_ivy_arrays_and_back
+def clip_by_norm(t, clip_norm, axes=None):
+ t = ivy.array(t)
+ l2sum = ivy.sum(t * t, axis=axes, keepdims=True)
+ pred = l2sum > 0
+
+ l2sum_safe = ivy.where(pred, l2sum, ivy.ones_like(l2sum))
+ l2norm = ivy.where(pred, ivy.sqrt(l2sum_safe), l2sum)
+ intermediate = t * clip_norm
+ assert t.shape == intermediate.shape, "Dimensions %s and %s are not compatible" % (
+ t.shape,
+ intermediate.shape,
+ )
+ t_clip = intermediate / ivy.maximum(l2norm, clip_norm)
+ return t_clip
+
+
@with_unsupported_dtypes({"2.9.0 and below": ("float16", "bfloat16")}, "tensorflow")
@handle_tf_dtype
@to_ivy_arrays_and_back
@@ -202,9 +219,13 @@ def boolean_mask(tensor, mask, axis=None, name=None):
k = ivy.get_num_dims(mask)
if axis < 0:
axis = n + axis
- ivy.assertions.check_less(k + axis, n, allow_equal=True,
- message="Value of axis must be \
- such that axis + dim(mask) <= dim(tensor)")
+ ivy.assertions.check_less(
+ k + axis,
+ n,
+ allow_equal=True,
+ message="Value of axis must be \
+ such that axis + dim(mask) <= dim(tensor)",
+ )
tensor_shape = ivy.shape(tensor)
for i in range(axis - 1, -1, -1):
mask = ivy.expand_dims(mask, axis=0)
diff --git a/ivy/functional/frontends/tensorflow/linalg.py b/ivy/functional/frontends/tensorflow/linalg.py
index 535633a89e926..3f1692258ac3d 100644
--- a/ivy/functional/frontends/tensorflow/linalg.py
+++ b/ivy/functional/frontends/tensorflow/linalg.py
@@ -189,7 +189,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None):
@to_ivy_arrays_and_back
def trace(x, name=None):
- return ivy.trace(x)
+ return ivy.trace(x, axis1=-2, axis2=-1)
@to_ivy_arrays_and_back
diff --git a/ivy/functional/frontends/torch/creation_ops.py b/ivy/functional/frontends/torch/creation_ops.py
index e571c55a5d282..26b9aef0f8ef9 100644
--- a/ivy/functional/frontends/torch/creation_ops.py
+++ b/ivy/functional/frontends/torch/creation_ops.py
@@ -20,7 +20,7 @@ def empty(
raise TypeError("empty() got multiple values for argument 'shape'")
if size is not None:
return ivy.empty(shape=size, dtype=dtype, device=device, out=out)
- size = args[0] if isinstance(args[0], tuple) else args
+ size = args[0] if isinstance(args[0], (tuple, list)) else args
return ivy.empty(shape=size, dtype=dtype, device=device, out=out)
@@ -47,7 +47,7 @@ def ones(*args, size=None, out=None, dtype=None, device=None, requires_grad=Fals
raise TypeError("ones() got multiple values for argument 'shape'")
if size is not None:
return ivy.ones(shape=size, dtype=dtype, device=device, out=out)
- size = args[0] if isinstance(args[0], tuple) else args
+ size = args[0] if isinstance(args[0], (tuple, list)) else args
return ivy.ones(shape=size, dtype=dtype, device=device, out=out)
@@ -81,7 +81,7 @@ def zeros(*args, size=None, out=None, dtype=None, device=None, requires_grad=Fal
raise TypeError("zeros() got multiple values for argument 'shape'")
if size is not None:
return ivy.zeros(shape=size, dtype=dtype, device=device, out=out)
- size = args[0] if isinstance(args[0], tuple) else args
+ size = args[0] if isinstance(args[0], (tuple, list)) else args
return ivy.zeros(shape=size, dtype=dtype, device=device, out=out)
diff --git a/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py b/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py
index 887dbfb0357f6..c88fb063d7945 100644
--- a/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py
+++ b/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py
@@ -208,5 +208,9 @@ def dsplit(input, indices_or_sections):
@to_ivy_arrays_and_back
+def hsplit(input, indices_or_sections):
+ return tuple(ivy.hsplit(input, indices_or_sections))
+
+
def row_stack(tensors, *, out=None):
return ivy.vstack(tensors, out=out)
diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py
index bc94e95f6661d..458a92a88eb3e 100644
--- a/ivy/functional/frontends/torch/linalg.py
+++ b/ivy/functional/frontends/torch/linalg.py
@@ -95,7 +95,7 @@ def matrix_norm(input, ord="fro", dim=(-2, -1), keepdim=False, *, dtype=None, ou
@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.11.0 and below": ("float16",)}, "torch")
def cross(input, other, *, dim=None, out=None):
- return torch_frontend.cross(input, other, dim=dim, out=out)
+ return torch_frontend.miscellaneous_ops.cross(input, other, dim=dim, out=out)
@to_ivy_arrays_and_back
diff --git a/ivy/functional/frontends/torch/nn/functional/convolution_functions.py b/ivy/functional/frontends/torch/nn/functional/convolution_functions.py
index 40262e2f0cff7..7e67e9ea13ca1 100644
--- a/ivy/functional/frontends/torch/nn/functional/convolution_functions.py
+++ b/ivy/functional/frontends/torch/nn/functional/convolution_functions.py
@@ -1,16 +1,12 @@
+# global
+import math
+
+# local
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back
-def _div_rtn(x, y):
- q = x / y
- r = x % y
- if (r != 0) and ((r < 0) != (y < 0)):
- q = q - 1
- return q
-
-
def _valid_shapes(input, weight, bias, stride, padding, groups, transpose=False):
in_channels = input.shape[1]
@@ -36,7 +32,7 @@ def _valid_shapes(input, weight, bias, stride, padding, groups, transpose=False)
stride, 1, message="padding cannot be 'same' for stride > 1"
)
else:
- for i in padding:
+ for i in stride:
ivy.assertions.check_equal(
i, 1, message="padding cannot be 'same' for stride > 1"
)
@@ -46,394 +42,283 @@ def _valid_shapes(input, weight, bias, stride, padding, groups, transpose=False)
ivy.assertions.check_equal(
in_channels,
in_channels_by_groups * groups,
- message="in_channels must be consistent",
+ message="in_channels must be consistent between input and weight",
)
else:
ivy.assertions.check_equal(
- in_channels, weight.shape[0], message="out_channels must be consistent"
+ in_channels,
+ weight.shape[0],
+ message="in_channels must be consistent between input and weight",
)
-@with_unsupported_dtypes(
- {
- "1.11.0 and below": (
- "float16",
- "bfloat16",
- )
- },
- "torch",
-)
-@to_ivy_arrays_and_back
-def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+def _conv(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ dims = len(input.shape) - 2
_valid_shapes(input, weight, bias, stride, padding, groups)
- if type(padding) == str:
+ if isinstance(padding, str):
padding = padding.upper()
else:
- _pad_w = padding if isinstance(padding, int) else padding[0]
- input = ivy.zero_pad(
- input,
- pad_width=[(0, 0), (0, 0), (_pad_w, _pad_w)],
- )
+ padding = [padding] * dims if isinstance(padding, int) else padding
+ pad_width = [(0, 0), (0, 0), *[(p, p) for p in padding]]
+ input = ivy.zero_pad(input, pad_width)
padding = "VALID"
- weight = ivy.permute_dims(weight, axes=(2, 1, 0))
+ weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 1, 0))
ret = ivy.conv(
input,
weight,
stride,
padding,
+ dims=dims,
data_format="channel_first",
dilations=dilation,
feature_group_count=groups,
- dims=1,
)
-
if bias is not None:
- return ivy.add(ret, ivy.expand_dims(bias, axis=(0, 2)))
+ return ivy.add(ret, ivy.expand_dims(bias, axis=(0, *range(2, dims + 2))))
return ret
-@with_unsupported_dtypes(
- {
- "1.11.0 and below": (
- "float16",
- "bfloat16",
- )
- },
- "torch",
-)
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
@to_ivy_arrays_and_back
-def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
- _valid_shapes(input, weight, bias, stride, padding, groups)
+def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ return _conv(
+ input,
+ weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
- if isinstance(padding, str):
- padding = padding.upper()
- else:
- _pad_h, _pad_w = (
- (padding, padding) if isinstance(padding, int) else (padding[0], padding[1])
- )
- input = ivy.zero_pad(
- input, pad_width=[(0, 0), (0, 0), (_pad_h, _pad_h), (_pad_w, _pad_w)]
- )
- padding = "VALID"
- weight = ivy.permute_dims(weight, axes=(2, 3, 1, 0))
- ret = ivy.conv(
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
+@to_ivy_arrays_and_back
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ return _conv(
input,
weight,
- stride,
- padding,
- data_format="channel_first",
- dilations=dilation,
- feature_group_count=groups,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
)
- if bias is not None:
- return ivy.add(ret, ivy.expand_dims(bias, axis=(0, 2, 3)))
- return ret
-@with_unsupported_dtypes(
- {
- "1.11.0 and below": (
- "float16",
- "bfloat16",
- )
- },
- "torch",
-)
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
@to_ivy_arrays_and_back
def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
- _valid_shapes(input, weight, bias, stride, padding, groups)
+ return _conv(
+ input,
+ weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
- if isinstance(padding, str):
- padding = padding.upper()
- else:
- _pad_t, _pad_h, _pad_w = (
- (padding, padding, padding)
- if isinstance(padding, int)
- else (padding[0], padding[1], padding[2])
- )
- input = ivy.zero_pad(
- input,
- pad_width=[
- (0, 0),
- (0, 0),
- (_pad_t, _pad_t),
- (_pad_h, _pad_h),
- (_pad_w, _pad_w),
- ],
- )
- padding = "VALID"
- weight = ivy.permute_dims(weight, axes=(2, 3, 4, 1, 0))
- ret = ivy.conv(
+# ToDo: add support / debug non-default stride, padding, and output_padding
+def _conv_transpose(
+ input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ groups=1,
+ dilation=1,
+):
+ dims = len(input.shape) - 2
+ _valid_shapes(input, weight, bias, stride, padding, groups, transpose=True)
+
+ padding = [padding] * dims if isinstance(padding, int) else list(padding)
+ paired_padding = [(padding[i], padding[i]) for i in reversed(range(len(padding)))]
+
+ weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1))
+
+ ret = ivy.conv_general_transpose(
input,
weight,
stride,
- padding,
+ paired_padding,
+ dims=dims,
data_format="channel_first",
dilations=dilation,
feature_group_count=groups,
- dims=3,
)
if bias is not None:
- return ivy.add(ret, ivy.expand_dims(bias, axis=(0, 2, 3, 4)))
- return ret
+ ret = ivy.add(ret, ivy.expand_dims(bias, axis=(0, *range(2, dims + 2))))
-
-@with_unsupported_dtypes(
- {
- "1.11.0 and below": (
- "uint8",
- "integer",
- )
- },
- "torch",
-)
-@to_ivy_arrays_and_back
-def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
-
- kernel_size = ivy.repeat(ivy.asarray(kernel_size), 2)[:2]
- dilation = ivy.repeat(ivy.asarray(dilation), 2)[:2]
- padding = ivy.repeat(ivy.asarray(padding), 2)[:2]
- stride = ivy.repeat(ivy.asarray(stride), 2)[:2]
-
- kernel_height, kernel_width = kernel_size
- dilation_height, dilation_width = dilation
- pad_height, pad_width = padding
- stride_height, stride_width = stride
-
- ivy.assertions.check_true(
- kernel_width > 0 and kernel_height > 0,
- message="kernel size should be greater than zero",
- )
- ivy.assertions.check_true(
- dilation_width > 0 and dilation_height > 0,
- message="dilation should be greater than zero",
+ out_pad = (
+ [output_padding] * dims
+ if isinstance(output_padding, int)
+ else list(output_padding)
)
- ivy.assertions.check_true(
- pad_width >= 0 and pad_height >= 0, message="padding should be non-negative"
- )
- ivy.assertions.check_true(
- stride_width > 0 and stride_height > 0,
- message="stride should be greater than zero",
- )
-
- input = ivy.asarray(input)
- ndim = input.ndim
+ paired_out_pad = [(out_pad[i], out_pad[i]) for i in reversed(range(len(out_pad)))]
- valid_dims = input.shape[1] != 0 and input.shape[2] != 0
- ivy.assertions.check_true(
- (ndim == 3 and input.shape[0] != 0 and valid_dims)
- or (ndim == 4 and valid_dims and input.shape[3] != 0),
- message="expected 3D or 4D (batch mode) tensor "
- "with possibly 0 batch size "
- "and other non-zero dimensions for input",
- )
-
- dim_batch = 0
- if ndim == 3:
- dim_batch = -1
-
- input_height = input.shape[dim_batch + 2]
- input_width = input.shape[dim_batch + 3]
- output_height = int(
- _div_rtn(
- input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1),
- stride_height,
- )
- + 1
- )
- output_width = int(
- _div_rtn(
- input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1),
- stride_width,
- )
- + 1
- )
+ ret = ivy.zero_pad(ret, [(0, 0), (0, 0), *paired_out_pad])
+ return ret
- ivy.assertions.check_true(
- output_width >= 1 and output_height >= 1,
- message="calculated shape of the array " "of sliding blocks is non-positive",
- )
- batched_input = True
- if input.ndim == 3:
- batched_input = False
- input = ivy.reshape(input, (1, input.shape[0], input.shape[1], input.shape[2]))
-
- batch_size = input.shape[0]
- n_input_channels = input.shape[1]
- input_height = input.shape[2]
- input_width = input.shape[3]
- output_height = int(
- (input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1))
- / stride_height
- + 1
- )
- output_width = int(
- (input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1))
- / stride_width
- + 1
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
+@to_ivy_arrays_and_back
+def conv_transpose1d(
+ input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ groups=1,
+ dilation=1,
+):
+ return _conv_transpose(
+ input,
+ weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
)
- n_output_channels = n_input_channels * kernel_width * kernel_height
- output_length = output_height * output_width
-
- output = ivy.zeros((batch_size, int(n_output_channels), output_length))
-
- height_col = output_height
- width_col = output_width
- channels_col = int(n_input_channels * kernel_height * kernel_width)
-
- for elt in range(batch_size):
- data_im = input[elt]
- data_col = output[elt]
-
- for c_col in range(channels_col):
- w_offset = c_col % kernel_width
- h_offset = int((c_col / kernel_width) % kernel_height)
- c_im = int(c_col / kernel_height / kernel_width)
-
- for h_col in range(height_col):
- h_im = h_col * stride_height - pad_height + h_offset * dilation_height
-
- for w_col in range(width_col):
- w_im = w_col * stride_width - pad_width + w_offset * dilation_width
-
- if 0 <= h_im < input_height and 0 <= w_im < input_width:
- data_col[h_col, c_col + w_col] = data_im[c_im, h_im, w_im]
-
- if not batched_input:
- output = ivy.squeeze(output, axis=0)
-
- return output
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
@to_ivy_arrays_and_back
-def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
-
- output_size = ivy.repeat(ivy.asarray(output_size), 2)[:2]
- kernel_size = ivy.repeat(ivy.asarray(kernel_size), 2)[:2]
- dilation = ivy.repeat(ivy.asarray(dilation), 2)[:2]
- padding = ivy.repeat(ivy.asarray(padding), 2)[:2]
- stride = ivy.repeat(ivy.asarray(stride), 2)[:2]
-
- output_height, output_width = output_size
- kernel_height, kernel_width = kernel_size
- dilation_height, dilation_width = dilation
- pad_height, pad_width = padding
- stride_height, stride_width = stride
-
- ivy.assertions.check_true(
- output_width >= 1 or output_height >= 1,
- message="expected output spatial size to be positive",
- )
- ivy.assertions.check_true(
- kernel_width > 0 and kernel_height > 0,
- message="kernel size should be greater than zero",
- )
- ivy.assertions.check_true(
- stride_width > 0 and stride_height > 0,
- message="stride should be greater than zero",
- )
- ivy.assertions.check_true(
- dilation_width > 0 and dilation_height > 0,
- message="dilation should be greater than zero",
+def conv_transpose2d(
+ input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ groups=1,
+ dilation=1,
+):
+ return _conv_transpose(
+ input,
+ weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
)
- input = ivy.asarray(input)
- ndim = input.ndim
- ivy.assertions.check_true(
- (ndim == 2 and input.shape[0] and input.shape[1])
- or (ndim == 3 and input.shape[1] and input.shape[2]),
- message="expected 2D or 3D (batch mode) tensor "
- "with possibly 0 batch size and "
- "non-zero dimensions for input",
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
+@to_ivy_arrays_and_back
+def conv_transpose3d(
+ input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ groups=1,
+ dilation=1,
+):
+ return _conv_transpose(
+ input,
+ weight,
+ bias=bias,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
)
- dim_batch = 0
- if input.ndim == 3:
- dim_batch = -1
- n_input_channels = input.shape[dim_batch + 1]
-
- ivy.assertions.check_true(
- n_input_channels % (kernel_width * kernel_height) != 0,
- message="expected size of input's dimension 1 to be "
- "divisible by the product of kernel_size",
- )
+# ToDo: both for fold and unfold, the conversion to numpy and back to ivy can be removed
+# as soon as scatter_nd stops failing for jax and tensorflow when given slices.
- input_length = input.shape[dim_batch + 2]
- blocks_height = int(
- _div_rtn(
- output_height
- + 2 * pad_height
- - (dilation_height * (kernel_height - 1) - 1),
- stride_height,
- )
- + 1
- )
- blocks_width = int(
- _div_rtn(
- output_width + 2 * pad_width - (dilation_width * (kernel_width - 1) - 1),
- stride_width,
- )
+@to_ivy_arrays_and_back
+def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
+ if input.ndim != 4:
+ raise ivy.exceptions.IvyException("only batched 4D inputs are supported")
+ stride = [stride] * 2 if isinstance(stride, int) else stride
+ dilation = [dilation] * 2 if isinstance(dilation, int) else dilation
+ padding = [padding] * 2 if isinstance(padding, int) else padding
+ kernel_size = [kernel_size] * 2 if isinstance(kernel_size, int) else kernel_size
+ output_shape = [
+ (input.shape[i + 2] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1)
+ // stride[i]
+ 1
+ for i in range(2)
+ ]
+ ret = ivy.zeros((*input.shape[0:2], *kernel_size, *output_shape), dtype=input.dtype)
+ input_padded = ivy.zero_pad(
+ input,
+ ((0, 0), (0, 0), (padding[0],) * 2, (padding[1],) * 2),
)
-
- ivy.assertions.check_true(
- input_length != (blocks_height * blocks_width),
- message="expected size of input's dimension 2 to "
- "match the calculated number of sliding blocks",
+ ret = ret.to_numpy()
+ input_padded = input_padded.to_numpy()
+ for i in range(output_shape[0]):
+ for j in range(output_shape[1]):
+ i_in = i * stride[0]
+ j_in = j * stride[1]
+ ret[:, :, :, :, i, j] = input_padded[
+ :,
+ :,
+ i_in : i_in + kernel_size[0] * dilation[0] : dilation[0],
+ j_in : j_in + kernel_size[1] * dilation[1] : dilation[1],
+ ]
+ return ivy.reshape(
+ ret, (input.shape[0], input.shape[1] * math.prod(kernel_size), -1)
)
- batched_input = True
- if input.ndim == 2:
- batched_input = False
- input = ivy.reshape(input, (1, input.shape[0], input.shape[1]))
-
- batch_size = input.shape[0]
- n_output_channels = int(n_input_channels / (kernel_width * kernel_height))
-
- output = ivy.zeros(
- (batch_size, n_output_channels, int(output_height), int(output_width))
- )
- height_col = int(
- (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1))
- / stride_height
+@to_ivy_arrays_and_back
+def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
+ orig_ndim = input.ndim
+ if orig_ndim == 2:
+ input = ivy.expand_dims(input, axis=0)
+ elif orig_ndim != 3:
+ raise ivy.exceptions.IvyException("only 2D or batched 3D inputs are supported")
+ stride = [stride] * 2 if isinstance(stride, int) else stride
+ dilation = [dilation] * 2 if isinstance(dilation, int) else dilation
+ padding = [padding] * 2 if isinstance(padding, int) else padding
+ kernel_size = [kernel_size] * 2 if isinstance(kernel_size, int) else kernel_size
+ output_size = [output_size] * 2 if isinstance(output_size, int) else output_size
+ input_shape = [
+ (output_size[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1)
+ // stride[i]
+ 1
+ for i in range(2)
+ ]
+ n_batches = input.shape[0]
+ n_channels = input.shape[1] // math.prod(kernel_size)
+ output = ivy.zeros((n_batches, n_channels, *output_size), dtype=input.dtype)
+ output_padded = ivy.zero_pad(
+ output,
+ ((0, 0), (0, 0), (padding[0],) * 2, (padding[1],) * 2),
)
- width_col = int(
- (output_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1))
- / stride_width
- + 1
+ output_padded = ivy.to_numpy(output_padded)
+ k = 0
+ for i in range(input_shape[0]):
+ for j in range(input_shape[1]):
+ i_in = i * stride[0]
+ j_in = j * stride[1]
+ patch = ivy.to_numpy(
+ input[:, :, k].reshape((n_batches, n_channels, *kernel_size))
+ )
+ output_padded[
+ :,
+ :,
+ i_in : i_in + kernel_size[0] * dilation[0] : dilation[0],
+ j_in : j_in + kernel_size[1] * dilation[1] : dilation[1],
+ ] += patch
+ k += 1
+ return ivy.array(
+ output_padded[:, :, padding[0] : -padding[0], padding[1] : -padding[1]]
)
- channels_col = int(n_output_channels * kernel_height * kernel_width)
-
- for elt in range(batch_size):
- data_col = input[elt]
- data_im = output[elt]
-
- for c_col in range(channels_col):
- w_offset = c_col % kernel_width
- h_offset = int((c_col / kernel_width) % kernel_height)
- c_im = int(c_col / kernel_height / kernel_width)
-
- for h_col in range(height_col):
- h_im = h_col * stride_height - pad_height + h_offset * dilation_height
-
- for w_col in range(width_col):
- w_im = w_col * stride_width - pad_width + w_offset * dilation_width
-
- if 0 <= h_im < output_height and 0 <= w_im < output_width:
- data_im[c_im, h_im, w_im] += data_col[h_col, c_col + w_col]
-
- if not batched_input:
- output = ivy.squeeze(output, axis=0)
-
- return output
diff --git a/ivy/functional/frontends/torch/nn/functional/loss_functions.py b/ivy/functional/frontends/torch/nn/functional/loss_functions.py
index bb187808bc070..af889c8549cb1 100644
--- a/ivy/functional/frontends/torch/nn/functional/loss_functions.py
+++ b/ivy/functional/frontends/torch/nn/functional/loss_functions.py
@@ -1,5 +1,6 @@
# global
import ivy
+import ivy.functional.frontends.torch as torch_frontend
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back
from ivy.func_wrapper import with_unsupported_dtypes
@@ -127,6 +128,77 @@ def binary_cross_entropy_with_logits(
return result
+@to_ivy_arrays_and_back
+@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
+def cosine_embedding_loss(
+ input1,
+ input2,
+ target,
+ margin=0.0,
+ size_average=None,
+ reduce=None,
+ reduction="mean"
+):
+ def norm(input, axis):
+ return ivy.sqrt(ivy.sum(ivy.square(input), axis=axis))
+
+ def cosine_similarity(x1, x2):
+ axis = None
+ if len(x1.shape) == len(x2.shape) and len(x2.shape) == 2:
+ axis = 1
+ input1_norm = norm(x1, axis=axis)
+ input2_norm = norm(x2, axis=axis)
+ norm_mm = input1_norm * input2_norm
+ norm_mm, eps = torch_frontend.promote_types_of_torch_inputs(norm_mm, 1e-08)
+ return ivy.sum(x1 * x2, axis=axis) / ivy.maximum(norm_mm, eps)
+
+ def calculate_loss(x1, x2, target):
+ cos = cosine_similarity(x1, x2)
+ if target == ivy.array(1.0):
+ loss = 1.0 - cos
+ elif target == ivy.array(-1.0):
+ loss = ivy.maximum(ivy.array(0.0), cos - ivy.array(margin))
+ else:
+ _, zero = torch_frontend.promote_types_of_torch_inputs(input1,
+ ivy.array(0.0))
+ return zero
+
+ return loss
+
+ ivy.assertions.check_true(
+ target.ndim + 1 == input1.ndim and target.ndim + 1 == input2.ndim,
+ "{}D target tensor expects {}D input tensors, but "
+ "found inputs with sizes {} and {}.".format(
+ target.ndim, target.ndim + 1, list(input1.shape), list(input2.shape)
+ )
+ )
+
+ ivy.assertions.check_true(
+ target.ndim < 2,
+ "0D or 1D target tensor expected, multi-target not supported"
+ )
+
+ ivy.assertions.check_shape(input1, input2)
+
+ if target.ndim == 1:
+ ivy.assertions.check_true(
+ target.shape[0] == input1.shape[0],
+ "The size of target tensor ({}) must match the size of input tensor ({}) "
+ "at non-singleton dimension 0 ".format(
+ target.shape[0], input1.shape[0])
+ )
+
+ if target.ndim == 0:
+ loss = calculate_loss(input1, input2, target)
+ else:
+ loss = ivy.array([calculate_loss(input1[i], input2[i], target[i])
+ for i in range(input1.shape[0])])
+
+ reduction = _get_reduction(reduction, size_average, reduce)
+ loss = reduction(loss)
+ return loss
+
+
@to_ivy_arrays_and_back
def mse_loss(input, target, size_average=None, reduce=None, reduction="mean"):
reduction = _get_reduction(reduction, size_average, reduce)
diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py
index 377bc56e38ce2..97f5090fd57bd 100644
--- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py
+++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py
@@ -1,4 +1,5 @@
import ivy
+from ivy import with_unsupported_dtypes
from ivy.functional.frontends.tensorflow.func_wrapper import (
to_ivy_arrays_and_back,
)
@@ -66,3 +67,35 @@ def avg_pool2d(
padding_str,
data_format=data_format,
)
+
+
+@with_unsupported_dtypes({"1.11.0 and below": ("float16",)}, "torch")
+@to_ivy_arrays_and_back
+def max_pool2d(
+ input,
+ kernel_size,
+ stride=None,
+ padding=0,
+ dilation=1,
+ ceil_mode=False,
+ return_indices=False,
+):
+ # ToDo: Add return_indices once superset in implemented
+ dim_check = False
+ if input.ndim == 3:
+ input = input.expand_dims()
+ dim_check = True
+ if not stride:
+ stride = kernel_size
+ ret = ivy.max_pool2d(
+ input,
+ kernel_size,
+ stride,
+ padding,
+ data_format="NCHW",
+ dilation=dilation,
+ ceil_mode=ceil_mode,
+ )
+ if dim_check:
+ return ret.squeeze(0)
+ return ret
diff --git a/ivy/functional/frontends/torch/reduction_ops.py b/ivy/functional/frontends/torch/reduction_ops.py
index ecfabd131b1bd..383608e1eaaa8 100644
--- a/ivy/functional/frontends/torch/reduction_ops.py
+++ b/ivy/functional/frontends/torch/reduction_ops.py
@@ -145,3 +145,20 @@ def var_mean(input, dim, unbiased, keepdim=False, *, out=None):
)
temp_mean = ivy.mean(input, axis=dim, keepdims=keepdim, out=out)
return (temp_var, temp_mean)
+
+
+@to_ivy_arrays_and_back
+def aminmax(input, *, dim=None, keepdim=False, out=None):
+ minmax_tuple = namedtuple("minmax", ["min", "max"])
+ return minmax_tuple(
+ ivy.min(input, axis=dim, keepdims=keepdim, out=out),
+ ivy.max(input, axis=dim, keepdims=keepdim, out=out),
+ )
+
+
+aminmax.unsupported_dtypes = {
+ "torch": ("float16", "bfloat16"),
+ "numpy": ("float16", "bfloat16"),
+ "jax": ("float16", "bfloat16"),
+ "tensorflow": ("float16", "bfloat16"),
+}
diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py
index 3d6871ca6a9a3..8ce39e5f466bf 100644
--- a/ivy/functional/frontends/torch/tensor.py
+++ b/ivy/functional/frontends/torch/tensor.py
@@ -54,7 +54,7 @@ def reshape(self, *args, shape=None):
if shape is not None:
return torch_frontend.reshape(self._ivy_array, shape)
if args:
- if isinstance(args[0], tuple):
+ if isinstance(args[0], (tuple, list)):
shape = args[0]
return torch_frontend.reshape(self._ivy_array, shape)
else:
@@ -321,11 +321,11 @@ def arctan(self):
def arctan_(self):
self._ivy_array = self.arctan().ivy_array
return self
-
+
@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
def arctan2(self, other):
return torch_frontend.arctan2(self._ivy_array, other)
-
+
@with_unsupported_dtypes({"1.11.0 and below": ("float16",)}, "torch")
def acos(self):
return torch_frontend.acos(self._ivy_array)
@@ -362,15 +362,27 @@ def new_tensor(
def view_as(self, other):
return self.view(other.shape)
- def expand(self, *sizes):
+ def expand(self, *args, size=None):
+ if args and size:
+ raise TypeError("expand() got multiple values for argument 'size'")
+ if args:
+ if isinstance(args[0], (tuple, list)):
+ size = args[0]
+ else:
+ size = args
- sizes = list(sizes)
- for i, dim in enumerate(sizes):
+ size = list(size)
+ for i, dim in enumerate(size):
if dim < 0:
- sizes[i] = self.shape[i]
+ size[i] = self.shape[i]
return torch_frontend.tensor(
- ivy.broadcast_to(self._ivy_array, shape=tuple(sizes))
+ ivy.broadcast_to(self._ivy_array, shape=tuple(size))
+ )
+
+ def expand_as(self, other):
+ return self.expand(
+ ivy.shape(other.ivy_array if isinstance(other, Tensor) else other)
)
def detach(self):
@@ -449,7 +461,7 @@ def pow_(self, exponent):
return self
def size(self, dim=None):
- shape = ivy.shape(self._ivy_array, as_array=True)
+ shape = ivy.shape(self._ivy_array)
if dim is None:
return shape
else:
@@ -489,7 +501,7 @@ def permute(self, *args, dims=None):
if dims is not None:
return torch_frontend.permute(self._ivy_array, dims)
if args:
- if isinstance(args[0], tuple):
+ if isinstance(args[0], (tuple, list)):
dims = args[0]
return torch_frontend.permute(self._ivy_array, dims)
else:
@@ -573,6 +585,11 @@ def clamp(self, min=None, max=None, *, out=None):
return torch_frontend.tensor(ivy.array(self._ivy_array).full_like(max))
return torch_frontend.clamp(self._ivy_array, min=min, max=max, out=out)
+ @with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, "torch")
+ def clamp_(self, min=None, max=None, *, out=None):
+ self._ivy_array = self.clamp(min=min, max=max, out=out).ivy_array
+ return self
+
@with_unsupported_dtypes({"1.11.0 and below": ("float16", "bfloat16")}, "torch")
def sqrt(self):
return torch_frontend.sqrt(self._ivy_array)
@@ -624,7 +641,9 @@ def __mod__(self, other):
return torch_frontend.remainder(self._ivy_array, other)
def __long__(self, memory_format=None):
- return torch_frontend.tensor(ivy.astype(self._ivy_array, ivy.int64))
+ cast_tensor = self.clone()
+ cast_tensor.ivy_array = ivy.astype(self._ivy_array, ivy.int64)
+ return cast_tensor
def __getitem__(self, query):
ret = ivy.get_item(self._ivy_array, query)
diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py
index 2b89740385669..3f56b3c93a5a9 100644
--- a/ivy/functional/ivy/activations.py
+++ b/ivy/functional/ivy/activations.py
@@ -7,6 +7,7 @@
import ivy
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
handle_out_argument,
to_native_arrays_and_back,
handle_nestable,
@@ -100,6 +101,7 @@ def deserialize(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def gelu(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -207,6 +209,7 @@ def get(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def leaky_relu(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -271,6 +274,7 @@ def leaky_relu(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def log_softmax(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -341,6 +345,7 @@ def log_softmax(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def relu(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -394,6 +399,7 @@ def relu(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sigmoid(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -435,6 +441,7 @@ def sigmoid(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def softmax(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -483,6 +490,7 @@ def softmax(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def softplus(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -539,6 +547,7 @@ def softplus(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def mish(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
diff --git a/ivy/functional/ivy/creation.py b/ivy/functional/ivy/creation.py
index f423dfd0979f4..0af916bf58b06 100644
--- a/ivy/functional/ivy/creation.py
+++ b/ivy/functional/ivy/creation.py
@@ -11,6 +11,7 @@
from ivy.backend_handler import current_backend
from ivy.exceptions import handle_exceptions
from ivy.func_wrapper import (
+ handle_array_function,
infer_device,
infer_dtype,
handle_out_argument,
@@ -156,7 +157,8 @@ def __len__(self, /) -> int:
@handle_out_argument
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def arange(
start: Number,
/,
@@ -224,7 +226,8 @@ def arange(
@handle_out_argument
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def asarray(
obj: Union[
ivy.Array,
@@ -285,7 +288,8 @@ def asarray(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def zeros(
shape: Union[ivy.Shape, ivy.NativeShape],
*,
@@ -344,7 +348,8 @@ def zeros(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def ones(
shape: Union[ivy.Shape, ivy.NativeShape],
*,
@@ -428,7 +433,8 @@ def ones(
@infer_device
@infer_dtype
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def full_like(
x: Union[ivy.Array, ivy.NativeArray],
@@ -533,7 +539,8 @@ def full_like(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def ones_like(
x: Union[ivy.Array, ivy.NativeArray],
@@ -650,7 +657,8 @@ def ones_like(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def zeros_like(
x: Union[ivy.Array, ivy.NativeArray],
@@ -765,7 +773,8 @@ def zeros_like(
@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def tril(
x: Union[ivy.Array, ivy.NativeArray],
@@ -813,7 +822,8 @@ def tril(
@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def triu(
x: Union[ivy.Array, ivy.NativeArray],
@@ -863,7 +873,8 @@ def triu(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def empty(
shape: Union[ivy.Shape, ivy.NativeShape],
*,
@@ -910,7 +921,8 @@ def empty(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def empty_like(
x: Union[ivy.Array, ivy.NativeArray],
@@ -960,7 +972,8 @@ def empty_like(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def eye(
n_rows: int,
n_cols: Optional[int] = None,
@@ -1105,7 +1118,8 @@ def eye(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def linspace(
start: Union[ivy.Array, ivy.NativeArray, float],
stop: Union[ivy.Array, ivy.NativeArray, float],
@@ -1210,7 +1224,8 @@ def linspace(
@to_native_arrays_and_back
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def meshgrid(
*arrays: Union[ivy.Array, ivy.NativeArray],
sparse: bool = False,
@@ -1324,7 +1339,8 @@ def meshgrid(
@handle_out_argument
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def full(
shape: Union[ivy.Shape, ivy.NativeShape],
fill_value: Union[float, bool],
@@ -1429,7 +1445,8 @@ def full(
@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def from_dlpack(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -1479,7 +1496,8 @@ def from_dlpack(
@inputs_to_native_arrays
@handle_out_argument
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
def copy_array(
x: Union[ivy.Array, ivy.NativeArray],
*,
@@ -1583,7 +1601,7 @@ def copy_array(
return current_backend(x).copy_array(x, to_ivy_array=to_ivy_array, out=out)
-@handle_exceptions
+@handle_array_like_without_promotion
def native_array(
x: Union[ivy.Array, ivy.NativeArray, List[Number], Tuple[Number], np.ndarray],
/,
@@ -1620,7 +1638,8 @@ def native_array(
@handle_out_argument
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def one_hot(
indices: Union[ivy.Array, ivy.NativeArray],
@@ -1733,7 +1752,8 @@ def one_hot(
@infer_dtype
@infer_device
@handle_nestable
-@handle_exceptions
+@handle_array_like_without_promotion
+@handle_array_function
@handle_array_like_without_promotion
def logspace(
start: Union[ivy.Array, ivy.NativeArray, float],
diff --git a/ivy/functional/ivy/data_type.py b/ivy/functional/ivy/data_type.py
index 029e18295f6e9..9f6f52ad02b24 100644
--- a/ivy/functional/ivy/data_type.py
+++ b/ivy/functional/ivy/data_type.py
@@ -11,6 +11,7 @@
import ivy
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
handle_out_argument,
to_native_arrays_and_back,
inputs_to_native_arrays,
@@ -223,6 +224,7 @@ def _get_dtypes(fn, complement=True):
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def astype(
x: Union[ivy.Array, ivy.NativeArray],
dtype: Union[ivy.Dtype, ivy.NativeDtype],
@@ -328,6 +330,7 @@ def astype(
@to_native_arrays_and_back
@handle_nestable
@handle_exceptions
+@handle_array_function
def broadcast_arrays(*arrays: Union[ivy.Array, ivy.NativeArray]) -> List[ivy.Array]:
"""Broadcasts one or more arrays against one another.
@@ -405,6 +408,7 @@ def broadcast_arrays(*arrays: Union[ivy.Array, ivy.NativeArray]) -> List[ivy.Arr
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def broadcast_to(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -474,6 +478,7 @@ def broadcast_to(
@inputs_to_ivy_arrays
@handle_nestable
@handle_exceptions
+@handle_array_function
def can_cast(
from_: Union[ivy.Dtype, ivy.Array, ivy.NativeArray],
to: ivy.Dtype,
diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py
index d5727f834bb3b..cdce8c301cfcc 100644
--- a/ivy/functional/ivy/elementwise.py
+++ b/ivy/functional/ivy/elementwise.py
@@ -5,6 +5,7 @@
# local
import ivy
from ivy.func_wrapper import (
+ handle_array_function,
handle_out_argument,
to_native_arrays_and_back,
handle_nestable,
@@ -23,6 +24,8 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
+@handle_array_function
def abs(
x: Union[float, ivy.Array, ivy.NativeArray],
/,
@@ -108,6 +111,7 @@ def abs(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def acos(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -186,6 +190,7 @@ def acos(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def acosh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -265,6 +270,7 @@ def acosh(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def add(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -385,6 +391,7 @@ def add(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def asin(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -468,6 +475,7 @@ def asin(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def asinh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -555,6 +563,7 @@ def asinh(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def atan(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -636,6 +645,7 @@ def atan(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def atan2(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -809,6 +819,7 @@ def atan2(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def atanh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -877,6 +888,7 @@ def atanh(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def bitwise_and(
x1: Union[int, bool, ivy.Array, ivy.NativeArray],
x2: Union[int, bool, ivy.Array, ivy.NativeArray],
@@ -968,6 +980,7 @@ def bitwise_and(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def bitwise_invert(
x: Union[int, bool, ivy.Array, ivy.NativeArray, ivy.Container],
/,
@@ -1040,6 +1053,7 @@ def bitwise_invert(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def bitwise_left_shift(
x1: Union[int, ivy.Array, ivy.NativeArray],
x2: Union[int, ivy.Array, ivy.NativeArray],
@@ -1087,6 +1101,7 @@ def bitwise_left_shift(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def bitwise_or(
x1: Union[int, bool, ivy.Array, ivy.NativeArray],
x2: Union[int, bool, ivy.Array, ivy.NativeArray],
@@ -1173,6 +1188,7 @@ def bitwise_or(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def bitwise_right_shift(
x1: Union[int, ivy.Array, ivy.NativeArray],
x2: Union[int, ivy.Array, ivy.NativeArray],
@@ -1285,6 +1301,7 @@ def bitwise_right_shift(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def bitwise_xor(
x1: Union[int, bool, ivy.Array, ivy.NativeArray],
x2: Union[int, bool, ivy.Array, ivy.NativeArray],
@@ -1388,6 +1405,7 @@ def bitwise_xor(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def ceil(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1475,6 +1493,7 @@ def ceil(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def cos(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1555,6 +1574,7 @@ def cos(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def cosh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1638,6 +1658,7 @@ def cosh(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def divide(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -1722,6 +1743,7 @@ def divide(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def equal(
x1: Union[float, ivy.Array, ivy.NativeArray, ivy.Container],
x2: Union[float, ivy.Array, ivy.NativeArray, ivy.Container],
@@ -1809,6 +1831,7 @@ def equal(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def exp(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1862,6 +1885,7 @@ def exp(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def expm1(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1944,6 +1968,7 @@ def expm1(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def floor(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2031,6 +2056,7 @@ def floor(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def floor_divide(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -2117,6 +2143,7 @@ def floor_divide(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def greater(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -2212,6 +2239,7 @@ def greater(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def greater_equal(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -2310,6 +2338,7 @@ def greater_equal(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def less_equal(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -2388,6 +2417,7 @@ def less_equal(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def multiply(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -2483,6 +2513,7 @@ def multiply(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def isfinite(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2560,6 +2591,7 @@ def isfinite(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def isinf(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2665,6 +2697,7 @@ def isinf(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def isnan(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2756,6 +2789,7 @@ def isnan(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def less(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -2843,6 +2877,7 @@ def less(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def log(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2916,6 +2951,7 @@ def log(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def log10(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2999,6 +3035,7 @@ def log10(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def log1p(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -3089,6 +3126,7 @@ def log1p(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def log2(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -3143,6 +3181,7 @@ def log2(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def logaddexp(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -3245,6 +3284,7 @@ def logaddexp(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def logical_and(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -3343,6 +3383,7 @@ def logical_and(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def logical_not(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -3435,6 +3476,7 @@ def logical_not(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def logical_or(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -3522,6 +3564,7 @@ def logical_or(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def logical_xor(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -3618,6 +3661,7 @@ def logical_xor(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def negative(
x: Union[float, ivy.Array, ivy.NativeArray],
/,
@@ -3690,6 +3734,7 @@ def negative(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def not_equal(
x1: Union[float, ivy.Array, ivy.NativeArray, ivy.Container],
x2: Union[float, ivy.Array, ivy.NativeArray, ivy.Container],
@@ -3847,6 +3892,7 @@ def not_equal(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def positive(
x: Union[float, ivy.Array, ivy.NativeArray],
/,
@@ -3920,6 +3966,7 @@ def positive(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def pow(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -4054,6 +4101,7 @@ def pow(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def remainder(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -4180,6 +4228,7 @@ def remainder(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def round(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4275,6 +4324,7 @@ def round(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sign(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4352,6 +4402,7 @@ def sign(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sin(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4437,6 +4488,7 @@ def sin(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sinh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4515,6 +4567,7 @@ def sinh(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sqrt(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4597,6 +4650,7 @@ def sqrt(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def square(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4665,6 +4719,7 @@ def square(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def subtract(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -4727,6 +4782,7 @@ def subtract(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def tan(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4813,6 +4869,7 @@ def tan(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def tanh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4898,6 +4955,7 @@ def tanh(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def trunc(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -4986,6 +5044,7 @@ def trunc(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def erf(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -5020,6 +5079,7 @@ def erf(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def maximum(
x1: Union[ivy.Array, ivy.NativeArray, Number],
x2: Union[ivy.Array, ivy.NativeArray, Number],
@@ -5107,6 +5167,7 @@ def maximum(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def minimum(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -5196,6 +5257,7 @@ def minimum(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def reciprocal(
x: Union[float, ivy.Array, ivy.NativeArray],
/,
@@ -5232,6 +5294,7 @@ def reciprocal(
@handle_out_argument
@handle_nestable
@handle_array_like_without_promotion
+@handle_array_function
def deg2rad(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -5309,6 +5372,7 @@ def deg2rad(
@handle_out_argument
@handle_nestable
@handle_array_like_without_promotion
+@handle_array_function
def rad2deg(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -5383,6 +5447,7 @@ def rad2deg(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def trunc_divide(
x1: Union[float, ivy.Array, ivy.NativeArray],
x2: Union[float, ivy.Array, ivy.NativeArray],
@@ -5431,6 +5496,7 @@ def trunc_divide(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def isreal(
x: Union[ivy.Array, ivy.NativeArray],
/,
diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py
index 57dd8d6e3c3c9..45bfe95d94e9a 100644
--- a/ivy/functional/ivy/experimental/activations.py
+++ b/ivy/functional/ivy/experimental/activations.py
@@ -78,7 +78,7 @@ def prelu(
f(x) = slope * x for x < 0, f(x) = x for x >= 0., is applied
to the data array elementwise. This operator supports unidirectional
broadcasting (array slope should be unidirectional broadcastable to
- input tensor X); for more details please check Broadcasting in ONNX.
+ input tensor X);
Parameters
----------
@@ -97,7 +97,10 @@ def prelu(
"""
try:
return ivy.where(x > 0, x, x * slope, out=out)
- except ValueError as e:
+ except ivy.exceptions.IvyError(
+ f"The shape {slope.shape} is not Unidirectional Broadcastable\n"
+ f"as per ONNX standards"
+ ) as IvyException:
if len(slope.shape) == 1:
dim = slope.shape[0]
new_shape = []
@@ -111,7 +114,7 @@ def prelu(
if n == 1:
xs = x * slope.reshape(tuple(new_shape), out=out)
return ivy.where(x > 0, x, xs, out=out)
- raise e
+ raise IvyException
@to_native_arrays_and_back
diff --git a/ivy/functional/ivy/experimental/elementwise.py b/ivy/functional/ivy/experimental/elementwise.py
index a384be6a7b13d..c5552591efab8 100644
--- a/ivy/functional/ivy/experimental/elementwise.py
+++ b/ivy/functional/ivy/experimental/elementwise.py
@@ -899,11 +899,11 @@ def hypot(
@handle_nestable
@handle_array_like_without_promotion
def diff(
- x: Union[ivy.Array, ivy.NativeArray, int, list, tuple],
+ x: Union[ivy.Array, ivy.NativeArray, list, tuple],
/,
*,
- n: Optional[int] = 1,
- axis: Optional[int] = -1,
+ n: int = 1,
+ axis: int = -1,
prepend: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
append: Optional[Union[ivy.Array, ivy.NativeArray, int, list, tuple]] = None,
out: Optional[ivy.Array] = None,
@@ -923,14 +923,18 @@ def diff(
Values to prepend/append to x along given axis prior to performing the
difference. Scalar values are expanded to arrays with length 1 in the direction
of axis and the shape of the input array in along all other axes. Otherwise the
- dimension and shape must match a except along axis.
+ dimension and shape must match x except along axis.
out
optional output array, for writing the result to.
Returns
-------
ret
- Rreturns the n-th discrete difference along the given axis.
+ Returns the n-th discrete difference along the given axis.
+
+ Both the description and the type hints above assumes an array input for simplicity,
+ but this function is *nestable*, and therefore also accepts :class:`ivy.Container`
+ instances in place of any of the arguments.
Examples
--------
diff --git a/ivy/functional/ivy/experimental/layers.py b/ivy/functional/ivy/experimental/layers.py
index 4088e298e5246..aec9d13aa5f54 100644
--- a/ivy/functional/ivy/experimental/layers.py
+++ b/ivy/functional/ivy/experimental/layers.py
@@ -1,4 +1,9 @@
-from typing import Optional, Union, Tuple, Literal
+# global
+import math
+from typing import Optional, Union, Tuple, Literal, Sequence
+
+
+# local
import ivy
from ivy.func_wrapper import (
handle_array_like_without_promotion,
@@ -79,10 +84,12 @@ def max_pool2d(
x: Union[ivy.Array, ivy.NativeArray],
kernel: Union[int, Tuple[int], Tuple[int, int]],
strides: Union[int, Tuple[int], Tuple[int, int]],
- padding: str,
+ padding: Union[str, int, Tuple[int], Tuple[int, int]],
/,
*,
data_format: str = "NHWC",
+ dilation: Union[int, Tuple[int], Tuple[int, int]] = 1,
+ ceil_mode: bool = False,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Computes a 2-D max pool given 4-D input x.
@@ -98,7 +105,7 @@ def max_pool2d(
The stride of the sliding window for each dimension of input.
padding
SAME" or "VALID" indicating the algorithm, or list
- indicating the per-dimensio paddings.
+ indicating the per-dimension paddings.
data_format
NHWC" or "NCHW". Defaults to "NHWC".
out
@@ -138,7 +145,16 @@ def max_pool2d(
[[46, 47]]]])
"""
- return ivy.current_backend(x).max_pool2d(x, kernel, strides, padding, out=out)
+ return ivy.current_backend(x).max_pool2d(
+ x,
+ kernel,
+ strides,
+ padding,
+ data_format=data_format,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
+ out=out,
+ )
@to_native_arrays_and_back
@@ -849,3 +865,329 @@ def dft(
slices[axis] = slice(0, res.shape[axis] // 2 + 1)
res = res[tuple(slices)]
return res
+
+
+@to_native_arrays_and_back
+@handle_exceptions
+@handle_out_argument
+@handle_nestable
+def interp(x, xp, fp, left=None, right=None, period=None):
+ x_arr = ivy.array(x)
+ fix_later = False
+ if x_arr.shape == ():
+ x_arr = ivy.array([x])
+ fix_later = True
+ x = ivy.astype(x_arr, "float64")
+ xp = ivy.astype(ivy.array(xp), "float64")
+ fp = ivy.astype(ivy.array(fp), "float64")
+ ivy.assertions.check_equal(xp.ndim, 1)
+ ivy.assertions.check_equal(fp.ndim, 1)
+ ivy.assertions.check_equal(xp.shape[0], fp.shape[0])
+ if period is not None:
+ ivy.assertions.check_equal(period, 0, inverse=True)
+ period = ivy.abs(period)
+ x = ivy.remainder(x, period)
+ xp = ivy.remainder(xp, period)
+ asort_xp = ivy.argsort(xp)
+ xp = xp[asort_xp]
+ fp = fp[asort_xp]
+ xp = ivy.concat((xp[-1:] - period, xp, xp[0:1] + period))
+ fp = ivy.concat((fp[-1:], fp, fp[0:1]))
+
+ def interp_inner(value):
+ value = ivy.array(value)
+ if value < xp[0]:
+ return left if left is not None else fp[0]
+ elif value > xp[-1]:
+ return right if right is not None else fp[-1]
+ else:
+ last = None
+ if xp.shape[0] < 3:
+ for i in range(xp.shape[0] - 1, -1, -1):
+ if xp[i] == value:
+ return fp[i]
+ elif xp[i] < value:
+ last = i
+ else:
+ first = 0
+ last = xp.shape[0]
+ while first < last:
+ midpoint = (first + last) // 2
+ if xp[midpoint] == value:
+ already_exists = ivy.argwhere(xp == value)
+ if already_exists.shape[0] > 0:
+ return fp[already_exists[-1][0]]
+ return fp[midpoint]
+ else:
+ if value < xp[midpoint]:
+ last = midpoint - 1
+ else:
+ first = midpoint + 1
+ dist = (value - xp[last]) / (xp[last + 1] - xp[last])
+ return (fp[last + 1] - fp[last]) * dist + fp[last]
+
+ ret = ivy.map(interp_inner, unique={"value": x})
+ if fix_later:
+ return ivy.astype(ivy.array(ret[0]), "float64")
+ else:
+ return ivy.astype(ivy.array(ret), "float64")
+
+
+@to_native_arrays_and_back
+@handle_exceptions
+@handle_out_argument
+@handle_nestable
+def interpolate(
+ x: Union[ivy.Array, ivy.NativeArray],
+ size: Union[Sequence[int], int],
+ /,
+ *,
+ mode: Union[
+ Literal["linear", "bilinear", "trilinear", "nearest", "area", "nearest_exact"]
+ ] = "linear",
+ align_corners: Optional[bool] = None,
+ antialias: Optional[bool] = False,
+ out: Optional[ivy.Array] = None,
+) -> ivy.Array:
+ """
+ Down/up samples the input to the given size.
+ The algorithm used for interpolation is determined by mode.
+
+ Parameters
+ ----------
+ x
+ Input array, Must have the shape
+ [batch x channels x [optional depth] x [optional height] x width].
+ size
+ Output size.
+ mode
+ Interpolation mode. Can be one of the following:
+ - linear
+ - bilinear
+ - trilinear
+ - nearest
+ - area
+ align_corners
+ If True, the corner pixels of the input and output tensors are aligned,
+ and thus preserving the values at the corner pixels. If False, the corner
+ pixels are not aligned, and the interpolation uses edge value padding for
+ out-of-boundary values.
+ only has an effect when mode is 'linear', 'bilinear',
+ 'bicubic' or 'trilinear'. Default: False
+ antialias
+ If True, antialiasing is applied when downsampling an image.
+ Supported modes: 'bilinear', 'bicubic'.
+ out
+ Optional output array, for writing the result to. It must
+ have a shape that the inputs broadcast to.
+
+ Returns
+ -------
+ resized array
+
+ """
+ dims = len(x.shape) - 2
+ size = (size,) * dims if isinstance(size, int) else tuple(size)
+ if mode == "linear":
+ size = size[0]
+ if not align_corners or align_corners is None:
+ x_up = ivy.arange(0, ivy.shape(x)[-1])
+ missing = (ivy.arange(0, size) + 0.5) * (ivy.shape(x)[-1] / size) - 0.5
+ else:
+ x_up = ivy.linspace(0, 1, ivy.shape(x)[-1])
+ missing = ivy.linspace(0, 1, size)
+ ret = ivy.zeros(ivy.shape(x)[:-1] + (size,))
+ for i, ba in enumerate(x):
+ for j, ch in enumerate(ba):
+ ret[i][j] = ivy.interp(missing, x_up, ch)
+ elif mode == "bilinear":
+ if not align_corners or align_corners is None:
+ x_up_h = ivy.arange(0, ivy.shape(x)[-2])
+ x_up_w = ivy.arange(0, ivy.shape(x)[-1])
+ missing_h = (ivy.arange(0, size[0]) + 0.5) * (
+ ivy.shape(x)[-2] / size[0]
+ ) - 0.5
+ missing_w = (ivy.arange(0, size[1]) + 0.5) * (
+ ivy.shape(x)[-1] / size[1]
+ ) - 0.5
+ else:
+ x_up_h = ivy.linspace(0, 1, ivy.shape(x)[-2])
+ x_up_w = ivy.linspace(0, 1, ivy.shape(x)[-1])
+ missing_h = ivy.linspace(0, 1, size[0])
+ missing_w = ivy.linspace(0, 1, size[1])
+ ret = ivy.zeros(ivy.shape(x)[:-2] + (size[1], size[0]))
+ for i, ba in enumerate(x):
+ for j, ch in enumerate(ba):
+ row_ret = ivy.zeros((ivy.shape(x)[-2], size[1]))
+ for k, row in enumerate(ch):
+ row_ret[k] = ivy.interp(missing_w, x_up_w, row)
+ row_ret = row_ret.T
+ for k, col in enumerate(row_ret):
+ ret[i][j][k] = ivy.interp(missing_h, x_up_h, col)
+ ret = ivy.permute_dims(ret, (0, 1, 3, 2))
+ elif mode == "trilinear":
+ if not align_corners or align_corners is None:
+ x_up_d = ivy.arange(0, ivy.shape(x)[-3])
+ x_up_h = ivy.arange(0, ivy.shape(x)[-2])
+ x_up_w = ivy.arange(0, ivy.shape(x)[-1])
+ missing_d = (ivy.arange(0, size[0]) + 0.5) * (
+ ivy.shape(x)[-3] / size[0]
+ ) - 0.5
+ missing_h = (ivy.arange(0, size[1]) + 0.5) * (
+ ivy.shape(x)[-2] / size[1]
+ ) - 0.5
+ missing_w = (ivy.arange(0, size[2]) + 0.5) * (
+ ivy.shape(x)[-1] / size[2]
+ ) - 0.5
+ else:
+ x_up_d = ivy.linspace(0, 1, ivy.shape(x)[-3])
+ x_up_h = ivy.linspace(0, 1, ivy.shape(x)[-2])
+ x_up_w = ivy.linspace(0, 1, ivy.shape(x)[-1])
+ missing_d = ivy.linspace(0, 1, size[0])
+ missing_h = ivy.linspace(0, 1, size[1])
+ missing_w = ivy.linspace(0, 1, size[2])
+ ret = ivy.zeros(ivy.shape(x)[:-3] + (size[1], size[2], size[0]))
+ for i, ba in enumerate(x):
+ for j, ch in enumerate(ba):
+ depth_ret = ivy.zeros((x.shape[-3], size[2], size[1]))
+ row_ret = ivy.zeros((ivy.shape(x)[-3], ivy.shape(x)[-2], size[2]))
+ for k, depth in enumerate(ch):
+ for (
+ l,
+ row,
+ ) in enumerate(ch[k]):
+ row_ret[k][l] = ivy.interp(missing_w, x_up_w, row)
+ row_ret = row_ret.transpose((0, 2, 1))
+ for k, row in enumerate(ch):
+ for (
+ l,
+ col,
+ ) in enumerate(row_ret[k]):
+ depth_ret[k][l] = ivy.interp(missing_h, x_up_h, col)
+ depth_ret = depth_ret.transpose((2, 1, 0))
+ for k, col in enumerate(depth_ret):
+ for (
+ l,
+ depth,
+ ) in enumerate(depth_ret[k]):
+ ret[i][j][k][l] = ivy.interp(missing_d, x_up_d, depth)
+ ret = ret.transpose((0, 1, 4, 2, 3))
+ elif mode == "nearest" or mode == "nearest_exact":
+ ret = ivy.zeros((x.shape[:2] + tuple(size)))
+ for i, ba in enumerate(x):
+ for j, ch in enumerate(ba):
+ w_scale = size[-1] / x.shape[-1]
+ if dims == 3:
+ h_scale = size[-2] / x.shape[-2]
+ d_scale = size[-3] / x.shape[-3]
+ for d_dim in range(size[0]):
+ for h_dim in range(size[1]):
+ for w_dim in range(size[2]):
+ ret[i][j][d_dim][h_dim][w_dim] = x[i][j][
+ round(d_dim // d_scale)
+ ][round(h_dim // h_scale)][round(w_dim // w_scale)]
+ elif dims == 2:
+ h_scale = size[-2] / x.shape[-2]
+ for h_dim in range(size[0]):
+ for w_dim in range(size[1]):
+ ret[i][j][h_dim][w_dim] = x[i][j][round(h_dim // h_scale)][
+ round(w_dim // w_scale)
+ ]
+ elif dims == 1:
+ for w_dim in range(size[0]):
+ ret[i][j][w_dim] = x[i][j][round(w_dim // w_scale)]
+ elif mode == "area":
+ ret = ivy.zeros((x.shape[:2] + size))
+ scale = ivy.divide(ivy.shape(x)[2:], size)
+ for i, ba in enumerate(x):
+ for j, ch in enumerate(ba):
+ if dims == 3:
+ for d_dim in range(size[0]):
+ for h_dim in range(size[1]):
+ for w_dim in range(size[2]):
+ d_index = (
+ int(d_dim * scale[0]),
+ math.ceil((d_dim + 1) * scale[0]),
+ )
+ h_index = (
+ int(h_dim * scale[1]),
+ math.ceil((h_dim + 1) * scale[1]),
+ )
+ w_index = (
+ int(w_dim * scale[2]),
+ math.ceil((w_dim + 1) * scale[2]),
+ )
+ scale_z = d_index[1] - d_index[0]
+ scale_y = h_index[1] - h_index[0]
+ scale_x = w_index[1] - w_index[0]
+ area = scale_z * scale_y * scale_x
+ ret[i][j][d_dim][h_dim][w_dim] = ivy.sum(
+ ch[
+ d_index[0] : d_index[1],
+ h_index[0] : h_index[1],
+ w_index[0] : w_index[1],
+ ]
+ ) * (1 / area)
+ elif dims == 2:
+ for h_dim in range(size[0]):
+ for w_dim in range(size[1]):
+ h_index = (
+ int(h_dim * scale[0]),
+ math.ceil((h_dim + 1) * scale[0]),
+ )
+ w_index = (
+ int(w_dim * scale[1]),
+ math.ceil((w_dim + 1) * scale[1]),
+ )
+ scale_y = h_index[1] - h_index[0]
+ scale_x = w_index[1] - w_index[0]
+ area = scale_y * scale_x
+ ret[i][j][h_dim][w_dim] = ivy.sum(
+ ch[h_index[0] : h_index[1], w_index[0] : w_index[1]]
+ ) * (1 / area)
+ else:
+ for w_dim in range(size[0]):
+ w_index = (
+ int(w_dim * scale[0]),
+ math.ceil((w_dim + 1) * scale[0]),
+ )
+ scale_x = w_index[1] - w_index[0]
+ ret[i][j][w_dim] = ivy.sum(ch[w_index[0] : w_index[1]]) * (
+ 1 / scale_x
+ )
+ return ivy.astype(ret, ivy.dtype(x))
+
+
+interpolate.mixed_function = True
+
+
+# Helpers #
+
+
+def _output_ceil_shape(w, f, p, s):
+ return math.ceil((w - f + p) / s) + 1
+
+
+def padding_ceil_mode(w, f, p, s):
+ remaining_pixels = (w - f + sum(p)) % s
+ if s > 1 and remaining_pixels != 0 and f > 1:
+ input_size = w + sum(p)
+ # making sure that the remaining pixels are supposed
+ # to be covered by the window
+ # they won't be covered if stride is big enough to skip them
+ if input_size - remaining_pixels - (f - 1) + s > input_size:
+ return p
+ output_shape = _output_ceil_shape(
+ w,
+ f,
+ sum(p),
+ s,
+ )
+ # calculating new padding with ceil_output_shape
+ new_pad = (output_shape - 1) * s + f - w
+ # updating pad_list with new padding by adding it to the end
+ p = (
+ p[0],
+ p[1] + new_pad - sum(p),
+ )
+ return p
diff --git a/ivy/functional/ivy/experimental/manipulation.py b/ivy/functional/ivy/experimental/manipulation.py
index 7f908c0a4b144..874b3b60edb75 100644
--- a/ivy/functional/ivy/experimental/manipulation.py
+++ b/ivy/functional/ivy/experimental/manipulation.py
@@ -1484,16 +1484,13 @@ def take_along_axis(
@to_native_arrays_and_back
-@handle_out_argument
@handle_nestable
@handle_array_like_without_promotion
def hsplit(
ary: Union[ivy.Array, ivy.NativeArray],
- indices_or_sections: Union[int, Tuple[int]],
+ indices_or_sections: Union[int, Tuple[int, ...]],
/,
- *,
- out: Optional[ivy.Array] = None,
-) -> ivy.Array:
+) -> List[ivy.Array]:
"""Split an array into multiple sub-arrays horizontally.
Parameters
@@ -1501,15 +1498,10 @@ def hsplit(
ary
Array input.
indices_or_sections
- If indices_or_sections is an integer n, the array is split into n sections.
- If the array is divisible by n along the 3rd axis, each section will be of
- equal size. If input is not divisible by n, the sizes of the first
- int(ary.size(0) % n) sections will have size int(ary.size(0) / n) + 1,
- and the rest will have size int(ary.size(0) / n).
+ If indices_or_sections is an integer n, the array is split into n
+ equal sections, provided that n must be a divisor of the split axis.
If indices_or_sections is a tuple of ints, then input is split at each of
the indices in the tuple.
- out
- optional output array, for writing the result to.
Returns
-------
@@ -1525,16 +1517,16 @@ def hsplit(
[12., 13., 14., 15.]]
)
>>> ivy.hsplit(ary, 2)
- [ivy.array([[ 0., 1.],
+ [ivy.array([[ 0., 1.],
[ 4., 5.],
[ 8., 9.],
[12., 13.]]),
ivy.array([[ 2., 3.],
[ 6., 7.],
[10., 11.],
- [14., 15.]]))
+ [14., 15.]])]
"""
- return ivy.current_backend(ary).hsplit(ary, indices_or_sections, out=out)
+ return ivy.current_backend(ary).hsplit(ary, indices_or_sections)
@handle_exceptions
diff --git a/ivy/functional/ivy/experimental/norms.py b/ivy/functional/ivy/experimental/norms.py
index f1f730413648f..9a5a0ebe24b79 100644
--- a/ivy/functional/ivy/experimental/norms.py
+++ b/ivy/functional/ivy/experimental/norms.py
@@ -140,3 +140,44 @@ def instance_norm(
track_running_stats=track_running_stats,
out=out,
)
+
+
+@to_native_arrays_and_back
+@handle_out_argument
+@handle_nestable
+@handle_exceptions
+def lp_normalize(
+ x: Union[ivy.Array, ivy.NativeArray],
+ /,
+ *,
+ p: float = 2,
+ axis: Optional[int] = None,
+ out: Optional[ivy.Array] = None,
+) -> ivy.Array:
+ """Normalizes the input array along the given axis to have Lp norm equal to 1.
+
+ Parameters
+ ----------
+ x
+ Input array.
+ p
+ The Lp norm to use for normalization. Default is L2 norm (p=2).
+ axis
+ Axis along which to normalize. If ``None``, the whole array is normalized.
+ out
+ optional output array, for writing the result to. It must have a shape that the
+ inputs broadcast to.
+
+ Returns
+ -------
+ ret
+ The normalized array.
+
+ Examples
+ --------
+ >>> x = ivy.array([[1., 2.], [3., 4.]])
+ >>> ivy.lp_normalize(x, p=1, axis=1)
+ ivy.array([[0.3333, 0.6666],
+ [0.75, 1.]])
+ """
+ return current_backend(x).lp_normalize(x, p=p, axis=axis, out=out)
diff --git a/ivy/functional/ivy/experimental/statistical.py b/ivy/functional/ivy/experimental/statistical.py
index 7cfea91304484..285f0e071a26a 100644
--- a/ivy/functional/ivy/experimental/statistical.py
+++ b/ivy/functional/ivy/experimental/statistical.py
@@ -235,3 +235,62 @@ def corrcoef(
out: Optional[ivy.Array] = None,
) -> ivy.Array:
return ivy.current_backend().corrcoef(x, y=y, rowvar=rowvar, out=out)
+
+
+@to_native_arrays_and_back
+@handle_out_argument
+@handle_nestable
+@handle_exceptions
+def nanmedian(
+ input: ivy.Array,
+ /,
+ *,
+ axis: Optional[Union[Tuple[int], int]] = None,
+ keepdims: Optional[bool] = False,
+ overwrite_input: Optional[bool] = False,
+ out: Optional[ivy.Array] = None,
+) -> ivy.Array:
+ """ivy.Array instance method variant of ivy.nanmedian. This method simply
+ wraps the function, and so the docstring for ivy.nanmedian also applies to
+ this method with minimal changes.
+
+ Parameters
+ ----------
+ self
+ Input array.
+ axis
+ Axis or axes along which the means are computed.
+ The default is to compute the mean of the flattened array.
+ keepdims
+ If this is set to True, the axes which are reduced are left in the result
+ as dimensions with size one. With this option, the result will broadcast
+ correctly against the original a. If the value is anything but the default,
+ then keepdims will be passed through to the mean or sum methods of
+ sub-classes of ndarray. If the sub-classes methods does not implement
+ keepdims any exceptions will be raised.
+ overwrite_input
+ If True, then allow use of memory of input array a for calculations.
+ The input array will be modified by the call to median. This will
+ save memory when you do not need to preserve the contents of the input array.
+ Treat the input as undefined, but it will probably be fully or partially sorted.
+ Default is False. If overwrite_input is True and a is not already an ndarray,
+ an error will be raised.
+ out
+ optional output array, for writing the result to.
+
+ Returns
+ -------
+ ret
+ A new array holding the result. If the input contains integers
+
+ Examples
+ >>> a = ivy.Array([[10.0, ivy.nan, 4], [3, 2, 1]])
+ >>> a.nanmedian(a)
+ 3.0
+ >>> a.nanmedian(a, axis=0)
+ array([6.5, 2. , 2.5])
+ """
+
+ return ivy.current_backend().nanmedian(
+ input, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, out=out
+ )
diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py
index 70877ac167435..82f700b531658 100644
--- a/ivy/functional/ivy/general.py
+++ b/ivy/functional/ivy/general.py
@@ -16,6 +16,7 @@
from ivy.functional.ivy.gradients import _is_variable
from ivy.exceptions import handle_exceptions
from ivy.func_wrapper import (
+ handle_array_function,
inputs_to_ivy_arrays,
inputs_to_native_arrays,
outputs_to_ivy_arrays,
@@ -513,6 +514,7 @@ def get_show_func_wrapper_trace_mode() -> bool:
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def array_equal(
x0: Union[ivy.Array, ivy.NativeArray],
x1: Union[ivy.Array, ivy.NativeArray],
@@ -559,6 +561,7 @@ def array_equal(
@to_native_arrays_and_back
@handle_nestable
@handle_exceptions
+@handle_array_function
def all_equal(
*xs: Iterable[Any], equality_matrix: bool = False
) -> Union[bool, ivy.Array, ivy.NativeArray]:
@@ -650,6 +653,7 @@ def all_equal(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def to_numpy(
x: Union[ivy.Array, ivy.NativeArray], /, *, copy: bool = True
) -> np.ndarray:
@@ -719,6 +723,7 @@ def isscalar(x: Any, /) -> bool:
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def to_scalar(x: Union[ivy.Array, ivy.NativeArray], /) -> Number:
"""Converts an array with a single element into a scalar.
@@ -773,6 +778,7 @@ def to_scalar(x: Union[ivy.Array, ivy.NativeArray], /) -> Number:
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def to_list(x: Union[ivy.Array, ivy.NativeArray], /) -> List:
"""Creates a (possibly nested) list from input array.
@@ -843,6 +849,7 @@ def to_list(x: Union[ivy.Array, ivy.NativeArray], /) -> List:
@handle_nestable
@outputs_to_ivy_arrays
@handle_exceptions
+@handle_array_function
def clip_vector_norm(
x: Union[ivy.Array, ivy.NativeArray],
max_norm: float,
@@ -929,6 +936,7 @@ def clip_vector_norm(
@handle_nestable
@handle_exceptions
+@handle_array_function
def clip_matrix_norm(
x: Union[ivy.Array, ivy.NativeArray],
max_norm: float,
@@ -1009,6 +1017,7 @@ def clip_matrix_norm(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def fourier_encode(
x: Union[ivy.Array, ivy.NativeArray],
max_freq: Union[float, ivy.Array, ivy.NativeArray],
@@ -1112,6 +1121,7 @@ def fourier_encode(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def value_is_nan(
x: Union[ivy.Array, ivy.NativeArray, Number],
/,
@@ -1171,6 +1181,7 @@ def value_is_nan(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def has_nans(
x: Union[ivy.Array, ivy.NativeArray], /, *, include_infs: bool = True
) -> bool:
@@ -1677,6 +1688,7 @@ def current_backend_str() -> Union[str, None]:
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def einops_rearrange(
x: Union[ivy.Array, ivy.NativeArray],
pattern: str,
@@ -1800,6 +1812,7 @@ def einops_rearrange(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def einops_reduce(
x: Union[ivy.Array, ivy.NativeArray],
pattern: str,
@@ -1869,6 +1882,7 @@ def einops_reduce(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def einops_repeat(
x: Union[ivy.Array, ivy.NativeArray],
pattern: str,
@@ -1949,6 +1963,7 @@ def get_min_denominator() -> float:
@handle_exceptions
+@handle_array_function
def set_min_denominator(val: float) -> None:
"""
Set the global minimum denominator used by ivy for numerically stable division.
@@ -1995,6 +2010,7 @@ def get_min_base() -> float:
@handle_exceptions
+@handle_array_function
def set_min_base(val: float) -> None:
"""Set the global minimum base used by ivy for numerically stable power raising.
@@ -2022,6 +2038,7 @@ def set_min_base(val: float) -> None:
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def stable_divide(
numerator: Union[Number, ivy.Array, ivy.NativeArray],
denominator: Union[Number, ivy.Array, ivy.NativeArray],
@@ -2120,6 +2137,7 @@ def stable_divide(
@inputs_to_ivy_arrays
@handle_nestable
@handle_exceptions
+@handle_array_function
def stable_pow(
base: Union[Number, ivy.Array, ivy.NativeArray],
exponent: Union[Number, ivy.Array, ivy.NativeArray],
@@ -2159,13 +2177,18 @@ def stable_pow(
@handle_exceptions
-def get_all_arrays_in_memory():
+def get_all_arrays_in_memory() -> List[Union[ivy.Array, ivy.NativeArray]]:
"""Gets all arrays which are currently alive."""
all_arrays = list()
for obj in gc.get_objects():
try:
- if ivy.is_native_array(obj):
- all_arrays.append(obj)
+ if ivy.current_backend_str() in ["", "numpy"]:
+ if ivy.is_ivy_array(obj):
+ all_arrays.append(obj)
+ else:
+ if ivy.is_native_array(obj):
+ all_arrays.append(obj)
+
except Exception:
pass
return all_arrays
@@ -2188,6 +2211,7 @@ def print_all_arrays_in_memory():
@handle_exceptions
+@handle_array_function
def set_queue_timeout(timeout: float):
"""
Set the global queue timeout value (in seconds)
@@ -2358,6 +2382,7 @@ def inplace_variables_supported(f=None):
@inputs_to_native_arrays
@handle_nestable
@handle_exceptions
+@handle_array_function
def supports_inplace_updates(x: Union[ivy.Array, ivy.NativeArray], /) -> bool:
"""
Determines whether in-place operations are supported for x's data type,
@@ -2413,6 +2438,7 @@ def supports_inplace_updates(x: Union[ivy.Array, ivy.NativeArray], /) -> bool:
@inputs_to_native_arrays
@handle_nestable
@handle_exceptions
+@handle_array_function
def assert_supports_inplace(x: Union[ivy.Array, ivy.NativeArray], /) -> bool:
"""Asserts that inplace operations are supported for x, else raises exception.
@@ -2439,6 +2465,7 @@ def assert_supports_inplace(x: Union[ivy.Array, ivy.NativeArray], /) -> bool:
@to_native_arrays_and_back
@handle_nestable
@handle_array_like_without_promotion
+@handle_array_function
def get_item(
x: Union[ivy.Array, ivy.NativeArray],
query: Union[ivy.Array, ivy.NativeArray, Tuple],
@@ -2478,6 +2505,7 @@ def get_item(
@handle_nestable
@handle_exceptions
@inputs_to_ivy_arrays
+@handle_array_function
def inplace_update(
x: Union[ivy.Array, ivy.NativeArray],
val: Union[ivy.Array, ivy.NativeArray],
@@ -2515,6 +2543,7 @@ def inplace_update(
@handle_nestable
@handle_exceptions
@inputs_to_ivy_arrays
+@handle_array_function
def inplace_decrement(
x: Union[ivy.Array, ivy.NativeArray],
val: Union[ivy.Array, ivy.NativeArray],
@@ -2584,6 +2613,7 @@ def inplace_decrement(
@handle_nestable
@handle_exceptions
@inputs_to_ivy_arrays
+@handle_array_function
def inplace_increment(
x: Union[ivy.Array, ivy.NativeArray],
val: Union[ivy.Array, ivy.NativeArray],
@@ -2642,6 +2672,7 @@ def inplace_increment(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def scatter_flat(
indices: Union[ivy.Array, ivy.NativeArray],
updates: Union[ivy.Array, ivy.NativeArray],
@@ -2681,6 +2712,7 @@ def scatter_flat(
@to_native_arrays_and_back
@handle_nestable
@handle_exceptions
+@handle_array_function
def scatter_nd(
indices: Union[ivy.Array, ivy.NativeArray],
updates: Union[ivy.Array, ivy.NativeArray],
@@ -2762,6 +2794,7 @@ def scatter_nd(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def gather(
params: Union[ivy.Array, ivy.NativeArray],
indices: Union[ivy.Array, ivy.NativeArray],
@@ -2868,6 +2901,7 @@ def gather(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def gather_nd(
params: Union[ivy.Array, ivy.NativeArray],
indices: Union[ivy.Array, ivy.NativeArray],
@@ -2939,6 +2973,7 @@ def gather_nd(
@handle_nestable
@handle_exceptions
+@handle_array_function
def multiprocessing(context: str = None):
"""Return backend-specific multiprocessing module.
@@ -2961,6 +2996,7 @@ def multiprocessing(context: str = None):
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def shape(
x: Union[ivy.Array, ivy.NativeArray], /, *, as_array: bool = False
) -> Union[ivy.Shape, ivy.NativeShape]:
@@ -3058,6 +3094,7 @@ def shape_array_mode() -> bool:
@to_native_arrays_and_back
@handle_nestable
@handle_array_like_without_promotion
+@handle_array_function
def get_num_dims(
x: Union[ivy.Array, ivy.NativeArray], /, *, as_array: bool = False
) -> int:
diff --git a/ivy/functional/ivy/gradients.py b/ivy/functional/ivy/gradients.py
index 4386f23acdb83..b8ae22e624a60 100644
--- a/ivy/functional/ivy/gradients.py
+++ b/ivy/functional/ivy/gradients.py
@@ -10,6 +10,7 @@
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
inputs_to_ivy_arrays,
to_native_arrays_and_back,
handle_out_argument,
@@ -321,6 +322,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# noinspection PyShadowingNames
@handle_exceptions
+@handle_array_function
def with_grads(*, with_grads: bool = None) -> bool:
"""
Enter a nested code space where gradients are computed. This method
@@ -371,6 +373,7 @@ def with_grads(*, with_grads: bool = None) -> bool:
# noinspection PyShadowingNames
@handle_exceptions
+@handle_array_function
def set_with_grads(with_grads: bool):
"""
Enter a nested code space where gradients are computed. This method
@@ -416,6 +419,7 @@ def set_with_grads(with_grads: bool):
@handle_exceptions
+@handle_array_function
def unset_with_grads():
"""
Enter a nested code space where gradients are computed. This method
@@ -452,6 +456,7 @@ def unset_with_grads():
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def stop_gradient(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -527,6 +532,7 @@ def stop_gradient(
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def execute_with_gradients(
func, xs, /, *, retain_grads=False, xs_grad_idxs=None, ret_grad_idxs=None
):
@@ -571,6 +577,7 @@ def execute_with_gradients(
@to_native_arrays_and_back
@handle_exceptions
+@handle_array_function
def value_and_grad(func):
"""
Create a function that evaluates both func and the gradient of func.
@@ -606,6 +613,7 @@ def value_and_grad(func):
@to_native_arrays_and_back
@handle_exceptions
+@handle_array_function
def jac(func):
"""Call function func, and return func's Jacobian partial derivatives.
@@ -641,6 +649,7 @@ def jac(func):
@to_native_arrays_and_back
@handle_exceptions
+@handle_array_function
def grad(func):
"""Call function func, and return func's gradients.
@@ -678,6 +687,7 @@ def grad(func):
@inputs_to_ivy_arrays
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def adam_step(
dcdw: Union[ivy.Array, ivy.NativeArray],
mw: Union[ivy.Array, ivy.NativeArray],
@@ -831,6 +841,7 @@ def adam_step(
@inputs_to_ivy_arrays
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def optimizer_update(
w: Union[ivy.Array, ivy.NativeArray],
effective_grad: Union[ivy.Array, ivy.NativeArray],
@@ -953,6 +964,7 @@ def optimizer_update(
@inputs_to_ivy_arrays
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def gradient_descent_update(
w: Union[ivy.Array, ivy.NativeArray],
dcdw: Union[ivy.Array, ivy.NativeArray],
@@ -1045,6 +1057,7 @@ def gradient_descent_update(
@inputs_to_ivy_arrays
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def lars_update(
w: Union[ivy.Array, ivy.NativeArray],
dcdw: Union[ivy.Array, ivy.NativeArray],
@@ -1095,6 +1108,7 @@ def lars_update(
@inputs_to_ivy_arrays
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def adam_update(
w: Union[ivy.Array, ivy.NativeArray],
dcdw: Union[ivy.Array, ivy.NativeArray],
@@ -1259,6 +1273,7 @@ def adam_update(
@inputs_to_ivy_arrays
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def lamb_update(
w: Union[ivy.Array, ivy.NativeArray],
dcdw: Union[ivy.Array, ivy.NativeArray],
diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py
index e400b2c4f81a6..7a544043522e1 100644
--- a/ivy/functional/ivy/layers.py
+++ b/ivy/functional/ivy/layers.py
@@ -7,6 +7,7 @@
import ivy
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
inputs_to_ivy_arrays,
to_native_arrays_and_back,
handle_out_argument,
@@ -26,6 +27,7 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def linear(
x: Union[ivy.Array, ivy.NativeArray],
weight: Union[ivy.Array, ivy.NativeArray],
@@ -171,6 +173,7 @@ def linear(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def dropout(
x: Union[ivy.Array, ivy.NativeArray],
prob: float,
@@ -341,6 +344,7 @@ def dropout(
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def scaled_dot_product_attention(
q: Union[ivy.Array, ivy.NativeArray],
k: Union[ivy.Array, ivy.NativeArray],
@@ -545,6 +549,7 @@ def scaled_dot_product_attention(
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def multi_head_attention(
x: Union[ivy.Array, ivy.NativeArray],
scale: float,
@@ -780,6 +785,7 @@ def call_einops(t):
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv1d(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -873,6 +879,7 @@ def conv1d(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv1d_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -933,6 +940,7 @@ def conv1d_transpose(
@handle_out_argument
@handle_nestable
@handle_array_like_without_promotion
+@handle_array_function
def conv2d(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -940,8 +948,8 @@ def conv2d(
padding: Union[str, Sequence[Tuple[int, int]]],
/,
*,
- data_format: Optional[str] = "NHWC",
- dilations: Optional[Union[int, Tuple[int, int]]] = 1,
+ data_format: str = "NHWC",
+ dilations: Union[int, Tuple[int, int]] = 1,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Computes a 2-D convolution given 4-D input x and filters arrays.
@@ -977,7 +985,6 @@ def conv2d(
but this function is *nestable*, and therefore also accepts :class:`ivy.Container`
instances in place of any of the arguments.
-
Examples
--------
With :class:`ivy.Array` input:
@@ -1061,6 +1068,7 @@ def conv2d(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv2d_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -1179,6 +1187,7 @@ def conv2d_transpose(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def depthwise_conv2d(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -1316,6 +1325,7 @@ def depthwise_conv2d(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv3d(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
filters: Union[ivy.Array, ivy.NativeArray, ivy.Container],
@@ -1431,6 +1441,7 @@ def conv3d(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv3d_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -1539,6 +1550,7 @@ def conv3d_transpose(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv_general_dilated(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -1617,6 +1629,7 @@ def conv_general_dilated(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv_general_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -1690,6 +1703,7 @@ def conv_general_transpose(
@handle_out_argument
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def conv(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
@@ -1793,6 +1807,7 @@ def conv(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def lstm_update(
x: Union[ivy.Array, ivy.NativeArray],
init_h: Union[ivy.Array, ivy.NativeArray],
diff --git a/ivy/functional/ivy/linear_algebra.py b/ivy/functional/ivy/linear_algebra.py
index f5b4296019be9..4f99c1cddd4a8 100644
--- a/ivy/functional/ivy/linear_algebra.py
+++ b/ivy/functional/ivy/linear_algebra.py
@@ -5,6 +5,7 @@
import ivy
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -25,6 +26,7 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def cholesky(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -161,6 +163,7 @@ def cholesky(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def cross(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -252,6 +255,7 @@ def cross(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def det(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -321,6 +325,7 @@ def det(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def diagonal(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -497,6 +502,7 @@ def diagonal(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def eig(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -557,6 +563,7 @@ def eig(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def eigh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -621,6 +628,7 @@ def eigh(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def eigvalsh(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -722,6 +730,7 @@ def eigvalsh(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def inner(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -760,6 +769,7 @@ def inner(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def inv(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -849,6 +859,7 @@ def inv(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def matmul(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -856,6 +867,8 @@ def matmul(
*,
transpose_a: bool = False,
transpose_b: bool = False,
+ adjoint_a: bool = False,
+ adjoint_b: bool = False,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Computes the matrix product.
@@ -985,7 +998,8 @@ def matmul(
"""
return current_backend(x1).matmul(
- x1, x2, transpose_a=transpose_a, transpose_b=transpose_b, out=out
+ x1, x2, transpose_a=transpose_a, transpose_b=transpose_b,
+ adjoint_a=adjoint_a, adjoint_b=adjoint_b, out=out
)
@@ -994,6 +1008,7 @@ def matmul(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def matrix_norm(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1141,6 +1156,7 @@ def matrix_norm(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def matrix_power(
x: Union[ivy.Array, ivy.NativeArray], n: int, /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -1233,6 +1249,7 @@ def matrix_power(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def matrix_rank(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1336,6 +1353,7 @@ def matrix_rank(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def matrix_transpose(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -1416,6 +1434,7 @@ def matrix_transpose(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def outer(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -1503,6 +1522,7 @@ def outer(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def pinv(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1573,6 +1593,7 @@ def pinv(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def qr(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1635,6 +1656,7 @@ def qr(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def slogdet(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1715,6 +1737,7 @@ def slogdet(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def solve(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -1768,6 +1791,7 @@ def solve(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def svd(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1891,6 +1915,7 @@ def svd(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def svdvals(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
@@ -2013,6 +2038,7 @@ def svdvals(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def tensordot(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -2094,6 +2120,7 @@ def tensordot(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def tensorsolve(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -2139,6 +2166,7 @@ def tensorsolve(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def trace(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2252,6 +2280,7 @@ def trace(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def vecdot(
x1: Union[ivy.Array, ivy.NativeArray],
x2: Union[ivy.Array, ivy.NativeArray],
@@ -2318,6 +2347,7 @@ def vecdot(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def vector_norm(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2416,6 +2446,7 @@ def vector_norm(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def diag(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2496,6 +2527,7 @@ def diag(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def vander(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -2568,6 +2600,7 @@ def vander(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def vector_to_skew_symmetric_matrix(
vector: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
) -> ivy.Array:
diff --git a/ivy/functional/ivy/losses.py b/ivy/functional/ivy/losses.py
index 9ecb2a88edaf1..d5e5fe3128daf 100644
--- a/ivy/functional/ivy/losses.py
+++ b/ivy/functional/ivy/losses.py
@@ -3,7 +3,11 @@
# local
import ivy
from typing import Optional, Union
-from ivy.func_wrapper import handle_nestable, handle_array_like_without_promotion
+from ivy.func_wrapper import (
+ handle_array_function,
+ handle_nestable,
+ handle_array_like_without_promotion,
+)
from ivy.exceptions import handle_exceptions
# Helpers #
@@ -26,6 +30,7 @@ def _reduce_loss(red, loss, axis, out):
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def cross_entropy(
true: Union[ivy.Array, ivy.NativeArray],
pred: Union[ivy.Array, ivy.NativeArray],
@@ -80,6 +85,7 @@ def cross_entropy(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def binary_cross_entropy(
true: Union[ivy.Array, ivy.NativeArray],
pred: Union[ivy.Array, ivy.NativeArray],
@@ -186,6 +192,7 @@ def binary_cross_entropy(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sparse_cross_entropy(
true: Union[ivy.Array, ivy.NativeArray],
pred: Union[ivy.Array, ivy.NativeArray],
diff --git a/ivy/functional/ivy/manipulation.py b/ivy/functional/ivy/manipulation.py
index f54739695c6a2..f7478f69d55fd 100644
--- a/ivy/functional/ivy/manipulation.py
+++ b/ivy/functional/ivy/manipulation.py
@@ -8,6 +8,7 @@
import ivy
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -37,6 +38,7 @@ def _calculate_out_shape(axis, array_shape):
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def concat(
xs: Union[
Tuple[Union[ivy.Array, ivy.NativeArray], ...],
@@ -94,6 +96,7 @@ def concat(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def expand_dims(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -219,6 +222,7 @@ def expand_dims(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def flip(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -303,6 +307,7 @@ def flip(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def permute_dims(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -392,6 +397,7 @@ def permute_dims(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def reshape(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -510,6 +516,7 @@ def reshape(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def roll(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
@@ -619,6 +626,7 @@ def roll(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def squeeze(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -700,6 +708,7 @@ def squeeze(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def stack(
arrays: Union[
Tuple[Union[ivy.Array, ivy.NativeArray], ...],
@@ -785,6 +794,7 @@ def stack(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def clip(
x: Union[ivy.Array, ivy.NativeArray],
x_min: Union[Number, ivy.Array, ivy.NativeArray],
@@ -910,6 +920,7 @@ def clip(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def constant_pad(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -996,6 +1007,7 @@ def constant_pad(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def repeat(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1064,6 +1076,7 @@ def repeat(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def split(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1138,6 +1151,7 @@ def split(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def swapaxes(
x: Union[ivy.Array, ivy.NativeArray],
axis0: int,
@@ -1239,6 +1253,7 @@ def swapaxes(
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def tile(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1317,6 +1332,7 @@ def tile(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def unstack(
x: Union[ivy.Array, ivy.NativeArray], /, *, axis: int = 0, keepdims: bool = False
) -> List[ivy.Array]:
@@ -1400,6 +1416,7 @@ def unstack(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def zero_pad(
x: Union[ivy.Array, ivy.NativeArray],
/,
diff --git a/ivy/functional/ivy/meta.py b/ivy/functional/ivy/meta.py
index dfe3d31100737..d9df968624647 100644
--- a/ivy/functional/ivy/meta.py
+++ b/ivy/functional/ivy/meta.py
@@ -1,5 +1,6 @@
# global
import ivy
+from ivy.func_wrapper import handle_array_function
from ivy.functional.ivy.gradients import gradient_descent_update
from ivy.exceptions import handle_exceptions
@@ -399,6 +400,7 @@ def _train_tasks(
@handle_exceptions
+@handle_array_function
def fomaml_step(
batch: ivy.Container,
inner_cost_fn: Callable,
@@ -519,6 +521,7 @@ def fomaml_step(
@handle_exceptions
+@handle_array_function
def reptile_step(
batch: ivy.Container,
cost_fn: Callable,
@@ -610,6 +613,7 @@ def reptile_step(
@handle_exceptions
+@handle_array_function
def maml_step(
batch: ivy.Container,
inner_cost_fn: Callable,
diff --git a/ivy/functional/ivy/nest.py b/ivy/functional/ivy/nest.py
index 32a2b770c8976..0545f5aaeee7c 100644
--- a/ivy/functional/ivy/nest.py
+++ b/ivy/functional/ivy/nest.py
@@ -1120,6 +1120,7 @@ def nested_map(
}
if shallow:
x.update(**ret)
+ return x
return class_instance(**ret)
return fn(x)
diff --git a/ivy/functional/ivy/norms.py b/ivy/functional/ivy/norms.py
index 944d1448281ca..4e85598822d7b 100644
--- a/ivy/functional/ivy/norms.py
+++ b/ivy/functional/ivy/norms.py
@@ -5,6 +5,7 @@
from typing import List, Union, Optional
import ivy
from ivy.func_wrapper import (
+ handle_array_function,
inputs_to_ivy_arrays,
integer_arrays_to_float,
handle_array_like_without_promotion,
@@ -20,6 +21,7 @@
@integer_arrays_to_float
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def layer_norm(
x: Union[ivy.Array, ivy.NativeArray],
normalized_idxs: List[int],
diff --git a/ivy/functional/ivy/random.py b/ivy/functional/ivy/random.py
index dab9f2cbfc63b..09f0f2b03e2db 100644
--- a/ivy/functional/ivy/random.py
+++ b/ivy/functional/ivy/random.py
@@ -6,6 +6,7 @@
# local
import ivy
from ivy.func_wrapper import (
+ handle_array_function,
infer_dtype,
infer_device,
handle_out_argument,
@@ -95,6 +96,7 @@ def _check_shapes_broadcastable(out, inp):
@infer_dtype
@handle_nestable
@handle_exceptions
+@handle_array_function
def random_uniform(
*,
low: Union[float, ivy.NativeArray, ivy.Array] = 0.0,
@@ -208,6 +210,7 @@ def random_uniform(
@infer_dtype
@handle_nestable
@handle_exceptions
+@handle_array_function
def random_normal(
*,
mean: Union[float, ivy.NativeArray, ivy.Array] = 0.0,
@@ -317,6 +320,7 @@ def random_normal(
@infer_device
@handle_nestable
@handle_exceptions
+@handle_array_function
def multinomial(
population_size: int,
num_samples: int,
@@ -425,6 +429,7 @@ def multinomial(
@infer_device
@handle_nestable
@handle_exceptions
+@handle_array_function
def randint(
low: Union[int, ivy.NativeArray, ivy.Array],
high: Union[int, ivy.NativeArray, ivy.Array],
@@ -520,6 +525,7 @@ def seed(*, seed_value: int = 0) -> None:
@handle_out_argument
@handle_nestable
@handle_exceptions
+@handle_array_function
def shuffle(
x: Union[ivy.Array, ivy.NativeArray],
/,
diff --git a/ivy/functional/ivy/searching.py b/ivy/functional/ivy/searching.py
index ae94383f9f201..560b8a46cd944 100644
--- a/ivy/functional/ivy/searching.py
+++ b/ivy/functional/ivy/searching.py
@@ -7,6 +7,7 @@
from ivy.backend_handler import current_backend
from ivy.exceptions import handle_exceptions
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -23,6 +24,7 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def argmax(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -126,6 +128,7 @@ def argmax(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def argmin(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -229,6 +232,7 @@ def argmin(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def nonzero(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -366,6 +370,7 @@ def nonzero(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def where(
condition: Union[ivy.Array, ivy.NativeArray],
x1: Union[ivy.Array, ivy.NativeArray],
@@ -455,6 +460,7 @@ def where(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def argwhere(
x: Union[ivy.Array, ivy.NativeArray],
/,
diff --git a/ivy/functional/ivy/set.py b/ivy/functional/ivy/set.py
index bfeff6acc4abf..ed5d81af9156e 100644
--- a/ivy/functional/ivy/set.py
+++ b/ivy/functional/ivy/set.py
@@ -4,6 +4,7 @@
# local
import ivy
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -20,6 +21,7 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def unique_all(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -139,6 +141,7 @@ def unique_all(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def unique_inverse(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -240,6 +243,7 @@ def unique_inverse(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def unique_values(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -306,6 +310,7 @@ def unique_values(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def unique_counts(
x: Union[ivy.Array, ivy.NativeArray],
/,
diff --git a/ivy/functional/ivy/sorting.py b/ivy/functional/ivy/sorting.py
index 4c52c5d1eea00..9ec87c3110414 100644
--- a/ivy/functional/ivy/sorting.py
+++ b/ivy/functional/ivy/sorting.py
@@ -4,6 +4,7 @@
# local
import ivy
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -21,6 +22,7 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def argsort(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -136,6 +138,7 @@ def argsort(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sort(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -239,6 +242,7 @@ def sort(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def lexsort(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -258,6 +262,7 @@ def lexsort(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def searchsorted(
x: Union[ivy.Array, ivy.NativeArray],
v: Union[ivy.Array, ivy.NativeArray],
diff --git a/ivy/functional/ivy/statistical.py b/ivy/functional/ivy/statistical.py
index c86e55c3b6c14..9c664ac68263c 100644
--- a/ivy/functional/ivy/statistical.py
+++ b/ivy/functional/ivy/statistical.py
@@ -5,6 +5,7 @@
import ivy
from ivy.backend_handler import current_backend
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -37,6 +38,7 @@ def _get_promoted_type_of_operands(operands):
@handle_out_argument
@handle_nestable
@handle_array_like_without_promotion
+@handle_array_function
def min(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -138,6 +140,7 @@ def min(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def max(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -245,6 +248,7 @@ def max(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def mean(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -352,6 +356,7 @@ def mean(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def prod(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -473,6 +478,7 @@ def prod(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def std(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -605,6 +611,7 @@ def std(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def sum(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -737,6 +744,7 @@ def sum(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def var(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -851,6 +859,7 @@ def var(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def cumsum(
x: Union[ivy.Array, ivy.NativeArray],
axis: int = 0,
@@ -992,6 +1001,7 @@ def cumsum(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def cumprod(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -1141,6 +1151,7 @@ def cumprod(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def einsum(
equation: str,
*operands: Union[ivy.Array, ivy.NativeArray],
diff --git a/ivy/functional/ivy/utility.py b/ivy/functional/ivy/utility.py
index 49004fe3dca49..f1b06e92e92b9 100644
--- a/ivy/functional/ivy/utility.py
+++ b/ivy/functional/ivy/utility.py
@@ -4,6 +4,7 @@
# local
import ivy
from ivy.func_wrapper import (
+ handle_array_function,
to_native_arrays_and_back,
handle_out_argument,
handle_nestable,
@@ -21,6 +22,7 @@
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def all(
x: Union[ivy.Array, ivy.NativeArray],
/,
@@ -131,6 +133,7 @@ def all(
@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
+@handle_array_function
def any(
x: Union[ivy.Array, ivy.NativeArray],
/,
diff --git a/ivy/inspection.py b/ivy/inspection.py
index f6e49708222ef..da5d111973e09 100644
--- a/ivy/inspection.py
+++ b/ivy/inspection.py
@@ -108,7 +108,6 @@ def fn_array_spec(fn):
type_hints = get_type_hints(fn)
except Exception as e:
type_hints = dict()
- print(f"exception found:{e} resorting to type_hindts=dict()")
array_idxs = list()
for i, (k, v) in enumerate(type_hints.items()):
a_idxs = _get_array_idxs(v)
diff --git a/ivy/stateful/converters.py b/ivy/stateful/converters.py
index f2fe96a6dd19b..54569196f836e 100644
--- a/ivy/stateful/converters.py
+++ b/ivy/stateful/converters.py
@@ -248,7 +248,7 @@ def _build(self, *args, **kwargs):
# noinspection PyUnresolvedReferences
params_hk = self._native_module.init(RNG, *args, **kwargs)
params_dict = _hk_flat_map_to_dict(params_hk)
- self._hk_params = ivy.Container(params_dict)
+ self._hk_params = ivy.Container(params_dict, dynamic_backend=False)
param_iterator = self._hk_params.cont_to_iterator()
_, param0 = next(param_iterator)
self._dev = ivy.as_ivy_dev(param0.device())
@@ -352,7 +352,8 @@ def _build(self, *args, **kwargs):
for param in self._native_module.variables
]
)
- )
+ ),
+ dynamic_backend=False,
)
def _forward(self, *a, **kw):
@@ -458,7 +459,8 @@ def _build(self, *args, **kwargs):
).items()
]
)
- )
+ ),
+ dynamic_backend=False,
)
@staticmethod
diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py
index 5511b9afc26f3..94cca38a29361 100644
--- a/ivy/stateful/module.py
+++ b/ivy/stateful/module.py
@@ -3,7 +3,6 @@
# global
import os
import abc
-import ivy.functional.backends.numpy
# local
import ivy
@@ -34,6 +33,7 @@ def __init__(
with_partial_v=False,
devices=None,
dtype=None,
+ dynamic_backend=None,
**kwargs,
):
"""
@@ -119,7 +119,7 @@ def __init__(
self._kwargs = kwargs
if build_mode != "on_init":
return
- self.build(*args, **kwargs)
+ self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
# Private #
# --------#
@@ -511,7 +511,15 @@ def save_weights(self, weights_path, /):
os.makedirs("/".join(weights_path.split("/")[:-1]), exist_ok=True)
self.v.cont_to_disk_as_hdf5(weights_path)
- def build(self, *args, from_call=False, device=None, dtype=None, **kwargs):
+ def build(
+ self,
+ *args,
+ from_call=False,
+ device=None,
+ dtype=None,
+ dynamic_backend=None,
+ **kwargs,
+ ):
"""
Build the internal layers and variables for this module.
@@ -547,8 +555,13 @@ def build(self, *args, from_call=False, device=None, dtype=None, **kwargs):
# build variables based on locally built layers, if v not passed in constructor
v_from_constructor = self._v_in
- created = Container(self._create_variables(device=self._dev, dtype=dtype))
- created_n_found = Container(dict(**self._find_variables(obj=self), **created))
+ created = Container(
+ self._create_variables(device=self._dev, dtype=dtype), dynamic_backend=False
+ )
+ created_n_found = Container(
+ dict(**self._find_variables(obj=self), **created),
+ dynamic_backend=dynamic_backend,
+ )
if ivy.exists(v_from_constructor):
if self._with_partial_v:
if v_from_constructor:
diff --git a/ivy_tests/test_docstrings.py b/ivy_tests/test_docstrings.py
index 96cd508eee257..baa70d188ff65 100644
--- a/ivy_tests/test_docstrings.py
+++ b/ivy_tests/test_docstrings.py
@@ -85,7 +85,6 @@ def check_docstring_examples_run(
).__doc__
else:
docstring = ivy.backend_handler.ivy_original_dict[fn_name].__doc__
-
if docstring is None:
return True
@@ -93,7 +92,6 @@ def check_docstring_examples_run(
trimmed_docstring = trim(docstring=docstring)
trimmed_docstring = trimmed_docstring.split("\n")
-
# end_index: -1, if print statement is not found in the docstring
end_index = -1
@@ -109,6 +107,8 @@ def check_docstring_examples_run(
if s.startswith(">>>") or s.lower().startswith("with"):
end_index = index + i + 1
break
+ else:
+ end_index = len(trimmed_docstring)
p_output = trimmed_docstring[index + 1 : end_index]
p_output = ("").join(p_output).replace(" ", "")
p_output = p_output.replace("...", "")
@@ -124,10 +124,11 @@ def check_docstring_examples_run(
for line in trimmed_docstring:
if line.startswith(">>>"):
executable_lines.append(line.split(">>>")[1][1:])
- if line.startswith("..."):
+ is_multiline_executable = True
+ if line.startswith("...") and is_multiline_executable:
executable_lines[-1] += line.split("...")[1][1:]
if ">>> print(" in line:
- break
+ is_multiline_executable = False
# noinspection PyBroadException
f = StringIO()
diff --git a/ivy_tests/test_ivy/conftest.py b/ivy_tests/test_ivy/conftest.py
index 175a855a4737c..9bc479895dc31 100644
--- a/ivy_tests/test_ivy/conftest.py
+++ b/ivy_tests/test_ivy/conftest.py
@@ -2,7 +2,8 @@
import os
import pytest
from typing import Dict
-
+import subprocess
+import importlib
mod_frontend = {
"tensorflow": None,
@@ -58,7 +59,19 @@ def pytest_configure(config):
if frontend:
frontend_strs = frontend.split(",")
for i in frontend_strs:
- mod_frontend[i.split("/")[0]] = i
+ process = subprocess.Popen(
+ [
+ "/opt/miniconda/envs/multienv/bin/python",
+ "multiversion_frontend_test.py",
+ "numpy" + "/" + importlib.import_module("numpy").__version__,
+ i,
+ ],
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ )
+ mod_frontend[i.split("/")[0]] = [i, process]
# compile_graph
raw_value = config.getoption("--compile_graph")
diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py
index 831443c2fca6b..ae320fd1dd4e9 100644
--- a/ivy_tests/test_ivy/helpers/function_testing.py
+++ b/ivy_tests/test_ivy/helpers/function_testing.py
@@ -11,7 +11,26 @@
import jsonpickle
except:
pass
-import subprocess
+
+
+def framework_comparator(frontend):
+ if frontend.split("/")[0] == "jax":
+ fw = frontend.split("/")[1] + frontend.split("/")[3]
+ backend_fw = (
+ importlib.import_module("jax").__version__
+ + importlib.import_module("jaxlib").__version__
+ )
+ return backend_fw == fw
+ elif frontend.split("/")[0] == "torch":
+ return (
+ frontend.split("/")[1]
+ == importlib.import_module(frontend.split("/")[1]).__version__.split("+")[0]
+ )
+ else:
+ return (
+ frontend.split("/")[1]
+ == importlib.import_module(frontend.split("/")[1]).__version__
+ )
try:
@@ -43,14 +62,10 @@
available_frameworks = available_frameworkss()
-def multiversion_native_array_check(fw):
- dic = {"torch": "Tensor", "tensorflow": "Tensor", "numpy": "ndarray"}
- param = dic[fw.__name__]
-
- def func(val):
- return isinstance(val, getattr(fw, param, None))
-
- return func
+def make_json_pickable(s):
+ s = s.replace("builtins.bfloat16", "ivy.bfloat16")
+ s = s.replace("jax._src.device_array.reconstruct_device_array", "jax.numpy.array")
+ return s
def empty_func(*args, **kwargs):
@@ -502,6 +517,8 @@ def test_frontend_function(
# frontend function
# parse function name and frontend submodules (jax.lax, jax.numpy etc.)
+ if isinstance(frontend, list):
+ frontend, frontend_proc = frontend
split_index = fn_tree.rfind(".")
frontend_submods, fn_name = fn_tree[:split_index], fn_tree[split_index + 1 :]
function_module = importlib.import_module(frontend_submods)
@@ -622,111 +639,41 @@ def test_frontend_function(
# multiversion zone, changes made in non-multiversion zone should
# be applied here too
- if (
- frontend.split("/")[1]
- != importlib.import_module(frontend.split("/")[0]).__version__
- ):
+ if not framework_comparator(frontend):
try:
- # create frontend framework args
- args_frontend = ivy.nested_map(
- args_np,
- lambda x: ivy.native_array(x)
- if isinstance(x, np.ndarray)
- else ivy.as_native_dtype(x)
- if isinstance(x, ivy.Dtype)
- else x,
- shallow=False,
- )
- kwargs_frontend = ivy.nested_map(
- kwargs_np,
- lambda x: ivy.native_array(x) if isinstance(x, np.ndarray) else x,
- shallow=False,
- )
-
- # change ivy dtypes to native dtypes
- if "dtype" in kwargs_frontend:
- kwargs_frontend["dtype"] = ivy.as_native_dtype(
- kwargs_frontend["dtype"]
- )
-
- # change ivy device to native devices
- if "device" in kwargs_frontend:
- kwargs_frontend["device"] = ivy.as_native_dev(
- kwargs_frontend["device"]
- )
-
- # check and replace the NativeClass objects in arguments
- # with true counterparts
- args_frontend = ivy.nested_map(
- args_frontend, fn=convtrue, include_derived=True, max_depth=10
- )
- kwargs_frontend = ivy.nested_map(
- kwargs_frontend, fn=convtrue, include_derived=True, max_depth=10
- )
# compute the return via the frontend framework
module_name = fn_tree[25 : fn_tree.rfind(".")]
- # frontend_fw = importlib.import_module(module_name)
-
- # temp=dict()
- # if frontend.split("/")[0]=='jax':
- # # we prepare for jaxlib
- # pass
pickle_dict = {"a": args_np, "b": kwargs_np}
-
- frontend_ret = subprocess.run(
- [
- "/opt/miniconda/envs/multienv/bin/python",
- "test.py",
- jsonpickle.dumps(pickle_dict),
- fn_name,
- module_name,
- "numpy" + "/" + np.__version__,
- frontend,
- ],
- capture_output=True,
- text=True,
- )
-
- if frontend_ret.stdout:
- frontend_ret = jsonpickle.loads(frontend_ret.stdout)
+ process = frontend_proc
+ z = make_json_pickable(jsonpickle.dumps(pickle_dict))
+ try:
+ process.stdin.write(z + "\n")
+ process.stdin.write(module_name + "\n")
+ process.stdin.write(fn_name + "\n")
+ process.stdin.flush()
+ except Exception as e:
+ print(
+ "Something bad happened to the subprocess, here are the logs:\n\n"
+ )
+ print(process.stdout.readlines())
+ raise e
+ frontend_ret = process.stdout.readline()
+ if frontend_ret:
+ frontend_ret = jsonpickle.loads(make_json_pickable(frontend_ret))
else:
- print(frontend_ret.stderr)
+ print(process.stderr.readlines())
raise Exception
- ivy.set_backend("numpy")
- frontend_ret = ivy.to_native(frontend_ret)
-
- # globally_done = (
- # frontend.split("/")[0]
- # + "/"
- # + importlib.import_module(frontend.split("/")[0]).__version__
- # )
- # try:
- # frontend_fw = config.custom_import(
- # frontend.split("/")[0] + "/" + frontend.split("/")[1],
- # module_name,
- # globally_done=globally_done,
- # )
- # except Exception as e:
- # raise e
- #
- # print(frontend_fw.__version__)
- # frontend_ret = frontend_fw.__dict__[fn_name](
- # *args_frontend, **kwargs_frontend
- # )
- # frontend_ret = np.asarray(
- # frontend_ret
- # ) # we do this because frontend_ret comes from a module in another file
-
if ivy.isscalar(frontend_ret):
frontend_ret_np_flat = [np.asarray(frontend_ret)]
else:
+ frontend_ret = ivy.to_ivy(frontend_ret)
# tuplify the frontend return
if not isinstance(frontend_ret, tuple):
frontend_ret = (frontend_ret,)
frontend_ret_idxs = ivy.nested_argwhere(
- frontend_ret, ivy.is_native_array
+ frontend_ret, lambda x: isinstance(x, np.ndarray)
)
frontend_ret_flat = ivy.multi_index_nest(
frontend_ret, frontend_ret_idxs
@@ -1387,6 +1334,8 @@ def test_frontend_method(
ret_gt
optional, return value from the Ground Truth function
"""
+ if isinstance(frontend, list):
+ frontend, frontend_proc = frontend
_assert_dtypes_are_valid(init_input_dtypes)
_assert_dtypes_are_valid(method_input_dtypes)
diff --git a/ivy_tests/test_ivy/helpers/globals.py b/ivy_tests/test_ivy/helpers/globals.py
index 0f0adf3fbf427..41707303d3946 100644
--- a/ivy_tests/test_ivy/helpers/globals.py
+++ b/ivy_tests/test_ivy/helpers/globals.py
@@ -176,8 +176,9 @@ def _set_frontend(framework: str):
global CURRENT_FRONTEND
if CURRENT_FRONTEND is not _Notsetval:
raise InterruptedTest(CURRENT_RUNNING_TEST)
- if "/" in framework:
- CURRENT_FRONTEND = FWS_DICT[framework.split("/")[0]]
+ if isinstance(framework, list):
+
+ CURRENT_FRONTEND = FWS_DICT[framework[0].split("/")[0]]
else:
CURRENT_FRONTEND = FWS_DICT[framework]
diff --git a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py
index 1c1588fefd618..64cb41a47266f 100644
--- a/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py
+++ b/ivy_tests/test_ivy/helpers/hypothesis_helpers/array_helpers.py
@@ -900,7 +900,13 @@ def array_and_broadcastable_shape(draw, dtype):
@st.composite
def arrays_for_pooling(
- draw, min_dims, max_dims, min_side, max_side, allow_explicit_padding=False
+ draw,
+ min_dims,
+ max_dims,
+ min_side,
+ max_side,
+ allow_explicit_padding=False,
+ return_dilation=False,
):
in_shape = draw(
nph.array_shapes(
@@ -931,24 +937,27 @@ def arrays_for_pooling(
)
if array_dim == 3:
kernel = draw(st.tuples(st.integers(1, in_shape[1])))
+ new_kernel = kernel
+ if return_dilation:
+ new_kernel = []
+ dilations = []
+ for i in range(len(kernel)):
+ if kernel[i] > 1:
+ max_dilation = (in_shape[i + 1] - kernel[i]) // (kernel[i] - 1) + 1
+ dilations.append(draw(st.integers(1, max_dilation)))
+ new_kernel.append(kernel[i] + (kernel[i] - 1) * (dilations[i] - 1))
+ else:
+ dilations.append(1)
+ new_kernel.append(kernel[i])
if allow_explicit_padding:
padding = []
for i in range(array_dim - 2):
- max_pad = kernel[i] // 2
- possible_pad_combos = [
- (i, max_pad - i)
- for i in range(0, max_pad)
- if i + (max_pad - i) == max_pad
- ]
- if len(possible_pad_combos) == 0:
- pad_selected_combo = (0, 0)
- else:
- pad_selected_combo = draw(st.sampled_from(possible_pad_combos))
+ max_pad = new_kernel[i] // 2
padding.append(
draw(
st.tuples(
- st.integers(0, pad_selected_combo[0]),
- st.integers(0, pad_selected_combo[1]),
+ st.integers(0, max_pad),
+ st.integers(0, max_pad),
)
)
)
@@ -956,4 +965,6 @@ def arrays_for_pooling(
else:
padding = draw(st.sampled_from(["VALID", "SAME"]))
strides = draw(st.tuples(st.integers(1, in_shape[1])))
+ if return_dilation:
+ return dtype, x, kernel, strides, padding, dilations
return dtype, x, kernel, strides, padding
diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_creation.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_creation.py
index c272e5ac07419..dcd2202a5c374 100644
--- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_creation.py
+++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_creation.py
@@ -12,8 +12,14 @@
fn_tree="jax.numpy.array",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
+ num_arrays=st.integers(min_value=1, max_value=10),
+ min_num_dims=0,
+ max_num_dims=5,
+ min_dim_size=1,
+ max_dim_size=5,
+ shared_dtype=True,
),
- dtype=helpers.get_dtypes("numeric", full=False, none=True),
+ as_list=st.booleans(),
copy=st.booleans(),
ndmin=helpers.ints(min_value=0, max_value=10),
test_with_out=st.just(False),
@@ -21,7 +27,7 @@
def test_jax_numpy_array(
*,
dtype_and_x,
- dtype,
+ as_list,
copy,
ndmin,
on_device,
@@ -30,14 +36,24 @@ def test_jax_numpy_array(
frontend,
):
input_dtype, x = dtype_and_x
+
+ if as_list:
+ if isinstance(x, list):
+ x = [list(i) if len(i.shape) > 0 else [float(i)] for i in x]
+ else:
+ x = list(x)
+ else:
+ if len(x) == 1:
+ x = x[0]
+
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
- object=x[0],
- dtype=dtype[0],
+ object=x,
+ dtype=input_dtype[0],
copy=copy,
order="K",
ndmin=ndmin,
@@ -467,3 +483,74 @@ def test_jax_numpy_ndim(
on_device=on_device,
a=x[0],
)
+
+
+# empty_like
+@handle_frontend_test(
+ fn_tree="jax.numpy.empty_like",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("valid"),
+ ),
+ shape=helpers.get_shape(
+ allow_none=True,
+ min_num_dims=1,
+ max_num_dims=5,
+ min_dim_size=1,
+ max_dim_size=10,
+ ),
+ dtype=helpers.get_dtypes("valid", full=False),
+ test_with_out=st.just(False),
+)
+def test_jax_numpy_empty_like(
+ dtype_and_x,
+ shape,
+ dtype,
+ test_flags,
+ frontend,
+ fn_tree,
+ on_device,
+):
+ input_dtype, x = dtype_and_x
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ a=x[0],
+ dtype=dtype[0],
+ shape=shape,
+ )
+
+
+# full
+@handle_frontend_test(
+ fn_tree="jax.numpy.full",
+ shape=helpers.get_shape(
+ min_num_dims=1,
+ max_num_dims=5,
+ min_dim_size=1,
+ max_dim_size=10,
+ ),
+ input_fill_dtype=_input_fill_and_dtype(),
+ test_with_out=st.just(False),
+)
+def test_jax_numpy_full(
+ shape,
+ input_fill_dtype,
+ frontend,
+ test_flags,
+ fn_tree,
+ on_device,
+):
+ input_dtype, _, fill_value, dtype = input_fill_dtype
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ shape=shape,
+ fill_value=fill_value,
+ dtype=dtype,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py
index cb522926d7d05..116c81c8738b2 100644
--- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py
+++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_logic.py
@@ -604,6 +604,33 @@ def test_jax_numpy_invert(
)
+# isfinite
+@handle_frontend_test(
+ fn_tree="jax.numpy.isfinite",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"), allow_nan=True
+ ),
+ test_with_out=st.just(False),
+)
+def test_jax_numpy_isfinite(
+ *,
+ dtype_and_x,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ input_dtype, x = dtype_and_x
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ x=x[0],
+ )
+
+
# isinf
@handle_frontend_test(
fn_tree="jax.numpy.isinf",
@@ -661,3 +688,29 @@ def test_jax_numpy_isclose(
b=input[1],
equal_nan=equal_nan,
)
+
+
+# logical_not
+@handle_frontend_test(
+ fn_tree="jax.numpy.logical_not",
+ dtypes_values=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("bool"),
+ num_arrays=1,
+ ),
+)
+def test_jax_numpy_logical_not(
+ dtypes_values,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ x_dtypes, x = dtypes_values
+ np_helpers.test_frontend_function(
+ input_dtypes=x_dtypes,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ x=x[0],
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_manipulation.py
index 489a5a257ce52..5446c03d94f45 100644
--- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_manipulation.py
+++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_manipulation.py
@@ -689,3 +689,73 @@ def test_jax_numpy_atleast_2d(
on_device=on_device,
**arys,
)
+
+
+# atleast_1d
+@handle_frontend_test(
+ fn_tree="jax.numpy.atleast_1d",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("valid"),
+ num_arrays=helpers.ints(min_value=1, max_value=10),
+ ),
+ test_with_out=st.just(False),
+)
+def test_jax_numpy_atleast_1d(
+ *,
+ dtype_and_x,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ input_dtype, arrays = dtype_and_x
+ arys = {}
+ for i, (array, idtype) in enumerate(zip(arrays, input_dtype)):
+ arys["arrs{}".format(i)] = np.asarray(array, dtype=idtype)
+ test_flags.num_positional_args = len(arys)
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ **arys,
+ )
+
+
+@st.composite
+def _squeeze_helper(draw):
+ shape = draw(st.shared(helpers.get_shape(), key="shape"))
+ valid_axes = [idx for idx in range(len(shape)) if shape[idx] == 1] + [None]
+ return draw(st.sampled_from(valid_axes))
+
+
+# squeeze
+@handle_frontend_test(
+ fn_tree="jax.numpy.squeeze",
+ dtype_and_values=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("valid"),
+ shape=st.shared(helpers.get_shape(), key="shape"),
+ ),
+ axis=_squeeze_helper(),
+ test_with_out=st.just(False),
+)
+def test_jax_numpy_squeeze(
+ *,
+ dtype_and_values,
+ axis,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ input_dtype, values = dtype_and_values
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ a=values[0],
+ axis=axis,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py
index 4d17fb447121e..c4c092395d583 100644
--- a/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py
+++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_jax_numpy_math.py
@@ -70,6 +70,55 @@ def test_jax_numpy_add(
)
+# diff
+@st.composite
+def _get_dtype_input_and_vector(draw):
+ size1 = draw(helpers.ints(min_value=1, max_value=5))
+ size2 = draw(helpers.ints(min_value=1, max_value=5))
+ dtype = draw(helpers.get_dtypes("integer"))
+ vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2)))
+ return dtype, vec1
+
+
+@handle_frontend_test(
+ fn_tree="jax.numpy.diff",
+ dtype_and_x=_get_dtype_input_and_vector(),
+ n=helpers.ints(
+ min_value=0,
+ max_value=10,
+ ),
+ axis=helpers.ints(
+ min_value=-1,
+ max_value=10,
+ ),
+)
+def test_jax_numpy_diff(
+ *,
+ dtype_and_x,
+ test_flags,
+ on_device,
+ fn_tree,
+ frontend,
+ n,
+ axis,
+):
+ input_dtype, x = dtype_and_x
+ if axis > (x[0].ndim - 1):
+ axis = x[0].ndim - 1
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ test_flags=test_flags,
+ frontend=frontend,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ a=x[0],
+ n=n,
+ axis=axis,
+ prepend=None,
+ append=None,
+ )
+
+
# arctan
@handle_frontend_test(
fn_tree="jax.numpy.arctan",
@@ -1402,6 +1451,32 @@ def test_jax_numpy_negative(
)
+# positive
+@handle_frontend_test(
+ fn_tree="jax.numpy.positive",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1
+ ),
+ test_with_out=st.just(False),
+)
+def test_jax_numpy_positive(
+ dtype_and_x,
+ frontend,
+ test_flags,
+ fn_tree,
+ on_device,
+):
+ input_dtype, x = dtype_and_x
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ x=x[0],
+ )
+
+
# rad2deg
@handle_frontend_test(
fn_tree="jax.numpy.rad2deg",
@@ -2114,3 +2189,34 @@ def test_jax_numpy_real(
test_values=True,
val=x[0],
)
+
+
+# inner
+@handle_frontend_test(
+ fn_tree="jax.numpy.inner",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
+ min_value=-10,
+ max_value=10,
+ num_arrays=2,
+ shared_dtype=True,
+ ),
+)
+def test_jax_numpy_inner(
+ *,
+ dtype_and_x,
+ test_flags,
+ on_device,
+ fn_tree,
+ frontend,
+):
+ input_dtypes, xs = dtype_and_x
+ helpers.test_frontend_function(
+ input_dtypes=input_dtypes,
+ test_flags=test_flags,
+ frontend=frontend,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ a=xs[0],
+ b=xs[1],
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py
index 906f5671205ab..8d5a4f1c856aa 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_creation_routines/test_building_matrices.py
@@ -5,6 +5,7 @@
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
+from ivy_tests.test_ivy.test_functional.test_core.test_linalg import _diag_helper
# tril
@@ -44,26 +45,16 @@ def test_numpy_tril(
# diag
@handle_frontend_test(
fn_tree="numpy.diag",
- dtype_and_x=helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("numeric"),
- num_arrays=1,
- min_num_dims=2,
- max_num_dims=2,
- min_dim_size=1,
- max_dim_size=2,
- ),
- k=helpers.ints(min_value=-10, max_value=10),
- test_with_out=st.just(False),
+ dtype_and_x_k=_diag_helper(),
)
def test_numpy_diag(
- dtype_and_x,
- k,
+ dtype_and_x_k,
frontend,
test_flags,
fn_tree,
on_device,
):
- dtype, x = dtype_and_x
+ dtype, x, k = dtype_and_x_k
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_decompositions.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_decompositions.py
index 7e4f7af324354..e655ece0953bd 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_decompositions.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_decompositions.py
@@ -89,14 +89,18 @@ def test_numpy_qr(
fn_tree="numpy.linalg.svd",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
- min_value=0,
+ min_value=0.1,
max_value=10,
shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])),
),
+ full_matrices=st.booleans(),
+ compute_uv=st.booleans(),
test_with_out=st.just(False),
)
def test_numpy_svd(
dtype_and_x,
+ full_matrices,
+ compute_uv,
frontend,
test_flags,
fn_tree,
@@ -114,10 +118,11 @@ def test_numpy_svd(
test_values=False,
fn_tree=fn_tree,
on_device=on_device,
- rtol=1e-02,
a=x,
+ full_matrices=full_matrices,
+ compute_uv=compute_uv,
)
for u, v in zip(ret, ret_gt):
u = ivy.to_numpy(ivy.abs(u))
v = ivy.to_numpy(ivy.abs(v))
- helpers.value_test(ret_np_flat=u, ret_np_from_gt_flat=v)
+ helpers.value_test(ret_np_flat=u, ret_np_from_gt_flat=v, rtol=1e-04, atol=1e-04)
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_norms_and_other_numbers.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_norms_and_other_numbers.py
index 6d28bdee166b0..da4c24d306a50 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_norms_and_other_numbers.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_norms_and_other_numbers.py
@@ -156,11 +156,26 @@ def test_numpy_slogdet(
@handle_frontend_test(
fn_tree="numpy.trace",
- dtype_and_x=_get_dtype_and_matrix(),
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ min_num_dims=2,
+ max_num_dims=5,
+ min_dim_size=2,
+ max_dim_size=10,
+ large_abs_safety_factor=2,
+ small_abs_safety_factor=2,
+ safety_factor_scale="log",
+ ),
test_with_out=st.just(False),
+ offset=st.integers(min_value=0, max_value=0),
+ axis1=st.integers(min_value=0, max_value=0),
+ axis2=st.integers(min_value=1, max_value=1),
)
def test_numpy_trace(
dtype_and_x,
+ offset,
+ axis1,
+ axis2,
frontend,
test_flags,
fn_tree,
@@ -174,4 +189,7 @@ def test_numpy_trace(
fn_tree=fn_tree,
on_device=on_device,
a=x[0],
+ offset=offset,
+ axis1=axis1,
+ axis2=axis2,
)
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_solving_equations_and_inverting_matrices.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_solving_equations_and_inverting_matrices.py
index 1e96e01e7b126..3c43d9b1fb507 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_solving_equations_and_inverting_matrices.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_linear_algebra/test_solving_equations_and_inverting_matrices.py
@@ -9,6 +9,7 @@
_get_first_matrix,
_get_second_matrix,
)
+from hypothesis import reproduce_failure
# solve
@@ -167,3 +168,56 @@ def test_numpy_tensorinv(
a=x,
ind=ind,
)
+
+
+@st.composite
+def _get_lstsq_matrices(draw):
+ shape1 = draw(helpers.ints(min_value=2, max_value=10))
+ shape2 = draw(helpers.ints(min_value=2, max_value=10))
+ input_dtype = "float64"
+ a = draw(
+ helpers.array_values(
+ dtype=input_dtype,
+ shape=(shape1, shape2),
+ min_value=10,
+ max_value=20,
+ exclude_min=False,
+ exclude_max=False,
+ )
+ )
+ b = draw(
+ helpers.array_values(
+ dtype=input_dtype,
+ shape=(shape1, 1),
+ min_value=10,
+ max_value=20,
+ exclude_min=False,
+ exclude_max=False,
+ )
+ )
+ return input_dtype, a, b
+
+
+# lstsq
+@handle_frontend_test(
+ fn_tree="numpy.linalg.lstsq",
+ params=_get_lstsq_matrices(),
+ test_with_out=st.just(False),
+)
+def test_numpy_lstsq(
+ params,
+ frontend,
+ test_flags,
+ fn_tree,
+ on_device,
+):
+ input_dtype, fir, sec = params
+ helpers.test_frontend_function(
+ input_dtypes=[input_dtype, input_dtype],
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ a=fir,
+ b=sec,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py
index 0fccb645e16cb..3c5ecf446f8ea 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_arithmetic_operations.py
@@ -6,6 +6,9 @@
import ivy_tests.test_ivy.helpers as helpers
import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
+from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_elementwise import ( # noqa
+ _float_power_helper,
+)
# add
@@ -287,14 +290,9 @@ def test_numpy_power(
@handle_frontend_test(
fn_tree="numpy.float_power",
dtypes_values_casting=np_frontend_helpers.dtypes_values_casting_dtype(
- arr_func=[
- lambda: helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("float"),
- num_arrays=2,
- shared_dtype=True,
- )
- ],
- get_dtypes_kind="float",
+ arr_func=[lambda: _float_power_helper()],
+ get_dtypes_kind="float_and_complex",
+ special=False,
),
where=np_frontend_helpers.where(),
number_positional_args=np_frontend_helpers.get_num_positional_args_ufunc(
@@ -309,11 +307,16 @@ def test_numpy_float_power(
fn_tree,
):
input_dtypes, xs, casting, dtype = dtypes_values_casting
+ xs = list(xs[0])
+ input_dtypes = list(input_dtypes[0])
where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools(
where=where,
input_dtype=input_dtypes,
test_flags=test_flags,
)
+ # removing casting options as they raise errors for this function
+ assume(casting == "same_kind")
+ assume(dtype != "bool")
np_frontend_helpers.test_frontend_function(
input_dtypes=input_dtypes,
frontend=frontend,
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py
index 837e2150b70fa..f1297dfd14d31 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_rounding.py
@@ -144,11 +144,10 @@ def test_numpy_rint(
on_device,
):
input_dtype, x, casting, dtype = dtypes_values_casting
- where, as_variable, native_array = np_frontend_helpers.handle_where_and_array_bools(
+ where, input_dtype, test_flags = np_frontend_helpers.handle_where_and_array_bools(
where=where,
input_dtype=input_dtype,
- as_variable=test_flags.as_variable,
- native_array=test_flags.native_arrays,
+ test_flags=test_flags,
)
np_frontend_helpers.test_frontend_function(
input_dtypes=input_dtype,
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py
index 62e47d71de6aa..138dcc2e8d814 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_sums_products_differences.py
@@ -326,3 +326,32 @@ def test_numpy_nansum(
where=where,
keepdims=keepdims,
)
+
+
+# diff
+@handle_frontend_test(
+ fn_tree="numpy.diff",
+ dtype_x_axis=helpers.dtype_values_axis(
+ available_dtypes=helpers.get_dtypes("valid"),
+ min_num_dims=1,
+ valid_axis=True,
+ force_int_axis=True,
+ ),
+)
+def test_numpy_diff(
+ dtype_x_axis,
+ frontend,
+ test_flags,
+ fn_tree,
+ on_device,
+):
+ input_dtype, x, axis = dtype_x_axis
+ np_frontend_helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ x=x[0],
+ axis=axis,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py
index a0debe8a0223c..62e8fbed5fe76 100644
--- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py
+++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_ndarray/test_ndarray.py
@@ -1215,6 +1215,46 @@ def test_numpy_instance_mul__(
)
+# __floordiv__ test
+@handle_frontend_method(
+ class_tree=CLASS_TREE,
+ init_tree="numpy.array",
+ method_name="__floordiv__",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("integer"),
+ num_arrays=2,
+ allow_inf=False,
+ large_abs_safety_factor=4,
+ safety_factor_scale="linear",
+ shared_dtype=True,
+ ),
+)
+def test_numpy_instance_floordiv__(
+ dtype_and_x,
+ frontend_method_data,
+ init_flags,
+ method_flags,
+ frontend,
+):
+ input_dtypes, xs = dtype_and_x
+ assume(not np.any(np.isclose(xs[1], 0)))
+ helpers.test_frontend_method(
+ init_input_dtypes=input_dtypes,
+ init_all_as_kwargs_np={
+ "object": xs[0],
+ },
+ method_input_dtypes=input_dtypes,
+ method_all_as_kwargs_np={
+ "value": xs[1],
+ },
+ init_flags=init_flags,
+ method_flags=method_flags,
+ frontend_method_data=frontend_method_data,
+ frontend=frontend,
+ atol_=1,
+ )
+
+
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="numpy.array",
@@ -1553,6 +1593,45 @@ def test_numpy_instance_pos__(
)
+@handle_frontend_method(
+ class_tree=CLASS_TREE,
+ init_tree="numpy.array",
+ method_name="__ifloordiv__",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("integer"),
+ num_arrays=2,
+ allow_inf=False,
+ large_abs_safety_factor=4,
+ safety_factor_scale="linear",
+ shared_dtype=True,
+ ),
+)
+def test_numpy_instance_ifloordiv__(
+ dtype_and_x,
+ frontend_method_data,
+ init_flags,
+ method_flags,
+ frontend,
+):
+ input_dtypes, xs = dtype_and_x
+ assume(not np.any(np.isclose(xs[1], 0)))
+ helpers.test_frontend_method(
+ init_input_dtypes=input_dtypes,
+ init_all_as_kwargs_np={
+ "object": xs[0],
+ },
+ method_input_dtypes=input_dtypes,
+ method_all_as_kwargs_np={
+ "value": xs[1],
+ },
+ init_flags=init_flags,
+ method_flags=method_flags,
+ frontend_method_data=frontend_method_data,
+ frontend=frontend,
+ atol_=1,
+ )
+
+
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="numpy.array",
@@ -2257,3 +2336,43 @@ def test_numpy_instance_len__(
frontend=frontend,
frontend_method_data=frontend_method_data,
)
+
+# __array__
+@handle_frontend_method(
+ class_tree=CLASS_TREE,
+ init_tree="numpy.array",
+ method_name="__array__",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
+ ),
+)
+def test_numpy_instance_array__(
+
+@handle_frontend_method(
+ class_tree=CLASS_TREE,
+ init_tree="numpy.array",
+ method_name="__tobytes__",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("valid"),
+ ),
+)
+def test_numpy_instance_tobytes__(
+ dtype_and_x,
+ frontend_method_data,
+ init_flags,
+ method_flags,
+ frontend,
+):
+ input_dtypes, x = dtype_and_x
+ helpers.test_frontend_method(
+ init_input_dtypes=input_dtypes,
+ init_all_as_kwargs_np={
+ "object": x[0],
+ },
+ method_input_dtypes=input_dtypes,
+ method_all_as_kwargs_np={},
+ init_flags=init_flags,
+ method_flags=method_flags,
+ frontend=frontend,
+ frontend_method_data=frontend_method_data,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py
index e69de29bb2d1d..89d84a2d45766 100644
--- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py
+++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_compat/test_v1/test_nn.py
@@ -0,0 +1,73 @@
+# global
+from hypothesis import strategies as st
+import numpy as np
+
+# local
+import ivy_tests.test_ivy.helpers as helpers
+from ivy_tests.test_ivy.helpers import handle_frontend_test
+
+
+@st.composite
+def _batch_norm_helper(draw):
+ num_dims = draw(st.integers(min_value=4, max_value=5))
+ dtype, x = draw(
+ helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ min_num_dims=num_dims,
+ max_num_dims=num_dims,
+ min_value=-1e02,
+ max_value=1e02,
+ )
+ )
+ epsilon = draw(st.floats(min_value=1e-07, max_value=1e-04))
+ factor = draw(st.floats(min_value=0.5, max_value=1))
+ training = draw(st.booleans())
+ if num_dims == 4:
+ data_format = draw(st.sampled_from(["NHWC", "NCHW"]))
+ else:
+ data_format = draw(st.sampled_from(["NDHWC", "NCDHW"]))
+ num_channels = x[0].shape[data_format.rfind("C")]
+ dtypes, vectors = draw(
+ helpers.dtype_and_values(
+ available_dtypes=["float32"],
+ shape=(num_channels,),
+ num_arrays=4,
+ min_value=-1e02,
+ max_value=1e02,
+ )
+ )
+ vectors[3] = np.abs(vectors[3]) # non-negative variance
+ return dtype + dtypes, x, epsilon, factor, training, data_format, vectors
+
+
+@handle_frontend_test(
+ fn_tree="tensorflow.compat.v1.nn.fused_batch_norm",
+ dtypes_args=_batch_norm_helper(),
+ test_with_out=st.just(False),
+)
+def test_tensorflow_fused_batch_norm(
+ *,
+ dtypes_args,
+ test_flags,
+ frontend,
+ fn_tree,
+ on_device,
+):
+ dtypes, x, epsilon, factor, training, data_format, vectors = dtypes_args
+ helpers.test_frontend_function(
+ input_dtypes=dtypes,
+ test_flags=test_flags,
+ frontend=frontend,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ atol=1e-02,
+ x=x[0],
+ scale=vectors[0],
+ offset=vectors[1],
+ mean=vectors[2],
+ variance=vectors[3],
+ epsilon=epsilon,
+ data_format=data_format,
+ is_training=training,
+ exponential_avg_factor=factor,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py
index df28fb7036b0b..d5d349f2b3350 100644
--- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py
+++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py
@@ -97,6 +97,51 @@ def test_tensorflow_clip_by_value(
)
+@st.composite
+def _get_norm_clip_inputs(draw):
+ x_dtype, x = draw(
+ helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
+ min_num_dims=1,
+ min_value=-100,
+ max_value=100,
+ )
+ )
+ norm_dtype, norm = draw(
+ helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"), shape=(1,)
+ )
+ )
+ print(x_dtype, x, norm_dtype, norm)
+ return x_dtype, x, norm_dtype, norm
+
+
+# clip_by_norm
+@handle_frontend_test(
+ fn_tree="tensorflow.clip_by_norm",
+ input_and_norm=_get_norm_clip_inputs(),
+ test_with_out=st.just(False),
+)
+def test_tensorflow_clip_by_norm(
+ *,
+ input_and_norm,
+ frontend,
+ test_flags,
+ fn_tree,
+ on_device,
+):
+ x_dtype, x, norm_dtype, norm = input_and_norm
+ helpers.test_frontend_function(
+ input_dtypes=[x_dtype, norm_dtype],
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ t=x[0],
+ clip_norm=norm[0],
+ )
+
+
# eye
@handle_frontend_test(
fn_tree="tensorflow.eye",
@@ -1282,17 +1327,10 @@ def _boolean_mask_helper(draw):
# Param: tensor
tensor = draw(
helpers.array_values(
- dtype=dtype,
- shape=tensor_shape,
- min_value=-5.0,
- max_value=5.0),
- )
- mask_dim = draw(
- helpers.number(
- min_value=1,
- max_value=len(tensor_shape)
- )
+ dtype=dtype, shape=tensor_shape, min_value=-5.0, max_value=5.0
+ ),
)
+ mask_dim = draw(helpers.number(min_value=1, max_value=len(tensor_shape)))
mask_shape = tensor_shape[:mask_dim]
# Param:stop
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_convolution_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_convolution_functions.py
index 95ba4b3af5662..fe11c3deaeba4 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_convolution_functions.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_convolution_functions.py
@@ -1,9 +1,14 @@
# global
+import math
from hypothesis import strategies as st
# local
+import ivy
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
+from ivy_tests.test_ivy.test_functional.test_nn.test_layers import (
+ _assume_tf_dilation_gt_1,
+)
@st.composite
@@ -16,13 +21,27 @@ def x_and_filters(draw, dim: int = 2, transpose: bool = False):
st.integers(min_value=1, max_value=3),
)
)
- padding = draw(
- st.one_of(
- st.sampled_from(["same", "valid"]) if strides == 1 else st.just("valid"),
- st.integers(min_value=1, max_value=3),
- st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim),
+ if not transpose:
+ padding = draw(
+ st.one_of(
+ st.sampled_from(["same", "valid"])
+ if strides == 1
+ else st.just("valid"),
+ st.integers(min_value=1, max_value=3),
+ st.lists(
+ st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim
+ ),
+ )
+ )
+ else:
+ padding = draw(
+ st.one_of(
+ st.integers(min_value=1, max_value=3),
+ st.lists(
+ st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim
+ ),
+ )
)
- )
batch_size = draw(st.integers(1, 5))
filter_shape = draw(
helpers.get_shape(
@@ -38,9 +57,14 @@ def x_and_filters(draw, dim: int = 2, transpose: bool = False):
else:
group_list = list(filter(lambda x: (output_channels % x == 0), group_list))
fc = draw(st.sampled_from(group_list))
- dilations = draw(st.integers(1, 3))
-
- x_dim = []
+ dilations = draw(
+ st.one_of(
+ st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim),
+ st.integers(min_value=1, max_value=3),
+ )
+ )
+ full_strides = [strides] * dim if isinstance(strides, int) else strides
+ full_dilations = [dilations] * dim if isinstance(dilations, int) else dilations
if transpose:
x_dim = draw(
helpers.get_shape(
@@ -48,8 +72,9 @@ def x_and_filters(draw, dim: int = 2, transpose: bool = False):
)
)
else:
+ x_dim = []
for i in range(dim):
- min_x = filter_shape[i] + (filter_shape[i] - 1) * (dilations - 1)
+ min_x = filter_shape[i] + (filter_shape[i] - 1) * (full_dilations[i] - 1)
x_dim.append(draw(st.integers(min_x, 15)))
x_dim = tuple(x_dim)
if not transpose:
@@ -57,7 +82,7 @@ def x_and_filters(draw, dim: int = 2, transpose: bool = False):
filter_shape = (output_channels, input_channels // fc) + filter_shape
else:
input_channels = input_channels * fc
- filter_shape = filter_shape + (input_channels, output_channels // fc)
+ filter_shape = (input_channels, output_channels // fc) + filter_shape
x_shape = (batch_size, input_channels) + x_dim
vals = draw(
helpers.array_values(
@@ -83,7 +108,29 @@ def x_and_filters(draw, dim: int = 2, transpose: bool = False):
max_value=1.0,
)
)
- return dtype, vals, filters, bias, dilations, strides, padding, fc
+ if transpose:
+ output_padding = draw(
+ st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim)
+ )
+ for i, p in enumerate(output_padding):
+ m = min(full_strides[i], full_dilations[i])
+ if p >= m:
+ output_padding[i] = m - 1
+ if draw(st.booleans()):
+ output_padding = min(output_padding)
+ return (
+ dtype,
+ vals,
+ filters,
+ bias,
+ dilations,
+ strides,
+ padding,
+ output_padding,
+ fc,
+ )
+ else:
+ return dtype, vals, filters, bias, dilations, strides, padding, fc
@handle_frontend_test(
@@ -99,6 +146,8 @@ def test_torch_conv1d(
test_flags,
):
dtype, vals, weight, bias, dilations, strides, padding, fc = dtype_vals
+ # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it.
+ _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations)
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
@@ -128,6 +177,7 @@ def test_torch_conv2d(
test_flags,
):
dtype, vals, weight, bias, dilations, strides, padding, fc = dtype_vals
+ _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations)
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
@@ -157,6 +207,7 @@ def test_torch_conv3d(
test_flags,
):
dtype, vals, weight, bias, dilations, strides, padding, fc = dtype_vals
+ _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations)
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
@@ -173,94 +224,246 @@ def test_torch_conv3d(
)
+@handle_frontend_test(
+ fn_tree="torch.nn.functional.conv_transpose1d",
+ dtype_vals=x_and_filters(dim=1, transpose=True),
+)
+def test_torch_conv_tranpose1d(
+ *,
+ dtype_vals,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ dtype, vals, weight, bias, dilations, strides, padding, output_pad, fc = dtype_vals
+ # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it.
+ _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations)
+ helpers.test_frontend_function(
+ input_dtypes=dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input=vals,
+ weight=weight,
+ bias=bias,
+ stride=1,
+ # stride=strides,
+ padding=0,
+ # padding=padding,
+ output_padding=0,
+ # output_padding=output_pad,
+ groups=fc,
+ dilation=dilations,
+ )
+
+
+@handle_frontend_test(
+ fn_tree="torch.nn.functional.conv_transpose2d",
+ dtype_vals=x_and_filters(dim=2, transpose=True),
+)
+def test_torch_conv_tranpose2d(
+ *,
+ dtype_vals,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ dtype, vals, weight, bias, dilations, strides, padding, output_pad, fc = dtype_vals
+ _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations)
+ helpers.test_frontend_function(
+ input_dtypes=dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input=vals,
+ weight=weight,
+ bias=bias,
+ stride=1,
+ # stride=strides,
+ padding=0,
+ # padding=padding,
+ output_padding=0,
+ # output_padding=output_pad,
+ groups=fc,
+ dilation=dilations,
+ )
+
+
+@handle_frontend_test(
+ fn_tree="torch.nn.functional.conv_transpose3d",
+ dtype_vals=x_and_filters(dim=3, transpose=True),
+)
+def test_torch_conv_tranpose3d(
+ *,
+ dtype_vals,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ dtype, vals, weight, bias, dilations, strides, padding, output_pad, fc = dtype_vals
+ _assume_tf_dilation_gt_1(ivy.current_backend_str(), on_device, dilations)
+ helpers.test_frontend_function(
+ input_dtypes=dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input=vals,
+ weight=weight,
+ bias=bias,
+ stride=1,
+ # stride=strides,
+ padding=0,
+ # padding=padding,
+ output_padding=0,
+ # output_padding=output_pad,
+ groups=fc,
+ dilation=dilations,
+ )
+
+
@st.composite
-def _int_or_tuple(draw, min_val, max_val):
- val = draw(
+def _fold_unfold_helper(draw, dim):
+ stride = draw(
+ st.one_of(
+ st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim),
+ st.integers(min_value=1, max_value=3),
+ )
+ )
+ padding = draw(
st.one_of(
- st.integers(min_val, max_val),
- st.tuples(
- st.integers(min_val, max_val),
- st.integers(min_val, max_val),
+ st.integers(min_value=1, max_value=3),
+ st.lists(st.integers(min_value=1, max_value=2), min_size=dim, max_size=dim),
+ )
+ )
+ dilation = draw(
+ st.one_of(
+ st.lists(st.integers(min_value=1, max_value=3), min_size=dim, max_size=dim),
+ st.integers(min_value=1, max_value=3),
+ )
+ )
+ kernel_size = draw(
+ st.one_of(
+ st.integers(min_value=1, max_value=5),
+ helpers.get_shape(
+ min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5
),
)
)
- return val
+ return stride, padding, dilation, kernel_size
+
+
+@st.composite
+def _unfold_helper(draw, dim=2):
+ stride, padding, dilation, kernel_size = draw(_fold_unfold_helper(dim))
+ dilations = [dilation] * dim if isinstance(dilation, int) else dilation
+ kernel_sizes = [kernel_size] * dim if isinstance(kernel_size, int) else kernel_size
+ x_dim = []
+ for i in range(dim):
+ min_x = kernel_sizes[i] + (kernel_sizes[i] - 1) * (dilations[i] - 1)
+ x_dim.append(draw(st.integers(min_x, 15)))
+ batch_size = draw(st.integers(1, 5))
+ input_channels = draw(st.integers(1, 3))
+ x_shape = (batch_size, input_channels) + tuple(x_dim)
+ dtype, [vals] = draw(
+ helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ shape=x_shape,
+ min_value=0.0,
+ max_value=1.0,
+ )
+ )
+ return dtype, vals, kernel_size, dilation, stride, padding
@handle_frontend_test(
fn_tree="torch.nn.functional.unfold",
- dtype_and_input_and_shape=helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("numeric"),
- shape=(1, 3, 6, 6),
- ),
- kernel_size=_int_or_tuple(2, 5),
- dilation=_int_or_tuple(1, 3),
- padding=_int_or_tuple(0, 2),
- stride=_int_or_tuple(1, 3),
+ dtype_vals=_unfold_helper(),
)
def test_torch_unfold(
*,
- dtype_and_input_and_shape,
- kernel_size,
- dilation,
- padding,
- stride,
+ dtype_vals,
on_device,
fn_tree,
frontend,
test_flags,
):
- args_dtypes = list([dtype_and_input_and_shape[0][0]] + ["uint8"] * 4)
+ dtype, vals, kernel_shape, dilations, strides, padding = dtype_vals
helpers.test_frontend_function(
- input_dtypes=args_dtypes,
+ input_dtypes=dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
- input=dtype_and_input_and_shape[1][0],
- kernel_size=kernel_size,
- dilation=dilation,
+ input=vals,
+ kernel_size=kernel_shape,
+ dilation=dilations,
padding=padding,
- stride=stride,
+ stride=strides,
)
+@st.composite
+def _fold_helper(draw, dim=2):
+ stride, padding, dilation, kernel_size = draw(_fold_unfold_helper(dim))
+ strides = [stride] * dim if isinstance(stride, int) else stride
+ paddings = [padding] * dim if isinstance(padding, int) else padding
+ dilations = [dilation] * dim if isinstance(dilation, int) else dilation
+ kernel_sizes = [kernel_size] * dim if isinstance(kernel_size, int) else kernel_size
+ output_shape = ()
+ for i in range(dim):
+ min_dim = kernel_sizes[i] + (kernel_sizes[i] - 1) * (dilations[i] - 1)
+ output_shape = output_shape + (draw(st.integers(min_dim, 15)),)
+ batch_size = draw(st.integers(1, 5))
+ n_channels = draw(st.integers(1, 3))
+ x_shape = [
+ (output_shape[i] + 2 * paddings[i] - dilations[i] * (kernel_sizes[i] - 1) - 1)
+ // strides[i]
+ + 1
+ for i in range(2)
+ ]
+ x_shape = (batch_size, n_channels * math.prod(kernel_sizes), math.prod(x_shape))
+ dtype, [vals] = draw(
+ helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ shape=x_shape,
+ min_value=0.0,
+ max_value=1.0,
+ )
+ )
+ if vals.shape[0] == 1: # un-batched inputs are also supported
+ vals = draw(st.one_of(st.just(vals), st.just(ivy.squeeze(vals, axis=0))))
+ return dtype, vals, kernel_size, output_shape, dilation, stride, padding
+
+
@handle_frontend_test(
fn_tree="torch.nn.functional.fold",
- dtype_and_input_and_shape=helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("numeric"),
- shape=(1, 12, 12),
- ),
- output_size=_int_or_tuple(3, 5),
- kernel_size=_int_or_tuple(2, 5),
- dilation=_int_or_tuple(1, 3),
- padding=_int_or_tuple(0, 2),
- stride=_int_or_tuple(1, 3),
+ dtype_vals=_fold_helper(),
)
def test_torch_fold(
*,
- dtype_and_input_and_shape,
- output_size,
- kernel_size,
- dilation,
- padding,
- stride,
+ dtype_vals,
on_device,
fn_tree,
frontend,
test_flags,
):
- args_dtypes = list([dtype_and_input_and_shape[0][0]] + ["uint8"] * 5)
+ dtype, vals, kernel_shape, output_shape, dilations, strides, padding = dtype_vals
helpers.test_frontend_function(
- input_dtypes=args_dtypes,
+ input_dtypes=dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
- input=dtype_and_input_and_shape[1][0],
- output_size=output_size,
- kernel_size=kernel_size,
- dilation=dilation,
+ input=vals,
+ output_size=output_shape,
+ kernel_size=kernel_shape,
+ dilation=dilations,
padding=padding,
- stride=stride,
+ stride=strides,
)
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py
index 347b06e9464d2..98b5f35097cd4 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py
@@ -2,7 +2,9 @@
from hypothesis import strategies as st, assume
import math
+
# local
+import ivy
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.test_functional.test_core.test_manipulation import _get_splits
@@ -922,7 +924,10 @@ def _get_split_locations(draw, min_num_dims, axis):
shape = draw(
st.shared(helpers.get_shape(min_num_dims=min_num_dims), key="value_shape")
)
- axis = draw(st.just(axis))
+ if len(shape) == 1:
+ axis = draw(st.just(0))
+ else:
+ axis = draw(st.just(axis))
@st.composite
def get_int_split(draw):
@@ -979,6 +984,46 @@ def test_torch_dsplit(
)
+# hsplit
+@handle_frontend_test(
+ fn_tree="torch.hsplit",
+ dtype_value=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("valid"),
+ shape=st.shared(helpers.get_shape(min_num_dims=1), key="value_shape"),
+ ),
+ indices_or_sections=_get_split_locations(min_num_dims=1, axis=1),
+ number_positional_args=st.just(2),
+)
+def test_torch_hsplit(
+ *,
+ dtype_value,
+ indices_or_sections,
+ on_device,
+ fn_tree,
+ frontend,
+ test_flags,
+):
+ input_dtype, value = dtype_value
+ # TODO: remove the assumption when these bugfixes are merged and version-pinned
+ # https://github.com/tensorflow/tensorflow/pull/59523
+ # https://github.com/google/jax/pull/14275
+ assume(
+ not (
+ len(value[0].shape) == 1
+ and ivy.current_backend_str() in ("tensorflow", "jax")
+ )
+ )
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input=value[0],
+ indices_or_sections=indices_or_sections,
+ )
+
+
# row_stack
@handle_frontend_test(
fn_tree="torch.row_stack",
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py
index 3a434d06954cc..c737273cd469b 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py
@@ -4,6 +4,7 @@
# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
+import ivy
# cross_entropy
@@ -242,6 +243,70 @@ def test_torch_binary_cross_entropy_with_logits(
)
+# cosine_embedding_loss
+@handle_frontend_test(
+ fn_tree="torch.nn.functional.cosine_embedding_loss",
+ dtype_and_inputs=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ min_value=2,
+ max_value=5,
+ min_num_dims=1,
+ max_num_dims=2,
+ min_dim_size=2,
+ shared_dtype=True,
+ num_arrays=2,
+ ),
+ margin=st.floats(
+ min_value=-1.0,
+ max_value=1.0,
+ width=16,
+ ),
+ size_average=st.booleans(),
+ reduce=st.booleans(),
+ reduction=st.sampled_from(["none", "mean", "sum"]),
+ test_with_out=st.just(False),
+)
+def test_torch_cosine_embedding_loss(
+ *,
+ dtype_and_inputs,
+ margin,
+ size_average,
+ reduce,
+ reduction,
+ test_flags,
+ fn_tree,
+ frontend,
+ on_device,
+):
+ input_dtype, x = dtype_and_inputs
+ input1_dtype, input1 = input_dtype[0], x[0]
+ input2_dtype, input2 = input_dtype[1], x[1]
+
+ if input1.ndim == input2.ndim == 1:
+ tar = ivy.array(1.0)
+ else:
+ third = input1.shape[0] // 3
+ ones = ivy.ones(input1.shape[0] - (third * 2))
+ minus_ones = ivy.ones(third) * -1
+ randoms = ivy.random_uniform(shape=[third])
+ tar = ivy.hstack((ones, minus_ones, randoms)).shuffle()
+
+ helpers.test_frontend_function(
+ input_dtypes=[input1_dtype, input2_dtype],
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input1=input1,
+ input2=input2,
+ target=tar,
+ margin=margin,
+ size_average=size_average,
+ reduce=reduce,
+ reduction=reduction,
+ )
+
+
# mse_loss
@handle_frontend_test(
fn_tree="torch.nn.functional.mse_loss",
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_non_linear_activation_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_non_linear_activation_functions.py
index 86454f137c042..1abe7793ca7b6 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_non_linear_activation_functions.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_non_linear_activation_functions.py
@@ -64,7 +64,7 @@ def test_torch_sigmoid(
force_int_axis=True,
valid_axis=True,
),
- dtypes=helpers.get_dtypes("float", none=True, full=False),
+ dtypes=helpers.get_dtypes("float", full=False),
)
def test_torch_softmax(
*,
@@ -94,6 +94,7 @@ def test_torch_softmax(
fn_tree="torch.nn.functional.gelu",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
+ max_value=1e04,
),
)
def test_torch_gelu(
@@ -212,7 +213,7 @@ def test_torch_logsigmoid(
force_int_axis=True,
valid_axis=True,
),
- dtypes=helpers.get_dtypes("float", none=True, full=False),
+ dtypes=helpers.get_dtypes("float", full=False),
)
def test_torch_softmin(
*,
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_pooling_functions.py
index 37a79f9d866ad..322bf105d0d0f 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_pooling_functions.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_pooling_functions.py
@@ -56,3 +56,47 @@ def test_torch_avg_pool2d(
count_include_pad=True,
divisor_override=None,
)
+
+
+# max_pool2d
+@handle_frontend_test(
+ fn_tree="torch.nn.functional.max_pool2d",
+ x_k_s_p=helpers.arrays_for_pooling(
+ min_dims=4,
+ max_dims=4,
+ min_side=1,
+ max_side=4,
+ allow_explicit_padding=True,
+ return_dilation=True,
+ ).filter(lambda x: x[4] != "VALID" and x[4] != "SAME"),
+ test_with_out=st.just(False),
+ ceil_mode=st.just(True),
+)
+def test_torch_max_pool2d(
+ x_k_s_p,
+ ceil_mode,
+ *,
+ test_flags,
+ frontend,
+ fn_tree,
+ on_device,
+):
+ dtype, x, kernel, stride, pad, dilation = x_k_s_p
+ # Torch ground truth func expects input to be consistent
+ # with a channels first format i.e. NCHW
+ x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], *x[0].shape[1:-1]))
+ pad = (pad[0][0], pad[1][0])
+
+ helpers.test_frontend_function(
+ input_dtypes=dtype,
+ test_flags=test_flags,
+ frontend=frontend,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input=x[0],
+ kernel_size=kernel,
+ stride=stride,
+ padding=pad,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py
index 601312a515eed..9a5932a6763d5 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py
@@ -634,3 +634,35 @@ def test_torch_var_mean(
unbiased=bool(correction),
keepdim=keepdims,
)
+
+
+@handle_frontend_test(
+ fn_tree="torch.aminmax",
+ dtype_input_axis=helpers.dtype_values_axis(
+ available_dtypes=helpers.get_dtypes("numeric"),
+ min_num_dims=1,
+ min_axis=-1,
+ max_axis=0,
+ ),
+ keepdims=st.booleans(),
+)
+def test_torch_aminmax(
+ *,
+ dtype_input_axis,
+ keepdims,
+ test_flags,
+ on_device,
+ fn_tree,
+ frontend,
+):
+ input_dtype, x, axis = dtype_input_axis
+ helpers.test_frontend_function(
+ input_dtypes=input_dtype,
+ frontend=frontend,
+ test_flags=test_flags,
+ fn_tree=fn_tree,
+ on_device=on_device,
+ input=x[0],
+ dim=axis,
+ keepdim=keepdims,
+ )
diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
index a04c45f1186ba..b5689dd8d7963 100644
--- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
+++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
@@ -4,7 +4,6 @@
import ivy
import torch
from hypothesis import strategies as st, given
-import hypothesis.extra.numpy as hnp
# local
import ivy_tests.test_ivy.helpers as helpers
@@ -95,7 +94,7 @@ def test_torch_tensor_property_dtype(
dtype_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
ret_shape=True,
- ),
+ ).filter(lambda x: "bfloat16" not in x[0]),
)
def test_torch_tensor_property_shape(dtype_x):
dtype, data, shape = dtype_x
@@ -2171,7 +2170,7 @@ def _fill_value_and_size(
key="shape",
)
)
- fill_value = draw(helpers.ints())
+ fill_value = draw(helpers.ints()) if "int" in dtype[0] else draw(helpers.floats())
return dtype, [array, size, fill_value]
@@ -2248,15 +2247,24 @@ def test_torch_instance_new_empty(
@st.composite
def _expand_helper(draw):
- shape, _ = draw(hnp.mutually_broadcastable_shapes(num_shapes=2, min_dims=2))
- shape1, shape2 = shape
- dtype_x = draw(
+ num_dims = draw(st.integers(min_value=1, max_value=10))
+ shape = draw(
+ helpers.get_shape(min_num_dims=num_dims, max_num_dims=num_dims).filter(
+ lambda x: any(i == 1 for i in x)
+ )
+ )
+ new_shape = draw(
+ helpers.get_shape(min_num_dims=num_dims, max_num_dims=num_dims).filter(
+ lambda x: all(x[i] == v if v != 1 else True for i, v in enumerate(shape))
+ )
+ )
+ dtype, x = draw(
helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("valid", full=True), shape=shape1
+ available_dtypes=helpers.get_dtypes("valid"),
+ shape=shape,
)
)
- dtype, x = dtype_x
- return dtype, x, shape1
+ return dtype, x, new_shape
@handle_frontend_method(
@@ -2264,22 +2272,68 @@ def _expand_helper(draw):
init_tree="torch.tensor",
method_name="expand",
dtype_x_shape=_expand_helper(),
+ unpack_shape=st.booleans(),
)
def test_torch_instance_expand(
dtype_x_shape,
+ unpack_shape,
frontend_method_data,
init_flags,
method_flags,
frontend,
):
input_dtype, x, shape = dtype_x_shape
+ if unpack_shape:
+ method_flags.num_positional_args = len(shape) + 1
+ size = {}
+ i = 0
+ for x_ in shape:
+ size["x{}".format(i)] = x_
+ i += 1
+ else:
+ size = {
+ "size": shape,
+ }
+ helpers.test_frontend_method(
+ init_input_dtypes=input_dtype,
+ init_all_as_kwargs_np={
+ "data": x[0],
+ },
+ method_input_dtypes=input_dtype,
+ method_all_as_kwargs_np=size,
+ frontend_method_data=frontend_method_data,
+ init_flags=init_flags,
+ method_flags=method_flags,
+ frontend=frontend,
+ )
+
+
+# expand_as
+@handle_frontend_method(
+ class_tree=CLASS_TREE,
+ init_tree="torch.tensor",
+ method_name="expand_as",
+ dtype_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("valid", full=True), num_arrays=2
+ ),
+)
+def test_torch_instance_expand_as(
+ dtype_x,
+ frontend_method_data,
+ init_flags,
+ method_flags,
+ frontend,
+):
+ input_dtype, x = dtype_x
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
- method_all_as_kwargs_np={str(i): s for i, s in enumerate(shape)},
+ method_all_as_kwargs_np={
+ "other": x[1],
+ },
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
@@ -4238,6 +4292,35 @@ def test_torch_instance_clamp(
)
+# clamp_
+@handle_frontend_method(
+ class_tree=CLASS_TREE,
+ init_tree="torch.tensor",
+ method_name="clamp_",
+ dtype_and_x_min_max=_get_clamp_inputs(),
+)
+def test_torch_instance_clamp_(
+ dtype_and_x_min_max,
+ frontend,
+ frontend_method_data,
+ init_flags,
+ method_flags,
+):
+ input_dtype, x, min, max = dtype_and_x_min_max
+ helpers.test_frontend_method(
+ init_input_dtypes=input_dtype,
+ init_all_as_kwargs_np={
+ "data": x[0],
+ },
+ method_input_dtypes=input_dtype,
+ method_all_as_kwargs_np={"min": min, "max": max},
+ frontend_method_data=frontend_method_data,
+ init_flags=init_flags,
+ method_flags=method_flags,
+ frontend=frontend,
+ )
+
+
# __gt__
@handle_frontend_method(
class_tree=CLASS_TREE,
diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_device.py b/ivy_tests/test_ivy/test_functional/test_core/test_device.py
index e6531a78ebc3a..a97b914fc621c 100644
--- a/ivy_tests/test_ivy/test_functional/test_core/test_device.py
+++ b/ivy_tests/test_ivy/test_functional/test_core/test_device.py
@@ -67,31 +67,22 @@ def _empty_dir(path, recreate=False):
# Device Queries #
+
# dev
@handle_test(
fn_tree="functional.ivy.dev",
- array_shape=helpers.lists(
- arg=helpers.ints(min_value=2, max_value=3),
- min_size="num_dims",
- max_size="num_dims",
- size_bounds=[1, 3],
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
),
- dtype=helpers.get_dtypes("numeric", full=False),
- as_variable_flags=st.booleans(),
)
-def test_dev(
- *,
- array_shape,
- dtype,
- test_flags,
- backend_fw,
-):
- assume(not (backend_fw == "torch" and "int" in dtype))
- x = np.random.uniform(size=tuple(array_shape)).astype(dtype[0])
+def test_dev(*, dtype_and_x, test_flags):
+ dtype, x = dtype_and_x
+ dtype = dtype[0]
+ x = x[0]
for device in _get_possible_devices():
x = ivy.array(x, device=device)
- if test_flags.as_variable and ivy.is_float_dtype(dtype[0]):
+ if test_flags.as_variable and ivy.is_float_dtype(dtype):
x = _variable(x)
ret = ivy.dev(x)
@@ -111,28 +102,18 @@ def test_dev(
# as_ivy_dev
@handle_test(
fn_tree="functional.ivy.as_ivy_dev",
- array_shape=helpers.lists(
- arg=helpers.ints(min_value=2, max_value=3),
- min_size="num_dims",
- max_size="num_dims",
- size_bounds=[1, 3],
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
),
- dtype=helpers.get_dtypes("numeric", full=False),
)
-def test_as_ivy_dev(
- *,
- array_shape,
- dtype,
- test_flags,
- backend_fw,
-):
- assume(not (backend_fw == "torch" and "int" in dtype))
-
- x = np.random.uniform(size=tuple(array_shape)).astype(dtype[0])
+def test_as_ivy_dev(*, dtype_and_x, test_flags):
+ dtype, x = dtype_and_x
+ dtype = dtype[0]
+ x = x[0]
for device in _get_possible_devices():
x = ivy.array(x, device=device)
- if test_flags.as_variable and ivy.is_float_dtype(dtype[0]):
+ if test_flags.as_variable and ivy.is_float_dtype(dtype):
x = _variable(x)
native_device = ivy.dev(x, as_native=True)
@@ -145,26 +126,16 @@ def test_as_ivy_dev(
# as_native_dev
-# TODO: possible refactor to use the helpers.test_function method
@handle_test(
fn_tree="functional.ivy.as_native_dev",
- array_shape=helpers.lists(
- arg=helpers.ints(min_value=1, max_value=3),
- min_size="num_dims",
- max_size="num_dims",
- size_bounds=[1, 3],
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
),
- dtype=helpers.get_dtypes("float", index=1, full=False),
)
-def test_as_native_dev(
- *,
- array_shape,
- dtype,
- test_flags,
- on_device,
-):
- # TODO: should be replaced with the helpers.dtype_values function
- x = np.random.uniform(size=tuple(array_shape)).astype(dtype[0])
+def test_as_native_dev(*, dtype_and_x, test_flags, on_device):
+ dtype, x = dtype_and_x
+ dtype = dtype[0]
+ x = x[0]
for device in _get_possible_devices():
x = ivy.asarray(x, device=on_device)
@@ -212,34 +183,30 @@ def test_default_device():
# to_dev
@handle_test(
fn_tree="functional.ivy.to_device",
- array_shape=helpers.lists(
- arg=helpers.ints(min_value=1, max_value=3),
- min_size="num_dims",
- max_size="num_dims",
- size_bounds=[1, 3],
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("numeric"),
),
- dtype=helpers.get_dtypes("numeric", full=False),
stream=helpers.ints(min_value=0, max_value=50),
)
def test_to_device(
*,
- array_shape,
- dtype,
+ dtype_and_x,
stream,
test_flags,
backend_fw,
on_device,
):
- assume(not (backend_fw == "torch" and "int" in dtype))
+ dtype, x = dtype_and_x
+ dtype = dtype[0]
+ x = x[0]
- x = np.random.uniform(size=tuple(array_shape)).astype(dtype[0])
x = ivy.asarray(x)
- if test_flags.as_variable and ivy.is_float_dtype(dtype[0]):
+ if test_flags.as_variable and ivy.is_float_dtype(dtype):
x = _variable(x)
# create a dummy array for out that is broadcastable to x
out = (
- ivy.zeros(ivy.shape(x), device=on_device, dtype=dtype[0])
+ ivy.zeros(ivy.shape(x), device=on_device, dtype=dtype)
if test_flags.with_out
else None
)
@@ -394,10 +361,7 @@ def func(t0, t1):
@handle_test(
fn_tree="functional.ivy.Profiler",
)
-def test_profiler(
- *,
- backend_fw,
-):
+def test_profiler(*, backend_fw):
# ToDo: find way to prevent this test from hanging when run
# alongside other tests in parallel
diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py
index 81eda8ffbc8a1..b21dfb8466305 100644
--- a/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py
+++ b/ivy_tests/test_ivy/test_functional/test_core/test_linalg.py
@@ -190,11 +190,16 @@ def _get_first_matrix_and_dtype(draw, *, transpose=False):
max_value=5,
)
)
- if transpose is True:
+ if transpose:
transpose = draw(st.booleans())
- if transpose:
+ adjoint = draw(st.booleans())
+ if adjoint and transpose:
+ adjoint = draw(st.just('False'))
+ if transpose and not adjoint:
matrix = np.transpose(matrix)
- return [input_dtype], matrix, transpose
+ if adjoint and not transpose:
+ matrix = np.transpose(np.conjugate(matrix))
+ return [input_dtype], matrix, transpose, adjoint
return [input_dtype], matrix
@@ -221,11 +226,16 @@ def _get_second_matrix_and_dtype(draw, *, transpose=False):
max_value=5,
)
)
- if transpose is True:
+ if transpose:
transpose = draw(st.booleans())
- if transpose:
+ adjoint = draw(st.booleans())
+ if adjoint and transpose:
+ adjoint = draw(st.just('False'))
+ if transpose and not adjoint:
matrix = np.transpose(matrix)
- return [input_dtype], matrix, transpose
+ if adjoint and not transpose:
+ matrix = np.transpose(np.conjugate(matrix))
+ return [input_dtype], matrix, transpose, adjoint
return [input_dtype], matrix
@@ -328,8 +338,8 @@ def test_matmul(
on_device,
ground_truth_backend,
):
- input_dtype1, x_1, transpose_a = x
- input_dtype2, y_1, transpose_b = y
+ input_dtype1, x_1, transpose_a, adjoint_a = x
+ input_dtype2, y_1, transpose_b, adjoint_b = y
helpers.test_function(
ground_truth_backend=ground_truth_backend,
input_dtypes=input_dtype1 + input_dtype2,
@@ -343,6 +353,8 @@ def test_matmul(
x2=y_1,
transpose_a=transpose_a,
transpose_b=transpose_b,
+ adjoint_a=adjoint_a,
+ adjoint_b=adjoint_b,
)
@@ -818,8 +830,8 @@ def test_tensordot(
max_num_dims=2,
min_dim_size=1,
max_dim_size=10,
- large_abs_safety_factor=2,
- small_abs_safety_factor=2,
+ large_abs_safety_factor=16,
+ small_abs_safety_factor=16,
safety_factor_scale="log",
),
offset=st.integers(min_value=0, max_value=0),
diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py
index 8c41ea39bf3fb..75d27cf5773b1 100644
--- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py
+++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_elementwise.py
@@ -269,20 +269,38 @@ def test_trapz(
)
+# float_power_helper
+@st.composite
+def _float_power_helper(draw, *, available_dtypes=None):
+ if available_dtypes is None:
+ available_dtypes = helpers.get_dtypes("numeric")
+ dtype1, x1 = draw(
+ helpers.dtype_and_values(
+ available_dtypes=available_dtypes,
+ small_abs_safety_factor=16,
+ large_abs_safety_factor=16,
+ safety_factor_scale="log",
+ )
+ )
+ dtype2 = draw(helpers.get_dtypes("numeric"))
+ if ivy.is_int_dtype(dtype2[0]):
+ min_value = 0
+ else:
+ min_value = -10
+ dtype2, x2 = draw(
+ helpers.dtype_and_values(
+ min_value=min_value,
+ max_value=10,
+ dtype=dtype2,
+ )
+ )
+ return (dtype1[0], dtype2[0]), (x1[0], x2[0])
+
+
# float_power
@handle_test(
fn_tree="functional.ivy.experimental.float_power",
- dtype_and_x=helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("float"),
- min_value=-10,
- max_value=10,
- num_arrays=2,
- shared_dtype=True,
- min_num_dims=1,
- max_num_dims=3,
- min_dim_size=1,
- max_dim_size=3,
- ),
+ dtype_and_x=_float_power_helper(),
test_gradients=st.just(False),
)
def test_float_power(
@@ -293,16 +311,18 @@ def test_float_power(
on_device,
ground_truth_backend,
):
- input_dtype, x = dtype_and_x
+ input_dtypes, x = dtype_and_x
helpers.test_function(
- input_dtypes=input_dtype,
+ input_dtypes=input_dtypes,
test_flags=test_flags,
on_device=on_device,
ground_truth_backend=ground_truth_backend,
fw=backend_fw,
fn_name=fn_name,
- x1=np.asarray(x[0], dtype=input_dtype[0]),
- x2=np.asarray(x[1], dtype=input_dtype[1]),
+ x1=x[0],
+ x2=x[1],
+ rtol_=1e-1,
+ atol_=1e-1,
)
@@ -475,7 +495,7 @@ def test_nansum(
fw=backend_fw,
on_device=on_device,
rtol_=1e-02,
- atol_=1e-02,
+ atol_=1,
fn_name=fn_name,
x=x[0],
axis=axis,
@@ -612,7 +632,7 @@ def test_angle(
@handle_test(
fn_tree="functional.ivy.experimental.imag",
dtype_and_x=helpers.dtype_and_values(
- available_dtypes=["float32"],
+ available_dtypes=helpers.get_dtypes("valid"),
min_value=-5,
max_value=5,
max_dim_size=5,
@@ -840,28 +860,40 @@ def test_nextafter(
# diff
@handle_test(
fn_tree="functional.ivy.experimental.diff",
- dtype_and_x=helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("integer"),
- num_arrays=2,
- shared_dtype=True,
+ dtype_n_x_n_axis=helpers.dtype_values_axis(
+ available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"),
min_num_dims=1,
- max_num_dims=3,
- min_value=-100,
- max_value=100,
- allow_nan=False,
+ valid_axis=True,
+ force_int_axis=True,
+ ),
+ n=st.integers(min_value=0, max_value=5),
+ dtype_prepend=helpers.dtype_and_values(
+ available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"),
+ min_num_dims=1,
+ max_num_dims=1,
+ ),
+ dtype_append=helpers.dtype_and_values(
+ available_dtypes=st.shared(helpers.get_dtypes("valid"), key="dtype"),
+ min_num_dims=1,
+ max_num_dims=1,
),
test_gradients=st.just(False),
)
def test_diff(
*,
- dtype_and_x,
+ dtype_n_x_n_axis,
+ n,
+ dtype_prepend,
+ dtype_append,
test_flags,
backend_fw,
fn_name,
on_device,
ground_truth_backend,
):
- input_dtype, x = dtype_and_x
+ input_dtype, x, axis = dtype_n_x_n_axis
+ _, prepend = dtype_prepend
+ _, append = dtype_append
helpers.test_function(
ground_truth_backend=ground_truth_backend,
input_dtypes=input_dtype,
@@ -870,6 +902,10 @@ def test_diff(
fn_name=fn_name,
on_device=on_device,
x=x[0],
+ n=n,
+ axis=axis,
+ prepend=prepend[0],
+ append=append[0],
)
diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py
index 6db0702335480..44df6c82ceb3b 100644
--- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py
+++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_manipulation.py
@@ -602,6 +602,44 @@ def test_pad(
)
+@st.composite
+def _get_split_locations(draw, min_num_dims, axis):
+ """
+ Generate valid splits, either by generating an integer that evenly divides the axis
+ or a list of split locations.
+ """
+ shape = draw(
+ st.shared(helpers.get_shape(min_num_dims=min_num_dims), key="value_shape")
+ )
+ if len(shape) == 1:
+ axis = draw(st.just(0))
+ else:
+ axis = draw(st.just(axis))
+
+ @st.composite
+ def get_int_split(draw):
+ if shape[axis] == 0:
+ return 0
+ factors = []
+ for i in range(1, shape[axis] + 1):
+ if shape[axis] % i == 0:
+ factors.append(i)
+ return draw(st.sampled_from(factors))
+
+ @st.composite
+ def get_list_split(draw):
+ return draw(
+ st.lists(
+ st.integers(min_value=0, max_value=shape[axis]),
+ min_size=0,
+ max_size=shape[axis],
+ unique=True,
+ ).map(sorted)
+ )
+
+ return draw(get_list_split() | get_int_split())
+
+
# vsplit
@handle_test(
fn_tree="functional.ivy.experimental.vsplit",
@@ -644,12 +682,10 @@ def test_vsplit(
@handle_test(
fn_tree="functional.ivy.experimental.dsplit",
dtype_and_x=helpers.dtype_and_values(
- available_dtypes=helpers.get_dtypes("float"),
- shape=st.shared(helpers.get_shape(min_num_dims=3), key="dsplit_shape"),
- ),
- indices_or_sections=helpers.get_shape(
- min_num_dims=1, max_num_dims=3, min_dim_size=1, max_dim_size=3
+ available_dtypes=helpers.get_dtypes("valid"),
+ shape=st.shared(helpers.get_shape(min_num_dims=3), key="value_shape"),
),
+ indices_or_sections=_get_split_locations(min_num_dims=3, axis=2),
test_gradients=st.just(False),
test_with_out=st.just(False),
)
@@ -663,7 +699,6 @@ def test_dsplit(
ground_truth_backend,
):
input_dtype, x = dtype_and_x
- indices_or_sections = sorted(indices_or_sections)
helpers.test_function(
ground_truth_backend=ground_truth_backend,
input_dtypes=input_dtype,
@@ -865,6 +900,7 @@ def test_take_along_axis(
min_num_dims=1, max_num_dims=3, min_dim_size=1, max_dim_size=3
),
test_gradients=st.just(False),
+ test_with_out=st.just(False),
)
def test_hsplit(
dtype_and_x,
diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_norms.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_norms.py
index 6d24668287139..5c946aac627cd 100644
--- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_norms.py
+++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_norms.py
@@ -24,6 +24,41 @@ def test_l2_normalize(
fn_name,
on_device,
ground_truth_backend,
+):
+ input_dtype, x, axis = dtype_and_x
+ helpers.test_function(
+ ground_truth_backend=ground_truth_backend,
+ input_dtypes=input_dtype,
+ test_flags=test_flags,
+ on_device=on_device,
+ fw=backend_fw,
+ fn_name=fn_name,
+ x=x[0],
+ axis=axis,
+ )
+
+
+# lp_normalize
+@handle_test(
+ fn_tree="functional.ivy.experimental.lp_normalize",
+ dtype_and_x=helpers.arrays_and_axes(
+ available_dtypes=helpers.get_dtypes("float"),
+ num=1,
+ returndtype=True,
+ force_int_axis=True,
+ ),
+ p=st.floats(min_value=0.1, max_value=2),
+ test_gradients=st.just(False),
+)
+def test_lp_normalize(
+ *,
+ dtype_and_x,
+ test_flags,
+ backend_fw,
+ fn_name,
+ on_device,
+ ground_truth_backend,
+ p,
):
input_dtype, x, axis = dtype_and_x
helpers.test_function(
@@ -37,4 +72,5 @@ def test_l2_normalize(
atol_=1e-1,
x=x[0],
axis=axis,
+ p=p,
)
diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py
index f7a668564617a..f49c7f49d40c4 100644
--- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py
+++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py
@@ -69,3 +69,40 @@ def test_thresholded_relu(
x=x[0],
threshold=threshold,
)
+
+
+@handle_test(
+ fn_tree="prelu",
+ dtype_and_x=helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ shape=st.shared(helpers.get_shape(), key="prelu"),
+ large_abs_safety_factor=8,
+ small_abs_safety_factor=8,
+ safety_factor_scale="log",
+ ),
+ slope=helpers.array_values(
+ dtype=helpers.get_dtypes("float"),
+ shape=st.shared(helpers.get_shape(), key="prelu"),
+ ),
+)
+def test_prelu(
+ *,
+ dtype_and_x,
+ slope,
+ test_flags,
+ backend_fw,
+ fn_name,
+ on_device,
+ ground_truth_backend,
+):
+ dtype, x = dtype_and_x
+ helpers.test_function(
+ ground_truth_backend=ground_truth_backend,
+ input_dtypes=dtype,
+ fw=backend_fw,
+ test_flags=test_flags,
+ fn_name=fn_name,
+ on_device=on_device,
+ x=x[0],
+ slope=slope,
+ )
diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py
index 35d7eb25b2718..9927b4b7e9e4d 100644
--- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py
+++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py
@@ -1,6 +1,5 @@
# global
-from hypothesis import strategies as st
-
+from hypothesis import strategies as st, assume
# local
import ivy_tests.test_ivy.helpers as helpers
@@ -10,19 +9,40 @@
@handle_test(
fn_tree="functional.ivy.experimental.max_pool2d",
x_k_s_p=helpers.arrays_for_pooling(
- min_dims=4, max_dims=4, min_side=1, max_side=4, allow_explicit_padding=True
+ min_dims=4,
+ max_dims=4,
+ min_side=2,
+ max_side=4,
+ allow_explicit_padding=True,
+ return_dilation=True,
),
+ ceil_mode=st.just(True),
test_gradients=st.just(False),
+ # problem with containers converting tuple padding to
+ # lists which jax does not support
container_flags=st.just([False]),
)
def test_max_pool2d(
*,
x_k_s_p,
+ ceil_mode,
test_flags,
backend_fw,
fn_name,
):
- dtype, x, kernel, stride, pad = x_k_s_p
+ dtype, x, kernel, stride, pad, dilation = x_k_s_p
+ assume(
+ not (
+ backend_fw.current_backend_str() == "tensorflow"
+ and (
+ (stride[0] > kernel[0] or stride[0] > kernel[1])
+ or (
+ (stride[0] > 1 and dilation[0] > 1)
+ or (stride[0] > 1 and dilation[1] > 1)
+ )
+ )
+ )
+ )
helpers.test_function(
ground_truth_backend="jax",
input_dtypes=dtype,
@@ -35,6 +55,8 @@ def test_max_pool2d(
kernel=kernel,
strides=stride,
padding=pad,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
)
@@ -234,6 +256,83 @@ def test_dct(
)
+@st.composite
+def _interp_args(draw):
+ mode = draw(st.sampled_from(["linear", "bilinear", "trilinear", "nearest", "area"]))
+ align_corners = draw(st.one_of(st.booleans(), st.none()))
+ if mode == "linear":
+ size = draw(helpers.ints(min_value=1, max_value=5))
+ num_dims = 3
+ elif mode == "bilinear":
+ size = draw(
+ helpers.lists(
+ arg=helpers.ints(min_value=1, max_value=5), min_size=2, max_size=2
+ )
+ )
+ num_dims = 4
+ elif mode == "trilinear":
+ size = draw(
+ helpers.lists(
+ arg=helpers.ints(min_value=1, max_value=5), min_size=3, max_size=3
+ )
+ )
+ num_dims = 5
+ elif mode == "nearest" or mode == "area":
+ dim = draw(helpers.ints(min_value=1, max_value=3))
+ size = draw(
+ helpers.lists(
+ arg=helpers.ints(min_value=1, max_value=5), min_size=dim, max_size=dim
+ )
+ )
+ size = size[0] if dim == 1 else size
+ num_dims = dim + 2
+ align_corners = None
+ dtype, x = draw(
+ helpers.dtype_and_values(
+ available_dtypes=helpers.get_dtypes("float"),
+ min_num_dims=num_dims,
+ max_num_dims=num_dims,
+ min_dim_size=1,
+ max_dim_size=3,
+ large_abs_safety_factor=30,
+ small_abs_safety_factor=30,
+ safety_factor_scale="log",
+ )
+ )
+
+ return dtype, x, mode, size, align_corners
+
+
+@handle_test(
+ fn_tree="functional.ivy.experimental.interpolate",
+ dtype_x_mode=_interp_args(),
+ test_gradients=st.just(False),
+ number_positional_args=st.just(2),
+)
+def test_interpolate(
+ dtype_x_mode,
+ test_flags,
+ backend_fw,
+ fn_name,
+ on_device,
+):
+ input_dtype, x, mode, size, align_corners = dtype_x_mode
+ helpers.test_function(
+ ground_truth_backend="torch",
+ input_dtypes=input_dtype,
+ test_flags=test_flags,
+ fw=backend_fw,
+ fn_name=fn_name,
+ on_device=on_device,
+ rtol_=1e-01,
+ atol_=1e-01,
+ x=x[0],
+ size=size,
+ mode=mode,
+ align_corners=align_corners,
+ )
+
+
@st.composite
def x_and_fft(draw, dtypes):
min_fft_points = 2
diff --git a/ivy_tests/test_ivy/test_misc/test_array.py b/ivy_tests/test_ivy/test_misc/test_array.py
index 190d4675c8a76..a6233e2f30c84 100644
--- a/ivy_tests/test_ivy/test_misc/test_array.py
+++ b/ivy_tests/test_ivy/test_misc/test_array.py
@@ -30,6 +30,46 @@ def _getitem_setitem(draw, available_dtypes=None):
return index, x
+def test_array_function():
+ HANDLED_FUNCTIONS = {}
+
+ class MyArray:
+ def __init__(self, data=None):
+ self.data = data
+
+ def __array_function__(self, func, types, args, kwargs):
+ if func not in HANDLED_FUNCTIONS:
+ return NotImplemented
+ if not all(
+ issubclass(t, (MyArray, ivy.Array, ivy.NativeArray)) for t in types
+ ):
+ return NotImplemented
+ return HANDLED_FUNCTIONS[func](*args, **kwargs)
+
+ def implements(ivy_function):
+ """Register an __array_function__ implementation for MyArray objects."""
+
+ def decorator(func):
+ HANDLED_FUNCTIONS[ivy_function] = func
+ return func
+
+ return decorator
+
+ @implements(ivy.abs)
+ def _(my_array, ivy_array):
+ my_array.data = abs(my_array.data)
+ ivy_array = ivy.abs(ivy_array)
+ return (my_array, ivy_array)
+
+ x = MyArray(-3)
+ y = ivy.array([1, -1])
+ xy = ivy.abs(x, ivy_array=y) # works
+ x1 = xy[0]
+ y1 = xy[1]
+ assert x1.data == 3
+ assert all(y1 == ivy.array([1, 1]))
+
+
# TODO do not use dummy fn_tree
@handle_test(
fn_tree="functional.ivy.native_array", # dummy fn_tree
diff --git a/ivy_tests/test_ivy/test_misc/test_backend_handler.py b/ivy_tests/test_ivy/test_misc/test_backend_handler.py
index 4ab2a366a60dd..48a35e59e383e 100644
--- a/ivy_tests/test_ivy/test_misc/test_backend_handler.py
+++ b/ivy_tests/test_ivy/test_misc/test_backend_handler.py
@@ -16,10 +16,11 @@
torch.tensor = lambda x: x
try:
import jax.numpy as jnp
+ import jax
except ImportError:
jnp = types.SimpleNamespace()
jnp.array = lambda x: x
-
+ jax = types.SimpleNamespace()
import numpy as np
# local
@@ -166,3 +167,142 @@ def test_get_backend(backend):
# checking whether the backend is returned correctly
ivy.assertions.check_equal(ivy.get_backend(backend), imported_backend)
+
+
+# Dynamic Backend
+
+backends = ["numpy", "torch", "tensorflow", "jax"]
+backend_combinations = [(a, b) for a in backends for b in backends if a != b]
+
+
+@pytest.mark.parametrize("middle_backend,end_backend", backend_combinations)
+def test_dynamic_backend_all_combos(middle_backend, end_backend):
+
+ # create an ivy array, container and native container
+ a = ivy.array([1, 2, 3])
+ b = ivy.array([4, 5, 6])
+ ivy_cont = ivy.Container({"w": a, "b": b})
+ nativ_cont = ivy.Container(
+ {"w": tf.Variable([1, 2, 3]), "b": tf.Variable([4, 5, 6])}
+ )
+
+ # clear the backend stack after initialization of inputs
+ ivy.clear_backend_stack()
+
+ # set dynamic_backend to false for all objects
+ ivy_cont.dynamic_backend = False
+ nativ_cont.dynamic_backend = False
+ a.dynamic_backend = False
+ b.dynamic_backend = False
+
+ # set the middle backend
+ ivy.set_backend(middle_backend, dynamic=True)
+
+ # set dynamic_backend to true for all objects
+ ivy_cont.dynamic_backend = True
+ nativ_cont.dynamic_backend = True
+ a.dynamic_backend = True
+ b.dynamic_backend = True
+
+ # set the final backend
+ ivy.set_backend(end_backend, dynamic=True)
+
+ # add the necessary asserts to check if the data
+ # of the objects are in the correct format
+
+ if end_backend == "numpy":
+ assert isinstance(a.data, np.ndarray)
+ elif end_backend == "torch":
+ assert isinstance(a.data, torch.Tensor)
+ elif end_backend == "jax":
+ assert isinstance(a.data, jax.interpreters.xla.DeviceArray)
+ elif end_backend == "tensorflow":
+ assert isinstance(a.data, tf.Tensor)
+
+ if end_backend == "numpy":
+ assert isinstance(ivy_cont["b"].data, np.ndarray)
+ elif end_backend == "torch":
+ assert isinstance(ivy_cont["b"].data, torch.Tensor)
+ elif end_backend == "jax":
+ assert isinstance(ivy_cont["b"].data, jax.interpreters.xla.DeviceArray)
+ elif end_backend == "tensorflow":
+ assert isinstance(ivy_cont["b"].data, tf.Tensor)
+
+ if end_backend == "numpy":
+ assert isinstance(nativ_cont["b"].data, np.ndarray)
+ elif end_backend == "jax":
+ assert isinstance(nativ_cont["b"].data, jax.interpreters.xla.DeviceArray)
+
+ if middle_backend not in ("jax", "numpy"):
+ # these frameworks don't support native variables
+ if end_backend == "torch":
+ assert (
+ isinstance(nativ_cont["b"].data, torch.Tensor)
+ and nativ_cont["b"].data.requires_grad is True
+ )
+ if end_backend == "tensorflow":
+ assert isinstance(nativ_cont["b"].data, tf.Variable)
+
+ else:
+ if end_backend == "torch":
+ assert isinstance(nativ_cont["b"].data, torch.Tensor)
+ if end_backend == "tensorflow":
+ assert isinstance(nativ_cont["b"].data, tf.Tensor)
+
+
+def test_dynamic_backend_setter():
+
+ a = ivy.array([1, 2, 3])
+ type_a = type(a.data)
+ a.dynamic_backend = False
+
+ # clear the backend stack after initialization of inputs
+ ivy.clear_backend_stack()
+
+ ivy.set_backend("tensorflow", dynamic=True)
+ assert type(a.data) == type_a
+
+ a.dynamic_backend = True
+ assert isinstance(a.data, tf.Tensor)
+
+ ivy.set_backend("torch", dynamic=True)
+ assert isinstance(a.data, torch.Tensor)
+
+
+def test_variables():
+
+ # clear the backend stack
+ ivy.clear_backend_stack()
+
+ ivy.set_backend("tensorflow", dynamic=True)
+
+ a = tf.Variable(0)
+ b = tf.Variable(1)
+
+ dyn_cont = ivy.Container({"w": a, "b": b})
+ stat_cont = ivy.Container({"w": a, "b": b})
+ stat_cont.dynamic_backend = False
+
+ ivy.set_backend("torch", dynamic=True)
+ assert (
+ isinstance(dyn_cont["w"].data, torch.Tensor)
+ and dyn_cont["w"].data.requires_grad is True
+ )
+
+ assert isinstance(stat_cont["w"], tf.Variable)
+
+
+def test_dynamic_backend_context_manager():
+
+ with ivy.dynamic_backend_as(True):
+ a = ivy.array([0.0, 1.0])
+ b = ivy.array([2.0, 3.0])
+
+ with ivy.dynamic_backend_as(False):
+ c = ivy.array([4.0, 5.0])
+ d = ivy.array([6.0, 7.0])
+
+ assert a.dynamic_backend is True
+ assert b.dynamic_backend is True
+ assert c.dynamic_backend is False
+ assert d.dynamic_backend is False
diff --git a/multiversion_frontend_test.py b/multiversion_frontend_test.py
new file mode 100644
index 0000000000000..1ae7a2eb56a0f
--- /dev/null
+++ b/multiversion_frontend_test.py
@@ -0,0 +1,130 @@
+from ivy_tests import config
+import sys
+import jsonpickle
+import importlib
+
+
+def available_frameworks():
+ available_frameworks_lis = ["numpy", "jax", "tensorflow", "torch"]
+ try:
+ import jax
+
+ assert jax, "jax is imported to see if the user has it installed"
+ except ImportError:
+ available_frameworks_lis.remove("jax")
+
+ try:
+ import tensorflow as tf
+
+ assert tf, "tensorflow is imported to see if the user has it installed"
+ except ImportError:
+ available_frameworks_lis.remove("tensorflow")
+
+ try:
+ import torch
+
+ assert torch, "torch is imported to see if the user has it installed"
+ except ImportError:
+ available_frameworks_lis.remove("torch")
+ return available_frameworks_lis
+
+
+def convtrue(argument):
+ """Convert NativeClass in argument to true framework counter part"""
+ if isinstance(argument, NativeClass):
+ return argument._native_class
+ return argument
+
+
+class NativeClass:
+ """
+ An empty class to represent a class that only exist in a specific framework.
+
+ Attributes
+ ----------
+ _native_class : class reference
+ A reference to the framework-specific class.
+ """
+
+ def __init__(self, native_class):
+ """
+ Constructs the native class object.
+
+ Parameters
+ ----------
+ native_class : class reference
+ A reperence to the framework-specific class being represented.
+ """
+ self._native_class = native_class
+
+
+if __name__ == "__main__":
+
+ arg_lis = sys.argv
+ fw_lis = []
+ for i in arg_lis[1:]:
+ if i.split("/")[0] == "jax":
+ fw_lis.append(i.split("/")[0] + "/" + i.split("/")[1])
+ fw_lis.append(i.split("/")[2] + "/" + i.split("/")[3])
+ else:
+ fw_lis.append(i)
+ config.allow_global_framework_imports(fw=fw_lis)
+
+ j = 1
+ import ivy
+
+ # ivy.bfloat16
+ ivy.set_backend(arg_lis[2].split("/")[0])
+ import numpy
+
+ while j:
+ try:
+ z = input()
+ pickle_dict = jsonpickle.loads(z)
+ frontend_fw = input()
+
+ frontend_fw = importlib.import_module(frontend_fw)
+
+ func = input()
+
+ args_np, kwargs_np = pickle_dict["a"], pickle_dict["b"]
+ args_frontend = ivy.nested_map(
+ args_np,
+ lambda x: ivy.native_array(x)
+ if isinstance(x, numpy.ndarray)
+ else ivy.as_native_dtype(x)
+ if isinstance(x, ivy.Dtype)
+ else x,
+ shallow=False,
+ )
+ kwargs_frontend = ivy.nested_map(
+ kwargs_np,
+ lambda x: ivy.native_array(x) if isinstance(x, numpy.ndarray) else x,
+ shallow=False,
+ )
+
+ # change ivy dtypes to native dtypes
+ if "dtype" in kwargs_frontend:
+ kwargs_frontend["dtype"] = ivy.as_native_dtype(kwargs_frontend["dtype"])
+
+ # change ivy device to native devices
+ if "device" in kwargs_frontend:
+ kwargs_frontend["device"] = ivy.as_native_dev(kwargs_frontend["device"])
+
+ # check and replace the NativeClass objects in arguments
+ # with true counterparts
+ args_frontend = ivy.nested_map(
+ args_frontend, fn=convtrue, include_derived=True, max_depth=10
+ )
+ kwargs_frontend = ivy.nested_map(
+ kwargs_frontend, fn=convtrue, include_derived=True, max_depth=10
+ )
+
+ frontend_ret = frontend_fw.__dict__[func](*args_frontend, **kwargs_frontend)
+ frontend_ret = ivy.to_numpy(frontend_ret)
+ frontend_ret = jsonpickle.dumps(frontend_ret)
+ print(frontend_ret)
+ except EOFError:
+ continue
+ except Exception as e:
+ raise e
diff --git a/requirements/optional_m1_1.txt b/requirements/optional_m1_1.txt
index 1b1ad69394775..cc7565beb116a 100644
--- a/requirements/optional_m1_1.txt
+++ b/requirements/optional_m1_1.txt
@@ -2,6 +2,8 @@ h5py==3.7.0
pytest==7.1.2
networkx==2.8.4
hypothesis==6.48.2
+pymongo==4.3.3
+redis==4.3.4
matplotlib==3.5.2
opencv-python==4.6.0.66 # mod_name=cv2
jax==0.3.14
diff --git a/requirements/optional_m1_2.txt b/requirements/optional_m1_2.txt
index aae7fa77bb5bf..7e8569f380ed0 100644
--- a/requirements/optional_m1_2.txt
+++ b/requirements/optional_m1_2.txt
@@ -2,3 +2,6 @@ torch-scatter==2.0.9 # torch_scatter requires a prior existing installation of t
scipy==1.8.1
dm-haiku==0.0.6 # mod_name=haiku
protobuf==3.19.4
+pydriller
+tqdm
+coverage
\ No newline at end of file
diff --git a/run_tests.py b/run_tests.py
index df6d6c064a15e..78b0c3eaa3b09 100644
--- a/run_tests.py
+++ b/run_tests.py
@@ -75,14 +75,17 @@ def run_multiversion_testing(failed):
with open("tests_to_run", "r") as f:
for line in f:
test, frontend, backend = line.split(",")
+ frontend, backend = frontend.split("=")[1], backend.split("=")[1].replace(
+ ":", ","
+ )
+ print(test, frontend, backend)
ret = os.system(
f'docker run --rm -v "$(pwd)":/ivy -v "$(pwd)"/.hypothesis:/.hypothesis unifyai/multiversion /opt/miniconda/envs/multienv/bin/python -m pytest --tb=short {test} --frontend={frontend} --backend={backend}' # noqa
)
if ret != 0:
- failed = True
-
- if failed:
exit(1)
+ else:
+ exit(0)
if __name__ == "__main__":
@@ -90,10 +93,12 @@ def run_multiversion_testing(failed):
redis_pass = sys.argv[2]
mongo_key = sys.argv[3]
version_flag = sys.argv[4]
- if len(sys.argv) > 4:
- run_id = sys.argv[5]
+ workflow_id = sys.argv[5]
+ if len(sys.argv) > 6:
+ print(f"Job URL available -: {sys.argv}")
+ run_id = sys.argv[6]
else:
- run_id = "https://github.com/unifyai/ivy/actions/"
+ run_id = "https://github.com/unifyai/ivy/actions/runs/" + workflow_id
failed = False
# multiversion testing
if version_flag == "true":