Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/vector db pgvector #3879

Merged
merged 12 commits into from
May 10, 2024
4 changes: 3 additions & 1 deletion .github/workflows/api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,23 @@ jobs:
- name: Run Workflow
run: dev/pytest/pytest_workflow.sh

- name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
docker/docker-compose.qdrant.yaml
docker/docker-compose.milvus.yaml
docker/docker-compose.pgvecto-rs.yaml
docker/docker-compose.pgvector.yaml
services: |
weaviate
qdrant
etcd
minio
milvus-standalone
pgvecto-rs
pgvector

- name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh
9 changes: 8 additions & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*

# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector
VECTOR_STORE=weaviate

# Weaviate configuration
Expand Down Expand Up @@ -99,6 +99,13 @@ PGVECTO_RS_USER=postgres
PGVECTO_RS_PASSWORD=difyai123456
PGVECTO_RS_DATABASE=postgres

# PGVector configuration
PGVECTOR_HOST=127.0.0.1
PGVECTOR_PORT=5433
PGVECTOR_USER=postgres
PGVECTOR_PASSWORD=postgres
PGVECTOR_DATABASE=postgres

# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
Expand Down
8 changes: 8 additions & 0 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "pgvector":
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'pgvector',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

Expand Down
9 changes: 8 additions & 1 deletion api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(self):

# ------------------------
# Vector Store Configurations.
# Currently, only support: qdrant, milvus, zilliz, weaviate, relyt
# Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
Expand Down Expand Up @@ -259,6 +259,13 @@ def __init__(self):
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')

# pgvector settings
self.PGVECTOR_HOST = get_env('PGVECTOR_HOST')
self.PGVECTOR_PORT = get_env('PGVECTOR_PORT')
self.PGVECTOR_USER = get_env('PGVECTOR_USER')
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')

# ------------------------
# Mail Configurations.
# ------------------------
Expand Down
9 changes: 4 additions & 5 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt':
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
elif vector_type in {"qdrant", "weaviate"}:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
Expand All @@ -497,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):

if vector_type == 'milvus' or vector_type == 'relyt':
if vector_type in {'milvus', 'relyt', 'pgvector'}:
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
elif vector_type in {'qdrant', 'weaviate'}:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
Expand Down
Empty file.
169 changes: 169 additions & 0 deletions api/core/rag/datasource/vdb/pgvector/pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any

import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, root_validator

from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client


class PGVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str

@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")
if not values["port"]:
raise ValueError("config PGVECTOR_PORT is required")
if not values["user"]:
raise ValueError("config PGVECTOR_USER is required")
if not values["password"]:
raise ValueError("config PGVECTOR_PASSWORD is required")
if not values["database"]:
raise ValueError("config PGVECTOR_DATABASE is required")
return values


SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
"""


class PGVector(BaseVector):
def __init__(self, collection_name: str, config: PGVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"

def get_type(self) -> str:
return "pgvector"

def _create_connection_pool(self, config: PGVectorConfig):
return psycopg2.pool.SimpleConnectionPool(
1,
5,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)

@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
)
return pks

def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None

def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs

def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
vaayne marked this conversation as resolved.
Show resolved Hide resolved
"""
Search the nearest neighbors to a vector.

:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 5)

with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# do not support bm25 search
return []

def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return

with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
redis_client.set(collection_exist_cache_key, 1, ex=3600)
23 changes: 23 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ def _init_vector(self) -> BaseVector:
),
dim=dim
)
elif vector_type == "pgvector":
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig

if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": "pgvector",
"vector_store": {"class_prefix": collection_name}}
self._dataset.index_struct = json.dumps(index_struct_dict)
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=config.get("PGVECTOR_HOST"),
port=config.get("PGVECTOR_PORT"),
user=config.get("PGVECTOR_USER"),
password=config.get("PGVECTOR_PASSWORD"),
database=config.get("PGVECTOR_DATABASE"),
),
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

Expand Down
1 change: 1 addition & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ pydantic~=1.10.0
pgvecto-rs==0.1.4
firecrawl-py==0.0.5
oss2==2.15.0
pgvector==0.2.5
Empty file.
30 changes: 30 additions & 0 deletions api/tests/integration_tests/vdb/pgvector/test_pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)


class TestPGVector(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = PGVector(
collection_name=self.collection_name,
config=PGVectorConfig(
host="localhost",
port=5433,
user="postgres",
password="difyai123456",
database="dify",
),
)

def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0


def test_pgvector(setup_mock_redis):
TestPGVector().run_all_tests()
24 changes: 24 additions & 0 deletions docker/docker-compose.pgvector.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
version: '3'
services:
# Qdrant vector store.
pgvector:
image: pgvector/pgvector:pg16
restart: always
environment:
PGUSER: postgres
# The password for the default postgres user.
POSTGRES_PASSWORD: difyai123456
# The name of the default postgres database.
POSTGRES_DB: dify
# postgres data directory
PGDATA: /var/lib/postgresql/data/pgdata
volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data
# uncomment to expose db(postgresql) port to host
ports:
- "5433:5432"
healthcheck:
test: [ "CMD", "pg_isready" ]
interval: 1s
timeout: 3s
retries: 30