Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Suggest] Tensor Parallellism for Accelerating LLM #29

Open
zhengpeirong opened this issue Apr 26, 2024 · 22 comments
Open

[Feature Suggest] Tensor Parallellism for Accelerating LLM #29

zhengpeirong opened this issue Apr 26, 2024 · 22 comments

Comments

@zhengpeirong
Copy link

zhengpeirong commented Apr 26, 2024

Dear Author,

Your contribution is critical for the open-source community. The distributed-llama repo has implemented tensor parallelism from scratch. And the result is amazingly significant. However, there are still improvements that could be made. Because of my poor coding ability, not able to make improvements myself, I hope you can look at my suggestions below.

Challenge: root node's special task and synchronization

When I run the repo version '0.1.0', I find that the softmax operations in MultiHead are conducted on the root node only. This operation costs a significant portion of the total time. Second, the synFfnA and synFfn2 functions also cost a lot of time.

Mature solutions

In fact, these challenges have been found in this paper: https://arxiv.org/abs/1909.08053. Its solution is shown in the image:

image

It conducts attention mechanism(softmax) on every worker. Second, the matrix segmentation direction is using column segment and row segment in two consecutive matrices, thus reducing to one synchronization operation instead of two.

If you are willing to make further improvements to the repo, the following is the mature solution for every component of llama2 using tensor parallelism and sequence parallelism.
https://pytorch.org/tutorials/intermediate/TP_tutorial.html
However, it's implemented in Python, and you will be the first one to implement the solution in C++.

Thanks for your contribution!!!
Best Regards

@zhengpeirong
Copy link
Author

zhengpeirong commented Apr 29, 2024

image
Just as a supplement, the figure shows detailed time costs for each task when 4 Raspberry Pis run Llama2-7B-Q40. As you can see, how much time the aforementioned functionk costs. And if you can make these computations parallel according to the 'Mature solution', then the time will decrease nearly linearly with the number of devices increasing. @b4rtaz

@b4rtaz
Copy link
Owner

b4rtaz commented Apr 29, 2024

Nice measurments! It seems multiheadAtt is super slow.

@zhengpeirong please check the 0.3.1 version. Now all tasks are executed in parallel so it should be a bit better.

@zhengpeirong
Copy link
Author

@b4rtaz The 'qkv' has been reverted. Do you plan to deal with this issue? Not only the 'MulHead' costs time, but also the 'Finalize' costs a big portion of time.

@b4rtaz
Copy link
Owner

b4rtaz commented Apr 30, 2024

@zhengpeirong yes I know. The qkv seems be quite good optimalized if you look at the rest layers. Still the qkv may be improved in this way as you suggested in the first post. I didn't have time to read it yet.

With the finalize layer is that problem, the output of this layer is large (vocabSize) and I think it's not a good idea to synchronize it. But maybe it could be optimised in this way that a worker would use the sampler on slice of own output, then the root node could merge it somehow. Different samplers would require a different logic for merging, but it looks doable (for example sample_argmax looks super easy).

Yes, I want to keep working on this project. More hands are welcome. :-)

@zhengpeirong
Copy link
Author

@b4rtaz Thanks for your persistence and endeavor.

  1. The qkv can be optimized, and all you need to read is the "3. Model Parallel Transformers" of the paper "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism".
  2. The finalize can be optimized as your design.
  3. The transfer time for the FFN layer can be reduced from 2 to 1 by utilizing the method in the paper "Megatron-LM".

If you combine all those mechanisms, the non-parallel functions will be optimized! Here is the draft workflow:

TransformerArch buildLlama2Arch(TransformerSpec* spec) {
    TransformerArch a;

    // Inference

    a.I(sendPoke, TASK_TYPE_TRANSFER);
    for (int i = 0; i < spec->nLayers; i++) {
        a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); // Combine the existing llamaRmsAtt and llamaRmsAttNorm
        a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); // Quantization
        a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); // Sending
        a.I(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
        a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
        a.I(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
        a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
        a.I(llamaSyncAtt, TASK_TYPE_TRANSFER); // First communication time-consuming
        a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
        a.I(llamaMergeAtt, TASK_TYPE_INFERENCE); // Merge all attention matrices
        a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
        a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
        a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
        a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
        a.I(llamaFfn, TASK_TYPE_INFERENCE); // Compute SwiGLU activation
        a.I(llamaFfn2, TASK_TYPE_INFERENCE); // Compute the second FFN
        a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
        a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER); // Second communication time-consuming
        a.I(llamaDequantizeFfn2, TASK_TYPE_INFERENCE);
        a.I(llamaMergeFfn2, TASK_TYPE_INFERENCE);
        a.I(llamaNextBlock, TASK_TYPE_INFERENCE);
    }
    a.I(llamaRmsFinal, TASK_TYPE_INFERENCE);
    a.I(llamaRmsFinalNorm, TASK_TYPE_INFERENCE);
    a.I(llamaLogits, TASK_TYPE_INFERENCE);
    a.I(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
    a.I(llamaSyncLogits, TASK_TYPE_TRANSFER);
    a.I(llamaDequantizeLogits, TASK_TYPE_INFERENCE);
    a.I(llamaMergeLogits, TASK_TYPE_INFERENCE);

    // Worker

    for (int i = 0; i < spec->nLayers; i++) {
        a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
        a.W(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
        a.W(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
        a.W(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
        a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
        a.W(llamaSyncAtt, TASK_TYPE_TRANSFER);
        a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
        a.W(llamaFfn, TASK_TYPE_INFERENCE);
        a.W(llamaFfn2, TASK_TYPE_INFERENCE);
        a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
        a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER);
        a.W(llamaNextBlock, TASK_TYPE_INFERENCE);
    }
    a.W(llamaLogits, TASK_TYPE_INFERENCE);
    a.W(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
    a.W(llamaSyncLogits, TASK_TYPE_TRANSFER);

    return a;
}

I hope this repo can catch up with the state-of-the-art algorithm as soon as possible~~

@zhengpeirong
Copy link
Author

The optimized result will be only 72% of the original generated time!!! It's 1.39x acceleration than this version.
I have roughly computed the optimized result. Specifically, the main transfer time only happens twice and the workload for the root node is divided among 4 workers.
image

@b4rtaz
Copy link
Owner

b4rtaz commented May 7, 2024

@zhengpeirong this is just a guess, have you proved that by any implementation?

Currently I noticed a problem with the rope layer, it's not easy to split it, because to calculate the output of this layer we need:

Output digits <0; kvDim) = q & k
Output digits <kvDim; dim) = q

So the current implementation divides q and k outputs into equal parts (<0; s), <s; s+1)...). This won't work for the rope, because the first node would require a bit of k output from the second node etc...

I see some posibility to solve it but it is much complex that I thought. I probably should split Q & K layers into many small columns (width=2), and assign columns to nodes.

worker 1: output digits 1, 2, 6, 8, ... (n, n + 1)
worker 2: output digits 2, 4, 9, 10, ... (n + 2, n + 3)

The paper that you linked doesn't have any part about the rope. So probably we have a different case here.

@b4rtaz b4rtaz mentioned this issue May 7, 2024
@zhengpeirong
Copy link
Author

zhengpeirong commented May 7, 2024

@zhengpeirong this is just a guess, have you proved that by any implementation?

Currently I noticed a problem with the rope layer, it's not easy to split it, because to calculate the output of this layer we need:

Output digits <0; kvDim) = q & k
Output digits <kvDim; dim) = q

So the current implementation divides q and k outputs into equal parts (<0; s), <s; s+1)...). This won't work for the rope, because the first node would require a bit of k output from the second node etc...

I see some posibility to solve it but it is much complex that I thought. I probably should split Q & K layers into many small columns (width=2), and assign columns to nodes.

worker 1: output digits 1, 2, 6, 8, ... (n, n + 1)
worker 2: output digits 2, 4, 9, 10, ... (n + 2, n + 3)

The paper that you linked doesn't have any part about the rope. So probably we have a different case here.

The issue you are currently facing lies in separately calculating the QKV matrices, which are split according to the hidden dimension. Therefore, splitting RoPE cannot be easily implemented.

However, if tensor parallelism is supported, the splitting is performed along the num_head dimension, dividing the attention heads across different devices. This is independent of the hidden dimension dimension where RoPE resides, thus avoiding the problem you encountered.

In summary, the RoPE computation and the multi-head attention computation are orthogonal, operating on different dimensions: the former on the hidden dimension and the latter on the num_head dimension. The RoPE part can be easily completed separately on each device.

@b4rtaz
Copy link
Owner

b4rtaz commented May 9, 2024

I needed a bit of time to notice my thinking error. After all the rope layer is splitted out to the root node and workers. 🎉 Tested it with 1, 2 and 4 nodes and the macbeth test generates the same output on different topologies *.

* The macbeth test doesn't work with the buffer quantization (it generates a different output), because now the RoPE is applied before the transfer quantization. Previously, it was applied after the transfer dequantization. I expect this affects the perplexity somehow. Probably this will be resolved if the llamaMultiheadAttJoin function will be also splitted out.

Now all nodes have the RoPE cache, and the size of the cache is different for all nodes. This may be a bit optimized, but "so far so good".

root node:
🕒 ropeCache: 8192 kB

1 worker:
🕒 ropeCache: 28672 kB

2 worker:
🕒 ropeCache: 20480 kB

3 worker:
🕒 ropeCache: 26624 kB

Next, I'll try to split out the llamaMultiheadAttJoin function.

@b4rtaz
Copy link
Owner

b4rtaz commented May 11, 2024

Finally I splitted out the multihead layer into all nodes (still not merged, I need to fix mixtral & grok architectures). First measurments:

Model: Llama 3 8B Q40
Buffer: Q80
Setup: 4 x Raspberry Pi 5 8GB + TP-Link LS1008G Switch

Transfer size / token

Devices 0.3.0 This PR Percentage change
2 x Raspberry Pi 5 S 646 kB + R 476 kB = 1122 kB S 578 kB + R 442 kB = 1020 kB -9.09%
4 x Raspberry Pi 5 S 2295 kB + R 714 kB = 3009 kB S 2193 kB + R 663 kB = 2856 kB -5.08%

Avg tokens / second

Devices 0.3.0 This PR Percentage change
2 x Raspberry Pi 5 Avg generation time 444.27 ms 381.81 ms
Avg inference time 362.73 349.94 ms -3.53%
Avg transfer time 80.11 ms 30.31 ms*
4 x Raspberry Pi 5 Avg generation time 331.47 ms 359.44 ms
Avg inference time 267.62 ms 258.00 ms -3.59%
Avg transfer time 62.34 ms 99.69 ms

* I think the used switch is completely non-deterministic, it achieves a random speed at different times. So I recommend to compare only the avg inference time.

It looks like that gave a tiny speed up (maybe 3%). I expected a bit more. 🤔

@b4rtaz
Copy link
Owner

b4rtaz commented May 11, 2024

Update: I changed the implementation a bit, now there is no synchronization between llamaQuantizeMultiheadAtt and llamaAtt. So basically now we have the state-of-the-art parallelism of attention layers. 🎉

Transfer size / token

Devices 0.3.0 PR v2 Percentage change
2 devices S 646 kB + R 476 kB = 1122 kB S 510 kB + R 442 kB = 952 kB -15.15%
4 devices S 2295 kB + R 714 kB = 3009 kB S 1887 kB + R 867 kB = 2754 kB -8.47%
8 devices S 5771 kB + R 833 kB = 6604 kB S 4819 kB + R 1487 kB = 6306 kB -4.51%

The final state of the attention synchronization looks like this for a single block:

root --- xb  ---> node
root <-- xbv ---- node
merge att

The previous implementation:

root --- xb  --> node
root <-- q  ---- node
root <-- k  ---- node
root <-- v  ---- node
root --- xb ---> node
root <-- xb2 --- node
merge att

@zhengpeirong
Copy link
Author

@b4rtaz 🎉You have completed the sota tensor parallel for Attention Layer!!!
Moreover, continuing our discussion before, there are still two optimizations that can be done:

  1. Computation:
    The last layer(Finalize) occupies 11% of the total time. It can be decomposed as parallel computing + synchronization(merge).
    Then, 11% can be reduced to 2.75%+synchronization.

  2. Communication:
    Currently, there are 3 main synchronization functions in one transformer block, the attention layer takes 1, and the FFN layer takes 2. You are using 2 All-gather operations(syncFfnA and syncFfn2) in the FFN. It can be optimized as 1 All-Reduce operation syncFfn.
    The slicing approach is explained in detail by the PyTorch tutorial:

        "feed_foward.w1": ColwiseParallel(),  
        "feed_forward.w2": RowwiseParallel(),    
        "feed_forward.w3": ColwiseParallel(),   
    

    Then 4.76% time can be reduced.

In summary, at most 12% acceleration can be made upon the current version. When the worker number increases, 4 workers in this issue, this acceleration would enjoy more parallelism.

@zhengpeirong
Copy link
Author

zhengpeirong commented May 20, 2024

https://github.com/huggingface/transformers/blob/bb48e921868ac750417956de941606f7e2fa02ca/src/transformers/models/llama/modeling_llama.py#L199-L219

@b4rtaz Just so your reference, this code implements the FFN layer of llama with Tensor Parallel acceleration.
In summary, the only 2 dimensions Tensor Parallel divides for the Attention layer is the head dimension, while for the MLP layer, it's the intermediate hidden dimension.

@b4rtaz
Copy link
Owner

b4rtaz commented May 25, 2024

@zhengpeirong it seems after I adjusted mlp layers to your suggestion the transfer has dropped by ~40% per token. 🤯

Devices 0.5.0 0.7.1 Percentage change
2 devices S 510 kB + R 442 kB = 952 kB S 272 kB + R 272 kB = 544 kB -42.8%
4 devices S 1887 kB + R 867 kB = 2754 kB S 816 kB + R 816 kB = 1632 kB -40.7%

PR

Later I'll check the impact on the generation time.

@b4rtaz
Copy link
Owner

b4rtaz commented May 25, 2024

Where you see the generation time data? 🤔

@zhengpeirong
Copy link
Author

zhengpeirong commented May 25, 2024

@zhengpeirong it seems after I adjusted mlp layers to your suggestion the transfer has dropped by ~40% per token. 🤯

Devices 0.5.0 PR Percentage change
2 devices S 510 kB + R 442 kB = 952 kB S 272 kB + R 272 kB = 544 kB -42.8%
4 devices S 1887 kB + R 867 kB = 2754 kB S 816 kB + R 816 kB = 816 kB -40.7%

PR

Later I'll check the impact on the generation time.

'272 kB ' is compatible with the theory analysis.

272/32/4=2.125

Except for transfer data for the embedding layers, we can treat this as 2.

This means there are 2 times All-Reduce transfers in a single Transformer block.

And the S=R is exactly what All-Reduce will show.

Congratulations on finishing this feature suggestion!

@b4rtaz
Copy link
Owner

b4rtaz commented May 25, 2024

Llama 2 7B Q40

nTokens = 90, buffer = Q80

4 x Rasperry Pi 5 8GB

Version Avg tokens / second Avg generation time Avg inference time Avg transfer time
0.7.1 4.08 245.08 ms 169.33 ms 75.34 ms
0.7.0 3.90 256.23 ms 168.77 ms 87.12 ms
0.6.0 4.24 235.69 ms 143.44 ms 91.77 ms

2 x Rasperry Pi 5 8GB

Version Avg tokens / second Avg generation time Avg inference time Avg transfer time
0.7.1 3.07 325.46 ms 269.04 ms 56.39 ms
0.7.0 2.91 343.44 ms 266.51 ms 76.87 ms
0.6.0 3.06 327.17 ms 249.80 ms 77.28 ms

Tinylama 1.3B 3T Q40

nTokens = 128, buffer = Q80

2 x Rasperry Pi 5 8GB

Version Avg tokens / second Avg generation time Avg inference time Avg transfer time
0.7.1 16.86 59.31 ms 50.37 ms 8.58 ms
0.7.0 15.17 65.93 ms 52.07 ms 13.45 ms

Llama 3 8B Q40

nTokens = 90, buffer = Q80

2 x AMD EPYC 7402P 24-Core Processor

Version Avg tokens / second Avg generation time Avg inference time Avg transfer time
0.7.1 13.04 76.67 ms 45.33 ms 30.93 ms
0.7.0 12.79 78.21 ms 46.30 ms 31.49 ms
0.6.0 12.55 79.71 ms 47.08 ms 32.22 ms

@b4rtaz
Copy link
Owner

b4rtaz commented May 25, 2024

In all cases the average transfer time has dropped. What is interesting the non-blocking sockets reduce the speed on Raspberry Pi but on a strong machine not. Maybe this mode should be optional.

@zhengpeirong
Copy link
Author

In all cases the average transfer time has dropped. What is interesting the non-blocking sockets reduce the speed on Raspberry Pi but on a strong machine not. Maybe this mode should be optional.

Do you mean blocking sockets reduces the speed?

Could you try 8 x Raspberry Pi? Since there are obvious transfer delays for 8 devices, I am curious whether it's because of network traffic congestion.

BTW, I think it's time to update the README.md with the newest generation time for people new here.

@b4rtaz
Copy link
Owner

b4rtaz commented May 25, 2024

Do you mean blocking sockets reduces the speed?

No. The non-blocking sockets I think. From the 0.6.1 Distributed Llama has enabled non-blocking sockets for root <> node communciation.

Could you try 8 x Raspberry Pi?

Unfortunelty I don't have 8 devices anymore. I have only 4 x Raspberry Pi 5 8GB.

BTW, I think it's time to update the README.md with the newest generation time for people new here.

You're right. I'll do it soon.

@zhengpeirong
Copy link
Author

Do you mean blocking sockets reduces the speed?

No. The non-blocking sockets I think. From the 0.6.1 Distributed Llama has enabled non-blocking sockets for root <> node communciation.'

The non-blocking sockets will make the CPU do other jobs instead of waiting. But what's the logical connection between non-blocking and increased inference time?

Could you try 8 x Raspberry Pi?

Unfortunelty I don't have 8 devices anymore. I have only 4 x Raspberry Pi 5 8GB.

In this discussion, you are invited to conduct experiments with more devices and find what number of devices is the best choice, then present it in README.

If the dllama can support any dual number of devices, more scenarios can be supported since there is a big gap between 8 and 16 and 32 devices.

@b4rtaz
Copy link
Owner

b4rtaz commented May 27, 2024

The non-blocking sockets will make the CPU do other jobs instead of waiting. But what's the logical connection between non-blocking and increased inference time?

I think this problem appears only on slow devices like Raspberry Pi. I cannot explain it but you can see the drop in the speed 0.6.0 -> 0.7.0. This was only a minor change between these versions.

Maybe we need more tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants