Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict PyTorch version to <2.3.0 to resolve import error #577

Merged
merged 2 commits into from
Apr 27, 2024

Conversation

Priyanshupareek
Copy link
Contributor

Fixes #576
This addresses the import error encountered with PyTorch 2.3.0 as detailed in issue #576. The error 'cannot import name '_refresh_per_optimizer_state' from 'torch.cuda.amp.grad_scaler' is resolved by pinning the PyTorch version to 2.2.2 in setup.cfg.

Small change:

  • Updated setup.cfg to specify torch==2.2.2 under install_requires.

Addressing the import error encountered with PyTorch 2.3.0 as detailed in issue bigscience-workshop#576. 
fixes bigscience-workshop#576
@nsarrazin
Copy link

Tested this PR locally, seems to have fixed the issue for me! 😁

@CherukupalleNaveen
Copy link

Yep! Even I have tested in my local. It's working fine.

setup.cfg Outdated
@@ -32,7 +32,7 @@ package_dir =
packages = find:
python_requires = >=3.8
install_requires =
torch>=1.12
torch==2.2.2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! The requirement seems to be overly restrictive though (it will force users with lower versions of PyTorch to update), can you change it to inequality constraints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, that's right! Updated PyTorch version constraint in the cfg to torch>=1.12,<2.3.0 to have compatibility with earlier versions while avoiding the issue introduced in 2.3.0. Please review the updated pull request.

Modified the version constraint for PyTorch in setup.cfg to torch>=1.12,<2.3.0 to avoid the import errors introduced in version 2.3.0 while still supporting earlier compatible versions. This change follows feedback from @mryab to allow flexibility for users on different versions.
Copy link
Contributor Author

@Priyanshupareek Priyanshupareek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified the version constraint for PyTorch in setup.cfg to torch>=1.12,<2.3.0 to avoid the import errors introduced in version 2.3.0 while still supporting earlier compatible versions. This change follows feedback from @mryab to allow flexibility for users on different versions.

Copy link
Member

@mryab mryab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the contribution!

@mryab mryab changed the title Pin PyTorch version to 2.2.2 to resolve import error Restrict PyTorch version to <2.3.0 to resolve import error Apr 27, 2024
@mryab mryab merged commit e268c99 into bigscience-workshop:main Apr 27, 2024
11 checks passed
@Priyanshupareek Priyanshupareek deleted the patch-1 branch April 28, 2024 19:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error with PyTorch 2.3.0: Missing '_refresh_per_optimizer_state' in 'torch.cuda.amp.grad_scaler'
4 participants