Browse Source

~ call knowledge searches in parallel in non-hybrid mode

Alexander Grimm 3 months ago
parent
commit
d182155fac
1 changed files with 40 additions and 16 deletions
  1. 40 16
      backend/open_webui/retrieval/utils.py

+ 40 - 16
backend/open_webui/retrieval/utils.py

@@ -260,23 +260,47 @@ def query_collection(
     k: int,
 ) -> dict:
     results = []
-    for query in queries:
-        log.debug(f"query_collection:query {query}")
-        query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
-        for collection_name in collection_names:
+    error = False
+
+    def process_query_collection(collection_name, query_embedding):
+        try:
             if collection_name:
-                try:
-                    result = query_doc(
-                        collection_name=collection_name,
-                        k=k,
-                        query_embedding=query_embedding,
-                    )
-                    if result is not None:
-                        results.append(result.model_dump())
-                except Exception as e:
-                    log.exception(f"Error when querying the collection: {e}")
-            else:
-                pass
+                result = query_doc(
+                    collection_name=collection_name,
+                    k=k,
+                    query_embedding=query_embedding,
+                )
+                if result is not None:
+                    return result.model_dump(), None
+            return None, None
+        except Exception as e:
+            log.exception(f"Error when querying the collection: {e}")
+            return None, e
+
+    # Generate all query embeddings (in one call)
+    query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
+    log.debug(
+        f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
+    )
+
+    with ThreadPoolExecutor() as executor:
+        future_results = []
+        for query_embedding in query_embeddings:
+            for collection_name in collection_names:
+                result = executor.submit(
+                    process_query_collection, collection_name, query_embedding
+                )
+                future_results.append(result)
+        task_results = [future.result() for future in future_results]
+
+    for result, err in task_results:
+        if err is not None:
+            error = True
+        elif result is not None:
+            results.append(result)
+
+    if error and not results:
+        log.warning("All collection queries failed. No results returned.")
 
     return merge_and_sort_query_results(results, k=k)