Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
1402564807 committed Apr 23, 2024
1 parent f492df6 commit eb0c25a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ private static Function toFunction(ToolSpecification toolSpecification) {
}

private static Parameters toFunctionParameters(ToolParameters toolParameters) {
if (toolParameters == null) {
return Parameters.builder().build();
}
return Parameters.builder()
.properties(toolParameters.properties())
.required(toolParameters.required())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,52 @@ void should_execute_a_tool_then_answer() {

assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}


ToolSpecification currentTime = ToolSpecification.builder()
.name("currentTime")
.description("currentTime")
.build();

@Test
void should_execute_get_current_time_tool_and_then_answer() {
// given
UserMessage userMessage = userMessage("What's the time now?");
List<ToolSpecification> toolSpecifications = singletonList(currentTime);

// when
Response<AiMessage> response = chatModel.generate(singletonList(userMessage), toolSpecifications);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("currentTime");

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);

// given
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "2024-04-23 12:00:20");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);

// when
Response<AiMessage> secondResponse = chatModel.generate(messages);

// then
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("2024-04-23 12:00:20");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();

TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());

assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,56 @@ void should_execute_a_tool_then_stream_answer() {

assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}


ToolSpecification currentTime = ToolSpecification.builder()
.name("currentTime")
.description("currentTime")
.build();

@Test
void should_execute_get_current_time_tool_and_then_answer() {
// given
UserMessage userMessage = userMessage("What's the time now?");
List<ToolSpecification> toolSpecifications = singletonList(currentTime);

// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), toolSpecifications, handler);

// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("currentTime");

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);

// given
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "2024-04-23 12:00:20");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);

// when
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
model.generate(messages, secondHandler);

// then
Response<AiMessage> secondResponse = secondHandler.get();
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("2024-04-23 12:00:20");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();

TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());

assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
}

0 comments on commit eb0c25a

Please sign in to comment.