Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MinTrace's protection to Schafer-Strimmer covariance, eliminated stat… #97

Merged
merged 7 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dependencies:
- numba
- pandas
- scikit-learn
- statsmodels
kdgutier marked this conversation as resolved.
Show resolved Hide resolved
- pip
- pip:
- nbdev
Expand Down
2 changes: 2 additions & 0 deletions hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods._reconcile_fcst_proportions': ( 'methods.html#_reconcile_fcst_proportions',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.cov2corr': ( 'methods.html#cov2corr',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.crossprod': ( 'methods.html#crossprod',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.is_strictly_hierarchical': ( 'methods.html#is_strictly_hierarchical',
Expand Down
2 changes: 2 additions & 0 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def reconcile(self,
uids = Y_hat_df.index.unique()
# check if Y_hat_df has the same uids as S
if len(S.index.difference(uids)) > 0 or len(Y_hat_df.index.difference(S.index.unique())) > 0:
print(len(S.index.difference(uids)))
print(len(Y_hat_df.index.difference(S.index.unique())))
kdgutier marked this conversation as resolved.
Show resolved Hide resolved
raise Exception('Summing matrix `S` and `Y_hat_df` do not have the same time series, please check.')
# same order of Y_hat_df to prevent errors
S_ = S.loc[uids]
Expand Down
58 changes: 51 additions & 7 deletions hierarchicalforecast/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
from numba import njit
from quadprog import solve_qp
from statsmodels.stats.moment_helpers import cov2corr

# %% ../nbs/methods.ipynb 5
def _reconcile(S: np.ndarray,
Expand Down Expand Up @@ -336,6 +335,25 @@ def crossprod(x):
return x.T @ x

# %% ../nbs/methods.ipynb 33
def cov2corr(cov, return_std=False):
""" convert covariance matrix to correlation matrix

**Parameters:**<br>
`cov`: array_like, 2d covariance matrix.<br>
`return_std`: bool=False, if True returned std.<br>

**Returns:**<br>
`corr`: ndarray (subclass) correlation matrix
"""
cov = np.asanyarray(cov)
std_ = np.sqrt(np.diag(cov))
corr = cov / np.outer(std_, std_)
if return_std:
return corr, std_
else:
return corr

# %% ../nbs/methods.ipynb 34
kdgutier marked this conversation as resolved.
Show resolved Hide resolved
class MinTrace:
"""MinTrace Reconciliation Class.

Expand Down Expand Up @@ -398,27 +416,53 @@ def reconcile(self,
elif self.method == 'wls_struct':
W = np.diag(S @ np.ones((n_bottom,)))
elif self.method in res_methods:
#we need residuals with shape (obs, n_hiers)
# Residuals with shape (obs, n_hiers)
residuals = (y_insample - y_hat_insample).T
n, _ = residuals.shape

# Protection: against overfitted model
residuals_sum = np.sum(residuals, axis=0)
zero_residual_prc = np.abs(residuals_sum) < 1e-4
zero_residual_prc = np.mean(zero_residual_prc)
if zero_residual_prc > .98:
raise Exception(f'Insample residuals close to 0, zero_residual_prc={zero_residual_prc}. Check `Y_df`')

# Protection: cases where data is unavailable/nan
masked_res = np.ma.array(residuals, mask=np.isnan(residuals))
covm = np.ma.cov(masked_res, rowvar=False, allow_masked=True).data

if self.method == 'wls_var':
W = np.diag(np.diag(covm))
elif self.method == 'mint_cov':
W = covm
elif self.method == 'mint_shrink':
# Schäfer and Strimmer 2005, scale invariant shrinkage
# lasso or ridge might improve numerical stability but
# this version follows https://robjhyndman.com/papers/MinT.pdf
tar = np.diag(np.diag(covm))

# Protection: constant's correlation set to 0
corm = cov2corr(covm)
xs = np.divide(residuals, np.sqrt(np.diag(covm)))
corm = np.nan_to_num(corm, nan=0.0)

# Protection: standardized residuals 0 where residual_std=0
residual_std = np.sqrt(np.diag(covm))
kdgutier marked this conversation as resolved.
Show resolved Hide resolved
xs = np.divide(residuals, residual_std,
out=np.zeros_like(residuals), where=residual_std!=0)

xs = xs[~np.isnan(xs).any(axis=1), :]
v = (1 / (n * (n - 1))) * (crossprod(xs ** 2) - (1 / n) * (crossprod(xs) ** 2))
np.fill_diagonal(v, 0)

# Protection: constant's correlation set to 0
corapn = cov2corr(tar)
corapn = np.nan_to_num(corapn, nan=0.0)
d = (corm - corapn) ** 2
lmd = v.sum() / d.sum()
lmd = max(min(lmd, 1), 0)
W = lmd * tar + (1 - lmd) * covm

# Protection: final ridge diagonal protection
W = (lmd * tar + (1 - lmd) * covm) + 1e-8
else:
raise ValueError(f'Unkown reconciliation method {self.method}')

Expand Down Expand Up @@ -468,7 +512,7 @@ def reconcile(self,

__call__ = reconcile

# %% ../nbs/methods.ipynb 40
# %% ../nbs/methods.ipynb 41
class OptimalCombination(MinTrace):
"""Optimal Combination Reconciliation Class.

Expand Down Expand Up @@ -503,7 +547,7 @@ def __init__(self,
self.nonnegative = nonnegative
self.insample = False

# %% ../nbs/methods.ipynb 46
# %% ../nbs/methods.ipynb 47
@njit
def lasso(X: np.ndarray, y: np.ndarray,
lambda_reg: float, max_iters: int = 1_000,
Expand Down Expand Up @@ -535,7 +579,7 @@ def lasso(X: np.ndarray, y: np.ndarray,
#print(it)
return beta

# %% ../nbs/methods.ipynb 47
# %% ../nbs/methods.ipynb 48
class ERM:
"""Optimal Combination Reconciliation Class.

Expand Down
84 changes: 57 additions & 27 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@
" uids = Y_hat_df.index.unique()\n",
" # check if Y_hat_df has the same uids as S\n",
" if len(S.index.difference(uids)) > 0 or len(Y_hat_df.index.difference(S.index.unique())) > 0:\n",
" print(len(S.index.difference(uids)))\n",
" print(len(Y_hat_df.index.difference(S.index.unique())))\n",
" raise Exception('Summing matrix `S` and `Y_hat_df` do not have the same time series, please check.')\n",
" # same order of Y_hat_df to prevent errors\n",
" S_ = S.loc[uids]\n",
Expand Down Expand Up @@ -333,7 +335,7 @@
"df.insert(0, 'Country', 'Australia')\n",
"\n",
"# non strictly hierarchical structure\n",
"hiers_grouped = [\n",
"hierS_grouped_df = [\n",
" ['Country'],\n",
" ['Country', 'State'], \n",
" ['Country', 'Purpose'], \n",
Expand All @@ -349,7 +351,7 @@
"]\n",
"\n",
"# getting df\n",
"hier_grouped_df, S_grouped, tags_grouped = aggregate(df, hiers_grouped)\n",
"hier_grouped_df, S_grouped_df, tags_grouped = aggregate(df, hierS_grouped_df)\n",
"hier_strict_df, S_strict, tags_strict = aggregate(df, hiers_strictly)"
]
},
Expand All @@ -362,8 +364,8 @@
"#| hide\n",
"hier_grouped_df['y_model'] = hier_grouped_df['y']\n",
"# we should be able to recover y using the methods\n",
"hier_grouped_df_h = hier_grouped_df.groupby('unique_id').tail(12)\n",
"ds_h = hier_grouped_df_h['ds'].unique()\n",
"hier_grouped_hat_df = hier_grouped_df.groupby('unique_id').tail(12)\n",
"ds_h = hier_grouped_hat_df['ds'].unique()\n",
"hier_grouped_df = hier_grouped_df.query('~(ds in @ds_h)')\n",
"#adding noise to `y_model` to avoid perfect fited values\n",
"hier_grouped_df['y_model'] += np.random.uniform(-1, 1, len(hier_grouped_df))\n",
Expand All @@ -383,8 +385,8 @@
" # ERM recovers but needs bigger eps\n",
" #ERM(method='reg_bu', lambda_reg=None),\n",
"])\n",
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_df_h, Y_df=hier_grouped_df, \n",
" S=S_grouped, tags=tags_grouped)\n",
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_hat_df, Y_df=hier_grouped_df, \n",
" S=S_grouped_df, tags=tags_grouped)\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
" if 'ERM' in model:\n",
" eps = 3\n",
Expand All @@ -407,19 +409,19 @@
"test_fail(\n",
" hrec.reconcile,\n",
" contains='do not have the same time series',\n",
" args=(hier_grouped_df_h.drop('Australia'), S_grouped, tags_grouped, hier_grouped_df),\n",
" args=(hier_grouped_hat_df.drop('Australia'), S_grouped_df, tags_grouped, hier_grouped_df),\n",
" \n",
")\n",
"test_fail(\n",
" hrec.reconcile,\n",
" contains='do not have the same time series',\n",
" args=(hier_grouped_df_h, S_grouped.drop('Australia'), tags_grouped, hier_grouped_df),\n",
" args=(hier_grouped_hat_df, S_grouped_df.drop('Australia'), tags_grouped, hier_grouped_df),\n",
" \n",
")\n",
"test_fail(\n",
" hrec.reconcile,\n",
" contains='do not have the same time series',\n",
" args=(hier_grouped_df_h, S_grouped, tags_grouped, hier_grouped_df.drop('Australia')),\n",
" args=(hier_grouped_hat_df, S_grouped_df, tags_grouped, hier_grouped_df.drop('Australia')),\n",
" \n",
")"
]
Expand All @@ -440,8 +442,8 @@
" MinTrace(method='ols', nonnegative=True),\n",
" MinTrace(method='wls_struct', nonnegative=True),\n",
"])\n",
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_df_h,\n",
" S=S_grouped, tags=tags_grouped)\n",
"reconciled = hrec.reconcile(Y_hat_df=hier_grouped_hat_df,\n",
" S=S_grouped_df, tags=tags_grouped)\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
" if 'ERM' in model:\n",
" eps = 3\n",
Expand All @@ -465,7 +467,7 @@
"test_fail(\n",
" hrec.reconcile,\n",
" contains='requires strictly hierarchical structures',\n",
" args=(hier_grouped_df_h, S_grouped, tags_grouped, hier_grouped_df,)\n",
" args=(hier_grouped_hat_df, S_grouped_df, tags_grouped, hier_grouped_df,)\n",
")"
]
},
Expand Down Expand Up @@ -556,6 +558,27 @@
" test_close(reconciled['y'], reconciled[model], eps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# MinTrace should break\n",
"# with extremely overfitted model, y_model==y\n",
"\n",
"zero_df = hier_grouped_df.copy()\n",
"zero_df['y'] = 0\n",
"zero_df['y_model'] = 0\n",
"hrec = HierarchicalReconciliation([MinTrace(method='mint_shrink')])\n",
"test_fail(\n",
" hrec.reconcile,\n",
" contains='Insample residuals',\n",
" args=(hier_grouped_hat_df, S_grouped_df, tags_grouped, zero_df)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -568,9 +591,9 @@
"#that argument\n",
"hrec = HierarchicalReconciliation([MinTrace(method='ols')])\n",
"reconciled = hrec.reconcile(\n",
" Y_hat_df=hier_grouped_df_h, \n",
" Y_hat_df=hier_grouped_hat_df, \n",
" Y_df=hier_grouped_df.drop(columns=['y_model']), \n",
" S=S_grouped, \n",
" S=S_grouped_df, \n",
" tags=tags_grouped\n",
")\n",
"for model in reconciled.drop(columns=['ds', 'y']).columns:\n",
Expand All @@ -597,8 +620,8 @@
"#test methods bootstrap prediction\n",
"#intervals\n",
"hrec = HierarchicalReconciliation([BottomUp()])\n",
"reconciled = hrec.reconcile(hier_grouped_df_h, \n",
" Y_df=hier_grouped_df, S=S_grouped, tags=tags_grouped,\n",
"reconciled = hrec.reconcile(hier_grouped_hat_df, \n",
" Y_df=hier_grouped_df, S=S_grouped_df, tags=tags_grouped,\n",
" level=[80, 90], \n",
" intervals_method='bootstrap')\n",
"total = reconciled.loc[tags_grouped['Country/State/Region/Purpose']].groupby('ds').sum().reset_index()\n",
Expand All @@ -618,11 +641,11 @@
"#| hide\n",
"# test methods that require prediction intervals\n",
"# normality\n",
"hier_grouped_df_h['y_model-lo-80'] = hier_grouped_df_h['y_model'] - 1.96\n",
"hier_grouped_df_h['y_model-hi-80'] = hier_grouped_df_h['y_model'] + 1.96\n",
"hier_grouped_hat_df['y_model-lo-80'] = hier_grouped_hat_df['y_model'] - 1.96\n",
"hier_grouped_hat_df['y_model-hi-80'] = hier_grouped_hat_df['y_model'] + 1.96\n",
"hrec = HierarchicalReconciliation([BottomUp()])\n",
"reconciled = hrec.reconcile(hier_grouped_df_h, \n",
" Y_df=hier_grouped_df, S=S_grouped, tags=tags_grouped,\n",
"reconciled = hrec.reconcile(hier_grouped_hat_df, \n",
" Y_df=hier_grouped_df, S=S_grouped_df, tags=tags_grouped,\n",
" level=[80, 90], \n",
" intervals_method='normality')\n",
"total = reconciled.loc[tags_grouped['Country/State/Region/Purpose']].groupby('ds').sum().reset_index()\n",
Expand All @@ -645,13 +668,13 @@
"\n",
"# test expect error with grouped structure \n",
"# (non strictly hierarchical)\n",
"hier_grouped_df_h['y_model-lo-80'] = hier_grouped_df_h['y_model'] - 1.96\n",
"hier_grouped_df_h['y_model-hi-80'] = hier_grouped_df_h['y_model'] + 1.96\n",
"hier_grouped_hat_df['y_model-lo-80'] = hier_grouped_hat_df['y_model'] - 1.96\n",
"hier_grouped_hat_df['y_model-hi-80'] = hier_grouped_hat_df['y_model'] + 1.96\n",
"hrec = HierarchicalReconciliation([BottomUp()])\n",
"test_fail(\n",
" hrec.reconcile,\n",
" contains='requires strictly hierarchical structures',\n",
" args=(hier_grouped_df_h, S_grouped, tags_grouped, hier_grouped_df, [80, 90], 'permbu',)\n",
" args=(hier_grouped_hat_df, S_grouped_df, tags_grouped, hier_grouped_df, [80, 90], 'permbu',)\n",
")\n",
"\n",
"# test PERMBU\n",
Expand Down Expand Up @@ -708,7 +731,7 @@
" ['Country', 'State', 'Purpose'], \n",
" ['Country', 'State', 'Region', 'Purpose']]\n",
"\n",
"Y_df, S, tags = aggregate(df=df, spec=hierarchy_levels)\n",
"Y_df, S_df, tags = aggregate(df=df, spec=hierarchy_levels)\n",
"qs = Y_df['ds'].str.replace(r'(\\d+) (Q\\d)', r'\\1-\\2', regex=True)\n",
"Y_df['ds'] = pd.PeriodIndex(qs, freq='Q').to_timestamp()\n",
"Y_df = Y_df.reset_index()\n",
Expand All @@ -732,16 +755,23 @@
" MinTrace(method='ols')]\n",
"hrec = HierarchicalReconciliation(reconcilers=reconcilers)\n",
"Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, Y_df=Y_train_df,\n",
" S=S, tags=tags)\n",
" S=S_df, tags=tags)\n",
"Y_rec_df.groupby('unique_id').head(2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "hierarchicalforecast",
"language": "python",
"name": "python3"
"name": "hierarchicalforecast"
}
},
"nbformat": 4,
Expand Down
Loading