Skip to content

Commit

Permalink
handle ampere and non ampere installs
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Mar 14, 2024
1 parent f0c8e6d commit ae073c4
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ get_pytorch_version () {
fi
}

# Function to get GPU architecture
get_gpu_type () {
GPU_MAJOR_VERSION=$(python -c "import torch; print(torch.cuda.get_device_capability()[0])")
if [[ "$GPU_MAJOR_VERSION" -ge 8 ]]; then
echo "ampere"
else
echo ""
fi
}

# Function to install packages via Conda
conda_install_packages () {
conda create --name unsloth_env python=3.10 -y
Expand All @@ -23,7 +33,11 @@ conda_install_packages () {
# Function to install packages via Pip
pip_install_packages () {
pip install --upgrade --force-reinstall --no-cache-dir torch==${PYTORCH_CORE_VERSION}+${CUDA_TAG} triton --index-url https://download.pytorch.org/whl/${CUDA_TAG}
pip install "unsloth[${CUDA_TAG}] @ git+https://github.com/unslothai/unsloth.git"
if [[ "$PYTORCH_VERSION_TAG" == "210" ]]; then
pip install "unsloth[${CUDA_TAG}${GPU_TYPE:+-$GPU_TYPE}] @ git+https://github.com/unslothai/unsloth.git"
else
pip install "unsloth[${CUDA_TAG}${GPU_TYPE:+-$GPU_TYPE}-$PYTORCH_VERSION_TAG] @ git+https://github.com/unslothai/unsloth.git"
fi
}

# Check if conda is installed
Expand Down Expand Up @@ -56,7 +70,12 @@ else
echo "CUDA version detected: $CUDA_VERSION"
PYTORCH_VERSION=$(get_pytorch_version)
echo "PyTorch version detected: $PYTORCH_VERSION"

GPU_TYPE=$(get_gpu_type)
if [[ $GPU_TYPE == "ampere"]]; then
echo "Ampere or newer architecture detected. Proceeding with ampere specific installation."
else
echo "Older GPU architecture detected. Proceeding with non-ampere specific installation."
fi
# Define CUDA tag based on CUDA version
if [[ "$CUDA_VERSION" == "11.8" ]]; then
CUDA_TAG="cu118"
Expand All @@ -69,6 +88,7 @@ else

# Extract PyTorch version (ignoring any suffix)
PYTORCH_CORE_VERSION=$(echo $PYTORCH_VERSION | cut -d'+' -f1)
PYTORCH_VERSION_TAG="torch${PYTORCH_CORE_VERSION//./}"

pip_install_packages

Expand Down

0 comments on commit ae073c4

Please sign in to comment.