-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathmetrics_objective.py
30 lines (24 loc) · 1.2 KB
/
metrics_objective.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Union, Iterable, Callable
from golem.core.optimisers.objective import Objective
from golem.core.utilities.data_structures import ensure_wrapped_in_sequence
from fedot.core.repository.quality_metrics_repository import MetricType, MetricsRepository, ComplexityMetricsEnum
class MetricsObjective(Objective):
def __init__(self,
metrics: Union[MetricType, Iterable[MetricType]],
is_multi_objective: bool = False):
quality_metrics = {}
complexity_metrics = {}
for metric in ensure_wrapped_in_sequence(metrics):
if callable(metric):
metric_id = str(metric)
quality_metrics[metric_id] = metric
else:
metric_func = MetricsRepository.metric_by_id(metric)
if metric_func:
if ComplexityMetricsEnum.has_value(metric):
complexity_metrics[metric] = metric_func
else:
quality_metrics[metric] = metric_func
else:
raise ValueError(f'Incorrect metric {metric}')
super().__init__(quality_metrics, complexity_metrics, is_multi_objective)