Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 972: [FEATURE] Add name to dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever to improve logs #1007

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
* <br>
* Configurable parameters (optional):
* <br>
* - {@code displayName}: Display name for logging purposes, e.g. when multiple instances are used.
* <br>
* - {@code maxResults}: The maximum number of {@link Content}s to retrieve.
* <br>
* - {@code dynamicMaxResults}: It is a {@link Function} that accepts a {@link Query} and returns a {@code maxResults} value.
Expand All @@ -57,16 +59,21 @@ public class EmbeddingStoreContentRetriever implements ContentRetriever {
public static final Function<Query, Double> DEFAULT_MIN_SCORE = (query) -> 0.0;
public static final Function<Query, Filter> DEFAULT_FILTER = (query) -> null;

public static final String DEFAULT_DISPLAY_NAME = "Default";

private final EmbeddingStore<TextSegment> embeddingStore;
private final EmbeddingModel embeddingModel;

private final Function<Query, Integer> maxResultsProvider;
private final Function<Query, Double> minScoreProvider;
private final Function<Query, Filter> filterProvider;

private final String displayName;

public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
DEFAULT_MAX_RESULTS,
Expand All @@ -79,6 +86,7 @@ public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore
EmbeddingModel embeddingModel,
int maxResults) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
(query) -> maxResults,
Expand All @@ -92,6 +100,7 @@ public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore
Integer maxResults,
Double minScore) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
(query) -> maxResults,
Expand All @@ -101,11 +110,13 @@ public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore
}

@Builder
private EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
private EmbeddingStoreContentRetriever(String displayName,
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
Function<Query, Integer> dynamicMaxResults,
Function<Query, Double> dynamicMinScore,
Function<Query, Filter> dynamicFilter) {
this.displayName = getOrDefault(displayName, DEFAULT_DISPLAY_NAME);
this.embeddingStore = ensureNotNull(embeddingStore, "embeddingStore");
this.embeddingModel = ensureNotNull(
getOrDefault(embeddingModel, EmbeddingStoreContentRetriever::loadEmbeddingModel),
Expand Down Expand Up @@ -181,4 +192,11 @@ public List<Content> retrieve(Query query) {
.map(Content::from)
.collect(toList());
}

@Override
public String toString() {
return "EmbeddingStoreContentRetriever{" +
"displayName='" + displayName + '\'' +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
Expand Down Expand Up @@ -46,13 +47,7 @@ void beforeEach() {
EMBEDDING_MODEL = mock(EmbeddingModel.class);
when(EMBEDDING_MODEL.embed(anyString())).thenReturn(Response.from(EMBEDDING));
}

@AfterEach
void afterEach() {
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}


@Test
void should_retrieve() {

Expand All @@ -69,6 +64,8 @@ void should_retrieve() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -90,6 +87,8 @@ void should_retrieve_builder() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -112,6 +111,8 @@ void should_retrieve_with_custom_maxResults() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -134,6 +135,8 @@ void should_retrieve_with_custom_maxResults_builder() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -156,6 +159,8 @@ void should_retrieve_with_custom_dynamicMaxResults_builder() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -179,6 +184,8 @@ void should_retrieve_with_custom_minScore_ctor() {
.minScore(CUSTOM_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -201,6 +208,8 @@ void should_retrieve_with_custom_minScore_builder() {
.minScore(CUSTOM_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -223,6 +232,8 @@ void should_retrieve_with_custom_dynamicMinScore_builder() {
.minScore(CUSTOM_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -248,6 +259,8 @@ void should_retrieve_with_custom_filter() {
.filter(metadataFilter)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -273,5 +286,51 @@ void should_retrieve_with_custom_dynamicFilter() {
.filter(metadataFilter)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
void should_include_explicit_display_name_in_to_string() {

// given
double minScore = 0.7;
String displayName = "MyName";
EmbeddingStore<TextSegment> embeddingStore = mock(EmbeddingStore.class);
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.displayName(displayName)
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.minScore(minScore)
.build();

// when
String result = contentRetriever.toString();

// then
assertThat(result).contains(displayName);
}

@Test
void should_include_implicit_display_name_in_to_string() {

// given
double minScore = 0.7;
EmbeddingStore<TextSegment> embeddingStore = mock(EmbeddingStore.class);
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.minScore(minScore)
.build();

// when
String result = contentRetriever.toString();

// then
assertThat(result).contains(EmbeddingStoreContentRetriever.DEFAULT_DISPLAY_NAME);
}
}