Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Mistral open mixtral8x22b model #978

Merged
merged 40 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
160d53d
Add Mistral AI model provider
czelabueno Jan 15, 2024
0cf6b07
MistralAI chat completions req/resp
czelabueno Jan 15, 2024
d2a47dc
Mistral AI embeddings req/resp
czelabueno Jan 15, 2024
f656f61
Mistral AI Taken usage
czelabueno Jan 15, 2024
62be0b2
Mistral AI models req/resp
czelabueno Jan 15, 2024
9c95a28
Mistral AI Client code
czelabueno Jan 15, 2024
4e6f008
Mistral AI Chat model
czelabueno Jan 15, 2024
c858389
Mistral AI Chat Streaming model support
czelabueno Jan 15, 2024
a3c9fc8
Mistral AI embedding model support
czelabueno Jan 15, 2024
2293142
Mistral Ai get models from API
czelabueno Jan 15, 2024
06e7cce
Mistral AI chat model tests
czelabueno Jan 15, 2024
13f056b
Mistral AI embeddings model tests
czelabueno Jan 15, 2024
ba19306
Mistral AI chat streaming model tests
czelabueno Jan 15, 2024
f09458f
Mistral AI get models tests
czelabueno Jan 15, 2024
63d6924
Merge branch 'main' into main
langchain4j Jan 16, 2024
9117ee3
MistralAI - renamed classes to the project convention names to avoid …
czelabueno Jan 17, 2024
2c1a22f
Mistral AI logRequestResponse and commit suggestions
czelabueno Jan 19, 2024
6b2e2b1
Merge branch 'main' into main
langchain4j Jan 24, 2024
0d87349
Mistral AI token masking until 4 symbols
czelabueno Jan 24, 2024
7da8928
MistralAI update chat model enum
czelabueno Jan 24, 2024
79237cc
MistralAI update embedding model enum
czelabueno Jan 24, 2024
b51d0a5
Mistral AI fix get usageInfo from last chat completion response
czelabueno Jan 24, 2024
2078d32
Mistral AI fix logging streaming and rename enums
czelabueno Jan 24, 2024
49dd2e4
Merge branch 'main' into main
langchain4j Jan 25, 2024
4d44c3b
Merge conflict with upstream repo
czelabueno Jan 29, 2024
134a9dd
Merge remote-tracking branch 'upstream/main'
czelabueno Feb 5, 2024
d982a78
Merge remote-tracking branch 'upstream/main'
czelabueno Feb 10, 2024
004e942
Merge remote-tracking branch 'upstream/main'
czelabueno Feb 23, 2024
08938d7
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 8, 2024
1d902d7
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 11, 2024
f7028e7
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 15, 2024
9ddb2ce
update overview integration table
czelabueno Mar 16, 2024
53a6840
Update docs/docs/integrations/index.mdx
LizeRaes Mar 16, 2024
507e748
Update docs/docs/integrations/index.mdx
LizeRaes Mar 16, 2024
1f21d10
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 25, 2024
14259c4
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 26, 2024
188befb
Merge remote-tracking branch 'upstream/main'
czelabueno Apr 6, 2024
b799984
Merge remote-tracking branch 'upstream/main'
czelabueno Apr 14, 2024
d499ae2
Merge remote-tracking branch 'upstream/main'
czelabueno Apr 16, 2024
a40995c
add new open-weight model
czelabueno Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public enum MistralAiChatModelName {
MISTRAL_TINY("mistral-tiny"),

OPEN_MIXTRAL_8x7B("open-mixtral-8x7b"), // aka mistral-small-2312
OPEN_MIXTRAL_8X22B("open-mixtral-8x22b"), // aka open-mixtral-8x22b

/**
* @deprecated As of release 0.29.0, replaced by {@link #MISTRAL_SMALL_LATEST}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,29 @@ class MistralAiChatModelIT {
.logResponses(true)
.build();

ChatLanguageModel model = MistralAiChatModel.builder()
ChatLanguageModel defaultModel = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

ChatLanguageModel openMixtral8x22BModel = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.OPEN_MIXTRAL_8X22B)
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

@Test
void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {

// given
UserMessage userMessage = userMessage("What is the capital of Peru?");

// when
Response<AiMessage> response = model.generate(userMessage);
Response<AiMessage> response = defaultModel.generate(userMessage);

// then
assertThat(response.content().text()).contains("Lima");
Expand Down Expand Up @@ -132,7 +140,7 @@ void should_generate_answer_and_return_token_usage_and_finish_reason_stop_with_m
UserMessage userMessage3 = userMessage("What is the capital of Canada?");

// when
Response<AiMessage> response = model.generate(userMessage1, userMessage2, userMessage3);
Response<AiMessage> response = defaultModel.generate(userMessage1, userMessage2, userMessage3);

// then
assertThat(response.content().text()).contains("Lima");
Expand Down Expand Up @@ -236,14 +244,14 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish
}

@Test
void should_execute_tool_and_return_finishReason_tool_execution(){
void should_execute_tool_using_model_open8x22B_and_return_finishReason_tool_execution(){

// given
UserMessage userMessage = userMessage("What is the status of transaction T123?");
List<ToolSpecification> toolSpecifications = singletonList(retrievePaymentStatus);

// when
Response<AiMessage> response = mistralLargeModel.generate(singletonList(userMessage), toolSpecifications);
Response<AiMessage> response = openMixtral8x22BModel.generate(singletonList(userMessage), toolSpecifications);

// then
AiMessage aiMessage = response.content();
Expand All @@ -264,7 +272,7 @@ void should_execute_tool_and_return_finishReason_tool_execution(){
}

@Test
void should_execute_tool_when_toolChoice_is_auto_and_answer(){
void should_execute_tool_using_model_open8x22B_when_toolChoice_is_auto_and_answer(){
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
Expand All @@ -279,7 +287,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){
List<ToolSpecification> toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate);

// when
Response<AiMessage> response = mistralLargeModel.generate(chatMessages, toolSpecifications);
Response<AiMessage> response = openMixtral8x22BModel.generate(chatMessages, toolSpecifications);

// then
AiMessage aiMessage = response.content();
Expand All @@ -299,7 +307,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){
chatMessages.add(toolExecutionResultMessage);

// when
Response<AiMessage> response2 = mistralLargeModel.generate(chatMessages);
Response<AiMessage> response2 = openMixtral8x22BModel.generate(chatMessages);

// then
AiMessage aiMessage2 = response2.content();
Expand All @@ -308,7 +316,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(69);
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(74);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
Expand All @@ -317,7 +325,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){
}

@Test
void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
void should_execute_tool_forcefully_using_model_open8x22B_when_toolChoice_is_any_and_answer() {
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
Expand All @@ -330,7 +338,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
chatMessages.add(userMessage);

// when
Response<AiMessage> response = mistralLargeModel.generate(singletonList(userMessage), retrievePaymentDate);
Response<AiMessage> response = openMixtral8x22BModel.generate(singletonList(userMessage), retrievePaymentDate);

// then
AiMessage aiMessage = response.content();
Expand All @@ -355,7 +363,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
chatMessages.add(toolExecutionResultMessage);

// when
Response<AiMessage> response2 = mistralLargeModel.generate(chatMessages);
Response<AiMessage> response2 = openMixtral8x22BModel.generate(chatMessages);

// then
AiMessage aiMessage2 = response2.content();
Expand All @@ -364,7 +372,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(78);
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(83);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
Expand All @@ -373,7 +381,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
}

@Test
void should_return_valid_json_object(){
void should_return_valid_json_object_using_model_large(){

// given
String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid.";
Expand All @@ -397,7 +405,7 @@ void should_return_valid_json_object(){
}

@Test
void should_execute_multiple_tools_then_answer(){
void should_execute_multiple_tools_using_model_open8x22B_then_answer(){
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
Expand Down Expand Up @@ -448,7 +456,7 @@ void should_execute_multiple_tools_then_answer(){
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(128);
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(132);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,19 @@ class MistralAiStreamingChatModelIT {
.logResponses(true)
.build();

StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder()
StreamingChatLanguageModel defaultModel = MistralAiStreamingChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

StreamingChatLanguageModel openMixtral8x22BModel = MistralAiStreamingChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.OPEN_MIXTRAL_8X22B)
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

@Test
Expand All @@ -54,7 +62,7 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_stop() {

// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), handler);
defaultModel.generate(singletonList(userMessage), handler);

Response<AiMessage> response = handler.get();

Expand Down Expand Up @@ -139,7 +147,7 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_mul

// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(asList(userMessage1, userMessage2, userMessage3), handler);
defaultModel.generate(asList(userMessage1, userMessage2, userMessage3), handler);
Response<AiMessage> response = handler.get();

// then
Expand Down Expand Up @@ -244,15 +252,15 @@ void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_r
}

@Test
void should_execute_tool_and_return_finishReason_tool_execution(){
void should_execute_tool_using_model_open8x22B_and_return_finishReason_tool_execution(){

// given
UserMessage userMessage = userMessage("What is the status of transaction T123?");
List<ToolSpecification> toolSpecifications = singletonList(retrievePaymentStatus);

// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
mistralLargeStreamingModel.generate(singletonList(userMessage), toolSpecifications,handler);
openMixtral8x22BModel.generate(singletonList(userMessage), toolSpecifications,handler);

Response<AiMessage> response = handler.get();

Expand All @@ -275,7 +283,7 @@ void should_execute_tool_and_return_finishReason_tool_execution(){
}

@Test
void should_execute_tool_when_toolChoice_is_auto_and_answer(){
void should_execute_tool_using_model_open8x22B_when_toolChoice_is_auto_and_answer(){
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
Expand All @@ -291,7 +299,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){

// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
mistralLargeStreamingModel.generate(chatMessages, toolSpecifications, handler);
openMixtral8x22BModel.generate(chatMessages, toolSpecifications, handler);
Response<AiMessage> response = handler.get();

// then
Expand All @@ -313,7 +321,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){

// when
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
mistralLargeStreamingModel.generate(chatMessages, handler2);
openMixtral8x22BModel.generate(chatMessages, handler2);
Response<AiMessage> response2 = handler2.get();

// then
Expand All @@ -323,7 +331,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(69);
assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(70);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
Expand All @@ -332,7 +340,7 @@ void should_execute_tool_when_toolChoice_is_auto_and_answer(){
}

@Test
void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
void should_execute_tool_forcefully_using_model_open8x22B_when_toolChoice_is_any_and_answer() {
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
Expand All @@ -346,7 +354,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {

// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
mistralLargeStreamingModel.generate(singletonList(userMessage), retrievePaymentDate, handler);
openMixtral8x22BModel.generate(singletonList(userMessage), retrievePaymentDate, handler);
Response<AiMessage> response = handler.get();

// then
Expand All @@ -373,7 +381,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {

// when
TestStreamingResponseHandler<AiMessage> handler2 = new TestStreamingResponseHandler<>();
mistralLargeStreamingModel.generate(chatMessages, handler2);
openMixtral8x22BModel.generate(chatMessages, handler2);
Response<AiMessage> response2 = handler2.get();

// then
Expand All @@ -383,7 +391,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(78);
assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(80);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
Expand All @@ -392,7 +400,7 @@ void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
}

@Test
void should_execute_multiple_tools_then_answer(){
void should_execute_multiple_tools_using_model_large_then_answer(){
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
Expand Down Expand Up @@ -447,7 +455,7 @@ void should_execute_multiple_tools_then_answer(){
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(128);
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(132);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());
Expand All @@ -456,7 +464,7 @@ void should_execute_multiple_tools_then_answer(){
}

@Test
void should_return_valid_json_object(){
void should_return_valid_json_object_using_model_large(){

// given
String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid.";
Expand Down