Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAiStreamingChatModel and OpenAiChatModel classes #971

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
008214c
Avoid using OpenAiStreamingChatModel class
anunnakian Apr 18, 2024
12d7e82
Polish
anunnakian Apr 18, 2024
0b32b58
Avoid using OpenAiStreamingChatModel class
anunnakian Apr 18, 2024
b7c85e3
Polish
anunnakian Apr 18, 2024
04682e7
Polish
anunnakian Apr 20, 2024
b2fa5fc
Merge remote-tracking branch 'origin/remove_OpenAiStreamingChatModel'…
anunnakian Apr 21, 2024
3091962
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 22, 2024
d83026b
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 23, 2024
5d7b790
Update the doc
anunnakian Apr 23, 2024
9f92054
Merge remote-tracking branch 'origin/remove_OpenAiStreamingChatModel'…
anunnakian Apr 23, 2024
1c90526
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 24, 2024
41c2a82
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 25, 2024
a317064
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 25, 2024
a32990f
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 27, 2024
8dd12fd
Deprecate OpenAiStreamingChatModel class
anunnakian May 1, 2024
0a0ac22
Polish
anunnakian May 1, 2024
89b2abf
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 1, 2024
6ae3eec
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 2, 2024
adb6064
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 3, 2024
bcffa06
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 6, 2024
6ba9487
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 7, 2024
23f7b42
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 14, 2024
61e3f79
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 15, 2024
930558a
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 17, 2024
178af50
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 19, 2024
65a09d6
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 23, 2024
2887d35
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 25, 2024
ba204fd
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/docs/integrations/language-models/open-ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,17 @@ OpenAiChatModel model = OpenAiChatModel.builder()
.proxy(...)
.logRequests(...)
.logResponses(...)
.logStreamingResponses(...)
.tokenizer(...)
.customHeaders(...)
.build();
```
See the description of some of the parameters above [here](https://platform.openai.com/docs/api-reference/chat/create).

## OpenAiStreamingChatModel
## Streaming ChatModel

```java
OpenAiStreamingChatModel model = OpenAiStreamingChatModel.withApiKey(System.getenv("OPENAI_API_KEY"));
OpenAiChatModel model = OpenAiChatModel.withApiKey(System.getenv("OPENAI_API_KEY"), true);

model.generate("Say 'Hello World'", new StreamingResponseHandler<AiMessage>() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;

Expand All @@ -18,7 +18,7 @@
*/
class OllamaOpenAiStreamingChatModelIT extends AbstractOllamaLanguageModelInfrastructure {

StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
StreamingChatLanguageModel model = OpenAiChatModel.builder()
.apiKey("does not matter") // TODO make apiKey optional when using custom baseUrl?
.baseUrl(ollama.getEndpoint() + "/v1") // TODO add "/v1" by default?
.modelName(TINY_DOLPHIN_MODEL)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package dev.langchain4j.model.openai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
Expand All @@ -26,6 +30,7 @@

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
Expand All @@ -38,7 +43,7 @@
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
public class OpenAiChatModel implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator {

private final OpenAiClient client;
private final String modelName;
Expand All @@ -54,6 +59,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final String user;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;

@Builder
Expand Down Expand Up @@ -98,10 +104,14 @@ public OpenAiChatModel(String baseUrl,
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.logStreamingResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();

this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.isOpenAiModel = isOpenAiModel(modelName);

this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.stop = stop;
Expand All @@ -123,22 +133,22 @@ public String modelName() {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null);
return generateMessage(messages, null, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
return generateMessage(messages, toolSpecifications, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
return generateMessage(messages, singletonList(toolSpecification), toolSpecification);
}

private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
private Response<AiMessage> generateMessage(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName)
Expand Down Expand Up @@ -248,4 +258,92 @@ public OpenAiChatModelBuilder modelName(OpenAiChatModelName modelName) {
return this;
}
}

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
generate(messages, null, null, handler);
}

@Override
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
generate(messages, toolSpecifications, null, handler);
}

@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
generate(messages, null, toolSpecification, handler);
}

private void generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
StreamingResponseHandler<AiMessage> handler
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.stream(true)
.model(modelName)
.messages(toOpenAiMessages(messages))
.temperature(temperature)
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);

if (toolThatMustBeExecuted != null) {
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted)));
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
}

ChatCompletionRequest request = requestBuilder.build();

int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);

client.chatCompletion(request)
.onPartialResponse(partialResponse -> {
responseBuilder.append(partialResponse);
handle(partialResponse, handler);
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
if (!isOpenAiModel) {
response = removeTokenUsage(response);
}
handler.onComplete(response);
})
.onError(handler::onError)
.execute();
}

private int countInputTokens(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages);
if (toolThatMustBeExecuted != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
} else if (!isNullOrEmpty(toolSpecifications)) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
return inputTokenCount;
}

private static void handle(ChatCompletionResponse partialResponse,
StreamingResponseHandler<AiMessage> handler) {
List<ChatCompletionChoice> choices = partialResponse.choices();
if (choices == null || choices.isEmpty()) {
return;
}
Delta delta = choices.get(0).delta();
String content = delta.content();
if (content != null) {
handler.onNext(content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
* The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*
*
* @Deprecated use {@link OpenAiChatModel} instead.
*/
@Slf4j
@Deprecated()
public class OpenAiStreamingChatModel implements StreamingChatLanguageModel, TokenCountEstimator {

private final OpenAiClient client;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
/**
* A factory for building {@link OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder} instances.
*/
@Deprecated
public interface OpenAiStreamingChatModelBuilderFactory extends Supplier<OpenAiStreamingChatModel.OpenAiStreamingChatModelBuilder> {
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
Expand All @@ -33,7 +33,7 @@ public class StreamingAiServicesIT {

static Stream<StreamingChatLanguageModel> models() {
return Stream.of(
OpenAiStreamingChatModel.builder()
OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
Expand Down Expand Up @@ -238,7 +238,7 @@ void should_execute_a_tool_then_stream_answer(StreamingChatLanguageModel model)
void should_execute_multiple_tools_sequentially_then_answer() throws Exception {

// TODO test more models
StreamingChatLanguageModel streamingChatModel = OpenAiStreamingChatModel.builder()
StreamingChatLanguageModel streamingChatModel = OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
Expand Down Expand Up @@ -337,7 +337,7 @@ void should_execute_multiple_tools_in_parallel_then_answer() throws Exception {
Calculator calculator = spy(new Calculator());

// TODO test more models
StreamingChatLanguageModel streamingChatModel = OpenAiStreamingChatModel.builder()
StreamingChatLanguageModel streamingChatModel = OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -32,7 +32,7 @@ class StreamingAiServicesWithToolsIT {

static Stream<StreamingChatLanguageModel> models() {
return Stream.of(
OpenAiStreamingChatModel.builder()
OpenAiChatModel.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
Expand Down