diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 6e6d40bee3..4d052f269a 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -5,12 +5,17 @@ from torch import Tensor +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu + from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) from mmengine.fileio import dump from mmengine.logging import print_log from mmengine.registry import METRICS from mmengine.structures import BaseDataElement +from mmengine.device import is_npu_available class BaseMetric(metaclass=ABCMeta): @@ -49,7 +54,10 @@ def __init__(self, "`collect_device='cpu'`") self._dataset_meta: Union[None, dict] = None - self.collect_device = collect_device + if is_npu_available(): + self.collect_device = 'gpu' + else: + self.collect_device = collect_device self.results: List[Any] = [] self.prefix = prefix or self.default_prefix self.collect_dir = collect_dir