Skip to content

Commit

Permalink
refactor: added support_native_out attr & update.. (#10736)
Browse files Browse the repository at this point in the history
Co-authored by MahmoudAshraf97 <[email protected]>
  • Loading branch information
rishabgit authored Feb 22, 2023
1 parent b8151bf commit 73e1c53
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 5 deletions.
22 changes: 21 additions & 1 deletion ivy/array/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,8 @@ def scatter_flat(
updates
Values for the new array to hold.
size
The size of the result.
The size of the result. Default is `None`, in which case tensor
argument out must be provided.
reduction
The reduction method for the scatter, one of 'sum', 'min', 'max' or
'replace'
Expand All @@ -1211,6 +1212,25 @@ def scatter_flat(
-------
ret
New array of given shape, with the values scattered at the indices.
Examples
--------
With :class:`ivy.Array` input:
>>> indices = ivy.array([0, 0, 1, 0, 2, 2, 3, 3])
>>> updates = ivy.array([5, 1, 7, 2, 3, 2, 1, 3])
>>> size = 8
>>> out = indices.scatter_flat(updates, size=size)
>>> print(out)
ivy.array([8, 7, 5, 4, 0, 0, 0, 0])
With :class:`ivy.Array` input:
>>> indices = ivy.array([0, 0, 1, 0, 2, 2, 3, 3])
>>> updates = ivy.array([5, 1, 7, 2, 3, 2, 1, 3])
>>> out = ivy.array([0, 0, 0, 0, 0, 0, 0, 0])
>>> indices.scatter_flat(updates, out=out)
>>> print(out)
ivy.array([8, 7, 5, 4, 0, 0, 0, 0])
"""
return ivy.scatter_flat(self, updates, size=size, reduction=reduction, out=out)

Expand Down
20 changes: 18 additions & 2 deletions ivy/container/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,8 @@ def static_scatter_flat(
updates
values to update input tensor with
size
The size of the result.
The size of the result. Default is `None`, in which case tensor
argument out must be provided.
reduction
The reduction method for the scatter, one of 'sum', 'min', 'max'
or 'replace'
Expand Down Expand Up @@ -1915,7 +1916,8 @@ def scatter_flat(
updates
values to update input tensor with
size
The size of the result.
The size of the result. Default is `None`, in which case tensor
argument out must be provided.
reduction
The reduction method for the scatter, one of 'sum', 'min', 'max'
or 'replace'
Expand All @@ -1938,6 +1940,20 @@ def scatter_flat(
-------
ret
New container of given shape, with the values updated at the indices.
Examples
--------
With :class:`ivy.Container` input:
>>> indices = ivy.Container(a=ivy.array([1, 0, 1, 0, 2, 2, 3, 3]), \
b=ivy.array([0, 0, 1, 0, 2, 2, 3, 3]))
>>> updates = ivy.Container(a=ivy.array([9, 2, 0, 2, 3, 2, 1, 8]), \
b=ivy.array([5, 1, 7, 2, 3, 2, 1, 3]))
>>> size = 8
>>> print(ivy.scatter_flat(indices, updates, size=size))
{
a: ivy.array([4, 9, 5, 9, 0, 0, 0, 0]),
b: ivy.array([8, 7, 5, 4, 0, 0, 0, 0])
}
"""
return self.static_scatter_flat(
self,
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def scatter_flat(
return _to_device(target)


scatter_flat.support_native_out = True


def scatter_nd(
indices: JaxArray,
updates: JaxArray,
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/numpy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def scatter_flat(
return _to_device(target)


scatter_flat.support_native_out = True


def scatter_nd(
indices: np.ndarray,
updates: np.ndarray,
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def scatter_flat(
return res


scatter_flat.support_native_out = True


@with_unsupported_dtypes({"2.9.1 and below": ("bfloat16",)}, backend_version)
def scatter_nd(
indices: Union[tf.Tensor, tf.Variable],
Expand Down
5 changes: 4 additions & 1 deletion ivy/functional/backends/torch/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def scatter_flat(
size: Optional[int] = None,
reduction: str = "sum",
out: Optional[torch.Tensor] = None,
):
) -> torch.Tensor:
target = out
target_given = ivy.exists(target)
if ivy.exists(size) and ivy.exists(target):
Expand Down Expand Up @@ -354,6 +354,9 @@ def scatter_flat(
return res


scatter_flat.support_native_out = True


@with_unsupported_dtypes(
{
"1.11.0 and below": (
Expand Down
50 changes: 49 additions & 1 deletion ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,7 +2775,8 @@ def scatter_flat(
updates
Values for the new array to hold.
size
The size of the result.
The size of the result. Default is `None`, in which case tensor
argument out must be provided.
reduction
The reduction method for the scatter, one of 'sum', 'min', 'max' or 'replace'
out
Expand All @@ -2787,6 +2788,53 @@ def scatter_flat(
ret
New array of given shape, with the values scattered at the indices.
This function is *nestable*, and therefore also accepts :code:'ivy.Container'
instance in place of the argument.
Examples
--------
With :class:`ivy.Array` input:
>>> indices = ivy.array([0, 0, 1, 0, 2, 2, 3, 3])
>>> updates = ivy.array([5, 1, 7, 2, 3, 2, 1, 3])
>>> out = ivy.array([0, 0, 0, 0, 0, 0, 0, 0])
>>> ivy.scatter_flat(indices, updates, out=out)
>>> print(out)
ivy.array([8, 7, 5, 4, 0, 0, 0, 0])
With :class:`ivy.Array` input:
>>> indices = ivy.array([1, 0, 1, 0, 2, 2, 3, 3])
>>> updates = ivy.array([9, 2, 0, 2, 3, 2, 1, 8])
>>> size = 8
>>> print(ivy.scatter_flat(indices, updates, size=size))
ivy.array([4, 9, 5, 9, 0, 0, 0, 0])
With :class:`ivy.Container` and :class:`ivy.Array` input:
>>> indices = ivy.array([1, 0, 1, 0, 2, 2, 3, 3])
>>> updates = ivy.Container(a=ivy.array([9, 2, 0, 2, 3, 2, 1, 8]), \
b=ivy.array([5, 1, 7, 2, 3, 2, 1, 3]))
>>> size = 8
>>> print(ivy.scatter_flat(indices, updates, size=size))
{
a: ivy.array([4, 9, 5, 9, 0, 0, 0, 0]),
b: ivy.array([3, 12, 5, 4, 0, 0, 0, 0])
}
With :class:`ivy.Container` input:
>>> indices = ivy.Container(a=ivy.array([1, 0, 1, 0, 2, 2, 3, 3]), \
b=ivy.array([0, 0, 1, 0, 2, 2, 3, 3]))
>>> updates = ivy.Container(a=ivy.array([9, 2, 0, 2, 3, 2, 1, 8]), \
b=ivy.array([5, 1, 7, 2, 3, 2, 1, 3]))
>>> size = 8
>>> print(ivy.scatter_flat(indices, updates, size=size))
{
a: ivy.array([4, 9, 5, 9, 0, 0, 0, 0]),
b: ivy.array([8, 7, 5, 4, 0, 0, 0, 0])
}
"""
return current_backend(indices).scatter_flat(
indices, updates, size=size, reduction=reduction, out=out
Expand Down

0 comments on commit 73e1c53

Please sign in to comment.