Skip to content

Commit

Permalink
[BE]: Update Typeguard to TypeIs for better type inference (pytorch#1…
Browse files Browse the repository at this point in the history
…33814)

Uses TypeIs instead of TypeGuard for better inference. See https://peps.python.org/pep-0742/

Pull Request resolved: pytorch#133814
Approved by: https://github.com/ezyang
  • Loading branch information
Skylion007 authored and pytorchmergebot committed Aug 20, 2024
1 parent fbf3fc2 commit bce0cab
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def main():
)
install_requires = [
"filelock",
"typing-extensions>=4.8.0",
"typing-extensions>=4.10.0",
'setuptools ; python_version >= "3.12"',
'sympy==1.12.1 ; python_version == "3.8"',
'sympy==1.13.1 ; python_version >= "3.9"',
Expand Down
6 changes: 3 additions & 3 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TypeVar as _TypeVar,
Union as _Union,
)
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs


if TYPE_CHECKING:
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def typename(obj: _Any, /) -> str:
return f"{module}.{qualname}"


def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]:
r"""Returns True if `obj` is a PyTorch tensor.
Note that this function is simply doing ``isinstance(obj, Tensor)``.
Expand All @@ -1022,7 +1022,7 @@ def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
return isinstance(obj, torch.Tensor)


def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]:
r"""Returns True if `obj` is a PyTorch storage object.
Args:
Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
Union,
ValuesView,
)
from typing_extensions import Literal, TypeGuard
from typing_extensions import Literal, TypeIs

import torch
import torch._functorch.config
Expand Down Expand Up @@ -532,14 +532,14 @@ def clear(self):


@overload
def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]:
def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]:
...


@overload
def istype(
obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]]
) -> TypeGuard[T]:
) -> TypeIs[T]:
...


Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
TypeVar,
Union,
)
from typing_extensions import Self, TypeGuard
from typing_extensions import Self, TypeIs

import torch
import torch._guards
Expand Down Expand Up @@ -277,10 +277,10 @@ def __bool__(self) -> bool:
MatchResult = Union[Match, FailedMatch]


def is_match(m: MatchResult) -> TypeGuard[Match]:
def is_match(m: MatchResult) -> TypeIs[Match]:
"""
TypeGuards cannot act on `self`. Thus this function exists to let mypy
recognize FailedMatch.__bool__ as a TypeGuard.
TypeIs cannot act on `self`. Thus this function exists to let mypy
recognize FailedMatch.__bool__ as a TypeIs.
"""
return bool(m)

Expand Down
6 changes: 3 additions & 3 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
TypeVar,
Union,
)
from typing_extensions import Self, TypeGuard
from typing_extensions import Self, TypeIs
from weakref import ReferenceType

import torch
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]:
return plain_tensors


def is_fake(x: object) -> TypeGuard[Tensor]:
def is_fake(x: object) -> TypeIs[Tensor]:
if isinstance(x, FakeTensor):
return True
if is_traceable_wrapper_subclass(x):
Expand Down Expand Up @@ -1213,7 +1213,7 @@ def reset_nt_tensor_id_counter(self) -> None:
# In this case, it's insufficient to test only one FakeTensor: you need
# to distinguish between our fake tensor and other fake tensors. That's
# what this function does.
def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]:
def is_our_fake(self, t: object) -> TypeIs[FakeTensor]:
return isinstance(t, FakeTensor) and t.fake_mode is self

# If we should avoid device init. This changes the behavior of various APIs:
Expand Down
4 changes: 2 additions & 2 deletions torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import warnings
from typing import Any
from typing_extensions import TypeGuard
from typing_extensions import TypeIs

import torch
from torch.overrides import get_default_nowrap_functions
Expand All @@ -15,7 +15,7 @@
]


def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]:
r"""Returns True if the input is a MaskedTensor, else False
Args:
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/parameter.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing_extensions import TypeGuard
from typing_extensions import TypeIs

from torch import device, dtype, Tensor

Expand All @@ -8,7 +8,7 @@ class Parameter(Tensor):

def is_lazy(
param: Tensor,
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ...

class UninitializedParameter(Tensor):
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
Type,
Union,
)
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
from typing_extensions import TypeAlias, TypeIs

import torch
import torch._weights_only_unpickler as _weights_only_unpickler
Expand Down Expand Up @@ -549,7 +549,7 @@ def storage_to_tensor_type(storage):
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))


def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
return isinstance(name_or_buffer, (str, os.PathLike))


Expand Down
4 changes: 2 additions & 2 deletions torch/utils/_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque
from typing_extensions import TypeGuard
from typing_extensions import TypeIs
from collections import deque

import torch
Expand Down Expand Up @@ -354,7 +354,7 @@ def to(



def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]:
"""
Returns whether or not a tensor subclass that implements __torch_dispatch__
is 'traceable' with torch.compile.
Expand Down

0 comments on commit bce0cab

Please sign in to comment.