Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction Intervals #1149

Merged
merged 49 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b5590fb
add prediction intervals functionality
GrigoriJasnovidov Aug 18, 2023
09bcc9e
add unit tests
GrigoriJasnovidov Aug 18, 2023
849268e
add example
GrigoriJasnovidov Aug 18, 2023
0b881c0
remove redundant type-hints in doc strings
GrigoriJasnovidov Aug 18, 2023
35dbcf3
delete empty strings in solvers/mutation_of_best_pipeline.py
GrigoriJasnovidov Aug 18, 2023
082363b
add params_with_replacement in test_solver_mutations.py
GrigoriJasnovidov Aug 18, 2023
500d91c
shorten fixture params in test_solver_mutations.py
GrigoriJasnovidov Aug 18, 2023
f8a1ccc
add test data in csv and script for generating/model+data
GrigoriJasnovidov Aug 21, 2023
8c75266
update example
GrigoriJasnovidov Aug 21, 2023
f76295f
correct return in main/PredictionIntervals forecast
GrigoriJasnovidov Aug 21, 2023
6dba725
delete defualt value in metrics quantile loss
GrigoriJasnovidov Aug 21, 2023
43537a9
delete redundant cwc in metrics
GrigoriJasnovidov Aug 21, 2023
0d09cb4
delete redundant np(...) in metrics
GrigoriJasnovidov Aug 21, 2023
9977aff
add detailed error message in utils/check_init_params
GrigoriJasnovidov Aug 21, 2023
a0ae79b
add exception if predictions is ocasionaly [] in solver_mutation
GrigoriJasnovidov Aug 21, 2023
f4be34c
Refactor metrics with numpy functions
kasyanovse Aug 22, 2023
da638a8
Replace some logic with numpy functions
kasyanovse Aug 22, 2023
0dce11b
Fix
kasyanovse Aug 22, 2023
ce92a07
remove copy_model; instead copy last generation and other attributes
GrigoriJasnovidov Aug 23, 2023
fa0a5b6
improve get_last_generations_function
GrigoriJasnovidov Aug 24, 2023
261bcfc
improve get_last_generation
GrigoriJasnovidov Aug 24, 2023
93d07cd
updated unit tests
GrigoriJasnovidov Aug 24, 2023
ea3468a
improve example
GrigoriJasnovidov Aug 24, 2023
57f7204
move figwidth/figheight to def plot_prediction_intervals
GrigoriJasnovidov Aug 24, 2023
bceb018
delete redundant transformations in utils/compute_prediction_intervals
GrigoriJasnovidov Aug 24, 2023
ef1d36c
add details for utils/check_init_params
GrigoriJasnovidov Aug 24, 2023
fb6bf9e
fix self.horizon issues
GrigoriJasnovidov Aug 25, 2023
16c223f
add check ts_test length in visualization/plot_prediction_intervals
GrigoriJasnovidov Aug 25, 2023
62aecc6
improve ts_mutation/get_mutations
GrigoriJasnovidov Aug 25, 2023
f017c99
add plt.show() in visualiztion/plot_prediction_intervals
GrigoriJasnovidov Aug 25, 2023
87e5f68
starting removing utils/pipeline_simple_structure
GrigoriJasnovidov Aug 25, 2023
b505bfb
extend get_distance_between functionality
GrigoriJasnovidov Aug 28, 2023
a9e5f48
rewrite utils/get_different_pipeplines
GrigoriJasnovidov Aug 28, 2023
cba4efb
update test_mutation
GrigoriJasnovidov Aug 28, 2023
f89cf20
delete redundant utils/simple_pipeline_structure
GrigoriJasnovidov Aug 28, 2023
0078a12
delete validataion_blocks parameter for QL tuners
GrigoriJasnovidov Aug 28, 2023
11a89bc
use all possiblie mutations in get_ts_mutation/ts_mutation
GrigoriJasnovidov Aug 29, 2023
d887da4
update ts_mutation/get_different_mutations
GrigoriJasnovidov Aug 29, 2023
821f408
move operations definition for mutation method to PredictionIntervals…
GrigoriJasnovidov Aug 30, 2023
12f7809
remove arima from operations
GrigoriJasnovidov Aug 30, 2023
fec22dc
expand model list for mutation method
GrigoriJasnovidov Aug 30, 2023
795df22
add comment on operations list
GrigoriJasnovidov Aug 30, 2023
217cbb2
add TODOs
GrigoriJasnovidov Aug 30, 2023
b25dae4
pep8 issues
GrigoriJasnovidov Aug 31, 2023
7d83e03
improve utils/check_init_params
GrigoriJasnovidov Aug 31, 2023
06a0028
add restriction to while sycle in ts_mutation/get_different_pipeline
GrigoriJasnovidov Aug 31, 2023
09dc258
improve horizon and forecast definitions in main
GrigoriJasnovidov Aug 31, 2023
7733b3a
correct choice of best individual
GrigoriJasnovidov Sep 1, 2023
d88d9ec
pep issues
GrigoriJasnovidov Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/advanced/time_series_forecasting/prediction_intervals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from fedot.api.main import Fedot
from fedot.core.repository.dataset_types import DataTypesEnum
from fedot.core.data.data import InputData
from fedot.core.repository.tasks import TsForecastingParams, Task, TaskTypesEnum
from fedot.core.utils import fedot_project_root

# class for making prediction intervals
from fedot.core.pipelines.prediction_intervals.main import PredictionIntervals
from fedot.core.pipelines.prediction_intervals.params import PredictionIntervalsParams
# metrics to evaluate results
from fedot.core.pipelines.prediction_intervals.metrics import interval_score, picp


def build_pred_ints(start=5000, end=7000, horizon=200):

# initialize and plot ts
d = pd.read_csv(f'{fedot_project_root()}/examples/data/ts/ts_long.csv')
init_series = d[d['series_id'] == 'temp']['value'].to_numpy()

# d = pd.read_csv('ts study 1.txt')
# init_series = d[d['label'] =='temp']['value'].to_numpy()
kasyanovse marked this conversation as resolved.
Show resolved Hide resolved

ts = init_series[start:end]
ts_test = init_series[end:end + horizon]

fig, ax = plt.subplots()
ax.plot(range(len(ts)), ts)
ax.plot(range(len(ts), len(ts) + len(ts_test)), ts_test)

# create fedot model
task = Task(TaskTypesEnum.ts_forecasting, TsForecastingParams(forecast_length=horizon))
idx = np.array(range(len(np.array(ts))))
kasyanovse marked this conversation as resolved.
Show resolved Hide resolved
train_input = InputData(idx=idx,
features=ts,
target=ts,
task=task,
data_type=DataTypesEnum.ts)
model = Fedot(problem='ts_forecasting',
task_params=task.task_params,
timeout=3,
preset='ts',
show_progress=False)

model.fit(train_input)

model.forecast()
model.plot_prediction()

# initilize PredictionIntervals instance
params = PredictionIntervalsParams(number_mutations=50, show_progress=False, mutations_choice='different')
pred_ints = PredictionIntervals(model=model,
horizon=horizon,
method='mutation_of_best_pipeline',
params=params)

pred_ints.fit(train_input)

x = pred_ints.forecast()
pred_ints.plot(ts_test=ts_test)
nicl-nno marked this conversation as resolved.
Show resolved Hide resolved
GrigoriJasnovidov marked this conversation as resolved.
Show resolved Hide resolved

pred_ints.get_base_quantiles(train_input)
pred_ints.plot_base_quantiles()

# Evaluate results using metrcis picp (predicition interval coverage probability) and interval_score,
# see https://arxiv.org/pdf/2007.05709.pdf

print(f'''intervals_score: {interval_score(ts_test,up=x['up_int'],low=x['low_int'])}
picp: {picp(ts_test,low = x['low_int'],up=x['up_int'])}
''')


if __name__ == '__main__':
build_pred_ints()
Empty file.
261 changes: 261 additions & 0 deletions fedot/core/pipelines/prediction_intervals/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import numpy as np
from functools import partial


from golem.core.log import default_log, Log
from fedot.api.main import Fedot
from fedot.core.data.data import InputData

from fedot.core.pipelines.prediction_intervals.utils import compute_prediction_intervals, model_copy, \
get_base_quantiles, check_init_params
from fedot.core.pipelines.prediction_intervals.visualization import plot_prediction_intervals, \
_plot_prediction_intervals
from fedot.core.pipelines.prediction_intervals.solvers.best_pipelines_quantiles import solver_best_pipelines_quantiles
from fedot.core.pipelines.prediction_intervals.solvers.last_generation_quantile_loss import solver_last_generation_ql
from fedot.core.pipelines.prediction_intervals.solvers.mutation_of_best_pipeline import solver_mutation_of_best_pipeline
from fedot.core.pipelines.prediction_intervals.params import PredictionIntervalsParams


class PredictionIntervals:
"""Class for building prediction intervals based on Fedot functionality.

Args:
model: a fitted Fedot class object for times series forecasting task
horizon: horizon for ts forecasting
nominal_error: nominal error for prediction intervals task
kasyanovse marked this conversation as resolved.
Show resolved Hide resolved
method: general method for creating forecasts. Possible options:

- 'last_generation_ql' -> a number of pipelines from the last generation is fitted using
quantile loss/pinball loss metric.
- 'best_pipelines_quantiles' -> prediction intervals are computed based on predictions of all
pipelines in last generation
- 'mutation_of_best_pipeline' -> prediction intervals are built relying on mutations of the final
pipeline

params: params for building prediction intervals.
"""

def __init__(self,
model: Fedot,
horizon: int = None,
kasyanovse marked this conversation as resolved.
Show resolved Hide resolved
nominal_error: float = 0.1,
method: str = 'mutation_of_best_pipeline',
params: PredictionIntervalsParams = PredictionIntervalsParams()):
GrigoriJasnovidov marked this conversation as resolved.
Show resolved Hide resolved

# check whether given Fedot class object is fitted and argument 'method' is written correctly
check_init_params(model, method)

# general parameters
if params.copy_model:
self.model = model_copy(model)
else:
self.model = model
self.horizon = horizon
self.nominal_error = nominal_error
self.method = method
self.model_forecast = model.forecast(horizon=horizon)
self.best_pipeline = model.current_pipeline

# setup logger
Log().reset_logging_level(params.logging_level)
self.logger = default_log(prefix='PredictionIntervals')

# arrays of auxilary forecasts to build prediction intervals
self.up_predictions = None
self.low_predictions = None
self.all_predictions = None

# prediction intervals
self.up_int = None
self.low_int = None

# base quantiles
self.base_quantiles = None

# flags whether PredictionIntervals instance is fitted/forecasted
self.is_fitted = False
self.is_forecasted = False

# flag whether base quantiles are computed
self.base_quantiles_are_computed = False

# initialize solver for building prediction intervals
if self.method == 'mutation_of_best_pipeline':
self.solver = partial(solver_mutation_of_best_pipeline,
model=self.model,
horizon=self.horizon,
forecast=self.model_forecast,
number_mutations=params.number_mutations,
mutations_choice=params.mutations_choice,
n_jobs=params.n_jobs,
show_progress=params.show_progress,
discard_inapropriate_pipelines=params.mutations_discard_inapropriate_pipelines,
keep_percentage=params.mutation_keep_percentage,
logger=self.logger)

elif self.method == 'best_pipelines_quantiles':
self.solver = partial(solver_best_pipelines_quantiles,
model=self.model,
horizon=self.horizon,
number_models=params.bpq_number_models,
show_progress=params.show_progress,
logger=self.logger)

elif self.method == 'last_generation_ql':
self.solver = partial(solver_last_generation_ql,
model=self.model,
horizon=self.horizon,
nominal_error=self.nominal_error,
number_models=params.ql_number_models,
iterations=params.ql_tuner_iterations,
minutes=params.ql_tuner_minutes,
n_jobs=params.n_jobs,
show_progress=params.show_progress,
validation_blocks=params.ql_tuner_validation_blocks,
up_tuner=params.ql_up_tuner,
low_tuner=params.ql_low_tuner,
logger=self.logger)


regime_up = {'quantile': 'quantile_up', 'mean': 'mean', 'median': 'median', 'absolute_bounds': 'max'}
regime_low = {'quantile': 'quantile_low', 'mean': 'mean', 'median': 'median', 'absolute_bounds': 'min'}


def fit(self, train_input: InputData):
"""This method creates several np.arrays that will be used in method 'forecast' to build prediction intervals.

Fitting process rans by self.solver initialized in method '__init__'. According to the solver several pipelines
are generated and their forecasts are transfered then in method 'forecast'.

Args:
train_input: train data used for training the model.
"""
x = self.solver(train_input)

if self.method == 'last_generation_ql':
self.up_predictions = x['up_predictions']
self.low_predictions = x['low_predictions']
else:
self.all_predictions = x

self.is_fitted = True

def forecast(self, regime: str = 'quantile'):
"""This method builds prediction intervals based on the output of method 'fit'.

Args:
regime (str): a way to compute prediction intervals if argument 'method' is 'last_generation_ql'.

Returns:
dictionary of upper and low prediction intervals.
"""
if not self.is_fitted:
self.logger.critical('PredictionIntervals instance is not fitted! Fit the instance first.')

if self.method == 'last_generation_ql':
quantiles_up = compute_prediction_intervals(self.up_predictions, nominal_error=self.nominal_error)
quantiles_low = compute_prediction_intervals(self.low_predictions, nominal_error=self.nominal_error)

up_int = quantiles_up[self.regime_up[regime]]
low_int = quantiles_low[self.regime_low[regime]]

elif self.method in ['best_pipelines_quantiles', 'mutation_of_best_pipeline']:

quantiles = compute_prediction_intervals(self.all_predictions, nominal_error=self.nominal_error)
up_int = quantiles['quantile_up']
low_int = quantiles['quantile_low']

self.up_int = np.maximum(up_int, self.model_forecast)
self.low_int = np.minimum(low_int, self.model_forecast)

self.is_forecasted = True

return {'up_int': up_int, 'low_int': low_int}
kasyanovse marked this conversation as resolved.
Show resolved Hide resolved


def get_base_quantiles(self, train_input: InputData):
"""Method to get quantiles based on predictions of final pipeline over train_data.

Args:
train_input: InputData train data.

Returns:
dictionary consisting of upper and low quantiles computed for residuals of model forecast over train ts.
"""
base_quantiles = get_base_quantiles(train_input,
pipeline=self.best_pipeline,
nominal_error=self.nominal_error)

self.base_quantiles = {'up': self.model_forecast + base_quantiles['up'],
'low': self.model_forecast + base_quantiles['low']}
self.base_quantiles_are_computed = True

return self.base_quantiles


def plot(self,
show_history: bool = True,
show_forecast: bool = True,
ts_test: np.array = None):
"""Method for plotting obtained prediction intervals, model forecast and test data."""

if self.is_forecasted is False:
self.logger.critical('Prediction intervals are not built! Use fit and then forecast methods first.')

plot_prediction_intervals(model_forecast=self.model_forecast,
up_int=self.up_int,
low_int=self.low_int,
ts=self.model.train_data.features,
show_history=show_history,
show_forecast=show_forecast,
ts_test=ts_test,
labels='pred_ints')


def plot_base_quantiles(self,
show_history: bool = True,
show_forecast: bool = True,
ts_test: np.array = None):
"""Method for plotting prediction intervals built on base quantiles, model forecast and test data."""


if self.base_quantiles_are_computed is False:
self.logger.critical('Base quantiles are not computed! Use get_base_quantiles method first.')

plot_prediction_intervals(model_forecast=self.model_forecast,
up_int=self.base_quantiles['up'],
low_int=self.base_quantiles['low'],
ts=self.model.train_data.features,
show_history=show_history,
show_forecast=show_forecast,
ts_test=ts_test,
labels='base_quantiles')


def _plot(self,
show_up_int=True,
show_low_int=True,
show_forecast=True,
show_history=True,
show_up_train=True,
show_low_train=True,
show_train=True,
ts_test: np.array = None):
""" Old method for plotting prediction intervals, train and test data. Used for developing, will be removed."""

_plot_prediction_intervals(horizon=len(self.model_forecast),
up_predictions=self.up_predictions,
low_predictions=self.low_predictions,
predictions=self.all_predictions,
model_forecast=self.model_forecast,
up_int=self.up_int,
low_int=self.low_int,
ts=self.model.train_data.features,
show_up_int=show_up_int,
show_low_int=show_low_int,
show_forecast=show_forecast,
show_history=show_history,
show_up_train=show_up_train,
show_low_train=show_low_train,
show_train=show_train,
ts_test=ts_test)
Loading