Skip to content

Commit

Permalink
Streaming Support for Nvidia's Triton Integration (#13135)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohith-2 committed Apr 30, 2024
1 parent af8c8eb commit 93cb095
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 5 deletions.
28 changes: 28 additions & 0 deletions docs/docs/examples/llm/nvidia_triton.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,34 @@
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Call `stream_complete` with a prompt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"resp = NvidiaTriton(server_url=triton_url, model_name=model_name, tokens=32).stream_complete(\"The tallest mountain in North America is \")\n",
"for delta in resp:\n",
" print(delta.delta, end=\" \")\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should expect the following response as a stream\n",
"```\n",
"the Great Pyramid of Giza, which is about 1,000 feet high. The Great Pyramid of Giza is the tallest mountain in North America.\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from llama_index.core.llms.callbacks import llm_chat_callback
from llama_index.core.base.llms.generic_utils import (
completion_to_chat_decorator,
stream_completion_to_chat_decorator,
)
from llama_index.core.llms.llm import LLM
from llama_index.llms.nvidia_triton.utils import GrpcTritonClient
Expand Down Expand Up @@ -236,10 +237,12 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
chat_fn = completion_to_chat_decorator(self.complete)
return chat_fn(messages, **kwargs)

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
raise NotImplementedError
chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete)
return chat_stream_fn(messages, **kwargs)

def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
Expand All @@ -266,7 +269,7 @@ def complete(
if isinstance(token, InferenceServerException):
client.stop_stream(model_params["model_name"], request_id)
raise token
response = response + token
response += token

return CompletionResponse(
text=response,
Expand All @@ -275,7 +278,34 @@ def complete(
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
raise NotImplementedError
from tritonclient.utils import InferenceServerException

client = self._get_client()

invocation_params = self._get_model_default_parameters
invocation_params.update(kwargs)
invocation_params["prompt"] = [[prompt]]
model_params = self._identifying_params
model_params.update(kwargs)
request_id = str(random.randint(1, 9999999)) # nosec

if self.triton_load_model_call:
client.load_model(model_params["model_name"])

result_queue = client.request_streaming(
model_params["model_name"], request_id, **invocation_params
)

def gen() -> CompletionResponseGen:
text = ""
for token in result_queue:
if isinstance(token, InferenceServerException):
client.stop_stream(model_params["model_name"], request_id)
raise token
text += token
yield CompletionResponse(text=text, delta=token)

return gen()

async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ ignore_missing_imports = true
python_version = "3.8"

[tool.poetry]
authors = ["Your Name <you@example.com>"]
authors = ["Rohith Ramakrishnan <rrohith2001@gmail.com>"]
description = "llama-index llms nvidia triton integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-nvidia-triton"
readme = "README.md"
version = "0.1.4"
version = "0.1.5"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_tests()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.nvidia_triton import NvidiaTriton


def test_text_inference_embedding_class():
names_of_base_classes = [b.__name__ for b in NvidiaTriton.__mro__]
assert BaseLLM.__name__ in names_of_base_classes

0 comments on commit 93cb095

Please sign in to comment.