Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#435: Add metadata support (read/write) to pinecone embedded store #955

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void should_add_embedding_with_segment_with_metadata() {
}

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to add this function to be able to use Awaitility with pinecone client to wait for records to be retrievable.
Pinecone is eventually consistent, so records are/may not be available just after write

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff! What do you think about changing existing awaitUntilPersisted() instead of introducing new overloaded method? I guess the implementation can also be moved to the EmbeddingStoreIT so that all implementations can benefit from using awaitility?


List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 1);
assertThat(relevant).hasSize(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void should_add_embedding() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand All @@ -68,6 +69,7 @@ void should_add_embedding_with_id() {
embeddingStore().add(id, embedding);

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -95,6 +97,7 @@ void should_add_embedding_with_segment() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
assertThat(relevant).hasSize(1);
Expand Down Expand Up @@ -125,6 +128,7 @@ void should_add_multiple_embeddings() {
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -171,6 +175,7 @@ void should_add_multiple_embeddings_with_segments() {
assertThat(ids.get(0)).isNotEqualTo(ids.get(1));

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -210,6 +215,7 @@ void should_find_with_min_score() {
embeddingStore().add(secondId, secondEmbedding);

awaitUntilPersisted();
awaitUntilPersisted(firstEmbedding, 2);

List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(firstEmbedding, 10);
assertThat(relevant).hasSize(2);
Expand Down Expand Up @@ -282,6 +288,7 @@ void should_return_correct_score() {
assertThat(id).isNotBlank();

awaitUntilPersisted();
awaitUntilPersisted(embedding, 1);

Embedding referenceEmbedding = embeddingModel().embed("hi").content();

Expand All @@ -304,4 +311,8 @@ void should_return_correct_score() {
protected void awaitUntilPersisted() {
// not waiting by default
}

protected void awaitUntilPersisted(Embedding firstEmbedding, int expectedSize) {
// not waiting by default
}
}
28 changes: 9 additions & 19 deletions langchain4j-pinecone/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,7 @@
<dependency>
<groupId>io.pinecone</groupId>
<artifactId>pinecone-client</artifactId>
<version>0.6.0</version>
<exclusions>
<!-- CVE-2023-44487 -->
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
</exclusion>
<!-- CVE-2023-3635 -->
<exclusion>
<groupId>com.squareup.okio</groupId>
<artifactId>okio-jvm</artifactId>
</exclusion>
<!-- CVE-2020-29582 -->
<exclusion>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib</artifactId>
</exclusion>
</exclusions>
<version>1.0.0</version>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for migrating to 1.0.0! 🤗

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinecone API 1.1.0 is out, it looks like its only added a new feature (pinecone-io/pinecone-java-client@v1.0.0...v1.1.0) and a couple of bug fixes, might be worth upgrading while we have eyes on it.

</dependency>
<dependency>
<groupId>io.netty</groupId>
Expand Down Expand Up @@ -83,6 +66,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>


<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
Expand All @@ -91,4 +81,4 @@

</dependencies>

</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import io.pinecone.PineconeClient;
import io.pinecone.PineconeClientConfig;
import io.pinecone.PineconeConnection;
import io.pinecone.PineconeConnectionConfig;
import io.pinecone.clients.Index;
import io.pinecone.clients.Pinecone;
import io.pinecone.proto.*;
import io.pinecone.proto.Vector;
import io.pinecone.unsigned_indices_model.QueryResponseWithUnsignedIndices;
import io.pinecone.unsigned_indices_model.ScoredVectorWithUnsignedIndices;
import io.pinecone.unsigned_indices_model.VectorWithUnsignedIndices;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.Utils.randomUUID;
import static java.util.Collections.emptyList;
Expand All @@ -34,41 +37,38 @@ public class PineconeEmbeddingStore implements EmbeddingStore<TextSegment> {
private static final String DEFAULT_NAMESPACE = "default"; // do not change, will break backward compatibility!
private static final String DEFAULT_METADATA_TEXT_KEY = "text_segment"; // do not change, will break backward compatibility!

private final PineconeConnection connection;
private final String nameSpace;
private final String metadataTextKey;
private final Index pineconeIndex;
private final Pinecone pinecone;
private final Consumer<List<VectorWithUnsignedIndices>> afterUpsertAction;

/**
* Creates an instance of PineconeEmbeddingStore.
*
* @param apiKey The Pinecone API key.
* @param environment The environment (e.g., "northamerica-northeast1-gcp").
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: https://app.pinecone.io/organizations/.../projects/...:{projectId}/indexes.
* @param index The name of the index (e.g., "test").
* @param nameSpace (Optional) Namespace. If not provided, "default" will be used.
* @param apiKey The Pinecone API key.
* @param environment The environment (e.g., "northamerica-northeast1-gcp").
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: <a href="https://app.pinecone.io/organizations/.../projects/">...</a>...:{projectId}/indexes.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required?

* @param index The name of the index (e.g., "test").
* @param nameSpace (Optional) Namespace. If not provided, "default" will be used.
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
* @param afterUpsertAction
*/
public PineconeEmbeddingStore(String apiKey,
String environment,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that environment and projectId are not used any more. Does Pinecone resolve them from api key now?
I think we should mark there params as @Deprecated in the builder then, WDYT?

String projectId,
String index,
String nameSpace,
String metadataTextKey) {
String metadataTextKey,
Consumer<List<VectorWithUnsignedIndices>> afterUpsertAction) {

PineconeClientConfig configuration = new PineconeClientConfig()
.withApiKey(apiKey)
.withEnvironment(environment)
.withProjectName(projectId);

PineconeClient pineconeClient = new PineconeClient(configuration);

PineconeConnectionConfig connectionConfig = new PineconeConnectionConfig()
.withIndexName(index);

this.connection = pineconeClient.connect(connectionConfig);
this.pinecone = new Pinecone.Builder(apiKey).build();
this.pineconeIndex = pinecone.getIndexConnection(index);
this.nameSpace = nameSpace == null ? DEFAULT_NAMESPACE : nameSpace;
this.metadataTextKey = metadataTextKey == null ? DEFAULT_METADATA_TEXT_KEY : metadataTextKey;
this.afterUpsertAction = afterUpsertAction;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why afterUpsertAction is needed? Seems to not be used anywhere, also no javadoc with explanation.
I would remove it if not necessary.

}

@Override
Expand Down Expand Up @@ -120,60 +120,49 @@ private void addInternal(String id, Embedding embedding, TextSegment textSegment

private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {

UpsertRequest.Builder upsertRequestBuilder = UpsertRequest.newBuilder()
.setNamespace(nameSpace);

List<VectorWithUnsignedIndices> vectorList = new ArrayList<>();
for (int i = 0; i < embeddings.size(); i++) {

String id = ids.get(i);
Embedding embedding = embeddings.get(i);

Vector.Builder vectorBuilder = Vector.newBuilder()
.setId(id)
.addAllValues(embedding.vectorAsList());

VectorWithUnsignedIndices vector = new VectorWithUnsignedIndices(id, embedding.vectorAsList());


if (textSegments != null) {
vectorBuilder.setMetadata(Struct.newBuilder()
Struct.Builder metadataStructBuilder = Struct.newBuilder()
.putFields(metadataTextKey, Value.newBuilder()
.setStringValue(textSegments.get(i).text())
.build()));
.build())
.putAllFields(textSegments.get(i).metadata().asMap().entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> Value.newBuilder().setStringValue(e.getValue()).build())));
vector.setMetadata(metadataStructBuilder.build());
}

upsertRequestBuilder.addVectors(vectorBuilder.build());
vectorList.add(vector);
}

connection.getBlockingStub().upsert(upsertRequestBuilder.build());
pineconeIndex.upsert(vectorList, nameSpace);
Optional.ofNullable(afterUpsertAction).ifPresent(action -> action.accept(vectorList));
}

@Override
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {

QueryRequest queryRequest = QueryRequest
.newBuilder()
.addAllVector(referenceEmbedding.vectorAsList())
.setNamespace(nameSpace)
.setTopK(maxResults)
.build();

List<String> matchedVectorIds = connection.getBlockingStub()
.query(queryRequest)
.getMatchesList()
.stream()
.map(ScoredVector::getId)
.collect(toList());

if (matchedVectorIds.isEmpty()) {
QueryResponseWithUnsignedIndices matchedVectorIds = pineconeIndex.queryByVector(maxResults, referenceEmbedding.vectorAsList(), nameSpace);

if (matchedVectorIds.getMatchesList().isEmpty()) {
return emptyList();
}
List<String> ids = matchedVectorIds.getMatchesList().stream()
.map(ScoredVectorWithUnsignedIndices::getId)
.collect(toList());

Collection<Vector> matchedVectors = connection.getBlockingStub().fetch(FetchRequest.newBuilder()
.addAllIds(matchedVectorIds)
.setNamespace(nameSpace)
.build())
.getVectorsMap()
.values();
FetchResponse fetchResponse = pineconeIndex.fetch(ids, nameSpace);

List<EmbeddingMatch<TextSegment>> matches = matchedVectors.stream()
List<EmbeddingMatch<TextSegment>> matches = fetchResponse.getVectorsMap().values().stream()
.map(vector -> toEmbeddingMatch(vector, referenceEmbedding))
.filter(match -> match.score() >= minScore)
.sorted(comparingDouble(EmbeddingMatch::score))
Expand All @@ -184,22 +173,47 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbeddi
return matches;
}


private EmbeddingMatch<TextSegment> toEmbeddingMatch(Vector vector, Embedding referenceEmbedding) {
Value textSegmentValue = vector.getMetadata()
Struct metadataStruct = vector.getMetadata();

Value textSegmentValue = metadataStruct
.getFieldsMap()
.get(metadataTextKey);

boolean filterOutMetadataTextKey = true;
Map<String, String> metadataMap = structToMap(metadataStruct, filterOutMetadataTextKey);
Metadata metadata = Metadata.from(metadataMap);

Embedding embedding = Embedding.from(vector.getValuesList());
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);

return new EmbeddingMatch<>(
RelevanceScore.fromCosineSimilarity(cosineSimilarity),
vector.getId(),
embedding,
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue())
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue(), metadata)
);
}

private Map<String, String> structToMap(Struct struct, boolean filterOutMetadataTextKey) {
Map<String, String> result = new HashMap<>();
Map<String, Value> fields = struct.getFieldsMap();

for (Map.Entry<String, Value> entry : fields.entrySet()) {
if (filterOutMetadataTextKey && isMetadataTextKey(entry.getKey())) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when retreiving I am filtering out metadataTextKey from metadata, otherwise tests were failing.
I assumed metadataTextKey is only a technical thing to let us store original content in metadata, but it is not something that should be exposed

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We store original embedded text under text_segment (or whatever is defined in metadataTextKey property) key in Pinecone metadata, and it should be returned back inside TextSegment.text().

The problem can happen in such case:
TextSegment textSegment = TextSegment.from("hello", new Metadata().put("text_segment", "bye"))
Since metadata key matches key of the text, text will be overriden and no metadata when retrieveing it back:
TextSegment { text = "bye" metadata = {} }

One option is to prepend all metadata keys with metadata_ prefix and then removing this prefix when retreiving back, so this TextSegment.from("hello", new Metadata().put("text_segment", "bye")) will become:

text_segment -> "hello"
metadata_text_segment -> "bye"

in Pinecone's metadata.

WDYT?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just making sure your not going to have a key collision and throwing an exception if it happens?
My reasoning being is some people are going to have multiple systems feeding into their store, and they will have to go modify other systems to support this paradigm.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfotex makes sense!

continue;
}
result.put(entry.getKey(), entry.getValue().getStringValue());
}

return result;
}

private boolean isMetadataTextKey(String key) {
return metadataTextKey.equals(key);
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -212,6 +226,7 @@ public static class Builder {
private String index;
private String nameSpace;
private String metadataTextKey;
private Consumer<List<VectorWithUnsignedIndices>> afterUpsertAction;

/**
* @param apiKey The Pinecone API key.
Expand All @@ -231,7 +246,7 @@ public Builder environment(String environment) {

/**
* @param projectId The ID of the project (e.g., "19a129b"). This is <b>not</b> a project name.
* The ID can be found in the Pinecone URL: https://app.pinecone.io/organizations/.../projects/...:{projectId}/indexes.
* The ID can be found in the Pinecone URL: <a href="https://app.pinecone.io/organizations/.../projects/">...</a>...:{projectId}/indexes.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

*/
public Builder projectId(String projectId) {
this.projectId = projectId;
Expand All @@ -257,13 +272,20 @@ public Builder nameSpace(String nameSpace) {
/**
* @param metadataTextKey (Optional) The key to find the text in the metadata. If not provided, "text_segment" will be used.
*/
@SuppressWarnings("unused")
public Builder metadataTextKey(String metadataTextKey) {
this.metadataTextKey = metadataTextKey;
return this;
}

public PineconeEmbeddingStore build() {
return new PineconeEmbeddingStore(apiKey, environment, projectId, index, nameSpace, metadataTextKey);
return new PineconeEmbeddingStore(apiKey, environment, projectId, index, nameSpace, metadataTextKey, afterUpsertAction);
}

@SuppressWarnings("unused")
public Builder afterUpsertAction(Consumer<List<VectorWithUnsignedIndices>> onUpsertAction) {
this.afterUpsertAction = onUpsertAction;
return this;
}
}
}