Forráskód Böngészése

Merge pull request #13098 from athoik/dev

feat: Add abstract base class for vector database integration
Tim Jaeryang Baek 3 hónapja
szülő
commit
d3e516934c

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/chroma.py

@@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
 
 from typing import Optional
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
     CHROMA_DATA_PATH,
     CHROMA_HTTP_HOST,
@@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-class ChromaClient:
+class ChromaClient(VectorDBBase):
     def __init__(self):
         settings_dict = {
             "allow_reset": True,

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/elasticsearch.py

@@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
 from typing import Optional
 import ssl
 from elasticsearch.helpers import bulk, scan
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
     ELASTICSEARCH_URL,
     ELASTICSEARCH_CA_CERTS,
@@ -15,7 +20,7 @@ from open_webui.config import (
 )
 
 
-class ElasticsearchClient:
+class ElasticsearchClient(VectorDBBase):
     """
     Important:
     in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/milvus.py

@@ -4,7 +4,12 @@ import json
 import logging
 from typing import Optional
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
     MILVUS_URI,
     MILVUS_DB,
@@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-class MilvusClient:
+class MilvusClient(VectorDBBase):
     def __init__(self):
         self.collection_prefix = "open_webui"
         if MILVUS_TOKEN is None:

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/opensearch.py

@@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
 from opensearchpy.helpers import bulk
 from typing import Optional
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
     OPENSEARCH_URI,
     OPENSEARCH_SSL,
@@ -12,7 +17,7 @@ from open_webui.config import (
 )
 
 
-class OpenSearchClient:
+class OpenSearchClient(VectorDBBase):
     def __init__(self):
         self.index_prefix = "open_webui"
         self.client = OpenSearch(

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.exc import NoSuchTableError
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
 
 from open_webui.env import SRC_LOG_LEVELS
@@ -44,7 +49,7 @@ class DocumentChunk(Base):
     vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
 
 
-class PgvectorClient:
+class PgvectorClient(VectorDBBase):
     def __init__(self) -> None:
 
         # if no pgvector uri, use the existing database connection

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/pinecone.py

@@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
 import logging
 from pinecone import Pinecone, ServerlessSpec
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
     PINECONE_API_KEY,
     PINECONE_ENVIRONMENT,
@@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-class PineconeClient:
+class PineconeClient(VectorDBBase):
     def __init__(self):
         self.collection_prefix = "open-webui"
 

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/qdrant.py

@@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
 from qdrant_client.http.models import PointStruct
 from qdrant_client.models import models
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
     QDRANT_URI,
     QDRANT_API_KEY,
@@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
-class QdrantClient:
+class QdrantClient(VectorDBBase):
     def __init__(self):
         self.collection_prefix = "open-webui"
         self.QDRANT_URI = QDRANT_URI

+ 68 - 1
backend/open_webui/retrieval/vector/main.py

@@ -1,5 +1,6 @@
 from pydantic import BaseModel
-from typing import Optional, List, Any
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Union
 
 
 class VectorItem(BaseModel):
@@ -17,3 +18,69 @@ class GetResult(BaseModel):
 
 class SearchResult(GetResult):
     distances: Optional[List[List[float | int]]]
+
+
+class VectorDBBase(ABC):
+    """
+    Abstract base class for all vector database backends.
+
+    Implementations of this class provide methods for collection management,
+    vector insertion, deletion, similarity search, and metadata filtering.
+
+    Any custom vector database integration must inherit from this class and
+    implement all abstract methods.
+    """
+
+    @abstractmethod
+    def has_collection(self, collection_name: str) -> bool:
+        """Check if the collection exists in the vector DB."""
+        pass
+
+    @abstractmethod
+    def delete_collection(self, collection_name: str) -> None:
+        """Delete a collection from the vector DB."""
+        pass
+
+    @abstractmethod
+    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Insert a list of vector items into a collection."""
+        pass
+
+    @abstractmethod
+    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Insert or update vector items in a collection."""
+        pass
+
+    @abstractmethod
+    def search(
+        self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
+    ) -> Optional[SearchResult]:
+        """Search for similar vectors in a collection."""
+        pass
+
+    @abstractmethod
+    def query(
+        self, collection_name: str, filter: Dict, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        """Query vectors from a collection using metadata filter."""
+        pass
+
+    @abstractmethod
+    def get(self, collection_name: str) -> Optional[GetResult]:
+        """Retrieve all vectors from a collection."""
+        pass
+
+    @abstractmethod
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict] = None,
+    ) -> None:
+        """Delete vectors by ID or filter from a collection."""
+        pass
+
+    @abstractmethod
+    def reset(self) -> None:
+        """Reset the vector database by removing all collections or those matching a condition."""
+        pass