From f90ba1ad484d5740beed0ff8d4c668e133e2ff24 Mon Sep 17 00:00:00 2001 From: Linh Nguyen Date: Thu, 26 Dec 2024 21:36:42 -0500 Subject: [PATCH] generalize horizon len --- orion/primitives/timesfm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/orion/primitives/timesfm.py b/orion/primitives/timesfm.py index f5089174..defd5770 100644 --- a/orion/primitives/timesfm.py +++ b/orion/primitives/timesfm.py @@ -39,7 +39,7 @@ def __init__(self, pred_len=1, repo_id="google/timesfm-1.0-200m-pytorch", batch_size=32, - freq=0, + freq=0, target=0): self.window_size = window_size @@ -65,19 +65,19 @@ def predict(self, X): """ frequency_input = [self.freq]*len(X) d = X.shape[-1] - - #Univariate + + # Univariate if d == 1: y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input) return y_hat[:, 0] - - #Multivariate + + # Multivariate covariates = list(range(d)) covariates = covariates.remove(self.target) X_cont = X[:, :, self.target] X_cov = np.delete(X, self.target, axis=2) - #Append covariates with future values + # Append covariates with future values m, n, k = X_cov.shape X_cov_new = np.zeros((m, n+self.pred_len, k)) X_cov_new[:, :-self.pred_len, :] = X_cov