Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/model deeptime #1329

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open

Feat/model deeptime #1329

wants to merge 23 commits into from

Conversation

madtoinou
Copy link
Collaborator

@madtoinou madtoinou commented Oct 31, 2022

Fixes #1152.

Summary

Implement the DeepTIMe model from https://arxiv.org/pdf/2207.06046.pdf, based on the original repository https://github.com/salesforce/DeepTime and the article pseudo-code.

Also implement some basics tests, inspired by the tests for N-Beats.

Other Information

In the original article, distinct optimizers are defined for the three groups of parameters: Ridge Regression regularization term, the biais/norm of the Implicit Neural Representation (INR) network and the weights of the INR. This was accomplished by overriding the configure_optimizer method and partially breaking the logic behind the lr_scheduler_cls and lr_scheduler_kwargs arguments. To make the model easier to use out of the box, the default arguments correspond to the original article parameters (including for the optimizer).

All the module necessary for this architecture were included in the same file to limit the fragmentation of the code. The Ridge Regression and the INR modules could however be extracted if others models require them.

The support for the nr_params functionnality is not implemented yet.

@codecov-commenter
Copy link

codecov-commenter commented Oct 31, 2022

Codecov Report

Base: 93.97% // Head: 94.03% // Increases project coverage by +0.05% 🎉

Coverage data is based on head (ea19348) compared to base (d712ce5).
Patch coverage: 96.95% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1329      +/-   ##
==========================================
+ Coverage   93.97%   94.03%   +0.05%     
==========================================
  Files          82       83       +1     
  Lines        8917     9102     +185     
==========================================
+ Hits         8380     8559     +179     
- Misses        537      543       +6     
Impacted Files Coverage Δ
darts/dataprocessing/transformers/scaler.py 97.56% <ø> (ø)
darts/utils/data/training_dataset.py 89.47% <ø> (ø)
darts/models/forecasting/deeptime.py 96.93% <96.93%> (ø)
...arts/models/forecasting/torch_forecasting_model.py 87.70% <100.00%> (-0.03%) ⬇️
darts/timeseries.py 91.94% <0.00%> (-0.06%) ⬇️
darts/models/forecasting/block_rnn_model.py 98.24% <0.00%> (-0.04%) ⬇️
darts/models/forecasting/nhits.py 99.27% <0.00%> (-0.01%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@eliane-maalouf eliane-maalouf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few comments on quick initial things I noticed

darts/tests/models/forecasting/test_deeptime.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_deeptime.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_deeptime.py Outdated Show resolved Hide resolved
Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good, nice job @madtoinou !
After glancing at the paper, I also think it would be a nice addition to Darts.
I haven't looked into all minute details of the processing being done but I trust you :) I've got a few small comments. Perhaps the main one concerns nr_params which we should try to exploit before we merge.

darts/utils/data/training_dataset.py Show resolved Hide resolved
darts/models/forecasting/deeptime.py Outdated Show resolved Hide resolved
darts/models/forecasting/deeptime.py Outdated Show resolved Hide resolved
darts/models/forecasting/deeptime.py Outdated Show resolved Hide resolved
darts/models/forecasting/deeptime.py Outdated Show resolved Hide resolved
darts/models/forecasting/deeptime.py Show resolved Hide resolved
darts/models/forecasting/deeptime.py Outdated Show resolved Hide resolved
darts/models/forecasting/deeptime.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_deeptime.py Outdated Show resolved Hide resolved
…_global_forecasting_models, the number of epochs during the reduced to the striuct minimum, corrected typo in docstring, removed the mutable default argument, added check in TorchForecastingModel for n_epochs
…crease the length of the prediction from 2 to 3, last version was relying on erroneous broadcasting
…eepTime, removed some comments in the forward method, corrected typo
@madtoinou madtoinou requested a review from hrzn November 17, 2022 09:07
@hrzn
Copy link
Contributor

hrzn commented Nov 28, 2022

I find that the results do not seem fantastic in the probabilistic setting. E.g. when running the following code:

from darts.datasets import AirPassengersDataset
from darts.dataprocessing.transformers import Scaler
from darts.models import DeepTimeModel
from darts.utils.likelihood_models import GaussianLikelihood, LaplaceLikelihood

series = AirPassengersDataset().load().astype(np.float32)

scaler = Scaler()
train, val = scaler.fit_transform(series[:-36]), scaler.transform(series[-36:])

model = DeepTimeModel(input_chunk_length=24,
                      output_chunk_length=12,
                      likelihood=GaussianLikelihood())

model.fit(train, epochs=100)

pred = model.predict(series=train, n=36, num_samples=300)
train.plot()
pred.plot()

I get this - the variance seems almost zero.
image

I'm wondering whether this might be due to our treatment of the distributions parameters, which perhaps happens too early in the processing (when creating the time representations), which could (maybe?) cause degenerate results. Could we maybe find a way to "tile" tensors somewhere else later in the forward pass? WDYT @madtoinou ?

@madtoinou
Copy link
Collaborator Author

This is indeed a bit disappointing, I should have spend more time looking at the variance of the resulting distribution.

There is not much room for tiling downstream: after the INR (fully connected network), there is only the ridge regression trying to solve the equation AX = B where A is the time representation transpose time itself, and B is the time representation transposed multiplied by the observations. I don't see how we could tweak this part.

I am going to experiment with using different Fourier features for each distribution parameter (before the INR), it should add bit of heterogeneity but I am not sure that it could solve the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[new model] DeepTIMe
4 participants