Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to override retrieval_chain flag (for custom chains with multi-key output) #25

Open
nb-programmer opened this issue Jul 21, 2023 · 3 comments

Comments

@nb-programmer
Copy link

Title says all. Current implementation assumes len(chain.output_keys) > 1 is a retrieval chain, but that is not always the case, since it can be any other type of chain (like a custom one which won't have the source key).

Or alternatively, the detection may be made more strict by checking the actual output keys

@msoedov
Copy link
Owner

msoedov commented Jul 21, 2023

Hi @nb-programmer, do you have a code example that reproduces this issue?

@nb-programmer
Copy link
Author

nb-programmer commented Jul 21, 2023

Sure:

demo_bug.py

from langchain.chains.base import Chain
from langchain.callbacks.manager import CallbackManagerForChainRun

from typing import Dict, Optional, Any, List


class CustomChain(Chain):
    input_key: str = "input"
    output_key: str = "output"

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        print("input:", inputs)
        return {self.output_key: "Hello", "other": "test"}

    @property
    def input_keys(self) -> List[str]:
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """:meta private:"""
        return [self.output_key, "other"]


chain = CustomChain()
$ langcorn server demo_bug:chain

Then call the endpoint with any input, eg:

{
  "input": "test"
}

500 Error: Internal Server Error

File "venv\lib\site-packages\langcorn\server\api.py", line 118, in handler
    source_documents=[str(t) for t in output.get("source_documents")],
TypeError: 'NoneType' object is not iterable

Expected: It should return all keys, or just the output key without exception

@msoedov
Copy link
Owner

msoedov commented Aug 9, 2023

Thx for the example @nb-programmer . Fixed the bug in e803e5d . Going to release a new version today

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants