diff --git a/ivy/functional/frontends/jax/numpy/__init__.py b/ivy/functional/frontends/jax/numpy/__init__.py index b5617adb92a27..36b894a47d4f6 100644 --- a/ivy/functional/frontends/jax/numpy/__init__.py +++ b/ivy/functional/frontends/jax/numpy/__init__.py @@ -416,9 +416,13 @@ def promote_types_of_jax_inputs( type1 = ivy.default_dtype(item=x1).strip("u123456789") type2 = ivy.default_dtype(item=x2).strip("u123456789") if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2: - x2 = ivy.asarray(x2, dtype=x1.dtype) + x2 = ivy.asarray( + x2, dtype=x1.dtype, device=ivy.default_device(item=x1, as_native=False) + ) elif not hasattr(x1, "dtype") and hasattr(x2, "dtype") and type1 == type2: - x1 = ivy.asarray(x1, dtype=x2.dtype) + x1 = ivy.asarray( + x1, dtype=x2.dtype, device=ivy.default_device(item=x2, as_native=False) + ) else: x1 = ivy.asarray(x1) x2 = ivy.asarray(x2) diff --git a/ivy/functional/frontends/numpy/__init__.py b/ivy/functional/frontends/numpy/__init__.py index d56b95692a6f7..7fb81ad4c0e1f 100644 --- a/ivy/functional/frontends/numpy/__init__.py +++ b/ivy/functional/frontends/numpy/__init__.py @@ -439,9 +439,13 @@ def promote_types_of_numpy_inputs( type2 = ivy.default_dtype(item=x2).strip("u123456789") if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2: x1 = ivy.asarray(x1) - x2 = ivy.asarray(x2, dtype=x1.dtype) + x2 = ivy.asarray( + x2, dtype=x1.dtype, device=ivy.default_device(item=x1, as_native=False) + ) elif not hasattr(x1, "dtype") and hasattr(x2, "dtype") and type1 == type2: - x1 = ivy.asarray(x1, dtype=x2.dtype) + x1 = ivy.asarray( + x1, dtype=x2.dtype, device=ivy.default_device(item=x2, as_native=False) + ) x2 = ivy.asarray(x2) else: x1 = ivy.asarray(x1) diff --git a/ivy/functional/frontends/torch/__init__.py b/ivy/functional/frontends/torch/__init__.py index 6ee554a05aa13..56de2bb480710 100644 --- a/ivy/functional/frontends/torch/__init__.py +++ b/ivy/functional/frontends/torch/__init__.py @@ -268,9 +268,13 @@ def promote_types_of_torch_inputs( type2 = ivy.default_dtype(item=x2).strip("u123456789") if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2: x1 = ivy.asarray(x1) - x2 = ivy.asarray(x2, dtype=x1.dtype) + x2 = ivy.asarray( + x2, dtype=x1.dtype, device=ivy.default_device(item=x1, as_native=False) + ) elif not hasattr(x1, "dtype") and hasattr(x2, "dtype") and type1 == type2: - x1 = ivy.asarray(x1, dtype=x2.dtype) + x1 = ivy.asarray( + x1, dtype=x2.dtype, device=ivy.default_device(item=x2, as_native=False) + ) x2 = ivy.asarray(x2) else: x1 = ivy.asarray(x1) diff --git a/ivy/functional/ivy/data_type.py b/ivy/functional/ivy/data_type.py index 318c85e9298c9..df479ee524799 100644 --- a/ivy/functional/ivy/data_type.py +++ b/ivy/functional/ivy/data_type.py @@ -2320,30 +2320,32 @@ def _special_case(a1, a2): return isinstance(a1, float) and "int" in str(a2.dtype) if hasattr(x1, "dtype") and not hasattr(x2, "dtype"): + device = ivy.default_device(item=x1, as_native=True) if x1.dtype == bool and not isinstance(x2, bool): x2 = ( - ivy.asarray(x2) + ivy.asarray(x2, device=device) if not _special_case(x2, x1) - else ivy.asarray(x2, dtype="float64") + else ivy.asarray(x2, dtype="float64", device=device) ) else: x2 = ( - ivy.asarray(x2, dtype=x1.dtype) + ivy.asarray(x2, dtype=x1.dtype, device=device) if not _special_case(x2, x1) - else ivy.asarray(x2, dtype="float64") + else ivy.asarray(x2, dtype="float64", device=device) ) elif hasattr(x2, "dtype") and not hasattr(x1, "dtype"): + device = ivy.default_device(item=x2, as_native=True) if x2.dtype == bool and not isinstance(x1, bool): x1 = ( - ivy.asarray(x1) + ivy.asarray(x1, device=device) if not _special_case(x1, x2) - else ivy.asarray(x1, dtype="float64") + else ivy.asarray(x1, dtype="float64", device=device) ) else: x1 = ( - ivy.asarray(x1, dtype=x2.dtype) + ivy.asarray(x1, dtype=x2.dtype, device=device) if not _special_case(x1, x2) - else ivy.asarray(x1, dtype="float64") + else ivy.asarray(x1, dtype="float64", device=device) ) elif not (hasattr(x1, "dtype") or hasattr(x2, "dtype")): x1 = ivy.asarray(x1)