Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed May 13, 2024
1 parent 7325f11 commit 1389e4f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 26 deletions.
12 changes: 9 additions & 3 deletions vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,15 @@ def sampler_output(
# Do not match itself
matches = (windows[:-1] == ngram_tensor).all(dim=-1)

match_indices = matches.nonzero(as_tuple=True)[0]
if match_indices.size()[0] > 0:
proposal_start_idx = match_indices[0].add_(ngram_size)
# first_match includes "values" (bool), indicating whether
# the match is found, and "indices", indicating the index
# of the first match.
# Note that "first_match.values.item()" triggers GPU-CPU
# sync so it is a bit inefficient, but we have not found
# a better way to do this.
first_match = matches.max(dim=-1)
if first_match.values.item():
proposal_start_idx = first_match.indices.add_(ngram_size)
spec_indices = (
proposal_start_idx).repeat(sample_len) + torch.arange(
sample_len, device=self.device)
Expand Down
79 changes: 56 additions & 23 deletions vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,13 @@ def get_proposals(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
)
if maybe_sampler_output is not None:
# Some sequences do not get speculative tokens from
# the draft worker. Remove these sequences from nonzero
# proposal len seqs to reduce scoring overhead
zero_seq_idxs = []
for seq_idx, sampler_output in zip(
nonzero_proposal_len_indices, maybe_sampler_output):
if sampler_output is None:
proposal_lens[seq_idx] = 0
zero_seq_idxs.append(seq_idx)
nonzero_proposal_len_indices = [
idx for idx in nonzero_proposal_len_indices
if idx not in zero_seq_idxs
]
maybe_sampler_output = [
sampler_output for sampler_output in maybe_sampler_output
if sampler_output is not None
]
# We assume sampler_output will not return a list of
# maybe_sampler_output with all Nones. In this case it should
# directly return maybe_sampler_output=None and should not
# enter this branch
assert maybe_sampler_output
(
proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
) = self._remove_no_proposal_seqs(proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
Expand Down Expand Up @@ -163,6 +147,55 @@ def _split_by_proposal_len(
nonzero_proposal_len_indices,
)

def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
(maybe_sampler_output=None). This can avoid scoring overheads.
"""
if maybe_sampler_output is None:
return (proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices)

new_proposal_lens: List[int] = []
new_nonzero_proposal_len_indices: List[int] = []
new_maybe_sampler_output: List[SamplerOutput] = []
nonzero_proposal_len_idx_ptr = 0
seq_idx = 0
while seq_idx < len(
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
nonzero_proposal_len_indices):
if seq_idx < nonzero_proposal_len_indices[
nonzero_proposal_len_idx_ptr]:
# Sequence is not in the original nonzero_proposal_len_indices,
# meaning that it has a proposal length of 0 before sending to
# the draft worker.
assert proposal_lens[seq_idx] == 0
new_proposal_lens.append(0)
else:
# Sequence is in the original nonzero_proposal_len_indices
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
# but does not have a proposal from the draft worker.
new_proposal_lens.append(0)
else:
# and has a proposal from the draft worker. Add it to the
# new nonzero proposal list and keep the sampler output.
new_proposal_lens.append(proposal_lens[seq_idx])
new_nonzero_proposal_len_indices.append(seq_idx)
new_maybe_sampler_output.append(
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
nonzero_proposal_len_idx_ptr += 1
seq_idx += 1

# The remaining sequences should have proposal length of 0.
new_proposal_lens.extend(proposal_lens[seq_idx:])

# We assume sampler_output will not be a list of all Nones.
# In this case this function should not be called.
assert new_maybe_sampler_output
return (new_proposal_lens, new_maybe_sampler_output,
new_nonzero_proposal_len_indices)

def _merge_outputs(
self,
batch_size: int,
Expand Down

0 comments on commit 1389e4f

Please sign in to comment.