Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Return retrieved Contents/TextSegments when using AI Service with RAG #1015

Merged
8 changes: 4 additions & 4 deletions docs/docs/tutorials/5-ai-services.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ you can change the return type of your AI Service method from `String` to someth
Currently, AI Services support the following return types:
- `String`
- `AiMessage`
- `Response<AiMessage>` (if you need to access `TokenUsage` or `FinishReason`)
- `boolean`/`Boolean` (if you need to get "yes" or "no" answer)
- `boolean`/`Boolean`, if you need to get "yes" or "no" answer
- `byte`/`Byte`/`short`/`Short`/`int`/`Integer`/`BigInteger`/`long`/`Long`/`float`/`Float`/`double`/`Double`/`BigDecimal`
- `Date`/`LocalDate`/`LocalTime`/`LocalDateTime`
- `List<String>`/`Set<String>` (if you want to get the answer in the form of a list of bullet points)
- Any `Enum` (if you want to classify text, e.g. sentiment, user intent, etc)
- `List<String>`/`Set<String>`, if you want to get the answer in the form of a list of bullet points
- Any `Enum`, if you want to classify text, e.g. sentiment, user intent, etc.
- Any custom POJO
- `Result<T>`, if you need to access `TokenUsage` or sources (`Content`s retrieved during RAG), aside from `T`, which can be of any type listed above. For example: `Result<String>`, `Result<MyCustomPojo>`

Unless the return type is `String`, `AiMessage`, or `Response<AiMessage>`,
the AI Service will automatically append instructions to the end of `UserMessage` indicating the format
Expand Down
15 changes: 15 additions & 0 deletions docs/docs/tutorials/7-rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ and retrieve relevant content from an `EmbeddingStore` that contains our documen
String answer = assistant.chat("How to do Easy RAG with LangChain4j?");
```

## Accessing Sources
If you wish to access the sources (retrieved `Content`s used to augment the message),
you can easily do so by wrapping the return type in the `Result` class:
```java
interface Assistant {

Result<String> chat(String userMessage);
}

Result<String> result = assistant.chat("How to do Easy RAG with LangChain4j?");

String answer = result.content();
List<Content> sources = result.sources();
```

## RAG APIs
LangChain4j offers a rich set of APIs to make it easy for you to build custom RAG pipelines,
ranging from simple ones to advanced ones.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.rag;


import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.query.Metadata;

import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* Represents a request for {@link ChatMessage} augmentation.
*/
public class AugmentationRequest {

/**
* The chat message to be augmented.
* Currently, it is a {@link UserMessage}, but soon it could also be a {@link SystemMessage}.
*/
private final ChatMessage chatMessage;

/**
* Additional metadata related to the augmentation request.
*/
private final Metadata metadata;

public AugmentationRequest(ChatMessage chatMessage, Metadata metadata) {
this.chatMessage = ensureNotNull(chatMessage, "chatMessage");
this.metadata = ensureNotNull(metadata, "metadata");
}

public ChatMessage chatMessage() {
return chatMessage;
}

public Metadata metadata() {
return metadata;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package dev.langchain4j.rag;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.rag.content.Content;
import lombok.Builder;

import java.util.List;

import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* Represents the result of a {@link ChatMessage} augmentation.
*/
public class AugmentationResult {

/**
* The augmented chat message.
*/
private final ChatMessage chatMessage;

/**
* A list of content used to augment the original chat message.
*/
private final List<Content> contents;

@Builder
public AugmentationResult(ChatMessage chatMessage, List<Content> contents) {
this.chatMessage = ensureNotNull(chatMessage, "chatMessage");
this.contents = copyIfNotNull(contents);
}

public ChatMessage chatMessage() {
return chatMessage;
}

public List<Content> contents() {
return contents;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.rag;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
Expand Down Expand Up @@ -119,10 +120,23 @@ public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
this.executor = getOrDefault(executor, Executors::newCachedThreadPool);
}

/**
* @deprecated use {@link #augment(AugmentationRequest)} instead.
*/
@Override
@Deprecated
public UserMessage augment(UserMessage userMessage, Metadata metadata) {
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
return (UserMessage) augment(augmentationRequest).chatMessage();
}

@Override
public AugmentationResult augment(AugmentationRequest augmentationRequest) {

ChatMessage chatMessage = augmentationRequest.chatMessage();
Metadata metadata = augmentationRequest.metadata();

Query originalQuery = Query.from(userMessage.text(), metadata);
Query originalQuery = Query.from(chatMessage.text(), metadata);

Collection<Query> queries = queryTransformer.transform(originalQuery);
logQueries(originalQuery, queries);
Expand All @@ -145,10 +159,13 @@ public UserMessage augment(UserMessage userMessage, Metadata metadata) {
List<Content> contents = contentAggregator.aggregate(queryToContents);
log(queryToContents, contents);

UserMessage augmentedUserMessage = contentInjector.inject(contents, userMessage);
log(augmentedUserMessage);
ChatMessage augmentedChatMessage = contentInjector.inject(contents, chatMessage);
log(augmentedChatMessage);

return augmentedUserMessage;
return AugmentationResult.builder()
.chatMessage(augmentedChatMessage)
.contents(contents)
.build();
}

private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
Expand Down Expand Up @@ -247,8 +264,8 @@ private static void log(Map<Query, Collection<List<Content>>> queryToContents, L
.collect(joining("\n")));
}

private static void log(UserMessage augmentedUserMessage) {
log.trace("Augmented user message: " + escapeNewlines(augmentedUserMessage.singleText()));
private static void log(ChatMessage augmentedChatMessage) {
log.trace("Augmented chat message: {}", escapeNewlines(augmentedChatMessage.text()));
}

private static String escapeNewlines(String text) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package dev.langchain4j.rag;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.query.Metadata;

import static dev.langchain4j.internal.Exceptions.runtime;

/**
* Augments the provided {@link UserMessage} with retrieved content.
* Augments the provided {@link ChatMessage} with retrieved {@link Content}s.
* <br>
* This serves as an entry point into the RAG flow in LangChain4j.
* <br>
Expand All @@ -16,12 +20,37 @@
@Experimental
public interface RetrievalAugmentor {

/**
* Augments the {@link ChatMessage} provided in the {@link AugmentationRequest} with retrieved {@link Content}s.
* <br>
* This method has a default implementation in order to <b>temporarily</b> support
* current custom implementations of {@code RetrievalAugmentor}. The default implementation will be removed soon.
*
* @param augmentationRequest The {@code AugmentationRequest} containing the {@code ChatMessage} to augment.
* @return The {@link AugmentationResult} containing the augmented {@code ChatMessage}.
*/
default AugmentationResult augment(AugmentationRequest augmentationRequest) {

if (!(augmentationRequest.chatMessage() instanceof UserMessage)) {
throw runtime("Please implement 'AugmentationResult augment(AugmentationRequest)' method " +
"in order to augment " + augmentationRequest.chatMessage().getClass());
}

UserMessage augmented = augment((UserMessage) augmentationRequest.chatMessage(), augmentationRequest.metadata());

return AugmentationResult.builder()
.chatMessage(augmented)
.build();
}

/**
* Augments the provided {@link UserMessage} with retrieved content.
*
* @param userMessage The {@link UserMessage} to be augmented.
* @param metadata The {@link Metadata} that may be useful or necessary for retrieval and augmentation.
* @return The augmented {@link UserMessage}.
* @deprecated Use/implement {@link #augment(AugmentationRequest)} instead.
*/
@Deprecated
UserMessage augment(UserMessage userMessage, Metadata metadata);
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package dev.langchain4j.rag.content.injector;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;

import java.util.List;

import static dev.langchain4j.internal.Exceptions.runtime;

/**
* Injects given {@link Content}s into a given {@link UserMessage}.
* <br>
Expand All @@ -17,12 +21,35 @@
@Experimental
public interface ContentInjector {

/**
* Injects given {@link Content}s into a given {@link ChatMessage}.
* <br>
* This method has a default implementation in order to <b>temporarily</b> support
* current custom implementations of {@code ContentInjector}. The default implementation will be removed soon.
*
* @param contents The list of {@link Content} to be injected.
* @param chatMessage The {@link ChatMessage} into which the {@link Content}s are to be injected.
* Can be either a {@link UserMessage} or a {@link SystemMessage}.
* @return The {@link UserMessage} with the injected {@link Content}s.
*/
default ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {

if (!(chatMessage instanceof UserMessage)) {
throw runtime("Please implement 'ChatMessage inject(List<Content>, ChatMessage)' method " +
"in order to inject contents into " + chatMessage);
}

return inject(contents, (UserMessage) chatMessage);
}

/**
* Injects given {@link Content}s into a given {@link UserMessage}.
*
* @param contents The list of {@link Content} to be injected.
* @param userMessage The {@link UserMessage} into which the {@link Content}s are to be injected.
* @return The {@link UserMessage} with the injected {@link Content}s.
* @deprecated Use/implement {@link #inject(List, ChatMessage)} instead.
*/
@Deprecated
UserMessage inject(List<Content> contents, UserMessage userMessage);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.langchain4j.rag.content.injector;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.input.Prompt;
Expand Down Expand Up @@ -71,6 +72,25 @@ public DefaultContentInjector(PromptTemplate promptTemplate, List<String> metada
}

@Override
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {

if (contents.isEmpty()) {
return chatMessage;
}

Prompt prompt = createPrompt(chatMessage, contents);
return prompt.toUserMessage();
}

protected Prompt createPrompt(ChatMessage chatMessage, List<Content> contents) {
return createPrompt((UserMessage) chatMessage, contents);
}

/**
* @deprecated use {@link #inject(List, ChatMessage)} instead.
*/
@Override
@Deprecated
public UserMessage inject(List<Content> contents, UserMessage userMessage) {

if (contents.isEmpty()) {
Expand All @@ -81,6 +101,10 @@ public UserMessage inject(List<Content> contents, UserMessage userMessage) {
return prompt.toUserMessage();
}

/**
* @deprecated implement/override {@link #createPrompt(ChatMessage, List)} instead.
*/
@Deprecated
protected Prompt createPrompt(UserMessage userMessage, List<Content> contents) {
Map<String, Object> variables = new HashMap<>();
variables.put("userMessage", userMessage.text());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.rag;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
Expand Down Expand Up @@ -115,6 +116,10 @@ void should_augment_user_message(Executor executor) {
content1, content2, content3, content4,
content1, content2, content3, content4
), userMessage);
verify(contentInjector).inject(asList(
content1, content2, content3, content4,
content1, content2, content3, content4
), (ChatMessage) userMessage);
verifyNoMoreInteractions(contentInjector);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import java.net.Proxy;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down