Skip to content

Commit

Permalink
Refactor XLA and TPU checks across codebase (#14550)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 4, 2022
1 parent acaeab2 commit 7ef8746
Show file tree
Hide file tree
Showing 60 changed files with 457 additions and 406 deletions.
5 changes: 5 additions & 0 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ local tputests = base.BaseTest {
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
echo "--- Sanity check TPU availability ---"
python -c "from lightning_lite.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
python -c "from pytorch_lightning.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
echo "Sanity check passed!"
echo "--- Running Lite tests ---"
cd tests/tests_lite
PL_RUN_TPU_TESTS=1 coverage run --source=lightning_lite -m pytest -vv --durations=0 ./
Expand Down
2 changes: 0 additions & 2 deletions docs/source-lit/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,6 @@ def find_source():
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/accelerators/mps_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ If Lightning can't detect the Apple Silicon hardware, it will raise this excepti

.. code::
MisconfigurationException: MPSAccelerator can not run on your system since the accelerator is not available.
MisconfigurationException: `MPSAccelerator` can not run on your system since the accelerator is not available.
If you are seeing this despite running on an ARM-enabled Mac, the most likely cause is that your Python is being emulated and thinks it is running on an Intel CPU.
To solve this, re-install your python executable (and if using environment managers like conda, you have to reinstall these as well) by downloading the Apple M1/M2 build (not Intel!), for example `here <https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links>`_.
2 changes: 0 additions & 2 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ def package_list_from_file(file):
from pytorch_lightning.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_XLA_AVAILABLE,
_TPU_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TORCH_GREATER_EQUAL_1_10,
)
Expand Down
75 changes: 72 additions & 3 deletions src/lightning_lite/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import functools
import queue as q
import traceback
from multiprocessing import Process, Queue
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from lightning_utilities.core.imports import RequirementCache

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities.device_parser import _check_data_type
from lightning_lite.utilities.imports import _TPU_AVAILABLE


class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

def setup_device(self, device: torch.device) -> None:
pass

Expand All @@ -47,8 +56,10 @@ def auto_device_count() -> int:
return 8

@staticmethod
@functools.lru_cache(maxsize=1)
def is_available() -> bool:
return _TPU_AVAILABLE
# check `_XLA_AVAILABLE` again to avoid launching processes
return bool(_XLA_AVAILABLE) and _is_device_tpu()

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
Expand All @@ -59,6 +70,64 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
)


# define TPU availability timeout in seconds
TPU_CHECK_TIMEOUT = 60


def _inner_f(queue: Queue, func: Callable, *args: Any, **kwargs: Any) -> None: # pragma: no cover
try:
queue.put(func(*args, **kwargs))
except Exception:
traceback.print_exc()
queue.put(None)


def _multi_process(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]:
queue: Queue = Queue()
proc = Process(target=_inner_f, args=(queue, func, *args), kwargs=kwargs)
proc.start()
proc.join(TPU_CHECK_TIMEOUT)
try:
return queue.get_nowait()
except q.Empty:
traceback.print_exc()
return False

return wrapper


@_multi_process
def _is_device_tpu() -> bool:
"""Check if TPU devices are available. Runs XLA device check within a separate process.
Return:
A boolean value indicating if TPU devices are available
"""
if not _XLA_AVAILABLE:
return False
import torch_xla.core.xla_model as xm

# For the TPU Pod training process, for example, if we have
# TPU v3-32 with 4 VMs, the world size would be 4 and as
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))


_XLA_AVAILABLE = RequirementCache("torch_xla")


def tpu_distributed() -> bool:
if not TPUAccelerator.is_available():
return False
import torch_xla.core.xla_model as xm

return xm.xrt_world_size() > 1


def parse_tpu_cores(tpu_cores: Optional[Union[int, str, List[int]]]) -> Optional[Union[int, List[int]]]:
"""
Parses the tpu_cores given in the format as accepted by the
Expand Down
17 changes: 10 additions & 7 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
from lightning_lite.utilities.device_parser import determine_root_gpu_device
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE

_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]
Expand Down Expand Up @@ -301,7 +301,7 @@ def _check_device_config_and_set_final_flags(
def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if self._accelerator_flag == "auto":
if _TPU_AVAILABLE:
if TPUAccelerator.is_available():
return "tpu"
if _IPU_AVAILABLE:
return "ipu"
Expand All @@ -328,23 +328,26 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
else:
assert self._accelerator_flag is not None
self.accelerator = ACCELERATOR_REGISTRY.get(self._accelerator_flag)
accelerator_cls = self.accelerator.__class__

if not self.accelerator.is_available():
if not accelerator_cls.is_available():
available_accelerator = [
acc_str for acc_str in self._registered_accelerators if ACCELERATOR_REGISTRY.get(acc_str).is_available()
acc_str
for acc_str in self._registered_accelerators
if ACCELERATOR_REGISTRY[acc_str]["accelerator"].is_available()
]
raise RuntimeError(
f"{self.accelerator.__class__.__qualname__} can not run on your system"
f"`{accelerator_cls.__qualname__}` can not run on your system"
" since the accelerator is not available. The following accelerator(s)"
" is available and can be passed into `accelerator` argument of"
f" `Lite`: {available_accelerator}."
)

self._set_devices_flag_if_auto_passed()

self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
self._devices_flag = accelerator_cls.parse_devices(self._devices_flag)
if not self._parallel_devices:
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)

def _set_devices_flag_if_auto_passed(self) -> None:
if self._devices_flag == "auto" or self._devices_flag is None:
Expand Down
26 changes: 20 additions & 6 deletions src/lightning_lite/plugins/environments/xla_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
# limitations under the License.
import logging
import os
from typing import Any

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE, TPUAccelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.imports import _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm

log = logging.getLogger(__name__)

Expand All @@ -31,36 +28,53 @@ class XLAEnvironment(ClusterEnvironment):
`here <https://github.com/pytorch/xla/blob/master/torch_xla/core/xla_env_vars.py>`_.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

@property
def creates_processes_externally(self) -> bool:
return False

@property
def main_address(self) -> str:
import torch_xla.core.xla_env_vars as xenv

return os.environ[xenv.TPU_MESH_CTLER_ADDR]

@property
def main_port(self) -> int:
import torch_xla.core.xla_env_vars as xenv

return int(os.environ[xenv.TPU_MESH_CTLER_PORT])

@staticmethod
def detect() -> bool:
return _TPU_AVAILABLE
return TPUAccelerator.is_available()

def world_size(self) -> int:
import torch_xla.core.xla_model as xm

return xm.xrt_world_size()

def set_world_size(self, size: int) -> None:
log.debug("XLAEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
import torch_xla.core.xla_model as xm

return xm.get_ordinal()

def set_global_rank(self, rank: int) -> None:
log.debug("XLAEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
import torch_xla.core.xla_model as xm

return xm.get_local_ordinal()

def node_rank(self) -> int:
import torch_xla.core.xla_env_vars as xenv

return int(os.environ.get(xenv.HOST_ORDINAL, 0))
13 changes: 9 additions & 4 deletions src/lightning_lite/plugins/io/xla_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@

from lightning_utilities.core.apply_func import apply_to_collection

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
from lightning_lite.plugins.io.torch_plugin import TorchCheckpointIO
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
from lightning_lite.utilities.imports import _OMEGACONF_AVAILABLE
from lightning_lite.utilities.types import _PATH

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


class XLACheckpointIO(TorchCheckpointIO):
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Expand All @@ -55,4 +58,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
# Ref: https://github.com/pytorch/xla/issues/2773
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
import torch_xla.core.xla_model as xm

xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path)
11 changes: 5 additions & 6 deletions src/lightning_lite/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,10 @@

from torch.multiprocessing import get_context

from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from lightning_lite.utilities import _TPU_AVAILABLE
from lightning_lite.utilities.apply_func import move_data_to_device

if _TPU_AVAILABLE:
import torch_xla.distributed.xla_multiprocessing as xmp
else:
xmp = None

if TYPE_CHECKING:
from lightning_lite.strategies import XLAStrategy

Expand All @@ -47,6 +42,8 @@ class _XLALauncher(_MultiProcessingLauncher):
"""

def __init__(self, strategy: "XLAStrategy") -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(strategy=strategy, start_method="fork")

@property
Expand All @@ -66,6 +63,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""
context = get_context(self._start_method)
return_queue = context.SimpleQueue()
import torch_xla.distributed.xla_multiprocessing as xmp

xmp.spawn(
self._wrapping_function,
args=(function, args, kwargs, return_queue),
Expand Down
Loading

0 comments on commit 7ef8746

Please sign in to comment.