Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jina AI Embedding model integration #997

Merged
merged 14 commits into from
May 22, 2024
64 changes: 64 additions & 0 deletions langchain4j-jina-ai/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-aggregator</artifactId>
<version>0.30.0</version>
</parent>

<artifactId>langchain4j-jina-ai</artifactId>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved

<properties>
<maven.compiler.source>22</maven.compiler.source>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<maven.compiler.target>22</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>0.30.0</version>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<scope>compile</scope>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.30</version>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
<version>2.9.0</version>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
<version>2.9.0</version>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<version>9.37.3</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
<version>1.19.7</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>3.24.2</version>
<scope>test</scope>
</dependency>
</dependencies>

</project>
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.jinaAi;

import lombok.Builder;

import java.util.List;
@Builder
public class EmbeddingRequest {
String model;
List<String> input;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.jinaAi;

import lombok.Data;

import java.util.List;
@Data
public class EmbeddingResponse {
Usage usage;
List<JinaAiEmbedding> data;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package dev.langchain4j.model.jinaAi;

import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import retrofit2.http.POST;

public interface JinaAiApi {
@POST("v1/embeddings")
@Headers({"Content-Type: application/json"})
Call<EmbeddingResponse> embed(@Body EmbeddingRequest request, @Header("Authorization") String authorizationHeader);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package dev.langchain4j.model.jinaAi;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.time.Duration;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

public class JinaAiClient {
private static final Gson GSON = new GsonBuilder()
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.setPrettyPrinting()
.create();

private final JinaAiApi jinaAiApi;
private final String authorizationHeader;

@Builder
JinaAiClient(String baseUrl, String apiKey, Duration timeout){
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout);
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.client(okHttpClientBuilder.build())
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();


this.jinaAiApi= retrofit.create(JinaAiApi.class);
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey");
}

public EmbeddingResponse embed(EmbeddingRequest request) {
try {
retrofit2.Response<EmbeddingResponse> retrofitResponse
= jinaAiApi.embed(request, authorizationHeader).execute();

if (retrofitResponse.isSuccessful()) {
return retrofitResponse.body();
} else {
throw toException(retrofitResponse);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}



private static RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();
String errorMessage = String.format("status code: %s; body: %s", code, body);
return new RuntimeException(errorMessage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.langchain4j.model.jinaAi;

import dev.langchain4j.data.embedding.Embedding;

import java.util.List;

public class JinaAiEmbedding {
long index;
float[] embedding;
String object;

public Embedding toEmbedding(){
return Embedding.from(embedding);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package dev.langchain4j.model.jinaAi;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;

/**
* An integration with Nomic Atlas's Text Embeddings API.
* See more details <a href="https://api.jina.ai/redoc#tag/embeddings">Jina API reference</a>
*/

public class JinaAiEmbeddingModel implements EmbeddingModel {


private static final String DEFAULT_BASE_URL = "https://api.jina.ai/";

private final JinaAiClient client;
private final String modelName;
private final Integer maxRetries;

@Builder
public JinaAiEmbeddingModel(String baseUrl,
String apiKey,
String modelName,
Duration timeout,
Integer maxRetries) {
this.client = JinaAiClient.builder()
.baseUrl(getOrDefault(baseUrl,DEFAULT_BASE_URL))
.apiKey(apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.build();
this.modelName = getOrDefault(modelName, "jina-embeddings-v2-base-en");
this.maxRetries = getOrDefault(maxRetries, 3);
}

public static JinaAiEmbeddingModel withApiKey(String apiKey) {
return JinaAiEmbeddingModel.builder().apiKey(apiKey).build();
}


@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
EmbeddingRequest request = EmbeddingRequest.builder()
.model(modelName)
.input(textSegments.stream().map(TextSegment::text).collect(toList()))
.build();

EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries);

List<Embedding> embeddings = response.getData().stream()
.map(JinaAiEmbedding::toEmbedding).collect(toList());

TokenUsage tokenUsage = new TokenUsage(response.getUsage().getPromptTokens(),0 );
return Response.from(embeddings,tokenUsage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.model.jinaAi;

import com.google.gson.annotations.SerializedName;
import lombok.Getter;

@Getter
class Usage {
@SerializedName("total_tokens")
private Integer totalTokens;
@SerializedName("prompt_tokens")
private Integer promptTokens;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dev.langchain4j.model.jinaAi;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.CosineSimilarity;
import org.junit.Test;
lucifer-Hell marked this conversation as resolved.
Show resolved Hide resolved

import java.util.List;

import static java.time.Duration.ofSeconds;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;



public class JinaAiEmbeddingModelIT {
@Test
public void should_embed_single_text() {

// given
EmbeddingModel model = JinaAiEmbeddingModel.withApiKey(System.getenv("JINA_AI_API_KEY"));

String text = "hello";

// when
Response<Embedding> response = model.embed(text);

// then
assertThat(response.content().dimension()).isEqualTo(768);

assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(3);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(3);
}

@Test
public void should_embed_multiple_segments() {

// given
EmbeddingModel model = JinaAiEmbeddingModel.builder()
.baseUrl("https://api.jina.ai/")
.apiKey(System.getenv("JINA_AI_API_KEY"))
.modelName("jina-embeddings-v2-base-en")
.timeout(ofSeconds(10))
.maxRetries(2)
.build();

TextSegment segment1 = TextSegment.from("hello");
TextSegment segment2 = TextSegment.from("hi");

// when
Response<List<Embedding>> response = model.embedAll(asList(segment1, segment2));

// then
assertThat(response.content()).hasSize(2);

Embedding embedding1 = response.content().get(0);
assertThat(embedding1.dimension()).isEqualTo(768);

Embedding embedding2 = response.content().get(1);
assertThat(embedding2.dimension()).isEqualTo(768);

assertThat(CosineSimilarity.between(embedding1, embedding2)).isGreaterThan(0.9);

assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(6);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(6);
}
}
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<module>langchain4j-vertex-ai</module>
<module>langchain4j-vertex-ai-gemini</module>
<module>langchain4j-zhipu-ai</module>
<module>langchain4j-jina-ai</module>

<!-- embedding stores -->
<module>langchain4j-azure-ai-search</module>
Expand Down