Skip to content

Commit

Permalink
Pin nvidia-cudnn-cu12==8.9.7.29
Browse files Browse the repository at this point in the history
cudnn 9 is not uploaded to pypi due to pypi restrictions
  • Loading branch information
michelle-yooh committed Mar 27, 2024
1 parent 4e4d7f5 commit fc15c4b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nvidia-cudnn-cu12==8.9.7.29
jax>=0.4.23
jaxlib>=0.4.23
orbax-checkpoint>=0.5.5
Expand Down
4 changes: 2 additions & 2 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
echo "Installing stable jax, jaxlib for NVIDIA gpu"
if [[ -n "$JAX_VERSION" ]]; then
echo "Installing stable jax, jaxlib ${JAX_VERSION}"
pip3 install -U "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install -U "jax[cuda12_local]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
else
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
pip3 install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install -U "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
fi
fi
elif [[ $MODE == "nightly" ]]; then
Expand Down

0 comments on commit fc15c4b

Please sign in to comment.