Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cohere: Add CohereClient and CohereChatModel to support Chat #917

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package dev.langchain4j.model.cohere;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.cohere.internal.api.CohereChatRequest;
import dev.langchain4j.model.cohere.internal.api.CohereChatResponse;
import dev.langchain4j.model.cohere.internal.client.CohereClient;
import dev.langchain4j.model.output.Response;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.cohere.internal.mapper.CohereMapper.*;

/**
* An implementation of a ChatModel that uses
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* An implementation of a ChatModel that uses
* An implementation of a ChatLanguageModel that uses

* <a href="https://docs.cohere.com/docs/command-r">Cohere Command R API</a>.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* <a href="https://docs.cohere.com/docs/command-r">Cohere Command R API</a>.
* <a href="https://docs.cohere.com/reference/chat">Cohere Chat API</a>.

*/
public class CohereChatModel implements ChatLanguageModel {

private static final String DEFAULT_BASE_URL = "https://api.cohere.ai/v1/";

private final CohereClient client;
private final String modelName;
private final Double temperature;
private final Double topP;
private final Integer topK;
private final Integer maxTokens;
private final List<String> stopSequences;
private final Integer maxRetries;

/**
* Constructs an instance of an {@code CohereChatModel} with the specified parameters.
*
* @param baseUrl The base URL of the Cohere API. Default: "https://api.cohere.ai/v1/"
* @param apiKey The API key for authentication with the Cohere API.
* @param modelName The name of the Cohere model to use. Default: command-r
* @param temperature The temperature. Default: 0.3
* @param topP The top-P. Defaults to 0.75. min value of 0.01, max value of 0.99.
* @param topK The top-K. Defaults to 0, min value of 0, max value of 500.
* @param maxTokens The maximum number of tokens the model will generate as part of the response.
* @param stopSequences The custom text sequences that will cause the model to stop generating
* @param timeout The timeout for API requests. Default: 60 seconds
* @param maxRetries The maximum number of retries for API requests. Default: 3
* @param logRequests Whether to log the content of API requests using SLF4J. Default: false
* @param logResponses Whether to log the content of API responses using SLF4J. Default: false
*/
@Builder
private CohereChatModel(String baseUrl,
String apiKey,
String modelName,
Double temperature,
Double topP,
Integer topK,
Integer maxTokens,
List<String> stopSequences,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses) {
this.client = CohereClient.builder()
.baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL))
.apiKey(apiKey)
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
this.modelName = getOrDefault(modelName, "command-r");
this.temperature = temperature;
this.topP = topP;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use Cohere naming, e.g. p instead of topP.
Same with other params.

this.topK = topK;
this.maxTokens = getOrDefault(maxTokens, 1024);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is not a mandatotry param, would remove the default

this.stopSequences = stopSequences;
this.maxRetries = getOrDefault(maxRetries, 3);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: I would not add a default for temperature, they have a default on their side already

}

/**
* Creates an instance of {@code CohereChatModel} with the specified API key.
*
* @param apiKey the API key for authentication
* @return an {@code CohereChatModel} instance
*/
public static CohereChatModel withApiKey(String apiKey) {
return builder().apiKey(apiKey).build();
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, (List<ToolSpecification>) null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> tools) {
ensureNotEmpty(messages, "messages");

CohereChatRequest request = CohereChatRequest.builder()
.model(modelName)
.preamble(toPreamble(messages))
.message(toCohereMessage(messages.get(messages.size() - 1)))
.toolResults(toToolResults(messages))
.chatHistory(toChatHistory(messages.subList(0, messages.size() - 1)))
.maxTokens(maxTokens)
.stopSequences(stopSequences)
.stream(false)
.temperature(temperature)
.p(topP)
.k(topK)
.tools(toCohereTools(tools))
.build();

CohereChatResponse response = withRetry(() -> client.chat(request), maxRetries);

return Response.from(
toAiMessage(response),
toTokenUsage(response.getMeta().getBilledUnits())
);
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package dev.langchain4j.model.cohere;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.cohere.internal.api.RerankRequest;
import dev.langchain4j.model.cohere.internal.api.RerankResponse;
import dev.langchain4j.model.cohere.internal.api.Result;
import dev.langchain4j.model.cohere.internal.client.CohereClient;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.scoring.ScoringModel;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.AllArgsConstructor;
import lombok.Getter;

@Getter
@AllArgsConstructor
public class BilledUnits {

private Integer searchUnits;

private Integer inputTokens;

private Integer outputTokens;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.Builder;
import lombok.NonNull;

@Builder
public class ChatHistory {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ChatMessage or Message seems to be a better name.
History is the complete list of messages, not a single one.


@NonNull
CohereRole role;

@NonNull
String message;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.Getter;
import lombok.ToString;

import java.util.List;


@Getter
@ToString
public class Citation {
Integer start;

Integer end;

String text;

List<String> documentIds;
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package dev.langchain4j.model.cohere;
package dev.langchain4j.model.cohere.internal.api;

import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import retrofit2.http.POST;

interface CohereApi {
public interface CohereApi {

@POST("chat")
@Headers({"Content-Type: application/json"})
Call<CohereChatResponse> chat(@Body CohereChatRequest request, @Header("Authorization") String authorizationHeader);

@POST("rerank")
@Headers({"Content-Type: application/json"})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.Builder;
import lombok.NonNull;

import java.util.List;
import java.util.Map;

@Builder
public class CohereChatRequest {

@NonNull
String message;

String model;

Boolean stream;

String preamble;

List<ChatHistory> chatHistory;

String conversationId;

String promptTruncation;

List<Connector> connectors;

Boolean searchQueriesOnly;

List<Map<String, String>> documents;

Double temperature;

Integer maxTokens;

Integer k;

Double p;

Double seed;

List<String> stopSequences;

Double frequencyPenalty;

Double presencePenalty;

List<Tool> tools;

List<ToolResult> toolResults;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.Getter;

import java.util.List;
import java.util.Map;


@Getter
public class CohereChatResponse {

String text;

String generationId;

List<Citation> citations;

List<Map<String, String>> documents;

Boolean isSearchRequired;

List<SearchQuery> searchQueries;

List<SearchResult> searchResults;

String finishReason;

List<ToolCall> toolCalls;

List<ChatHistory> chatHistory;

Meta meta;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package dev.langchain4j.model.cohere.internal.api;

public enum CohereRole {

CHATBOX,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be CHATBOT? Does this even work?


SYSTEM,

USER

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about TOOL role?

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.Builder;
import lombok.NonNull;

import java.util.Map;

@Builder
public class Connector {

@NonNull
String id;

String userAccessToken;

Boolean continueOnFailure;

Map<String, String> options;

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package dev.langchain4j.model.cohere;
package dev.langchain4j.model.cohere.internal.api;

import lombok.Getter;


@Getter
class Meta {
public class Meta {

private BilledUnits billedUnits;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.model.cohere.internal.api;

import lombok.Builder;
import lombok.NonNull;

@Builder
public class ParameterDefinition {

String description;

@NonNull
String type;

Boolean required;
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package dev.langchain4j.model.cohere;
package dev.langchain4j.model.cohere.internal.api;

import lombok.Builder;

import java.util.List;

@Builder
class RerankRequest {
public class RerankRequest {

private String model;
private String query;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package dev.langchain4j.model.cohere;
package dev.langchain4j.model.cohere.internal.api;


import lombok.Getter;

import java.util.List;


@Getter
class RerankResponse {
public class RerankResponse {

private List<Result> results;
private Meta meta;
Expand Down