Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Support GPTQ quantization #39

Open
araleza opened this issue Dec 17, 2023 · 35 comments
Open

[Feature request] Support GPTQ quantization #39

araleza opened this issue Dec 17, 2023 · 35 comments
Labels
help wanted Help from the OSS community wanted! on roadmap Feature request on roadmap

Comments

@araleza
Copy link

araleza commented Dec 17, 2023

So I have a GPTQ llama model I downloaded (from TheBloke), and it's already 4 bit quantized. I have to pass in False for the load_in_4bit parameter of:

model, tokenizer = FastLlamaModel.from_pretrained(

because if I don't, I get an error thrown saying:

The model is already quantized with gptq. You can't quantize it again with bitsandbytes

But, if I pass in False for load_in_4bit, this code makes bnb_config be None:

        bnb_config = None
        if load_in_4bit:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit              = True,
                bnb_4bit_use_double_quant = True,
                bnb_4bit_quant_type       = "nf4",
                bnb_4bit_compute_dtype    = dtype,
            )

and that makes quantization_config be None as well:

quantization_config = bnb_config,

and that crashes here:

        if hasattr(self, "quantization_config"):
            output["quantization_config"] = (
                self.quantization_config.to_dict()

with the error message:

'NoneType' object has no attribute 'to_dict'

So I'm not sure how to LoRA train this llama model. Any thoughts?

@araleza
Copy link
Author

araleza commented Dec 17, 2023

I tried adding:

[...] and self.quantization_config is not None:

to the end of that line there (and similar additions in two other places that came up), and it hasn't crashed, but it's now taking a very long time to load the model, so maybe it's doing some unwanted conversion?

@araleza
Copy link
Author

araleza commented Dec 17, 2023

Yeah, it finally 'loaded' but then it said some weights of the model checkpoint were not used when initializing LlamaForCausalLM, and it listed a giant list of weights, which I'm guessing was all of them.

The the LoRA training crashed with:

Cannot copy out of meta tensor; no data!

So something definitely did not go well.

@danielhanchen
Copy link
Contributor

@araleza Oh no I don't think GPTQ models are supported as of yet

@danielhanchen
Copy link
Contributor

Currently only QLoRA via bitsandbytes is supported, hence all the error messages. If GPTQ is a super popular request, I will add it in - the dequantization steps will just be replaced, but I will have to read up on how GPTQ does it internally.

For now, is it possible to use a non GPTQ quantized model?

@danielhanchen danielhanchen changed the title Can't load a 4-bit quantized GPTQ model [Feature request] Support GPTQ quantization Dec 17, 2023
@araleza
Copy link
Author

araleza commented Dec 17, 2023

For now, is it possible to use a non GPTQ quantized model?

I don't know actually... I've only done LoRA training with oobabooga's Training tab, and it can only do LoRA training with unquantized models, or GPTQ models (which you have to load with the Transformers loader). So I don't know how to load a quantized model of any format except GPTQ onto my GPU. Any tips for which format to use instead, but still have it fit on my 24GB GPU?

@danielhanchen
Copy link
Contributor

@araleza Would it be possible to try load a non quantized model, then pass load_in_4bit = True via Unsloth? It should load on ur CPU / RAM then it quantizes then loads it into the GPU

@danielhanchen
Copy link
Contributor

I'll see for a future release if I can add GPTQ support!

@danielhanchen
Copy link
Contributor

I was atually just reading up upon HQQ (half quadratic quantization) https://github.com/mobiusml/hqq and maybe I'll be adding HQQ instead of GPTQ since HQQ has no need for data calibration, whilst GPTQ does.

@araleza
Copy link
Author

araleza commented Dec 31, 2023

Sounds good. I think you've got two groups of people who want to use your software:

  1. people who have a big model and big training data, and want the fine tuning to be faster
  2. people with 24GB cards who want to train larger models, but without quantizing them so badly that the training is meaningless.

Supporting HQQ would help the people in group 2, like me.

@danielhanchen
Copy link
Contributor

@araleza Cool I'll get on with HQQ! It seems like even Mixtral can supposedly fit on a 24GB card!

But HQQ supports 8, 4, 3 and 2 bit quantization so it'll be pretty useful!

@jeromeku
Copy link

@danielhanchen happy to pitch in with quantization (or other feature requests). let me know how best to contribute!

@danielhanchen
Copy link
Contributor

@jeromeku More than happy to collaborate! I was actually taking a look at GPTQ the other day - I guess technically Unsloth can add in GPTQ during training - we we need is to port the dequantization kernels from GPTQ to float16 / bfloat16, and if that works, then GPTQ will be supported.

For all, I'm using bitsandbytes's dequantization kernels.

Again more than happy to collaborate if you're interested!

@jeromeku
Copy link

jeromeku commented Jan 14, 2024

@danielhanchen
That should work -- this is what QLoRA does under the hood for non-LoRA weights right? I.e., dequantizes 'frozen' weights to f16 / bf16 in order to pass grads through non-LoRA layers.

I can take a crack at this if you're more keen working on hqq...

@danielhanchen
Copy link
Contributor

@jeromeku I'll investigate GPTQ's dequant kernels as well! But if you're interested in adding GPTQ support - I'm more than happy for a few more OSS collaborators!

Essentially in terms of the main gist of things:

  1. Find how GPTQ dequantizes its quantized weights to float16 / bfloat16
  2. Extract this functionality from say Huggingface internals or some other provider like Exllama / llama.cpp etc
  3. Replace fast_dequantize with GPTQ equivalent kernels
  4. Fix up a few lines where Linear4bit naming conventions are seen with GPTQ equivalent conventions.
  5. If 3 works as is, then Unsloth is now GPTQ compatible!

If you wanna take a crack at that - I'll be super grateful! In fact just step 1 or 2 is enough for a general GPTQ integration!

@jeromeku
Copy link

@danielhanchen
Will work on it!

@danielhanchen
Copy link
Contributor

@jeromeku Great! If you need any help - ask away! I guess we can use this Github issue as a central discussion area. I'll see if I have some time on GPTQ - probably next week ish - I'm trying to work on some other stuff currently.

Again thanks!

@jeromeku
Copy link

@danielhanchen

Trying to understand design decisions / coding style of the library.

What is the purpose of patching {Mistral, Llama}_fast_forward when initializing Mistral (pre_patch)? It seems you are extracting sections directly from the original HF implementations of these layers (which already support flash-attn2) and in some cases using xformers for some of the ops.

Why the use of pass after every function? This is (AFAIK) a rather unconventional python coding style?

@danielhanchen
Copy link
Contributor

@jeromeku prepatch essentially just patches some portions of each function to call their relevant efficient implementation - ie as you mentioned some xformers some FA2.

Oh ye sorry on my coding style - I came from like like C++ / C background so I generally like all functions / if / for loops etc to be "enclosed" to make it "look" compartmentalized.

But you can have whatever coding style you like - for eg I like spaces between eqals during variance assignments, whilst general style is var=2 and not var = 2. It definitely comes from my C background!!

If you're contributing code - I don't mind on style - that's the least of worries! :)) You can use any style you desire - it just has to work :)

@jeromeku
Copy link

@danielhanchen

Any tools / tests you use to check the correctness of gradient implementations?

@danielhanchen
Copy link
Contributor

@jeromeku Oh lol what I do is to get HF to do training, copy paste the training losses to Google Sheets, then with ur updated gradient implementation, log if the new training loss is mostly identical.

Another approach is to use torch.dist or torch.all_close on W.grad and new_W.grad to confirm the gradients. You'll have to do loss.backward(Y) for eg to get the gradients.

@jeromeku
Copy link

@danielhanchen

Ok, was wondering if there was a more efficient way to do this verification. Was trying to use torch.autograd.gradcheck but runs into issues with large inputs / outputs and mixed precision since it needs to realize the full VJP during numerical / analytical gradient calc.

I've adapted GPTQ code to re-implement fast_lora custom fwd / bwd and should have the rest done by early next week.

A minimal way to check the gradient is being calculated correctly -- akin to a unit test -- without having to do a training run would be a worthwhile effort both for existing and future implementations.

@danielhanchen
Copy link
Contributor

@jeromeku Actually I did technically make some functions to check gradients somewhere - I manaully made some random inputs and some random outputs, then backpropagated with torch.backward(outputs), and checked every item's .grad to confirm it - I just need to find where I wrote it :))

@jeromeku
Copy link

@danielhanchen

I wrote a small test script to do gradient checking:

import torch
from datasets import load_dataset

# 4bit pre quantized models we support for 4x faster downloading!
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import DataLoader

from unsloth import FastLanguageModel

DTYPE = torch.float16


def get_model(
    model_id="unsloth/mistral-7b-bnb-4bit",
    reference=True,
    max_seq_length=2048,
    dtype=torch.float16,
    load_in_4bit=True,
    init_lora_weights=False,
    upcast=True,
):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_id,
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
    )

    lora_config = LoraConfig(
        r=16,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        task_type="CAUSAL_LM",
        init_lora_weights=init_lora_weights,
    )

    if reference:
        model = prepare_model_for_kbit_training(
            model,
            use_gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": True},
        )
        model = get_peft_model(model, lora_config)
    else:
        config = lora_config.to_dict()
        del config["task_type"]
        model = FastLanguageModel.get_peft_model(
            model,
            use_gradient_checkpointing=True,
            random_state=3407,
            max_seq_length=max_seq_length,
            upcast=upcast,
            **config,
        )

    return model, tokenizer


ref_model, _ = get_model(dtype=DTYPE)
test_model, _ = get_model(dtype=DTYPE, reference=False)


def check_grad(model, dtype, seed=0, scale=1):
    wrapped_model = model.model.model
    embed_layer = wrapped_model.embed_tokens
    self_attn = wrapped_model.layers[0].self_attn
    mlp = wrapped_model.layers[0].mlp
    torch.manual_seed(seed)

    with torch.autocast(device_type="cuda", dtype=dtype):
        # embeddings = embed_layer(inputs)

        embeddings = torch.randn(
            1, 1, embed_layer.weight.shape[1], dtype=dtype, requires_grad=True
        ).cuda()
        print(f"Attention input dtype: {embeddings.dtype}")
        attn_out, *_ = self_attn(embeddings)
        print(f"Attn out dtype: {attn_out.dtype}")
        mlp_out = mlp(attn_out)

        torch.manual_seed(seed)
        fake_grad_output = scale * torch.randn(mlp_out.shape, dtype=torch.float32).to(
            mlp_out.device
        )
        mlp_out.backward(fake_grad_output)

    return mlp_out, mlp, attn_out, fake_grad_output


mlp_out_ref, mlp_ref, attn_out_ref, fake_grad_ref = check_grad(ref_model, dtype=DTYPE)
print(
    "Grad check after reference backwards:",
    test_model.model.model.layers[0].mlp.down_proj.lora_B.default.weight.grad,
)
mlp_out, mlp, attn_out, fake_grad = check_grad(test_model, dtype=DTYPE)

ref_type = torch.float32
print()
print(
    f"Checking fake grad (dY): {torch.allclose(fake_grad.to(ref_type), fake_grad_ref.to(ref_type))}"
)
# torch.max(torch.abs(fake_grad.to(ref_type) - fake_grad_ref.to(ref_type)))
# torch.allclose(mlp_out.to(ref_type), mlp_out_ref.to(ref_type))

print(f"Checking mlp grads:")
for (n1, m1), (n2, m2) in zip(mlp.named_parameters(), mlp_ref.named_parameters()):
    if "lora" in n1 and "lora" in n2:
        n1 = ".".join(n1.split(".")[:2])
        print(f"{n1}")
        print(
            f"Mean grad:\n  UNSLOTH: {m1.grad.max():.10f}\n  REF: {m2.grad.mean():.10f}\nMax abs diff: {torch.max(torch.abs(m1.grad - m2.grad)):.10f}\nMean abs diff: {torch.mean(torch.abs(m1.grad - m2.grad)):.10f}"
        )
        print()

print("Checking attn grads:")
for (n1, m1), (n2, m2) in zip(
    ref_model.model.model.layers[0].self_attn.named_parameters(),
    test_model.model.model.layers[0].self_attn.named_parameters(),
):
    if "lora" in n1 and "lora" in n2:
        # torch.allclose(m1.grad.to(dtype), m2.grad.to(dtype))
        n1 = ".".join(n1.split(".")[:2])
        print(f"{n1}")
        print(
            f"Mean grad:\n  UNSLOTH: {m1.grad.max():.10f}\n  REF: {m2.grad.max():.10f}\nMax abs diff: {torch.max(torch.abs(m1.grad - m2.grad)):.10f}\nMean abs diff: {torch.mean(torch.abs(m1.grad - m2.grad)):.10f}"
        )
        print()

Note: there are small inconsistencies between prepare_model_for_kbit_training in unsloth vs. huggingface peft when doing QLoRA fine-tuning -- peft upcasts all non-INT-8 params to fp32 -- see here.

I added an upcast kwarg to unsloth FastLanguageModel.get_peft_model that is passed to prepare_model_for_kbit_training to replicate this behavior:

def prepare_model_for_kbit_training(
    model: Any,
    use_gradient_checkpointing: bool = True,
    use_reentrant: Optional[bool] = True,
    upcast=False,
) -> Any:
    """
    Calculates where to place the gradient checkpoints given n_layers.
    We also freeze all other layers's gradients

    Args:
        model: Any LlamaModel with layers.
        use_gradient_checkpointing (`bool`, *optional*):
            Default enabled. Provides memory savings by not saving all activations,
            but only some.
        use_reentrant (`bool`, *optional*):
            https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
            Optimal gradient checkpointing algorithm which will be the default in
            future Pytorch versions.
    """

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad_(False)

    # Cast non INT8 parameters to fp32
    if upcast:
        for param in model.parameters():
            if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
                param.data = param.data.to(torch.float32)

    if use_gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
    if use_reentrant:
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:

            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    return model

Here is the output from running the above script:

Checking mlp grads:
gate_proj.lora_A
Mean grad:
  UNSLOTH: 0.0441589355
  REF: 0.0000020351
Max abs diff: 0.1207160950
Mean abs diff: 0.0097856047

gate_proj.lora_B
Mean grad:
  UNSLOTH: 0.0051155090
  REF: 0.0000001698
Max abs diff: 0.0086461902
Mean abs diff: 0.0002924677

up_proj.lora_A
Mean grad:
  UNSLOTH: 0.0850219727
  REF: -0.0000299520
Max abs diff: 0.1020736694
Mean abs diff: 0.0135316616

up_proj.lora_B
Mean grad:
  UNSLOTH: 0.0048866272
  REF: -0.0000000757
Max abs diff: 0.0068296790
Mean abs diff: 0.0002973406

down_proj.lora_A
Mean grad:
  UNSLOTH: 0.0928344727
  REF: -0.0000352956
Max abs diff: 0.2047328949
Mean abs diff: 0.0073212739

down_proj.lora_B
Mean grad:
  UNSLOTH: 0.0037288666
  REF: 0.0000003116
Max abs diff: 0.0040407181
Mean abs diff: 0.0002820148

Checking attn grads:
q_proj.lora_A
Mean grad:
  UNSLOTH: -0.0000000000
  REF: -0.0000000000
Max abs diff: 0.0000000000
Mean abs diff: 0.0000000000

q_proj.lora_B
Mean grad:
  UNSLOTH: 0.0000000000
  REF: -0.0000000000
Max abs diff: 0.0000000000
Mean abs diff: 0.0000000000

k_proj.lora_A
Mean grad:
  UNSLOTH: -0.0000000000
  REF: -0.0000000000
Max abs diff: 0.0000000000
Mean abs diff: 0.0000000000

k_proj.lora_B
Mean grad:
  UNSLOTH: -0.0000000000
  REF: 0.0000000000
Max abs diff: 0.0000000000
Mean abs diff: 0.0000000000

v_proj.lora_A
Mean grad:
  UNSLOTH: 0.1055297852
  REF: 0.1329345703
Max abs diff: 0.1655731201
Mean abs diff: 0.0144135132

v_proj.lora_B
Mean grad:
  UNSLOTH: 0.0139694214
  REF: 0.0166625977
Max abs diff: 0.0193632841
Mean abs diff: 0.0024413881

o_proj.lora_A
Mean grad:
  UNSLOTH: 0.1630859375
  REF: 0.1149902344
Max abs diff: 0.1842651367
Mean abs diff: 0.0191203523

o_proj.lora_B
Mean grad:
  UNSLOTH: 0.0102157593
  REF: 0.0053596497
Max abs diff: 0.0119572878
Mean abs diff: 0.0010805393

Thoughts?

@danielhanchen
Copy link
Contributor

@jeromeku Great work! Some pointers:

torch.manual_seed sadly does not actually work on GPUs - torch.cuda.manual_seed is the one you want!!

torch.randn can also take device = "cuda" - so I guess my first point of manual_seed is irrelevant since ur copying from CPU to GPU

Yep one issue is the upcasting to float32 which is one of the optimizations we found for VRAM reduction.

You can see there are error differences - mainly due to Flash Attention - Pytorch does Q @ K.T and other attention ops in float16, whilst FA upcasts internally to fp32, which makes it more equivalent to full float32 training - hence the error differences.

I think the reference model you used does not have FA enabled.

But ye - great work again - super useful script :)))

@jeromeku
Copy link

@danielhanchen

What do you consider permissible range of gradient discrepancies between the unsloth and the reference HF implementation?

I.e., there are differences (e.g., up_proj) that are on the same order of magnitude as the mean grads themselves -- can this be chalked up to the use of f32 vs f16...

@danielhanchen
Copy link
Contributor

@jeromeku Ye one of the issues I found as well when verifying Unsloth vs normal HF - thats what I for now opted to just compare training losses directly

@jeromeku
Copy link

@danielhanchen

Just wanted to give a quick update:

  • I have a working implementation of gptq fast_lora working.
    • I patched in a triton quantized matmul kernel into the existing fused forward / backward layers
    • Training works and is the losses are on par with the default HF gptq fine-tuner (the non-fused, torch-only GPTQ fine-tuning model if you provide a gptq quantized model to the standard from_pretrained loader).
    • However, the training runs are slower than the default HF model (and also the unsloth bnb version).
  • Need to do some additional profiling / debugging to see where the problems are and whether a torch.compile version of the quantized matmul kernel outperforms the triton kernel.

@danielhanchen
Copy link
Contributor

@jeromeku Super great work! Are you testing it on a Tesla T4 or Ampere based GPU? I found older GPUs Triton kernels to be noticeably slower.

Also I found through experimentation instead of writing 1 full fused kernel for matrix mult and dequantization, to split it into 2. The dequant step should only take 1-2ms, whilst the matrix mult takes 30ms or so. The compiler can be "confused" on the dequant steps, causing it to not optimize correctly, so I found using torch.matmul to be most effective.

@jeromeku
Copy link

@danielhanchen
I've been testing on an Ampere-based GPU (A6000).

  • Going to do some additional profiling to determine bottlenecks vs. vanilla HF implementation and the unsloth bnb version.
  • Additional optimizations after above analysis.
  • Will post a draft PR to make collab easier.

@danielhanchen
Copy link
Contributor

@jeromeku Oh ok cool! If I have to guess, it's that NVCC / the Trtion compiler is not optimizing "properly" - also did u use the matmul Triton autotuner? It could be that maybe?

@jeromeku
Copy link

@danielhanchen
Yes - used a custom autotuner that is essentially the same as default triton matmul autotuner. Without the autotuner, performance is even worse.

@danielhanchen
Copy link
Contributor

@jeromeku Ohh ok ok interesting - I'm just guessing somewhere the compiler is not optimizing the dequantization parts properly

@danielhanchen danielhanchen added on roadmap Feature request on roadmap help wanted Help from the OSS community wanted! labels Jan 27, 2024
@jeromeku
Copy link

Did some preliminary profiling using torch.profiler of 4 implementations:

  • Default huggingface GPTQ peft tuning, which defaults to the auto_gptq cuda kernel which in turn defaults to a torch-only implementation
  • Huggingface GPTQ model but with auto_gptq triton quant linear layers replacing default quant linear layer
  • unsloth GPTQ fast_lora implementation which fuses triton quant matmul with LoRA adapters
  • unsloth bitsandbytes fast_lora implementation

All were 4-bit Mistral models ("TheBloke/Mistral-7B-v0.1-GPTQ" and
"unsloth/mistral-7b-bnb-4bit") running a sample batch of data for 10 iterations (5 warmup, 5 active) using float16 as torch.autocast dtype.

Summary results, sorted by CUDA time:

  • huggingface default gptq peft
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        54.22%        3.263s        76.08%        4.578s     915.608ms        2.130s        34.20%        4.582s     916.372ms             5  
                                               aten::mm         3.13%     188.433ms         3.13%     188.433ms      21.007us     858.252ms        13.78%     858.252ms      95.680us          8970  
                                              aten::mul         1.24%      74.524ms         1.24%      74.524ms       9.083us     480.920ms         7.72%     480.920ms      58.613us          8205  
                                            aten::index         3.63%     218.297ms         3.84%     230.936ms      90.209us     414.696ms         6.66%     428.958ms     167.562us          2560  
                                            aten::copy_         1.02%      61.614ms         1.02%      61.614ms       5.030us     394.459ms         6.33%     394.459ms      32.201us         12250  
                              aten::bitwise_right_shift         0.40%      24.235ms         0.40%      24.235ms      10.819us     245.560ms         3.94%     245.560ms     109.625us          2240  
                                              aten::sub         0.13%       7.855ms         0.13%       7.855ms       6.951us     154.506ms         2.48%     154.506ms     136.731us          1130  
                                              aten::add         0.85%      50.884ms         0.85%      50.884ms       8.819us     137.561ms         2.21%     137.561ms      23.841us          5770  
                                      aten::bitwise_and         1.16%      69.900ms         1.16%      69.900ms      31.205us     107.831ms         1.73%     107.831ms      48.139us          2240  
autograd::engine::evaluate_function: BackwardHookFun...         3.71%     223.360ms         4.34%     261.250ms      20.104us      99.978ms         1.61%     189.637ms      14.593us         12995  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.018s
Self CUDA time total: 6.228s
  • huggingface gptq with triton patch
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        50.17%        4.002s        69.23%        5.522s        1.104s        2.921s        35.56%        5.457s        1.091s             5  
                                    QuantLinearFunction         5.56%     443.474ms         5.60%     447.013ms     399.119us        1.879s        22.88%        1.892s       1.689ms          1120  
                            QuantLinearFunctionBackward         3.29%     262.817ms         3.34%     266.309ms     237.776us        1.697s        20.65%        1.709s       1.526ms          1120  
                                            aten::copy_         1.11%      88.517ms         1.11%      88.517ms       6.109us     272.443ms         3.32%     272.443ms      18.802us         14490  
                                              aten::mul         0.89%      71.151ms         0.89%      71.151ms      10.042us     208.454ms         2.54%     208.454ms      29.422us          7085  
                                               aten::mm         2.01%     160.412ms         2.01%     160.412ms      23.835us     120.875ms         1.47%     120.875ms      17.961us          6730  
                                             aten::add_         0.39%      31.360ms         0.39%      31.360ms       6.316us      84.104ms         1.02%      84.104ms      16.939us          4965  
                                              aten::add         0.26%      20.900ms         0.26%      20.900ms       8.672us      65.398ms         0.80%      65.398ms      27.136us          2410  
                                              aten::bmm         0.27%      21.311ms         0.27%      21.311ms      22.199us      53.940ms         0.66%      53.940ms      56.188us           960  
                                         aten::_to_copy         2.16%     172.124ms         3.33%     265.847ms      24.312us      51.294ms         0.62%     294.112ms      26.896us         10935  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.977s
Self CUDA time total: 8.214s
  • unsloth triton
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        36.56%        2.499s        66.76%        4.564s     912.765ms        2.388s        35.07%        4.500s     899.941ms             5  
                                       LoRA_MLPBackward         3.19%     217.928ms         6.51%     445.002ms       2.781ms        1.435s        21.08%        1.618s      10.110ms           160  
                                               LoRA_MLP         2.21%     150.863ms         3.60%     246.300ms       1.539ms        1.370s        20.11%        1.466s       9.162ms           160  
                                               LoRA_QKV         2.02%     138.021ms         3.51%     240.104ms       1.501ms     239.376ms         3.51%     304.731ms       1.905ms           160  
                                       LoRA_QKVBackward        14.58%     996.555ms        18.14%        1.240s       7.752ms     233.183ms         3.42%     370.001ms       2.313ms           160  
                                               aten::mm         1.94%     132.859ms         1.94%     132.859ms      14.811us     162.068ms         2.38%     162.068ms      18.068us          8970  
                                                 LoRA_W         0.75%      51.438ms         1.25%      85.611ms     535.069us     130.279ms         1.91%     154.580ms     966.125us           160  
                                         LoRA_WBackward         1.09%      74.258ms         2.19%     149.957ms     937.231us     117.188ms         1.72%     162.495ms       1.016ms           160  
                                             aten::add_         0.45%      30.965ms         0.45%      30.965ms       6.048us     102.670ms         1.51%     102.670ms      20.053us          5120  
                                            aten::copy_         0.82%      56.340ms         0.82%      56.340ms       5.567us      55.663ms         0.82%      55.663ms       5.500us         10120  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.836s
Self CUDA time total: 6.810s
  • unsloth bitsandbytes
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        42.36%        1.415s        63.84%        2.132s     426.477ms        1.277s        38.22%        2.111s     422.283ms             5  
                                               aten::mm         4.95%     165.347ms         4.95%     165.347ms      14.750us     882.710ms        26.43%     882.710ms      78.743us         11210  
                                       LoRA_MLPBackward         4.35%     145.286ms        11.18%     373.488ms       2.334ms     193.455ms         5.79%     665.271ms       4.158ms           160  
                                               LoRA_MLP         2.83%      94.647ms         5.66%     189.087ms       1.182ms     163.861ms         4.91%     547.066ms       3.419ms           160  
                                             aten::add_         1.27%      42.497ms         1.27%      42.497ms       5.774us     119.956ms         3.59%     119.956ms      16.298us          7360  
                                        aten::transpose         3.73%     124.711ms         3.86%     128.893ms       5.835us      66.135ms         1.98%      99.478ms       4.503us         22090  
                                       LoRA_QKVBackward         3.88%     129.752ms        10.58%     353.476ms       2.209ms      59.565ms         1.78%     266.389ms       1.665ms           160  
                                                aten::t         3.40%     113.613ms         6.65%     222.031ms      11.754us      57.437ms         1.72%     143.028ms       7.572us         18890  
                                            aten::copy_         1.33%      44.395ms         1.33%      44.395ms       4.933us      51.266ms         1.53%      51.266ms       5.696us          9000  
                                           aten::matmul         4.35%     145.459ms        12.34%     412.209ms      30.659us      44.673ms         1.34%        1.002s      74.520us         13445  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.340s
Self CUDA time total: 3.340s

It seems the custom LoRA layers of my GPTQ implementation not as efficient as the existing bitsandbytes fast_lora implementations.

  • Going to play around with optimizing these custom autograd.Functions to see what other juice can be squeezed from triton quantized matmul kernels which are already autotuned.
  • Also will torch.compiling to see what kind of fused implementations it can spit out
  • Additional analysis using nsys also

Will draft PR the profiling script and documentation along with current fast_lora gptq implementation once cleaned up.

@danielhanchen
Copy link
Contributor

@jeromeku LOVEE the detailed profiling!!! Just love it!! Great work again. Interesting so the Unsloth BnB kernels run in around 3.34s whilst HF's GPTQ runs in 6.2s. HF GPTQ with your Triton patch is 8 ish seconds, and Unsloth with your Trtion patch is 6.8 seconds.

Very interesting results! Did you manage to test a GPTQ just dequantize kernel, but with Unsloth? I can see in Unsloth, matrix multiplies are taking 26% of all time, whilst GPTQ is 13% Unsloth Triton is 3% (looks like overhead?) and HF + Triton is 1.5%. The goal is to move the majority of the time over to matrix multiplies in order to leverage the GPU's Tensor Cores :))

But anyways I love the table and results and fabulous work!

@jeromeku
Copy link

@danielhanchen

Yes -- there seems to be some overhead issues with the unsloth triton quant / dequant kernels.

Just opened a draft PR with the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Help from the OSS community wanted! on roadmap Feature request on roadmap
Projects
None yet
Development

No branches or pull requests

3 participants