Skip to content

Commit

Permalink
[FEAT] NeuralForecast Compatibility and Example Notebook (#188)
Browse files Browse the repository at this point in the history
* added median checking in reconcile

* updated drop_cols logic

* added MLFrameworksExample notebook

* updated MLFrameworksExamble nb documentation

* removed extra -median check

* updated sidebar.yml with new example nb
  • Loading branch information
dluuo authored Apr 24, 2023
1 parent 3364340 commit 3336197
Show file tree
Hide file tree
Showing 4 changed files with 9,069 additions and 4 deletions.
9 changes: 7 additions & 2 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def _reverse_engineer_sigmah(Y_hat_df, y_hat, model_name):
In the future, we might deprecate this function in favor of a
direct usage of an estimated $\hat{sigma}_{h}$
"""
drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds']

drop_cols = ['ds']
if 'y' in Y_hat_df.columns:
drop_cols.append('y')
if model_name+'-median' in Y_hat_df.columns:
drop_cols.append(model_name+'-median')
model_names = Y_hat_df.drop(columns=drop_cols, axis=1).columns.to_list()
pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)]
pi_model_name = [pi_name for pi_name in pi_model_names if model_name in pi_name]
Expand Down Expand Up @@ -143,7 +148,7 @@ def _prepare_fit(self,
if Y_hat_df[model_names].isnull().values.any():
raise Exception('`Y_hat_df` contains null values')

pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)]
pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name or '-median' in name)]
model_names = [name for name in model_names if name not in pi_model_names]

# TODO: Complete y_hat_insample protection
Expand Down
9 changes: 7 additions & 2 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@
" In the future, we might deprecate this function in favor of a \n",
" direct usage of an estimated $\\hat{sigma}_{h}$\n",
" \"\"\"\n",
" drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds']\n",
"\n",
" drop_cols = ['ds']\n",
" if 'y' in Y_hat_df.columns:\n",
" drop_cols.append('y')\n",
" if model_name+'-median' in Y_hat_df.columns:\n",
" drop_cols.append(model_name+'-median')\n",
" model_names = Y_hat_df.drop(columns=drop_cols, axis=1).columns.to_list()\n",
" pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)]\n",
" pi_model_name = [pi_name for pi_name in pi_model_names if model_name in pi_name]\n",
Expand Down Expand Up @@ -243,7 +248,7 @@
" if Y_hat_df[model_names].isnull().values.any():\n",
" raise Exception('`Y_hat_df` contains null values')\n",
" \n",
" pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)]\n",
" pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name or '-median' in name)]\n",
" model_names = [name for name in model_names if name not in pi_model_names]\n",
" \n",
" # TODO: Complete y_hat_insample protection\n",
Expand Down
Loading

0 comments on commit 3336197

Please sign in to comment.