Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code for AstraDB #197

Open
wants to merge 26 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions mirascope/astra/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""A module for interacting with Astra vectorstores."""
from .types import AstraParams, AstraQueryResult, AstraSettings
from .vectorstores import AstraVectorStore
brenkao marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"AstraParams",
"AstraQueryResult",
"AstraSettings",
"AstraVectorStore",
]
94 changes: 94 additions & 0 deletions mirascope/astra/types.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstrings to the module and all public classes and functions. They should follow the Google docstring style

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Types for interacting with Astra DB using Mirascope."""
from pydantic import BaseModel
brenkao marked this conversation as resolved.
Show resolved Hide resolved
from typing import Any, Optional, Literal, Dict
from .vectorstores import BaseVectorStoreParams # Adjust import based on actual location


class AstraSettings(BaseModel):
brenkao marked this conversation as resolved.
Show resolved Hide resolved
"""
AstraSettings stores the configuration parameters necessary to establish a connection with AstraDB.
These parameters include the API endpoint, the application token, etc. to be used when interacting
with AstraDB.
"""
token: str
api_endpoint: str
api_path: Optional[str] = None
api_version: Optional[str] = None
namespace: Optional[str] = None
caller_name: Optional[str] = None
caller_version: Optional[str] = None

def kwargs(self):
"""Return a dictionary of settings suitable for passing to AstraDB client initialization."""
return self.dict(exclude_none=True)

class AstraParams(BaseVectorStoreParams):

"""AstraParams defines the parameters used for managing AstraDB collections.
These can include options for collection creation, dimensions for vector search,
metric choices, and other database-specific settings.

Example usage:
params = AstraParams(
collection_name="example_collection",
dimension=128,
metric="cosine",
service_dict={"example_key": "example_value"}
)

collection = my_astra_db_instance.create_collection(
params.collection_name,
options=params.options,
dimension=params.dimension,
metric=params.metric,
service_dict=params.service_dict,
timeout_info=params.timeout_info
)
"""
# Additional parameters can be added here as needed.
brenkao marked this conversation as resolved.
Show resolved Hide resolved
collection_name: str
options: Optional[Dict[str, Any]] = None
dimension: Optional[int] = None
metric: Optional[str] = None
service_dict: Optional[Dict[str, str]] = None
timeout_info: Optional[Any] = None # Type here should match the expected type for timeout_info

class Config:
extra = "allow" # This allows the class to accept other fields not explicitly defined here, if needed.



class AstraQueryResult(BaseModel):
"""
AstraQueryResult defines the structure of the results returned by queries to AstraDB.
It primarily wraps the documents retrieved as a list of lists, where each inner list
represents a document and its associated details.

Attributes:
documents (Optional[list[list[str]]]): A nested list where each sublist contains
details of a document, such as its text and source.

Methods:
convert(api_response: Any) -> 'AstraQueryResult':
Converts the API response into an AstraQueryResult format, extracting relevant
document details such as text and source from the response.
"""
documents: Optional[list[list[str]]] = None

@staticmethod
def convert(api_response):
"""
Converts a raw API response into an organized AstraQueryResult, making it easier to handle.

Args:
api_response (Any): The raw API response from which document details will be extracted.

Returns:
AstraQueryResult: The result object containing the structured documents.
"""
return AstraQueryResult(
documents=[
[f"text: {result['text']}", f"source: {result['source']}"]
for result in api_response
]
)
97 changes: 97 additions & 0 deletions mirascope/astra/vectorstores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""A module for calling Astra DB's Client and Collection."""
import logging
brenkao marked this conversation as resolved.
Show resolved Hide resolved
from contextlib import suppress
from functools import cached_property
from typing import Any, ClassVar, Optional, Union, Dict
from pydantic import BaseModel

from astrapy.db import AstraDB
from ..rag.types import Document
from ..rag.vectorstores import BaseVectorStore
from .types import AstraParams, AstraQueryResult, AstraSettings

class AstraVectorStoreParams(BaseModel):
get_or_create: bool = True
additional_params: Optional[Dict[str, Any]] = {}

class AstraVectorStore(BaseVectorStore):
"""AstraVectorStore integrates AstraDB with a vector storage mechanism, allowing for efficient
storage and retrieval of document vectors. This class handles the connection and operations
specific to AstraDB, such as adding and retrieving documents based on vector similarity.

"""

client_settings: ClassVar[AstraSettings] = AstraSettings()
index_name: ClassVar[str] = "default_collection" # Use BaseVectorStore's index_name if applicable
vectorstore_params: ClassVar[AstraVectorStoreParams] = AstraVectorStoreParams()


def retrieve(self, text: Optional[Union[str, list[str]]] = None, **kwargs: Any) -> AstraQueryResult:
"""
Queries the AstraDB vectorstore for documents that are the closest match to the input text.

Args:
text (str | list[str], optional): The text or list of texts to query against the database.
**kwargs: Additional keyword arguments for configuring the query, such as limit.

Returns:
AstraQueryResult: Contains the documents and possibly embeddings retrieved from the database.
"""

embedded_query = self.embedder(text)[0] if text else None
query_params = {**self.vectorstore_params.additional_params, **kwargs}
results = self._collection.vector_find(
embedded_query, **query_params
)

documents = []
embeddings = []
for result in results:
documents.append([result['text'], result['source']])
embeddings.append(result['embeddings']) # Assuming 'embeddings' is part of the results

return AstraQueryResult(documents=documents, embeddings=embeddings)


def add(self, text: Union[str, list[Document]], **kwargs: Any) -> None:
"""
Adds a new document or a list of documents to the AstraDB collection. Each document
must include the text, its embeddings, and optionally the source.

Args:
text (str | list[Document]): The text or documents to be added.
**kwargs: Additional keyword arguments such as filename which represents the source of the document.

Returns:
None
"""

if not text:
raise ValueError("No text provided for addition.")

documents = self.chunker.chunk(text) if isinstance(text, str) else text
for document in documents:
embeddings = self.embedder(document.text)[0]
document_to_insert = {
"text": document.text,
"$vector": embeddings, # Include vector embeddings
"source": kwargs.get("filename", "unknown") # Optionally include source file name
}
self._collection.insert_one(document_to_insert)

############################# PRIVATE PROPERTIES #################################

@cached_property
def _client(self) -> AstraDB:
"""Instantiate and return an AstraDB client configured from settings."""
try:
return AstraDB(**self.client_settings.kwargs()) # Dynamically passing settings
except Exception as e:
logging.error(f"Failed to initialize AstraDB client: {e}")
raise

@cached_property
def _collection(self):
"""Access or create the collection based on the parameters."""
collection_params = {**self.vectorstore_params.dict(), "name": self.index_name}
return self._client.create_collection(**collection_params)