Skip to content

Commit

Permalink
Adding jina as scoring model. Feature request: langchain4j#974
Browse files Browse the repository at this point in the history
Module added according langchain4j#973 in order to minimize conflicts.
Implementation in line with cohere reranking
  • Loading branch information
ksmeyers committed May 3, 2024
1 parent d28f5ab commit 59d7af8
Show file tree
Hide file tree
Showing 14 changed files with 514 additions and 0 deletions.
74 changes: 74 additions & 0 deletions langchain4j-jina-ai/pom.xml
@@ -0,0 +1,74 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.31.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-jina-ai</artifactId>
<name>LangChain4j :: Integration :: Jina.ai</name>

<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-jackson</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
@@ -0,0 +1,9 @@
package dev.langchain4j.model.jinaAi.rerank;

import lombok.Getter;

@Getter
public class Document {

private String text;
}
@@ -0,0 +1,14 @@
package dev.langchain4j.model.jinaAi.rerank;

import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import retrofit2.http.POST;

interface JinaApi {

@POST("rerank")
@Headers({"Content-Type: application/json"})
Call<RerankResponse> rerank(@Body RerankRequest request, @Header("Authorization") String authorizationHeader);
}
@@ -0,0 +1,80 @@
package dev.langchain4j.model.jinaAi.rerank;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.io.IOException;
import java.time.Duration;

import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL;
import static com.fasterxml.jackson.databind.PropertyNamingStrategies.SNAKE_CASE;
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT;
import static com.fasterxml.jackson.databind.cfg.EnumFeature.WRITE_ENUMS_TO_LOWERCASE;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

class JinaClient {

private static final ObjectMapper objectMapper = new ObjectMapper();

static {
objectMapper.enable(INDENT_OUTPUT);
objectMapper.setPropertyNamingStrategy(SNAKE_CASE);
objectMapper.configure(WRITE_ENUMS_TO_LOWERCASE, true);
objectMapper.setSerializationInclusion(NON_NULL);
objectMapper.enable(INDENT_OUTPUT);
}

private final JinaApi jinaApi;
private final String authorizationHeader;

@Builder
JinaClient(String baseUrl, String apiKey, Duration timeout, Boolean logRequests, Boolean logResponses) {

OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout);

if (logRequests) {
okHttpClientBuilder.addInterceptor(new RequestLoggingInterceptor());
}
if (logResponses) {
okHttpClientBuilder.addInterceptor(new ResponseLoggingInterceptor());
}

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.client(okHttpClientBuilder.build())
.addConverterFactory(JacksonConverterFactory.create(objectMapper))
.build();

this.jinaApi = retrofit.create(JinaApi.class);
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey");
}

public RerankResponse rerank(RerankRequest request) {
try {
retrofit2.Response<RerankResponse> retrofitResponse
= jinaApi.rerank(request, authorizationHeader).execute();

if (retrofitResponse.isSuccessful()) {
return retrofitResponse.body();
} else {
throw toException(retrofitResponse);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();
String errorMessage = String.format("status code: %s; body: %s", code, body);
return new RuntimeException(errorMessage);
}
}
@@ -0,0 +1,82 @@
package dev.langchain4j.model.jinaAi.rerank;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.scoring.ScoringModel;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.time.Duration.ofSeconds;
import static java.util.Comparator.comparingInt;
import static java.util.stream.Collectors.toList;

/**
* An implementation of a {@link ScoringModel} that uses
* <a href="https://jina.ai/reranker">Jina Rerank API</a>.
*/
public class JinaScoringModel implements ScoringModel {

private static final String DEFAULT_BASE_URL = "https://api.jina.ai/v1/";
/**
* This is the leading Jina Reranker model
*/
private static final String DEFAULT_MODEL = "jina-reranker-v1-base-en";

private final JinaClient client;
private final String modelName;
private final Integer maxRetries;

@Builder
public JinaScoringModel(
String baseUrl,
String apiKey,
String modelName,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses
) {
this.client = JinaClient.builder()
.baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL))
.apiKey(ensureNotBlank(apiKey, "apiKey"))
.timeout(getOrDefault(timeout, ofSeconds(60)))
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
this.modelName = getOrDefault(modelName, DEFAULT_MODEL);
this.maxRetries = getOrDefault(maxRetries, 3);
}

public static JinaScoringModel withApiKey(String apiKey) {
return JinaScoringModel.builder().apiKey(apiKey).build();
}



@Override
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {

RerankRequest request = RerankRequest.builder()
.model(modelName)
.query(query)
.documents(segments.stream()
.map(TextSegment::text)
.collect(toList()))
.build();

RerankResponse response = withRetry(() -> client.rerank(request), maxRetries);

List<Double> scores = response.getResults().stream()
.sorted(comparingInt(Result::getIndex))
.map(Result::getRelevanceScore)
.collect(toList());

return Response.from(scores, new TokenUsage(response.getUsage().getTotalTokens()));
}
}
@@ -0,0 +1,82 @@
package dev.langchain4j.model.jinaAi.rerank;

import okhttp3.Headers;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import okio.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static java.util.stream.StreamSupport.stream;

class RequestLoggingInterceptor implements Interceptor {

private static final Logger log = LoggerFactory.getLogger(RequestLoggingInterceptor.class);

private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s)(\\w{2})(\\w+)(\\w{2})");

public Response intercept(Interceptor.Chain chain) throws IOException {
Request request = chain.request();
log(request);
return chain.proceed(request);
}

private void log(Request request) {
log.debug(
"Request:\n" +
"- method: {}\n" +
"- url: {}\n" +
"- headers: {}\n" +
"- body: {}",
request.method(),
request.url(),
inOneLine(request.headers()),
getBody(request)
);
}

static String inOneLine(Headers headers) {
return stream(headers.spliterator(), false)
.map((header) -> {
String headerKey = header.component1();
String headerValue = header.component2();
if (headerKey.equals("Authorization")) {
headerValue = maskAuthorizationHeaderValue(headerValue);
}
return String.format("[%s: %s]", headerKey, headerValue);
}).collect(Collectors.joining(", "));
}

private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) {
try {
Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue);
StringBuffer sb = new StringBuffer();

while (matcher.find()) {
matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4));
}

matcher.appendTail(sb);
return sb.toString();
} catch (Exception e) {
return "[failed to mask the API key]";
}
}

private static String getBody(Request request) {
try {
Buffer buffer = new Buffer();
request.body().writeTo(buffer);
return buffer.readUtf8();
} catch (Exception e) {
log.warn("Exception happened while reading request body", e);
return "[Exception happened while reading request body. Check logs for more details.]";
}
}
}
@@ -0,0 +1,19 @@
package dev.langchain4j.model.jinaAi.rerank;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
class RerankRequest {

private String model;
private String query;
private List<String> documents;
}
@@ -0,0 +1,15 @@
package dev.langchain4j.model.jinaAi.rerank;

import lombok.Getter;
import lombok.NoArgsConstructor;

import java.util.List;

@NoArgsConstructor
@Getter
class RerankResponse {

private String model;
private Usage usage;
private List<Result> results;
}

0 comments on commit 59d7af8

Please sign in to comment.