You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using the TFT model example code from the website, I get an error saying Trainer.init() got an unexpected kewyworkd argument 'grn_activation'. This is running in a fresh anaconda environment, and only neuralforecast, its dependencies and matplotlib have been installed.
Versions / Dependencies
nueralforecast version 1.7.5
pytorch lightning version 2.4.0
Reproduction script
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from neuralforecast import NeuralForecast
from neuralforecast.models import TFT
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
What happened + What you expected to happen
Using the TFT model example code from the website, I get an error saying Trainer.init() got an unexpected kewyworkd argument 'grn_activation'. This is running in a fresh anaconda environment, and only neuralforecast, its dependencies and matplotlib have been installed.
Versions / Dependencies
nueralforecast version 1.7.5
pytorch lightning version 2.4.0
Reproduction script
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from neuralforecast import NeuralForecast
from neuralforecast.models import TFT
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
AirPassengersPanel['month']=AirPassengersPanel.ds.dt.month
Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test
nf = NeuralForecast(
models=[TFT(h=12, input_size=48,
hidden_size=20,
grn_activation='ELU',
loss=DistributionLoss(distribution='StudentT', level=[80, 90]),
learning_rate=0.005,
stat_exog_list=['airline1'],
futr_exog_list=['y_[lag12]','month'],
hist_exog_list=['trend'],
max_steps=300,
val_check_steps=10,
early_stop_patience_steps=10,
scaler_type='robust',
windows_batch_size=None,
enable_progress_bar=True),
],
freq='M'
)
nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
Y_hat_df = nf.predict(futr_df=Y_test_df)
Issue Severity
High: It blocks me from completing my task.
The text was updated successfully, but these errors were encountered: