Skip to content

Commit

Permalink
Issue 972
Browse files Browse the repository at this point in the history
Add display name to dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever to improve logs
  • Loading branch information
alwa authored and alixwar committed Apr 23, 2024
1 parent 6a87b9b commit 6dd3567
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
* <br>
* Configurable parameters (optional):
* <br>
* - {@code displayName}: Display name for this retriever for logging purposes.
* <br>
* - {@code maxResults}: The maximum number of {@link Content}s to retrieve.
* <br>
* - {@code dynamicMaxResults}: It is a {@link Function} that accepts a {@link Query} and returns a {@code maxResults} value.
Expand All @@ -57,16 +59,21 @@ public class EmbeddingStoreContentRetriever implements ContentRetriever {
public static final Function<Query, Double> DEFAULT_MIN_SCORE = (query) -> 0.0;
public static final Function<Query, Filter> DEFAULT_FILTER = (query) -> null;

public static final String DEFAULT_DISPLAY_NAME = "Default";

private final EmbeddingStore<TextSegment> embeddingStore;
private final EmbeddingModel embeddingModel;

private final Function<Query, Integer> maxResultsProvider;
private final Function<Query, Double> minScoreProvider;
private final Function<Query, Filter> filterProvider;

private final String displayName;

public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
DEFAULT_MAX_RESULTS,
Expand All @@ -79,6 +86,7 @@ public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore
EmbeddingModel embeddingModel,
int maxResults) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
(query) -> maxResults,
Expand All @@ -92,6 +100,7 @@ public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore
Integer maxResults,
Double minScore) {
this(
DEFAULT_DISPLAY_NAME,
embeddingStore,
embeddingModel,
(query) -> maxResults,
Expand All @@ -101,11 +110,13 @@ public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore
}

@Builder
private EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
private EmbeddingStoreContentRetriever(String displayName,
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
Function<Query, Integer> dynamicMaxResults,
Function<Query, Double> dynamicMinScore,
Function<Query, Filter> dynamicFilter) {
this.displayName = getOrDefault(displayName, DEFAULT_DISPLAY_NAME);
this.embeddingStore = ensureNotNull(embeddingStore, "embeddingStore");
this.embeddingModel = ensureNotNull(
getOrDefault(embeddingModel, EmbeddingStoreContentRetriever::loadEmbeddingModel),
Expand Down Expand Up @@ -181,4 +192,11 @@ public List<Content> retrieve(Query query) {
.map(Content::from)
.collect(toList());
}

@Override
public String toString() {
return "EmbeddingStoreContentRetriever{" +
"displayName='" + displayName + '\'' +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
Expand Down Expand Up @@ -46,13 +47,7 @@ void beforeEach() {
EMBEDDING_MODEL = mock(EmbeddingModel.class);
when(EMBEDDING_MODEL.embed(anyString())).thenReturn(Response.from(EMBEDDING));
}

@AfterEach
void afterEach() {
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}


@Test
void should_retrieve() {

Expand All @@ -69,6 +64,8 @@ void should_retrieve() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -90,6 +87,8 @@ void should_retrieve_builder() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -112,6 +111,8 @@ void should_retrieve_with_custom_maxResults() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -134,6 +135,8 @@ void should_retrieve_with_custom_maxResults_builder() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -156,6 +159,8 @@ void should_retrieve_with_custom_dynamicMaxResults_builder() {
.minScore(DEFAULT_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -179,6 +184,8 @@ void should_retrieve_with_custom_minScore_ctor() {
.minScore(CUSTOM_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -201,6 +208,8 @@ void should_retrieve_with_custom_minScore_builder() {
.minScore(CUSTOM_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -223,6 +232,8 @@ void should_retrieve_with_custom_dynamicMinScore_builder() {
.minScore(CUSTOM_MIN_SCORE)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -248,6 +259,8 @@ void should_retrieve_with_custom_filter() {
.filter(metadataFilter)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
Expand All @@ -273,5 +286,28 @@ void should_retrieve_with_custom_dynamicFilter() {
.filter(metadataFilter)
.build());
verifyNoMoreInteractions(EMBEDDING_STORE);
verify(EMBEDDING_MODEL).embed(QUERY.text());
verifyNoMoreInteractions(EMBEDDING_MODEL);
}

@Test
void should_include_correct_values_in_to_string_with_default_display_name() {

// given
double minScore = 0.7;
EmbeddingStore<TextSegment> embeddingStore = mock(EmbeddingStore.class);
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.minScore(minScore)
.build();

// when
String result = contentRetriever.toString();

// then
assertThat(result).contains(EmbeddingStoreContentRetriever.DEFAULT_DISPLAY_NAME);
}
}

0 comments on commit 6dd3567

Please sign in to comment.