Skip to content

Commit

Permalink
generalize horizon len
Browse files Browse the repository at this point in the history
  • Loading branch information
Linh-nk committed Dec 27, 2024
1 parent 1e204b6 commit f90ba1a
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions orion/primitives/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self,
pred_len=1,
repo_id="google/timesfm-1.0-200m-pytorch",
batch_size=32,
freq=0,
freq=0,
target=0):

self.window_size = window_size
Expand All @@ -65,19 +65,19 @@ def predict(self, X):
"""
frequency_input = [self.freq]*len(X)
d = X.shape[-1]
#Univariate

# Univariate
if d == 1:
y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input)
return y_hat[:, 0]
#Multivariate

# Multivariate
covariates = list(range(d))
covariates = covariates.remove(self.target)
X_cont = X[:, :, self.target]
X_cov = np.delete(X, self.target, axis=2)

#Append covariates with future values
# Append covariates with future values
m, n, k = X_cov.shape
X_cov_new = np.zeros((m, n+self.pred_len, k))
X_cov_new[:, :-self.pred_len, :] = X_cov
Expand Down

0 comments on commit f90ba1a

Please sign in to comment.