Skip to content

Commit

Permalink
Constraint dependency versions
Browse files Browse the repository at this point in the history
Workaround of google#516

Also pin other dependencies for mostly reproducible container build
  • Loading branch information
chajath committed Mar 14, 2024
1 parent 1046886 commit b259042
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 5 deletions.
1 change: 1 addition & 0 deletions .dockerignore
133 changes: 133 additions & 0 deletions constraints.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
absl-py==1.4.0
aqtp==0.6.1
array-record==0.5.0
astroid==3.1.0
astunparse==1.6.3
attrs==23.2.0
cachetools==5.3.3
certifi==2024.2.2
charset-normalizer==3.3.2
chex==0.1.85
click==8.1.7
cloud-tpu-diagnostics==0.1.5
cloudpickle==3.0.0
contextlib2==21.6.0
dill==0.3.8
dm-tree==0.1.8
etils==1.7.0
exceptiongroup==1.2.0
flatbuffers==24.3.7
flax==0.8.1
fsspec==2024.2.0
gast==0.4.0
google-api-core==2.17.1
google-auth==2.28.2
google-auth-oauthlib==1.0.0
google-cloud-core==2.4.1
google-cloud-storage==2.15.0
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
grain-nightly==0.0.6
grpcio==1.62.1
gviz-api==1.10.0
h5py==3.10.0
idna==3.6
immutabledict==4.2.0
importlab==0.8.1
importlib_resources==6.3.0
iniconfig==2.0.0
isort==5.13.2
jax==0.4.25
jaxlib==0.4.25
jaxtyping==0.2.28
Jinja2==3.1.3
keras==2.13.1
libclang==16.0.6
libcst==1.2.0
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mccabe==0.7.0
mdurl==0.1.2
ml-collections==0.1.1
ml-dtypes==0.3.2
mlperf-logging==3.0.0
more-itertools==10.2.0
msgpack==1.0.8
msgspec==0.18.6
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.1
ninja==1.11.1.1
numpy==1.24.3
nvidia-cublas-cu12==12.4.2.65
nvidia-cuda-cupti-cu12==12.4.99
nvidia-cuda-nvcc-cu12==12.4.99
nvidia-cuda-nvrtc-cu12==12.4.99
nvidia-cuda-runtime-cu12==12.4.99
nvidia-cudnn-cu12==8.9.7.29
nvidia-cufft-cu12==11.2.0.44
nvidia-cusolver-cu12==11.6.0.99
nvidia-cusparse-cu12==12.3.0.142
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
oauthlib==3.2.2
opt-einsum==3.3.0
optax==0.2.1
orbax-checkpoint==0.5.5
packaging==24.0
pandas==2.2.1
platformdirs==4.2.0
pluggy==1.4.0
promise==2.3
protobuf==3.20.3
psutil==5.9.8
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycnite==2023.10.11
pydot==2.0.0
Pygments==2.17.2
pylint==3.1.0
pyparsing==3.1.2
pytest==8.1.1
python-dateutil==2.9.0.post0
pytype==2024.3.11
pytz==2024.1
PyYAML==6.0.1
requests==2.31.0
requests-oauthlib==1.4.0
rich==13.7.1
rsa==4.9
scipy==1.12.0
sentencepiece==0.1.97
six==1.16.0
tabulate==0.9.0
tensorboard==2.13.0
tensorboard-data-server==0.7.2
tensorboard_plugin_profile==2.15.1
tensorboardX==2.6.2.2
tensorflow==2.13.1
tensorflow-datasets==4.9.4
tensorflow-estimator==2.13.0
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-metadata==1.14.0
tensorflow-text==2.13.0
tensorstore==0.1.54
termcolor==2.4.0
tf-keras==2.15.0
toml==0.10.2
tomli==2.0.1
tomlkit==0.12.4
toolz==0.12.1
tqdm==4.66.2
typeguard==2.13.3
typing-inspect==0.9.0
typing_extensions==4.5.0
tzdata==2024.1
urllib3==2.2.1
Werkzeug==3.0.1
wrapt==1.16.0
zipp==3.18.0
2 changes: 1 addition & 1 deletion maxtext_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM ghcr.io/nvidia/jax:base
FROM ghcr.io/nvidia/jax:base-2024-03-13

# Install dependencies for adjusting network rto
RUN apt-get update && apt-get install -y iproute2 ethtool lsof
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ pylint
pytest
pytype
sentencepiece==0.1.97
tensorflow-text>=2.13.0
tensorflow>=2.13.0
# Limit tf version pending investigation https://github.com/google/maxtext/issues/516.
tensorflow-text>=2.13.0,<2.15
tensorflow>=2.13.0,<2.15
tensorflow-datasets
tensorboardx
tensorboard-plugin-profile
Expand Down
9 changes: 7 additions & 2 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
pip3 install -U "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
else
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
pip3 install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install --no-cache-dir "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints.txt
fi
fi
elif [[ $MODE == "nightly" ]]; then
Expand Down Expand Up @@ -205,4 +205,9 @@ else
fi

# Install dependencies from requirements.txt
cd $run_name_folder_path && pip install --upgrade pip && pip3 install -r requirements.txt
cd $run_name_folder_path && pip install --upgrade pip
if [[ $DEVICE == "gpu" ]] && [[ "$MODE" == "stable" || ! -v MODE ]] && [[ ! -v JAX_VERSION ]]; then
pip3 install --no-cache-dir -r requirements.txt -c constraints.txt
else
pip3 install -U -r requirements.txt
fi

0 comments on commit b259042

Please sign in to comment.