Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAiStreamingChatModel and OpenAiChatModel classes #971

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
008214c
Avoid using OpenAiStreamingChatModel class
anunnakian Apr 18, 2024
12d7e82
Polish
anunnakian Apr 18, 2024
0b32b58
Avoid using OpenAiStreamingChatModel class
anunnakian Apr 18, 2024
b7c85e3
Polish
anunnakian Apr 18, 2024
04682e7
Polish
anunnakian Apr 20, 2024
b2fa5fc
Merge remote-tracking branch 'origin/remove_OpenAiStreamingChatModel'…
anunnakian Apr 21, 2024
3091962
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 22, 2024
d83026b
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 23, 2024
5d7b790
Update the doc
anunnakian Apr 23, 2024
9f92054
Merge remote-tracking branch 'origin/remove_OpenAiStreamingChatModel'…
anunnakian Apr 23, 2024
1c90526
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 24, 2024
41c2a82
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 25, 2024
a317064
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 25, 2024
a32990f
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian Apr 27, 2024
8dd12fd
Deprecate OpenAiStreamingChatModel class
anunnakian May 1, 2024
0a0ac22
Polish
anunnakian May 1, 2024
89b2abf
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 1, 2024
6ae3eec
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 2, 2024
adb6064
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 3, 2024
bcffa06
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 6, 2024
6ba9487
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 7, 2024
23f7b42
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 14, 2024
61e3f79
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 15, 2024
930558a
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 17, 2024
178af50
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 19, 2024
65a09d6
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 23, 2024
2887d35
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 25, 2024
ba204fd
Merge branch 'main' into remove_OpenAiStreamingChatModel
anunnakian May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/docs/integrations/language-models/open-ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ OpenAiChatModel model = OpenAiChatModel.builder()
.logResponses(...)
.tokenizer(...)
.customHeaders(...)
.isStreaming(...)
anunnakian marked this conversation as resolved.
Show resolved Hide resolved
.build();
```
See the description of some of the parameters above [here](https://platform.openai.com/docs/api-reference/chat/create).

## OpenAiStreamingChatModel
## Streaming ChatModel

```java
OpenAiStreamingChatModel model = OpenAiStreamingChatModel.withApiKey(System.getenv("OPENAI_API_KEY"));
OpenAiChatModel model = OpenAiChatModel.withApiKey(System.getenv("OPENAI_API_KEY"), true);

model.generate("Say 'Hello World'", new StreamingResponseHandler<AiMessage>() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Test;

Expand All @@ -18,13 +18,14 @@
*/
class OllamaOpenAiStreamingChatModelIT extends AbstractOllamaLanguageModelInfrastructure {

StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
StreamingChatLanguageModel model = OpenAiChatModel.builder()
.apiKey("does not matter") // TODO make apiKey optional when using custom baseUrl?
.baseUrl(ollama.getEndpoint() + "/v1") // TODO add "/v1" by default?
.modelName(TINY_DOLPHIN_MODEL)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
.isStreaming(true)
.build();

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
package dev.langchain4j.model.openai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.net.Proxy;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
Expand All @@ -31,7 +35,7 @@
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
public class OpenAiChatModel implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator {

private final OpenAiClient client;
private final String modelName;
Expand All @@ -47,6 +51,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final String user;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final boolean isOpenAiModel;

@Builder
public OpenAiChatModel(String baseUrl,
Expand All @@ -69,7 +74,8 @@ public OpenAiChatModel(String baseUrl,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders) {
Map<String, String> customHeaders,
boolean isStreaming) {

baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
Expand All @@ -78,7 +84,7 @@ public OpenAiChatModel(String baseUrl,

timeout = getOrDefault(timeout, ofSeconds(60));

this.client = OpenAiClient.builder()
OpenAiClient.Builder openAiClientBuilder = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
Expand All @@ -90,9 +96,18 @@ public OpenAiChatModel(String baseUrl,
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
.customHeaders(customHeaders);

if (isStreaming) {
openAiClientBuilder.logResponses(logResponses);
} else {
openAiClientBuilder.logStreamingResponses(logResponses);
}
this.client = openAiClientBuilder.build();

this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.isOpenAiModel = isOpenAiModel(modelName);

this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.stop = stop;
Expand All @@ -113,22 +128,22 @@ public String modelName() {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null);
return generateMessage(messages, null, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
return generateMessage(messages, toolSpecifications, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
return generateMessage(messages, singletonList(toolSpecification), toolSpecification);
}

private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
private Response<AiMessage> generateMessage(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName)
Expand Down Expand Up @@ -168,7 +183,11 @@ public int estimateTokenCount(List<ChatMessage> messages) {
}

public static OpenAiChatModel withApiKey(String apiKey) {
return builder().apiKey(apiKey).build();
return withApiKey(apiKey, false);
}

public static OpenAiChatModel withApiKey(String apiKey, boolean isStreaming) {
return builder().apiKey(apiKey).isStreaming(isStreaming).build();
}

public static OpenAiChatModelBuilder builder() {
Expand All @@ -195,4 +214,92 @@ public OpenAiChatModelBuilder modelName(OpenAiChatModelName modelName) {
return this;
}
}

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
generate(messages, null, null, handler);
}

@Override
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
generate(messages, toolSpecifications, null, handler);
}

@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
generate(messages, null, toolSpecification, handler);
}

private void generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
StreamingResponseHandler<AiMessage> handler
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.stream(true)
.model(modelName)
.messages(toOpenAiMessages(messages))
.temperature(temperature)
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);

if (toolThatMustBeExecuted != null) {
requestBuilder.tools(toTools(singletonList(toolThatMustBeExecuted)));
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
}

ChatCompletionRequest request = requestBuilder.build();

int inputTokenCount = countInputTokens(messages, toolSpecifications, toolThatMustBeExecuted);
OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder(inputTokenCount);

client.chatCompletion(request)
.onPartialResponse(partialResponse -> {
responseBuilder.append(partialResponse);
handle(partialResponse, handler);
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
if (!isOpenAiModel) {
response = removeTokenUsage(response);
}
handler.onComplete(response);
})
.onError(handler::onError)
.execute();
}

private int countInputTokens(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
int inputTokenCount = tokenizer.estimateTokenCountInMessages(messages);
if (toolThatMustBeExecuted != null) {
inputTokenCount += tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
} else if (!isNullOrEmpty(toolSpecifications)) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
return inputTokenCount;
}

private static void handle(ChatCompletionResponse partialResponse,
StreamingResponseHandler<AiMessage> handler) {
List<ChatCompletionChoice> choices = partialResponse.choices();
if (choices == null || choices.isEmpty()) {
return;
}
Delta delta = choices.get(0).delta();
String content = delta.content();
if (content != null) {
handler.onNext(content);
}
}
}