-
Notifications
You must be signed in to change notification settings - Fork 479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restrict PyTorch version to <2.3.0 to resolve import error #577
Conversation
Addressing the import error encountered with PyTorch 2.3.0 as detailed in issue bigscience-workshop#576. fixes bigscience-workshop#576
Tested this PR locally, seems to have fixed the issue for me! 😁 |
Yep! Even I have tested in my local. It's working fine. |
setup.cfg
Outdated
@@ -32,7 +32,7 @@ package_dir = | |||
packages = find: | |||
python_requires = >=3.8 | |||
install_requires = | |||
torch>=1.12 | |||
torch==2.2.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! The requirement seems to be overly restrictive though (it will force users with lower versions of PyTorch to update), can you change it to inequality constraints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, that's right! Updated PyTorch version constraint in the cfg to torch>=1.12,<2.3.0 to have compatibility with earlier versions while avoiding the issue introduced in 2.3.0. Please review the updated pull request.
Modified the version constraint for PyTorch in setup.cfg to torch>=1.12,<2.3.0 to avoid the import errors introduced in version 2.3.0 while still supporting earlier compatible versions. This change follows feedback from @mryab to allow flexibility for users on different versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified the version constraint for PyTorch in setup.cfg to torch>=1.12,<2.3.0 to avoid the import errors introduced in version 2.3.0 while still supporting earlier compatible versions. This change follows feedback from @mryab to allow flexibility for users on different versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the contribution!
Fixes #576
This addresses the import error encountered with PyTorch 2.3.0 as detailed in issue #576. The error 'cannot import name '_refresh_per_optimizer_state' from 'torch.cuda.amp.grad_scaler' is resolved by pinning the PyTorch version to 2.2.2 in setup.cfg.
Small change:
setup.cfg
to specifytorch==2.2.2
underinstall_requires
.