Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial fused GPTQ implementation #141

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

jeromeku
Copy link

GPTQ Peft Fine-tuning

GPTQ fast_lora

Adds fast_lora implementation for peft fine-tuning of GPTQ quantized models.

  • Following methodology of existing bitsandbytes fast_lora custom autograd, uses fuses triton quant / dequant matmul kernels from auto_gptq with LoRA adapters into custom torch.autograd.Function (see unsloth/gptq/fast_lora.py).
  • Default Huggingface GPTQ peft fine-tuning uses the auto_gptq cuda QuantLinear layer, which in turn falls back to a torch-only implementation since the custom cuda kernel employed by auto_gptq does not implement backwards.
  • Current implementation runs slower than default Huggingface implementation
  • Additional tuning / optimizations in the works.
  • See this issue for further profiling details.

Profiling

  • Also includes a profiling / benchmarking script for comparing unsloth models with huggingface models
  • See benchmarks/Profiling.MD for documentation.

@danielhanchen
Copy link
Contributor

@jeromeku Oh my this is a LARGE PR!!!! I'll take a read through it today :)

@danielhanchen
Copy link
Contributor

Ohh know I understand why you add the matmul triton kernels that are merged and not a separate dequantize kernel then a matmul ie:

out = dequantize_and_matmul(X, W)

vs

W = dequantize(W)
out = torch.matmul(X, W)

I took a look through GPTQ's repo, and yes I cannot find any dequantization kernel either written in Triton or not.

To attain maximal performance, technically that means an inclusion of the GPTQ dequantize kernel only, ie without the matrix multiplies inside the Triton kernel, which can screw with the compiler.

I'll see what I can do if I have some more bandwidth - sadly I don't have too much knowledge about GPTQ so I'll have to dive into their papers a bit on how their dequantization even works :)

Great work so far @jeromeku and thanks so much wonderfully for trying to add GPTQ!

@jeromeku
Copy link
Author

@danielhanchen
I have a pretty good handle on the situation -- will try to strip out the dequant part (in addition to some other optimizations).

@danielhanchen
Copy link
Contributor

@jeromeku Ok cool!! :)

@jeromeku
Copy link
Author

jeromeku commented Jan 30, 2024

@danielhanchen

Stripped out dequant portion of the fused dequant matmul and did some quick benchmarking of default quant linear forward per huggingface gptq vs. a torch.compiled dequant + torch.mm.

Promising early results (forward only):

    seqlen  reference_gptq_quantlinear  torch.compile(dequant+mm)
0     32.0                    2.581504                   0.406528
1     64.0                    2.563072                   0.407552
2    128.0                    2.591728                   0.430080
3    256.0                    2.689024                   0.502784
4    512.0                    2.971648                   0.780288
5   1024.0                    3.467648                   1.236992
6   2048.0                    4.403200                   2.150400
7   4096.0                    6.563840                   4.184480
8   8192.0                   10.655744                   8.019968
9  16384.0                   19.193855                  15.764481

These are median time (ms) for various sequence lengths.

However, running both forward and backward degrades the performance of the compiled version vs ref, which is confusing since the backwards graph is just a transposed matmul. Needs further investigation.

Interestingly, the triton kernel that gets codegen'ed for the dequant forward part is similar if not more efficient as the hand-written dequant portion of the previous triton kernel.

@danielhanchen
Copy link
Contributor

@jeromeku Cool great work again! Ye it definitely looks like torch.compile is destroying the hand written GPTQ kernel inside HF's codebase loll! Ye the backwards is transpose - but I'm assuming it's cause the strides are reversed, causing a performance hit - just my speculation.

@jeromeku
Copy link
Author

jeromeku commented Feb 3, 2024

@danielhanchen

Good news -- refactored the fast_lora implementation with a new triton kernel that does dequant separately from matmul (previous impl was an adapted version of the fused dequant matmul kernel from auto_gptq).

Performance now is on par with fast_lora bnb for llama and mistral models.

Will run some additional tests / benchmarks and PR should be ready for review.

Trainer results after 20 steps on guanaco for llama-{gptq,bnb} 4-bit:

  • hf-gptq
{
  "train_runtime": 113.4277,
  "train_samples_per_second": 1.411,
  "train_steps_per_second": 0.176,
  "train_loss": 1.3709101617336272,
  "epoch": 0.02
}
  • unsloth-gptq-triton
{
  "train_runtime": 69.5648,
  "train_samples_per_second": 2.3,
  "train_steps_per_second": 0.288,
  "train_loss": 1.3829106092453003,
  "epoch": 0.02
}
  • unsloth-bnb
{
  "train_runtime": 63.8765,
  "train_samples_per_second": 2.505,
  "train_steps_per_second": 0.313,
  "train_loss": 1.3803951740264893,
  "epoch": 0.02
}

@danielhanchen
Copy link
Contributor

@jeromeku Extremely extremely fabulous work!!! Now that is a fantastic performance boost from HF's GPTQ!! It looks like splitting the dequantization step and matmul did the trick!! Again super duper appreciate you adding GPTQ support into Unsloth - highly appreciate it :)

@jeromeku
Copy link
Author

jeromeku commented Feb 5, 2024

@danielhanchen

Cleaned up the dequant kernel.

Re-running the above benchmark (20 train steps on TheBloke/Llama-2-7b-GPTQ) gives the following:

{
  "train_runtime": 67.3811,
  "train_samples_per_second": 2.375,
  "train_steps_per_second": 0.297,
  "train_loss": 1.3829236447811126,
  "epoch": 0.02
}

To reproduce, run

python benchmark.py --model_name=llama --model_type=unsloth-gptq-triton --dtype=float16 --dataset_id=guanaco --output_dir=./bench_results

Replace --model_type with hf-gptq-default or unsloth-bnb to benchmark respectively.

See PROFILING.MD for more details on running the benchmark script

@danielhanchen
Copy link
Contributor

@jeromeku Super duper great work again! I will take a look later today! Thanks so much for your contribution again!

@danielhanchen
Copy link
Contributor

@jeromeku Hey sorry on the delay! Extreme apologies again didn't have time to take a look :( I will do so asap in the next few days! Sorry again, and super great work again! :)

@jeromeku jeromeku marked this pull request as ready for review March 7, 2024 02:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants