Skip to content

Commit

Permalink
feat: Updated torch version mapping form 2.1.2 to 2.2 (#28236)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai-Suraj-27 authored Feb 10, 2024
1 parent 1603886 commit 3277232
Show file tree
Hide file tree
Showing 51 changed files with 674 additions and 678 deletions.
24 changes: 12 additions & 12 deletions ivy/functional/backends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def rep_method(*args, **kwargs):

# update these to add new dtypes
valid_dtypes = {
"2.1.2 and below": (
"2.2 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -147,7 +147,7 @@ def rep_method(*args, **kwargs):


valid_numeric_dtypes = {
"2.1.2 and below": (
"2.2 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -163,13 +163,13 @@ def rep_method(*args, **kwargs):
}

valid_int_dtypes = {
"2.1.2 and below": (ivy.int8, ivy.int16, ivy.int32, ivy.int64, ivy.uint8)
"2.2 and below": (ivy.int8, ivy.int16, ivy.int32, ivy.int64, ivy.uint8)
}
valid_float_dtypes = {
"2.1.2 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
"2.2 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_uint_dtypes = {"2.1.2 and below": (ivy.uint8,)}
valid_complex_dtypes = {"2.1.2 and below": (ivy.complex64, ivy.complex128)}
valid_uint_dtypes = {"2.2 and below": (ivy.uint8,)}
valid_complex_dtypes = {"2.2 and below": (ivy.complex64, ivy.complex128)}

# leave these untouched
valid_dtypes = _dtype_from_version(valid_dtypes, backend_version)
Expand All @@ -182,17 +182,17 @@ def rep_method(*args, **kwargs):
# invalid data types
# update these to add new dtypes
invalid_dtypes = {
"2.1.2 and below": (
"2.2 and below": (
ivy.uint16,
ivy.uint32,
ivy.uint64,
)
}
invalid_numeric_dtypes = {"2.1.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_int_dtypes = {"2.1.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_float_dtypes = {"2.1.2 and below": ()}
invalid_uint_dtypes = {"2.1.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_complex_dtypes = {"2.1.2 and below": ()}
invalid_numeric_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_int_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_float_dtypes = {"2.2 and below": ()}
invalid_uint_dtypes = {"2.2 and below": (ivy.uint16, ivy.uint32, ivy.uint64)}
invalid_complex_dtypes = {"2.2 and below": ()}
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)

# leave these untouched
Expand Down
20 changes: 10 additions & 10 deletions ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import ivy.functional.backends.torch as torch_backend


@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("float16",)}, backend_version)
def relu(
x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.relu(x)


@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("float16",)}, backend_version)
def leaky_relu(
x: torch.Tensor,
/,
Expand All @@ -37,7 +37,7 @@ def leaky_relu(
return torch.nn.functional.leaky_relu(x, alpha)


@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("float16",)}, backend_version)
def gelu(
x: torch.Tensor,
/,
Expand All @@ -51,7 +51,7 @@ def gelu(
return torch.nn.functional.gelu(x)


@with_unsupported_dtypes({"2.1.2 and below": ("float16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("float16",)}, backend_version)
def sigmoid(
x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand All @@ -63,7 +63,7 @@ def sigmoid(
sigmoid.support_native_out = True


@with_unsupported_dtypes({"2.1.2 and below": ("bfloat16", "float16")}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, backend_version)
def softmax(
x: torch.Tensor,
/,
Expand All @@ -80,7 +80,7 @@ def softmax(
return torch.nn.functional.softmax(x, axis)


@with_unsupported_dtypes({"2.1.2 and below": ("float16", "bfloat16")}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version)
def softplus(
x: torch.Tensor,
/,
Expand All @@ -97,7 +97,7 @@ def softplus(


# Softsign
@with_unsupported_dtypes({"2.1.2 and below": ("float16", "bfloat16")}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, backend_version)
def softsign(x: torch.Tensor, /, out: Optional[torch.Tensor] = None) -> torch.Tensor:
# return x / (1 + torch.abs(x))
return torch.nn.functional.softsign(x)
Expand All @@ -107,7 +107,7 @@ def softsign(x: torch.Tensor, /, out: Optional[torch.Tensor] = None) -> torch.Te


@with_unsupported_dtypes(
{"2.1.2 and below": ("float16",)},
{"2.2 and below": ("float16",)},
backend_version,
)
def log_softmax(
Expand All @@ -128,7 +128,7 @@ def log_softmax(


@with_unsupported_dtypes(
{"2.1.2 and below": ("float16",)},
{"2.2 and below": ("float16",)},
backend_version,
)
def mish(
Expand All @@ -146,7 +146,7 @@ def mish(

@with_unsupported_dtypes(
{
"2.1.2 and below": (
"2.2 and below": (
"complex",
"float16",
)
Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/backends/torch/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _differentiable_linspace(start, stop, num, *, device, dtype=None):
return res


@with_unsupported_dtypes({"2.1.2 and below": ("complex",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("complex",)}, backend_version)
def arange(
start: float,
/,
Expand Down Expand Up @@ -95,7 +95,7 @@ def _stack_tensors(x, dtype):
return x


@with_unsupported_dtypes({"2.1.2 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, backend_version)
@_asarray_to_native_arrays_and_back
@_asarray_infer_device
@_asarray_handle_nestable
Expand Down Expand Up @@ -166,7 +166,7 @@ def empty_like(
return torch.empty_like(x, dtype=dtype, device=device)


@with_unsupported_dtypes({"2.1.2 and below": ("bfloat16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, backend_version)
def eye(
n_rows: int,
n_cols: Optional[int] = None,
Expand Down Expand Up @@ -276,7 +276,7 @@ def _slice_at_axis(sl, axis):


@with_unsupported_device_and_dtypes(
{"2.1.2 and below": {"cpu": ("float16",)}}, backend_version
{"2.2 and below": {"cpu": ("float16",)}}, backend_version
)
def linspace(
start: Union[torch.Tensor, float],
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/torch/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def smallest_normal(self):
# -------------------#


@with_unsupported_dtypes({"2.1.2 and below": ("bfloat16", "float16")}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, backend_version)
def astype(
x: torch.Tensor,
dtype: torch.dtype,
Expand Down Expand Up @@ -181,7 +181,7 @@ def as_ivy_dtype(
)


@with_unsupported_dtypes({"2.1.2 and below": ("uint16",)}, backend_version)
@with_unsupported_dtypes({"2.2 and below": ("uint16",)}, backend_version)
def as_native_dtype(
dtype_in: Union[torch.dtype, str, bool, int, float, np.dtype],
) -> torch.dtype:
Expand Down
Loading

0 comments on commit 3277232

Please sign in to comment.