Skip to content

Commit

Permalink
model minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinnglabs committed Nov 17, 2023
1 parent 94e7c6a commit 4dbafab
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
4 changes: 2 additions & 2 deletions docs/examples/attribution.ipynb

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions karpiu/model_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,20 @@ def __init__(
# base comp related
df_zero = self.df.copy()
df_zero.loc[:, self.target_regressors] = 0.0
zero_pred_df = model.predict(df=df_zero, decompose=True)
zero_pred_df = model.predict(df=df_zero)
full_pred_df = model.predict(df=df)
resid = - full_pred_df

# prediction when all target regressors are set to zero
# we also need to include the residual terms during the training period
# prevent over-under bias in the fitting obs realization

# prediction when all target regressors are set to zero
# (n_calc_steps, )
self.base_comp_calc = zero_pred_df.loc[self.calc_mask, "prediction"].values
self.base_comp_calc = (
zero_pred_df.loc[self.calc_mask, "prediction"].values +
(df[self.calc_mask, model.response_col].values -
full_pred_df[self.calc_mask, "prediction"].values)
)
# (n_result_steps, )
self.base_comp_result = zero_pred_df.loc[self.result_mask, "prediction"].values
# (n_input_steps, )
Expand Down
3 changes: 1 addition & 2 deletions karpiu/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,7 @@ def predict(
# _5 and _95 probably won't exist with median prediction for current version
pred_tr_col = [
x
for x in ["prediction_5", "prediction", "prediction_95"]
if x in pred.columns
for x in ["prediction_5", "prediction", "prediction_95"] if x in pred.columns
]
pred[pred_tr_col] = pred[pred_tr_col].apply(np.exp)

Expand Down

0 comments on commit 4dbafab

Please sign in to comment.