Skip to content

Commit

Permalink
Issue 972
Browse files Browse the repository at this point in the history
Add display name to dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever to improve logs
  • Loading branch information
alwa authored and alixwar committed Apr 23, 2024
1 parent 197b4af commit 80384b9
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
* <br>
* Configurable parameters (optional):
* <br>
* - {@link #displayName}: Display name for this retriever for logging purposes.
* <br>
* - {@link #maxResults}: The maximum number of {@link Content}s to retrieve.
* <br>
* - {@link #minScore}: The minimum relevance score for the returned {@link Content}s.
Expand All @@ -32,28 +34,32 @@ public class EmbeddingStoreContentRetriever implements ContentRetriever {

public static final int DEFAULT_MAX_RESULTS = 3;
public static final double DEFAULT_MIN_SCORE = 0;
public static final String DEFAULT_DISPLAY_NAME = "Default";

private final EmbeddingStore<TextSegment> embeddingStore;
private final EmbeddingModel embeddingModel;
private final String displayName;
private final int maxResults;
private final double minScore;

public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel) {
this(embeddingStore, embeddingModel, DEFAULT_MAX_RESULTS, DEFAULT_MIN_SCORE);
this(DEFAULT_DISPLAY_NAME, embeddingStore, embeddingModel, DEFAULT_MAX_RESULTS, DEFAULT_MIN_SCORE);
}

public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
int maxResults) {
this(embeddingStore, embeddingModel, maxResults, DEFAULT_MIN_SCORE);
this(DEFAULT_DISPLAY_NAME, embeddingStore, embeddingModel, maxResults, DEFAULT_MIN_SCORE);
}

@Builder
public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore,
public EmbeddingStoreContentRetriever(String displayName,
EmbeddingStore<TextSegment> embeddingStore,
EmbeddingModel embeddingModel,
Integer maxResults,
Double minScore) {
this.displayName = getOrDefault(displayName, DEFAULT_DISPLAY_NAME);
this.embeddingStore = ensureNotNull(embeddingStore, "embeddingStore");
this.embeddingModel = ensureNotNull(embeddingModel, "embeddingModel");
this.maxResults = ensureGreaterThanZero(getOrDefault(maxResults, DEFAULT_MAX_RESULTS), "maxResults");
Expand All @@ -72,4 +78,13 @@ public List<Content> retrieve(Query query) {
.map(Content::from)
.collect(toList());
}

@Override
public String toString() {
return "EmbeddingStoreContentRetriever{" +
"displayName='" + displayName + '\'' +
", maxResults=" + maxResults +
", minScore=" + minScore +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void should_retrieve_with_custom_minScore() {
when(embeddingModel.embed(anyString())).thenReturn(Response.from(embedding));

ContentRetriever contentRetriever = new EmbeddingStoreContentRetriever(
null, // displayName
embeddingStore,
embeddingModel,
null, // maxResults
Expand Down Expand Up @@ -243,4 +244,79 @@ void should_retrieve_with_custom_minScore_builder() {
verify(embeddingStore).findRelevant(embedding, 3, minScore);
verifyNoMoreInteractions(embeddingStore);
}

@Test
void should_include_correct_values_in_to_string_with_default_display_name() {

// given
double minScore = 0.7;
EmbeddingStore<TextSegment> embeddingStore = mock(EmbeddingStore.class);
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.minScore(minScore)
.build();

// when
String result = contentRetriever.toString();

// then
assertThat(result).contains(EmbeddingStoreContentRetriever.DEFAULT_DISPLAY_NAME);
assertThat(result).contains(Integer.toString(EmbeddingStoreContentRetriever.DEFAULT_MAX_RESULTS));
assertThat(result).contains(Double.toString(minScore));
}

@Test
void should_include_correct_values_in_to_string_with_default_max_results() {

// given
double minScore = 0.7;
String displayName = "unique";
EmbeddingStore<TextSegment> embeddingStore = mock(EmbeddingStore.class);
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.displayName(displayName)
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.minScore(minScore)
.build();

// when
String result = contentRetriever.toString();

// then
assertThat(result).contains(displayName);
assertThat(result).contains(Integer.toString(EmbeddingStoreContentRetriever.DEFAULT_MAX_RESULTS));
assertThat(result).contains(Double.toString(minScore));
}

@Test
void should_include_correct_values_in_to_string_with_explicit_max_results() {

// given
double minScore = 0.7;
int maxResults = 1711;
String displayName = "unique";
EmbeddingStore<TextSegment> embeddingStore = mock(EmbeddingStore.class);
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.displayName(displayName)
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.minScore(minScore)
.maxResults(maxResults)
.build();

// when
String result = contentRetriever.toString();

// then
assertThat(result).contains(displayName);
assertThat(result).contains(Integer.toString(maxResults));
assertThat(result).contains(Double.toString(minScore));
}
}

0 comments on commit 80384b9

Please sign in to comment.