Skip to content

Commit

Permalink
Add rerank model type for LocalAI provider (#3952)
Browse files Browse the repository at this point in the history
  • Loading branch information
thiner committed May 11, 2024
1 parent 2c1c660 commit a588df4
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ help:
supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
configurate_methods:
- customizable-model
Expand Down
Empty file.
120 changes: 120 additions & 0 deletions api/core/model_runtime/model_providers/localai/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from json import dumps
from typing import Optional

import httpx
from requests import post
from yarl import URL

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class LocalaiRerankModel(RerankModel):
"""
LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here.
"""

def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])

server_url = credentials['server_url']
model_name = model

if not server_url:
raise CredentialsValidateFailedError('server_url is required')
if not model_name:
raise CredentialsValidateFailedError('model_name is required')

url = server_url
headers = {
'Authorization': f"Bearer {credentials.get('api_key')}",
'Content-Type': 'application/json'
}

data = {
"model": model_name,
"query": query,
"documents": docs,
"top_n": top_n
}

try:
response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10)
response.raise_for_status()
results = response.json()

rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)

return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:

self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}
158 changes: 158 additions & 0 deletions api/tests/integration_tests/model_runtime/localai/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os

import pytest
from api.core.model_runtime.entities.rerank_entities import RerankResult

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel


def test_validate_credentials_for_chat_model():
model = LocalaiRerankModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-v2-m3',
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
)

model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
)

def test_invoke_rerank_model():
model = LocalaiRerankModel()

response = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
query='Organic skincare products for sensitive skin',
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
],
top_n=3,
score_threshold=0.75,
user="abc-123"
)

assert isinstance(response, RerankResult)
assert len(response.docs) == 3
import os

import pytest
from api.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel


def test_validate_credentials_for_chat_model():
model = LocalaiRerankModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-v2-m3',
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
)

model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
)

def test_invoke_rerank_model():
model = LocalaiRerankModel()

response = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
query='Organic skincare products for sensitive skin',
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
],
top_n=3,
score_threshold=0.75,
user="abc-123"
)

assert isinstance(response, RerankResult)
assert len(response.docs) == 3

def test__invoke():
model = LocalaiRerankModel()

# Test case 1: Empty docs
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123"
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0

# Test case 2: Valid invocation
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
],
top_n=3,
score_threshold=0.75,
user="abc-123"
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)

0 comments on commit a588df4

Please sign in to comment.