Skip to content

Commit

Permalink
fix promote_types_of_inputs: if one of the inputs is a scalar, make s…
Browse files Browse the repository at this point in the history
…ure it will have the same device as the other one when converted to [native]array regardless of default device (#11928)
  • Loading branch information
xoiga123 authored Mar 9, 2023
1 parent d5a7eac commit 93964b5
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
8 changes: 6 additions & 2 deletions ivy/functional/frontends/jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,13 @@ def promote_types_of_jax_inputs(
type1 = ivy.default_dtype(item=x1).strip("u123456789")
type2 = ivy.default_dtype(item=x2).strip("u123456789")
if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2:
x2 = ivy.asarray(x2, dtype=x1.dtype)
x2 = ivy.asarray(
x2, dtype=x1.dtype, device=ivy.default_device(item=x1, as_native=False)
)
elif not hasattr(x1, "dtype") and hasattr(x2, "dtype") and type1 == type2:
x1 = ivy.asarray(x1, dtype=x2.dtype)
x1 = ivy.asarray(
x1, dtype=x2.dtype, device=ivy.default_device(item=x2, as_native=False)
)
else:
x1 = ivy.asarray(x1)
x2 = ivy.asarray(x2)
Expand Down
8 changes: 6 additions & 2 deletions ivy/functional/frontends/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,13 @@ def promote_types_of_numpy_inputs(
type2 = ivy.default_dtype(item=x2).strip("u123456789")
if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2:
x1 = ivy.asarray(x1)
x2 = ivy.asarray(x2, dtype=x1.dtype)
x2 = ivy.asarray(
x2, dtype=x1.dtype, device=ivy.default_device(item=x1, as_native=False)
)
elif not hasattr(x1, "dtype") and hasattr(x2, "dtype") and type1 == type2:
x1 = ivy.asarray(x1, dtype=x2.dtype)
x1 = ivy.asarray(
x1, dtype=x2.dtype, device=ivy.default_device(item=x2, as_native=False)
)
x2 = ivy.asarray(x2)
else:
x1 = ivy.asarray(x1)
Expand Down
8 changes: 6 additions & 2 deletions ivy/functional/frontends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,13 @@ def promote_types_of_torch_inputs(
type2 = ivy.default_dtype(item=x2).strip("u123456789")
if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2:
x1 = ivy.asarray(x1)
x2 = ivy.asarray(x2, dtype=x1.dtype)
x2 = ivy.asarray(
x2, dtype=x1.dtype, device=ivy.default_device(item=x1, as_native=False)
)
elif not hasattr(x1, "dtype") and hasattr(x2, "dtype") and type1 == type2:
x1 = ivy.asarray(x1, dtype=x2.dtype)
x1 = ivy.asarray(
x1, dtype=x2.dtype, device=ivy.default_device(item=x2, as_native=False)
)
x2 = ivy.asarray(x2)
else:
x1 = ivy.asarray(x1)
Expand Down
18 changes: 10 additions & 8 deletions ivy/functional/ivy/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2320,30 +2320,32 @@ def _special_case(a1, a2):
return isinstance(a1, float) and "int" in str(a2.dtype)

if hasattr(x1, "dtype") and not hasattr(x2, "dtype"):
device = ivy.default_device(item=x1, as_native=True)
if x1.dtype == bool and not isinstance(x2, bool):
x2 = (
ivy.asarray(x2)
ivy.asarray(x2, device=device)
if not _special_case(x2, x1)
else ivy.asarray(x2, dtype="float64")
else ivy.asarray(x2, dtype="float64", device=device)
)
else:
x2 = (
ivy.asarray(x2, dtype=x1.dtype)
ivy.asarray(x2, dtype=x1.dtype, device=device)
if not _special_case(x2, x1)
else ivy.asarray(x2, dtype="float64")
else ivy.asarray(x2, dtype="float64", device=device)
)
elif hasattr(x2, "dtype") and not hasattr(x1, "dtype"):
device = ivy.default_device(item=x2, as_native=True)
if x2.dtype == bool and not isinstance(x1, bool):
x1 = (
ivy.asarray(x1)
ivy.asarray(x1, device=device)
if not _special_case(x1, x2)
else ivy.asarray(x1, dtype="float64")
else ivy.asarray(x1, dtype="float64", device=device)
)
else:
x1 = (
ivy.asarray(x1, dtype=x2.dtype)
ivy.asarray(x1, dtype=x2.dtype, device=device)
if not _special_case(x1, x2)
else ivy.asarray(x1, dtype="float64")
else ivy.asarray(x1, dtype="float64", device=device)
)
elif not (hasattr(x1, "dtype") or hasattr(x2, "dtype")):
x1 = ivy.asarray(x1)
Expand Down

0 comments on commit 93964b5

Please sign in to comment.