-
Notifications
You must be signed in to change notification settings - Fork 682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cohere: Add CohereClient and CohereChatModel to support Chat #917
base: main
Are you sure you want to change the base?
Changes from all commits
912b9f4
f32aad6
32ef3c4
86781ec
2a02f86
f4791ce
308d79c
a1ea113
b31a642
dbf7844
54e72a5
27bf822
1e874fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,124 @@ | ||||||
package dev.langchain4j.model.cohere; | ||||||
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification; | ||||||
import dev.langchain4j.data.message.*; | ||||||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||||||
import dev.langchain4j.model.cohere.internal.api.CohereChatRequest; | ||||||
import dev.langchain4j.model.cohere.internal.api.CohereChatResponse; | ||||||
import dev.langchain4j.model.cohere.internal.client.CohereClient; | ||||||
import dev.langchain4j.model.output.Response; | ||||||
import lombok.Builder; | ||||||
|
||||||
import java.time.Duration; | ||||||
import java.util.List; | ||||||
|
||||||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||||||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; | ||||||
import static dev.langchain4j.model.cohere.internal.mapper.CohereMapper.*; | ||||||
|
||||||
/** | ||||||
* An implementation of a ChatModel that uses | ||||||
* <a href="https://docs.cohere.com/docs/command-r">Cohere Command R API</a>. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*/ | ||||||
public class CohereChatModel implements ChatLanguageModel { | ||||||
|
||||||
private static final String DEFAULT_BASE_URL = "https://api.cohere.ai/v1/"; | ||||||
|
||||||
private final CohereClient client; | ||||||
private final String modelName; | ||||||
private final Double temperature; | ||||||
private final Double topP; | ||||||
private final Integer topK; | ||||||
private final Integer maxTokens; | ||||||
private final List<String> stopSequences; | ||||||
private final Integer maxRetries; | ||||||
|
||||||
/** | ||||||
* Constructs an instance of an {@code CohereChatModel} with the specified parameters. | ||||||
* | ||||||
* @param baseUrl The base URL of the Cohere API. Default: "https://api.cohere.ai/v1/" | ||||||
* @param apiKey The API key for authentication with the Cohere API. | ||||||
* @param modelName The name of the Cohere model to use. Default: command-r | ||||||
* @param temperature The temperature. Default: 0.3 | ||||||
* @param topP The top-P. Defaults to 0.75. min value of 0.01, max value of 0.99. | ||||||
* @param topK The top-K. Defaults to 0, min value of 0, max value of 500. | ||||||
* @param maxTokens The maximum number of tokens the model will generate as part of the response. | ||||||
* @param stopSequences The custom text sequences that will cause the model to stop generating | ||||||
* @param timeout The timeout for API requests. Default: 60 seconds | ||||||
* @param maxRetries The maximum number of retries for API requests. Default: 3 | ||||||
* @param logRequests Whether to log the content of API requests using SLF4J. Default: false | ||||||
* @param logResponses Whether to log the content of API responses using SLF4J. Default: false | ||||||
*/ | ||||||
@Builder | ||||||
private CohereChatModel(String baseUrl, | ||||||
String apiKey, | ||||||
String modelName, | ||||||
Double temperature, | ||||||
Double topP, | ||||||
Integer topK, | ||||||
Integer maxTokens, | ||||||
List<String> stopSequences, | ||||||
Duration timeout, | ||||||
Integer maxRetries, | ||||||
Boolean logRequests, | ||||||
Boolean logResponses) { | ||||||
this.client = CohereClient.builder() | ||||||
.baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL)) | ||||||
.apiKey(apiKey) | ||||||
.timeout(getOrDefault(timeout, Duration.ofSeconds(60))) | ||||||
.logRequests(getOrDefault(logRequests, false)) | ||||||
.logResponses(getOrDefault(logResponses, false)) | ||||||
.build(); | ||||||
this.modelName = getOrDefault(modelName, "command-r"); | ||||||
this.temperature = temperature; | ||||||
this.topP = topP; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use Cohere naming, e.g. |
||||||
this.topK = topK; | ||||||
this.maxTokens = getOrDefault(maxTokens, 1024); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is not a mandatotry param, would remove the default |
||||||
this.stopSequences = stopSequences; | ||||||
this.maxRetries = getOrDefault(maxRetries, 3); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: I would not add a default for temperature, they have a default on their side already |
||||||
} | ||||||
|
||||||
/** | ||||||
* Creates an instance of {@code CohereChatModel} with the specified API key. | ||||||
* | ||||||
* @param apiKey the API key for authentication | ||||||
* @return an {@code CohereChatModel} instance | ||||||
*/ | ||||||
public static CohereChatModel withApiKey(String apiKey) { | ||||||
return builder().apiKey(apiKey).build(); | ||||||
} | ||||||
|
||||||
@Override | ||||||
public Response<AiMessage> generate(List<ChatMessage> messages) { | ||||||
return generate(messages, (List<ToolSpecification>) null); | ||||||
} | ||||||
|
||||||
@Override | ||||||
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> tools) { | ||||||
ensureNotEmpty(messages, "messages"); | ||||||
|
||||||
CohereChatRequest request = CohereChatRequest.builder() | ||||||
.model(modelName) | ||||||
.preamble(toPreamble(messages)) | ||||||
.message(toCohereMessage(messages.get(messages.size() - 1))) | ||||||
.toolResults(toToolResults(messages)) | ||||||
.chatHistory(toChatHistory(messages.subList(0, messages.size() - 1))) | ||||||
.maxTokens(maxTokens) | ||||||
.stopSequences(stopSequences) | ||||||
.stream(false) | ||||||
.temperature(temperature) | ||||||
.p(topP) | ||||||
.k(topK) | ||||||
.tools(toCohereTools(tools)) | ||||||
.build(); | ||||||
|
||||||
CohereChatResponse response = withRetry(() -> client.chat(request), maxRetries); | ||||||
|
||||||
return Response.from( | ||||||
toAiMessage(response), | ||||||
toTokenUsage(response.getMeta().getBilledUnits()) | ||||||
); | ||||||
} | ||||||
|
||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
|
||
@Getter | ||
@AllArgsConstructor | ||
public class BilledUnits { | ||
|
||
private Integer searchUnits; | ||
|
||
private Integer inputTokens; | ||
|
||
private Integer outputTokens; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Builder; | ||
import lombok.NonNull; | ||
|
||
@Builder | ||
public class ChatHistory { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
|
||
@NonNull | ||
CohereRole role; | ||
|
||
@NonNull | ||
String message; | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Getter; | ||
import lombok.ToString; | ||
|
||
import java.util.List; | ||
|
||
|
||
@Getter | ||
@ToString | ||
public class Citation { | ||
Integer start; | ||
|
||
Integer end; | ||
|
||
String text; | ||
|
||
List<String> documentIds; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Builder; | ||
import lombok.NonNull; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
@Builder | ||
public class CohereChatRequest { | ||
|
||
@NonNull | ||
String message; | ||
|
||
String model; | ||
|
||
Boolean stream; | ||
|
||
String preamble; | ||
|
||
List<ChatHistory> chatHistory; | ||
|
||
String conversationId; | ||
|
||
String promptTruncation; | ||
|
||
List<Connector> connectors; | ||
|
||
Boolean searchQueriesOnly; | ||
|
||
List<Map<String, String>> documents; | ||
|
||
Double temperature; | ||
|
||
Integer maxTokens; | ||
|
||
Integer k; | ||
|
||
Double p; | ||
|
||
Double seed; | ||
|
||
List<String> stopSequences; | ||
|
||
Double frequencyPenalty; | ||
|
||
Double presencePenalty; | ||
|
||
List<Tool> tools; | ||
|
||
List<ToolResult> toolResults; | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Getter; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
|
||
@Getter | ||
public class CohereChatResponse { | ||
|
||
String text; | ||
|
||
String generationId; | ||
|
||
List<Citation> citations; | ||
|
||
List<Map<String, String>> documents; | ||
|
||
Boolean isSearchRequired; | ||
|
||
List<SearchQuery> searchQueries; | ||
|
||
List<SearchResult> searchResults; | ||
|
||
String finishReason; | ||
|
||
List<ToolCall> toolCalls; | ||
|
||
List<ChatHistory> chatHistory; | ||
|
||
Meta meta; | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
public enum CohereRole { | ||
|
||
CHATBOX, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be |
||
|
||
SYSTEM, | ||
|
||
USER | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Builder; | ||
import lombok.NonNull; | ||
|
||
import java.util.Map; | ||
|
||
@Builder | ||
public class Connector { | ||
|
||
@NonNull | ||
String id; | ||
|
||
String userAccessToken; | ||
|
||
Boolean continueOnFailure; | ||
|
||
Map<String, String> options; | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
package dev.langchain4j.model.cohere; | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Getter; | ||
|
||
|
||
@Getter | ||
class Meta { | ||
public class Meta { | ||
|
||
private BilledUnits billedUnits; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.model.cohere.internal.api; | ||
|
||
import lombok.Builder; | ||
import lombok.NonNull; | ||
|
||
@Builder | ||
public class ParameterDefinition { | ||
|
||
String description; | ||
|
||
@NonNull | ||
String type; | ||
|
||
Boolean required; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.