Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Return retrieved Contents/TextSegments when using AI Service with RAG #1015

Merged
22 changes: 22 additions & 0 deletions docs/docs/tutorials/7-rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ and retrieve relevant content from an `EmbeddingStore` that contains our documen
```java
String answer = assistant.chat("How to do Easy RAG with LangChain4j?");
```
### Augmented User Message
If you need to access the data used for augmenting the user message, you can easily do so by wrapping the response in the `WithSources` class.

Here's an illustration:
```java
interface Assistant {

WithSources<String> chat(String userMessage);
}
... The remaining code is the same :)
```
`WithSources` class contains the information used to augment the user message :
* `response` : The response to the user's input.
* `augmentedMessage`: A wrapper for augmentation details.
* `usermessage`: The augmented user message that was sent to the LLM.
* `contents`: The list of documents used to enrich the user message including the `metadata` of each document.

Attempting to use the `WithSources` class without specifying a generic type will result in an `IllegalArgumentException`. For instance:
```java

WithSources chat(String userMessage); // Throw an IllegalArgumentException
```

## RAG APIs
LangChain4j offers a rich set of APIs to make it easy for you to build custom RAG pipelines,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dev.langchain4j.data.message;

import dev.langchain4j.rag.content.Content;
import lombok.Builder;
import lombok.Getter;

import java.util.List;

/**
* Represents an augmented message containing information about the user message and associated contents.
*/
@Getter
@Builder
public class AugmentedMessage {
private final UserMessage userMessage; // The augmented user message.
private final List<Content> contents; // The list of contents used to augment the associated the user message.
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.rag;

import dev.langchain4j.data.message.AugmentedMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
Expand Down Expand Up @@ -95,8 +96,9 @@
* @see DefaultQueryRouter
* @see DefaultContentAggregator
* @see DefaultContentInjector
* @see AugmentedMessage
*/
public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
public class DefaultRetrievalAugmentor implements RetrievalAugmentor<AugmentedMessage> {

private static final Logger log = LoggerFactory.getLogger(DefaultRetrievalAugmentor.class);

Expand All @@ -120,7 +122,7 @@ public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
}

@Override
public UserMessage augment(UserMessage userMessage, Metadata metadata) {
public AugmentedMessage augment(UserMessage userMessage, Metadata metadata) {

Query originalQuery = Query.from(userMessage.text(), metadata);

Expand Down Expand Up @@ -148,7 +150,10 @@ public UserMessage augment(UserMessage userMessage, Metadata metadata) {
UserMessage augmentedUserMessage = contentInjector.inject(contents, userMessage);
log(augmentedUserMessage);

return augmentedUserMessage;
return AugmentedMessage.builder()
.userMessage(augmentedUserMessage)
.contents(contents)
.build();
}

private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* @see DefaultRetrievalAugmentor
*/
@Experimental
public interface RetrievalAugmentor {
public interface RetrievalAugmentor<T> {

/**
* Augments the provided {@link UserMessage} with retrieved content.
Expand All @@ -23,5 +23,5 @@ public interface RetrievalAugmentor {
* @param metadata The {@link Metadata} that may be useful or necessary for retrieval and augmentation.
* @return The augmented {@link UserMessage}.
*/
UserMessage augment(UserMessage userMessage, Metadata metadata);
T augment(UserMessage userMessage, Metadata metadata);
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void should_augment_user_message(Executor executor) {

ContentInjector contentInjector = spy(new TestContentInjector());

RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.queryRouter(queryRouter)
.contentAggregator(contentAggregator)
Expand All @@ -67,7 +67,7 @@ void should_augment_user_message(Executor executor) {
Metadata metadata = Metadata.from(userMessage, null, null);

// when
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata).getUserMessage();

// then
assertThat(augmented.singleText()).isEqualTo(
Expand Down Expand Up @@ -126,7 +126,7 @@ void should_not_augment_when_router_does_not_return_retrievers(Executor executor
List<ContentRetriever> retrievers = emptyList();
QueryRouter queryRouter = spy(new TestQueryRouter(retrievers));

RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(queryRouter)
.executor(executor)
.build();
Expand All @@ -136,7 +136,7 @@ void should_not_augment_when_router_does_not_return_retrievers(Executor executor
Metadata metadata = Metadata.from(userMessage, null, null);

// when
UserMessage augmentedUserMessage = retrievalAugmentor.augment(userMessage, metadata);
UserMessage augmentedUserMessage = retrievalAugmentor.augment(userMessage, metadata).getUserMessage();

// then
assertThat(augmentedUserMessage).isEqualTo(userMessage);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.langchain4j.chain;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.AugmentedMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
Expand Down Expand Up @@ -30,7 +31,7 @@ public class ConversationalRetrievalChain implements Chain<String, String> {

private final ChatLanguageModel chatLanguageModel;
private final ChatMemory chatMemory;
private final RetrievalAugmentor retrievalAugmentor;
private final RetrievalAugmentor<AugmentedMessage> retrievalAugmentor;

public ConversationalRetrievalChain(ChatLanguageModel chatLanguageModel,
ChatMemory chatMemory,
Expand Down Expand Up @@ -77,7 +78,7 @@ public String execute(String query) {

UserMessage userMessage = UserMessage.from(query);
Metadata metadata = Metadata.from(userMessage, chatMemory.id(), chatMemory.messages());
userMessage = retrievalAugmentor.augment(userMessage, metadata);
userMessage = retrievalAugmentor.augment(userMessage, metadata).getUserMessage();
chatMemory.add(userMessage);

AiMessage aiMessage = chatLanguageModel.generate(chatMemory.messages()).content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AugmentedMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand Down Expand Up @@ -31,7 +32,7 @@ public class AiServiceContext {
public List<ToolSpecification> toolSpecifications;
public Map<String, ToolExecutor> toolExecutors;

public RetrievalAugmentor retrievalAugmentor;
public RetrievalAugmentor<AugmentedMessage> retrievalAugmentor;

public Function<Object, Optional<String>> systemMessageProvider = DEFAULT_MESSAGE_PROVIDER;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(method, args);

AugmentedMessage augmentedMessage = null;
if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
userMessage = context.retrievalAugmentor.augment(userMessage, metadata);
augmentedMessage = context.retrievalAugmentor.augment(userMessage, metadata);
userMessage = augmentedMessage.getUserMessage();
}

// TODO give user ability to provide custom OutputParser
String outputFormatInstructions = outputFormatInstructions(method.getReturnType());
Class<?> returnType = method.getReturnType();
String outputFormatInstructions = outputFormatInstructions(returnType);
userMessage = UserMessage.from(userMessage.text() + outputFormatInstructions);

if (context.hasChatMemory()) {
Expand All @@ -118,7 +120,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

Future<Moderation> moderationFuture = triggerModerationIfNeeded(method, messages);

if (method.getReturnType() == TokenStream.class) {
if (returnType == TokenStream.class) {
return new AiServiceTokenStream(messages, context, memoryId); // TODO moderation
}

Expand Down Expand Up @@ -173,7 +175,26 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
}

response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
return parse(response, method.getReturnType());
if (returnType != WithSources.class) {
return parse(response, returnType);
}
AnnotatedType annotatedReturnType = method.getAnnotatedReturnType();

Type withSourcesAnnotatedType = annotatedReturnType.getType();
if (withSourcesAnnotatedType instanceof ParameterizedType) {
ParameterizedType type = (ParameterizedType) withSourcesAnnotatedType;
Type[] typeArguments = type.getActualTypeArguments();
for (Type typeArg : typeArguments) {
returnType = Class.forName(typeArg.getTypeName());
}
} else {
throw illegalArgument("WithSources needs to have a generic class defined for the following method : %s", method.getName());
}
Object parsedResponse = parse(response, returnType);
return WithSources.builder()
.response(parsedResponse)
.augmentedMessage(augmentedMessage)
.build();
}

private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public static String outputFormatInstructions(Class<?> returnType) {
if (returnType == String.class
|| returnType == AiMessage.class
|| returnType == TokenStream.class
|| returnType == Response.class) {
|| returnType == Response.class
|| returnType == WithSources.class) {
return "";
}

Expand Down
17 changes: 17 additions & 0 deletions langchain4j/src/main/java/dev/langchain4j/service/WithSources.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dev.langchain4j.service;

import dev.langchain4j.data.message.AugmentedMessage;
import lombok.Builder;
import lombok.Getter;

/**
* Represents a container holding augmented information associated with a response.
*
* @param <T> The type of the response.
*/
@Getter
@Builder
public class WithSources<T> {
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
private T response; // The response associated with the augmented information.
private AugmentedMessage augmentedMessage; // Wrapper for the augmentation details.
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import dev.langchain4j.store.embedding.filter.builder.sql.LanguageModelSqlFilterBuilder;
import dev.langchain4j.store.embedding.filter.builder.sql.TableDefinition;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
Expand All @@ -58,6 +59,7 @@
import static java.util.Collections.emptyList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.*;

class AiServicesWithRagIT {
Expand Down Expand Up @@ -210,6 +212,77 @@ void should_use_query_transformer_and_content_retriever(ChatLanguageModel model)
assertThat(answer).containsAnyOf(ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS);
}

interface AssistantWithSources {
WithSources<String> answer(String query);
WithSources answerWithNoGenericType(String query);
}

@ParameterizedTest
@MethodSource("models")
void should_use_query_transformer_and_content_retriever_and_retrieve_sources(ChatLanguageModel model) {
langchain4j marked this conversation as resolved.
Show resolved Hide resolved

// given
QueryTransformer queryTransformer = new ExpandingQueryTransformer(model);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(1)
.build();

AssistantWithSources assistant = AiServices.builder(AssistantWithSources.class)
.chatLanguageModel(model)
.retrievalAugmentor(DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.contentRetriever(contentRetriever)
.build())
.build();

// when
WithSources<String> answer = assistant.answer("Can I cancel my booking?");

// then
assertThat(answer.getResponse()).containsAnyOf(ALLOWED_CANCELLATION_PERIOD_DAYS, MIN_BOOKING_PERIOD_DAYS);
Assertions.assertNotNull(answer.getAugmentedMessage());
assertThat(answer.getAugmentedMessage().getContents().size()).isEqualTo(1);
assertThat(answer.getAugmentedMessage().getContents().get(0).textSegment().text().replace(System.lineSeparator(), ""))
.isEqualTo(
"4. Cancellation Policy" +
"4.1 Reservations can be cancelled up to 61 days prior to the start of the booking period." +
"4.2 If the booking period is less than 17 days, cancellations are not permitted."
);
assertThat(answer.getAugmentedMessage().getContents().get(0).textSegment().metadata("index")).isEqualTo("3");
assertThat(answer.getAugmentedMessage().getContents().get(0).textSegment().metadata("file_name")).isEqualTo("miles-of-smiles-terms-of-use.txt");


}

@ParameterizedTest
@MethodSource("models")
void should_use_query_transformer_and_content_retriever_and_through_exception_when_generic_type_is_not_set(ChatLanguageModel model) {
langchain4j marked this conversation as resolved.
Show resolved Hide resolved

// given
QueryTransformer queryTransformer = new ExpandingQueryTransformer(model);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingStore)
.embeddingModel(embeddingModel)
.maxResults(1)
.build();

AssistantWithSources assistant = AiServices.builder(AssistantWithSources.class)
.chatLanguageModel(model)
.retrievalAugmentor(DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.contentRetriever(contentRetriever)
.build())
.build();

// when
assertThrows(IllegalArgumentException.class, () -> assistant.answerWithNoGenericType("Can I cancel my booking?"));
langchain4j marked this conversation as resolved.
Show resolved Hide resolved

}
langchain4j marked this conversation as resolved.
Show resolved Hide resolved

langchain4j marked this conversation as resolved.
Show resolved Hide resolved
@ParameterizedTest
@MethodSource("models")
void should_use_query_router_and_content_retriever(ChatLanguageModel model) {
Expand Down