Skip to content

Commit

Permalink
using psycopg2 as version of sqlalchemy is too slow
Browse files Browse the repository at this point in the history
  • Loading branch information
vaayne committed Apr 26, 2024
1 parent 942b77b commit d062099
Showing 1 changed file with 84 additions and 68 deletions.
152 changes: 84 additions & 68 deletions api/core/rag/datasource/vdb/pgvector/pgvector.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,110 @@
from pgvector.sqlalchemy import Vector
import json
import uuid
from contextlib import contextmanager
from typing import Any

import psycopg2.pool
from pydantic import BaseModel
from sqlalchemy import JSON, Column, MetaData, String, Table, create_engine
from sqlalchemy.engine import URL
from sqlalchemy.orm import mapped_column

from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client


class PGVectorConfig(BaseModel):
host: str
port: int
host: str = "localhost"
port: int = 5432
user: str
password: str
database: str


def new_embedding_table(collection_name: str, dimension: int):
metadata_obj = MetaData()
return Table(
f"embedding_{collection_name}",
metadata_obj,
Column("id", String, primary_key=True),
Column("meta", JSON),
Column("text", String),
Column("embedding", mapped_column(Vector(dimension))),
)
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id UUID PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
"""


class PGVector(BaseVector):
def __init__(self, collection_name: str, config: PGVectorConfig, dimension: int):
super().__init__(collection_name)
self.engine = self._create_engine(config)
self.table = None
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"

def get_type(self) -> str:
return "pgvector"

def _create_engine(self, config: PGVectorConfig):
url = URL(
drivername="postgresql",
username=config.user,
password=config.password,
def _create_connection_pool(self, config: PGVectorConfig):
return psycopg2.pool.SimpleConnectionPool(
1,
5,
host=config.host,
port=config.port,
user=config.user,
password=config.password,
database=config.database,
)
return create_engine(url)

@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self.table = new_embedding_table(self._collection_name, dimension)
self._create_collection()
self.add_texts(texts, embeddings)
self._create_collection(dimension)
return self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
data = []
ids = []
values = []
pks = []
for i, doc in enumerate(documents):
ids.append(doc.metadata["doc_id"])
data.append(
{
"id": doc.metadata["doc_id"],
"meta": doc.metadata,
"text": doc.page_content,
"embedding": embeddings[i],
}
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
with self.engine.connect() as conn:
conn.execute(self.table.insert(), data)
with self._get_cursor() as cur:
cur.executemany(
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, %s)", values
)
return pks

def text_exists(self, id: str) -> bool:
with self.engine.connect() as conn:
return conn.execute(self.table.select().where(self.table.c.id == id)).fetchone() is not None
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None

def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs

def delete_by_ids(self, ids: list[str]) -> None:
with self.engine.connect() as conn:
conn.execute(self.table.delete().where(self.table.c.id.in_(ids)))
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))

def delete_by_metadata_field(self, key: str, value: str) -> None:
with self.engine.connect() as conn:
conn.execute(self.table.delete().where(self.table.c.meta[key].astext == value))
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Expand All @@ -92,44 +115,37 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
:param distance_metric: The distance metric to use ('l2', 'max_inner_product', 'cosine').
:return: List of Documents that are nearest to the query vector.
"""
distance_metric = kwargs.get("distance_metric", "l2")
top_k = kwargs.get("top_k", 5)

# Build the order_by clause based on the distance metric specified
if distance_metric == "l2":
order_clause = self.table.c.embedding.l2_distance(query_vector)
elif distance_metric == "max_inner_product":
order_clause = self.table.c.embedding.max_inner_product(query_vector)
elif distance_metric == "cosine":
order_clause = self.table.c.embedding.cosine_distance(query_vector)
else:
raise ValueError(f"Unsupported distance metric: {distance_metric}")

with self.engine.connect() as conn:
results = conn.scalars(self.table.select().order_by(order_clause).limit(top_k)).all()
docs = []
for ret in results:
docs.append(Document(page_content=ret.text, metadata=ret.meta))
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text FROM {self.table_name} ORDER BY embedding <-> %s LIMIT {top_k}",
(json.dumps(query_vector),),
)
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# do not support bm25 search
return []

def delete(self) -> None:
with self.engine.connect() as conn:
conn.execute(self.table.drop())
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

def _create_collection(self):
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self.engine.connect() as conn:
conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
conn.execute(self.table.drop(if_exists=True))
conn.execute(self.table.create())

with self._get_cursor() as cur:
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
redis_client.set(collection_exist_cache_key, 1, ex=3600)

0 comments on commit d062099

Please sign in to comment.