Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to finetune this on a custom dataset? #17

Open
asmith26 opened this issue Jan 10, 2024 · 7 comments
Open

Is it possible to finetune this on a custom dataset? #17

asmith26 opened this issue Jan 10, 2024 · 7 comments

Comments

@asmith26
Copy link

asmith26 commented Jan 10, 2024

Hi there,

Just wondering is it possible to fine tune this model on a custom dataset? If so, are there any examples/code?

Many thanks for any help, and for this amazing model, I'm finding it works really well!

@dvmazur
Copy link
Owner

dvmazur commented Jan 10, 2024

Hi!

Full fine-tuning won't work as the model is quantized, but you could try fine-tuning the model using various PEFT techniques which work with quantized base models. Check out QLoRA for example.

Hope this is helpful.

@complete-dope
Copy link

complete-dope commented Jan 16, 2024

@dvmazur any link where this has been implemented or if you have done something similar please share that would be helpful !!

@asmith26 Did you found any method ?

@nmarafo
Copy link

nmarafo commented Jan 24, 2024

The structure of the loaded model is:

  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralAttention(
          (q_proj): HQQLinearTritonSavable()
          (k_proj): HQQLinearTritonSavable()
          (v_proj): HQQLinearTritonSavable()
          (o_proj): HQQLinearTritonSavable()
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): SparseMoeWrapper(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

When I try to train with

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

I get that peft is not compatible with HQQLinearTritonSavable, evidently:
ValueError Traceback (most recent call last)
in <cell line: 12>()
10 )
11
---> 12 model = get_peft_model(model, config)
13 print_trainable_parameters(model)

7 frames
/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py in _create_new_module(lora_config, adapter_name, target, **kwargs)
255 if new_module is None:
256 # no module could be matched
--> 257 raise ValueError(
258 f"Target module {target} is not supported. Currently, only the following modules are supported: "
259 "torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D."

ValueError: Target module HQQLinearTritonSavable() is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D.

@dvmazur
Copy link
Owner

dvmazur commented Jan 24, 2024

Hey, @nmarafo and @complete-dope!

It looks like using huggingface's peft for fine-tuning the offloaded model is a bit tricky (due to custom layers mostly), but I haven't looked into it myself.

A LoRA fine-tuning setup similar to the original paper can be hacked together quite simply:

# imports

class LoRALayer(nn.Module):
    def __init__(self, module: nn.Linear, rank: int):
        super().__init__()
        self.module = module
        self.adapter_A = nn.Parameter(torch.empty(module.in_features, rank, device=module.weight.device))
        nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
        self.adapter_B = nn.Parameter(torch.zeros(rank, module.out_features, device=module.weight.device))

    def forward(self, input):
        bottleneck = F.linear(input, self.adapter_A.T)
        residual = F.linear(bottleneck, self.adapter_B.T)
        return self.module(input) + residual

def custom_get_peft_model(model, rank):
    for _, module in model.named_modules():
        if not isinstance(module, MixtralAttention):
            continue
        module.q_proj = LoRALayer(module.q_proj, rank)
        # TODO: {k, v, o}_proj
    return model

Note that this example only applies LoRA to attention parameters. Doing the same for the expert layers is tricker as it might break the ExpertCache (haven't looked into that myself yet).

@nmarafo
Copy link

nmarafo commented Jan 24, 2024

Thank you very much for the answer.

Sorry for my inexperience, I'm trying to implement it like this:

import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralAttention

class LoRALayer(nn.Module):
    def __init__(self, module: nn.Linear, rank: int):
        super().__init__()
        self.module = module
        self.adapter_A = nn.Parameter(torch.empty(module.in_features, rank, device=module.weight.device))
        nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
        self.adapter_B = nn.Parameter(torch.zeros(rank, module.out_features, device=module.weight.device))

    def forward(self, input):
        bottleneck = F.linear(input, self.adapter_A.T)
        residual = F.linear(bottleneck, self.adapter_B.T)
        return self.module(input) + residual

def custom_get_peft_model(model, rank):
    for _, module in model.named_modules():
        if not isinstance(module, MixtralAttention):
            continue
        module.q_proj = LoRALayer(module.q_proj, rank)
        # TODO: {k, v, o}_proj
    return model

model = custom_get_peft_model(model, rank=8)

and I get this error:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in __getattr__(self, name)
   1693             if name in modules:
   1694                 return modules[name]
-> 1695         raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
   1696 
   1697     def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

AttributeError: 'HQQLinearTritonSavable' object has no attribute 'in_features'

@nmarafo
Copy link

nmarafo commented Jan 24, 2024

Perhaps is solved with this:

import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from src.custom_layers import HQQLinearTritonSavable

class LoRALayer(nn.Module):
    def __init__(self, module: HQQLinearTritonSavable, rank: int):
        super().__init__()
        self.module = module
        in_features = module.meta['shape'][1]
        out_features = module.meta['shape'][0]
        self.adapter_A = nn.Parameter(torch.empty(in_features, rank, device=module.W_q.device))
        nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
        self.adapter_B = nn.Parameter(torch.zeros(rank, out_features, device=module.W_q.device))
       
    def forward(self, input):
        bottleneck = F.linear(input, self.adapter_A.T)
        residual = F.linear(bottleneck, self.adapter_B.T)
        return self.module(input) + residual

def custom_get_peft_model(model, rank):
    for _, module in model.named_modules():
        if not isinstance(module, MixtralAttention):
            continue
        module.q_proj = LoRALayer(module.q_proj, rank)
        # TODO: {k, v, o}_proj
    return model
model = custom_get_peft_model(model, rank=8)
´´´

@dvmazur
Copy link
Owner

dvmazur commented Jan 24, 2024

I'm not sure whether (module.meta['shape'][1], module.meta['shape'][0]) is the correct shape. Maybe you should try pulling the correct shape from the original model's config.

from transformers import AutoConfig

config = AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

head_dim = config.hidden_size // config.num_attention_heads
#              (in_features, out_features)
q_proj_shape = (config.hidden_size, config.num_attention_heads * head_dim)
k_proj_shape = (config.hidden_size, config.num_key_value_heads * head_dim)
v_proj_shape = (config.hidden_size, config.num_key_value_heads * head_dim)
o_proj_shape = (config.num_attention_heads * head_dim, config.hidden_size)

Haven't checked whether these shapes are correct, but they must be.

If this snippet doesn't work, you could try reconstructing the original shapes from here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants