Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When doing tuner.scale_batch_size, check full dataset length first #19850

Open
fingoldo opened this issue May 6, 2024 · 0 comments
Open

When doing tuner.scale_batch_size, check full dataset length first #19850

fingoldo opened this issue May 6, 2024 · 0 comments
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@fingoldo
Copy link

fingoldo commented May 6, 2024

Description & Motivation

Currently on a relatively small tabular dataset (3_366_292 by 62, RAM footprint <500Mb), when searching for optimal batch size,

tuner = Tuner(trainer)
# Auto-scale batch size with binary search
tuner.scale_batch_size(model=model,datamodule=dm, mode="binsearch")

I have to wait ~10 mins till it figures out The batch size 4194304 is greater or equal than the length of your dataset.

Pitch

I'd like an option to start the search backwards, from the length of entire dataset back to one.
Or at least I'd like that each algo would first check entire ds length (could be a flag with default value False, to protect people training on terabytes of images etc), if it succeeds then don't proceed with anything else.

Alternatives

Wating for 10 mins )
I just realized that implementing getitems method made running time drop from 10 to 1.2 minute in my case, but still.

Additional context

Related might be a proposal for a dataloader to be able to skip sampling from the underlying dataset and return all samples at once if batch_size equals/exceeds dataset length. If found mentioning of getitems func in the Dataset class, and ensured it's really used. But please confirm/disprove if indices sent to getitems are just range(len(ds)) if batch_size equals/exceeds dataset length. I'm currently using this workaround:


    def __getitems__(self, indices: List):
        if len(indices) == len(self.labels):
            return self.features, self.labels
        else:
            return self.features[indices], self.labels[indices]

Note that I also have to pass collate_fn=lambda x: x to the dataloader for getitems to work, mb it's worth reflecting in Dataset class comments.

PS Oh I can see the indices are indeed range(len(ds)). Then further optimization is possible, send None in such cases, for the dataset to return everything, but that will need documenting.

cc @Borda

@fingoldo fingoldo added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant