Skip to content

Commit

Permalink
langchain4j#435: Add metadata support (read/write) to pinecone embedd…
Browse files Browse the repository at this point in the history
…ed store
  • Loading branch information
rgrebski committed Apr 17, 2024
1 parent 1b3b9f2 commit d5b67ca
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void should_add_embedding_with_segment_with_metadata() {
}

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
assertThat(relevant).hasSize(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void should_add_embedding() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand All @@ -68,6 +69,7 @@ void should_add_embedding_with_id() {
embeddingStore().add(id, embedding);

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -95,6 +97,7 @@ void should_add_embedding_with_segment() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -125,6 +128,7 @@ void should_add_multiple_embeddings() {
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -171,6 +175,7 @@ void should_add_multiple_embeddings_with_segments() {
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -210,6 +215,7 @@ void should_find_with_min_score() {
embeddingStore().add(secondId, secondEmbedding);

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -282,6 +288,7 @@ void should_return_correct_score() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

Embedding referenceEmbedding = embeddingModel().embed("hi").content();

Expand All @@ -304,4 +311,8 @@ void should_return_correct_score() {
protected void awaitUntilPersisted() {
// not waiting by default
}

protected void awaitUntilPersisted(Embedding firstEmbedding, int expectedSize) {
// not waiting by default
}
}
9 changes: 8 additions & 1 deletion langchain4j-pinecone/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>


<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
Expand All @@ -91,4 +98,4 @@

</dependencies>

</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
Expand All @@ -13,10 +14,10 @@
import io.pinecone.PineconeConnection;
import io.pinecone.PineconeConnectionConfig;
import io.pinecone.proto.*;
import io.pinecone.proto.Vector;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Collections.emptyList;
Expand All @@ -41,13 +42,13 @@ public class PineconeEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* Creates an instance of PineconeEmbeddingStore.
*
* @param apiKey The Pinecone API key.
* @param environment The environment (e.g., "northamerica-northeast1-gcp").
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: https://app.pinecone.io/organizations/.../projects/...:{projectId}/indexes.
* @param index The name of the index (e.g., "test").
* @param nameSpace (Optional) Namespace. If not provided, "default" will be used.
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
* @param apiKey The Pinecone API key.
* @param environment The environment (e.g., "northamerica-northeast1-gcp").
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: <a href="https://app.pinecone.io/organizations/.../projects/">...</a>...:{projectId}/indexes.
* @param index The name of the index (e.g., "test").
* @param nameSpace (Optional) Namespace. If not provided, "default" will be used.
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
*/
public PineconeEmbeddingStore(String apiKey,
String environment,
Expand Down Expand Up @@ -136,12 +137,16 @@ private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<T
vectorBuilder.setMetadata(Struct.newBuilder()
.putFields(metadataTextKey, Value.newBuilder()
.setStringValue(textSegments.get(i).text())
.build()));
.build())
.putAllFields(textSegments.get(i).metadata().asMap().entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> Value.newBuilder().setStringValue(e.getValue()).build()))));
}

upsertRequestBuilder.addVectors(vectorBuilder.build());
}

//noinspection ResultOfMethodCallIgnored
connection.getBlockingStub().upsert(upsertRequestBuilder.build());
}

Expand Down Expand Up @@ -184,22 +189,47 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
return matches;
}


private EmbeddingMatch<TextSegment> toEmbeddingMatch(Vector vector, Embedding referenceEmbedding) {
Value textSegmentValue = vector.getMetadata()
Struct metadataStruct = vector.getMetadata();

Value textSegmentValue = metadataStruct
.getFieldsMap()
.get(metadataTextKey);

boolean filterOutMetadataTextKey = true;
Map<String, String> metadataMap = structToMap(metadataStruct, filterOutMetadataTextKey);
Metadata metadata = Metadata.from(metadataMap);

Embedding embedding = Embedding.from(vector.getValuesList());
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);

return new EmbeddingMatch<>(
RelevanceScore.fromCosineSimilarity(cosineSimilarity),
vector.getId(),
embedding,
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue())
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue(), metadata)
);
}

private Map<String, String> structToMap(Struct struct, boolean filterOutMetadataTextKey) {
Map<String, String> result = new HashMap<>();
Map<String, Value> fields = struct.getFieldsMap();

for (Map.Entry<String, Value> entry : fields.entrySet()) {
if (filterOutMetadataTextKey && isMetadataTextKey(entry.getKey())) {
continue;
}
result.put(entry.getKey(), entry.getValue().getStringValue());
}

return result;
}

private boolean isMetadataTextKey(String key) {
return metadataTextKey.equals(key);
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -231,7 +261,7 @@ public Builder environment(String environment) {

/**
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: https://app.pinecone.io/organizations/.../projects/...:{projectId}/indexes.
* The ID can be found in the Pinecone URL: <a href="https://app.pinecone.io/organizations/.../projects/">...</a>...:{projectId}/indexes.
*/
public Builder projectId(String projectId) {
this.projectId = projectId;
Expand All @@ -257,6 +287,7 @@ public Builder nameSpace(String nameSpace) {
/**
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
*/
@SuppressWarnings("unused")
public Builder metadataTextKey(String metadataTextKey) {
this.metadataTextKey = metadataTextKey;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package dev.langchain4j.store.embedding.pinecone;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT;
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import java.time.Duration;

import static dev.langchain4j.internal.Utils.randomUUID;
import static org.awaitility.Awaitility.await;

@EnabledIfEnvironmentVariable(named = "PINECONE_API_KEY", matches = ".+")
class PineconeEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
class PineconeEmbeddingStoreIT extends EmbeddingStoreIT {

EmbeddingStore<TextSegment> embeddingStore = PineconeEmbeddingStore.builder()
.apiKey(System.getenv("PINECONE_API_KEY"))
Expand All @@ -31,4 +35,11 @@ protected EmbeddingStore<TextSegment> embeddingStore() {
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
}

@Override
protected void awaitUntilPersisted(Embedding embedding, int expectedSize) {
await()
.timeout(Duration.ofSeconds(15))
.until(() -> embeddingStore.findRelevant(embedding, expectedSize).size() == expectedSize);
}
}

0 comments on commit d5b67ca

Please sign in to comment.