Просмотр исходного кода

Merge pull request #16419 from expruc/feat/qdrant_improvements

feat: qdrant client improvements
Tim Jaeryang Baek 1 месяц назад
Родитель
Сommit
53425ffadb

+ 2 - 0
backend/open_webui/config.py

@@ -1924,6 +1924,8 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
 QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
 QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true"
 QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
+QDRANT_TIMEOUT = int(os.environ.get("QDRANT_TIMEOUT", "5"))
+QDRANT_HNSW_M = int(os.environ.get("QDRANT_HNSW_M", "16"))
 ENABLE_QDRANT_MULTITENANCY_MODE = (
     os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true"
 )

+ 14 - 6
backend/open_webui/retrieval/vector/dbs/qdrant.py

@@ -19,6 +19,8 @@ from open_webui.config import (
     QDRANT_GRPC_PORT,
     QDRANT_PREFER_GRPC,
     QDRANT_COLLECTION_PREFIX,
+    QDRANT_TIMEOUT,
+    QDRANT_HNSW_M,
 )
 from open_webui.env import SRC_LOG_LEVELS
 
@@ -36,6 +38,8 @@ class QdrantClient(VectorDBBase):
         self.QDRANT_ON_DISK = QDRANT_ON_DISK
         self.PREFER_GRPC = QDRANT_PREFER_GRPC
         self.GRPC_PORT = QDRANT_GRPC_PORT
+        self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
+        self.QDRANT_HNSW_M = QDRANT_HNSW_M
 
         if not self.QDRANT_URI:
             self.client = None
@@ -53,9 +57,10 @@ class QdrantClient(VectorDBBase):
                 grpc_port=self.GRPC_PORT,
                 prefer_grpc=self.PREFER_GRPC,
                 api_key=self.QDRANT_API_KEY,
+                timeout=self.QDRANT_TIMEOUT,
             )
         else:
-            self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+            self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=QDRANT_TIMEOUT,)
 
     def _result_to_get_result(self, points) -> GetResult:
         ids = []
@@ -85,6 +90,9 @@ class QdrantClient(VectorDBBase):
                 distance=models.Distance.COSINE,
                 on_disk=self.QDRANT_ON_DISK,
             ),
+            hnsw_config=models.HnswConfigDiff(
+                m=self.QDRANT_HNSW_M,
+            ),
         )
 
         # Create payload indexes for efficient filtering
@@ -171,23 +179,23 @@ class QdrantClient(VectorDBBase):
                     )
                 )
 
-            points = self.client.query_points(
+            points = self.client.scroll(
                 collection_name=f"{self.collection_prefix}_{collection_name}",
-                query_filter=models.Filter(should=field_conditions),
+                scroll_filter=models.Filter(should=field_conditions),
                 limit=limit,
             )
-            return self._result_to_get_result(points.points)
+            return self._result_to_get_result(points[0])
         except Exception as e:
             log.exception(f"Error querying a collection '{collection_name}': {e}")
             return None
 
     def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
-        points = self.client.query_points(
+        points = self.client.scroll(
             collection_name=f"{self.collection_prefix}_{collection_name}",
             limit=NO_LIMIT,  # otherwise qdrant would set limit to 10!
         )
-        return self._result_to_get_result(points.points)
+        return self._result_to_get_result(points[0])
 
     def insert(self, collection_name: str, items: list[VectorItem]):
         # Insert the items into the collection, if the collection does not exist, it will be created.

+ 18 - 7
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py

@@ -10,6 +10,8 @@ from open_webui.config import (
     QDRANT_PREFER_GRPC,
     QDRANT_URI,
     QDRANT_COLLECTION_PREFIX,
+    QDRANT_TIMEOUT,
+    QDRANT_HNSW_M,
 )
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.retrieval.vector.main import (
@@ -51,6 +53,8 @@ class QdrantClient(VectorDBBase):
         self.QDRANT_ON_DISK = QDRANT_ON_DISK
         self.PREFER_GRPC = QDRANT_PREFER_GRPC
         self.GRPC_PORT = QDRANT_GRPC_PORT
+        self.QDRANT_TIMEOUT = QDRANT_TIMEOUT
+        self.QDRANT_HNSW_M = QDRANT_HNSW_M
 
         if not self.QDRANT_URI:
             raise ValueError(
@@ -69,9 +73,10 @@ class QdrantClient(VectorDBBase):
                 grpc_port=self.GRPC_PORT,
                 prefer_grpc=self.PREFER_GRPC,
                 api_key=self.QDRANT_API_KEY,
+                timeout=self.QDRANT_TIMEOUT,
             )
             if self.PREFER_GRPC
-            else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
+            else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY, timeout=self.QDRANT_TIMEOUT,)
         )
 
         # Main collection types for multi-tenancy
@@ -133,6 +138,12 @@ class QdrantClient(VectorDBBase):
                 distance=models.Distance.COSINE,
                 on_disk=self.QDRANT_ON_DISK,
             ),
+            # Disable global index building due to multitenancy
+            # For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance
+            hnsw_config=models.HnswConfigDiff(
+                payload_m=self.QDRANT_HNSW_M,
+                m=0,
+            ),
         )
         log.info(
             f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
@@ -278,12 +289,12 @@ class QdrantClient(VectorDBBase):
         tenant_filter = _tenant_filter(tenant_id)
         field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
         combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
-        points = self.client.query_points(
+        points = self.client.scroll(
             collection_name=mt_collection,
-            query_filter=combined_filter,
+            scroll_filter=combined_filter,
             limit=limit,
         )
-        return self._result_to_get_result(points.points)
+        return self._result_to_get_result(points[0])
 
     def get(self, collection_name: str) -> Optional[GetResult]:
         """
@@ -296,12 +307,12 @@ class QdrantClient(VectorDBBase):
             log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
             return None
         tenant_filter = _tenant_filter(tenant_id)
-        points = self.client.query_points(
+        points = self.client.scroll(
             collection_name=mt_collection,
-            query_filter=models.Filter(must=[tenant_filter]),
+            scroll_filter=models.Filter(must=[tenant_filter]),
             limit=NO_LIMIT,
         )
-        return self._result_to_get_result(points.points)
+        return self._result_to_get_result(points[0])
 
     def upsert(self, collection_name: str, items: List[VectorItem]):
         """