Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFT Model: Trainer.__init__() got an unexpected keyword argument 'grn_activation' #1182

Closed
skuzmier opened this issue Oct 16, 2024 · 1 comment · Fixed by #1183
Closed

TFT Model: Trainer.__init__() got an unexpected keyword argument 'grn_activation' #1182

skuzmier opened this issue Oct 16, 2024 · 1 comment · Fixed by #1183
Labels

Comments

@skuzmier
Copy link

What happened + What you expected to happen

Using the TFT model example code from the website, I get an error saying Trainer.init() got an unexpected kewyworkd argument 'grn_activation'. This is running in a fresh anaconda environment, and only neuralforecast, its dependencies and matplotlib have been installed.

Versions / Dependencies

nueralforecast version 1.7.5
pytorch lightning version 2.4.0

Reproduction script

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from neuralforecast import NeuralForecast
from neuralforecast.models import TFT
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic

AirPassengersPanel['month']=AirPassengersPanel.ds.dt.month
Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test

nf = NeuralForecast(
models=[TFT(h=12, input_size=48,
hidden_size=20,
grn_activation='ELU',
loss=DistributionLoss(distribution='StudentT', level=[80, 90]),
learning_rate=0.005,
stat_exog_list=['airline1'],
futr_exog_list=['y_[lag12]','month'],
hist_exog_list=['trend'],
max_steps=300,
val_check_steps=10,
early_stop_patience_steps=10,
scaler_type='robust',
windows_batch_size=None,
enable_progress_bar=True),
],
freq='M'
)
nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
Y_hat_df = nf.predict(futr_df=Y_test_df)

Issue Severity

High: It blocks me from completing my task.

@skuzmier skuzmier added the bug label Oct 16, 2024
@elephaint
Copy link
Contributor

Hi,

grn_activation isn't released yet but the docs were already updated, unfortunately. Just remove that argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants