Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto batch size for torch model #2318

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

BohdanBilonoh
Copy link
Contributor

@BohdanBilonoh BohdanBilonoh commented Apr 11, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Summary

Auto batch size finding for TorchForecastingModel. This is just a wrapper for lightning.pytorch.tuner.tuning.Tuner.scale_batch_size

Copy link

codecov bot commented Apr 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.02%. Comparing base (a0cc279) to head (35d1d58).

Current head 35d1d58 differs from pull request most recent head ae7a128

Please upload reports for the commit ae7a128 to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2318      +/-   ##
==========================================
+ Coverage   93.75%   94.02%   +0.26%     
==========================================
  Files         138      138              
  Lines       14352    14152     -200     
==========================================
- Hits        13456    13306     -150     
+ Misses        896      846      -50     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,
method: Literal["fit", "validate", "test", "predict"] = "fit",
mode: str = "power",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you not using Literal here? This variable can just be power or linear right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took it from lightning. I think this is motivated by fact that the mode potentially could be many more modes

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this @BohdanBilonoh. It looks already pretty good 🚀
I added some suggestions on how we could simplify things a bit, and also to support "predict" method.

Comment on lines 1258 to 1260
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not required I guess

Suggested change
epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,

epochs: int = 0,
max_samples_per_ts: Optional[int] = None,
num_loader_workers: int = 0,
method: Literal["fit", "validate", "test", "predict"] = "fit",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"test" is not supported for darts.
"predict" would require a datamodule for prediction

in my opinion, we should only support "fit" and "predict", since we use the same batch size for train and val.

Comment on lines 1318 to 1319
batch_arg_name
The name of the argument to scale in the model. Defaults to 'batch_size'.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not required

Suggested change
batch_arg_name
The name of the argument to scale in the model. Defaults to 'batch_size'.

self.batch_size = batch_size

def train_dataloader(self):
return DataLoader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we use this also in _setup_for_train and predict_from_dataset, it would be good to have this logic in a private method for example _build_dataloader(), that takes as input a dataset, and returns the dataloader according to "train", "val" and "predict"

def scale_batch_size(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
val_series: Union[TimeSeries, Sequence[TimeSeries]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove all val_* arguments, since we use the same batch size for training and evaluation.

For the val_dataloader in the DataModule, we can just use series, past_covariates, future_covariates for the input dataset

@@ -1373,6 +1373,21 @@ def test_lr_find(self):
)
assert scores["worst"] > scores["suggested"]

@pytest.mark.slow
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not slow

Suggested change
@pytest.mark.slow

@@ -1373,6 +1373,21 @@ def test_lr_find(self):
)
assert scores["worst"] > scores["suggested"]

@pytest.mark.slow
def test_scale_batch_size(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test this for method="fit" and "predict"

CHANGELOG.md Outdated
@@ -89,6 +89,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`.
- Improvements to `DataTransformer`: [#2267](https://github.com/unit8co/darts/pull/2267) by [Alicja Krzeminska-Sciga](https://github.com/alicjakrzeminska).
- `InvertibleDataTransformer` now supports parallelized inverse transformation for `series` being a list of lists of `TimeSeries` (`Sequence[Sequence[TimeSeries]]`). This `series` type represents for example the output from `historical_forecasts()` when using multiple series.
- New method `TorchForecastingModel.scale_batch_size()` that helps to find batch size automatically. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- New method `TorchForecastingModel.scale_batch_size()` that helps to find batch size automatically. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh)
- Improvements to `TorchForecastingModel`:
- New method `TorchForecastingModel.scale_batch_size()` to find the maximum batch size for fit and predict before memory would run out. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh)

@BohdanBilonoh
Copy link
Contributor Author

@dennisbader could you please help with predict mode. It is tricky because TorchForecastingModel requires set_predict_parameters to be called before predict. Unfortunately I don't know how to make it works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants