From 8c2f919cfe6e84bfd005013694d6626404cb5049 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 15 Jul 2020 13:45:01 +0300 Subject: [PATCH 1/7] feat: add prophet post processing operation --- setup.py | 1 + superset/utils/pandas_postprocessing.py | 119 ++++++++++++++++++++++++ tests/pandas_postprocessing_tests.py | 42 +++++++++ 3 files changed, 162 insertions(+) diff --git a/setup.py b/setup.py index 7a6c5d9f5c874..1ebdf319a255e 100644 --- a/setup.py +++ b/setup.py @@ -124,6 +124,7 @@ def get_git_sha(): "cockroachdb": ["cockroachdb==0.3.3"], "thumbnails": ["Pillow>=7.0.0, <8.0.0"], "excel": ["xlrd>=1.2.0, <1.3"], + "prophet": ["fbprophet>=0.6,,0.7"], }, python_requires="~=3.6", author="Apache Software Foundation", diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index e71df74717df8..dc097fa9b4a9c 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -72,6 +72,25 @@ "cumsum", ) +PROPHET_TIME_GRAIN_MAP = { + "PT1S": "S", + "PT1M": "min", + "PT5M": "5min", + "PT10M": "10min", + "PT15M": "15min", + "PT0.5H": "30min", + "PT1H": "H", + "P1D": "D", + "P1W": "W", + "P1M": "M", + "P0.25Y": "Q", + "P1Y": "A", + "1969-12-28T00:00:00Z/P1W": "W", + "1969-12-29T00:00:00Z/P1W": "W", + "P1W/1970-01-03T00:00:00Z": "W", + "P1W/1970-01-04T00:00:00Z": "W", +} + def _flatten_column_after_pivot( column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]] @@ -544,3 +563,103 @@ def contribution( if temporal_series is not None: contribution_df.insert(0, DTTM_ALIAS, temporal_series) return contribution_df + + +def prophet( + df: DataFrame, + time_grain: str, + periods: int, + confidence_interval: float, + yearly_seasonality: Optional[Union[bool, int]] = None, + weekly_seasonality: Optional[Union[bool, int]] = None, + daily_seasonality: Optional[Union[bool, int]] = None, +): + """ + Add forecasts to each series in a timeseries dataframe, along with confidence + intervals for the prediction. For each series, the operation creates three + new columns with the column name suffixed with the following values: + + - `__yhat`: the forecast for the given date + - `__yhat_lower`: the lower bound of the forecast for the given date + - `__yhat_upper`: the upper bound of the forecast for the given date + - `__yhat_upper`: the upper bound of the forecast for the given date + + + :param df: DataFrame containing all-numeric data (temporal column ignored) + :param time_grain: Time grain used to specify time period increments in prediction + :param periods: Time periods (in units of `time_grain`) to predict into the future + :param confidence_interval: Width of predicted confidence interval + :param yearly_seasonality: Should yearly seasonality be applied. + An integer value will specify Fourier order of seasonality. + :param weekly_seasonality: Should weekly seasonality be applied. + An integer value will specify Fourier order of seasonality, `None` will + automatically detect seasonality. + :param daily_seasonality: Should daily seasonality be applied. + An integer value will specify Fourier order of seasonality, `None` will + automatically detect seasonality. + :return: DataFrame with contributions, with temporal column at beginning if present + """ + # validate inputs + if not time_grain: + raise QueryObjectValidationError(_("Time grain missing")) + freq = PROPHET_TIME_GRAIN_MAP.get(time_grain) + if not freq: + raise QueryObjectValidationError( + _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) + ) + if not periods or periods < 0: + raise QueryObjectValidationError(_("Periods must be a positive integer value")) + if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: + raise QueryObjectValidationError( + _("Confidence interval must be between 0 and 1 (exclusive)") + ) + if DTTM_ALIAS not in df.columns: + raise QueryObjectValidationError(_("DataFrame must include temporal column")) + if len(df.columns) < 2: + raise QueryObjectValidationError(_("DataFrame include at least one series")) + + try: + from fbprophet import Prophet + except ModuleNotFoundError: + raise QueryObjectValidationError(_("`fbprophet` package not installed")) + + def _parse_seasonality( + input_value: Optional[Union[bool, int]] + ) -> Union[bool, str, int]: + if input_value is None: + return "auto" + if isinstance(input_value, bool): + return input_value + try: + return int(input_value) + except ValueError: + return input_value + + target_df: Optional[DataFrame] = None + for column in [column for column in df.columns if column != DTTM_ALIAS]: + df_fit = df[[DTTM_ALIAS, column]] + df_fit.columns = ["ds", "y"] + model = Prophet( + interval_width=confidence_interval, + yearly_seasonality=_parse_seasonality(yearly_seasonality), + weekly_seasonality=_parse_seasonality(weekly_seasonality), + daily_seasonality=_parse_seasonality(daily_seasonality), + ) + model.fit(df_fit) + future = model.make_future_dataframe(periods=periods, freq=freq) + forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] + joined = forecast.join(df_fit.set_index("ds"), on="ds").set_index(["ds"]) + new_columns = [ + f"{column}__yhat", + f"{column}__yhat_lower", + f"{column}__yhat_upper", + f"{column}", + ] + joined.columns = new_columns + if target_df is None: + target_df = joined + else: + for new_column in new_columns: + target_df = target_df.assign(**{new_column: joined[new_column]}) + target_df.reset_index(level=0, inplace=True) + return target_df.rename(columns={"ds": DTTM_ALIAS}) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index ea708349ea71f..88c17629d1a3a 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -20,6 +20,7 @@ from typing import Any, List, Optional from pandas import DataFrame, Series +import pytest from superset.exceptions import QueryObjectValidationError from superset.utils import pandas_postprocessing as proc @@ -508,3 +509,44 @@ def test_contribution(self): self.assertListEqual(df.columns.tolist(), ["a", "b"]) self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75]) self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9]) + + def test_prophet(self): + pytest.importorskip("fbprophet") + df_orig = DataFrame( + { + "__timestamp": [ + datetime(2018, 12, 31), + datetime(2019, 12, 31), + datetime(2020, 12, 31), + datetime(2021, 12, 31), + ], + "a": [1.1, 1, 1.9, 3.15], + "b": [4, 3, 4.1, 3.95], + } + ) + + df = proc.prophet( + df=df_orig, time_grain="P1M", periods=3, confidence_interval=0.9 + ) + columns = {column for column in df.columns} + assert columns == { + "__timestamp", + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 3, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) + assert len(df) == 7 + + df = proc.prophet( + df=df_orig, time_grain="P1M", periods=5, confidence_interval=0.9 + ) + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 3, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2026, 12, 31) + assert len(df) == 9 From 4ff25f59a227dfa6e8019fe9c65228306cebf581 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 15 Jul 2020 14:16:57 +0300 Subject: [PATCH 2/7] add tests --- setup.py | 2 +- superset/charts/schemas.py | 98 ++++++++++++++++++++----- superset/utils/pandas_postprocessing.py | 4 +- tests/charts/api_tests.py | 39 +++++++++- tests/pandas_postprocessing_tests.py | 81 +++++++++++++++++++- 5 files changed, 198 insertions(+), 26 deletions(-) diff --git a/setup.py b/setup.py index 1ebdf319a255e..58c22a17094bb 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,7 @@ def get_git_sha(): "cockroachdb": ["cockroachdb==0.3.3"], "thumbnails": ["Pillow>=7.0.0, <8.0.0"], "excel": ["xlrd>=1.2.0, <1.3"], - "prophet": ["fbprophet>=0.6,,0.7"], + "prophet": ["fbprophet>=0.6, <0.7"], }, python_requires="~=3.6", author="Apache Software Foundation", diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 9ccd0f20b9a43..39bda90b06786 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -101,6 +101,26 @@ } +TIME_GRAINS = ( + "PT1S", + "PT1M", + "PT5M", + "PT10M", + "PT15M", + "PT0.5H", + "PT1H", + "P1D", + "P1W", + "P1M", + "P0.25Y", + "P1Y", + "1969-12-28T00:00:00Z/P1W", # Week starting Sunday + "1969-12-29T00:00:00Z/P1W", # Week starting Monday + "P1W/1970-01-03T00:00:00Z", # Week ending Saturday + "P1W/1970-01-04T00:00:00Z", # Week ending Sunday +) + + class ChartPostSchema(Schema): """ Schema to add a new chart. @@ -423,6 +443,62 @@ class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptions ) +class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Prophet operation config. + """ + + time_grain = fields.String( + description="Time grain used to specify time period increments in prediction. " + "Supports [ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) " + "durations.", + validate=validate.OneOf(choices=TIME_GRAINS), + example="P1D", + required=True, + ) + periods = fields.Integer( + descrption="Time periods (in units of `time_grain`) to predict into the future", + min=1, + example=7, + required=True, + ) + confidence_interval = fields.Float( + description="Width of predicted confidence interval", + validate=[ + Range( + min=0, + max=1, + min_inclusive=False, + max_inclusive=False, + error=_("`confidence_interval` must be between 0 and 1 (exclusive)"), + ) + ], + example=0.8, + required=True, + ) + yearly_seasonality = fields.Raw( + # TODO: add correct union type once supported by Marshmallow + description="Should yearly seasonality be applied. " + "An integer value will specify Fourier order of seasonality, `None` will " + "automatically detect seasonality.", + example=False, + ) + weekly_seasonality = fields.Raw( + # TODO: add correct union type once supported by Marshmallow + description="Should weekly seasonality be applied. " + "An integer value will specify Fourier order of seasonality, `None` will " + "automatically detect seasonality.", + example=False, + ) + monthly_seasonality = fields.Raw( + # TODO: add correct union type once supported by Marshmallow + description="Should monthly seasonality be applied. " + "An integer value will specify Fourier order of seasonality, `None` will " + "automatically detect seasonality.", + example=False, + ) + + class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Pivot operation config. @@ -534,6 +610,7 @@ class ChartDataPostProcessingOperationSchema(Schema): "geohash_decode", "geohash_encode", "pivot", + "prophet", "rolling", "select", "sort", @@ -613,26 +690,7 @@ class ChartDataExtrasSchema(Schema): description="To what level of granularity should the temporal column be " "aggregated. Supports " "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.", - validate=validate.OneOf( - choices=( - "PT1S", - "PT1M", - "PT5M", - "PT10M", - "PT15M", - "PT0.5H", - "PT1H", - "P1D", - "P1W", - "P1M", - "P0.25Y", - "P1Y", - "1969-12-28T00:00:00Z/P1W", # Week starting Sunday - "1969-12-29T00:00:00Z/P1W", # Week starting Monday - "P1W/1970-01-03T00:00:00Z", # Week ending Saturday - "P1W/1970-01-04T00:00:00Z", # Week ending Sunday - ), - ), + validate=validate.OneOf(choices=TIME_GRAINS), example="P1D", allow_none=True, ) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index dc097fa9b4a9c..2e395743afb28 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -607,7 +607,9 @@ def prophet( raise QueryObjectValidationError( _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) ) - if not periods or periods < 0: + # check type at runtime due to marhsmallow schema not being able to handle + # union types + if not periods or periods < 0 or not isinstance(periods, int): raise QueryObjectValidationError(_("Periods must be a positive integer value")) if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1: raise QueryObjectValidationError( diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index c115b1d9c7907..9b42a9f6977f5 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -21,8 +21,9 @@ from datetime import datetime from unittest import mock -import prison import humanize +import prison +import pytest from sqlalchemy.sql import func from tests.test_app import app @@ -796,6 +797,42 @@ def test_chart_data_mixed_case_filter_op(self): result = response_payload["result"][0] self.assertEqual(result["rowcount"], 10) + def test_chart_data_prophet(self): + """ + Chart data API: Ensure prophet post transformation works + """ + pytest.importorskip("fbprophet") + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + time_grain = "P1Y" + request_payload["queries"][0]["is_timeseries"] = True + request_payload["queries"][0]["groupby"] = [] + request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain} + request_payload["queries"][0]["granularity"] = "ds" + request_payload["queries"][0]["post_processing"] = [ + { + "operation": "prophet", + "options": { + "time_grain": time_grain, + "periods": 3, + "confidence_interval": 0.9, + }, + } + ] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + print(rv.data) + self.assertEqual(rv.status_code, 200) + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + row = result["data"][0] + self.assertIn("__timestamp", row) + self.assertIn("sum__num", row) + self.assertIn("sum__num__yhat", row) + self.assertIn("sum__num__yhat_upper", row) + self.assertIn("sum__num__yhat_lower", row) + self.assertEqual(result["rowcount"], 47) + def test_chart_data_no_data(self): """ Chart data API: Test chart data with empty result diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index 88c17629d1a3a..bb4b32e06eeb9 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -510,6 +510,81 @@ def test_contribution(self): self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75]) self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9]) + def test_prophet_incorrect_values(self): + df = DataFrame({"a": [1.1, 1, 1.9, 3.15], "b": [4, 3, 4.1, 3.95],}) + + # missing temporal column + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="P1M", + periods=3, + confidence_interval=0.9, + ) + + df = DataFrame( + { + "__timestamp": [ + datetime(2018, 12, 31), + datetime(2019, 12, 31), + datetime(2020, 12, 31), + datetime(2021, 12, 31), + ], + "a": [1.1, 1, 1.9, 3.15], + "b": [4, 3, 4.1, 3.95], + } + ) + + # incorrect confidence interval + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="P1M", + periods=3, + confidence_interval=0.0, + ) + + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="P1M", + periods=3, + confidence_interval=1.1, + ) + + # incorrect confidence interval + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="P1M", + periods=3, + confidence_interval=0.0, + ) + + # incorrect time periods + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="P1M", + periods=0, + confidence_interval=0.8, + ) + + # incorrect time grain + self.assertRaises( + QueryObjectValidationError, + proc.prophet, + df=df, + time_grain="yearly", + periods=10, + confidence_interval=0.8, + ) + def test_prophet(self): pytest.importorskip("fbprophet") df_orig = DataFrame( @@ -540,13 +615,13 @@ def test_prophet(self): "b__yhat_lower", "b", } - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 3, 31) + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) assert len(df) == 7 df = proc.prophet( df=df_orig, time_grain="P1M", periods=5, confidence_interval=0.9 ) - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 3, 31) - assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2026, 12, 31) + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) assert len(df) == 9 From ad6e4fec5e5267d6cf1392295918f4f05c99b837 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 15 Jul 2020 21:54:35 +0300 Subject: [PATCH 3/7] lint --- superset/utils/pandas_postprocessing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 2e395743afb28..450363fe8ac0f 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -565,7 +565,7 @@ def contribution( return contribution_df -def prophet( +def prophet( # pylint: disable=too-many-arguments,too-many-locals df: DataFrame, time_grain: str, periods: int, @@ -573,7 +573,7 @@ def prophet( yearly_seasonality: Optional[Union[bool, int]] = None, weekly_seasonality: Optional[Union[bool, int]] = None, daily_seasonality: Optional[Union[bool, int]] = None, -): +) -> DataFrame: """ Add forecasts to each series in a timeseries dataframe, along with confidence intervals for the prediction. For each series, the operation creates three @@ -621,7 +621,7 @@ def prophet( raise QueryObjectValidationError(_("DataFrame include at least one series")) try: - from fbprophet import Prophet + from fbprophet import Prophet # pylint: disable=import-error except ModuleNotFoundError: raise QueryObjectValidationError(_("`fbprophet` package not installed")) @@ -637,7 +637,7 @@ def _parse_seasonality( except ValueError: return input_value - target_df: Optional[DataFrame] = None + target_df = DataFrame() for column in [column for column in df.columns if column != DTTM_ALIAS]: df_fit = df[[DTTM_ALIAS, column]] df_fit.columns = ["ds", "y"] @@ -658,7 +658,7 @@ def _parse_seasonality( f"{column}", ] joined.columns = new_columns - if target_df is None: + if target_df.empty: target_df = joined else: for new_column in new_columns: From 5c961ab05b61c1bcc33c49582cab0af6750809f5 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 15 Jul 2020 23:37:09 +0300 Subject: [PATCH 4/7] whitespace --- superset-frontend/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset-frontend/package.json b/superset-frontend/package.json index 747c94b76a229..b7f85d72ead3a 100644 --- a/superset-frontend/package.json +++ b/superset-frontend/package.json @@ -1,5 +1,5 @@ { - "name": "superset", + "name": "superset", "version": "0.999.0dev", "description": "Superset is a data exploration platform designed to be visual, intuitive, and interactive.", "license": "Apache-2.0", From 6a61ffdc01986becbe45d9516af378a8fbd1bce2 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 15 Jul 2020 23:37:29 +0300 Subject: [PATCH 5/7] remove whitespace --- superset-frontend/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset-frontend/package.json b/superset-frontend/package.json index b7f85d72ead3a..747c94b76a229 100644 --- a/superset-frontend/package.json +++ b/superset-frontend/package.json @@ -1,5 +1,5 @@ { - "name": "superset", + "name": "superset", "version": "0.999.0dev", "description": "Superset is a data exploration platform designed to be visual, intuitive, and interactive.", "license": "Apache-2.0", From f24b785f56384b4f729af3eed2ab92b13185af8e Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 20 Jul 2020 10:14:52 +0300 Subject: [PATCH 6/7] address comments --- superset/utils/pandas_postprocessing.py | 89 ++++++++++++------- tests/fixtures/dataframes.py | 15 +++- tests/pandas_postprocessing_tests.py | 112 ++++++++---------------- 3 files changed, 107 insertions(+), 109 deletions(-) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 450363fe8ac0f..73336ebdb7500 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -565,7 +565,48 @@ def contribution( return contribution_df -def prophet( # pylint: disable=too-many-arguments,too-many-locals +def _prophet_parse_seasonality( + input_value: Optional[Union[bool, int]] +) -> Union[bool, str, int]: + if input_value is None: + return "auto" + if isinstance(input_value, bool): + return input_value + try: + return int(input_value) + except ValueError: + return input_value + + +def _prophet_fit_and_predict( # pylint: disable=too-many-arguments + df: DataFrame, + confidence_interval: float, + yearly_seasonality: Union[bool, str, int], + weekly_seasonality: Union[bool, str, int], + daily_seasonality: Union[bool, str, int], + periods: int, + freq: str, +) -> DataFrame: + """ + Fit a prophet model and return a DataFrame with predicted results. + """ + try: + from fbprophet import Prophet # pylint: disable=import-error + except ModuleNotFoundError: + raise QueryObjectValidationError(_("`fbprophet` package not installed")) + model = Prophet( + interval_width=confidence_interval, + yearly_seasonality=yearly_seasonality, + weekly_seasonality=weekly_seasonality, + daily_seasonality=daily_seasonality, + ) + model.fit(df) + future = model.make_future_dataframe(periods=periods, freq=freq) + forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] + return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"]) + + +def prophet( # pylint: disable=too-many-arguments df: DataFrame, time_grain: str, periods: int, @@ -602,11 +643,11 @@ def prophet( # pylint: disable=too-many-arguments,too-many-locals # validate inputs if not time_grain: raise QueryObjectValidationError(_("Time grain missing")) - freq = PROPHET_TIME_GRAIN_MAP.get(time_grain) - if not freq: + if time_grain not in PROPHET_TIME_GRAIN_MAP: raise QueryObjectValidationError( _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) ) + freq = PROPHET_TIME_GRAIN_MAP[time_grain] # check type at runtime due to marhsmallow schema not being able to handle # union types if not periods or periods < 0 or not isinstance(periods, int): @@ -620,48 +661,28 @@ def prophet( # pylint: disable=too-many-arguments,too-many-locals if len(df.columns) < 2: raise QueryObjectValidationError(_("DataFrame include at least one series")) - try: - from fbprophet import Prophet # pylint: disable=import-error - except ModuleNotFoundError: - raise QueryObjectValidationError(_("`fbprophet` package not installed")) - - def _parse_seasonality( - input_value: Optional[Union[bool, int]] - ) -> Union[bool, str, int]: - if input_value is None: - return "auto" - if isinstance(input_value, bool): - return input_value - try: - return int(input_value) - except ValueError: - return input_value - target_df = DataFrame() for column in [column for column in df.columns if column != DTTM_ALIAS]: - df_fit = df[[DTTM_ALIAS, column]] - df_fit.columns = ["ds", "y"] - model = Prophet( - interval_width=confidence_interval, - yearly_seasonality=_parse_seasonality(yearly_seasonality), - weekly_seasonality=_parse_seasonality(weekly_seasonality), - daily_seasonality=_parse_seasonality(daily_seasonality), + fit_df = _prophet_fit_and_predict( + df=df[[DTTM_ALIAS, column]].rename(columns={DTTM_ALIAS: "ds", column: "y"}), + confidence_interval=confidence_interval, + yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality), + weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality), + daily_seasonality=_prophet_parse_seasonality(daily_seasonality), + periods=periods, + freq=freq, ) - model.fit(df_fit) - future = model.make_future_dataframe(periods=periods, freq=freq) - forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]] - joined = forecast.join(df_fit.set_index("ds"), on="ds").set_index(["ds"]) new_columns = [ f"{column}__yhat", f"{column}__yhat_lower", f"{column}__yhat_upper", f"{column}", ] - joined.columns = new_columns + fit_df.columns = new_columns if target_df.empty: - target_df = joined + target_df = fit_df else: for new_column in new_columns: - target_df = target_df.assign(**{new_column: joined[new_column]}) + target_df = target_df.assign(**{new_column: fit_df[new_column]}) target_df.reset_index(level=0, inplace=True) return target_df.rename(columns={"ds": DTTM_ALIAS}) diff --git a/tests/fixtures/dataframes.py b/tests/fixtures/dataframes.py index dd01085a18a47..d93242879916b 100644 --- a/tests/fixtures/dataframes.py +++ b/tests/fixtures/dataframes.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import date +from datetime import date, datetime from pandas import DataFrame, to_datetime @@ -133,3 +133,16 @@ ], } ) + +prophet_df = DataFrame( + { + "__timestamp": [ + datetime(2018, 12, 31), + datetime(2019, 12, 31), + datetime(2020, 12, 31), + datetime(2021, 12, 31), + ], + "a": [1.1, 1, 1.9, 3.15], + "b": [4, 3, 4.1, 3.95], + } +) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index bb4b32e06eeb9..479df423c6357 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -27,7 +27,7 @@ from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation from .base_tests import SupersetTestCase -from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df +from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df, prophet_df AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}} AGGREGATES_MULTIPLE = { @@ -510,118 +510,82 @@ def test_contribution(self): self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75]) self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9]) - def test_prophet_incorrect_values(self): - df = DataFrame({"a": [1.1, 1, 1.9, 3.15], "b": [4, 3, 4.1, 3.95],}) + def test_prophet_valid(self): + pytest.importorskip("fbprophet") - # missing temporal column - self.assertRaises( - QueryObjectValidationError, - proc.prophet, - df=df, - time_grain="P1M", - periods=3, - confidence_interval=0.9, + df = proc.prophet( + df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9 ) + columns = {column for column in df.columns} + assert columns == { + DTTM_ALIAS, + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) + assert len(df) == 7 - df = DataFrame( - { - "__timestamp": [ - datetime(2018, 12, 31), - datetime(2019, 12, 31), - datetime(2020, 12, 31), - datetime(2021, 12, 31), - ], - "a": [1.1, 1, 1.9, 3.15], - "b": [4, 3, 4.1, 3.95], - } + df = proc.prophet( + df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9 ) + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) + assert len(df) == 9 + + def test_prophet_missing_temporal_column(self): + df = prophet_df.drop(DTTM_ALIAS, axis=1) - # incorrect confidence interval self.assertRaises( QueryObjectValidationError, proc.prophet, df=df, time_grain="P1M", periods=3, - confidence_interval=0.0, + confidence_interval=0.9, ) + def test_prophet_incorrect_confidence_interval(self): self.assertRaises( QueryObjectValidationError, proc.prophet, - df=df, + df=prophet_df, time_grain="P1M", periods=3, - confidence_interval=1.1, + confidence_interval=0.0, ) - # incorrect confidence interval self.assertRaises( QueryObjectValidationError, proc.prophet, - df=df, + df=prophet_df, time_grain="P1M", periods=3, - confidence_interval=0.0, + confidence_interval=1.0, ) - # incorrect time periods + def test_prophet_incorrect_periods(self): self.assertRaises( QueryObjectValidationError, proc.prophet, - df=df, + df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.8, ) - # incorrect time grain + def test_prophet_incorrect_time_grain(self): self.assertRaises( QueryObjectValidationError, proc.prophet, - df=df, + df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8, ) - - def test_prophet(self): - pytest.importorskip("fbprophet") - df_orig = DataFrame( - { - "__timestamp": [ - datetime(2018, 12, 31), - datetime(2019, 12, 31), - datetime(2020, 12, 31), - datetime(2021, 12, 31), - ], - "a": [1.1, 1, 1.9, 3.15], - "b": [4, 3, 4.1, 3.95], - } - ) - - df = proc.prophet( - df=df_orig, time_grain="P1M", periods=3, confidence_interval=0.9 - ) - columns = {column for column in df.columns} - assert columns == { - "__timestamp", - "a__yhat", - "a__yhat_upper", - "a__yhat_lower", - "a", - "b__yhat", - "b__yhat_upper", - "b__yhat_lower", - "b", - } - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) - assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) - assert len(df) == 7 - - df = proc.prophet( - df=df_orig, time_grain="P1M", periods=5, confidence_interval=0.9 - ) - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) - assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) - assert len(df) == 9 From 2bfad7e3efcd1850ef2c7c57648e6eece2497c14 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Mon, 20 Jul 2020 10:47:26 +0300 Subject: [PATCH 7/7] add note to UPDATING.md --- UPDATING.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/UPDATING.md b/UPDATING.md index 5b4369114e28e..420cb031068f4 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,6 +23,8 @@ assists people when migrating to a new version. ## Next +* [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`. + * [10320](https://github.com/apache/incubator-superset/pull/10320): References to blacklst/whitelist language have been replaced with more appropriate alternatives. All configs refencing containing `WHITE`/`BLACK` have been replaced with `ALLOW`/`DENY`. Affected config variables that need to be updated: `TIME_GRAIN_BLACKLIST`, `VIZ_TYPE_BLACKLIST`, `DRUID_DATA_SOURCE_BLACKLIST`. * [9964](https://github.com/apache/incubator-superset/pull/9964): Breaking change on Flask-AppBuilder 3. If you're using OAuth, find out what needs to be changed [here](https://github.com/dpgaspar/Flask-AppBuilder/blob/master/README.rst#change-log).