Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Staging PR for implimenting Phi-2 support. #97

Open
wants to merge 54 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1a21471
Added kernel for layerNORMS from triton tutorials[https://triton-lang…
cm2435 Jan 18, 2024
b60c138
formatting and typo fix
cm2435 Jan 22, 2024
598dda3
starting to structure out test suite
cm2435 Jan 22, 2024
3954f16
Contributed relu triton kernl, removed a small amount of boilerplate …
cm2435 Jan 22, 2024
48eb887
update test to use common torch.allclose function util
cm2435 Jan 24, 2024
4a7c20c
impliment forward and backward kernels for GeLU
cm2435 Jan 24, 2024
b6d224b
fixed incorrect bwd pass for GeLU, wrapped functions in pytorch convi…
cm2435 Jan 24, 2024
e520e9a
added partial scaling to rope emebeddings
cm2435 Feb 4, 2024
cbd31c3
added new seeded dropout kernel
cm2435 Feb 4, 2024
b62c886
updated tests
cm2435 Feb 4, 2024
e3e41a7
formatting
cm2435 Feb 4, 2024
440ef5d
uncommented out init.py
cm2435 Feb 6, 2024
a1e2b0d
doing work to impliment model
cm2435 Feb 6, 2024
f2112b1
updated pre_patch for phi2
cm2435 Feb 6, 2024
0386c96
wrote decoder model fwd for phi2
cm2435 Feb 12, 2024
6ec3c4f
I am dumb. fixing all the broken stuff
cm2435 Feb 19, 2024
f070fad
Quick fixes (#101)
danielhanchen Jan 19, 2024
2904ad9
Revert quantization methods
danielhanchen Jan 19, 2024
2da8a7d
getattr issues (#103)
danielhanchen Jan 19, 2024
24f943f
Update _utils.py
danielhanchen Jan 19, 2024
ddb7bee
Quick fixes (#106)
danielhanchen Jan 19, 2024
9a9e6d4
Hotfix for Jan 2024 Release (#110)
danielhanchen Jan 20, 2024
b392c28
Fixed saving! (#113)
danielhanchen Jan 20, 2024
3c880df
Update save.py
danielhanchen Jan 20, 2024
770b5ac
Update save.py
danielhanchen Jan 20, 2024
164319a
Update save.py
danielhanchen Jan 20, 2024
8f996e2
Hotfix (#118)
danielhanchen Jan 21, 2024
7e6f313
2-4x faster native HF inference (#119)
danielhanchen Jan 22, 2024
f61ed0e
Fix bugs (#129)
danielhanchen Jan 25, 2024
43c146d
More bug fixes (#133)
danielhanchen Jan 26, 2024
cb4c49c
Inference bug fix (#134)
danielhanchen Jan 26, 2024
d08a042
Fix bugs + more accurate Swiglu (#137)
danielhanchen Jan 27, 2024
c1d6501
1 more bug (#138)
danielhanchen Jan 27, 2024
7a2f5d2
Fix saving issues (#139)
danielhanchen Jan 28, 2024
77866a2
Nightly (#140)
danielhanchen Jan 28, 2024
7b667a4
Fix inference attention mask (#142)
danielhanchen Jan 29, 2024
d129628
Hotfix - fix inference (#146)
danielhanchen Jan 30, 2024
60acab2
2x faster inference (#151)
danielhanchen Feb 4, 2024
3618e5b
ReadMe Revamp (#156)
shimmyshimmer Feb 6, 2024
2d5f7ed
Torch 2.2 (#157)
danielhanchen Feb 6, 2024
e62c037
Nightly (#161)
danielhanchen Feb 7, 2024
e81b78d
Update README.md (#162)
danielhanchen Feb 8, 2024
c3ea900
Update mapper.py
danielhanchen Feb 8, 2024
1f5f2e3
Update README.md (#164)
danielhanchen Feb 9, 2024
53b7af5
Update README.md (#165)
danielhanchen Feb 9, 2024
38c3f43
add HF tagging in unsloth (#170)
younesbelkada Feb 13, 2024
c5fd5cb
Prelim Feb release (#173)
danielhanchen Feb 14, 2024
0bb66a9
edited layernorm for better implimentation (credits to ludacrin
cm2435 Feb 26, 2024
e031ed8
fix decoder and model implimentation
cm2435 Feb 26, 2024
c8198a0
fix partial rope embedding
cm2435 Feb 26, 2024
69005f4
added new kernel test
cm2435 Feb 26, 2024
63328d1
resolve merge conflicts
cm2435 Feb 26, 2024
2d19215
updated tests, experimenting with layernorm
cm2435 Mar 4, 2024
1142330
fixed kernels, currently debugging gelu
cm2435 Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .idea/.idea.unsloth.dir/.idea/projectSettingsUpdater.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

141 changes: 141 additions & 0 deletions .idea/.idea.unsloth.dir/.idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions experiments/benchmark.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "unsloth",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Empty file added tests/__init__.py
Empty file.
Empty file added tests/kernels/__init__.py
Empty file.
73 changes: 73 additions & 0 deletions tests/kernels/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import gc
from contextlib import contextmanager

import os
import pytest
import numpy as np
import torch
import torch._dynamo as dynamo


@contextmanager
def set_seed(seed: int = 0):
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
yield

@pytest.fixture(autouse=True)
def reset_dyno_state():
cache_limit = dynamo.config.cache_size_limit
try:
dynamo.config.cache_size_limit = 8192
dynamo.reset()
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
yield {}
except RuntimeError as err:
raise err
finally:
dynamo.config.cache_size_limit = cache_limit
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()


def assert_all_close(a: torch.Tensor, b: torch.Tensor, rtol=0, atol=1e-1) -> None:
"""
Check that all elements of tensors a and b are within provided thresholds.
"""
assert a.shape == b.shape, f"Shapes don't match: {a.shape} != {b.shape}"
assert a.dtype == b.dtype, f"Dtypes don't match: {a.dtype} != {b.dtype}"
assert a.device == b.device, f"Devices don't match: {a.device} != {b.device}"
max_abs_diff = torch.max(torch.abs(a - b))
rel_diff = torch.abs(a / b)
max_rel_diff = torch.max(rel_diff)
mismatch_elements = torch.sum(torch.abs(a - b) > atol + rtol * torch.abs(b))
nb_elements = torch.numel(a)
msg = (
f"Differences: "
f"{max_abs_diff:.3f} (max abs), "
f"{max_rel_diff:.3f} (max rel), "
f"{mismatch_elements}/{nb_elements} (mismatch elements)"
)
assert torch.allclose(a, b, rtol=rtol, atol=atol), msg
22 changes: 22 additions & 0 deletions tests/kernels/test_crossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss

x = torch.randn(1, 126, 51200, device='cuda')
y = torch.randn(1, 126, device='cuda')

fast_cross_entropy_loss(logits=x,labels=y)
55 changes: 55 additions & 0 deletions tests/kernels/test_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch

from unsloth.kernels.gelu import gelu_forward_kenel, gelu_backward_kernel
from tests.kernels.conftest import set_seed, assert_all_close

@set_seed
@pytest.fixture(params=[(100, 100), (1024, 1024), (5000, 1024), (12345, 5678)])
def test_matrix(request):
shape = request.param
x = torch.randn(shape, device='cuda')
return x

# Test function
def test_relu_kernel_fwd(test_matrix):
# Apply your Triton-based ReLU kernel
triton_output = gelu_forward_kenel(test_matrix)

# Apply PyTorch's ReLU for comparison
torch_gelu = torch.nn.GELU()
torch_output = torch_gelu(test_matrix)

# Check if the outputs are close enough using assert_all_close
assert_all_close(triton_output, torch_output, rtol=1e-05, atol=1e-08)


# Test function for GeLU backward kernel
def test_gelu_backward_kernel(test_matrix):
# Create a tensor representing gradients (e.g., random gradients)
grad_input = torch.randn_like(test_matrix)

# Apply your Triton-based GeLU backward kernel
triton_output = gelu_backward_kernel(test_matrix, grad_input)

# Compute PyTorch's GeLU gradient for comparison
torch_gelu = torch.nn.GELU()
torch_output = torch.autograd.grad(torch_output.sum(), test_matrix, grad_outputs=grad_input)[0]

# Check if the outputs are close enough using assert_all_close
assert_all_close(triton_output, torch_output, rtol=1e-05, atol=1e-08)