Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

out of memory #171

Open
enhaofrank opened this issue Oct 20, 2023 · 3 comments
Open

out of memory #171

enhaofrank opened this issue Oct 20, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@enhaofrank
Copy link

代码报错信息:
在V100的机器上,显存32G。能正常启动,当跑一条短query时,就报out of memory 错误。

python3 -m lightllm.server.api_server --model_dir /app/baichuan2-13B --trust_remote_code --host 0.0.0.0 --port 8080 --tp 1 --max_total_token_num 6000
Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO: Started server process [260217]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://10.0.0.0:8080 (Press CTRL+C to quit)
Task exception was never retrieved
future: <Task finished name='Task-5' coro=<RouterManager.loop_for_fwd() done, defined at /home/fangenhao/lightllm/lightllm/server/router/manager.py:88> exception=RuntimeError('Triton Error [CUDA]: out of memory')>
Traceback (most recent call last):
File "", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-f24b6aa9b101a518b6a4a6bddded372e-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, 'fp32', torch.float32, torch.int32, torch.int32, torch.float32, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 128, 128), (True, True, True, (False,), True, True, True, True, True, (True, False), (True, False), (False, True), (True, False), (True, False), (False, True), (True, False), (True, False), (False, True), (True, False), (True, False), (False, True), (True, False), (False, False), (False, True)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/fangenhao/lightllm/lightllm/server/router/manager.py", line 91, in loop_for_fwd
await self._step()
File "/home/fangenhao/lightllm/lightllm/server/router/manager.py", line 112, in _step
await self._prefill_batch(self.running_batch)
File "/home/fangenhao/lightllm/lightllm/server/router/manager.py", line 149, in _prefill_batch
ans = await asyncio.gather(*rets)
File "/home/fangenhao/lightllm/lightllm/server/router/model_infer/model_rpc.py", line 241, in prefill_batch
ans = self._prefill_batch(batch_id)
File "/home/fangenhao/lightllm/lightllm/utils/infer_utils.py", line 54, in inner_func
result = func(*args, **kwargs)
File "/home/fangenhao/lightllm/lightllm/server/router/model_infer/model_rpc.py", line 112, in exposed_prefill_batch
return self.forward(batch_id, is_prefill=True)
File "/home/fangenhao/lightllm/lightllm/server/router/model_infer/model_rpc.py", line 163, in forward
logits = self.model.forward(**kwargs)
File "/home/fangenhao/anaconda3/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/fangenhao/lightllm/lightllm/common/basemodel/basemodel.py", line 128, in forward
return self._prefill(batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len)
File "/home/fangenhao/lightllm/lightllm/common/basemodel/basemodel.py", line 152, in _prefill
predict_logics = self._context_forward(input_ids, infer_state)
File "/home/fangenhao/lightllm/lightllm/common/basemodel/basemodel.py", line 192, in _context_forward
input_embs = self.layers_infer[i].context_forward(input_embs, infer_state, self.trans_layers_weight[i])
File "/home/fangenhao/lightllm/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py", line 129, in context_forward
self._context_attention(input_embdings,
File "/home/fangenhao/lightllm/lightllm/utils/infer_utils.py", line 21, in time_func
ans = func(*args, **kwargs)
File "/home/fangenhao/lightllm/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py", line 84, in _context_attention
o = self._context_attention_kernel(q, cache_k, cache_v, infer_state, layer_weight)
File "/home/fangenhao/lightllm/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py", line 27, in _context_attention_kernel
return BloomTransformerLayerInfer.context_attention_kernel(self, q, k, v, infer_state, layer_weight)
File "/home/fangenhao/lightllm/lightllm/models/bloom/layer_infer/transformer_layer_infer.py", line 55, in context_attention_kernel
context_attention_fwd(q.view(-1, self.tp_q_head_num
, self.head_dim
),
File "/home/fangenhao/anaconda3/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/fangenhao/lightllm/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py", line 233, in context_attention_fwd
_fwd_kernel[grid](
File "/home/fangenhao/anaconda3/lib/python3.9/site-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "", line 43, in _fwd_kernel
RuntimeError: Triton Error [CUDA]: out of memory

@enhaofrank enhaofrank added the bug Something isn't working label Oct 20, 2023
@hiworldwzj
Copy link
Collaborator

@enhaofrank 现在还没有支持Baichuan2的模型。

@enhaofrank
Copy link
Author

@hiworldwzj 有预期时间吗?

@hiworldwzj
Copy link
Collaborator

@enhaofrank 要等后面有个Router特性合并后,才会接入这个baichuan2新模型。你也可以在docs中找到如何接入新模型的文档,接入一下。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants