Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch IterableDataset are not fully supported #594

Open
ottonemo opened this issue Feb 20, 2020 · 3 comments
Open

torch IterableDataset are not fully supported #594

ottonemo opened this issue Feb 20, 2020 · 3 comments
Labels
Projects

Comments

@ottonemo
Copy link
Member

With PyTorch 1.2.0 came IterableDataset which only implements __iter__ but no __len__ and certainly no __getitem__. This is definitely a problem since we are using Subset to split the input dataset and wraps the original dataset, introduces __getitem__ and delegates the call to the wrapped dataset - which doesn't implement that method since it is iterable.

The simplest solution to this is to not split IterableDataset in any way. What do you think?

@ottonemo ottonemo added the bug label Feb 20, 2020
@ottonemo ottonemo added this to To do in 0.8.0 via automation Feb 20, 2020
@BenjaminBossan
Copy link
Collaborator

Didn't know about this one. I always wondered whether Datasets really needed __getitem__, this answers the question :)

Splitting the way skorch (or rather, sklearn) does it, can't be easily supported with IterableDataset. The __len__ part would be okay, since our Dataset supports passing the length explicitly. For train/valid, a user would need to predefine two datasets at the moment.

We could think about a wrapper class that allows to split IterableDataset by using every n-th element for validation, but e.g. stratified splits, group-based splits, or predefined splits wouldn't work.

@thomasjpfan
Copy link
Member

Didn't know about this one. I always wondered whether Datasets really needed __getitem__, this answers the question :)

Did you mean __len__? :)

We could think about a wrapper class that allows to split IterableDataset by using every n-th element for validation, but e.g. stratified splits, group-based splits, or predefined splits wouldn't work.

If I was a user of IterableDataset, I do not think there is a sensible default for splitting the data.

The simplest solution to this is to not split IterableDataset in any way. What do you think?

I agree.

Currently, is there an issue with passing an IterableDataset directly into fit? Something like this works: (a little hacky?)

class MyDataset(torch.utils.data.IterableDataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y
        self._i = -1

    def _generator(self):
        if self._i == len(X):
           raise StopIteration()
        self._i = self._i + 1
        yield self.X[self._i], self.y[self._i]

    def __iter__(self):
        return self._generator()

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
dataset = MyDataset(X, y)

net = NeuralNetClassifier(ClassifierModule, train_split=None)
net.fit(dataset, y=None)

Moving forward, we can raise an error when train_split is not None and IterableDataset is passed in?

@BenjaminBossan
Copy link
Collaborator

Did you mean __len__? :)

I meant __getitem__ but my sentence was not very clear. What I wanted to express is I wondered why torch Dataset was not implemented as an iterable instead of relying on __getitem__ to access its members. I just concluded that there probably is a technical reason for it, but the existence of IterableDataset shows that __getitem__ is actually not strictly necessary (though still helpful in some situations).

If I was a user of IterableDataset, I do not think there is a sensible default for splitting the data.

The simplest solution to this is to not split IterableDataset in any way. What do you think?

I agree.

I agree to both.

Something like this works: (a little hacky?)

I don't think it's too hacky. Maybe this could be added to helper.py?

Moving forward, we can raise an error when train_split is not None and IterableDataset is passed in?

I agree to this too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
No open projects
0.8.0
  
To do
Development

No branches or pull requests

3 participants