Browse Source

refac: rerank

Timothy Jaeryang Baek 3 months ago
parent
commit
bc739de024
1 changed files with 7 additions and 7 deletions
  1. 7 7
      backend/open_webui/retrieval/utils.py

+ 7 - 7
backend/open_webui/retrieval/utils.py

@@ -168,7 +168,7 @@ def query_doc_with_hybrid_search(
         ):
         ):
             log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
             log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}")
             return {"documents": [], "metadatas": [], "distances": []}
             return {"documents": [], "metadatas": [], "distances": []}
-        
+
         # Now safely check the documents content after confirming attributes exist
         # Now safely check the documents content after confirming attributes exist
         if (
         if (
             not collection_result.documents
             not collection_result.documents
@@ -516,11 +516,13 @@ def get_reranking_function(reranking_engine, reranking_model, reranking_function
     if reranking_function is None:
     if reranking_function is None:
         return None
         return None
     if reranking_engine == "external":
     if reranking_engine == "external":
-        return lambda sentences, user=None: reranking_function.predict(
-            sentences, user=user
+        return lambda query, documents, user=None: reranking_function.predict(
+            [(query, doc.page_content) for doc in documents], user=user
         )
         )
     else:
     else:
-        return lambda sentences, user=None: reranking_function.predict(sentences)
+        return lambda query, documents, user=None: reranking_function.predict(
+            [(query, doc.page_content) for doc in documents]
+        )
 
 
 
 
 def get_sources_from_items(
 def get_sources_from_items(
@@ -1064,9 +1066,7 @@ class RerankCompressor(BaseDocumentCompressor):
 
 
         scores = None
         scores = None
         if reranking:
         if reranking:
-            scores = self.reranking_function(
-                [(query, doc.page_content) for doc in documents]
-            )
+            scores = self.reranking_function(query, documents)
         else:
         else:
             from sentence_transformers import util
             from sentence_transformers import util