Skip to content

Commit

Permalink
Merge pull request #1 from unifyai/main
Browse files Browse the repository at this point in the history
Interpolate ivy-llc#22679
  • Loading branch information
azhanxjaved authored Aug 29, 2023
2 parents e9a6e88 + 43f4291 commit 5febda3
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 43 deletions.
89 changes: 50 additions & 39 deletions ivy/functional/backends/paddle/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,17 +405,32 @@ def scatter_nd(
if ivy.exists(out)
else list(indices.shape[:-1]) + list(shape[indices.shape[-1] :])
)
updates = _broadcast_to(updates, expected_shape)._data

if indices.ndim > 1:
indices, unique_idxs = ivy.unique_all(indices, axis=0)[:2]
indices, unique_idxs = indices.data, unique_idxs.data
updates = ivy.gather(updates, unique_idxs, axis=0).data
updates = _broadcast_to(updates, expected_shape).data

# remove duplicate indices
# necessary because we will be using scatter_nd_add
if indices.ndim > 1 and reduction != "sum":
indices_shape = indices.shape
indices = paddle.reshape(indices, (-1, indices.shape[-1]))
num_indices = indices.shape[0]
# use flip to keep the last occurrence of each value
indices, unique_idxs = ivy.unique_all(ivy.flip(indices, axis=[0]), axis=0, by_value=True)[:2]
indices = indices.data
if len(unique_idxs) < num_indices:
updates = paddle.reshape(updates, (-1, *updates.shape[len(indices_shape)-1:]))
updates = ivy.gather(ivy.flip(updates, axis=[0]), unique_idxs, axis=0).data
expected_shape = (
list(indices.shape[:-1]) + list(out.shape[indices.shape[-1]:])
if ivy.exists(out)
else list(indices.shape[:-1]) + list(shape[indices.shape[-1]:])
)
else:
indices = paddle.reshape(indices, indices_shape)

# implementation
target_given = ivy.exists(out)
if target_given:
target = out._data
target = out.data
else:
shape = list(shape) if ivy.exists(shape) else out.shape
target = paddle.zeros(shape=shape).astype(updates.dtype)
Expand All @@ -429,26 +444,30 @@ def scatter_nd(
'"sum", "min", "max" or "replace"'.format(reduction)
)
if reduction == "min":
updates = ivy.minimum(ivy.gather_nd(target, indices), updates)._data
reduction = "replace"
updates = ivy.minimum(ivy.gather_nd(target, indices), updates).data
elif reduction == "max":
updates = ivy.maximum(ivy.gather_nd(target, indices), updates)._data
reduction = "replace"
updates = ivy.maximum(ivy.gather_nd(target, indices), updates).data
elif reduction == "sum":
updates = ivy.add(ivy.gather_nd(target, indices), updates).data
if indices.ndim <= 1:
indices = ivy.expand_dims(indices, axis=0)._data
updates = ivy.expand_dims(updates, axis=0)._data
indices = ivy.expand_dims(indices, axis=0).data
updates = ivy.expand_dims(updates, axis=0).data
updates_ = _broadcast_to(ivy.gather_nd(target, indices), expected_shape).data
target_dtype = target.dtype
if target_dtype in [
paddle.complex64,
paddle.complex128,
]:
if reduction == "replace":
updates = paddle_backend.subtract(
updates,
paddle_backend.gather_nd(target, indices),
)
result_real = paddle.scatter_nd_add(target.real(), indices, updates.real())
result_imag = paddle.scatter_nd_add(target.imag(), indices, updates.imag())
result_real = paddle.scatter_nd_add(
paddle.scatter_nd_add(target.real(), indices, -updates_.real()),
indices,
updates.real(),
)
result_imag = paddle.scatter_nd_add(
paddle.scatter_nd_add(target.imag(), indices, -updates_.imag()),
indices,
updates.imag(),
)
ret = paddle.complex(result_real, result_imag)
elif target_dtype in [
paddle.int8,
Expand All @@ -457,26 +476,18 @@ def scatter_nd(
paddle.float16,
paddle.bool,
]:
if reduction == "replace":
updates = paddle.subtract(
updates.cast("float32"),
paddle.gather_nd(target.cast("float32"), indices),
)
ret = paddle.scatter_nd_add(target.cast("float32"), indices, updates).cast(
target_dtype
)
target, updates, updates_ = target.cast("float32"), updates.cast("float32"), updates_.cast("float32")
ret = paddle.scatter_nd_add(
paddle.scatter_nd_add(target, indices, -updates_),
indices,
updates,
).cast(target_dtype)
else:
if reduction == "replace":
gathered_vals = paddle.gather_nd(target, indices)
# values greater than 2^24 - 1 can only be accurately represented as float64
if (np.abs(gathered_vals.numpy()).max() >= 2**24) or (
np.abs(updates.numpy()).max() >= 2**24
):
gathered_vals = gathered_vals.cast("float64")
target = target.cast("float64")
updates = updates.cast("float64")
updates = paddle.subtract(updates, gathered_vals)
ret = paddle.scatter_nd_add(target, indices, updates).cast(target_dtype)
ret = paddle.scatter_nd_add(
paddle.scatter_nd_add(target, indices, -updates_),
indices,
updates,
)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
Expand Down
14 changes: 12 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,12 @@ def equal(self, other):
def erf(self, *, out=None):
return torch_frontend.erf(self, out=out)

@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, "torch")
@with_supported_dtypes(
{"2.0.1 and below": ("float32", "float64", "bfloat16")}, "torch"
)
def erf_(self, *, out=None):
self.ivy_array = torch_frontend.erf(self, out=out).ivy_array
self.ivy_array = self.erf(out=out).ivy_array
return self

def new_zeros(
self,
Expand Down Expand Up @@ -1556,6 +1559,13 @@ def logdet(self):
def copysign(self, other, *, out=None):
return torch_frontend.copysign(self, other, out=out)

@with_supported_dtypes(
{"2.0.1 and below": ("float16", "float32", "float64")}, "torch"
)
def copysign_(self, other, *, out=None):
self.ivy_array = self.copysign(other, out=out).ivy_array
return self

@with_unsupported_dtypes(
{"2.0.1 and below": ("complex", "bfloat16", "bool")}, "torch"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,7 @@ def dtype_array_query(
min_value=-s + 1,
max_value=s - 1,
dtype=["int64"],
max_num_dims=4,
)
)
new_index = new_index[0]
Expand Down
41 changes: 40 additions & 1 deletion ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5135,6 +5135,45 @@ def test_torch_tensor_copysign(
)


# copysign_
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="copysign_",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
num_arrays=2,
),
)
def test_torch_tensor_copysign_(
dtype_and_x,
frontend_method_data,
init_flags,
method_flags,
frontend,
on_device,
backend_fw,
):
input_dtype, x = dtype_and_x
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np={
"other": x[1],
},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
)


# cos
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down Expand Up @@ -6141,7 +6180,7 @@ def test_torch_tensor_erf(
init_tree="torch.tensor",
method_name="erf_",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("valid"),
),
)
def test_torch_tensor_erf_(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,6 @@ def test_set_item(
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
rtol_=1e-03, # needed only for the paddle backend
x=x,
query=query,
val=val,
Expand Down

0 comments on commit 5febda3

Please sign in to comment.