Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use this frame work to train a LLM with multi-GPUs? #149

Open
zhhvvv opened this issue Jan 18, 2024 · 7 comments
Open

How to use this frame work to train a LLM with multi-GPUs? #149

zhhvvv opened this issue Jan 18, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@zhhvvv
Copy link

zhhvvv commented Jan 18, 2024

Is the frame work support multi-gpu training?
I want to use the frame work to train a 70B model, however, I did not find the parameter settings or methods for multi-gpus training.

@yezhem
Copy link
Collaborator

yezhem commented Jan 18, 2024

The parallel functionality is currently under development and testing, and is expected to be available by the end of the month.

@yezhem yezhem added the good first issue Good for newcomers label Jan 18, 2024
@zhhvvv
Copy link
Author

zhhvvv commented Jan 18, 2024

Hi, thx for replying!

Will your team will test the frame work on the llama2-70B model? As, I also find that the frame work can support llama2-7b and llama2-13b, but it does not support the llama2-70B model.

Maybe this problem caused by inconsistent output dimensions of the qkv layers of the 70b model? (7b and 13b model have the same output dimensions at qkv layers‘ output)

截屏2024-01-18 16 37 19

I hope you can teach me how to fix the code to address this problem! thx

And here is the traceback of the error:

Traceback (most recent call last):
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora.py", line 199, in <module>
    mlora.train(mlora.Dispatcher(config, tokenizer), model,
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora/train.py", line 69, in train
    output, router_outputs = llm_model.forward(input)
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora/model_llama.py", line 292, in forward
    data = seq_layer.forward(data)
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora/model_llama.py", line 223, in forward
    output = CheckpointRecomputeFunction.apply(
  File "/home/miniconda3/envs/mlora_env/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora/checkpoint.py", line 99, in forward
    outputs = run_function(*args)
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora/model_llama.py", line 158, in forward
    xv = self.wv_.forward(attention_norm_data, input_args)
  File "/mnt/project/longcontext/multi-lora-fine-tune/mlora/lora_liner.py", line 110, in forward
    result[start_idx: end_idx] += self.loras_[
RuntimeError: The size of tensor a (1024) must match the size of tensor b (8192) at non-singleton dimension 2

I print the tensors' shape with:(in the lora_liner.py Linear.forward() )

print('==============================')
            print(result[start_idx: end_idx].shape)
            #print(self.loras_.shape)
            print(self.loras_[adapter_name].forward(data[start_idx:end_idx]).shape)
            print('++++++++++++++++++++++++++++++')

I got


==============================
torch.Size([1, 104, 8192])
torch.Size([1, 104, 8192])
++++++++++++++++++++++++++++++
==============================
torch.Size([1, 104, 1024])
torch.Size([1, 104, 8192])
++++++++++++++++++++++++++++++

Before the error of the execution, it printed twice.

@yezhem
Copy link
Collaborator

yezhem commented Jan 18, 2024

seem the wrong lora dimension.
can you check the in_dim and out_dim value in the init_lora_weight function(mlora/lora_liner.py:58)
the lora's dimension be inited in this function, maybe you can fix the problem in this function.

@zhhvvv
Copy link
Author

zhhvvv commented Jan 19, 2024

in (model_llama.py:301)
I print the layers info in the mlora.model_llama.LlamaModel object

for idx, layer in enumerate(llama_model.model.layers):
            model.layers_[idx].wq_ = Linear(
                layer.self_attn.q_proj, device=device)
            model.layers_[idx].wk_ = Linear(
                layer.self_attn.k_proj, device=device)
            model.layers_[idx].wv_ = Linear(
                layer.self_attn.v_proj, device=device)
            print('==============================')
            print(idx)
            print(layer.self_attn.k_proj)
            print(model.layers_[idx].wk_)
            print('++++++++++++++++++++++++++++++')
            model.layers_[idx].wo_ = Linear(
                layer.self_attn.o_proj, device=device)
            model.layers_[idx].ffn_ = FeedForward(
                norm=RMSNorm(layer.post_attention_layernorm.weight.to(
                    device=device).detach(), model.norm_eps_),
                w1=Linear(layer.mlp.gate_proj, device=device),
                w2=Linear(layer.mlp.down_proj, device=device),
                w3=Linear(layer.mlp.up_proj, device=device),
                device=device
            )
            model.layers_[idx].attention_norm_ = RMSNorm(
                layer.input_layernorm.weight.to(device=device).detach(), model.norm_eps_)
        return model

And I got this result:

Linear4bit(in_features=8192, out_features=8192, bias=False)
Linear(
  (weight_): Linear4bit(in_features=8192, out_features=8192, bias=False)

layer.self_attn.k_proj of each layers is different from the llama2-70b model's config(Linear(in_features=8192, out_features=1024, bias=False))

I think there are bugs in the LlamaModel Class and LLMModelArgs Class.

@yezhem
Copy link
Collaborator

yezhem commented Jan 19, 2024

the raw liners are created from the transformers.LlamaForCausalLM, and we get the in_features and out_features from the Linear.
can you check the llama_model's structure in mlora/model_llama.py:332.

@zhhvvv
Copy link
Author

zhhvvv commented Jan 19, 2024

I sove the problem by updating the transformers

I found transformers==4.30.2
It can not load a llama2-70b model

size mismatch for model.layers.79.self_attn.k_proj.weight: copying a param with shape torch.Size([1024, 8192]) from checkpoint, the shape in current model is torch.Size([8192, 8192]).
size mismatch for model.layers.79.self_attn.v_proj.weight: copying a param with shape torch.Size([1024, 8192]) from checkpoint, the shape in current model is torch.Size([8192, 8192]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

I recommend that you upgrade the dependencies of transformers during subsequent updates.

@yezhem yezhem added bug Something isn't working and removed good first issue Good for newcomers labels Jan 19, 2024
@yezhem
Copy link
Collaborator

yezhem commented Jan 19, 2024

thks, we will evaluate this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants