Skip to content

Commit

Permalink
2.4x faster Gemma (#197)
Browse files Browse the repository at this point in the history
* Update save.py

* Update save.py

* linking

* llama.cpp bugs

* Update save.py

* Update save.py

* saving

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update __init__.py

* Update save.py

* Update save.py

* Update save.py

* save

* trainer

* spaces

* original

* Gemma

* Update pyproject.toml

* Update mapper.py

* Update fast_lora.py

* FastGemmaModel

* model_type

* Update llama.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* gemma

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update fast_lora.py

* Update fast_lora.py

* Fast CE Loss

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* CE

* Update llama.py

* Update llama.py

* Update cross_entropy_loss.py

* Update geglu.py

* Update cross_entropy_loss.py

* revert

* Update llama.py

* Update llama.py

* norm

* Update gemma.py

* Update gemma.py

* position_ids

* Update gemma.py

* Update gemma.py

* pos

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* revert

* revert

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* rope

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* llama

* Update llama.py

* gemma

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update gemma.py

* Update save.py

* RoPE

* Update llama.py

* Update llama.py

* Update llama.py

* Update gemma.py

* correct_dtype

* Update gemma.py

* Update cross_entropy_loss.py

* Update cross_entropy_loss.py

* Chat Templates

* Update README.md

* Update README.md
  • Loading branch information
danielhanchen committed Feb 26, 2024
1 parent 3e4c5a3 commit f946bed
Show file tree
Hide file tree
Showing 14 changed files with 767 additions and 141 deletions.
40 changes: 21 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<a href="https://discord.gg/u54VK8m8tk"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/Discord button.png" height="48"></a>
<a href="https://ko-fi.com/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/buy me a coffee button.png" height="48"></a>

### Finetune Mistral, Llama 2-5x faster with 70% less memory!
### Finetune Mistral, Gemma, Llama 2-5x faster with 70% less memory!

![](https://i.ibb.co/sJ7RhGG/image-41.png)

Expand All @@ -22,28 +22,30 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and

| Unsloth supports | Free Notebooks | Performance | Memory use |
|-----------------|--------------------------------------------------------------------------------------------------------------------------|-------------|----------|
| **Gemma 7b** | [▶️ Start on Colab](https://colab.research.google.com/drive/10NbwlsRChbma1v55m8LAPYG15uQv6HLo?usp=sharing) | 2.4x faster | 58% less |
| **Mistral 7b** | [▶️ Start on Colab](https://colab.research.google.com/drive/1Dyauq4kTZoLewQ1cApceUQVNcnnNTzg_?usp=sharing) | 2.2x faster | 62% less |
| **Llama-2 7b** | [▶️ Start on Colab](https://colab.research.google.com/drive/1lBzz5KeZJKXjvivbYvmGarix9Ao6Wxe5?usp=sharing) | 2.2x faster | 43% less |
| **DPO - Zephyr** | [▶️ Start on Colab](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 19% less |
| **TinyLlama** | [▶️ Start on Colab](https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing) | 3.9x faster | 74% less |
| **CodeLlama 34b** A100 | [▶️ Start on Colab](https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing) | 1.9x faster | 27% less |
| **Mistral 7b** 1xT4 | [▶️ Start on Kaggle](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook) | 5x faster\* | 62% less |
| **DPO - Zephyr** | [▶️ Start on Colab](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 19% less |

- This [conversational notebook](https://colab.research.google.com/drive/1Aau3lgPzeZKQ-98h69CCu1UJcvIBLmy2?usp=sharing) is useful for ShareGPT ChatML / Vicuna templates.
- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for raw text. This [DPO notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) replicates Zephyr.
- Colab provides a free GPU sometimes. Kaggle has 30 hrs free per week on a 12 hr running cap.
- \* Kaggle has 2x T4s, but we use 1. Due to overhead, 1x T4 is 5x faster. Use Colab as Kaggle takes 10 mins to install.
- \* Kaggle has 2x T4s, but we use 1. Due to overhead, 1x T4 is 5x faster.

## 🦥 Unsloth.ai News
- 📣 [DPO support](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) is now included. [More info](#DPO) on DPO.
- 📣 [TinyLlama 1.1b](https://colab.research.google.com/drive/1AZghoNBQaMDgWJpi4RbffGM1h6raLUj9?usp=sharing) on 3T tokens now works.
- 📣 We did a [blog](https://huggingface.co/blog/unsloth-trl) with 🤗Hugging Face, and we're in their official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth).
- 📣 Now supports **Llama, Yi, Mistral, CodeLlama, Qwen (llamafied), Deepseek** and their derived models (**Open Hermes** etc). Llama 7, 13, 70b; CodeLlama 7, 13, 34, 70b; Yi 6, 34b are all supported!
- 📣 **Download models 4x faster** from 🤗Hugging Face! Eg: `unsloth/mistral-7b-bnb-4bit` See our [HF collection](https://huggingface.co/collections/unsloth/load-4bit-models-4x-faster-659042e3a41c3cbad582e734) for more!
- 📣 [Gemma 7b](https://colab.research.google.com/drive/10NbwlsRChbma1v55m8LAPYG15uQv6HLo?usp=sharing) on 6T tokens now works. And [Gemma 2b notebook](https://colab.research.google.com/drive/15gGm7x_jTm017_Ic8e317tdIpDG53Mtu?usp=sharing)
- 📣 Added [conversational notebooks](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) and [raw text notebooks](https://colab.research.google.com/drive/1bMOKOBzxQWUIGZBs_B0zm8pimuEnZdfM?usp=sharing)
- 📣 [2x faster inference](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) added for all our models
- 📣 [DPO support](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) is now included. [More info](#DPO) on DPO
- 📣 We did a [blog](https://huggingface.co/blog/unsloth-trl) with 🤗Hugging Face and are in their official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)
- 📣 [Download models 4x faster](https://huggingface.co/collections/unsloth/) from 🤗Hugging Face. Eg: `unsloth/mistral-7b-bnb-4bit`

## 🔗 Links and Resources
| Type | Links |
| ------------------------------- | --------------------------------------- |
| 📚 **Wiki & FAQ** | [Read Our Wiki](https://github.com/unslothai/unsloth/wiki) |
| 📜 **Documentation** | [Read The Doc](https://github.com/unslothai/unsloth/tree/main#-documentation) |
| 💾 **Installation** | [unsloth/README.md](https://github.com/unslothai/unsloth/tree/main#installation-instructions)|
| <img height="14" src="https://upload.wikimedia.org/wikipedia/commons/6/6f/Logo_of_Twitter.svg" />&nbsp; **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
Expand Down Expand Up @@ -113,30 +115,30 @@ pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.0 triton \
```bash
pip install "unsloth[cu118] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118_ampere] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121_ampere] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere] @ git+https://github.com/unslothai/unsloth.git"
```
3. For Pytorch 2.1.1: Use the `"ampere"` path for newer RTX 30xx GPUs or higher.
```bash
pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.1 triton \
--index-url https://download.pytorch.org/whl/cu121
```
```bash
pip install "unsloth[cu118_torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121_torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118_ampere_torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121_ampere_torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere-torch211] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere-torch211] @ git+https://github.com/unslothai/unsloth.git"
```
4. For Pytorch 2.2.0: Use the `"ampere"` path for newer RTX 30xx GPUs or higher.
```bash
pip install --upgrade --force-reinstall --no-cache-dir torch==2.2.0 triton \
--index-url https://download.pytorch.org/whl/cu121
```
```bash
pip install "unsloth[cu118_torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121_torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118_ampere_torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121_ampere_torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"
```
5. If you get errors, try the below first, then go back to step 1:
```bash
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ exclude = ["images*"]

[project.optional-dependencies]
huggingface = [
"transformers>=4.37.0",
"transformers>=4.38.0",
"datasets",
"sentencepiece",
"accelerate>=0.26.1",
Expand Down
77 changes: 75 additions & 2 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,35 @@
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token,)


# https://huggingface.co/google/gemma-7b-it
# Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
# <end_of_turn> maps to 107. user and model are normal 1 word tokens.
gemma_template = \
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
"{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
"{% elif message['role'] == 'assistant' %}"\
"{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
"{% else %}"\
"{{ '<start_of_turn>system\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}"\
"{{ '<start_of_turn>model\n' }}"\
"{% endif %}"
gemma_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token,)


# Gemma with ChatML instead
gemma_chatml_template = chatml_template
gemma_chatml_eos_token = (
{"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"},
"<|im_end|>",
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token,)


def get_chat_template(
tokenizer,
chat_template = "chatml",
Expand All @@ -229,7 +258,7 @@ def get_chat_template(

old_padding_side = tokenizer.padding_side

if type(chat_template) in (list, tuple):
if type(chat_template) in (list, tuple,):
chat_template, stop_word = chat_template
assert(type(chat_template) is str)
assert(type(stop_word) is str)
Expand All @@ -238,7 +267,38 @@ def get_chat_template(

chat_template, stop_word = CHAT_TEMPLATES[chat_template]

if stop_word != "eos_token":
if type(stop_word) in (list, tuple,):
token_mapping, stop_word = stop_word
assert(type(token_mapping) is dict)
else:
token_mapping = None

assert(type(stop_word) is str)

# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)
if token_mapping is not None:

string_vocab = tokenizer._tokenizer.to_str()

for old_token, new_token in token_mapping.items():
old_count = string_vocab.count(f'"{old_token}"')
new_count = string_vocab.count(f'"{new_token}"')
if new_count != 0:
print(f"{new_token} is already a token. Skipping.")
elif old_count == 0:
raise RuntimeError(f"{old_token} was not part of the tokenizer!")
else:
string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
pass
pass

logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)

elif stop_word != "eos_token":
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")

# Replaces the old EOS token with a new one.
Expand All @@ -252,6 +312,7 @@ def get_chat_template(
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
pass

else:
raise TypeError(
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
Expand Down Expand Up @@ -318,6 +379,7 @@ def test_chat_templates():
{"role": "user", "content": " No it's 100% 5! "},
]

# Zephyr
from transformers import AutoTokenizer
template = zephyr_template
correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
Expand All @@ -326,27 +388,31 @@ def test_chat_templates():
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)

# Chatml
template = chatml_template
correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)

# Mistral
template = mistral_template
correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)

# Llama
template = llama_template
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)

# Vicuna
try:
from fastchat.conversation import get_conv_template
except:
Expand Down Expand Up @@ -381,4 +447,11 @@ def test_chat_templates():
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
# We add </s> ourselves
assert(correct_prompt == our_prompt.replace("</s>", ""))

# Gemma
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = gemma_template
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
assert(our_prompt == correct_prompt)
pass
4 changes: 3 additions & 1 deletion unsloth/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from .rms_layernorm import fast_rms_layernorm
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
from .geglu import geglu_forward_kernel, geglu_backward_kernel
from .fast_lora import (
get_lora_parameters,
apply_lora_mlp,
apply_lora_mlp_swiglu,
apply_lora_mlp_geglu,
apply_lora_qkv,
apply_lora_o,
)
Expand Down

0 comments on commit f946bed

Please sign in to comment.