From cdb86de3b78f5d5397d10b05f91300ba983434e5 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Tue, 21 Jan 2025 14:38:15 +0100 Subject: [PATCH 1/4] Add XPU accelerator Signed-off-by: Ashwin Vaidya --- src/anomalib/engine/__init__.py | 4 +- src/anomalib/engine/accelerator/__init__.py | 8 +++ src/anomalib/engine/accelerator/xpu.py | 64 +++++++++++++++++++++ src/anomalib/engine/strategy/__init__.py | 8 +++ src/anomalib/engine/strategy/xpu_single.py | 43 ++++++++++++++ 5 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 src/anomalib/engine/accelerator/__init__.py create mode 100644 src/anomalib/engine/accelerator/xpu.py create mode 100644 src/anomalib/engine/strategy/__init__.py create mode 100644 src/anomalib/engine/strategy/xpu_single.py diff --git a/src/anomalib/engine/__init__.py b/src/anomalib/engine/__init__.py index e887d4f7bb..cd31dc5234 100644 --- a/src/anomalib/engine/__init__.py +++ b/src/anomalib/engine/__init__.py @@ -26,6 +26,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from .accelerator import XPUAccelerator from .engine import Engine +from .strategy import SingleXPUStrategy -__all__ = ["Engine"] +__all__ = ["Engine", "SingleXPUStrategy", "XPUAccelerator"] diff --git a/src/anomalib/engine/accelerator/__init__.py b/src/anomalib/engine/accelerator/__init__.py new file mode 100644 index 0000000000..d40b6990b2 --- /dev/null +++ b/src/anomalib/engine/accelerator/__init__.py @@ -0,0 +1,8 @@ +"""Accelerator for Lightning Trainer.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .xpu import XPUAccelerator + +__all__ = ["XPUAccelerator"] diff --git a/src/anomalib/engine/accelerator/xpu.py b/src/anomalib/engine/accelerator/xpu.py new file mode 100644 index 0000000000..e313596af2 --- /dev/null +++ b/src/anomalib/engine/accelerator/xpu.py @@ -0,0 +1,64 @@ +"""XPU Accelerator.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from lightning.pytorch.accelerators import Accelerator, AcceleratorRegistry + + +class XPUAccelerator(Accelerator): + """Support for a XPU, optimized for large-scale machine learning.""" + + accelerator_name = "xpu" + + @staticmethod + def setup_device(device: torch.device) -> None: + """Sets up the specified device.""" + if device.type != "xpu": + msg = f"Device should be xpu, got {device} instead" + raise RuntimeError(msg) + + torch.xpu.set_device(device) + + @staticmethod + def parse_devices(devices: str | list | torch.device) -> list: + """Parses devices for multi-GPU training.""" + if isinstance(devices, list): + return devices + return [devices] + + @staticmethod + def get_parallel_devices(devices: list) -> list[torch.device]: + """Generates a list of parrallel devices.""" + return [torch.device("xpu", idx) for idx in devices] + + @staticmethod + def auto_device_count() -> int: + """Returns number of XPU devices available.""" + return torch.xpu.device_count() + + @staticmethod + def is_available() -> bool: + """Checks if XPU available.""" + return hasattr(torch, "xpu") and torch.xpu.is_available() + + @staticmethod + def get_device_stats(device: str | torch.device) -> dict[str, Any]: + """Returns XPU devices stats.""" + return {} + + def teardown(self) -> None: + """Teardown the XPU accelerator. + + This method is empty as it needs to be overridden otherwise the base class will throw an error. + """ + + +AcceleratorRegistry.register( + XPUAccelerator.accelerator_name, + XPUAccelerator, + description="Accelerator supports XPU devices", +) diff --git a/src/anomalib/engine/strategy/__init__.py b/src/anomalib/engine/strategy/__init__.py new file mode 100644 index 0000000000..3669ef1803 --- /dev/null +++ b/src/anomalib/engine/strategy/__init__.py @@ -0,0 +1,8 @@ +"""Strategy for Lightning Trainer.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .xpu_single import SingleXPUStrategy + +__all__ = ["SingleXPUStrategy"] diff --git a/src/anomalib/engine/strategy/xpu_single.py b/src/anomalib/engine/strategy/xpu_single.py new file mode 100644 index 0000000000..b42e5d9350 --- /dev/null +++ b/src/anomalib/engine/strategy/xpu_single.py @@ -0,0 +1,43 @@ +"""Lightning strategy for single XPU device.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import lightning.pytorch as pl +import torch +from lightning.pytorch.strategies import SingleDeviceStrategy, StrategyRegistry +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning_fabric.plugins import CheckpointIO +from lightning_fabric.plugins.precision import Precision +from lightning_fabric.utilities.types import _DEVICE + + +class SingleXPUStrategy(SingleDeviceStrategy): + """Strategy for training on single XPU device.""" + + strategy_name = "xpu_single" + + def __init__( + self, + device: _DEVICE = "xpu:0", + accelerator: pl.accelerators.Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: Precision | None = None, + ) -> None: + if not (hasattr(torch, "xpu") and torch.xpu.is_available()): + msg = "`SingleXPUStrategy` requires XPU devices to run" + raise MisconfigurationException(msg) + + super().__init__( + accelerator=accelerator, + device=device, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + + +StrategyRegistry.register( + SingleXPUStrategy.strategy_name, + SingleXPUStrategy, + description="Strategy that enables training on single XPU", +) From 4f2209cc9f6f729d26297e64183061b107092fc8 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Tue, 21 Jan 2025 14:41:30 +0100 Subject: [PATCH 2/4] Update changelog Signed-off-by: Ashwin Vaidya --- CHANGELOG.md | 2 ++ src/anomalib/engine/__init__.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 036a2f0e49..2a4237d2e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- 🚀 Add XPU accelerator and strategy by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2530 + ### Removed ### Changed diff --git a/src/anomalib/engine/__init__.py b/src/anomalib/engine/__init__.py index cd31dc5234..92f1d7a81e 100644 --- a/src/anomalib/engine/__init__.py +++ b/src/anomalib/engine/__init__.py @@ -23,7 +23,7 @@ >>> engine = Engine(config=config) # doctest: +SKIP """ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .accelerator import XPUAccelerator From b78075433ee735413f6aeddf3f7bf39b635661f8 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Tue, 21 Jan 2025 15:01:35 +0100 Subject: [PATCH 3/4] precommit Signed-off-by: Ashwin Vaidya --- src/anomalib/engine/accelerator/xpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/engine/accelerator/xpu.py b/src/anomalib/engine/accelerator/xpu.py index e313596af2..026268b5fa 100644 --- a/src/anomalib/engine/accelerator/xpu.py +++ b/src/anomalib/engine/accelerator/xpu.py @@ -48,6 +48,7 @@ def is_available() -> bool: @staticmethod def get_device_stats(device: str | torch.device) -> dict[str, Any]: """Returns XPU devices stats.""" + del device # Unused return {} def teardown(self) -> None: From 33a0e0eda43ece23757750d55fb90c7da58ab78f Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Tue, 21 Jan 2025 18:04:10 +0100 Subject: [PATCH 4/4] Add documentation Signed-off-by: Ashwin Vaidya --- README.md | 28 ++++++++++ docs/source/markdown/guides/how_to/index.md | 8 +++ .../how_to/training_on_intel_gpus/index.md | 52 +++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 docs/source/markdown/guides/how_to/training_on_intel_gpus/index.md diff --git a/README.md b/README.md index 96140268ab..d9e9a14da7 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,34 @@ anomalib predict --model anomalib.models.Patchcore \ > 📘 **Note:** For advanced inference options including Gradio and OpenVINO, check our [Inference Documentation](https://anomalib.readthedocs.io). +# Training on Intel GPUs + +> [!Note] +> Currently, only single GPU training is supported on Intel GPUs. +> These commands were tested on Arc 750 and Arc 770. + +Ensure that you have PyTorch with XPU support installed. For more information, please refer to the [PyTorch XPU documentation](https://pytorch.org/docs/stable/notes/get_start_xpu.html) + +## 🔌 API + +```python +from anomalib.data import MVTec +from anomalib.engine import Engine, SingleXPUStrategy, XPUAccelerator +from anomalib.models import Stfpm + +engine = Engine( + strategy=SingleXPUStrategy(), + accelerator=XPUAccelerator(), +) +engine.train(Stfpm(), datamodule=MVTec()) +``` + +## ⌨️ CLI + +```bash +anomalib train --model Padim --data MVTec --trainer.accelerator xpu --trainer.strategy xpu_single +``` + # ⚙️ Hyperparameter Optimization Anomalib supports hyperparameter optimization (HPO) using [Weights & Biases](https://wandb.ai/) and [Comet.ml](https://www.comet.com/). diff --git a/docs/source/markdown/guides/how_to/index.md b/docs/source/markdown/guides/how_to/index.md index 171c655ab6..235d837252 100644 --- a/docs/source/markdown/guides/how_to/index.md +++ b/docs/source/markdown/guides/how_to/index.md @@ -72,6 +72,13 @@ Learn more about anomalib's deployment capabilities Learn more about anomalib hpo, sweep and benchmarking pipelines ::: +:::{grid-item-card} {octicon}`cpu` Training on Intel GPUs +:link: ./training_on_intel_gpus/index +:link-type: doc + +Learn more about training on Intel GPUs +::: + :::: ```{toctree} @@ -83,4 +90,5 @@ Learn more about anomalib hpo, sweep and benchmarking pipelines ./models/index ./pipelines/index ./visualization/index +./training_on_intel_gpus/index ``` diff --git a/docs/source/markdown/guides/how_to/training_on_intel_gpus/index.md b/docs/source/markdown/guides/how_to/training_on_intel_gpus/index.md new file mode 100644 index 0000000000..17888d2cc3 --- /dev/null +++ b/docs/source/markdown/guides/how_to/training_on_intel_gpus/index.md @@ -0,0 +1,52 @@ +# Training on Intel GPUs + +This tutorial demonstrates how to train a model on Intel GPUs using anomalib. +Anomalib comes with XPU accelerator and strategy for PyTorch Lightning. This allows you to train your models on Intel GPUs. + +> [!Note] +> Currently, only single GPU training is supported on Intel GPUs. +> These commands were tested on Arc 750 and Arc 770. + +## Installing Drivers + +First, check if you have the correct drivers installed. If you are on Ubuntu, you can refer to the [following guide](https://dgpu-docs.intel.com/driver/client/overview.html). + +Another recommended tool is `xpu-smi` which can be installed from the [releases](https://github.com/intel/xpumanager) page. + +If everything is installed correctly, you should be able to see your card using the following command: + +```bash +xpu-smi discovery +``` + +## Installing PyTorch + +Then, ensure that you have PyTorch with XPU support installed. For more information, please refer to the [PyTorch XPU documentation](https://pytorch.org/docs/stable/notes/get_start_xpu.html) + +To ensure that your PyTorch installation supports XPU, you can run the following command: + +```bash +python -c "import torch; print(torch.xpu.is_available())" +``` + +If the command returns `True`, then your PyTorch installation supports XPU. + +## 🔌 API + +```python +from anomalib.data import MVTec +from anomalib.engine import Engine, SingleXPUStrategy, XPUAccelerator +from anomalib.models import Stfpm + +engine = Engine( + strategy=SingleXPUStrategy(), + accelerator=XPUAccelerator(), +) +engine.train(Stfpm(), datamodule=MVTec()) +``` + +## ⌨️ CLI + +```bash +anomalib train --model Padim --data MVTec --trainer.accelerator xpu --trainer.strategy xpu_single +```