浏览代码

Merge pull request #13098 from athoik/dev

feat: Add abstract base class for vector database integration
Tim Jaeryang Baek 3 月之前
父节点
当前提交
d3e516934c

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/chroma.py

@@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
 
 
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
 from open_webui.config import (
     CHROMA_DATA_PATH,
     CHROMA_DATA_PATH,
     CHROMA_HTTP_HOST,
     CHROMA_HTTP_HOST,
@@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 
 
-class ChromaClient:
+class ChromaClient(VectorDBBase):
     def __init__(self):
     def __init__(self):
         settings_dict = {
         settings_dict = {
             "allow_reset": True,
             "allow_reset": True,

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/elasticsearch.py

@@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
 from typing import Optional
 from typing import Optional
 import ssl
 import ssl
 from elasticsearch.helpers import bulk, scan
 from elasticsearch.helpers import bulk, scan
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
 from open_webui.config import (
     ELASTICSEARCH_URL,
     ELASTICSEARCH_URL,
     ELASTICSEARCH_CA_CERTS,
     ELASTICSEARCH_CA_CERTS,
@@ -15,7 +20,7 @@ from open_webui.config import (
 )
 )
 
 
 
 
-class ElasticsearchClient:
+class ElasticsearchClient(VectorDBBase):
     """
     """
     Important:
     Important:
     in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
     in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/milvus.py

@@ -4,7 +4,12 @@ import json
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
 from open_webui.config import (
     MILVUS_URI,
     MILVUS_URI,
     MILVUS_DB,
     MILVUS_DB,
@@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 
 
-class MilvusClient:
+class MilvusClient(VectorDBBase):
     def __init__(self):
     def __init__(self):
         self.collection_prefix = "open_webui"
         self.collection_prefix = "open_webui"
         if MILVUS_TOKEN is None:
         if MILVUS_TOKEN is None:

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/opensearch.py

@@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
 from opensearchpy.helpers import bulk
 from opensearchpy.helpers import bulk
 from typing import Optional
 from typing import Optional
 
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
 from open_webui.config import (
     OPENSEARCH_URI,
     OPENSEARCH_URI,
     OPENSEARCH_SSL,
     OPENSEARCH_SSL,
@@ -12,7 +17,7 @@ from open_webui.config import (
 )
 )
 
 
 
 
-class OpenSearchClient:
+class OpenSearchClient(VectorDBBase):
     def __init__(self):
     def __init__(self):
         self.index_prefix = "open_webui"
         self.index_prefix = "open_webui"
         self.client = OpenSearch(
         self.client = OpenSearch(

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/pgvector.py

@@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.exc import NoSuchTableError
 from sqlalchemy.exc import NoSuchTableError
 
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
 from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
@@ -44,7 +49,7 @@ class DocumentChunk(Base):
     vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
     vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
 
 
 
 
-class PgvectorClient:
+class PgvectorClient(VectorDBBase):
     def __init__(self) -> None:
     def __init__(self) -> None:
 
 
         # if no pgvector uri, use the existing database connection
         # if no pgvector uri, use the existing database connection

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/pinecone.py

@@ -2,7 +2,12 @@ from typing import Optional, List, Dict, Any, Union
 import logging
 import logging
 from pinecone import Pinecone, ServerlessSpec
 from pinecone import Pinecone, ServerlessSpec
 
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
 from open_webui.config import (
     PINECONE_API_KEY,
     PINECONE_API_KEY,
     PINECONE_ENVIRONMENT,
     PINECONE_ENVIRONMENT,
@@ -20,7 +25,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 
 
-class PineconeClient:
+class PineconeClient(VectorDBBase):
     def __init__(self):
     def __init__(self):
         self.collection_prefix = "open-webui"
         self.collection_prefix = "open-webui"
 
 

+ 7 - 2
backend/open_webui/retrieval/vector/dbs/qdrant.py

@@ -6,7 +6,12 @@ from qdrant_client import QdrantClient as Qclient
 from qdrant_client.http.models import PointStruct
 from qdrant_client.http.models import PointStruct
 from qdrant_client.models import models
 from qdrant_client.models import models
 
 
-from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
+from open_webui.retrieval.vector.main import (
+    VectorDBBase,
+    VectorItem,
+    SearchResult,
+    GetResult,
+)
 from open_webui.config import (
 from open_webui.config import (
     QDRANT_URI,
     QDRANT_URI,
     QDRANT_API_KEY,
     QDRANT_API_KEY,
@@ -22,7 +27,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 
 
-class QdrantClient:
+class QdrantClient(VectorDBBase):
     def __init__(self):
     def __init__(self):
         self.collection_prefix = "open-webui"
         self.collection_prefix = "open-webui"
         self.QDRANT_URI = QDRANT_URI
         self.QDRANT_URI = QDRANT_URI

+ 68 - 1
backend/open_webui/retrieval/vector/main.py

@@ -1,5 +1,6 @@
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import Optional, List, Any
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Union
 
 
 
 
 class VectorItem(BaseModel):
 class VectorItem(BaseModel):
@@ -17,3 +18,69 @@ class GetResult(BaseModel):
 
 
 class SearchResult(GetResult):
 class SearchResult(GetResult):
     distances: Optional[List[List[float | int]]]
     distances: Optional[List[List[float | int]]]
+
+
+class VectorDBBase(ABC):
+    """
+    Abstract base class for all vector database backends.
+
+    Implementations of this class provide methods for collection management,
+    vector insertion, deletion, similarity search, and metadata filtering.
+
+    Any custom vector database integration must inherit from this class and
+    implement all abstract methods.
+    """
+
+    @abstractmethod
+    def has_collection(self, collection_name: str) -> bool:
+        """Check if the collection exists in the vector DB."""
+        pass
+
+    @abstractmethod
+    def delete_collection(self, collection_name: str) -> None:
+        """Delete a collection from the vector DB."""
+        pass
+
+    @abstractmethod
+    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Insert a list of vector items into a collection."""
+        pass
+
+    @abstractmethod
+    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
+        """Insert or update vector items in a collection."""
+        pass
+
+    @abstractmethod
+    def search(
+        self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
+    ) -> Optional[SearchResult]:
+        """Search for similar vectors in a collection."""
+        pass
+
+    @abstractmethod
+    def query(
+        self, collection_name: str, filter: Dict, limit: Optional[int] = None
+    ) -> Optional[GetResult]:
+        """Query vectors from a collection using metadata filter."""
+        pass
+
+    @abstractmethod
+    def get(self, collection_name: str) -> Optional[GetResult]:
+        """Retrieve all vectors from a collection."""
+        pass
+
+    @abstractmethod
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[List[str]] = None,
+        filter: Optional[Dict] = None,
+    ) -> None:
+        """Delete vectors by ID or filter from a collection."""
+        pass
+
+    @abstractmethod
+    def reset(self) -> None:
+        """Reset the vector database by removing all collections or those matching a condition."""
+        pass