Skip to content

Commit

Permalink
2024 Release (#96)
Browse files Browse the repository at this point in the history
* Fix tokenizer, dropout, bias for LoRA

* Update loader.py

* Fix LoRA downcasting

* Update _utils.py

* Saving to GGUF

* fix

* colab_quantize_to_gguf

* move save modules

* save module

* Update __init__.py

* Update save.py

* Temp downgrade due to TRL issue

* Fix up bugs

* Faster saving + other changes

* Update llama.py

* Saving modules

* spelling

* Update llama.py

* Update save.py

* Update save.py

* Update loader.py

* Update llama.py

* patch saving

* Update save.py

* Update save.py

* Update save.py

* patch saving

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* original_model

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* saving to RAM leakage?

* Update save.py

* new_save_directory

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml
  • Loading branch information
danielhanchen committed Jan 18, 2024
1 parent 4112eb4 commit b8b1eaf
Show file tree
Hide file tree
Showing 9 changed files with 1,095 additions and 189 deletions.
3 changes: 1 addition & 2 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@
libcuda_dirs()
except:
warnings.warn(
"CUDA is not linked properly.\n"\
"We shall run `ldconfig /usr/lib64-nvidia` to try to fix it."
"Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
)
os.system("ldconfig /usr/lib64-nvidia")
importlib.reload(bnb)
Expand Down
5 changes: 3 additions & 2 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ def _rms_layernorm_forward(
r += row_idx * r_row_stride

X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)

row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1 / tl.sqrt(row_var + eps)
inv_var = 1.0 / tl.sqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
tl.store(Y + col_offsets, output, mask = mask)
pass
Expand Down
12 changes: 7 additions & 5 deletions unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
mask = offsets < n_elements

e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)

# f = e * sigmoid(e)
f_row = e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row

Expand All @@ -53,12 +54,13 @@ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

DW_row = tl.load(DW + offsets, mask = mask, other = 0).to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0)#.to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)

# f = e * sigmoid(e)
se_row = 1 / (1 + tl.exp(-e_row))
se_row = 1 / (1 + tl.exp(-e_row.to(tl.float32)))
se_row = se_row.to(e_row.dtype) # Exact copy from HF
# f = e * se
f_row = e_row * se_row
# h = f * g
Expand Down
2 changes: 0 additions & 2 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import torch
from typing import Union, Optional, List, Any, Callable
import numpy as np
import warnings
import gc
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
Expand Down

0 comments on commit b8b1eaf

Please sign in to comment.