Pārlūkot izejas kodu

disable collection retrieval and bm_25 calculation if bm_25 weight is 0 or less

expruc 1 mēnesi atpakaļ
vecāks
revīzija
74b1c80132
1 mainītis faili ar 24 papildinājumiem un 18 dzēšanām
  1. 24 18
      backend/open_webui/retrieval/utils.py

+ 24 - 18
backend/open_webui/retrieval/utils.py

@@ -124,12 +124,14 @@ def query_doc_with_hybrid_search(
     hybrid_bm25_weight: float,
 ) -> dict:
     try:
-        log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
-        bm25_retriever = BM25Retriever.from_texts(
-            texts=collection_result.documents[0],
-            metadatas=collection_result.metadatas[0],
-        )
-        bm25_retriever.k = k
+        # BM_25 required only if weight is greater than 0
+        if hybrid_bm25_weight > 0:
+            log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
+            bm25_retriever = BM25Retriever.from_texts(
+                texts=collection_result.documents[0],
+                metadatas=collection_result.metadatas[0],
+            )
+            bm25_retriever.k = k
 
         vector_search_retriever = VectorSearchRetriever(
             collection_name=collection_name,
@@ -337,18 +339,22 @@ def query_collection_with_hybrid_search(
     # Fetch collection data once per collection sequentially
     # Avoid fetching the same data multiple times later
     collection_results = {}
-    for collection_name in collection_names:
-        try:
-            log.debug(
-                f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
-            )
-            collection_results[collection_name] = VECTOR_DB_CLIENT.get(
-                collection_name=collection_name
-            )
-        except Exception as e:
-            log.exception(f"Failed to fetch collection {collection_name}: {e}")
-            collection_results[collection_name] = None
-
+    # Only retrieve entire collection if bm_25 calculation is required
+    if hybrid_bm25_weight > 0:
+        for collection_name in collection_names:
+            try:
+                log.debug(
+                    f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
+                )
+                collection_results[collection_name] = VECTOR_DB_CLIENT.get(
+                    collection_name=collection_name
+                )
+            except Exception as e:
+                log.exception(f"Failed to fetch collection {collection_name}: {e}")
+                collection_results[collection_name] = None
+    else:
+        for collection_name in collection_names:
+            collection_results[collection_name] = []
     log.info(
         f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
     )