Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Add Multi-GPU Training Support #2435

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/anomalib/data/validators/torch/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,10 +588,10 @@
Examples:
>>> import torch
>>> from anomalib.data.validators import VideoBatchValidator
>>> gt_masks = torch.rand(2, 10, 224, 224) > 0.5 # 2 videos, 10 frames each
>>> gt_masks = torch.rand(10, 224, 224) > 0.5 # 10 frames each
>>> validated_masks = VideoBatchValidator.validate_gt_mask(gt_masks)
>>> print(validated_masks.shape)
torch.Size([2, 10, 224, 224])
torch.Size([10, 224, 224])
>>> single_frame_masks = torch.rand(4, 456, 256) > 0.5 # 4 single-frame images
>>> validated_single_frame = VideoBatchValidator.validate_gt_mask(single_frame_masks)
>>> print(validated_single_frame.shape)
Expand All @@ -600,17 +600,18 @@
if mask is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to revisit the docstrings for this method

return None
if not isinstance(mask, torch.Tensor):
msg = f"Masks must be a torch.Tensor, got {type(mask)}."
msg = f"Ground truth mask must be a torch.Tensor, got {type(mask)}."
raise TypeError(msg)
if mask.ndim not in {3, 4, 5}:
msg = f"Masks must have shape [B, H, W], [B, T, H, W] or [B, T, 1, H, W], got shape {mask.shape}."
if mask.ndim not in {2, 3, 4}:
msg = f"Ground truth mask must have shape [H, W] or [N, H, W] or [N, 1, H, W] got shape {mask.shape}."

Check warning on line 606 in src/anomalib/data/validators/torch/video.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/data/validators/torch/video.py#L606

Added line #L606 was not covered by tests
raise ValueError(msg)
if mask.ndim == 5:
if mask.shape[2] != 1:
msg = f"Masks must have 1 channel, got {mask.shape[2]}."
if mask.ndim == 2:
mask = mask.unsqueeze(0)

Check warning on line 609 in src/anomalib/data/validators/torch/video.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/data/validators/torch/video.py#L609

Added line #L609 was not covered by tests
if mask.ndim == 4:
if mask.shape[1] != 1:
msg = f"Ground truth mask must have 1 channel, got {mask.shape[1]}."
raise ValueError(msg)
mask = mask.squeeze(2)

mask = mask.squeeze(1)
return Mask(mask, dtype=torch.bool)

@staticmethod
Expand Down
3 changes: 0 additions & 3 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,6 @@ def _setup_trainer(self, model: AnomalibModule) -> None:
# Setup anomalib callbacks to be used with the trainer
self._setup_anomalib_callbacks(model)

# Temporarily set devices to 1 to avoid issues with multiple processes
self._cache.args["devices"] = 1

# Instantiate the trainer if it is not already instantiated
if self._trainer is None:
self._trainer = Trainer(**self._cache.args)
Expand Down
14 changes: 12 additions & 2 deletions src/anomalib/metrics/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from collections.abc import Sequence
from typing import Any

Expand All @@ -14,6 +15,8 @@

from anomalib.metrics import AnomalibMetric

logger = logging.getLogger(__name__)


class Evaluator(nn.Module, Callback):
"""Evaluator module for LightningModule.
Expand Down Expand Up @@ -53,8 +56,15 @@
super().__init__()
self.val_metrics = ModuleList(self.validate_metrics(val_metrics))
self.test_metrics = ModuleList(self.validate_metrics(test_metrics))

if compute_on_cpu:
self.compute_on_cpu = compute_on_cpu

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Move metrics to cpu if ``num_devices == 1`` and ``compute_on_cpu`` is set to ``True``."""
del pl_module, stage # Unused arguments.
if trainer.num_devices > 1:
if self.compute_on_cpu:
logger.warning("Number of devices is greater than 1, setting compute_on_cpu to False.")

Check warning on line 66 in src/anomalib/metrics/evaluator.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/metrics/evaluator.py#L65-L66

Added lines #L65 - L66 were not covered by tests
elif self.compute_on_cpu:
self.metrics_to_cpu(self.val_metrics)
self.metrics_to_cpu(self.test_metrics)

Expand Down
5 changes: 3 additions & 2 deletions src/anomalib/models/components/base/memory_bank_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_buffer("_is_fitted", torch.tensor([False]))
self.device: torch.device # defined in lightning module
self._is_fitted: torch.Tensor

@abstractmethod
Expand All @@ -34,10 +35,10 @@
"""Ensure that the model is fitted before validation starts."""
if not self._is_fitted:
self.fit()
self._is_fitted = torch.tensor([True])
self._is_fitted = torch.tensor([True], device=self.device)

def on_train_epoch_end(self) -> None:
"""Ensure that the model is fitted before validation starts."""
if not self._is_fitted:
self.fit()
self._is_fitted = torch.tensor([True])
self._is_fitted = torch.tensor([True], device=self.device)

Check warning on line 44 in src/anomalib/models/components/base/memory_bank_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/memory_bank_module.py#L44

Added line #L44 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@

# if max training points is non-zero and smaller than number of staged features, select random subset
if embeddings.shape[0] > self.max_training_points:
selected_idx = torch.tensor(random.sample(range(embeddings.shape[0]), self.max_training_points))
selected_idx = torch.tensor(

Check warning on line 96 in src/anomalib/models/components/classification/kde_classifier.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/classification/kde_classifier.py#L96

Added line #L96 was not covered by tests
random.sample(range(embeddings.shape[0]), self.max_training_points),
device=embeddings.device,
)
selected_features = embeddings[selected_idx]
else:
selected_features = embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fit(self, dataset: torch.Tensor) -> None:
else:
num_components = int(self.n_components)

self.num_components = torch.Tensor([num_components])
self.num_components = torch.tensor([num_components], device=dataset.device)

self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float()
self.singular_values = sig[:num_components].float()
Expand All @@ -98,7 +98,7 @@ def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor:
mean = dataset.mean(dim=0)
dataset -= mean
num_components = int(self.n_components)
self.num_components = torch.Tensor([num_components])
self.num_components = torch.tensor([num_components], device=dataset.device)

v_h = torch.linalg.svd(dataset)[-1]
self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components]
Expand Down
3 changes: 3 additions & 0 deletions src/anomalib/models/image/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
embedding = self.model(batch.image)
self.embeddings.append(embedding)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a KDE Model to the embedding collected from the training set."""
embeddings = torch.vstack(self.embeddings)
Expand Down
3 changes: 3 additions & 0 deletions src/anomalib/models/image/dfm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
embedding = self.model.get_features(batch.image).squeeze()
self.embeddings.append(embedding)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a PCA transformation and a Gaussian model to dataset."""
logger.info("Aggregating the embedding extracted from the training set.")
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/image/dfm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
dataset (torch.Tensor): Input dataset to fit the model.
"""
num_samples = dataset.shape[1]
self.mean_vec = torch.mean(dataset, dim=1)
self.mean_vec = torch.mean(dataset, dim=1, device=dataset.device)

Check warning on line 44 in src/anomalib/models/image/dfm/torch_model.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/dfm/torch_model.py#L44

Added line #L44 was not covered by tests
data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / math.sqrt(num_samples)
self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False)

Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/image/dsr/anomaly_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
masks_list: list[Tensor] = []
for _ in range(batch_size):
if torch.rand(1) > self.p_anomalous: # include normal samples
masks_list.append(torch.zeros((1, height, width)))
masks_list.append(torch.zeros((1, height, width), device=batch.device))

Check warning on line 76 in src/anomalib/models/image/dsr/anomaly_generator.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/dsr/anomaly_generator.py#L76

Added line #L76 was not covered by tests
else:
mask = self.generate_anomaly(height, width)
masks_list.append(mask)
Expand Down
5 changes: 4 additions & 1 deletion src/anomalib/models/image/padim/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
del args, kwargs # These variables are not used.

embedding = self.model(batch.image)
self.embeddings.append(embedding.cpu())
self.embeddings.append(embedding)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a Gaussian to the embedding collected from the training set."""
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:

embedding = self.model(batch.image)
self.embeddings.append(embedding)
# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Apply subsampling to the embedding collected from the training set."""
Expand Down
13 changes: 5 additions & 8 deletions src/anomalib/models/video/ai_vad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import replace
from typing import Any

import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib import LearningType
Expand Down Expand Up @@ -123,6 +123,9 @@
self.model.density_estimator.update(features, video_path)
self.total_detections += len(next(iter(features.values())))

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

Check warning on line 127 in src/anomalib/models/video/ai_vad/lightning_model.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/video/ai_vad/lightning_model.py#L127

Added line #L127 was not covered by tests

def fit(self) -> None:
"""Fit the density estimators to the extracted features from the training set."""
if self.total_detections == 0:
Expand All @@ -146,13 +149,7 @@
del args, kwargs # Unused arguments.

predictions = self.model(batch.image)

return replace(
batch,
pred_score=predictions.pred_score,
anomaly_map=predictions.anomaly_map,
pred_mask=predictions.pred_mask,
)
return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map)

Check warning on line 152 in src/anomalib/models/video/ai_vad/lightning_model.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/video/ai_vad/lightning_model.py#L152

Added line #L152 was not covered by tests

@property
def trainer_arguments(self) -> dict[str, Any]:
Expand Down
7 changes: 0 additions & 7 deletions src/anomalib/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,3 @@ def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None:
"Anomalib's models and visualizer are currently not compatible with video datasets with a clip length > 1. "
"Custom changes to these modules will be needed to prevent errors and/or unpredictable behaviour.",
)
if (
"devices" in config.trainer
and (config.trainer.devices is None or config.trainer.devices != 1)
and config.trainer.accelerator != "cpu"
):
logger.warning("Anomalib currently does not support multi-gpu training. Setting devices to 1.")
config.trainer.devices = 1
10 changes: 5 additions & 5 deletions tests/unit/data/validators/torch/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def test_validate_gt_label_invalid_type(self) -> None:

def test_validate_gt_mask_valid(self) -> None:
"""Test validation of valid ground truth masks."""
masks = torch.randint(0, 2, (2, 10, 224, 224))
masks = torch.randint(0, 2, (10, 1, 224, 224))
validated_masks = self.validator.validate_gt_mask(masks)
assert isinstance(validated_masks, Mask)
assert validated_masks.shape == (2, 10, 224, 224)
assert validated_masks.shape == (10, 224, 224)
assert validated_masks.dtype == torch.bool

def test_validate_gt_mask_none(self) -> None:
Expand All @@ -186,13 +186,13 @@ def test_validate_gt_mask_none(self) -> None:

def test_validate_gt_mask_invalid_type(self) -> None:
"""Test validation of ground truth masks with invalid type."""
with pytest.raises(TypeError, match="Masks must be a torch.Tensor"):
with pytest.raises(TypeError, match="Ground truth mask must be a torch.Tensor"):
self.validator.validate_gt_mask([torch.zeros(10, 224, 224)])

def test_validate_gt_mask_invalid_shape(self) -> None:
"""Test validation of ground truth masks with invalid shape."""
with pytest.raises(ValueError, match="Masks must have 1 channel, got 2."):
self.validator.validate_gt_mask(torch.zeros(2, 10, 2, 224, 224))
with pytest.raises(ValueError, match="Ground truth mask must have 1 channel, got 2."):
self.validator.validate_gt_mask(torch.zeros(10, 2, 224, 224))

def test_validate_anomaly_map_valid(self) -> None:
"""Test validation of a valid anomaly map batch."""
Expand Down