Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit to pass '**predict_params' to 'predict' method as for 'fit' method #1042

Open
corradomio opened this issue Jan 26, 2024 · 2 comments
Open

Comments

@corradomio
Copy link

corradomio commented Jan 26, 2024

It could be useful to add the possibility to pass custom 'predict params' in 'predict' method as it is available in 'fit' method.

Starting from 'predict', there are only 4 skorch's methods to improve.
This is a simple implementation I did

class NeuralNetRegressor(skorch.NeuralNetRegressor):
    def __init__(
            self,
            module,
            *args,
            criterion=torch.nn.MSELoss,
            **kwargs
    ):
        super(NeuralNetRegressor, self).__init__(
            module,
            *args,
            criterion=criterion,
            **kwargs
        )

    def predict(self, X, **predict_params):
        return self.predict_proba(X, **predict_params)

    def predict_proba(self, X, **predict_params):
        nonlin = self._get_predict_nonlinearity()
        y_probas = []
        for yp in self.forward_iter(X, training=False, **predict_params):
            yp = yp[0] if isinstance(yp, tuple) else yp
            yp = nonlin(yp)
            y_probas.append(to_numpy(yp))
        y_proba = np.concatenate(y_probas, 0)
        return y_proba

    def forward_iter(self, X, training=False, device='cpu', **params):
        dataset = self.get_dataset(X)
        iterator = self.get_iterator(dataset, training=training)
        for batch in iterator:
            yp = self.evaluation_step(batch, training=training, **params)
            yield to_device(yp, device=device)

    def evaluation_step(self, batch, training=False, **eval_params):
        self.check_is_fitted()
        Xi, _ = unpack_data(batch)
        with torch.set_grad_enabled(training):
            self._set_training(training)
            return self.infer(Xi, **eval_params)

@githubnemo
Copy link
Contributor

Hey! Thanks for the suggestion.

Indeed this is something that is not symmetrical with .fit().
There are sklearn classifiers and transformers that support additional parameters so I don't see an immediate reason against it.

Are you interested in working on this and submitting an PR? :)

@ramonamezquita
Copy link

I believe this can already be achieved by making X a dict and using a proper collate function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants