Browse Source

feat:Add vector type and vector factory class for vector database integration

hwzhuhao 1 month ago
parent
commit
6f869ded43

+ 1 - 1
backend/open_webui/retrieval/utils.py

@@ -12,7 +12,7 @@ from langchain_community.retrievers import BM25Retriever
 from langchain_core.documents import Document
 from langchain_core.documents import Document
 
 
 from open_webui.config import VECTOR_DB
 from open_webui.config import VECTOR_DB
-from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 
 
 from open_webui.models.users import UserModel
 from open_webui.models.users import UserModel
 from open_webui.models.files import Files
 from open_webui.models.files import Files

+ 0 - 30
backend/open_webui/retrieval/vector/connector.py

@@ -1,30 +0,0 @@
-from open_webui.config import VECTOR_DB
-
-if VECTOR_DB == "milvus":
-    from open_webui.retrieval.vector.dbs.milvus import MilvusClient
-
-    VECTOR_DB_CLIENT = MilvusClient()
-elif VECTOR_DB == "qdrant":
-    from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
-
-    VECTOR_DB_CLIENT = QdrantClient()
-elif VECTOR_DB == "opensearch":
-    from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
-
-    VECTOR_DB_CLIENT = OpenSearchClient()
-elif VECTOR_DB == "pgvector":
-    from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
-
-    VECTOR_DB_CLIENT = PgvectorClient()
-elif VECTOR_DB == "elasticsearch":
-    from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
-
-    VECTOR_DB_CLIENT = ElasticsearchClient()
-elif VECTOR_DB == "pinecone":
-    from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
-
-    VECTOR_DB_CLIENT = PineconeClient()
-else:
-    from open_webui.retrieval.vector.dbs.chroma import ChromaClient
-
-    VECTOR_DB_CLIENT = ChromaClient()

+ 48 - 0
backend/open_webui/retrieval/vector/factory.py

@@ -0,0 +1,48 @@
+from open_webui.retrieval.vector.main import VectorDBBase
+from open_webui.retrieval.vector.type import VectorType
+from open_webui.config import VECTOR_DB
+
+
+class Vector:
+
+    @staticmethod
+    def get_vector(vector_type: str) -> VectorDBBase:
+        """
+        get vector db instance by vector type
+        """
+        match vector_type:
+            case VectorType.MILVUS:
+                from open_webui.retrieval.vector.dbs.milvus import MilvusClient
+
+                return MilvusClient()
+            case VectorType.QDRANT:
+                from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
+
+                return QdrantClient()
+            case VectorType.PINECONE:
+                from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
+
+                return PineconeClient()
+            case VectorType.OPENSEARCH:
+                from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
+
+                return OpenSearchClient()
+            case VectorType.PGVECTOR:
+                from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
+
+                return PgvectorClient()
+            case VectorType.ELASTICSEARCH:
+                from open_webui.retrieval.vector.dbs.elasticsearch import (
+                    ElasticsearchClient,
+                )
+
+                return ElasticsearchClient()
+            case VectorType.CHROMA:
+                from open_webui.retrieval.vector.dbs.chroma import ChromaClient
+
+                return ChromaClient()
+            case _:
+                raise ValueError(f"Unsupported vector type: {vector_type}")
+
+
+VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)

+ 11 - 0
backend/open_webui/retrieval/vector/type.py

@@ -0,0 +1,11 @@
+from enum import StrEnum
+
+
+class VectorType(StrEnum):
+    MILVUS = "milvus"
+    QDRANT = "qdrant"
+    CHROMA = "chroma"
+    PINECONE = "pinecone"
+    ELASTICSEARCH = "elasticsearch"
+    OPENSEARCH = "opensearch"
+    PGVECTOR = "pgvector"

+ 1 - 1
backend/open_webui/routers/knowledge.py

@@ -10,7 +10,7 @@ from open_webui.models.knowledge import (
     KnowledgeUserResponse,
     KnowledgeUserResponse,
 )
 )
 from open_webui.models.files import Files, FileModel, FileMetadataResponse
 from open_webui.models.files import Files, FileModel, FileMetadataResponse
-from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 from open_webui.routers.retrieval import (
 from open_webui.routers.retrieval import (
     process_file,
     process_file,
     ProcessFileForm,
     ProcessFileForm,

+ 1 - 1
backend/open_webui/routers/memories.py

@@ -4,7 +4,7 @@ import logging
 from typing import Optional
 from typing import Optional
 
 
 from open_webui.models.memories import Memories, MemoryModel
 from open_webui.models.memories import Memories, MemoryModel
-from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 from open_webui.utils.auth import get_verified_user
 from open_webui.utils.auth import get_verified_user
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 

+ 1 - 1
backend/open_webui/routers/retrieval.py

@@ -36,7 +36,7 @@ from open_webui.models.knowledge import Knowledges
 from open_webui.storage.provider import Storage
 from open_webui.storage.provider import Storage
 
 
 
 
-from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
+from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
 
 
 # Document loaders
 # Document loaders
 from open_webui.retrieval.loaders.main import Loader
 from open_webui.retrieval.loaders.main import Loader