Skip to content

Commit

Permalink
Add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
chajath committed Mar 15, 2024
1 parent 61950fd commit 1a6c4a7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ psutil==5.9.8
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycnite==2023.10.11
pydantic==2.6.4
pydantic-core==2.16.3
pydot==2.0.0
Pygments==2.17.2
pylint==3.1.0
Expand Down Expand Up @@ -123,9 +125,10 @@ tomli==2.0.1
tomlkit==0.12.4
toolz==0.12.1
tqdm==4.66.2
transformer-engine @ git+https://github.com/NVIDIA/TransformerEngine.git@0fbc76af3733ae997394eaf82b78ff9c0498fe9
typeguard==2.13.3
typing-inspect==0.9.0
typing_extensions==4.5.0
typing_extensions==4.6.1
tzdata==2024.1
urllib3==2.2.1
Werkzeug==3.0.1
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
pip3 install --no-cache-dir "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints.txt
fi
export NVTE_FRAMEWORK=jax
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip3 install --no-cache-dir git+https://github.com/NVIDIA/TransformerEngine.git@0fbc76af3733ae997394eaf82b78ff9c0498fe9 -c constraints.txt
fi
elif [[ $MODE == "nightly" ]]; then
# Nightly mode
Expand Down

0 comments on commit 1a6c4a7

Please sign in to comment.