diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 43ab3ab2559d5..754904efd03c7 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -39,6 +39,11 @@ local tputests = base.BaseTest { echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" + echo "--- Sanity check TPU availability ---" + python -c "from lightning_lite.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()" + python -c "from pytorch_lightning.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()" + echo "Sanity check passed!" + echo "--- Running Lite tests ---" cd tests/tests_lite PL_RUN_TPU_TESTS=1 coverage run --source=lightning_lite -m pytest -vv --durations=0 ./ diff --git a/docs/source-lit/conf.py b/docs/source-lit/conf.py index 5a7cdd25b59dc..8d5c4c47465e3 100644 --- a/docs/source-lit/conf.py +++ b/docs/source-lit/conf.py @@ -406,8 +406,6 @@ def find_source(): from pytorch_lightning.cli import LightningCLI from pytorch_lightning.utilities import ( _APEX_AVAILABLE, - _XLA_AVAILABLE, - _TPU_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, ) diff --git a/docs/source-pytorch/accelerators/mps_basic.rst b/docs/source-pytorch/accelerators/mps_basic.rst index 5db866a531e13..eec8967c1ce1a 100644 --- a/docs/source-pytorch/accelerators/mps_basic.rst +++ b/docs/source-pytorch/accelerators/mps_basic.rst @@ -57,7 +57,7 @@ If Lightning can't detect the Apple Silicon hardware, it will raise this excepti .. code:: - MisconfigurationException: MPSAccelerator can not run on your system since the accelerator is not available. + MisconfigurationException: `MPSAccelerator` can not run on your system since the accelerator is not available. If you are seeing this despite running on an ARM-enabled Mac, the most likely cause is that your Python is being emulated and thinks it is running on an Intel CPU. To solve this, re-install your python executable (and if using environment managers like conda, you have to reinstall these as well) by downloading the Apple M1/M2 build (not Intel!), for example `here `_. diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index c732a7c181acd..467abdb9613a0 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -394,8 +394,6 @@ def package_list_from_file(file): from pytorch_lightning.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE from pytorch_lightning.utilities import ( _APEX_AVAILABLE, - _XLA_AVAILABLE, - _TPU_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, ) diff --git a/src/lightning_lite/accelerators/tpu.py b/src/lightning_lite/accelerators/tpu.py index 7a326e47596c3..81905d460d1f6 100644 --- a/src/lightning_lite/accelerators/tpu.py +++ b/src/lightning_lite/accelerators/tpu.py @@ -11,18 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +import functools +import queue as q +import traceback +from multiprocessing import Process, Queue +from typing import Any, Callable, Dict, List, Optional, Union import torch +from lightning_utilities.core.imports import RequirementCache from lightning_lite.accelerators.accelerator import Accelerator from lightning_lite.utilities.device_parser import _check_data_type -from lightning_lite.utilities.imports import _TPU_AVAILABLE class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(*args, **kwargs) + def setup_device(self, device: torch.device) -> None: pass @@ -47,8 +56,10 @@ def auto_device_count() -> int: return 8 @staticmethod + @functools.lru_cache(maxsize=1) def is_available() -> bool: - return _TPU_AVAILABLE + # check `_XLA_AVAILABLE` again to avoid launching processes + return bool(_XLA_AVAILABLE) and _is_device_tpu() @classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: @@ -59,6 +70,64 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None: ) +# define TPU availability timeout in seconds +TPU_CHECK_TIMEOUT = 60 + + +def _inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover + try: + queue.put(func(*args, **kwargs)) + except Exception: + traceback.print_exc() + queue.put(None) + + +def _multi_process(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]: + queue: Queue = Queue() + proc = Process(target=_inner_f, args=(queue, func, *args), kwargs=kwargs) + proc.start() + proc.join(TPU_CHECK_TIMEOUT) + try: + return queue.get_nowait() + except q.Empty: + traceback.print_exc() + return False + + return wrapper + + +@_multi_process +def _is_device_tpu() -> bool: + """Check if TPU devices are available. Runs XLA device check within a separate process. + + Return: + A boolean value indicating if TPU devices are available + """ + if not _XLA_AVAILABLE: + return False + import torch_xla.core.xla_model as xm + + # For the TPU Pod training process, for example, if we have + # TPU v3-32 with 4 VMs, the world size would be 4 and as + # we would have to use `torch_xla.distributed.xla_dist` for + # multiple VMs and TPU_CONFIG won't be available, running + # `xm.get_xla_supported_devices("TPU")` won't be possible. + return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU")) + + +_XLA_AVAILABLE = RequirementCache("torch_xla") + + +def tpu_distributed() -> bool: + if not TPUAccelerator.is_available(): + return False + import torch_xla.core.xla_model as xm + + return xm.xrt_world_size() > 1 + + def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]: """ Parses the tpu_cores given in the format as accepted by the diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index 3e9a7560d6472..4a9b9598535ec 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -55,7 +55,7 @@ from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn from lightning_lite.utilities.device_parser import determine_root_gpu_device -from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE +from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] @@ -301,7 +301,7 @@ def _check_device_config_and_set_final_flags( def _choose_auto_accelerator(self) -> str: """Choose the accelerator type (str) based on availability when ``accelerator='auto'``.""" if self._accelerator_flag == "auto": - if _TPU_AVAILABLE: + if TPUAccelerator.is_available(): return "tpu" if _IPU_AVAILABLE: return "ipu" @@ -328,13 +328,16 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: else: assert self._accelerator_flag is not None self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag) + accelerator_cls = self.accelerator.__class__ - if not self.accelerator.is_available(): + if not accelerator_cls.is_available(): available_accelerator = [ - acc_str for acc_str in self._registered_accelerators if ACCELERATOR_REGISTRY.get(acc_str).is_available() + acc_str + for acc_str in self._registered_accelerators + if ACCELERATOR_REGISTRY[acc_str]["accelerator"].is_available() ] raise RuntimeError( - f"{self.accelerator.__class__.__qualname__} can not run on your system" + f"`{accelerator_cls.__qualname__}` can not run on your system" " since the accelerator is not available. The following accelerator(s)" " is available and can be passed into `accelerator` argument of" f" `Lite`: {available_accelerator}." @@ -342,9 +345,9 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: self._set_devices_flag_if_auto_passed() - self._devices_flag = self.accelerator.parse_devices(self._devices_flag) + self._devices_flag = accelerator_cls.parse_devices(self._devices_flag) if not self._parallel_devices: - self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag) def _set_devices_flag_if_auto_passed(self) -> None: if self._devices_flag == "auto" or self._devices_flag is None: diff --git a/src/lightning_lite/plugins/environments/xla_environment.py b/src/lightning_lite/plugins/environments/xla_environment.py index da5a99c000d56..ce969cbef19c8 100644 --- a/src/lightning_lite/plugins/environments/xla_environment.py +++ b/src/lightning_lite/plugins/environments/xla_environment.py @@ -13,13 +13,10 @@ # limitations under the License. import logging import os +from typing import Any +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment -from lightning_lite.utilities.imports import _TPU_AVAILABLE - -if _TPU_AVAILABLE: - import torch_xla.core.xla_env_vars as xenv - import torch_xla.core.xla_model as xm log = logging.getLogger(__name__) @@ -31,36 +28,53 @@ class XLAEnvironment(ClusterEnvironment): `here `_. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(*args, **kwargs) + @property def creates_processes_externally(self) -> bool: return False @property def main_address(self) -> str: + import torch_xla.core.xla_env_vars as xenv + return os.environ[xenv.TPU_MESH_CTLER_ADDR] @property def main_port(self) -> int: + import torch_xla.core.xla_env_vars as xenv + return int(os.environ[xenv.TPU_MESH_CTLER_PORT]) @staticmethod def detect() -> bool: - return _TPU_AVAILABLE + return TPUAccelerator.is_available() def world_size(self) -> int: + import torch_xla.core.xla_model as xm + return xm.xrt_world_size() def set_world_size(self, size: int) -> None: log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") def global_rank(self) -> int: + import torch_xla.core.xla_model as xm + return xm.get_ordinal() def set_global_rank(self, rank: int) -> None: log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") def local_rank(self) -> int: + import torch_xla.core.xla_model as xm + return xm.get_local_ordinal() def node_rank(self) -> int: + import torch_xla.core.xla_env_vars as xenv + return int(os.environ.get(xenv.HOST_ORDINAL, 0)) diff --git a/src/lightning_lite/plugins/io/xla_plugin.py b/src/lightning_lite/plugins/io/xla_plugin.py index 1b97736d8f71d..75c13898ebeae 100644 --- a/src/lightning_lite/plugins/io/xla_plugin.py +++ b/src/lightning_lite/plugins/io/xla_plugin.py @@ -16,14 +16,12 @@ from lightning_utilities.core.apply_func import apply_to_collection +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO from lightning_lite.utilities.cloud_io import get_filesystem -from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE +from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE from lightning_lite.utilities.types import _PATH -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf @@ -31,6 +29,11 @@ class XLACheckpointIO(TorchCheckpointIO): """CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(*args, **kwargs) + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -55,4 +58,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio # Ref: https://github.com/pytorch/xla/issues/2773 if _OMEGACONF_AVAILABLE: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) + import torch_xla.core.xla_model as xm + xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path) diff --git a/src/lightning_lite/strategies/launchers/xla.py b/src/lightning_lite/strategies/launchers/xla.py index 1351229ec933f..bcb770d942791 100644 --- a/src/lightning_lite/strategies/launchers/xla.py +++ b/src/lightning_lite/strategies/launchers/xla.py @@ -17,15 +17,10 @@ from torch.multiprocessing import get_context +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher -from lightning_lite.utilities import _TPU_AVAILABLE from lightning_lite.utilities.apply_func import move_data_to_device -if _TPU_AVAILABLE: - import torch_xla.distributed.xla_multiprocessing as xmp -else: - xmp = None - if TYPE_CHECKING: from lightning_lite.strategies import XLAStrategy @@ -47,6 +42,8 @@ class _XLALauncher(_MultiProcessingLauncher): """ def __init__(self, strategy: "XLAStrategy") -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(strategy=strategy, start_method="fork") @property @@ -66,6 +63,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """ context = get_context(self._start_method) return_queue = context.SimpleQueue() + import torch_xla.distributed.xla_multiprocessing as xmp + xmp.spawn( self._wrapping_function, args=(function, args, kwargs, return_queue), diff --git a/src/lightning_lite/strategies/xla.py b/src/lightning_lite/strategies/xla.py index 48e2338f637c6..80165777814ac 100644 --- a/src/lightning_lite/strategies/xla.py +++ b/src/lightning_lite/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union import torch from torch import Tensor @@ -21,6 +21,7 @@ from torch.utils.data import DataLoader from lightning_lite.accelerators import Accelerator +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.plugins.environments import XLAEnvironment from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO @@ -28,20 +29,14 @@ from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy from lightning_lite.strategies.launchers.xla import _XLALauncher from lightning_lite.strategies.strategy import TBroadcast -from lightning_lite.utilities import _TPU_AVAILABLE from lightning_lite.utilities.apply_func import apply_to_collection from lightning_lite.utilities.data import has_len from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.rank_zero import rank_zero_only from lightning_lite.utilities.types import _PATH -if _TPU_AVAILABLE: - import torch_xla.core.xla_env_vars as xenv - import torch_xla.core.xla_model as xm - from torch_xla.core.xla_model import rendezvous +if TYPE_CHECKING and _XLA_AVAILABLE: from torch_xla.distributed.parallel_loader import MpDeviceLoader -else: - xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 class XLAStrategy(DDPSpawnStrategy): @@ -70,6 +65,8 @@ def __init__( def root_device(self) -> torch.device: if not self._launched: raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") + import torch_xla.core.xla_model as xm + return xm.xla_device() @property @@ -88,6 +85,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]: @property def is_distributed(self) -> bool: + import torch_xla.core.xla_env_vars as xenv + # HOST_WORLD_SIZE is not set outside the xmp.spawn process return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 @@ -105,8 +104,10 @@ def setup_module(self, module: Module) -> Module: def module_to_device(self, module: Module) -> None: module.to(self.root_device) - def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: + def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": XLAStrategy._validate_dataloader(dataloader) + from torch_xla.distributed.parallel_loader import MpDeviceLoader + dataloader = MpDeviceLoader(dataloader, self.root_device) # Mimic interface to torch.utils.data.DataLoader dataloader.dataset = dataloader._loader.dataset @@ -125,6 +126,7 @@ def reduce( "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" f" {reduce_op}" ) + import torch_xla.core.xla_model as xm output = xm.mesh_reduce("reduce", output, sum) @@ -135,7 +137,9 @@ def reduce( def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if self.is_distributed: - rendezvous(name) + import torch_xla.core.xla_model as xm + + xm.rendezvous(name) def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not self.is_distributed: @@ -144,6 +148,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + import torch_xla.core.xla_model as xm + data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) @@ -161,6 +167,8 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) + import torch_xla.core.xla_model as xm + return xm.all_gather(tensor) def save_checkpoint( diff --git a/src/lightning_lite/utilities/__init__.py b/src/lightning_lite/utilities/__init__.py index edeab0cd5d360..4237b5c23a405 100644 --- a/src/lightning_lite/utilities/__init__.py +++ b/src/lightning_lite/utilities/__init__.py @@ -29,8 +29,6 @@ _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12, - _TPU_AVAILABLE, - _XLA_AVAILABLE, ) from lightning_lite.utilities.rank_zero import ( # noqa: F401 rank_zero_deprecation, diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index 9c0feec8f7275..c9caa4e5122db 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -13,8 +13,7 @@ # limitations under the License. from typing import Any, List, MutableSequence, Optional, Tuple, Union -from lightning_lite.accelerators.cuda import _get_all_available_cuda_gpus -from lightning_lite.accelerators.mps import _get_all_available_mps_gpus +import lightning_lite.accelerators as accelerators # avoid circular dependency from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.types import _DEVICE @@ -161,8 +160,8 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals Returns: A list of all available GPUs """ - cuda_gpus = _get_all_available_cuda_gpus() if include_cuda else [] - mps_gpus = _get_all_available_mps_gpus() if include_mps else [] + cuda_gpus = accelerators.cuda._get_all_available_cuda_gpus() if include_cuda else [] + mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else [] return cuda_gpus + mps_gpus diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index be308d1fc778d..ce1f27e82d05d 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -3,18 +3,14 @@ from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union import torch +import torch.nn.functional as F from torch import Tensor -from torch.nn import functional as F from torch.utils.data import Dataset, DistributedSampler, Sampler from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment -from lightning_lite.utilities.imports import _HPU_AVAILABLE, _TPU_AVAILABLE +from lightning_lite.utilities.imports import _HPU_AVAILABLE from lightning_lite.utilities.rank_zero import rank_zero_info -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - - if torch.distributed.is_available(): from torch.distributed import group, ReduceOp else: @@ -89,6 +85,8 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L def distributed_available() -> bool: + from lightning_lite.accelerators.tpu import tpu_distributed + return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() @@ -245,10 +243,6 @@ def init_dist_connection( ) -def tpu_distributed() -> bool: - return _TPU_AVAILABLE and xm.xrt_world_size() > 1 - - def get_default_process_group_backend_for_device(device: torch.device) -> str: return "nccl" if device.type == "cuda" else "gloo" diff --git a/src/lightning_lite/utilities/imports.py b/src/lightning_lite/utilities/imports.py index aa9a1fed3726b..737d1d7a4a151 100644 --- a/src/lightning_lite/utilities/imports.py +++ b/src/lightning_lite/utilities/imports.py @@ -35,12 +35,6 @@ _HOROVOD_AVAILABLE = module_available("horovod.torch") _OMEGACONF_AVAILABLE = package_available("omegaconf") _POPTORCH_AVAILABLE = package_available("poptorch") -_XLA_AVAILABLE: bool = package_available("torch_xla") - - -from lightning_lite.utilities.xla_device import XLADeviceUtils # noqa: E402 - -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if _POPTORCH_AVAILABLE: import poptorch diff --git a/src/lightning_lite/utilities/xla_device.py b/src/lightning_lite/utilities/xla_device.py deleted file mode 100644 index cc0bfb78823bc..0000000000000 --- a/src/lightning_lite/utilities/xla_device.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import os -import queue as q -import traceback -from multiprocessing import Process, Queue -from typing import Any, Callable, Union - -from lightning_lite.utilities.imports import _XLA_AVAILABLE - -if _XLA_AVAILABLE: - import torch_xla.core.xla_model as xm - -# define TPU availability timeout in seconds -TPU_CHECK_TIMEOUT = 60 - - -def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover - try: - queue.put(func(*args, **kwargs)) - # todo: specify the possible exception - except Exception: - traceback.print_exc() - queue.put(None) - - -def pl_multi_process(func: Callable) -> Callable: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]: - queue: Queue = Queue() - proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) - proc.start() - proc.join(TPU_CHECK_TIMEOUT) - try: - return queue.get_nowait() - except q.Empty: - traceback.print_exc() - return False - - return wrapper - - -class XLADeviceUtils: - """Used to detect the type of XLA device.""" - - _TPU_AVAILABLE = False - - @staticmethod - @pl_multi_process - def _is_device_tpu() -> bool: - """Check if TPU devices are available. - - Return: - A boolean value indicating if TPU devices are available - """ - # For the TPU Pod training process, for example, if we have - # TPU v3-32 with 4 VMs, the world size would be 4 and as - # we would have to use `torch_xla.distributed.xla_dist` for - # multiple VMs and TPU_CONFIG won't be available, running - # `xm.get_xla_supported_devices("TPU")` won't be possible. - return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU")) - - @staticmethod - def xla_available() -> bool: - """Check if XLA library is installed. - - Return: - A boolean value indicating if a XLA is installed - """ - return _XLA_AVAILABLE - - @staticmethod - def tpu_device_exists() -> bool: - """Runs XLA device check within a separate process. - - Return: - A boolean value indicating if a TPU device exists on the system - """ - if os.getenv("PL_TPU_AVAILABLE", "0") == "1": - XLADeviceUtils._TPU_AVAILABLE = True - - if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE: - - XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu() - - if XLADeviceUtils._TPU_AVAILABLE: - os.environ["PL_TPU_AVAILABLE"] = "1" - return XLADeviceUtils._TPU_AVAILABLE diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 79af45be188f8..ace0dc6c4a0ce 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -141,7 +141,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the internal `pl.core.mixins.DeviceDtypeModuleMixin` class ([#14511](https://github.com/Lightning-AI/lightning/pull/14511), [#14548](https://github.com/Lightning-AI/lightning/pull/14548)) -- Deprecated all functions in `pytorch_lightning.utilities.xla_device` in favor of `lightning_lite.utilities.xla_device` ([#14514](https://github.com/Lightning-AI/lightning/pull/14514)) +- Deprecated all functions in `pytorch_lightning.utilities.xla_device` ([#14514](https://github.com/Lightning-AI/lightning/pull/14514), [#14550](https://github.com/Lightning-AI/lightning/pull/14550)) + * Deprecated the internal `inner_f` function + * Deprecated the internal `pl_multi_process` function + * Deprecated the internal `XLADeviceUtils.xla_available` staticmethod + * Deprecated the `XLADeviceUtils.tpu_device_exists` staticmethod in favor of `pytorch_lightning.accelerators.TPUAccelerator.is_available()` + + +- Deprecated `pytorch_lightning.utilities.distributed.tpu_distributed` in favor of `lightning_lite.accelerators.tpu.tpu_distributed` ([#14550](https://github.com/Lightning-AI/lightning/pull/14550)) - Deprecated all functions in `pytorch_lightning.utilities.cloud_io` in favor of `lightning_lite.utilities.cloud_io` ([#14515](https://github.com/Lightning-AI/lightning/pull/14515)) diff --git a/src/pytorch_lightning/accelerators/tpu.py b/src/pytorch_lightning/accelerators/tpu.py index ddb981d3545a1..ae1916fb8381a 100644 --- a/src/pytorch_lightning/accelerators/tpu.py +++ b/src/pytorch_lightning/accelerators/tpu.py @@ -15,15 +15,20 @@ import torch -from lightning_lite.accelerators.tpu import parse_tpu_cores +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE, parse_tpu_cores +from lightning_lite.accelerators.tpu import TPUAccelerator as LiteTPUAccelerator from lightning_lite.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.utilities.imports import _TPU_AVAILABLE class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(*args, **kwargs) + def setup_device(self, device: torch.device) -> None: pass @@ -69,12 +74,12 @@ def auto_device_count() -> int: @staticmethod def is_available() -> bool: - return _TPU_AVAILABLE + return LiteTPUAccelerator.is_available() @classmethod def register_accelerators(cls, accelerator_registry: Dict) -> None: accelerator_registry.register( "tpu", cls, - description=f"{cls.__class__.__name__}", + description=cls.__class__.__name__, ) diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py index 3af98a7b26ce8..83dfdaa3cc84f 100644 --- a/src/pytorch_lightning/plugins/precision/tpu.py +++ b/src/pytorch_lightning/plugins/precision/tpu.py @@ -15,18 +15,20 @@ from typing import Any, Callable import pytorch_lightning as pl +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.utilities.types import Steppable from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _XLA_AVAILABLE: - import torch_xla.core.xla_model as xm - class TPUPrecisionPlugin(PrecisionPlugin): """Precision plugin for TPU integration.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + super().__init__(*args, **kwargs) + def optimizer_step( # type: ignore[override] self, optimizer: Steppable, @@ -35,6 +37,8 @@ def optimizer_step( # type: ignore[override] closure: Callable[[], Any], **kwargs: Any, ) -> Any: + import torch_xla.core.xla_model as xm + closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) skipped_backward = closure_result is None diff --git a/src/pytorch_lightning/profilers/xla.py b/src/pytorch_lightning/profilers/xla.py index 0f86d63b546eb..ef103a9a45842 100644 --- a/src/pytorch_lightning/profilers/xla.py +++ b/src/pytorch_lightning/profilers/xla.py @@ -14,12 +14,8 @@ import logging from typing import Dict +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from pytorch_lightning.profilers.profiler import Profiler -from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -if _TPU_AVAILABLE: - import torch_xla.debug.profiler as xp log = logging.getLogger(__name__) @@ -43,8 +39,8 @@ def __init__(self, port: int = 9012) -> None: port: the port to start the profiler server on. An exception is raised if the provided port is invalid or busy. """ - if not _TPU_AVAILABLE: - raise MisconfigurationException("`XLAProfiler` is only supported on TPUs") + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(dirpath=None, filename=None) self.port = port self._recording_map: Dict = {} @@ -52,6 +48,8 @@ def __init__(self, port: int = 9012) -> None: self._start_trace: bool = False def start(self, action_name: str) -> None: + import torch_xla.debug.profiler as xp + if action_name in self.RECORD_FUNCTIONS: if not self._start_trace: self.server = xp.start_server(self.port) diff --git a/src/pytorch_lightning/strategies/launchers/xla.py b/src/pytorch_lightning/strategies/launchers/xla.py index 1528698445f66..d6e623da58937 100644 --- a/src/pytorch_lightning/strategies/launchers/xla.py +++ b/src/pytorch_lightning/strategies/launchers/xla.py @@ -18,6 +18,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.strategies.launchers.xla import _rank_teardown from lightning_lite.utilities import move_data_to_device from pytorch_lightning.strategies.launchers.multiprocessing import ( @@ -27,14 +28,8 @@ _WorkerOutput, ) from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_debug -if _TPU_AVAILABLE: - import torch_xla.distributed.xla_multiprocessing as xmp -else: - xmp = None - class _XLALauncher(_MultiProcessingLauncher): r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the @@ -53,6 +48,8 @@ class _XLALauncher(_MultiProcessingLauncher): """ def __init__(self, strategy: "pl.strategies.TPUSpawnStrategy") -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(strategy=strategy, start_method="fork") @property @@ -74,6 +71,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] """ context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() + import torch_xla.distributed.xla_multiprocessing as xmp + xmp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), diff --git a/src/pytorch_lightning/strategies/single_tpu.py b/src/pytorch_lightning/strategies/single_tpu.py index afe4c02ea8a3d..07543fb8881ad 100644 --- a/src/pytorch_lightning/strategies/single_tpu.py +++ b/src/pytorch_lightning/strategies/single_tpu.py @@ -15,14 +15,12 @@ from typing import Dict, Optional import pytorch_lightning as pl +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.plugins import CheckpointIO, XLACheckpointIO from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.single_device import SingleDeviceStrategy -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters - -if _TPU_AVAILABLE: - import torch_xla.core.xla_model as xm +from pytorch_lightning.utilities import find_shared_parameters, set_shared_parameters class SingleTPUStrategy(SingleDeviceStrategy): @@ -38,6 +36,10 @@ def __init__( precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) + import torch_xla.core.xla_model as xm + super().__init__( accelerator=accelerator, device=xm.xla_device(device), diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index d220b719ccbfb..76c505a0a3ed2 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl +from lightning_lite.accelerators.tpu import _XLA_AVAILABLE from lightning_lite.plugins import CheckpointIO, XLACheckpointIO from lightning_lite.plugins.environments import XLAEnvironment from lightning_lite.utilities.data import has_len @@ -36,19 +37,15 @@ from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities import find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS -if _TPU_AVAILABLE: - import torch_xla.core.xla_env_vars as xenv - import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp - from torch_xla.core.xla_model import rendezvous +if TYPE_CHECKING and _XLA_AVAILABLE: from torch_xla.distributed.parallel_loader import MpDeviceLoader else: - xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 + MpDeviceLoader = None class TPUSpawnStrategy(DDPSpawnStrategy): @@ -66,6 +63,8 @@ def __init__( debug: bool = False, **_: Any, ) -> None: + if not _XLA_AVAILABLE: + raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__( accelerator=accelerator, parallel_devices=parallel_devices, @@ -95,6 +94,8 @@ def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: def root_device(self) -> torch.device: if not self._launched: raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.") + import torch_xla.core.xla_model as xm + return xm.xla_device() @staticmethod @@ -126,6 +127,8 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: def connect(self, model: "pl.LightningModule") -> None: TPUSpawnStrategy._validate_patched_dataloaders(model) + import torch_xla.distributed.xla_multiprocessing as xmp + self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) @@ -160,10 +163,14 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]: @property def is_distributed(self) -> bool: # HOST_WORLD_SIZE is not set outside the xmp.spawn process + import torch_xla.core.xla_env_vars as xenv + return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 - def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: + def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": TPUSpawnStrategy._validate_dataloader(dataloader) + from torch_xla.distributed.parallel_loader import MpDeviceLoader + dataloader = MpDeviceLoader(dataloader, self.root_device) # Mimic interface to torch.utils.data.DataLoader dataloader.dataset = dataloader._loader.dataset @@ -177,7 +184,9 @@ def model_to_device(self) -> None: def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if self.is_distributed: - rendezvous(name) + import torch_xla.core.xla_model as xm + + xm.rendezvous(name) def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not self.is_distributed: @@ -186,6 +195,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float) + import torch_xla.core.xla_model as xm + data = xm.all_gather(data_tensor) buffer = io.BytesIO(data.cpu().byte().numpy()) obj = torch.load(buffer) @@ -205,6 +216,8 @@ def reduce( f" {reduce_op}" ) + import torch_xla.core.xla_model as xm + output = xm.mesh_reduce("reduce", output, sum) if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): @@ -249,6 +262,8 @@ def _pod_progress_bar_force_stdout(self) -> None: # from different vms to the main worker doesn't work well with tqdm # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 # The print statement seems to force tqdm to flush stdout. + import torch_xla.core.xla_env_vars as xenv + if self.global_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: print() @@ -286,6 +301,8 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo """ if isinstance(tensor, Tensor) and tensor.dim() == 0: tensor = tensor.unsqueeze(0) + import torch_xla.core.xla_model as xm + return xm.all_gather(tensor) def teardown(self) -> None: diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8ba07c53d798e..a97d7bffab28b 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -83,7 +83,6 @@ _IPU_AVAILABLE, _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11, - _TPU_AVAILABLE, ) from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn @@ -494,7 +493,7 @@ def _set_accelerator_if_ipu_strategy_is_passed(self) -> None: def _choose_auto_accelerator(self) -> str: """Choose the accelerator type (str) based on availability when ``accelerator='auto'``.""" if self._accelerator_flag == "auto": - if _TPU_AVAILABLE: + if TPUAccelerator.is_available(): return "tpu" if _IPU_AVAILABLE: return "ipu" @@ -521,15 +520,16 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: else: assert self._accelerator_flag is not None self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) + accelerator_cls = self.accelerator.__class__ - if not self.accelerator.is_available(): + if not accelerator_cls.is_available(): available_accelerator = [ acc_str for acc_str in self._accelerator_types if AcceleratorRegistry[acc_str]["accelerator"].is_available() ] raise MisconfigurationException( - f"{self.accelerator.__class__.__qualname__} can not run on your system" + f"`{accelerator_cls.__qualname__}` can not run on your system" " since the accelerator is not available. The following accelerator(s)" " is available and can be passed into `accelerator` argument of" f" `Trainer`: {available_accelerator}." @@ -542,9 +542,9 @@ def _set_parallel_devices_and_init_accelerator(self) -> None: self._set_devices_flag_if_auto_select_gpus_passed() - self._devices_flag = self.accelerator.parse_devices(self._devices_flag) + self._devices_flag = accelerator_cls.parse_devices(self._devices_flag) if not self._parallel_devices: - self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag) + self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag) def _set_devices_flag_if_auto_passed(self) -> None: if self._devices_flag == "auto" or self._devices_flag is None: diff --git a/src/pytorch_lightning/trainer/setup.py b/src/pytorch_lightning/trainer/setup.py index 00c5c0f762711..bf9eb275a8157 100644 --- a/src/pytorch_lightning/trainer/setup.py +++ b/src/pytorch_lightning/trainer/setup.py @@ -33,7 +33,7 @@ SimpleProfiler, XLAProfiler, ) -from pytorch_lightning.utilities import _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE +from pytorch_lightning.utilities import _HPU_AVAILABLE, _IPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn @@ -162,7 +162,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None: rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, TPUAccelerator) else 0 - rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores") + rank_zero_info(f"TPU available: {TPUAccelerator.is_available()}, using: {num_tpu_cores} TPU cores") num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0 rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs") @@ -178,7 +178,7 @@ def _log_device_info(trainer: "pl.Trainer") -> None: category=PossibleUserWarning, ) - if _TPU_AVAILABLE and not isinstance(trainer.accelerator, TPUAccelerator): + if TPUAccelerator.is_available() and not isinstance(trainer.accelerator, TPUAccelerator): rank_zero_warn( "TPU available but not used. Set `accelerator` and `devices` using" f" `Trainer(accelerator='tpu', devices={TPUAccelerator.auto_device_count()})`." diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index c9740016538f6..dc5c81f2a8919 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -34,8 +34,6 @@ _TORCH_GREATER_EQUAL_1_12, _TORCH_QUANTIZE_AVAILABLE, _TORCHVISION_AVAILABLE, - _TPU_AVAILABLE, - _XLA_AVAILABLE, ) from pytorch_lightning.utilities.parameter_tying import find_shared_parameters, set_shared_parameters # noqa: F401 from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 diff --git a/src/pytorch_lightning/utilities/distributed.py b/src/pytorch_lightning/utilities/distributed.py index 6f01a1a5b447e..537f1eccf6bff 100644 --- a/src/pytorch_lightning/utilities/distributed.py +++ b/src/pytorch_lightning/utilities/distributed.py @@ -26,7 +26,6 @@ from lightning_lite.utilities.distributed import init_dist_connection as new_init_dist_connection from lightning_lite.utilities.distributed import sync_ddp as new_sync_ddp from lightning_lite.utilities.distributed import sync_ddp_if_available as new_sync_ddp_if_available -from lightning_lite.utilities.distributed import tpu_distributed as new_tpu_distributed from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_info @@ -213,6 +212,8 @@ def sync_ddp_if_available(*args: Any, **kwargs: Any) -> Any: def tpu_distributed() -> bool: rank_zero_deprecation( "`pytorch_lightning.utilities.distributed.tpu_distributed` has been deprecated in v1.8.0 and will" - " be removed in v1.10.0. Please use `lightning_lite.utilities.distributed.tpu_distributed` instead." + " be removed in v1.10.0. Please use `lightning_lite.accelerators.tpu.tpu_distributed` instead." ) - return new_tpu_distributed() + from lightning_lite.accelerators.tpu import tpu_distributed + + return tpu_distributed() diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index d870d0faab823..de686d93f5e92 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -42,12 +42,6 @@ _RICH_AVAILABLE = package_available("rich") and compare_version("rich", operator.ge, "10.2.2") _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) _TORCHVISION_AVAILABLE = RequirementCache("torchvision") -_XLA_AVAILABLE: bool = package_available("torch_xla") - - -from lightning_lite.utilities.xla_device import XLADeviceUtils # noqa: E402 - -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() if _POPTORCH_AVAILABLE: import poptorch diff --git a/src/pytorch_lightning/utilities/xla_device.py b/src/pytorch_lightning/utilities/xla_device.py index a515058a63c1f..7584c91027d46 100644 --- a/src/pytorch_lightning/utilities/xla_device.py +++ b/src/pytorch_lightning/utilities/xla_device.py @@ -15,48 +15,53 @@ from multiprocessing import Queue from typing import Any, Callable -from lightning_lite.utilities.xla_device import inner_f as new_inner_f -from lightning_lite.utilities.xla_device import pl_multi_process as new_pl_multi_process -from lightning_lite.utilities.xla_device import XLADeviceUtils as NewXLADeviceUtils -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_deprecation def inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover rank_zero_deprecation( "`pytorch_lightning.utilities.xla_device.inner_f` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.inner_f` instead." + " removed in v1.10.0. This class is internal but you can copy over its implementation." ) - return new_inner_f(queue, func, *args, **kwargs) + from lightning_lite.accelerators.tpu import _inner_f + + return _inner_f(queue, func, *args, **kwargs) def pl_multi_process(func: Callable) -> Callable: rank_zero_deprecation( "`pytorch_lightning.utilities.xla_device.pl_multi_process` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.pl_multi_process` instead." + " removed in v1.10.0. This class is internal but you can copy over its implementation." ) - return new_pl_multi_process(func) + from lightning_lite.accelerators.tpu import _multi_process + + return _multi_process(func) -class XLADeviceUtils(NewXLADeviceUtils): +class XLADeviceUtils: def __init__(self) -> None: rank_zero_deprecation( "`pytorch_lightning.utilities.xla_device.XLADeviceUtils` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.XLADeviceUtils` instead." + " removed in v1.10.0. This class is internal." ) - super().__init__() @staticmethod def xla_available() -> bool: rank_zero_deprecation( - "`pytorch_lightning.utilities.xla_device.XLADeviceUtils` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.XLADeviceUtils` instead." + "`pytorch_lightning.utilities.xla_device.XLADeviceUtils.xla_available` has been deprecated in v1.8.0 and" + " will be removed in v1.10.0. This method is internal." ) - return NewXLADeviceUtils.xla_available() + from pytorch_lightning.accelerators.tpu import _XLA_AVAILABLE + + return bool(_XLA_AVAILABLE) @staticmethod def tpu_device_exists() -> bool: rank_zero_deprecation( - "`pytorch_lightning.utilities.xla_device.XLADeviceUtils` has been deprecated in v1.8.0 and will be" - " removed in v1.10.0. Please use `lightning_lite.utilities.xla_device.XLADeviceUtils` instead." + "`pytorch_lightning.utilities.xla_device.XLADeviceUtils.tpu_device_exists` has been deprecated in v1.8.0" + " and will be removed in v1.10.0. Please use `pytorch_lightning.accelerators.TPUAccelerator.is_available()`" + " instead." ) - return NewXLADeviceUtils.tpu_device_exists() + from pytorch_lightning.accelerators.tpu import TPUAccelerator + + return TPUAccelerator.is_available() diff --git a/tests/tests_lite/conftest.py b/tests/tests_lite/conftest.py index c38d93ff24aae..efce3dcf79b1b 100644 --- a/tests/tests_lite/conftest.py +++ b/tests/tests_lite/conftest.py @@ -17,6 +17,8 @@ import pytest import torch.distributed +import lightning_lite + @pytest.fixture(scope="function", autouse=True) def preserve_global_rank_variable(): @@ -52,17 +54,6 @@ def restore_env_variables(): "HOROVOD_FUSION_THRESHOLD", "RANK", # set by DeepSpeed "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy - # set by XLA - "TF2_BEHAVIOR", - "XRT_MESH_SERVICE_ADDRESS", - "XRT_TORCH_DIST_ROOT", - "XRT_MULTI_PROCESSING_DEVICE", - "XRT_SHARD_WORLD_SIZE", - "XRT_LOCAL_WORKER", - "XRT_HOST_WORLD_SIZE", - "XRT_SHARD_ORDINAL", - "XRT_SHARD_LOCAL_ORDINAL", - "TF_CPP_MIN_LOG_LEVEL", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -83,6 +74,19 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) +@pytest.fixture(scope="function") +def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(lightning_lite.accelerators.tpu, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.plugins.environments.xla_environment, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.strategies.xla, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.strategies.launchers.xla, "_XLA_AVAILABLE", True) + + +@pytest.fixture(scope="function") +def tpu_available(xla_available, monkeypatch) -> None: + monkeypatch.setattr(lightning_lite.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) + + @pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. diff --git a/tests/tests_lite/helpers/runif.py b/tests/tests_lite/helpers/runif.py index a3f484255c84e..6a40a47b9a770 100644 --- a/tests/tests_lite/helpers/runif.py +++ b/tests/tests_lite/helpers/runif.py @@ -20,10 +20,11 @@ from packaging.version import Version from pkg_resources import get_distribution +from lightning_lite.accelerators import TPUAccelerator from lightning_lite.accelerators.mps import MPSAccelerator from lightning_lite.strategies.deepspeed import _DEEPSPEED_AVAILABLE from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10, _TPU_AVAILABLE +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 class RunIf: @@ -112,7 +113,7 @@ def __new__( reasons.append("unimplemented on Windows") if tpu: - conditions.append(not _TPU_AVAILABLE) + conditions.append(not TPUAccelerator.is_available()) reasons.append("TPU") # used in conftest.py::pytest_collection_modifyitems kwargs["tpu"] = True diff --git a/tests/tests_lite/plugins/environments/test_xla_environment.py b/tests/tests_lite/plugins/environments/test_xla_environment.py index 313aab368b2ff..17b858a0646f2 100644 --- a/tests/tests_lite/plugins/environments/test_xla_environment.py +++ b/tests/tests_lite/plugins/environments/test_xla_environment.py @@ -72,8 +72,8 @@ def test_attributes_from_environment_variables(): def test_detect(monkeypatch): """Test the detection of a xla environment configuration.""" - monkeypatch.setattr(lightning_lite.plugins.environments.xla_environment, "_TPU_AVAILABLE", False) + monkeypatch.setattr(lightning_lite.accelerators.tpu.TPUAccelerator, "is_available", lambda: False) assert not XLAEnvironment.detect() - monkeypatch.setattr(lightning_lite.plugins.environments.xla_environment, "_TPU_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) assert XLAEnvironment.detect() diff --git a/tests/tests_lite/strategies/launchers/test_xla.py b/tests/tests_lite/strategies/launchers/test_xla.py index 223414e81e537..846c64c4ae4af 100644 --- a/tests/tests_lite/strategies/launchers/test_xla.py +++ b/tests/tests_lite/strategies/launchers/test_xla.py @@ -7,19 +7,19 @@ @RunIf(skip_windows=True) -def test_xla_launcher_default_start_method(): +def test_xla_launcher_default_start_method(xla_available): launcher = _XLALauncher(strategy=Mock()) assert launcher._start_method == "fork" @RunIf(skip_windows=True) -def test_xla_launcher_interactive_compatible(): +def test_xla_launcher_interactive_compatible(xla_available): launcher = _XLALauncher(strategy=Mock()) assert launcher.is_interactive_compatible -@RunIf(skip_windows=True) -@mock.patch("lightning_lite.strategies.launchers.xla.xmp") +@RunIf(skip_windows=True, tpu=True) +@mock.patch("torch_xla.distributed.xla_multiprocessing") @mock.patch("lightning_lite.strategies.launchers.xla.get_context") def test_xla_launcher_xmp_spawn(get_context_mock, xmp_mock): strategy = Mock() diff --git a/tests/tests_lite/strategies/test_xla.py b/tests/tests_lite/strategies/test_xla.py index 536393d98d599..c0abc61c7bb6e 100644 --- a/tests/tests_lite/strategies/test_xla.py +++ b/tests/tests_lite/strategies/test_xla.py @@ -1,4 +1,6 @@ +import os from functools import partial +from unittest import mock import pytest from tests_lite.helpers.runif import RunIf @@ -32,6 +34,7 @@ def broadcast_on_tpu_fn(strategy): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_broadcast_on_tpu(): """Checks if an object from the main process is broadcasted to other processes correctly.""" xla_launch(broadcast_on_tpu_fn) @@ -54,6 +57,7 @@ def tpu_reduce_fn(strategy): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_reduce(): """Test tpu spawn reduce operation.""" xla_launch(tpu_reduce_fn) diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 73a22fc473c3f..7595040537543 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -225,8 +225,7 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch): @RunIf(skip_windows=True) -@mock.patch("lightning_lite.accelerators.tpu.TPUAccelerator.is_available", return_value=True) -def test_ipython_compatible_strategy_tpu(_, monkeypatch): +def test_ipython_compatible_strategy_tpu(tpu_available, monkeypatch): monkeypatch.setattr(lightning_lite.utilities, "_IS_INTERACTIVE", True) connector = _Connector(accelerator="tpu") assert connector.strategy.launcher.is_interactive_compatible @@ -258,13 +257,13 @@ def test_strategy_choice_multi_node_gpu(_, strategy, strategy_class, devices): @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=0) -def test_accelerator_cpu(*_): +def test_accelerator_cpu(_): connector = _Connector(accelerator="cpu") assert isinstance(connector.accelerator, CPUAccelerator) with pytest.raises( RuntimeError, - match="CUDAAccelerator can not run on your system since the accelerator is not available.", + match="CUDAAccelerator` can not run on your system since the accelerator is not available.", ): _Connector(accelerator="cuda", devices=1) @@ -587,23 +586,21 @@ def test_strategy_choice_ddp_cpu_slurm(strategy): assert connector.strategy.local_rank == 0 -@mock.patch("lightning_lite.accelerators.tpu.TPUAccelerator.is_available", return_value=True) @mock.patch.dict(os.environ, {}, clear=True) -def test_unsupported_tpu_choice(*_): - +def test_unsupported_tpu_choice(tpu_available): with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision=64\)` is not implemented"): _Connector(accelerator="tpu", precision=64) # if user didn't set strategy, _Connector will choose the TPUSingleStrategy or TPUSpawnStrategy - with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): - with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"): - _Connector(accelerator="tpu", precision=16, strategy="ddp") + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( + UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported" + ): + _Connector(accelerator="tpu", precision=16, strategy="ddp") @mock.patch("lightning_lite.accelerators.cuda.CUDAAccelerator.is_available", return_value=False) -@mock.patch("lightning_lite.accelerators.tpu.TPUAccelerator.is_available", return_value=False) @mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) -def test_devices_auto_choice_cpu(*_): +def test_devices_auto_choice_cpu(tpu_available, *_): connector = _Connector(accelerator="auto", devices="auto") assert isinstance(connector.accelerator, CPUAccelerator) assert isinstance(connector.strategy, SingleDeviceStrategy) @@ -681,9 +678,9 @@ def test_gpu_accelerator_backend_choice_cuda(*_): assert isinstance(connector.accelerator, CUDAAccelerator) +@RunIf(min_torch="1.12") @mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=True) @mock.patch("lightning_lite.accelerators.mps._get_all_available_mps_gpus", return_value=[0]) -@mock.patch("torch.device", return_value="mps") # necessary because torch doesn't allow creation of mps devices def test_gpu_accelerator_backend_choice_mps(*_): connector = _Connector(accelerator="gpu") assert connector._accelerator_flag == "mps" diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 73cdb9a7bd8e2..4ebb4bdcab347 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -371,6 +371,7 @@ def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): pytest.param("gpu", "mps:0", marks=RunIf(mps=True)), ], ) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_to_device(accelerator, expected): """Test that the to_device method can move various objects to the device determined by the accelerator.""" diff --git a/tests/tests_lite/utilities/test_xla_device_utils.py b/tests/tests_lite/utilities/test_xla_device_utils.py index 87c92b772c520..fd40d704db194 100644 --- a/tests/tests_lite/utilities/test_xla_device_utils.py +++ b/tests/tests_lite/utilities/test_xla_device_utils.py @@ -17,20 +17,19 @@ import pytest from tests_lite.helpers.runif import RunIf -import lightning_lite.utilities.xla_device as xla_utils -from lightning_lite.utilities.imports import _XLA_AVAILABLE +from lightning_lite.accelerators.tpu import _multi_process, _XLA_AVAILABLE, TPUAccelerator @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): - """Check tpu_device_exists returns False when torch_xla is not available.""" - assert not xla_utils.XLADeviceUtils.tpu_device_exists() + """Check `is_available` returns True when TPU is available.""" + assert not TPUAccelerator.is_available() @RunIf(tpu=True) def test_tpu_device_presence(): - """Check tpu_device_exists returns True when TPU is available.""" - assert xla_utils.XLADeviceUtils.tpu_device_exists() + """Check `is_available` returns True when TPU is available.""" + assert TPUAccelerator.is_available() def sleep_fn(sleep_time: float) -> bool: @@ -38,16 +37,18 @@ def sleep_fn(sleep_time: float) -> bool: return True -@patch("lightning_lite.utilities.xla_device.TPU_CHECK_TIMEOUT", 3) +@patch("lightning_lite.accelerators.tpu.TPU_CHECK_TIMEOUT", 3) @pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): """Check that pl_multi_process returns within 3 seconds.""" - fn = xla_utils.pl_multi_process(sleep_fn) + fn = _multi_process(sleep_fn) start = time.time() - result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5) + from lightning_lite.accelerators.tpu import TPU_CHECK_TIMEOUT + + result = fn(TPU_CHECK_TIMEOUT * 0.5) end = time.time() elapsed_time = int(end - start) - assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT + assert elapsed_time <= TPU_CHECK_TIMEOUT assert result diff --git a/tests/tests_pytorch/accelerators/test_hpu.py b/tests/tests_pytorch/accelerators/test_hpu.py index 4947000b47162..113266c8b61a7 100644 --- a/tests/tests_pytorch/accelerators/test_hpu.py +++ b/tests/tests_pytorch/accelerators/test_hpu.py @@ -47,7 +47,7 @@ def test_device_name(): @pytest.mark.skipif(_HPU_AVAILABLE, reason="test requires non-HPU machine") def test_fail_if_no_hpus(): - with pytest.raises(MisconfigurationException, match="HPUAccelerator can not run on your system"): + with pytest.raises(MisconfigurationException, match="HPUAccelerator` can not run on your system"): Trainer(accelerator="hpu", devices=1) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 74edcc6ea86d4..85ce3cac3a31c 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -20,7 +20,6 @@ import pytest import torch from torch import nn -from torch.multiprocessing import ProcessExitedException from torch.utils.data import DataLoader from pytorch_lightning import Trainer @@ -49,6 +48,7 @@ def forward(self, x): @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_resume_training_on_cpu(tmpdir): """Checks if training can be resumed from a saved checkpoint on CPU.""" # Train a model on TPU @@ -69,8 +69,7 @@ def test_resume_training_on_cpu(tmpdir): @RunIf(tpu=True) -@mock.patch.dict(os.environ, {}, clear=True) -@pytest.mark.xfail(raises=ProcessExitedException, reason="https://github.com/pytorch/xla/issues/1666") +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_if_test_works_after_train(tmpdir): """Ensure that .test() works after .fit()""" model = BoringModel() @@ -80,8 +79,8 @@ def test_if_test_works_after_train(tmpdir): assert len(out) == 1 -@RunIf(tpu=True) -def test_accelerator_cpu_with_tpu_cores_flag(): +@RunIf(skip_windows=True) +def test_accelerator_cpu_with_tpu_cores_flag(tpu_available): assert TPUAccelerator.is_available() trainer = Trainer(accelerator="cpu", devices=8) @@ -92,9 +91,9 @@ def test_accelerator_cpu_with_tpu_cores_flag(): assert isinstance(trainer.strategy, TPUSpawnStrategy) -@RunIf(tpu=True) +@RunIf(skip_windows=True) @pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)]) -def test_accelerator_tpu(accelerator, devices): +def test_accelerator_tpu(accelerator, devices, tpu_available): assert TPUAccelerator.is_available() trainer = Trainer(accelerator=accelerator, devices=devices) @@ -103,10 +102,9 @@ def test_accelerator_tpu(accelerator, devices): assert trainer.num_devices == 8 -@RunIf(tpu=True) -def test_accelerator_tpu_with_tpu_cores_priority(): +@RunIf(skip_windows=True) +def test_accelerator_tpu_with_tpu_cores_priority(tpu_available): """Test for checking `tpu_cores` flag takes priority over `devices`.""" - tpu_cores = 8 with pytest.warns(UserWarning, match="The flag `devices=1` will be ignored,"): trainer = Trainer(accelerator="tpu", devices=1, tpu_cores=tpu_cores) @@ -115,8 +113,8 @@ def test_accelerator_tpu_with_tpu_cores_priority(): assert trainer.num_devices == tpu_cores -@RunIf(tpu=True) -def test_set_devices_if_none_tpu(): +@RunIf(skip_windows=True) +def test_set_devices_if_none_tpu(tpu_available): with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."): trainer = Trainer(accelerator="tpu", tpu_cores=8) assert isinstance(trainer.accelerator, TPUAccelerator) @@ -124,6 +122,7 @@ def test_set_devices_if_none_tpu(): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_manual_optimization_tpus(tmpdir): class ManualOptimizationModel(BoringModel): @@ -197,25 +196,25 @@ def on_train_end(self): assert not torch.equal(param.cpu().data, param_copy.data) -@RunIf(tpu=True) -def test_strategy_choice_tpu_str_ddp_spawn(tmpdir): +def test_strategy_choice_tpu_str_ddp_spawn(tpu_available): with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): Trainer(strategy="ddp_spawn", accelerator="tpu", devices=8) -@RunIf(tpu=True) -def test_strategy_choice_tpu_str_tpu_spawn_debug(tmpdir): +@RunIf(skip_windows=True) +def test_strategy_choice_tpu_str_tpu_spawn_debug(tpu_available): trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8) assert isinstance(trainer.strategy, TPUSpawnStrategy) @RunIf(tpu=True) -def test_strategy_choice_tpu_strategy(tmpdir): +def test_strategy_choice_tpu_strategy(): trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) assert isinstance(trainer.strategy, TPUSpawnStrategy) @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_auto_parameters_tying_tpus(tmpdir): model = WeightSharingModule() @@ -230,6 +229,7 @@ def test_auto_parameters_tying_tpus(tmpdir): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_auto_parameters_tying_tpus_nested_module(tmpdir): class SubModule(nn.Module): def __init__(self, layer): @@ -261,8 +261,7 @@ def forward(self, x): assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight)) -@RunIf(tpu=True) -def test_tpu_invalid_raises(): +def test_tpu_invalid_raises(tpu_available): strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): Trainer(strategy=strategy, devices=8) @@ -272,8 +271,7 @@ def test_tpu_invalid_raises(): Trainer(strategy=strategy, devices=8) -@RunIf(tpu=True) -def test_tpu_invalid_raises_set_precision_with_strategy(): +def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available): accelerator = TPUAccelerator() strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): @@ -287,13 +285,14 @@ def test_tpu_invalid_raises_set_precision_with_strategy(): Trainer(strategy=strategy, devices=8) -@RunIf(tpu=True) -def test_xla_checkpoint_plugin_being_default(): +@RunIf(skip_windows=True) +def test_xla_checkpoint_plugin_being_default(tpu_available): trainer = Trainer(accelerator="tpu", devices=8) assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO) -@patch("pytorch_lightning.strategies.tpu_spawn.MpDeviceLoader") +@RunIf(tpu=True) +@patch("torch_xla.distributed.parallel_loader.MpDeviceLoader") @patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device") def test_mp_device_dataloader_attribute(root_device_mock, mp_loader_mock): dataset = RandomDataset(32, 64) @@ -303,8 +302,7 @@ def test_mp_device_dataloader_attribute(root_device_mock, mp_loader_mock): assert processed_dataloader.dataset == processed_dataloader._loader.dataset -@RunIf(tpu=True) -def test_warning_if_tpus_not_used(): +def test_warning_if_tpus_not_used(tpu_available): with pytest.warns(UserWarning, match="TPU available but not used. Set `accelerator` and `devices`"): Trainer() @@ -320,6 +318,7 @@ def test_warning_if_tpus_not_used(): ("2,", [2]), ], ) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_trainer_config_device_ids(devices, expected_device_ids): trainer = Trainer(accelerator="tpu", devices=devices) assert trainer.device_ids == expected_device_ids diff --git a/tests/tests_pytorch/benchmarks/test_basic_parity.py b/tests/tests_pytorch/benchmarks/test_basic_parity.py index 16a1dc0d1a2d7..1e817af34d892 100644 --- a/tests/tests_pytorch/benchmarks/test_basic_parity.py +++ b/tests/tests_pytorch/benchmarks/test_basic_parity.py @@ -57,7 +57,7 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_pytorch_parity( - tmpdir, cls_model: LightningModule, max_diff_speed: float, max_diff_memory: float, num_epochs: int, num_runs: int + cls_model: LightningModule, max_diff_speed: float, max_diff_memory: float, num_epochs: int, num_runs: int ): """Verify that the same pytorch and lightning models achieve the same results.""" lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs) diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index 2a2bae8a2e5a4..36b30dc346d65 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Dict, Optional from unittest import mock from unittest.mock import Mock @@ -97,6 +98,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_device_stats_monitor_tpu(tmpdir): """Test TPU stats are logged using a logger.""" diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 127bf93fe3221..cd9b1df221bf7 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -22,7 +22,11 @@ from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_11, + _TORCH_GREATER_EQUAL_1_12, + _TORCH_GREATER_EQUAL_1_13, +) class TestBackboneFinetuningCallback(BackboneFinetuning): @@ -370,6 +374,8 @@ def test_callbacks_restore(tmpdir): expected["maximize"] = False if _TORCH_GREATER_EQUAL_1_12: expected["foreach"] = None + if _TORCH_GREATER_EQUAL_1_13: + expected["differentiable"] = False assert callback._internal_optimizer_metadata[0][0] == expected @@ -386,6 +392,8 @@ def test_callbacks_restore(tmpdir): expected["maximize"] = False if _TORCH_GREATER_EQUAL_1_12: expected["foreach"] = None + if _TORCH_GREATER_EQUAL_1_13: + expected["differentiable"] = False assert callback._internal_optimizer_metadata[0][1] == expected diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index 54f2c11983c85..9486b8fd0b586 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -78,7 +78,7 @@ def validation_step(self, batch, batch_idx): assert f"epoch={idx + 1}" in best_model_path -def test_trainer_save_checkpoint_storage_options(tmpdir): +def test_trainer_save_checkpoint_storage_options(tmpdir, xla_available): """This test validates that storage_options argument is properly passed to ``CheckpointIO``""" model = BoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 725953d992e25..a8b54e737c21a 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -26,7 +26,7 @@ import pytorch_lightning from lightning_lite.plugins.environments.lightning_environment import find_free_network_port from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector -from pytorch_lightning.utilities.imports import _IS_WINDOWS +from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_12 from tests_pytorch import _PATH_DATASETS @@ -72,17 +72,6 @@ def restore_env_variables(): "HOROVOD_FUSION_THRESHOLD", "RANK", # set by DeepSpeed "POPLAR_ENGINE_OPTIONS", # set by IPUStrategy - # set by XLA - "TF2_BEHAVIOR", - "XRT_MESH_SERVICE_ADDRESS", - "XRT_TORCH_DIST_ROOT", - "XRT_MULTI_PROCESSING_DEVICE", - "XRT_SHARD_WORLD_SIZE", - "XRT_LOCAL_WORKER", - "XRT_HOST_WORLD_SIZE", - "XRT_SHARD_ORDINAL", - "XRT_SHARD_LOCAL_ORDINAL", - "TF_CPP_MIN_LOG_LEVEL", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -147,6 +136,9 @@ def cuda_count_4(monkeypatch): def mock_mps_count(monkeypatch, n: int) -> None: + if n > 0 and not _TORCH_GREATER_EQUAL_1_12: + # torch doesn't allow creation of mps devices on older versions + monkeypatch.setattr("torch.device", lambda *_: "mps") monkeypatch.setattr(lightning_lite.accelerators.mps, "_get_all_available_mps_gpus", lambda: list(range(n))) monkeypatch.setattr(lightning_lite.accelerators.mps.MPSAccelerator, "is_available", lambda *_: n > 0) @@ -161,6 +153,36 @@ def mps_count_1(monkeypatch): mock_mps_count(monkeypatch, 1) +@pytest.fixture(scope="function") +def mps_count_2(monkeypatch): + mock_mps_count(monkeypatch, 2) + + +@pytest.fixture(scope="function") +def mps_count_4(monkeypatch): + mock_mps_count(monkeypatch, 4) + + +@pytest.fixture(scope="function") +def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(pytorch_lightning.accelerators.tpu, "_XLA_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning.strategies.tpu_spawn, "_XLA_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning.strategies.single_tpu, "_XLA_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning.plugins.precision.tpu, "_XLA_AVAILABLE", True) + monkeypatch.setattr(pytorch_lightning.strategies.launchers.xla, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.accelerators.tpu, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.plugins.environments.xla_environment, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.plugins.io.xla_plugin, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.strategies.xla, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning_lite.strategies.launchers.xla, "_XLA_AVAILABLE", True) + + +@pytest.fixture(scope="function") +def tpu_available(xla_available, monkeypatch) -> None: + monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) + monkeypatch.setattr(lightning_lite.accelerators.tpu.TPUAccelerator, "is_available", lambda: True) + + @pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 71f2e07d23709..2c0757d1cb82d 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -24,7 +24,7 @@ from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13 from tests_pytorch.helpers.runif import RunIf @@ -333,7 +333,7 @@ def __init__(self, spec): m_0 = BoringModelWithShardedTensor(spec) m_0.sharded_tensor.local_shards()[0].tensor.fill_(1) - name_st = ".sharded_tensor" if _TORCH_GREATER_EQUAL_1_11 else "sharded_tensor" + name_st = ".sharded_tensor" if _TORCH_GREATER_EQUAL_1_11 and not _TORCH_GREATER_EQUAL_1_13 else "sharded_tensor" assert name_st in m_0.state_dict(), 'Expect "sharded_tensor" to appear in the state dict' m_1 = BoringModelWithShardedTensor(spec) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 3ff5490695baa..4e85fac609ebe 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -132,12 +132,17 @@ def test_v1_10_deprecated_xla_device_utilities(): with pytest.deprecated_call(match="xla_device.XLADeviceUtils` has been deprecated in v1.8.0"): XLADeviceUtils() - with pytest.deprecated_call(match="xla_device.XLADeviceUtils` has been deprecated in v1.8.0"): + with pytest.deprecated_call(match="xla_device.XLADeviceUtils.xla_available` has been deprecated in v1.8.0"): XLADeviceUtils.xla_available() - with pytest.deprecated_call(match="xla_device.XLADeviceUtils` has been deprecated in v1.8.0"): + with pytest.deprecated_call(match="xla_device.XLADeviceUtils.tpu_device_exists` has been deprecated in v1.8.0"): XLADeviceUtils.tpu_device_exists() + from pytorch_lightning.utilities.distributed import tpu_distributed + + with pytest.deprecated_call(match="tpu_distributed` has been deprecated in v1.8.0"): + tpu_distributed() + def test_v1_10_deprecated_apply_func_utilities(): with pytest.deprecated_call(match="apply_func.apply_to_collection` has been deprecated in v1.8.0"): @@ -277,8 +282,7 @@ def test_lite_convert_deprecated_gpus_argument(cuda_count_2): @RunIf(skip_windows=True) -@mock.patch("lightning_lite.accelerators.TPUAccelerator.is_available", return_value=True) -def test_lite_convert_deprecated_tpus_argument(*_): +def test_lite_convert_deprecated_tpus_argument(tpu_available): with pytest.deprecated_call(match=escape("Setting `Lite(tpu_cores=8)` is deprecated in v1.8.0")): lite = EmptyLite(tpu_cores=8) assert isinstance(lite._accelerator, LiteTPUAccelerator) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index 3110948cd8ddf..548c7feec41e1 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -34,9 +34,7 @@ def test_v2_0_0_deprecated_gpus(cuda_count_4): @RunIf(skip_windows=True) -@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) -@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8) -def test_v2_0_0_deprecated_tpu_cores(*_): +def test_v2_0_0_deprecated_tpu_cores(tpu_available): with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."): _ = Trainer(tpu_cores=8) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 1f369b6c759a4..98b1530500bc8 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -22,6 +22,7 @@ from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.accelerators.mps import MPSAccelerator +from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE @@ -35,7 +36,6 @@ _PSUTIL_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, _TORCH_QUANTIZE_AVAILABLE, - _TPU_AVAILABLE, ) _HOROVOD_NCCL_AVAILABLE = False @@ -172,7 +172,7 @@ def __new__( reasons.append("unimplemented on Windows") if tpu: - conditions.append(not _TPU_AVAILABLE) + conditions.append(not TPUAccelerator.is_available()) reasons.append("TPU") # used in conftest.py::pytest_collection_modifyitems kwargs["tpu"] = True diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index b240360db6f6a..7fc86cd9b3b6e 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -83,8 +83,7 @@ def test_single_gpu_model(tmpdir, devices): "-1", ], ) -@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=False) -def test_root_gpu_property_0_raising(_, devices): +def test_root_gpu_property_0_raising(mps_count_0, cuda_count_0, devices): """Test that asking for a GPU when none are available will result in a MisconfigurationException.""" with pytest.raises(MisconfigurationException, match="No supported gpu backend found!"): Trainer(accelerator="gpu", devices=devices, strategy="ddp") diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index 6c2f407687156..95274624d7bb4 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -29,13 +29,9 @@ from pytorch_lightning.strategies import TPUSpawnStrategy from pytorch_lightning.strategies.launchers.xla import _XLALauncher from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync -from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf -if _TPU_AVAILABLE: - import torch_xla - class SerialLoaderBoringModel(BoringModel): def train_dataloader(self): @@ -46,6 +42,7 @@ def val_dataloader(self): @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_devices_1(tmpdir): """Make sure model trains on TPU.""" tutils.reset_seed() @@ -65,6 +62,7 @@ def test_model_tpu_devices_1(tmpdir): @pytest.mark.parametrize("tpu_core", [1, 5]) @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" tutils.reset_seed() @@ -80,10 +78,13 @@ def test_model_tpu_index(tmpdir, tpu_core): model = BoringModel() tpipes.run_model_test(trainer_options, model, with_hpc=False) + import torch_xla + assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}" @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_devices_8(tmpdir): """Make sure model trains on TPU.""" tutils.reset_seed() @@ -103,6 +104,7 @@ def test_model_tpu_devices_8(tmpdir): @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_16bit_tpu_devices_1(tmpdir): """Make sure model trains on TPU.""" tutils.reset_seed() @@ -123,6 +125,7 @@ def test_model_16bit_tpu_devices_1(tmpdir): @pytest.mark.parametrize("tpu_core", [1, 5]) @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_16bit_tpu_index(tmpdir, tpu_core): """Make sure model trains on TPU.""" tutils.reset_seed() @@ -139,10 +142,13 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): model = BoringModel() tpipes.run_model_test(trainer_options, model) + import torch_xla + assert torch_xla._XLAC._xla_get_default_device() == f"xla:{tpu_core}" @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_16bit_tpu_devices_8(tmpdir): """Make sure model trains on TPU.""" tutils.reset_seed() @@ -163,6 +169,7 @@ def test_model_16bit_tpu_devices_8(tmpdir): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works.""" @@ -189,6 +196,7 @@ def validation_step(self, *args, **kwargs): @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_grad_norm(tmpdir): """Test if grad_norm works on TPU.""" tutils.reset_seed() @@ -208,6 +216,7 @@ def test_tpu_grad_norm(tmpdir): @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_clip_grad_by_value(tmpdir): """Test if clip_gradients by value works on TPU.""" tutils.reset_seed() @@ -228,6 +237,7 @@ def test_tpu_clip_grad_by_value(tmpdir): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_dataloaders_passed_to_fit(tmpdir): """Test if dataloaders passed to trainer works on TPU.""" tutils.reset_seed() @@ -237,23 +247,22 @@ def test_dataloaders_passed_to_fit(tmpdir): trainer.fit(model, train_dataloaders=model.train_dataloader(), val_dataloaders=model.val_dataloader()) -@RunIf(tpu=True) @pytest.mark.parametrize("tpu_cores", [[1, 8], "9, ", [9], [0], 2, 10]) -def test_tpu_misconfiguration(tpu_cores): +def test_tpu_misconfiguration(tpu_cores, tpu_available): with pytest.raises(TypeError, match="`tpu_cores` can only be"): Trainer(accelerator="tpu", devices=tpu_cores) -@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU") -def test_exception_when_no_tpu_found(): +@pytest.mark.skipif(TPUAccelerator.is_available(), reason="test requires missing TPU") +def test_exception_when_no_tpu_found(xla_available): """Test if exception is thrown when xla devices are not available.""" - - with pytest.raises(MisconfigurationException, match="TPUAccelerator can not run on your system"): + with pytest.raises(MisconfigurationException, match="TPUAccelerator` can not run on your system"): Trainer(accelerator="tpu", devices=8) @pytest.mark.parametrize("tpu_cores", [1, 8, [1]]) @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_accelerator_set_when_using_tpu(tpu_cores): """Test if the accelerator is set to `tpu` when tpu_cores is not None.""" assert isinstance(Trainer(accelerator="tpu", devices=tpu_cores).accelerator, TPUAccelerator) @@ -264,6 +273,7 @@ def test_accelerator_set_when_using_tpu(tpu_cores): [("--tpu_cores=8", {"tpu_cores": 8}), ("--tpu_cores=1,", {"tpu_cores": "1,"})], ) @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_cores_with_argparse(cli_args, expected): """Test passing tpu_cores in command line.""" cli_args = cli_args.split(" ") if cli_args else [] @@ -278,28 +288,24 @@ def test_tpu_cores_with_argparse(cli_args, expected): assert Trainer.from_argparse_args(args) -@RunIf(tpu=True, standalone=True) -@pytest.mark.parametrize("clip_val", [10]) +@RunIf(min_torch="1.10") +@pytest.mark.parametrize("clip_val", [0, 10]) @mock.patch("torch.nn.utils.clip_grad_norm_") -def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): - """Ensure that clip gradients is only called if the value is greater than 0. - - TODO: Fix (test fails with parametrize) - """ - tutils.reset_seed() - trainer_options = dict( +def test_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): + """Ensure that clip gradients is only called if the value is greater than 0.""" + # TODO: shouldn't be in the TPU file + model = BoringModel() + trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, max_epochs=1, - accelerator="tpu", devices=1, precision=16, limit_train_batches=4, - limit_val_batches=4, + limit_val_batches=0, gradient_clip_val=clip_val, ) - model = BoringModel() - tpipes.run_model_test(trainer_options, model, with_hpc=False) + trainer.fit(model) if clip_val > 0: mock_clip_grad_norm.assert_called() @@ -308,6 +314,7 @@ def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_if_test_works_with_checkpoint_false(tmpdir): """Ensure that model trains properly when `enable_checkpointing` is set to False.""" @@ -349,12 +356,14 @@ def tpu_sync_dist_fn(strategy): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_sync_dist(): """Test tpu spawn sync dist operation.""" xla_launch(tpu_sync_dist_fn) @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_debug_mode(tmpdir): """Test if debug mode works on TPU.""" @@ -382,6 +391,7 @@ def teardown(self, stage): @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_host_world_size(tmpdir): """Test Host World size env setup on TPU.""" diff --git a/tests/tests_pytorch/plugins/precision/test_tpu_bf16_plugin.py b/tests/tests_pytorch/plugins/precision/test_tpu_bf16_plugin.py index abf02548fde7d..fb6296ed5eda8 100644 --- a/tests/tests_pytorch/plugins/precision/test_tpu_bf16_plugin.py +++ b/tests/tests_pytorch/plugins/precision/test_tpu_bf16_plugin.py @@ -17,7 +17,7 @@ from pytorch_lightning.plugins import TPUBf16PrecisionPlugin -def test_teardown(): +def test_teardown(xla_available): plugin = TPUBf16PrecisionPlugin() plugin.connect(Mock(), Mock(), Mock()) assert os.environ.get("XLA_USE_BF16") == "1" diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index f8005f2d8a80e..24646c117d3c2 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -85,8 +85,7 @@ def test_ranks_available_manual_strategy_selection(mock_gpu_acc_available, strat dict(strategy="ddp_spawn", accelerator="gpu", devices=[1, 2]), ], ) -@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=list(range(4))) -def test_ranks_available_automatic_strategy_selection(_, cuda_count_4, trainer_kwargs): +def test_ranks_available_automatic_strategy_selection(mps_count_4, cuda_count_4, trainer_kwargs): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 trainer_kwargs.update(num_nodes=num_nodes) diff --git a/tests/tests_pytorch/profilers/test_xla_profiler.py b/tests/tests_pytorch/profilers/test_xla_profiler.py index 694d978905177..2f18141ab025f 100644 --- a/tests/tests_pytorch/profilers/test_xla_profiler.py +++ b/tests/tests_pytorch/profilers/test_xla_profiler.py @@ -13,23 +13,19 @@ # limitations under the License. import os from multiprocessing import Event, Process +from unittest import mock import pytest from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.profilers import XLAProfiler -from pytorch_lightning.utilities import _TPU_AVAILABLE from tests_pytorch.helpers.runif import RunIf -if _TPU_AVAILABLE: - import torch_xla.debug.profiler as xp - import torch_xla.utils.utils as xu - @RunIf(tpu=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_xla_profiler_instance(tmpdir): - model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler="xla", accelerator="tpu", devices=8) @@ -37,8 +33,10 @@ def test_xla_profiler_instance(tmpdir): trainer.fit(model) -@pytest.mark.skipif(True, reason="XLA Profiler doesn't support Prog. capture yet") +@pytest.mark.skip(reason="XLA Profiler doesn't support Prog. capture yet") def test_xla_profiler_prog_capture(tmpdir): + import torch_xla.debug.profiler as xp + import torch_xla.utils.utils as xu port = xu.get_free_tcp_ports()[0] training_started = Event() diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index ad3e891ad607f..ef1a5ccce1547 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -18,6 +18,7 @@ import torch from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from tests_pytorch.helpers.runif import RunIf @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) @@ -26,7 +27,7 @@ def test_multiprocessing_launcher_forking_on_unsupported_platform(_): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") -@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))]) @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp") def test_multiprocessing_launcher_start_method(mp_mock, start_method): mp_mock.get_all_start_methods.return_value = [start_method] @@ -41,7 +42,7 @@ def test_multiprocessing_launcher_start_method(mp_mock, start_method): ) -@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))]) @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp") def test_multiprocessing_launcher_restore_globals(mp_mock, start_method): """Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'.""" diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 7fcb18791ba9d..561a00c1931a9 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -57,8 +57,7 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir): @RunIf(skip_windows=True) -@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=list(range(2))) -def test_torch_distributed_backend_invalid(_, cuda_count_2, tmpdir): +def test_torch_distributed_backend_invalid(cuda_count_2, tmpdir): """This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError.""" model = BoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index dcb182b657c49..8536e0b8b3438 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -56,8 +56,7 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy): @RunIf(skip_windows=True) -def test_tpu_spawn_debug_strategy_registry(): - +def test_tpu_spawn_debug_strategy_registry(xla_available): strategy = "tpu_spawn_debug" assert strategy in StrategyRegistry @@ -65,7 +64,6 @@ def test_tpu_spawn_debug_strategy_registry(): assert StrategyRegistry[strategy]["strategy"] == TPUSpawnStrategy trainer = Trainer(strategy=strategy) - assert isinstance(trainer.strategy, TPUSpawnStrategy) diff --git a/tests/tests_pytorch/strategies/test_tpu_spawn.py b/tests/tests_pytorch/strategies/test_tpu_spawn.py index 967e44a42c9de..fca76a3be65ac 100644 --- a/tests/tests_pytorch/strategies/test_tpu_spawn.py +++ b/tests/tests_pytorch/strategies/test_tpu_spawn.py @@ -55,9 +55,8 @@ def predict_dataloader(self): (None, [_loader, _loader_no_len], None, None), ], ) -@mock.patch("pytorch_lightning.strategies.tpu_spawn.xm") def test_error_iterable_dataloaders_passed_to_fit( - _, tmpdir, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders + xla_available, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders ): """Test that the TPUSpawnStrategy identifies dataloaders with iterable datasets and fails early.""" trainer = Trainer() @@ -76,8 +75,7 @@ def test_error_iterable_dataloaders_passed_to_fit( TPUSpawnStrategy(MagicMock()).connect(model) -@mock.patch("pytorch_lightning.strategies.tpu_spawn.xm") -def test_error_process_iterable_dataloader(_): +def test_error_process_iterable_dataloader(xla_available): with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len) @@ -90,6 +88,7 @@ def on_train_start(self) -> None: @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_one_core(): """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" model = BoringModelTPU() diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 0496ed8e2b465..a77951ddbd1a5 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -210,8 +210,7 @@ def test_dist_backend_accelerator_mapping(cuda_count_0): assert trainer.strategy.local_rank == 0 -@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=[0, 1]) -def test_ipython_incompatible_backend_error(_, cuda_count_2, monkeypatch): +def test_ipython_incompatible_backend_error(mps_count_2, cuda_count_2, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): Trainer(strategy="ddp", accelerator="gpu", devices=2) @@ -234,8 +233,7 @@ def test_ipython_compatible_dp_strategy_gpu(cuda_count_2, monkeypatch): @RunIf(skip_windows=True) -@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) -def test_ipython_compatible_strategy_tpu(_, monkeypatch): +def test_ipython_compatible_strategy_tpu(tpu_available, monkeypatch): monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) trainer = Trainer(accelerator="tpu") assert trainer.strategy.launcher.is_interactive_compatible @@ -271,14 +269,14 @@ def test_accelerator_cpu(cuda_count_0): with pytest.raises( MisconfigurationException, - match="CUDAAccelerator can not run on your system since the accelerator is not available.", + match="CUDAAccelerator` can not run on your system since the accelerator is not available.", ): with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed"): Trainer(gpus=1) with pytest.raises( MisconfigurationException, - match="CUDAAccelerator can not run on your system since the accelerator is not available.", + match="CUDAAccelerator` can not run on your system since the accelerator is not available.", ): Trainer(accelerator="cuda") @@ -456,7 +454,6 @@ def test_strategy_choice_ddp_cuda(strategy, expected_cls, mps_count_0, cuda_coun assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) -@RunIf(mps=True) @pytest.mark.parametrize("strategy,expected_cls", [("ddp", DDPStrategy), ("ddp_spawn", DDPSpawnStrategy)]) def test_strategy_choice_ddp_mps(strategy, expected_cls, mps_count_1, cuda_count_0): trainer = Trainer(fast_dev_run=True, strategy=strategy, accelerator="gpu", devices=1) @@ -605,20 +602,20 @@ def test_check_native_fsdp_strategy_and_fallback(): Trainer(accelerator="cpu", strategy="fsdp_native") -@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) -def test_unsupported_tpu_choice(mock_tpu_acc_avail): - +def test_unsupported_tpu_choice(tpu_available): with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64) # if user didn't set strategy, AcceleratorConnector will choose the TPUSingleStrategy or TPUSpawnStrategy - with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): - with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported"): - Trainer(accelerator="tpu", precision=16, strategy="ddp") + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( + UserWarning, match=r"accelerator='tpu', precision=16\)` but native AMP is not supported" + ): + Trainer(accelerator="tpu", precision=16, strategy="ddp") - with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"): - with pytest.warns(UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported"): - Trainer(accelerator="tpu", precision=16, amp_backend="apex", strategy="single_device") + with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( + UserWarning, match=r"accelerator='tpu', precision=16\)` but apex AMP is not supported" + ): + Trainer(accelerator="tpu", precision=16, amp_backend="apex", strategy="single_device") @mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True) @@ -634,7 +631,7 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch): Trainer(accelerator="ipu", precision=64) -@mock.patch("pytorch_lightning.utilities.imports._TPU_AVAILABLE", return_value=False) +@mock.patch("pytorch_lightning.accelerators.tpu._XLA_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._IPU_AVAILABLE", return_value=False) @mock.patch("pytorch_lightning.utilities.imports._HPU_AVAILABLE", return_value=False) def test_devices_auto_choice_cpu(cuda_count_0, *_): @@ -760,12 +757,8 @@ def test_gpu_accelerator_backend_choice_cuda(cuda_count_1): assert isinstance(trainer.accelerator, CUDAAccelerator) -@mock.patch("lightning_lite.accelerators.mps.MPSAccelerator.is_available", return_value=True) -@mock.patch("lightning_lite.accelerators.mps._get_all_available_mps_gpus", return_value=[0]) -@mock.patch("torch.device", return_value="mps") # necessary because torch doesn't allow creation of mps devices -def test_gpu_accelerator_backend_choice_mps(*_): +def test_gpu_accelerator_backend_choice_mps(mps_count_1): trainer = Trainer(accelerator="gpu") - assert trainer._accelerator_connector._accelerator_flag == "mps" assert isinstance(trainer.accelerator, MPSAccelerator) diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 2dec57277e2a4..0095885a06aa9 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os from unittest import mock from unittest.mock import PropertyMock @@ -141,6 +142,7 @@ def test_num_stepping_batches_gpu(trainer_kwargs, estimated_steps, monkeypatch): @RunIf(tpu=True, standalone=True) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_num_stepping_batches_with_tpu_single(): """Test stepping batches with the single-core TPU strategy.""" trainer = Trainer(accelerator="tpu", devices=1, max_epochs=1) diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index a954f90402e84..40d4356bcbba4 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -144,7 +144,7 @@ def test_raise_exception_with_batch_transfer_hooks(monkeypatch, hook, trainer_kw mock_cuda_count(monkeypatch, 2) elif trainer_kwargs.get("accelerator") == "ipu": match_pattern = rf"Overriding `{hook}` is not .* with IPUs" - monkeypatch.setattr(pl.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True) + monkeypatch.setattr(pl.accelerators.ipu.IPUAccelerator, "is_available", lambda: True) monkeypatch.setattr(pl.strategies.ipu, "_IPU_AVAILABLE", lambda: True) def custom_method(self, batch, *_, **__): diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index fa043bb126338..15958500c2dec 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -314,10 +314,9 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length): @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -@mock.patch("lightning_lite.utilities.device_parser._get_all_available_mps_gpus", return_value=[0, 1]) @pytest.mark.parametrize("use_fault_tolerant", [False, True]) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) -def test_combined_data_loader_validation_test(_, cuda_count_2, use_fault_tolerant, replace_sampler_ddp): +def test_combined_data_loader_validation_test(mps_count_2, cuda_count_2, use_fault_tolerant, replace_sampler_ddp): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 4738f71fe2a97..6e2841547ac3c 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -32,13 +32,12 @@ from torch.optim import SGD from torch.utils.data import DataLoader, IterableDataset -import lightning_lite import pytorch_lightning import tests_pytorch.helpers.utils as tutils from lightning_lite.utilities.cloud_io import load as pl_load from lightning_lite.utilities.seed import seed_everything from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator +from pytorch_lightning.accelerators import CPUAccelerator, CUDAAccelerator from pytorch_lightning.callbacks import EarlyStopping, GradientAccumulationScheduler, ModelCheckpoint, Timer from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter @@ -63,7 +62,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_1_12 -from tests_pytorch.conftest import mock_cuda_count +from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -2215,10 +2214,9 @@ def test_trainer_config_device_ids(monkeypatch, trainer_kwargs, expected_device_ if trainer_kwargs.get("accelerator") in ("cuda", "gpu"): mock_cuda_count(monkeypatch, 4) elif trainer_kwargs.get("accelerator") in ("mps", "gpu"): - monkeypatch.setattr(lightning_lite.utilities.device_parser, "_get_all_available_mps_gpus", lambda: [0]) - monkeypatch.setattr(MPSAccelerator, "is_available", lambda *_: True) + mock_mps_count(monkeypatch, 1) elif trainer_kwargs.get("accelerator") == "ipu": - monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda _: True) + monkeypatch.setattr(pytorch_lightning.accelerators.ipu.IPUAccelerator, "is_available", lambda: True) monkeypatch.setattr(pytorch_lightning.strategies.ipu, "_IPU_AVAILABLE", lambda: True) trainer = Trainer(**trainer_kwargs)