Browse Source

make bm25_weight a regular parameter of query_doc.. / get_sources_from_files functions

Jan Kessler 4 tháng trước cách đây
mục cha
commit
308d8ac04a

+ 8 - 4
backend/open_webui/retrieval/utils.py

@@ -29,7 +29,6 @@ from open_webui.config import (
     RAG_EMBEDDING_QUERY_PREFIX,
     RAG_EMBEDDING_CONTENT_PREFIX,
     RAG_EMBEDDING_PREFIX_FIELD_NAME,
-    RAG_BM25_WEIGHT,
 )
 
 log = logging.getLogger(__name__)
@@ -117,6 +116,7 @@ def query_doc_with_hybrid_search(
     reranking_function,
     k_reranker: int,
     r: float,
+    bm25_weight: float,
 ) -> dict:
     try:
         log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
@@ -132,18 +132,18 @@ def query_doc_with_hybrid_search(
             top_k=k,
         )
 
-        if RAG_BM25_WEIGHT <= 0:
+        if bm25_weight <= 0:
             ensemble_retriever = EnsembleRetriever(
                 retrievers=[vector_search_retriever], weights=[1.]
             )
-        elif RAG_BM25_WEIGHT >= 1:
+        elif bm25_weight >= 1:
             ensemble_retriever = EnsembleRetriever(
                 retrievers=[bm25_retriever], weights=[1.]
             )
         else:
             ensemble_retriever = EnsembleRetriever(
                 retrievers=[bm25_retriever, vector_search_retriever],
-                weights=[RAG_BM25_WEIGHT, 1. - RAG_BM25_WEIGHT]
+                weights=[bm25_weight, 1. - bm25_weight]
             )
 
         compressor = RerankCompressor(
@@ -325,6 +325,7 @@ def query_collection_with_hybrid_search(
     reranking_function,
     k_reranker: int,
     r: float,
+    bm25_weight: float,
 ) -> dict:
     results = []
     error = False
@@ -358,6 +359,7 @@ def query_collection_with_hybrid_search(
                 reranking_function=reranking_function,
                 k_reranker=k_reranker,
                 r=r,
+                bm25_weight=bm25_weight,
             )
             return result, None
         except Exception as e:
@@ -445,6 +447,7 @@ def get_sources_from_files(
     reranking_function,
     k_reranker,
     r,
+    bm25_weight,
     hybrid_search,
     full_context=False,
 ):
@@ -562,6 +565,7 @@ def get_sources_from_files(
                                     reranking_function=reranking_function,
                                     k_reranker=k_reranker,
                                     r=r,
+                                    bm25_weight=bm25_weight,
                                 )
                             except Exception as e:
                                 log.debug(

+ 10 - 0
backend/open_webui/routers/retrieval.py

@@ -1782,6 +1782,11 @@ def query_doc_handler(
                     if form_data.r
                     else request.app.state.config.RELEVANCE_THRESHOLD
                 ),
+                bm25_weight=(
+                    form_data.bm25_weight
+                    if form_data.bm25_weight
+                    else request.app.state.config.BM25_WEIGHT
+                ),
                 user=user,
             )
         else:
@@ -1833,6 +1838,11 @@ def query_collection_handler(
                     if form_data.r
                     else request.app.state.config.RELEVANCE_THRESHOLD
                 ),
+                bm25_weight=(
+                    form_data.bm25_weight
+                    if form_data.bm25_weight
+                    else request.app.state.config.BM25_WEIGHT
+                ),
             )
         else:
             return query_collection(

+ 1 - 0
backend/open_webui/utils/middleware.py

@@ -603,6 +603,7 @@ async def chat_completion_files_handler(
                         reranking_function=request.app.state.rf,
                         k_reranker=request.app.state.config.TOP_K_RERANKER,
                         r=request.app.state.config.RELEVANCE_THRESHOLD,
+                        bm25_weight=request.app.state.config.BM25_WEIGHT,
                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
                         full_context=request.app.state.config.RAG_FULL_CONTEXT,
                     ),