Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freature #1005 - Add streaming API for Bedrock Anthropics #1006

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main
| [OpenAI](https://docs.langchain4j.dev/integrations/language-models/open-ai) | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-open-ai) | | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/hugging-face) | | ✅ | | ✅ | | | | |
| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | | ✅ | ✅ | | |
| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | | ✅ | ✅ | | |
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

| [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ |
| [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | |
| [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistral-ai) | | ✅ | ✅ | ✅ | | | ✅ |
Expand Down
8 changes: 8 additions & 0 deletions langchain4j-bedrock/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;

@Getter
@SuperBuilder
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel {
@Builder.Default
private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: should it use one of v3 models (e.g. haiku) by default?


@Override
protected String getModelId() {
return model;
}

@Getter
/**
* Bedrock Anthropic model ids
*/
public enum Types {
AnthropicClaudeV2("anthropic.claude-v2"),
AnthropicClaudeV2_1("anthropic.claude-v2:1");

private final String value;

Types(String modelID) {
this.value = modelID;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
Expand All @@ -30,47 +31,14 @@
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> implements ChatLanguageModel {
private static final String HUMAN_PROMPT = "Human:";
private static final String ASSISTANT_PROMPT = "Assistant:";

@Builder.Default
private final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
private final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
private final Integer maxRetries = 5;
@Builder.Default
private final Region region = Region.US_EAST_1;
@Builder.Default
private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
private final int maxTokens = 300;
@Builder.Default
private final float temperature = 1;
@Builder.Default
private final float topP = 0.999f;
@Builder.Default
private final String[] stopSequences = new String[]{};
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> extends AbstractSharedBedrockChatModel implements ChatLanguageModel {
@Getter(lazy = true)
private final BedrockRuntimeClient client = initClient();

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {

final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
final String body = convertMessagesToAwsBody(messages);

InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries);
final String response = invokeModelResponse.body().asUtf8String();
Expand All @@ -81,26 +49,6 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
result.getFinishReason());
}

/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

/**
* Get request parameters
Expand All @@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) {
*/
protected abstract Map<String, Object> getRequestParameters(final String prompt);

/**
* Get model id
*
* @return model id
*/
protected abstract String getModelId();


/**
* Get response class type
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Json;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

/**
* Bedrock Streaming chat model
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel {
@Getter
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient();

class StreamingResponse {
public String completion;
}

@Override
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) {
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage(userMessage));
generate(messages, handler);
}

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages)))
.modelId(getModelId())
.contentType("application/json")
.accept("application/json")
.build();

StringBuffer finalCompletion = new StringBuffer();

InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
.onChunk(chunk -> {
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class);
finalCompletion.append(sr.completion);
handler.onNext(sr.completion);
})
.build();

InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor)))
.onComplete(() -> {
handler.onComplete(Response.from(new AiMessage(finalCompletion.toString())));
})
.onError(handler::onError)
.build();
asyncClient.invokeModelWithResponseStream(request, h).join();

}

/**
* Initialize async bedrock client
*
* @return async bedrock client
*/
private BedrockRuntimeAsyncClient initAsyncClient() {
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(region)
.credentialsProvider(credentialsProvider)
.build();
return client;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.internal.Json;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.stream.Collectors.joining;

@Getter
@SuperBuilder
public abstract class AbstractSharedBedrockChatModel {
// Claude requires you to enclose the prompt as follows:
// String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:";
protected static final String HUMAN_PROMPT = "Human:";
protected static final String ASSISTANT_PROMPT = "Assistant:";
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";

@Builder.Default
protected final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
protected final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
protected final Integer maxRetries = 5;
@Builder.Default
protected final Region region = Region.US_EAST_1;
@Builder.Default
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
protected final int maxTokens = 300;
@Builder.Default
protected final double temperature = 1;
@Builder.Default
protected final float topP = 0.999f;
@Builder.Default
protected final String[] stopSequences = new String[]{};
@Builder.Default
protected final int topK = 250;
@Builder.Default
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;


/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
return body;
}

protected Map<String, Object> getRequestParameters(String prompt) {
final Map<String, Object> parameters = new HashMap<>(7);

parameters.put("prompt", prompt);
parameters.put("max_tokens_to_sample", getMaxTokens());
parameters.put("temperature", getTemperature());
parameters.put("top_k", topK);
parameters.put("top_p", getTopP());
parameters.put("stop_sequences", getStopSequences());
parameters.put("anthropic_version", anthropicVersion);

return parameters;
}

/**
* Get model id
*
* @return model id
*/
protected abstract String getModelId();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.regions.Region;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;

public class BedrockStreamingChatModelIT {
@Test
@Disabled("To run this test, you must have provide your own access key, secret, region")
void testBedrockAnthropicStreamingChatModel() {
//given
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel
.builder()
.temperature(0.5)
.maxTokens(300)
.region(Region.US_EAST_1)
.maxRetries(1)
.build();
UserMessage userMessage = userMessage("What's the capital of Poland?");

//when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
bedrockChatModel.generate(singletonList(userMessage), handler);
Response<AiMessage> response = handler.get();

//then
assertThat(response.content().text()).contains("Warsaw");
}


}