Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Skorch NeuralNetClassifier as booster for XGBoost #789

Open
francescamanni1989 opened this issue Jul 5, 2021 · 2 comments
Open

Using Skorch NeuralNetClassifier as booster for XGBoost #789

francescamanni1989 opened this issue Jul 5, 2021 · 2 comments
Labels

Comments

@francescamanni1989
Copy link

I couldn't re-open issue #737, however do you have any suggestions on how to implement XGBoost by using NeuralNetClassifier as weak learner instead of trees?

Thanks in advance.

@BenjaminBossan
Copy link
Collaborator

You probably meant issue #787

For sklearn's GradientBoostingClassifier, it doesn't look like you could easily replace the base estimator, since it is hard-coded: https://github.com/scikit-learn/scikit-learn/blob/2beed55847ee70d363bdbfe14ee4401438fba057/sklearn/ensemble/_gb.py#L195-L207

Theoretically, you could override this method and perhaps be successful, but my guess is that this is going to be very hard since this was not intended by the authors, and it could easily break in new sklearn versions. For XGBoost, this is going to be even more difficult, since significant parts of it are written in C/C++.

I believe your best shot to get something similar would be to use skorch with AdaBoostClassifier as described here.

@thomasjpfan
Copy link
Member

There is a discussion in scikit-learn: scikit-learn/scikit-learn#17660 for expanding GradientBoosting* for any base estimator.

Regarding gradient boosting + neural networks, there is recent research on this topic: https://arxiv.org/pdf/2002.07971.pdf (GrowNet).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants