Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(instructor): add middleware and update package #551

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/middleware/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Middleware in Instructor

Middleware in Instructor allows you to modify the messages sent to the language model before they are processed. This is beneficial because it enables you to perform custom preprocessing, add context, or even implement simple retrieval-augmented generation (RAG) techniques.

Middleware can be defined as simple functions or classes (when you need stateful variables). They are then registered with the Instructor client using the `with_middleware` method.

## what is middleware?

Middleware is a way to modify the input or output of a function or method. In the context of language models and AI assistants, middleware allows you to intercept and modify the messages being sent to the model before they are processed.

Some common use cases for middleware include:

- Preprocessing the input messages (e.g. cleaning up text, adding context)
- Implementing retrieval augmented generation by fetching relevant information and appending it to the messages
- Filtering or moderating content
- Logging or monitoring the messages being sent to the model

Middleware functions take in the list of messages, make any desired changes, and return the modified list of messages to be sent to the model.

Instructor makes it easy to define and use middleware. You can create middleware as simple functions using the `@messages_middleware` decorator, or for more complex stateful middleware you can define a class that inherits from `MessageMiddleware` and implements the `__call__` method.

Once defined, middleware is registered with the Instructor client using the `with_middleware()` method. This allows chaining multiple middleware together.

## Simple RAG Example

Middleware can also be used to implement more advanced techniques like retrieval-augmented generation (RAG). RAG involves retrieving relevant information from an external source and using it to augment the input to the language model. This can help provide additional context and improve the quality and accuracy of the generated responses.

To implement a simple RAG middleware, you could define a function or class that takes the input messages, performs a retrieval step to find relevant information, and then appends that information to the messages before sending them to the model. For example:

```python
@instructor.messages_middleware
def add_retrieval_augmentation(messages):
# Perform retrieval step to find relevant information
relevant_information = retrieve_relevant_information(messages)

# Append the relevant information to the messages
return messages + [{
"role": "user",
"content": f"Relevant Information: {relevant_information}"
}]
```

## Logging and Monitoring
Another useful application of middleware is for logging and monitoring the messages being sent to and received from the language model. This can be helpful for debugging, auditing, or analyzing the conversations.

To implement logging middleware, you can define a function or class that takes the input messages, logs them to a file or database, and then returns the original messages unmodified. For example:

```python
@instructor.messages_middleware
def logging_middleware(messages):
import logging
logging.info(f"Input messages: {messages}")

# Return the original messages unmodified
return messages
```

## Stateful Middleware

For more advanced stateful middleware, you can define a class that inherits from `MessageMiddleware` and implements the `__call__` method. This allows you to maintain state across multiple calls to the middleware.

For example, let's say you want to implement a middleware that adds user preferences to the messages. You could define a stateful middleware class like this:

```python
class UserPreferencesMiddleware(MessageMiddleware):

user_id: str

def __call__(self, messages):
preferences = get_user_preferences(self.user_id)
for message in messages:
if message.role == "system":
message.content += f"\n\nUser Preferences: {preferences}"
return messages
```

As you can see above, middleware provides a flexible way to modify and augment the messages being sent to and received from the language model. This can be used for a variety of purposes, such as adding relevant information, logging and monitoring conversations, and maintaining stateful interactions.
60 changes: 60 additions & 0 deletions examples/middleware/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import instructor
import openai
from openai.types.chat import ChatCompletionMessageParam
from typing import List
from pydantic import BaseModel


class PrintLastUserMessage(instructor.MessageMiddleware):
log: bool = False

def __call__(
self, messages: List[ChatCompletionMessageParam]
) -> List[ChatCompletionMessageParam]:
if self.log:
import pprint

pprint.pprint({"messages": messages})
return messages


@instructor.messages_middleware
jxnl marked this conversation as resolved.
Show resolved Hide resolved
def dumb_rag(messages):
# TODO: use RAG to generate a response
# TODO: add the response to the messages
return messages + [
{
"role": "user",
"content": "Search retrieved: 'Jason is 20 years old'",
}
]


class User(BaseModel):
age: int
name: str


client = (
instructor.from_openai(openai.OpenAI())
.with_middleware(dumb_rag)
.with_middleware(PrintLastUserMessage(log=True)) # can be called directly
)


user = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{
"role": "user",
"content": "How old is jason?",
}
],
response_model=User,
)

print(user)
# {'messages': [{'content': 'How old is jason?', 'role': 'user'},
# {'content': "Search retrieved: 'Jason is 20 years old'",
# 'role': 'user'}]}
# {'age': 20, 'name': 'jason'}
8 changes: 8 additions & 0 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@
from .patch import apatch, patch
from .process_response import handle_parallel_model
from .client import Instructor, from_openai, from_anthropic, from_litellm
from .messages_middleware import (
MessageMiddleware,
AsyncMessageMiddleware,
messages_middleware,
)

__all__ = [
"Instructor",
"MessageMiddleware",
"AsyncMessageMiddleware",
"messages_middleware",
"from_openai",
"from_anthropic",
"from_litellm",
Expand Down
61 changes: 40 additions & 21 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing_extensions import Self
from pydantic import BaseModel
from instructor.dsl.partial import Partial

from instructor.messages_middleware import MessageMiddleware

T = TypeVar("T", bound=(BaseModel | Iterable | Partial))

Expand All @@ -47,6 +47,7 @@ def __init__(
self.mode = mode
self.kwargs = kwargs
self.provider = provider
self.message_middleware = []

@property
def chat(self) -> Self:
Expand All @@ -60,6 +61,10 @@ def completions(self) -> Self:
def messages(self) -> Self:
return self

def with_middleware(self, middleware: MessageMiddleware | Callable) -> Self:
self.message_middleware.append(middleware)
return self

# TODO: we should overload a case where response_model is None
def create(
self,
Expand All @@ -69,9 +74,7 @@ def create(
validation_context: dict | None = None,
**kwargs,
) -> T:
kwargs = self.handle_kwargs(kwargs)

return self.create_fn(
return self._create(
response_model=response_model,
messages=messages,
max_retries=max_retries,
Expand All @@ -90,13 +93,9 @@ def create_partial(
assert self.provider != Provider.ANTHROPIC, "Anthropic doesn't support partial"

kwargs["stream"] = True

kwargs = self.handle_kwargs(kwargs)

response_model = instructor.Partial[response_model] # type: ignore
return self.create_fn(
return self._create(
messages=messages,
response_model=response_model,
response_model=instructor.Partial[response_model], # type: ignore
max_retries=max_retries,
validation_context=validation_context,
**kwargs,
Expand All @@ -113,12 +112,9 @@ def create_iterable(
assert self.provider != Provider.ANTHROPIC, "Anthropic doesn't support iterable"

kwargs["stream"] = True
kwargs = self.handle_kwargs(kwargs)

response_model = Iterable[response_model] # type: ignore
return self.create_fn(
return self._create(
messages=messages,
response_model=response_model,
response_model=Iterable[response_model],
max_retries=max_retries,
validation_context=validation_context,
**kwargs,
Expand All @@ -132,8 +128,7 @@ def create_with_completion(
validation_context: dict | None = None,
**kwargs,
) -> Tuple[T, ChatCompletion | Message]:
kwargs = self.handle_kwargs(kwargs)
model = self.create_fn(
model = self._create(
messages=messages,
response_model=response_model,
max_retries=max_retries,
Expand All @@ -148,6 +143,27 @@ def handle_kwargs(self, kwargs: dict):
kwargs[key] = value
return kwargs

def _create(
self,
messages: List[ChatCompletionMessageParam],
response_model: Type[T],
max_retries: int = 3,
validation_context: dict | None = None,
**kwargs,
) -> T:
for middleware in self.message_middleware:
messages = middleware(messages)

kwargs = self.handle_kwargs(kwargs)

return self.create_fn(
messages=messages,
response_model=response_model,
max_retries=max_retries,
validation_context=validation_context,
**kwargs,
)


class AsyncInstructor(Instructor):
client: openai.AsyncOpenAI | anthropic.AsyncAnthropic | None
Expand Down Expand Up @@ -374,10 +390,13 @@ def from_anthropic(
mode: instructor.Mode = instructor.Mode.ANTHROPIC_JSON,
**kwargs,
) -> Instructor | AsyncInstructor:
assert mode in {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_TOOLS,
}, "Mode be one of {instructor.Mode.ANTHROPIC_JSON, instructor.Mode.ANTHROPIC_TOOLS}"
assert (
mode
in {
instructor.Mode.ANTHROPIC_JSON,
instructor.Mode.ANTHROPIC_TOOLS,
}
), "Mode be one of {instructor.Mode.ANTHROPIC_JSON, instructor.Mode.ANTHROPIC_TOOLS}"

assert isinstance(
client, (anthropic.Anthropic, anthropic.AsyncAnthropic)
Expand Down
35 changes: 35 additions & 0 deletions instructor/messages_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List, Callable
from openai.types.chat import ChatCompletionMessageParam
from abc import ABC, abstractmethod
from pydantic import BaseModel


class MessageMiddleware(BaseModel, ABC):
@abstractmethod
def __call__(
self, messages: List[ChatCompletionMessageParam]
) -> List[ChatCompletionMessageParam]:
pass


class AsyncMessageMiddleware(MessageMiddleware):
@abstractmethod
async def __call__(
self, messages: List[ChatCompletionMessageParam]
) -> List[ChatCompletionMessageParam]:
pass


def messages_middleware(func: Callable) -> MessageMiddleware:
import inspect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving the import statement for the inspect module to the top of the file. This aligns with Python's best practices for import statements.


if "messages" not in inspect.signature(func).parameters:
raise ValueError("`messages` must be a parameter of the middleware function")

class _Middleware(MessageMiddleware):
def __call__(
self, messages: List[ChatCompletionMessageParam]
) -> List[ChatCompletionMessageParam]:
return func(messages=messages)

return _Middleware()
1 change: 0 additions & 1 deletion tests/llm/test_openai/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class QuestionAnswerNoEvil(BaseModel):

@pytest.mark.parametrize("model", models)
def test_runmodel_validator_default_openai_client(model, client):

client = instructor.from_openai(client)

class QuestionAnswerNoEvil(BaseModel):
Expand Down