Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rerank model type for LocalAI provider #3952

Merged
merged 6 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/core/model_runtime/model_providers/localai/localai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ help:
supported_model_types:
- llm
- text-embedding
- rerank
# - tts
# - speech2text
configurate_methods:
- customizable-model
model_credential_schema:
Expand Down
Empty file.
120 changes: 120 additions & 0 deletions api/core/model_runtime/model_providers/localai/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from json import dumps
from typing import Optional

import httpx
from requests import post
from yarl import URL

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class LocalaiRerankModel(RerankModel):
"""
LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
"""

def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])

server_url = credentials['server_url']
model_name = model

if not server_url:
raise CredentialsValidateFailedError('server_url is required')
if not model_name:
raise CredentialsValidateFailedError('model_name is required')

url = server_url
headers = {
'Authorization': f"Bearer {credentials.get('api_key')}",
'Content-Type': 'application/json'
}

data = {
"model": model_name,
"query": query,
"documents": docs,
"top_n": top_n
}

try:
response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
response.raise_for_status()
results = response.json()

rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)

return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:

self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}
154 changes: 154 additions & 0 deletions api/tests/integration_tests/model_runtime/localai/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os

import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
from api.core.model_runtime.entities.rerank_entities import RerankResult


def test_validate_credentials_for_chat_model():
model = LocalaiRerankModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-v2-m3',
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
)

model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
)

def test_invoke_rerank_model():
model = LocalaiRerankModel()

response = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
query='Organic skincare products for sensitive skin',
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
],
top_n=3,
score_threshold=0.75,
user="abc-123"
)

assert isinstance(response, RerankResult)
assert len(response.docs) == 3
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel
from api.core.model_runtime.entities.rerank_entities import RerankResult, RerankDocument

def test_validate_credentials_for_chat_model():
model = LocalaiRerankModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-v2-m3',
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
)

model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
)

def test_invoke_rerank_model():
model = LocalaiRerankModel()

response = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
query='Organic skincare products for sensitive skin',
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
],
top_n=3,
score_threshold=0.75,
user="abc-123"
)

assert isinstance(response, RerankResult)
assert len(response.docs) == 3

def test__invoke():
model = LocalaiRerankModel()

# Test case 1: Empty docs
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123"
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0

# Test case 2: Valid invocation
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
],
top_n=3,
score_threshold=0.75,
user="abc-123"
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)