Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jina AI Embedding model integration #997

Merged
merged 14 commits into from
May 22, 2024
60 changes: 60 additions & 0 deletions langchain4j-jina/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.31.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-jina</artifactId>
<name>LangChain4j :: Integration :: Jina</name>

<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.jina;

import lombok.Builder;

import java.util.List;
@Builder
public class EmbeddingRequest {
String model;
List<String> input;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.jina;

import lombok.Data;

import java.util.List;
@Data
public class EmbeddingResponse {
Usage usage;
List<JinaEmbedding> data;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package dev.langchain4j.model.jina;

import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import retrofit2.http.POST;

public interface JinaApi {
@POST("v1/embeddings")
@Headers({"Content-Type: application/json"})
Call<EmbeddingResponse> embed(@Body EmbeddingRequest request, @Header("Authorization") String authorizationHeader);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package dev.langchain4j.model.jina;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.time.Duration;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

public class JinaClient {
private static final Gson GSON = new GsonBuilder()
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.setPrettyPrinting()
.create();

private final JinaApi jinaApi;
private final String authorizationHeader;

@Builder
JinaClient(String baseUrl, String apiKey, Duration timeout){
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout);
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.client(okHttpClientBuilder.build())
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();


this.jinaApi = retrofit.create(JinaApi.class);
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey");
}

public EmbeddingResponse embed(EmbeddingRequest request) {
try {
retrofit2.Response<EmbeddingResponse> retrofitResponse
= jinaApi.embed(request, authorizationHeader).execute();

if (retrofitResponse.isSuccessful()) {
return retrofitResponse.body();
} else {
throw toException(retrofitResponse);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}



private static RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();
String errorMessage = String.format("status code: %s; body: %s", code, body);
return new RuntimeException(errorMessage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package dev.langchain4j.model.jina;

import dev.langchain4j.data.embedding.Embedding;

public class JinaEmbedding {
long index;
float[] embedding;
String object;

public Embedding toEmbedding(){
return Embedding.from(embedding);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package dev.langchain4j.model.jina;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;

/**
* An integration with Nomic Atlas's Text Embeddings API.
* See more details <a href="https://api.jina.ai/redoc#tag/embeddings">Jina API reference</a>
*/

public class JinaEmbeddingModel implements EmbeddingModel {


private static final String DEFAULT_BASE_URL = "https://api.jina.ai/";

private final JinaClient client;
private final String modelName;
private final Integer maxRetries;

@Builder
public JinaEmbeddingModel(String baseUrl,
String apiKey,
String modelName,
Duration timeout,
Integer maxRetries) {
this.client = JinaClient.builder()
.baseUrl(getOrDefault(baseUrl,DEFAULT_BASE_URL))
.apiKey(apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.build();
this.modelName = getOrDefault(modelName, "jina-embeddings-v2-base-en");
this.maxRetries = getOrDefault(maxRetries, 3);
}

public static JinaEmbeddingModel withApiKey(String apiKey) {
return JinaEmbeddingModel.builder().apiKey(apiKey).build();
}


@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
EmbeddingRequest request = EmbeddingRequest.builder()
.model(modelName)
.input(textSegments.stream().map(TextSegment::text).collect(toList()))
.build();

EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries);

List<Embedding> embeddings = response.getData().stream()
.map(JinaEmbedding::toEmbedding).collect(toList());

TokenUsage tokenUsage = new TokenUsage(response.getUsage().getPromptTokens(),0 );
return Response.from(embeddings,tokenUsage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.model.jina;

import com.google.gson.annotations.SerializedName;
import lombok.Getter;

@Getter
class Usage {
@SerializedName("total_tokens")
private Integer totalTokens;
@SerializedName("prompt_tokens")
private Integer promptTokens;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dev.langchain4j.model.jina;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.CosineSimilarity;
import org.junit.jupiter.api.Test;

import java.util.List;

import static java.time.Duration.ofSeconds;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;



public class JinaEmbeddingModelIT {
@Test
public void should_embed_single_text() {

// given
EmbeddingModel model = JinaEmbeddingModel.withApiKey(System.getenv("JINA_AI_API_KEY"));

String text = "hello";

// when
Response<Embedding> response = model.embed(text);

// then
assertThat(response.content().dimension()).isEqualTo(768);

assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(3);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(3);
}

@Test
public void should_embed_multiple_segments() {

// given
EmbeddingModel model = JinaEmbeddingModel.builder()
.baseUrl("https://api.jina.ai/")
.apiKey(System.getenv("JINA_AI_API_KEY"))
.modelName("jina-embeddings-v2-base-en")
.timeout(ofSeconds(10))
.maxRetries(2)
.build();

TextSegment segment1 = TextSegment.from("hello");
TextSegment segment2 = TextSegment.from("hi");

// when
Response<List<Embedding>> response = model.embedAll(asList(segment1, segment2));

// then
assertThat(response.content()).hasSize(2);

Embedding embedding1 = response.content().get(0);
assertThat(embedding1.dimension()).isEqualTo(768);

Embedding embedding2 = response.content().get(1);
assertThat(embedding2.dimension()).isEqualTo(768);

assertThat(CosineSimilarity.between(embedding1, embedding2)).isGreaterThan(0.9);

assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(6);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(6);
}
}
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<module>langchain4j-vertex-ai</module>
<module>langchain4j-vertex-ai-gemini</module>
<module>langchain4j-zhipu-ai</module>
<module>langchain4j-jina</module>

<!-- embedding stores -->
<module>langchain4j-azure-ai-search</module>
Expand Down