Skip to content

Commit

Permalink
build: hypothesis testing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Nov 11, 2023
1 parent ef7cd4e commit 2a57e63
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
19 changes: 12 additions & 7 deletions tests/test_kani.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
import string
from hypothesis import HealthCheck, given, settings, strategies as st

from kani import ChatMessage, ChatRole, Kani
from tests.engine import TestEngine
Expand Down Expand Up @@ -45,13 +44,19 @@ async def test_always_include():
assert flatten_chatmessages(prompt) == "12a"


async def test_spam():
@settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None)
@given(st.data())
async def test_spam(data):
# spam the kani with a bunch of random prompts
# and make sure it never breaks
ai = Kani(engine, desired_response_tokens=3, system_prompt="1", always_included_messages=[ChatMessage.user("2")])
for _ in range(1000):
query_len = random.randint(0, 5)
query = "".join(random.choice(string.ascii_letters) for _ in range(query_len))
ai = Kani(
engine,
desired_response_tokens=3,
system_prompt=data.draw(st.text(min_size=0, max_size=1)),
always_included_messages=[ChatMessage.user(data.draw(st.text(min_size=0, max_size=1)))],
)
queries = data.draw(st.lists(st.text(min_size=0, max_size=5)))
for query in queries:
resp = await ai.chat_round_str(query, test_echo=True)
assert resp == query

Expand Down
66 changes: 66 additions & 0 deletions tests/test_mock_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Ensure that any messages sent to OpenAI are valid (mock the API and just echo)."""

import time

from hypothesis import HealthCheck, given, settings, strategies as st
from pydantic import RootModel

from kani import ChatMessage, Kani
from kani.engines.openai import OpenAIClient, OpenAIEngine
from kani.engines.openai.models import OpenAIChatMessage


class MockOpenAIClient(OpenAIClient):
async def request(self, method: str, route: str, headers=None, retry=None, **kwargs):
if route != "/chat/completions":
raise ValueError("only chat completions is mocked in tests")

# validate that all the messages come across correctly
data = kwargs["json"]
RootModel[list[OpenAIChatMessage]].model_validate(data["messages"])

async def post(self, route: str, **kwargs):
if route != "/chat/completions":
raise ValueError("only chat completions is mocked in tests")

await self.request("POST", route, **kwargs)
data = kwargs["json"]
message = data["messages"][-1] if data["messages"] else {"role": "assistant", "content": None}
return dict(
id="some-id",
object="chat.completion",
created=int(time.time()),
model=data["model"],
usage=dict(prompt_tokens=0, completion_tokens=0, total_tokens=0),
choices=[dict(message=message, index=0)],
)


class MockOpenAIEngine(OpenAIEngine):
@staticmethod
def translate_messages(messages, cls=OpenAIChatMessage):
# we don't care about the tool call bindings here - just the translation
return [cls.from_chatmessage(m) for m in messages]


client = MockOpenAIClient("sk-fake-api-key")
engine = MockOpenAIEngine(client=client)


# hypothesis synchronously constructs a coro to call MockOpenAIClient.create_chat_completion
# based on the type annotations of the async function
# we then await the returned coro in the async test body
@settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None)
@given(st.builds(client.create_chat_completion))
async def test_chat_completions_valid(coro):
await coro


def build_kani_state(msgs: list[ChatMessage]):
return Kani(engine, chat_history=msgs)


@settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None)
@given(st.builds(build_kani_state))
async def test_kani_chatmessages_valid(ai):
await ai.get_model_completion()

0 comments on commit 2a57e63

Please sign in to comment.