Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add RPC backend #761

Closed
wants to merge 9 commits into from
Closed

ggml : add RPC backend #761

wants to merge 9 commits into from

Conversation

rgerganov
Copy link
Collaborator

We have a use case where we want to build and run ggml programs on low end machines (without GPUs) and leverage the computational resources of some high end machines (with GPUs) over the network. In this PR I am trying to prototype an RPC backend which proxies all operations to another host. On the remote host, the RPC backend simply delegates to one of the existing backends (CUDA, Metal, etc.):

flowchart TB
    rpcc---|gRPC|rpcs
    subgraph hosta[Host B]
    rpcs[RPC Server]---backend["Backend (CUDA,Metal,etc.)"]
    end
    subgraph hostb[Host A]
    ggml[ggml program]---rpcc[RPC Client]
    end

I am using gRPC for the remote calls, you can find the interface definition in ggml-rpc.proto. I have a simple program (client.cpp) which creates some tensors and successfully stores and retrieves data into them using the RPC backend.

You can give it a try with the following steps:

  1. Install gRPC by following this guide
  2. Build ggml with:
cmake -DGGML_RPC=ON -DCMAKE_PREFIX_PATH=$MY_INSTALL_DIR -DGGML_CUBLAS=ON -DCMAKE_BUILD_TYPE=Debug ..
make

With this configuration the RPC backend will delegate to the CUDA backend
3. Start the RPC server:

bin/server 50051
  1. Run the sample client:
bin/client localhost:50051

I am currently looking for some guidance on how to implement graph_compute with this approach. Any help is appreciated.

@slaren
Copy link
Collaborator

slaren commented Mar 7, 2024

Very cool! To implement graph_compute, you would need to re-create the same graph in the server. Most of the attributes of the tensors you should be able to copy them as they are in the client, but you would need to map the buffer pointers to the local copy of the buffer, as well as the tensor pointers. Roughly, something like this should work:

tensor map_tensor(t):
    if t in tensor_map:
        return tensor_map[t]
    tensor new_t = t
    new_t->view_src = map_tensor(t->view_src)
    new_t->buffer = map_buffer(t->buffer)
    for i in range(GGML_MAX_SRC):
        new_t->src[i] = map_tensor(t->src[i])
    tensor_map[t] = new_t
    return new_t
}

for i in range(remote_graph->n_nodes):
    local_graph->nodes[i] = map_tensor(remote_graph->nodes[i])

For the CUDA backend, it is important to also wrap the buffer init_tensor function, and update the local tensor with the values changed by the backend (should only be the extra and backend attributes).

@rgerganov
Copy link
Collaborator Author

Thanks for the hints. I have added a Tensor protobuf structure and the client-side implementation of graph_compute is serializing the graph into an array of Tensor structures. I am using the local address of the tensor as unique identifier, so I can reconstruct the same relations on the server. As for the tensor buffer, I am keeping the remote address of the buffer in the local buffer context, so I take it from there. The data pointer is already a remote address, so I copy it as is.

The purpose of the RPC backend is to proxy all operations to another
host where they are implemented with one of the existing backends (e.g.
CUDA, Metal, etc.).
src/ggml-rpc.cpp Outdated Show resolved Hide resolved
src/ggml-rpc.cpp Outdated Show resolved Hide resolved
@rgerganov
Copy link
Collaborator Author

@slaren Thanks for the review, I have addressed your comments.

The simple example which multiplies two tensors is working. I am currently trying to make gpt-2 work with this backend. Currently it is producing garbage:

main: seed = 1234
gpt2_model_load: loading model from '../models-mnt/gpt-2/ggml-model-gpt-2-117M.bin'
gpt2_model_load: n_vocab = 50257
gpt2_model_load: n_ctx   = 1024
gpt2_model_load: n_embd  = 768
gpt2_model_load: n_head  = 12
gpt2_model_load: n_layer = 12
gpt2_model_load: ftype   = 1
gpt2_model_load: qntvr   = 0
gpt2_model_load: using RPC backend
gpt2_model_load: ggml tensor size    = 368 bytes
gpt2_model_load: backend buffer size = 312.70 MB
gpt2_model_load: memory size =   144.00 MB, n_mem = 24576
gpt2_model_load: model size  =   239.08 MB
extract_tests_from_file : No test file found.
test_gpt_tokenizer : 0 tests failed out of 0 tests.
main: compute buffer size: 6.87 MB
main: prompt: 'I believe the meaning of life is'
main: number of tokens in prompt = 7, first 8 tokens: 40 1975 262 3616 286 1204 318 

I believe the meaning of life is '... for that to the. ©.

 is a:. The the's- in. ... has 456.AE = by of in to in- was

I suspect there is something wrong with the way I reconstruct the compute graph but I am still debugging ...

@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

What backend are you wrapping in the server? The CPU backend should be the simplest to make work.

@rgerganov
Copy link
Collaborator Author

What backend are you wrapping in the server? The CPU backend should be the simplest to make work.

I am wrapping the CPU backend on the server.

src/ggml-rpc.cpp Outdated Show resolved Hide resolved
@rgerganov rgerganov marked this pull request as ready for review March 11, 2024 14:49
@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

The issue may be that tensor->nb is not serialized. For non-contiguous views this is important.

src/ggml-rpc.cpp Outdated
}
result->flags = protobuf_tensor.flags();
result->data = reinterpret_cast<void *>(protobuf_tensor.data());
strncpy(result->name, protobuf_tensor.name().c_str(), GGML_MAX_NAME);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice this will work as long as the server and the client have the same GGML_MAX_NAME, but it is not safe to use strncpy in this way because it does not guarantee that the string will be NUL-terminated. snprintf should be safe.

@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

I get this error when building (followed by a million more):

In file included from /home/diego/usr/include/google/protobuf/stubs/common.h:20,
                 from /home/diego/usr/include/google/protobuf/io/coded_stream.h:107,
                 from /home/diego/code/ggml/build/src/ggml-rpc.pb.h:26,
                 from /home/diego/code/ggml/build/src/ggml-rpc.pb.cc:4:
/home/diego/usr/include/absl/strings/string_view.h:53:26: error: ‘string_view’ in namespace ‘std’ does not name a type
   53 | using string_view = std::string_view;
      |                          ^~~~~~~~~~~
/home/diego/usr/include/absl/strings/string_view.h:53:21: note: ‘std::string_view’ is only available from C++17 onwards
   53 | using string_view = std::string_view;
      |                     ^~~

It seems that gRPC requires building with C++17. Am I missing something?

@rgerganov
Copy link
Collaborator Author

It seems that gRPC requires building with C++17. Am I missing something?

Build gRPC by adding -DCMAKE_CXX_STANDARD=14 to cmake

@rgerganov
Copy link
Collaborator Author

The issue may be that tensor->nb is not serialized. For non-contiguous views this is important.

That was exactly the problem, gpt-2 works now!

@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

Build gRPC by adding -DCMAKE_CXX_STANDARD=14 to cmake

Does it work with 11 too? Usually ggml targets C++11.

@rgerganov
Copy link
Collaborator Author

gRPC doesn't build with -DCMAKE_CXX_STANDARD=11 but it works fine with ggml when built with -DCMAKE_CXX_STANDARD=14

@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

I got it to work now, very nice. Seems to work fine with CUDA as well. It is very slow, but I guess this is because all the GetAllocSize/BufferGetBase/InitTensor calls. I think that the BuferGetBase function could be cached in the client, and the other functions could be buffered until there is a call that uses the tensor, such as a tensor_set, tensor_get or a graph_compute, and then submitted to the server in a large batch.

@rgerganov
Copy link
Collaborator Author

Yes, these exact 3 functions are called many many times. GetAllocSize and BufferGetBase are easy to handle and InitTensor will require some tricks as you already suggested. When I cache the result for the first two and skip the third one (only with CPU backend), I get the following:

main:     load time =  1615.98 ms
main:   sample time =    17.24 ms
main:  predict time =  2319.07 ms / 33.13 ms per token
main:    total time =  3957.53 ms

The same prompt with CPU backend gives:

main:     load time =   140.88 ms
main:   sample time =     6.73 ms
main:  predict time =   536.77 ms / 7.67 ms per token
main:    total time =   685.79 ms

@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

Actually GetAllocSize cannot be buffered, since ggml-alloc needs to know the value immediately. Not sure how to address that. A simply hack for now could be to return ggml_nbytes, which will be correct most of the time.

@slaren
Copy link
Collaborator

slaren commented Mar 11, 2024

This is the performance that I get after removing these calls:

117M
CPU: main: predict time = 171.29 ms / 4.89 ms per token
RPC: main: predict time = 374.76 ms / 10.71 ms per token

1558M
CPU: main: predict time = 1847.17 ms / 52.78 ms per token
RPC: main: predict time = 2238.46 ms / 63.96 ms per token

diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c
index e675306..8e72957 100644
--- a/src/ggml-alloc.c
+++ b/src/ggml-alloc.c
@@ -369,6 +369,7 @@ struct node_alloc {
 struct ggml_gallocr {
     ggml_backend_buffer_type_t * bufts; // [n_buffers]
     ggml_backend_buffer_t * buffers; // [n_buffers]
+    void ** buffer_bases; // [n_buffers]
     struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
     int n_buffers;

@@ -392,6 +393,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
     galloc->buffers = calloc(sizeof(ggml_backend_buffer_t) * n_bufs, 1);
     GGML_ASSERT(galloc->buffers != NULL);

+    galloc->buffer_bases = calloc(sizeof(void *) * n_bufs, 1);
+    GGML_ASSERT(galloc->buffer_bases != NULL);
+
     galloc->buf_tallocs = calloc(sizeof(struct ggml_dyn_tallocr *) * n_bufs, 1);
     GGML_ASSERT(galloc->buf_tallocs != NULL);

@@ -733,6 +737,7 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
 #endif
             ggml_backend_buffer_free(galloc->buffers[i]);
             galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
+            galloc->buffer_bases[i] = ggml_backend_buffer_get_base(galloc->buffers[i]);
             if (galloc->buffers[i] == NULL) {
                 fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
                 return false;
@@ -763,7 +768,7 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
         if (node->data == NULL) {
             assert(tensor_alloc->offset != SIZE_MAX);
             assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], node) <= tensor_alloc->size_max);
-            void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
+            void * base = galloc->buffer_bases[buffer_id];
             void * addr = (char *)base + tensor_alloc->offset;
             ggml_backend_tensor_alloc(galloc->buffers[buffer_id], node, addr);
         } else {
diff --git a/src/ggml-backend.c b/src/ggml-backend.c
index d60d984..3b56b4b 100644
--- a/src/ggml-backend.c
+++ b/src/ggml-backend.c
@@ -1637,20 +1637,20 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t
     tensor->buffer = buffer;
     tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
     tensor->backend = tensor->view_src->backend;
-    ggml_backend_buffer_init_tensor(buffer, tensor);
+    //ggml_backend_buffer_init_tensor(buffer, tensor);
 }

 void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
     GGML_ASSERT(tensor->buffer == NULL);
     GGML_ASSERT(tensor->data == NULL);
     GGML_ASSERT(tensor->view_src == NULL);
-    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
-    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
-                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+    //GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    //GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+    //            (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));

     tensor->buffer = buffer;
     tensor->data = addr;
-    ggml_backend_buffer_init_tensor(buffer, tensor);
+    //ggml_backend_buffer_init_tensor(buffer, tensor);
 }

 static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
diff --git a/src/ggml-rpc.cpp b/src/ggml-rpc.cpp
index 0846429..e5592cd 100644
--- a/src/ggml-rpc.cpp
+++ b/src/ggml-rpc.cpp
@@ -255,6 +255,7 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
 }

 GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    return ggml_nbytes(tensor);
     GGML_PRINT_DEBUG("get alloc size\n");
     ggml::GetAllocSizeRequest request;
     ggml::Tensor * protobuf_tensor = request.mutable_tensor();
@@ -333,7 +334,7 @@ static void add_node(ggml::GraphComputeRequest & request, ggml_tensor * node, st
     add_node(request, node->view_src, visited);

     ggml::Tensor * protobuf_tensor = request.add_tensors();
-    GGML_PRINT_DEBUG("add node: %p\n", (void*)node);
+    //GGML_PRINT_DEBUG("add node: %p\n", (void*)node);
     serialize_tensor(node, protobuf_tensor);
 }

@@ -574,7 +575,7 @@ static struct ggml_tensor * create_node(uint64_t id,
     }
     for (int i = 0; i < request->tensors_size(); i++) {
         if (request->tensors(i).id() == id) {
-            GGML_PRINT_DEBUG("create node: %lx\n", id);
+            //GGML_PRINT_DEBUG("create node: %lx\n", id);
             const ggml::Tensor & protobuf_tensor = request->tensors(i);
             struct ggml_tensor * result = deserialize_tensor(ctx, protobuf_tensor);
             tensor_map[id] = result;

@ggerganov
Copy link
Owner

Metal backend also works:

image

This is the performance without gRPC:

main:     load time =  4307.36 ms
main:   sample time =    16.92 ms
main:  predict time =  2078.37 ms / 14.24 ms per token
main:    total time =  6403.82 ms

@ggerganov
Copy link
Owner

With the backend implementation now supporting pipeline parallelism, it seems possible to extend this RPC backend to perform distributed inference across many devices. This would be advantageous compared to the MPI backend because the latter does not support pipeline parallelisation and it is not obvious how to implement it.

Are there any obvious blockers? If not, maybe we should put this on the roadmap and try to support it eventually. It would be a cool technical feat and might even unlock some interesting inference use cases

@slaren
Copy link
Collaborator

slaren commented Mar 14, 2024

It should be doable, but pipeline parallelism requires the ability to perform asynchronous copies between backends, and asynchronous event synchronization between backends, and it could be tricky to implement that. Servers would probably need to be able to communicate between themselves to do this.

@rgerganov
Copy link
Collaborator Author

I believe distributed ggml would be a huge win, especially for very large models like Grok. Async operations are not a problem with gRPC but I need to get more familiar with the pipeline parallelism. In any case, I think this would be much better compared to MPI in the long term.

@ggerganov ggerganov changed the title Add RPC backend ggml : add RPC backend Mar 22, 2024
@ggerganov ggerganov added the enhancement New feature or request label Mar 22, 2024
@slaren
Copy link
Collaborator

slaren commented Mar 22, 2024

Pipeline parallelism allows evaluating large batches in parallel when using multiple devices. The idea is that to evaluate a layer, you only need the KV cache of this layer and the previous layers. This allows evaluating multiple batches in a pipeline such as this (excuse my terrible spreadsheet art):
image

The synchronization necessary to do this is implemented in ggml_backend_sched, and backends need to implement the async and event interface. The RPC can implement this behavior without relying on the underlying backend to also implement it, instead implementing an asynchronous queue that runs on a separate thread and uses synchronous calls to the backend.

About the issue of the latency of the calls to init_tensor/get_base/get_alloc_size:

get_base can be cached as mentioned previously.

After a recent change in the CUDA backend, the init_tensor call no longer changes the extra and backend parameters of ggml_tensor for normal buffers. Split buffers still add an extra, but these can be ignored for now. init_tensor still needs to be called for quantized tensors to clear some padding, but these are not normally created during inference.

get_alloc_size is only different from ggml_nbytes when using quantized tensors (this is where the padding is added).

The simplest solution for now would be to replace these calls with a client-side implementation when possible, and only call the backend in the RPC server if strictly necessary. Ideally this would be done by identifying the backend in the server, and providing a client-side implementation if the backend is recognized. For now, I think only passing these calls to the server if the tensor is quantized would be good enough.

@ggerganov
Copy link
Owner

ggerganov commented May 15, 2024

Superseded by: ggerganov/llama.cpp#6829

@ggerganov ggerganov closed this May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: No status
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

3 participants