Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tips on speeding up training of the TFT #2344

Open
chododom opened this issue Apr 21, 2024 · 1 comment
Open

Tips on speeding up training of the TFT #2344

chododom opened this issue Apr 21, 2024 · 1 comment
Labels
bug Something isn't working triage Issue waiting for triaging

Comments

@chododom
Copy link

Hi, I am working on trying to compare the prediction of a basic LSTM, a DeepAR and a TemporalFusionTransformer for the prediction of some multivariate sets of series.

I have around 4000 various TimeSeries objects, when put all together, it is a few million data points.

For the LSTM and DeepAR, the training times are acceptable.
With the TFT, I am experiencing a really weird thing, where when I try to train a model with 2 attention heads, 1 RNN layer and hidden dimension of 8, an epoch takes about an hour, but if I increase the dimension from 8 to 10, suddenly an epoch is estimated to take 11 hours.

How is this possible, the model is only like 40k parameters according to the summary...

I am also using torch.set_float32_matmul_precision('medium') to try to speed up the training, but I'm having absolutely no luck.

Any explanations regarding the complexity or tips on improving the computation speed would be very much welcome, thank you!

@chododom chododom added bug Something isn't working triage Issue waiting for triaging labels Apr 21, 2024
@dennisbader
Copy link
Collaborator

Hi @chododom, TFTModel is a transformer model so it is more complex and more inefficient compared to the other models.
From my checks, increasing the hidden size from 16 to 32 (factor 2), increased the number of trainable params by a factor 4.
So increasing the hidden size from 8 to 10 (factor 1.25) leading to an increase in training time by a factor 11 does indeed sound strange (not saying yet that it is a bug though).

We would have to perform an in-depth analysis and profile the model to see whether this is normal. Currently, we don't have much capacity on our side for this as we're working on higher-prio tasks. So any help from the community would also be greatly appreciated :)

Also, we have some additional recommendations for model performance in our user guide for torch models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue waiting for triaging
Projects
None yet
Development

No branches or pull requests

2 participants