Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Add XPU accelerator #2530

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- πŸš€ Add XPU accelerator and strategy by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2530

### Removed

### Changed
Expand Down
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,34 @@ anomalib predict --model anomalib.models.Patchcore \

> πŸ“˜ **Note:** For advanced inference options including Gradio and OpenVINO, check our [Inference Documentation](https://anomalib.readthedocs.io).

# Training on Intel GPUs

> [!Note]
> Currently, only single GPU training is supported on Intel GPUs.
> These commands were tested on Arc 750 and Arc 770.

Ensure that you have PyTorch with XPU support installed. For more information, please refer to the [PyTorch XPU documentation](https://pytorch.org/docs/stable/notes/get_start_xpu.html)

## πŸ”Œ API

```python
from anomalib.data import MVTec
from anomalib.engine import Engine, SingleXPUStrategy, XPUAccelerator
from anomalib.models import Stfpm

engine = Engine(
strategy=SingleXPUStrategy(),
accelerator=XPUAccelerator(),
)
engine.train(Stfpm(), datamodule=MVTec())
```

## ⌨️ CLI

```bash
anomalib train --model Padim --data MVTec --trainer.accelerator xpu --trainer.strategy xpu_single
```

# βš™οΈ Hyperparameter Optimization

Anomalib supports hyperparameter optimization (HPO) using [Weights & Biases](https://wandb.ai/) and [Comet.ml](https://www.comet.com/).
Expand Down
8 changes: 8 additions & 0 deletions docs/source/markdown/guides/how_to/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ Learn more about anomalib's deployment capabilities
Learn more about anomalib hpo, sweep and benchmarking pipelines
:::

:::{grid-item-card} {octicon}`cpu` Training on Intel GPUs
:link: ./training_on_intel_gpus/index
:link-type: doc

Learn more about training on Intel GPUs
:::

::::

```{toctree}
Expand All @@ -83,4 +90,5 @@ Learn more about anomalib hpo, sweep and benchmarking pipelines
./models/index
./pipelines/index
./visualization/index
./training_on_intel_gpus/index
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Training on Intel GPUs

This tutorial demonstrates how to train a model on Intel GPUs using anomalib.
Anomalib comes with XPU accelerator and strategy for PyTorch Lightning. This allows you to train your models on Intel GPUs.

> [!Note]
> Currently, only single GPU training is supported on Intel GPUs.
> These commands were tested on Arc 750 and Arc 770.

## Installing Drivers

First, check if you have the correct drivers installed. If you are on Ubuntu, you can refer to the [following guide](https://dgpu-docs.intel.com/driver/client/overview.html).

Another recommended tool is `xpu-smi` which can be installed from the [releases](https://github.com/intel/xpumanager) page.

If everything is installed correctly, you should be able to see your card using the following command:

```bash
xpu-smi discovery
```

## Installing PyTorch

Then, ensure that you have PyTorch with XPU support installed. For more information, please refer to the [PyTorch XPU documentation](https://pytorch.org/docs/stable/notes/get_start_xpu.html)

To ensure that your PyTorch installation supports XPU, you can run the following command:

```bash
python -c "import torch; print(torch.xpu.is_available())"
```

If the command returns `True`, then your PyTorch installation supports XPU.

## πŸ”Œ API

```python
from anomalib.data import MVTec
from anomalib.engine import Engine, SingleXPUStrategy, XPUAccelerator
from anomalib.models import Stfpm

engine = Engine(
strategy=SingleXPUStrategy(),
accelerator=XPUAccelerator(),
)
engine.train(Stfpm(), datamodule=MVTec())
```

## ⌨️ CLI

```bash
anomalib train --model Padim --data MVTec --trainer.accelerator xpu --trainer.strategy xpu_single
```
6 changes: 4 additions & 2 deletions src/anomalib/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
>>> engine = Engine(config=config) # doctest: +SKIP
"""

# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .accelerator import XPUAccelerator
from .engine import Engine
from .strategy import SingleXPUStrategy

__all__ = ["Engine"]
__all__ = ["Engine", "SingleXPUStrategy", "XPUAccelerator"]
8 changes: 8 additions & 0 deletions src/anomalib/engine/accelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Accelerator for Lightning Trainer."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .xpu import XPUAccelerator

__all__ = ["XPUAccelerator"]
65 changes: 65 additions & 0 deletions src/anomalib/engine/accelerator/xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""XPU Accelerator."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import torch
from lightning.pytorch.accelerators import Accelerator, AcceleratorRegistry


class XPUAccelerator(Accelerator):
"""Support for a XPU, optimized for large-scale machine learning."""

accelerator_name = "xpu"

@staticmethod
def setup_device(device: torch.device) -> None:
"""Sets up the specified device."""
if device.type != "xpu":
msg = f"Device should be xpu, got {device} instead"
raise RuntimeError(msg)

Check warning on line 22 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L20-L22

Added lines #L20 - L22 were not covered by tests

torch.xpu.set_device(device)

Check warning on line 24 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L24

Added line #L24 was not covered by tests

@staticmethod
def parse_devices(devices: str | list | torch.device) -> list:
"""Parses devices for multi-GPU training."""
if isinstance(devices, list):
return devices
return [devices]

Check warning on line 31 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L29-L31

Added lines #L29 - L31 were not covered by tests

@staticmethod
def get_parallel_devices(devices: list) -> list[torch.device]:
"""Generates a list of parrallel devices."""
return [torch.device("xpu", idx) for idx in devices]

Check warning on line 36 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L36

Added line #L36 was not covered by tests

@staticmethod
def auto_device_count() -> int:
"""Returns number of XPU devices available."""
return torch.xpu.device_count()

Check warning on line 41 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L41

Added line #L41 was not covered by tests

@staticmethod
def is_available() -> bool:
"""Checks if XPU available."""
return hasattr(torch, "xpu") and torch.xpu.is_available()

Check warning on line 46 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L46

Added line #L46 was not covered by tests

@staticmethod
def get_device_stats(device: str | torch.device) -> dict[str, Any]:
"""Returns XPU devices stats."""
del device # Unused
return {}

Check warning on line 52 in src/anomalib/engine/accelerator/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/accelerator/xpu.py#L51-L52

Added lines #L51 - L52 were not covered by tests

def teardown(self) -> None:
"""Teardown the XPU accelerator.

This method is empty as it needs to be overridden otherwise the base class will throw an error.
"""


AcceleratorRegistry.register(
XPUAccelerator.accelerator_name,
XPUAccelerator,
description="Accelerator supports XPU devices",
)
8 changes: 8 additions & 0 deletions src/anomalib/engine/strategy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Strategy for Lightning Trainer."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .xpu_single import SingleXPUStrategy

__all__ = ["SingleXPUStrategy"]
43 changes: 43 additions & 0 deletions src/anomalib/engine/strategy/xpu_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Lightning strategy for single XPU device."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import lightning.pytorch as pl
import torch
from lightning.pytorch.strategies import SingleDeviceStrategy, StrategyRegistry
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.plugins.precision import Precision
from lightning_fabric.utilities.types import _DEVICE


class SingleXPUStrategy(SingleDeviceStrategy):
"""Strategy for training on single XPU device."""

strategy_name = "xpu_single"

def __init__(
self,
device: _DEVICE = "xpu:0",
accelerator: pl.accelerators.Accelerator | None = None,
checkpoint_io: CheckpointIO | None = None,
precision_plugin: Precision | None = None,
) -> None:
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
msg = "`SingleXPUStrategy` requires XPU devices to run"
raise MisconfigurationException(msg)

Check warning on line 29 in src/anomalib/engine/strategy/xpu_single.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/strategy/xpu_single.py#L27-L29

Added lines #L27 - L29 were not covered by tests

super().__init__(

Check warning on line 31 in src/anomalib/engine/strategy/xpu_single.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/engine/strategy/xpu_single.py#L31

Added line #L31 was not covered by tests
accelerator=accelerator,
device=device,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)


StrategyRegistry.register(
SingleXPUStrategy.strategy_name,
SingleXPUStrategy,
description="Strategy that enables training on single XPU",
)
Loading