Skip to content

Commit

Permalink
Validate train test frequencies
Browse files Browse the repository at this point in the history
  • Loading branch information
kdgutier committed Aug 1, 2020
1 parent a0731d7 commit cc9311c
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions ESRNN/ESRNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,21 @@ def fit(self, X_df, y_df, X_test_df=None, y_test_df=None, y_hat_benchmark='y_hat
n_series = self.train_dataloader.n_series
self.instantiate_esrnn(exogenous_size, n_series)

# Infer freq of model
if self.mc.frequency is None:
self.mc.frequency = pd.infer_freq(X_df.head()['ds'])
print("Infered frequency: {}".format(self.mc.frequency))
# Validating frequencies
X_train_frequency = pd.infer_freq(X_df.head()['ds'])
y_train_frequency = pd.infer_freq(y_df.head()['ds'])
self.frequencies = [X_train_frequency, y_train_frequency]

if (X_test_df is not None) and (y_test_df is not None):
X_test_frequency = pd.infer_freq(X_test_df.head()['ds'])
y_test_frequency = pd.infer_freq(y_test_df.head()['ds'])
self.frequencies += [X_test_frequency, y_test_frequency]

assert len(set(self.frequencies)) <= 1, \
"Match the frequencies of the dataframes {}".format(self.frequencies)

self.mc.frequency = self.frequencies[0]
print("Infered frequency: {}".format(self.mc.frequency))

# Train model
self._fitted = True
Expand Down

0 comments on commit cc9311c

Please sign in to comment.