Timothy Jaeryang Baek пре 1 месец
родитељ
комит
50b8dec3ac
1 измењених фајлова са 18 додато и 13 уклоњено
  1. 18 13
      backend/open_webui/retrieval/utils.py

+ 18 - 13
backend/open_webui/retrieval/utils.py

@@ -16,6 +16,8 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.models.users import UserModel
 from open_webui.models.files import Files
 
+from open_webui.retrieval.vector.main import GetResult
+
 from open_webui.env import (
     SRC_LOG_LEVELS,
     OFFLINE_MODE,
@@ -98,7 +100,7 @@ def get_doc(collection_name: str, user: UserModel = None):
 
 def query_doc_with_hybrid_search(
     collection_name: str,
-    collection_data,
+    collection_result: GetResult,
     query: str,
     embedding_function,
     k: int,
@@ -108,8 +110,8 @@ def query_doc_with_hybrid_search(
 ) -> dict:
     try:
         bm25_retriever = BM25Retriever.from_texts(
-            texts=collection_data.documents[0],
-            metadatas=collection_data.metadatas[0],
+            texts=collection_result.documents[0],
+            metadatas=collection_result.metadatas[0],
         )
         bm25_retriever.k = k
 
@@ -135,9 +137,9 @@ def query_doc_with_hybrid_search(
 
         result = compression_retriever.invoke(query)
 
-        distances = [d.metadata.get("score") for d in collection_data]
-        documents = [d.page_content for d in collection_data]
-        metadatas = [d.metadata for d in collection_data]
+        distances = [d.metadata.get("score") for d in result]
+        documents = [d.page_content for d in result]
+        metadatas = [d.metadata for d in result]
 
         # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
         if k < k_reranker:
@@ -146,7 +148,8 @@ def query_doc_with_hybrid_search(
             )
             sorted_items = sorted_items[:k]
             distances, documents, metadatas = map(list, zip(*sorted_items))
-        collection_data = {
+
+        result = {
             "distances": [distances],
             "documents": [documents],
             "metadatas": [metadatas],
@@ -154,9 +157,9 @@ def query_doc_with_hybrid_search(
 
         log.info(
             "query_doc_with_hybrid_search:result "
-            + f'{collection_data["metadatas"]} {collection_data["distances"]}'
+            + f'{result["metadatas"]} {result["distances"]}'
         )
-        return collection_data
+        return result
     except Exception as e:
         raise e
 
@@ -279,20 +282,22 @@ def query_collection_with_hybrid_search(
     error = False
     # Fetch collection data once per collection sequentially
     # Avoid fetching the same data multiple times later
-    collection_data = {}
+    collection_results = {}
     for collection_name in collection_names:
         try:
-            collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name)
+            collection_results[collection_name] = VECTOR_DB_CLIENT.get(
+                collection_name=collection_name
+            )
         except Exception as e:
             log.exception(f"Failed to fetch collection {collection_name}: {e}")
-            collection_data[collection_name] = None
+            collection_results[collection_name] = None
 
     for collection_name in collection_names:
         try:
             for query in queries:
                 result = query_doc_with_hybrid_search(
                     collection_name=collection_name,
-                    collection_data=collection_data[collection_name],
+                    collection_result=collection_results[collection_name],
                     query=query,
                     embedding_function=embedding_function,
                     k=k,