Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add optional prophet forecasting functionality to chart data api #10324

Merged
merged 7 commits into from
Jul 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ assists people when migrating to a new version.

## Next

* [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how to install Superset with the optional dependency prophet? i have already installed superset. Now how to install prophet to my existing superset?

Copy link

@muchemwal muchemwal Jul 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wakilkhan96

If you haven't managed to solve installing prophet, access the terminal for your superset_app and run the following commands:

  1. docker exec -it superset_app /bin/bash
  2. pip uninstall fbprophet pystan
  3. pip --no-cache-dir install pystan==2.19.1.1
  4. pip install prophet

That should solve that issue.
pystan>=3.0 is currently not supported for prophet that why you should specify that version.


* [10320](https://github.com/apache/incubator-superset/pull/10320): References to blacklst/whitelist language have been replaced with more appropriate alternatives. All configs refencing containing `WHITE`/`BLACK` have been replaced with `ALLOW`/`DENY`. Affected config variables that need to be updated: `TIME_GRAIN_BLACKLIST`, `VIZ_TYPE_BLACKLIST`, `DRUID_DATA_SOURCE_BLACKLIST`.

* [9964](https://github.com/apache/incubator-superset/pull/9964): Breaking change on Flask-AppBuilder 3. If you're using OAuth, find out what needs to be changed [here](https://github.com/dpgaspar/Flask-AppBuilder/blob/master/README.rst#change-log).
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_git_sha():
"cockroachdb": ["cockroachdb==0.3.3"],
"thumbnails": ["Pillow>=7.0.0, <8.0.0"],
"excel": ["xlrd>=1.2.0, <1.3"],
"prophet": ["fbprophet>=0.6, <0.7"],
},
python_requires="~=3.6",
author="Apache Software Foundation",
Expand Down
98 changes: 78 additions & 20 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@
}


TIME_GRAINS = (
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-28T00:00:00Z/P1W", # Week starting Sunday
"1969-12-29T00:00:00Z/P1W", # Week starting Monday
"P1W/1970-01-03T00:00:00Z", # Week ending Saturday
"P1W/1970-01-04T00:00:00Z", # Week ending Sunday
)


class ChartPostSchema(Schema):
"""
Schema to add a new chart.
Expand Down Expand Up @@ -423,6 +443,62 @@ class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptions
)


class ChartDataProphetOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Prophet operation config.
"""

time_grain = fields.String(
description="Time grain used to specify time period increments in prediction. "
"Supports [ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
"durations.",
validate=validate.OneOf(choices=TIME_GRAINS),
example="P1D",
required=True,
)
periods = fields.Integer(
descrption="Time periods (in units of `time_grain`) to predict into the future",
min=1,
example=7,
required=True,
)
confidence_interval = fields.Float(
description="Width of predicted confidence interval",
validate=[
Range(
min=0,
max=1,
min_inclusive=False,
max_inclusive=False,
error=_("`confidence_interval` must be between 0 and 1 (exclusive)"),
)
],
example=0.8,
required=True,
)
yearly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should yearly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
weekly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should weekly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)
monthly_seasonality = fields.Raw(
# TODO: add correct union type once supported by Marshmallow
description="Should monthly seasonality be applied. "
"An integer value will specify Fourier order of seasonality, `None` will "
"automatically detect seasonality.",
example=False,
)


class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Pivot operation config.
Expand Down Expand Up @@ -534,6 +610,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
"geohash_decode",
"geohash_encode",
"pivot",
"prophet",
"rolling",
"select",
"sort",
Expand Down Expand Up @@ -613,26 +690,7 @@ class ChartDataExtrasSchema(Schema):
description="To what level of granularity should the temporal column be "
"aggregated. Supports "
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.",
validate=validate.OneOf(
choices=(
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
"1969-12-28T00:00:00Z/P1W", # Week starting Sunday
"1969-12-29T00:00:00Z/P1W", # Week starting Monday
"P1W/1970-01-03T00:00:00Z", # Week ending Saturday
"P1W/1970-01-04T00:00:00Z", # Week ending Sunday
),
),
validate=validate.OneOf(choices=TIME_GRAINS),
example="P1D",
allow_none=True,
)
Expand Down
142 changes: 142 additions & 0 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@
"cumsum",
)

PROPHET_TIME_GRAIN_MAP = {
"PT1S": "S",
"PT1M": "min",
"PT5M": "5min",
"PT10M": "10min",
"PT15M": "15min",
"PT0.5H": "30min",
"PT1H": "H",
"P1D": "D",
"P1W": "W",
"P1M": "M",
"P0.25Y": "Q",
"P1Y": "A",
"1969-12-28T00:00:00Z/P1W": "W",
"1969-12-29T00:00:00Z/P1W": "W",
"P1W/1970-01-03T00:00:00Z": "W",
"P1W/1970-01-04T00:00:00Z": "W",
}


def _flatten_column_after_pivot(
column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
Expand Down Expand Up @@ -544,3 +563,126 @@ def contribution(
if temporal_series is not None:
contribution_df.insert(0, DTTM_ALIAS, temporal_series)
return contribution_df


def _prophet_parse_seasonality(
input_value: Optional[Union[bool, int]]
) -> Union[bool, str, int]:
if input_value is None:
return "auto"
if isinstance(input_value, bool):
return input_value
try:
return int(input_value)
except ValueError:
return input_value


def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
df: DataFrame,
confidence_interval: float,
yearly_seasonality: Union[bool, str, int],
weekly_seasonality: Union[bool, str, int],
daily_seasonality: Union[bool, str, int],
periods: int,
freq: str,
) -> DataFrame:
"""
Fit a prophet model and return a DataFrame with predicted results.
"""
try:
from fbprophet import Prophet # pylint: disable=import-error
except ModuleNotFoundError:
raise QueryObjectValidationError(_("`fbprophet` package not installed"))
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
weekly_seasonality=weekly_seasonality,
daily_seasonality=daily_seasonality,
)
model.fit(df)
future = model.make_future_dataframe(periods=periods, freq=freq)
forecast = model.predict(future)[["ds", "yhat", "yhat_lower", "yhat_upper"]]
return forecast.join(df.set_index("ds"), on="ds").set_index(["ds"])


def prophet( # pylint: disable=too-many-arguments
df: DataFrame,
time_grain: str,
periods: int,
confidence_interval: float,
yearly_seasonality: Optional[Union[bool, int]] = None,
weekly_seasonality: Optional[Union[bool, int]] = None,
daily_seasonality: Optional[Union[bool, int]] = None,
) -> DataFrame:
"""
Add forecasts to each series in a timeseries dataframe, along with confidence
intervals for the prediction. For each series, the operation creates three
new columns with the column name suffixed with the following values:

- `__yhat`: the forecast for the given date
- `__yhat_lower`: the lower bound of the forecast for the given date
- `__yhat_upper`: the upper bound of the forecast for the given date
- `__yhat_upper`: the upper bound of the forecast for the given date


:param df: DataFrame containing all-numeric data (temporal column ignored)
:param time_grain: Time grain used to specify time period increments in prediction
:param periods: Time periods (in units of `time_grain`) to predict into the future
:param confidence_interval: Width of predicted confidence interval
:param yearly_seasonality: Should yearly seasonality be applied.
An integer value will specify Fourier order of seasonality.
:param weekly_seasonality: Should weekly seasonality be applied.
An integer value will specify Fourier order of seasonality, `None` will
automatically detect seasonality.
:param daily_seasonality: Should daily seasonality be applied.
An integer value will specify Fourier order of seasonality, `None` will
automatically detect seasonality.
:return: DataFrame with contributions, with temporal column at beginning if present
"""
# validate inputs
if not time_grain:
raise QueryObjectValidationError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
raise QueryObjectValidationError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not periods or periods < 0 or not isinstance(periods, int):
raise QueryObjectValidationError(_("Periods must be a positive integer value"))
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
raise QueryObjectValidationError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if DTTM_ALIAS not in df.columns:
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
if len(df.columns) < 2:
raise QueryObjectValidationError(_("DataFrame include at least one series"))

target_df = DataFrame()
for column in [column for column in df.columns if column != DTTM_ALIAS]:
fit_df = _prophet_fit_and_predict(
df=df[[DTTM_ALIAS, column]].rename(columns={DTTM_ALIAS: "ds", column: "y"}),
confidence_interval=confidence_interval,
yearly_seasonality=_prophet_parse_seasonality(yearly_seasonality),
weekly_seasonality=_prophet_parse_seasonality(weekly_seasonality),
daily_seasonality=_prophet_parse_seasonality(daily_seasonality),
periods=periods,
freq=freq,
)
new_columns = [
f"{column}__yhat",
f"{column}__yhat_lower",
f"{column}__yhat_upper",
f"{column}",
]
fit_df.columns = new_columns
if target_df.empty:
target_df = fit_df
else:
for new_column in new_columns:
target_df = target_df.assign(**{new_column: fit_df[new_column]})
target_df.reset_index(level=0, inplace=True)
return target_df.rename(columns={"ds": DTTM_ALIAS})
39 changes: 38 additions & 1 deletion tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from datetime import datetime
from unittest import mock

import prison
import humanize
import prison
import pytest
from sqlalchemy.sql import func

from tests.test_app import app
Expand Down Expand Up @@ -796,6 +797,42 @@ def test_chart_data_mixed_case_filter_op(self):
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)

def test_chart_data_prophet(self):
"""
Chart data API: Ensure prophet post transformation works
"""
pytest.importorskip("fbprophet")
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
time_grain = "P1Y"
request_payload["queries"][0]["is_timeseries"] = True
request_payload["queries"][0]["groupby"] = []
request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain}
request_payload["queries"][0]["granularity"] = "ds"
request_payload["queries"][0]["post_processing"] = [
{
"operation": "prophet",
"options": {
"time_grain": time_grain,
"periods": 3,
"confidence_interval": 0.9,
},
}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
print(rv.data)
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
row = result["data"][0]
self.assertIn("__timestamp", row)
self.assertIn("sum__num", row)
self.assertIn("sum__num__yhat", row)
self.assertIn("sum__num__yhat_upper", row)
self.assertIn("sum__num__yhat_lower", row)
self.assertEqual(result["rowcount"], 47)

def test_chart_data_no_data(self):
"""
Chart data API: Test chart data with empty result
Expand Down
15 changes: 14 additions & 1 deletion tests/fixtures/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import date
from datetime import date, datetime

from pandas import DataFrame, to_datetime

Expand Down Expand Up @@ -133,3 +133,16 @@
],
}
)

prophet_df = DataFrame(
{
"__timestamp": [
datetime(2018, 12, 31),
datetime(2019, 12, 31),
datetime(2020, 12, 31),
datetime(2021, 12, 31),
],
"a": [1.1, 1, 1.9, 3.15],
"b": [4, 3, 4.1, 3.95],
}
)
Loading