Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] model Yi-34B compat #181

Open
Qubitium opened this issue Mar 14, 2024 · 1 comment
Open

[BUG] model Yi-34B compat #181

Qubitium opened this issue Mar 14, 2024 · 1 comment

Comments

@Qubitium
Copy link
Contributor

We have tested sglang with flashinfer 0.0.2 and flashinfer 0.0.3-dev (238563f) and both will crash in flashinfer with following stacktrace under A100.

Model: Yi-34B
OS: Ubuntu 22.04
Gpu: A100 80GB

Yi-6B and Yi-9B has no such issue. Yi is llama2 based arch if I am not mistaken.

@yzh119 Since the stacktrace is vague to me, BatchPrefillWithPagedKVCache failed to dispatch with dtype Half, I am first reproting the bug here. If you think this is sglang related, I will move bug to sglang. Thanks!

Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 184, in exposed_step
    self.forward_step()
  File "/root/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 199, in forward_step
    self.forward_fill_batch(new_batch)
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 412, in forward_fill_batch
    ) = self.model_runner.forward(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 506, in forward
    return self.forward_extend(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 411, in forward_extend
    return self.model.forward(input_ids, input_metadata.positions, input_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 269, in forward
    hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 239, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 191, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 140, in forward
    attn_output = self.attn(q, k, v, input_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/layers/radix_attention.py", line 115, in forward
    return self.extend_forward(q, k, v, input_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/layers/radix_attention.py", line 91, in prefill_forward_flashinfer
    o = input_metadata.prefill_wrapper.forward(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/flashinfer/prefill.py", line 507, in forward
    return self._wrapper.forward(
           ^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: BatchPrefillWithPagedKVCache failed to dispatch with dtype Half
@yzh119
Copy link
Collaborator

yzh119 commented Mar 14, 2024

This is related to #35 , Yi has a GQA group size (num_qo_heads/num_kv_heads) of 7 which is not compiled in our kernels. I'm refactoring the code so that we don't need a specialized kernel for each group size and the issue will be resolved then.

Sorry about the confusing error message, it's a dispatching issue but not related to data type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants