Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make flashinfer kernels cuda graphs friendly #187

Open
AgrawalAmey opened this issue Mar 20, 2024 · 7 comments
Open

Make flashinfer kernels cuda graphs friendly #187

AgrawalAmey opened this issue Mar 20, 2024 · 7 comments

Comments

@AgrawalAmey
Copy link

AgrawalAmey commented Mar 20, 2024

Thanks for creating these awesome kernels! I am trying to get flashinfer kernels to work with cuda graphs. But it appears that several parallelism decisions (block size, num_q_tiles, etc.) are made on the fly based on the input data in the forward function. This makes it difficult to capture flashinfer kernels in cuda graphs in a generic manner. I think one solution to the problem would be to introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism. Towards that, following are the items I have identified --

1. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- handle return lse?
2. BatchPrefillWithPagedKVCachePyTorchWrapper::Forward -- paged_kv_t batch_size should not be on cpu side
3. BatchPrefillWithPagedKVCacheWrapperDispatched -- make cuda device function or get rid of it
4. BatchPrefillWithPagedKVCacheWrapperDispatched -- num_frags_x, num_qo_tiles, batch size need to be 
5. BatchPrefillWithPagedKVCacheWrapperDispatched -- do not access handler state directly in the function
6. BatchPrefillWithPagedKVCacheDispatched -- make cuda device function
7. BatchPrefillWithPagedKVCacheDispatched -- put num_qo_tiles on device accessible memory
8. BatchPrefillWithPagedKVCacheDispatched -- Make validations gpu friendly
9. Batch size should be explicit input parameter not be based on length of indptr, so that inputs can be padded.

@yzh119 please let me know what would be the best way to proceed?

@yzh119
Copy link
Collaborator

yzh119 commented Mar 22, 2024

Hi @AgrawalAmey , thanks for bringing this up, I have some ideas about the CUDA graph integration with flashinfer:

The kernels to be executed can be determined before the a decode/prefill step (for all layers) by analyze the shapes, we can compile the CUDA Graph for all possible combinations (not too many) ahead of time, and dispatch to one of them according to the shapes.

Regarding dynamic parallelism:

introduce a launcher kernel which would factor in the input metadata and launch the actual the actual cuda kernel using dynamic parallelism

It sounds tricky to me because the required shared memory size/grid size varies for different schedules.

@AgrawalAmey
Copy link
Author

Hi @yzh119!

I have one implementation in sarathi-serve which tries to list different combinations, and capture them. But with increasing batch size and big variance in input sequences, the number of possibilities seemed explode. Plus, prefill + decode requests clubbed together makes it further more challenging. The memory cost of cuda graphs becomes too high as the number of combinations increases.

The child kernel/dynamic parallelism proposal is aimed to solve the challenge with different grid size etc. Essentially, the launcher kernel will be triggered with a single warp. Inside the launcher kernel, we can determine all the launch params and launch the actual attention kernel.

@AgrawalAmey
Copy link
Author

A sample program to explain what I mean:

#include <cuda_runtime.h>
#include <iostream>


__global__ void subKernel(int *data) {
    printf("Data before sub kernel: %d\n", *data);
    (*data) -= 1;
}

__global__ void addKernel(int *data) {
    printf("Data before add kernel: %d\n", *data);
    (*data) += 1;
}

struct UserData {
    int data;
    bool op;
};

__global__ void launchChildKernelFromDevice(void *_userData) {
    UserData *userData = (UserData *)_userData;
    bool op = userData->op;

    if (op) {
        addKernel<<<1, 1>>>((int*)userData);
    } else {
        subKernel<<<1, 1>>>((int*)userData);
    }
}

int main() {
    cudaStream_t stream;
    cudaStreamCreate(&stream);

    UserData *userData;
    cudaMallocHost(&userData, sizeof(UserData));

    userData->data = 10;
    userData->op = true;

    // run add kernel for sanity check

    cudaStreamSynchronize(stream);
    std::cout << "Data before kernel: " << userData->data << std::endl;
    launchChildKernelFromDevice<<<1, 1, 0, stream>>>(userData);
    cudaStreamSynchronize(stream);
    std::cout << "Data after kernel: " << userData->data << std::endl;

    cudaGraph_t graph;
    cudaGraphExec_t instance;

    // Begin graph capture
    cudaStreamSynchronize(stream);
    cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

    // Use cuda host function to launch child kernel
    launchChildKernelFromDevice<<<1, 1, 0, stream>>>(userData);

    // End graph capture
    cudaStreamEndCapture(stream, &graph);
    cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
    
    cudaStreamSynchronize(stream);

    printf("Data after graph: %d\n", userData->data);

    // Run the graph
    cudaGraphLaunch(instance, stream);
    cudaStreamSynchronize(stream);

    printf("Data after graph replay: %d\n", userData->data);

    userData->op = false;
    cudaGraphLaunch(instance, stream);
    cudaStreamSynchronize(stream);

    printf("Data after graph replay with different op: %d\n", userData->data);

    cudaGraphExecDestroy(instance);
    cudaGraphDestroy(graph);
    cudaStreamDestroy(stream);
    cudaFree(userData);

    return 0;
}

@yzh119
Copy link
Collaborator

yzh119 commented Mar 23, 2024

Thanks for your explaination, that's sounds reasonable.

To proceed, I'd love to write some documentations on our dispatching rules and see if we can describe them in dynamic parallelism. Before that I have to make #75 done because it will affect our dispatching strategy.

I'll be glad to follow up next week and we can schedule a meeting on zoom (you can drop me an email at zhye@cs.washington.edu).

@AgrawalAmey
Copy link
Author

Yes, that would be great, I will send out a when2meet link on email, thank you!

@ZSL98
Copy link

ZSL98 commented Apr 3, 2024

Hi, @AgrawalAmey, will your sarathi or sarathi-serve be open-sourced?

@AgrawalAmey
Copy link
Author

Hey @ZSL98, we are working with the vLLM team to get Sarathi-Serve scheduler support inside vLLM

yzh119 added a commit that referenced this issue May 24, 2024
As requested in #187 , this PR adds initial support of `CUDAGraph`
compatibility of flashinfer batch decode attention kernels. This PR is
the first step towards full CUDAGraph support and we will implement
CUDAGraph compatible prefill operators in later PRs.

# Proposed APIs
We add another wrapper `CUDAGraphBatchDecodeWithPagedKVCacheWrapper`,
and user need to pre-allocation page data structure buffers to
initialize this wrapper class. Once initiated, these buffers are pinned
on GPUs in the life cycle of the wrapper class.

The behavior of `CUDAGraphBatchDecodeWithPagedKVCacheWrapper` is a
little bit different from `BatchDecodeWithPagedKVCacheWrapper`'s: we
will only run a fixed set of kernels in CUDAGraph mode, no matter what
the input shape is (the original implementation will dispatch to
different kernels according to different input shapes).

This PR also fix the address of all kernel input pointers to accomodate
the constraint of CUDAGraph capturing.

# Examples
See `test_cuda_graph_batch_decode_with_paged_kv_cache` in unittests.
`begin_forward` functions should not be captured as some of the
operators are not allowed to be captured.

cc @AgrawalAmey  @LiuXiaoxuanPKU  @comaniac
yzh119 added a commit that referenced this issue Jun 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants