Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding] Improve n-gram efficiency #4724

Merged
merged 6 commits into from May 13, 2024

Conversation

comaniac
Copy link
Contributor

@comaniac comaniac commented May 9, 2024

This PR refactors the ngram (PLD) implementation to make it more efficient. Specifically, in ngram worker

  1. Do not use .all() when n-gram size is 1.
  2. Keep all tensors on GPU all the time and avoid using Python native lists.
  3. When a sequence has no match, return None instead of padding 0's.
  4. When the speculative token indices exceed the input length, simply clamp them instead of padding 0's.
  5. Change ngram_prompt_lookup_min to be inclusive, so the minimum valid value becomes 1. IMO this is more intuitive to users, as --ngram-prompt-lookup-min=0 looks like we are disabling this.

Also in top1_proposer, this PR adds a feature to allow draft worker to not propose speculative tokens for certain sequences. Before this PR, unmatched sequence still proposes all 0 tokens and they will be scored by the target model as well.

Test cases are fixed accordingly to reflect the changes. The new added features are covered already.
So far I haven't observed an obvious speedup with these changes except for ngram size 1, but these changes should be nice to have. Please let me know if you have any opinion.

cc @cadedaniel @leiwen83 @LiuXiaoxuanPKU

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice improvement. some question

vllm/spec_decode/top1_proposer.py Outdated Show resolved Hide resolved
matches = (windows[:-1] == ngram_tensor).all(dim=1).max(
dim=-1)

if matches.values.item():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line looks sus, from a performance perspective

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It might transfer the value back to CPU. I'll double check.

Copy link
Contributor Author

@comaniac comaniac May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked and yes it does have to transfer the value to CPU for the branch. If we want to avoid GPU-CPU transfer at all during this entire process, instead of

# Now: Approach 1
matches = (windows[:-1] == ngram_tensor).all(dim=-1)
first_match = matches.max(dim=-1)
if first_match.values.item(): # GPU -> CPU (tensor with shape (1,))
    # generate speculative token ids and logprobs based on first_match.indices
    break

we can do

# Alternative: Approach 2
temp = (windows[:-1] == ngram_tensor).all(dim=-1)
all_matches = temp.nonzero(as_tuple=True)[0]
if all_matches.size() > 1:
    # generate speculative token ids and logprobs based on all_matches[0]
    break

Approach 2 avoids CPU-GPU communication, but keep temp with shape (same as windows[:-1]) longer. I guess it should be fine? I'll change to this approach for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero forces CPU<->GPU sync as well - https://pytorch.org/docs/stable/generated/torch.nonzero.html#torch-nonzero

I'd suggest instead doing something like num_matches = (temp != 0).sum()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I guess you need to sync here anyway... would probably need a custom kernel to avoid that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the spec decode framework has a few such synchronizations. we can fix this one here but until they're all solved the benefit won't be that large. (if you want to pioneer a solution e.g. lower conditional logic to kernel then fine by me 😄 )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ways to minimize the cost of a synchronization is to batch them. e.g. do a synchronization once for all iterations of the for loop. not sure if that's possible in your ngram matching logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, batching is what I thought of would work here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two challenges for batching:

  1. The length of each sequence differs.
  2. The first match ngram size of each sequence differs.

Point 1 means we have to pad input IDs if we batch sequences for ngram matching. Point 2 means we have to perform ngram matching from max ngram size to min ngram size for all sequences (i.e., no break). These 2 points also increase memory footprint, because we keep 1) all input IDs with padding (but I guess this tensor should not be too large?), and 2) all slicing windows (generated by .unfold).

I'll think a bit more about this today and get you back, but we may unblock this PR if there's no obvious solution shortly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah don’t mean to block the PR on this, we can optimize further later

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comments

vllm/spec_decode/top1_proposer.py Outdated Show resolved Hide resolved
vllm/spec_decode/top1_proposer.py Outdated Show resolved Hide resolved
vllm/spec_decode/top1_proposer.py Outdated Show resolved Hide resolved
vllm/spec_decode/top1_proposer.py Outdated Show resolved Hide resolved
matches = (windows[:-1] == ngram_tensor).all(dim=1).max(
dim=-1)

if matches.values.item():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the spec decode framework has a few such synchronizations. we can fix this one here but until they're all solved the benefit won't be that large. (if you want to pioneer a solution e.g. lower conditional logic to kernel then fine by me 😄 )

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving preemptively (am OOO next week)

@Yard1 Yard1 merged commit ce532ff into vllm-project:main May 13, 2024
55 checks passed
@comaniac comaniac deleted the pld-opt branch May 13, 2024 22:00
tlrmchlsmth pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants