-
Notifications
You must be signed in to change notification settings - Fork 691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Freature #1005 - Add streaming API for Bedrock Anthropics #1006
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e4439a8
Freature #1005 - Add streaming API for Bedrock Anthropics
michalkozminski fb41bfd
Merge branch 'main' into feature-1005
michalkozminski d0a1b20
Merge branch 'main' into feature-1005
michalkozminski 60824da
Merge branch 'main' into feature-1005
michalkozminski d24ecfa
Merge branch 'main' into feature-1005
michalkozminski 060e8c1
Merge branch 'main' into feature-1005
michalkozminski a90f8a7
Merge branch 'main' into feature-1005
michalkozminski File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
...drock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { | ||
@Builder.Default | ||
private final String model = BedrockAnthropicStreamingChatModel.Types.AnthropicClaudeV2.getValue(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: should it use one of v3 models (e.g. haiku) by default? |
||
|
||
@Override | ||
protected String getModelId() { | ||
return model; | ||
} | ||
|
||
@Getter | ||
/** | ||
* Bedrock Anthropic model ids | ||
*/ | ||
public enum Types { | ||
AnthropicClaudeV2("anthropic.claude-v2"), | ||
AnthropicClaudeV2_1("anthropic.claude-v2:1"); | ||
|
||
private final String value; | ||
|
||
Types(String modelID) { | ||
this.value = modelID; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
...c/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.internal.Json; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.core.SdkBytes; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
/** | ||
* Bedrock Streaming chat model | ||
*/ | ||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { | ||
@Getter | ||
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); | ||
|
||
class StreamingResponse { | ||
public String completion; | ||
} | ||
|
||
@Override | ||
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) { | ||
List<ChatMessage> messages = new ArrayList<>(); | ||
messages.add(new UserMessage(userMessage)); | ||
generate(messages, handler); | ||
} | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() | ||
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) | ||
.modelId(getModelId()) | ||
.contentType("application/json") | ||
.accept("application/json") | ||
.build(); | ||
|
||
StringBuffer finalCompletion = new StringBuffer(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() | ||
.onChunk(chunk -> { | ||
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); | ||
finalCompletion.append(sr.completion); | ||
handler.onNext(sr.completion); | ||
}) | ||
.build(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() | ||
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) | ||
.onComplete(() -> { | ||
handler.onComplete(Response.from(new AiMessage(finalCompletion.toString()))); | ||
}) | ||
.onError(handler::onError) | ||
.build(); | ||
asyncClient.invokeModelWithResponseStream(request, h).join(); | ||
|
||
} | ||
|
||
/** | ||
* Initialize async bedrock client | ||
* | ||
* @return async bedrock client | ||
*/ | ||
private BedrockRuntimeAsyncClient initAsyncClient() { | ||
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() | ||
.region(region) | ||
.credentialsProvider(credentialsProvider) | ||
.build(); | ||
return client; | ||
} | ||
|
||
|
||
|
||
} |
112 changes: 112 additions & 0 deletions
112
.../src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ChatMessageType; | ||
import dev.langchain4j.internal.Json; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; | ||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; | ||
import software.amazon.awssdk.regions.Region; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static java.util.stream.Collectors.joining; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractSharedBedrockChatModel { | ||
// Claude requires you to enclose the prompt as follows: | ||
// String enclosedPrompt = "Human: " + prompt + "\n\nAssistant:"; | ||
protected static final String HUMAN_PROMPT = "Human:"; | ||
protected static final String ASSISTANT_PROMPT = "Assistant:"; | ||
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; | ||
|
||
@Builder.Default | ||
protected final String humanPrompt = HUMAN_PROMPT; | ||
@Builder.Default | ||
protected final String assistantPrompt = ASSISTANT_PROMPT; | ||
@Builder.Default | ||
protected final Integer maxRetries = 5; | ||
@Builder.Default | ||
protected final Region region = Region.US_EAST_1; | ||
@Builder.Default | ||
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); | ||
@Builder.Default | ||
protected final int maxTokens = 300; | ||
@Builder.Default | ||
protected final double temperature = 1; | ||
@Builder.Default | ||
protected final float topP = 0.999f; | ||
@Builder.Default | ||
protected final String[] stopSequences = new String[]{}; | ||
@Builder.Default | ||
protected final int topK = 250; | ||
@Builder.Default | ||
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; | ||
|
||
|
||
/** | ||
* Convert chat message to string | ||
* | ||
* @param message chat message | ||
* @return string | ||
*/ | ||
protected String chatMessageToString(ChatMessage message) { | ||
switch (message.type()) { | ||
case SYSTEM: | ||
return message.text(); | ||
case USER: | ||
return humanPrompt + " " + message.text(); | ||
case AI: | ||
return assistantPrompt + " " + message.text(); | ||
case TOOL_EXECUTION_RESULT: | ||
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); | ||
} | ||
|
||
throw new IllegalArgumentException("Unknown message type: " + message.type()); | ||
} | ||
|
||
protected String convertMessagesToAwsBody(List<ChatMessage> messages) { | ||
final String context = messages.stream() | ||
.filter(message -> message.type() == ChatMessageType.SYSTEM) | ||
.map(ChatMessage::text) | ||
.collect(joining("\n")); | ||
|
||
final String userMessages = messages.stream() | ||
.filter(message -> message.type() != ChatMessageType.SYSTEM) | ||
.map(this::chatMessageToString) | ||
.collect(joining("\n")); | ||
|
||
final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); | ||
final Map<String, Object> requestParameters = getRequestParameters(prompt); | ||
final String body = Json.toJson(requestParameters); | ||
return body; | ||
} | ||
|
||
protected Map<String, Object> getRequestParameters(String prompt) { | ||
final Map<String, Object> parameters = new HashMap<>(7); | ||
|
||
parameters.put("prompt", prompt); | ||
parameters.put("max_tokens_to_sample", getMaxTokens()); | ||
parameters.put("temperature", getTemperature()); | ||
parameters.put("top_k", topK); | ||
parameters.put("top_p", getTopP()); | ||
parameters.put("stop_sequences", getStopSequences()); | ||
parameters.put("anthropic_version", anthropicVersion); | ||
|
||
return parameters; | ||
} | ||
|
||
/** | ||
* Get model id | ||
* | ||
* @return model id | ||
*/ | ||
protected abstract String getModelId(); | ||
|
||
} |
39 changes: 39 additions & 0 deletions
39
...in4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.chat.TestStreamingResponseHandler; | ||
import dev.langchain4j.model.output.Response; | ||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
import software.amazon.awssdk.regions.Region; | ||
|
||
import static dev.langchain4j.data.message.UserMessage.userMessage; | ||
import static java.util.Collections.singletonList; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
public class BedrockStreamingChatModelIT { | ||
@Test | ||
@Disabled("To run this test, you must have provide your own access key, secret, region") | ||
void testBedrockAnthropicStreamingChatModel() { | ||
//given | ||
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel | ||
.builder() | ||
.temperature(0.5) | ||
.maxTokens(300) | ||
.region(Region.US_EAST_1) | ||
.maxRetries(1) | ||
.build(); | ||
UserMessage userMessage = userMessage("What's the capital of Poland?"); | ||
|
||
//when | ||
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>(); | ||
bedrockChatModel.generate(singletonList(userMessage), handler); | ||
Response<AiMessage> response = handler.get(); | ||
|
||
//then | ||
assertThat(response.content().text()).contains("Warsaw"); | ||
} | ||
|
||
|
||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update https://github.com/langchain4j/langchain4j/blob/main/docs/docs/integrations/language-models/index.md as well