forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding jina as scoring model. Feature request: langchain4j#974
Module added according langchain4j#973 in order to minimize conflicts. Implementation in line with cohere reranking
- Loading branch information
Showing
14 changed files
with
514 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-parent</artifactId> | ||
<version>0.31.0-SNAPSHOT</version> | ||
<relativePath>../langchain4j-parent/pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-jina-ai</artifactId> | ||
<name>LangChain4j :: Integration :: Jina.ai</name> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>retrofit</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>converter-jackson</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.okhttp3</groupId> | ||
<artifactId>okhttp</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
<scope>provided</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.tinylog</groupId> | ||
<artifactId>tinylog-impl</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.tinylog</groupId> | ||
<artifactId>slf4j-tinylog</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
</dependencies> | ||
|
||
</project> |
9 changes: 9 additions & 0 deletions
9
langchain4j-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/Document.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import lombok.Getter; | ||
|
||
@Getter | ||
public class Document { | ||
|
||
private String text; | ||
} |
14 changes: 14 additions & 0 deletions
14
langchain4j-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/JinaApi.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import retrofit2.Call; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.Header; | ||
import retrofit2.http.Headers; | ||
import retrofit2.http.POST; | ||
|
||
interface JinaApi { | ||
|
||
@POST("rerank") | ||
@Headers({"Content-Type: application/json"}) | ||
Call<RerankResponse> rerank(@Body RerankRequest request, @Header("Authorization") String authorizationHeader); | ||
} |
80 changes: 80 additions & 0 deletions
80
langchain4j-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/JinaClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import lombok.Builder; | ||
import okhttp3.OkHttpClient; | ||
import retrofit2.Retrofit; | ||
import retrofit2.converter.jackson.JacksonConverterFactory; | ||
|
||
import java.io.IOException; | ||
import java.time.Duration; | ||
|
||
import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; | ||
import static com.fasterxml.jackson.databind.PropertyNamingStrategies.SNAKE_CASE; | ||
import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; | ||
import static com.fasterxml.jackson.databind.cfg.EnumFeature.WRITE_ENUMS_TO_LOWERCASE; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
|
||
class JinaClient { | ||
|
||
private static final ObjectMapper objectMapper = new ObjectMapper(); | ||
|
||
static { | ||
objectMapper.enable(INDENT_OUTPUT); | ||
objectMapper.setPropertyNamingStrategy(SNAKE_CASE); | ||
objectMapper.configure(WRITE_ENUMS_TO_LOWERCASE, true); | ||
objectMapper.setSerializationInclusion(NON_NULL); | ||
objectMapper.enable(INDENT_OUTPUT); | ||
} | ||
|
||
private final JinaApi jinaApi; | ||
private final String authorizationHeader; | ||
|
||
@Builder | ||
JinaClient(String baseUrl, String apiKey, Duration timeout, Boolean logRequests, Boolean logResponses) { | ||
|
||
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout); | ||
|
||
if (logRequests) { | ||
okHttpClientBuilder.addInterceptor(new RequestLoggingInterceptor()); | ||
} | ||
if (logResponses) { | ||
okHttpClientBuilder.addInterceptor(new ResponseLoggingInterceptor()); | ||
} | ||
|
||
Retrofit retrofit = new Retrofit.Builder() | ||
.baseUrl(baseUrl) | ||
.client(okHttpClientBuilder.build()) | ||
.addConverterFactory(JacksonConverterFactory.create(objectMapper)) | ||
.build(); | ||
|
||
this.jinaApi = retrofit.create(JinaApi.class); | ||
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey"); | ||
} | ||
|
||
public RerankResponse rerank(RerankRequest request) { | ||
try { | ||
retrofit2.Response<RerankResponse> retrofitResponse | ||
= jinaApi.rerank(request, authorizationHeader).execute(); | ||
|
||
if (retrofitResponse.isSuccessful()) { | ||
return retrofitResponse.body(); | ||
} else { | ||
throw toException(retrofitResponse); | ||
} | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private static RuntimeException toException(retrofit2.Response<?> response) throws IOException { | ||
int code = response.code(); | ||
String body = response.errorBody().string(); | ||
String errorMessage = String.format("status code: %s; body: %s", code, body); | ||
return new RuntimeException(errorMessage); | ||
} | ||
} |
82 changes: 82 additions & 0 deletions
82
langchain4j-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/JinaScoringModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.model.output.TokenUsage; | ||
import dev.langchain4j.model.scoring.ScoringModel; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
import static java.time.Duration.ofSeconds; | ||
import static java.util.Comparator.comparingInt; | ||
import static java.util.stream.Collectors.toList; | ||
|
||
/** | ||
* An implementation of a {@link ScoringModel} that uses | ||
* <a href="https://jina.ai/reranker">Jina Rerank API</a>. | ||
*/ | ||
public class JinaScoringModel implements ScoringModel { | ||
|
||
private static final String DEFAULT_BASE_URL = "https://api.jina.ai/v1/"; | ||
/** | ||
* This is the leading Jina Reranker model | ||
*/ | ||
private static final String DEFAULT_MODEL = "jina-reranker-v1-base-en"; | ||
|
||
private final JinaClient client; | ||
private final String modelName; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public JinaScoringModel( | ||
String baseUrl, | ||
String apiKey, | ||
String modelName, | ||
Duration timeout, | ||
Integer maxRetries, | ||
Boolean logRequests, | ||
Boolean logResponses | ||
) { | ||
this.client = JinaClient.builder() | ||
.baseUrl(getOrDefault(baseUrl, DEFAULT_BASE_URL)) | ||
.apiKey(ensureNotBlank(apiKey, "apiKey")) | ||
.timeout(getOrDefault(timeout, ofSeconds(60))) | ||
.logRequests(getOrDefault(logRequests, false)) | ||
.logResponses(getOrDefault(logResponses, false)) | ||
.build(); | ||
this.modelName = getOrDefault(modelName, DEFAULT_MODEL); | ||
this.maxRetries = getOrDefault(maxRetries, 3); | ||
} | ||
|
||
public static JinaScoringModel withApiKey(String apiKey) { | ||
return JinaScoringModel.builder().apiKey(apiKey).build(); | ||
} | ||
|
||
|
||
|
||
@Override | ||
public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) { | ||
|
||
RerankRequest request = RerankRequest.builder() | ||
.model(modelName) | ||
.query(query) | ||
.documents(segments.stream() | ||
.map(TextSegment::text) | ||
.collect(toList())) | ||
.build(); | ||
|
||
RerankResponse response = withRetry(() -> client.rerank(request), maxRetries); | ||
|
||
List<Double> scores = response.getResults().stream() | ||
.sorted(comparingInt(Result::getIndex)) | ||
.map(Result::getRelevanceScore) | ||
.collect(toList()); | ||
|
||
return Response.from(scores, new TokenUsage(response.getUsage().getTotalTokens())); | ||
} | ||
} |
82 changes: 82 additions & 0 deletions
82
...-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/RequestLoggingInterceptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import okhttp3.Headers; | ||
import okhttp3.Interceptor; | ||
import okhttp3.Request; | ||
import okhttp3.Response; | ||
import okio.Buffer; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.io.IOException; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
import java.util.stream.Collectors; | ||
|
||
import static java.util.stream.StreamSupport.stream; | ||
|
||
class RequestLoggingInterceptor implements Interceptor { | ||
|
||
private static final Logger log = LoggerFactory.getLogger(RequestLoggingInterceptor.class); | ||
|
||
private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s)(\\w{2})(\\w+)(\\w{2})"); | ||
|
||
public Response intercept(Interceptor.Chain chain) throws IOException { | ||
Request request = chain.request(); | ||
log(request); | ||
return chain.proceed(request); | ||
} | ||
|
||
private void log(Request request) { | ||
log.debug( | ||
"Request:\n" + | ||
"- method: {}\n" + | ||
"- url: {}\n" + | ||
"- headers: {}\n" + | ||
"- body: {}", | ||
request.method(), | ||
request.url(), | ||
inOneLine(request.headers()), | ||
getBody(request) | ||
); | ||
} | ||
|
||
static String inOneLine(Headers headers) { | ||
return stream(headers.spliterator(), false) | ||
.map((header) -> { | ||
String headerKey = header.component1(); | ||
String headerValue = header.component2(); | ||
if (headerKey.equals("Authorization")) { | ||
headerValue = maskAuthorizationHeaderValue(headerValue); | ||
} | ||
return String.format("[%s: %s]", headerKey, headerValue); | ||
}).collect(Collectors.joining(", ")); | ||
} | ||
|
||
private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { | ||
try { | ||
Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue); | ||
StringBuffer sb = new StringBuffer(); | ||
|
||
while (matcher.find()) { | ||
matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4)); | ||
} | ||
|
||
matcher.appendTail(sb); | ||
return sb.toString(); | ||
} catch (Exception e) { | ||
return "[failed to mask the API key]"; | ||
} | ||
} | ||
|
||
private static String getBody(Request request) { | ||
try { | ||
Buffer buffer = new Buffer(); | ||
request.body().writeTo(buffer); | ||
return buffer.readUtf8(); | ||
} catch (Exception e) { | ||
log.warn("Exception happened while reading request body", e); | ||
return "[Exception happened while reading request body. Check logs for more details.]"; | ||
} | ||
} | ||
} |
19 changes: 19 additions & 0 deletions
19
langchain4j-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/RerankRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import lombok.Data; | ||
import lombok.NoArgsConstructor; | ||
|
||
import java.util.List; | ||
|
||
@Builder | ||
@NoArgsConstructor | ||
@AllArgsConstructor | ||
@Data | ||
class RerankRequest { | ||
|
||
private String model; | ||
private String query; | ||
private List<String> documents; | ||
} |
15 changes: 15 additions & 0 deletions
15
langchain4j-jina-ai/src/main/java/dev/langchain4j/model/jinaAi/rerank/RerankResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package dev.langchain4j.model.jinaAi.rerank; | ||
|
||
import lombok.Getter; | ||
import lombok.NoArgsConstructor; | ||
|
||
import java.util.List; | ||
|
||
@NoArgsConstructor | ||
@Getter | ||
class RerankResponse { | ||
|
||
private String model; | ||
private Usage usage; | ||
private List<Result> results; | ||
} |
Oops, something went wrong.