Skip to content

Commit

Permalink
chore: token cache openai
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed May 7, 2024
1 parent ae50c6a commit 01c62ad
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,4 @@ cython_debug/


**.DS_Store
sandbox/
31 changes: 8 additions & 23 deletions kani/engines/anthropic/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall
from kani.prompts.pipeline import PromptPipeline
from ..base import BaseCompletion, BaseEngine, Completion
from ..mixins import TokenCached

try:
from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic
Expand Down Expand Up @@ -84,7 +85,7 @@ def content_transform(msg: ChatMessage):
)


class AnthropicEngine(BaseEngine):
class AnthropicEngine(TokenCached, BaseEngine):
"""Engine for using the Anthropic API.
This engine supports all Claude models. See https://docs.anthropic.com/claude/docs/getting-access-to-claude for
Expand Down Expand Up @@ -137,6 +138,9 @@ def __init__(
)
if max_context_size is None:
max_context_size = next(size for prefix, size in CONTEXT_SIZES_BY_PREFIX if model.startswith(prefix))

super().__init__()

self.client = client or AsyncAnthropic(
api_key=api_key, max_retries=retry, base_url=api_base, default_headers=headers
)
Expand All @@ -159,27 +163,9 @@ def __init__(
self.tokenizer = None

# ==== token counting ====
@staticmethod
def message_cache_key(message: ChatMessage):
# (role, content, tool calls)

# we'll use msgpart identity for the hash here since we'll always have a ref as long as it's in a message
# history
hashable_content = tuple(part if isinstance(part, str) else id(part) for part in message.parts)

# use (name, args) for tool calls
if message.tool_calls:
hashable_tool_calls = tuple((tc.function.name, tc.function.arguments) for tc in message.tool_calls)
else:
hashable_tool_calls = message.tool_calls

return hash((message.role, hashable_content, hashable_tool_calls))

def message_len(self, message: ChatMessage) -> int:
# use cache
cache_key = self.message_cache_key(message)
if cache_key in self.token_cache:
return self.token_cache[cache_key]
if (cached_len := self.get_cached_message_len(message)) is not None:
return cached_len

# use tokenizer
if self.tokenizer is not None:
Expand Down Expand Up @@ -285,8 +271,7 @@ def _translate_anthropic_message(self, message):
kani_msg = ChatMessage.assistant(content, tool_calls=tool_calls or None)

# also cache the message token len
cache_key = self.message_cache_key(kani_msg)
self.token_cache[cache_key] = message.usage.output_tokens
self.set_cached_message_len(kani_msg, message.usage.output_tokens)

return Completion(
message=kani_msg,
Expand Down
3 changes: 3 additions & 0 deletions kani/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kani.models import ChatMessage


# ==== completions ====
class BaseCompletion(abc.ABC):
"""Base class for all LM engine completions."""

Expand Down Expand Up @@ -47,6 +48,7 @@ def completion_tokens(self):
return self._completion_tokens


# ==== base engines ====
class BaseEngine(abc.ABC):
"""Base class for all LM engines.
Expand Down Expand Up @@ -131,6 +133,7 @@ async def close(self):
pass


# ==== utils ====
class WrapperEngine(BaseEngine):
"""
A base class for engines that are meant to wrap other engines. By default, this class takes in another engine
Expand Down
32 changes: 32 additions & 0 deletions kani/engines/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from kani.models import ChatMessage


class TokenCached:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token_cache = {}

def message_cache_key(self, message: ChatMessage):
# (role, content, tool calls)

# we'll use msgpart identity for the hash here since we'll always have a ref as long as it's in a message
# history
hashable_content = tuple(part if isinstance(part, str) else id(part) for part in message.parts)

# use (name, args) for tool calls
if message.tool_calls is not None:
hashable_tool_calls = tuple((tc.function.name, tc.function.arguments) for tc in message.tool_calls)
else:
hashable_tool_calls = message.tool_calls

return hash((message.role, hashable_content, hashable_tool_calls))

def get_cached_message_len(self, message: ChatMessage) -> int | None:
# use cache
cache_key = self.message_cache_key(message)
if cache_key in self.token_cache:
return self.token_cache[cache_key]

def set_cached_message_len(self, message: ChatMessage, length: int):
cache_key = self.message_cache_key(message)
self.token_cache[cache_key] = length
30 changes: 27 additions & 3 deletions kani/engines/openai/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import function_calling
from .translation import ChatCompletion, openai_tc_to_kani_tc, translate_functions, translate_messages
from ..base import BaseCompletion, BaseEngine, Completion
from ..mixins import TokenCached

try:
import tiktoken
Expand Down Expand Up @@ -43,7 +44,7 @@
]


class OpenAIEngine(BaseEngine):
class OpenAIEngine(TokenCached, BaseEngine):
"""Engine for using the OpenAI API.
This engine supports all chat-based models and fine-tunes.
Expand Down Expand Up @@ -84,6 +85,9 @@ def __init__(
raise ValueError("You must supply no more than one of (api_key, client).")
if max_context_size is None:
max_context_size = next(size for prefix, size in CONTEXT_SIZES_BY_PREFIX if model.startswith(prefix))

super().__init__()

self.client = client or OpenAIClient(
api_key=api_key, organization=organization, max_retries=retry, base_url=api_base, default_headers=headers
)
Expand All @@ -100,6 +104,9 @@ def _load_tokenizer(self):
self.tokenizer = tiktoken.get_encoding("cl100k_base")

def message_len(self, message: ChatMessage) -> int:
if (cached_len := self.get_cached_message_len(message)) is not None:
return cached_len

mlen = 7
if message.text:
mlen += len(self.tokenizer.encode(message.text))
Expand All @@ -109,6 +116,8 @@ def message_len(self, message: ChatMessage) -> int:
for tc in message.tool_calls:
mlen += len(self.tokenizer.encode(tc.function.name))
mlen += len(self.tokenizer.encode(tc.function.arguments))

self.set_cached_message_len(message, mlen)
return mlen

async def predict(
Expand All @@ -125,7 +134,9 @@ async def predict(
model=self.model, messages=translated_messages, tools=tool_specs, **self.hyperparams, **hyperparams
)
# translate into Kani spec and return
return ChatCompletion(openai_completion=completion)
kani_cmpl = ChatCompletion(openai_completion=completion)
self.set_cached_message_len(kani_cmpl.message, kani_cmpl.completion_tokens)
return kani_cmpl

async def stream(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
Expand All @@ -142,16 +153,26 @@ async def stream(
messages=translated_messages,
tools=tool_specs,
stream=True,
stream_options={"include_usage": True},
**self.hyperparams,
**hyperparams,
)

# save requested tool calls and content as streamed
content_chunks = []
tool_call_partials = {} # index -> tool call
usage = None

# iterate over the stream and yield/save
async for chunk in stream:
# save usage if present
if chunk.usage is not None:
usage = chunk.usage

if not chunk.choices:
continue

# process content delta
delta = chunk.choices[0].delta

# yield content
Expand All @@ -172,7 +193,10 @@ async def stream(
# construct the final completion with streamed tool calls
content = None if not content_chunks else "".join(content_chunks)
tool_calls = [openai_tc_to_kani_tc(tc) for tc in sorted(tool_call_partials.values(), key=lambda c: c.index)]
yield Completion(message=ChatMessage(role=ChatRole.ASSISTANT, content=content, tool_calls=tool_calls))
msg = ChatMessage(role=ChatRole.ASSISTANT, content=content, tool_calls=tool_calls)
if usage:
self.set_cached_message_len(msg, usage.completion_tokens)
yield Completion(message=msg)

def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ cpp = [
]

openai = [
"openai>=1.0.0,<2.0.0",
"openai>=1.26.0,<2.0.0",
"tiktoken>=0.4.0,<1.0.0",
]

Expand Down
3 changes: 3 additions & 0 deletions sandbox/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# sandbox

This directory contains experimental little scripts that are not part of the main library.

0 comments on commit 01c62ad

Please sign in to comment.