Sfoglia il codice sorgente

Added new k_reranker parameter

Marko Henning 2 mesi fa
parent
commit
41a4cf7106

+ 5 - 0
backend/open_webui/config.py

@@ -1646,6 +1646,11 @@ BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
 RAG_TOP_K = PersistentConfig(
     "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
 )
+RAG_TOP_K_RERANKER = PersistentConfig(
+    "RAG_TOP_K_RERANKER",
+    "rag.top_k_reranker",
+    int(os.environ.get("RAG_TOP_K_RERANKER", "3"))
+)
 RAG_RELEVANCE_THRESHOLD = PersistentConfig(
     "RAG_RELEVANCE_THRESHOLD",
     "rag.relevance_threshold",

+ 2 - 0
backend/open_webui/main.py

@@ -189,6 +189,7 @@ from open_webui.config import (
     DOCUMENT_INTELLIGENCE_ENDPOINT,
     DOCUMENT_INTELLIGENCE_KEY,
     RAG_TOP_K,
+    RAG_TOP_K_RERANKER,
     RAG_TEXT_SPLITTER,
     TIKTOKEN_ENCODING_NAME,
     PDF_EXTRACT_IMAGES,
@@ -535,6 +536,7 @@ app.state.FUNCTIONS = {}
 
 
 app.state.config.TOP_K = RAG_TOP_K
+app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
 app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
 app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
 app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT

+ 6 - 1
backend/open_webui/retrieval/utils.py

@@ -106,6 +106,7 @@ def query_doc_with_hybrid_search(
     embedding_function,
     k: int,
     reranking_function,
+    k_reranker: int,
     r: float,
 ) -> dict:
     try:
@@ -128,7 +129,7 @@ def query_doc_with_hybrid_search(
         )
         compressor = RerankCompressor(
             embedding_function=embedding_function,
-            top_n=k,
+            top_n=k_reranker,
             reranking_function=reranking_function,
             r_score=r,
         )
@@ -267,6 +268,7 @@ def query_collection_with_hybrid_search(
     embedding_function,
     k: int,
     reranking_function,
+    k_reranker: int,
     r: float,
 ) -> dict:
     results = []
@@ -280,6 +282,7 @@ def query_collection_with_hybrid_search(
                     embedding_function=embedding_function,
                     k=k,
                     reranking_function=reranking_function,
+                    k_reranker=k_reranker,
                     r=r,
                 )
                 results.append(result)
@@ -345,6 +348,7 @@ def get_sources_from_files(
     embedding_function,
     k,
     reranking_function,
+    k_reranker,
     r,
     hybrid_search,
     full_context=False,
@@ -461,6 +465,7 @@ def get_sources_from_files(
                                     embedding_function=embedding_function,
                                     k=k,
                                     reranking_function=reranking_function,
+                                    k_reranker=k_reranker,
                                     r=r,
                                 )
                             except Exception as e:

+ 8 - 0
backend/open_webui/routers/retrieval.py

@@ -713,6 +713,7 @@ async def get_query_settings(request: Request, user=Depends(get_admin_user)):
         "status": True,
         "template": request.app.state.config.RAG_TEMPLATE,
         "k": request.app.state.config.TOP_K,
+        "k_reranker": request.app.state.config.TOP_K_RERANKER,
         "r": request.app.state.config.RELEVANCE_THRESHOLD,
         "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
     }
@@ -720,6 +721,7 @@ async def get_query_settings(request: Request, user=Depends(get_admin_user)):
 
 class QuerySettingsForm(BaseModel):
     k: Optional[int] = None
+    k_reranker: Optional[int] = None
     r: Optional[float] = None
     template: Optional[str] = None
     hybrid: Optional[bool] = None
@@ -731,6 +733,7 @@ async def update_query_settings(
 ):
     request.app.state.config.RAG_TEMPLATE = form_data.template
     request.app.state.config.TOP_K = form_data.k if form_data.k else 4
+    request.app.state.config.TOP_K_RERANKER = form_data.k_reranker or 4
     request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
 
     request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
@@ -741,6 +744,7 @@ async def update_query_settings(
         "status": True,
         "template": request.app.state.config.RAG_TEMPLATE,
         "k": request.app.state.config.TOP_K,
+        "k_reranker": request.app.state.config.TOP_K_RERANKER,
         "r": request.app.state.config.RELEVANCE_THRESHOLD,
         "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
     }
@@ -1488,6 +1492,7 @@ class QueryDocForm(BaseModel):
     collection_name: str
     query: str
     k: Optional[int] = None
+    k_reranker: Optional[int] = None
     r: Optional[float] = None
     hybrid: Optional[bool] = None
 
@@ -1508,6 +1513,7 @@ def query_doc_handler(
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 reranking_function=request.app.state.rf,
+                k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER,
                 r=(
                     form_data.r
                     if form_data.r
@@ -1536,6 +1542,7 @@ class QueryCollectionsForm(BaseModel):
     collection_names: list[str]
     query: str
     k: Optional[int] = None
+    k_reranker: Optional[int] = None
     r: Optional[float] = None
     hybrid: Optional[bool] = None
 
@@ -1556,6 +1563,7 @@ def query_collection_handler(
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 reranking_function=request.app.state.rf,
+                k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER,
                 r=(
                     form_data.r
                     if form_data.r

+ 1 - 0
backend/open_webui/utils/middleware.py

@@ -567,6 +567,7 @@ async def chat_completion_files_handler(
                         ),
                         k=request.app.state.config.TOP_K,
                         reranking_function=request.app.state.rf,
+                        k_reranker=request.app.state.config.TOP_K_RERANKER,
                         r=request.app.state.config.RELEVANCE_THRESHOLD,
                         hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
                         full_context=request.app.state.config.RAG_FULL_CONTEXT,

+ 18 - 0
src/lib/components/admin/Settings/Documents.svelte

@@ -74,6 +74,7 @@
 		template: '',
 		r: 0.0,
 		k: 4,
+		k_reranker: 4,
 		hybrid: false
 	};
 
@@ -738,6 +739,23 @@
 						</div>
 					</div>
 
+					{#if querySettings.hybrid === true}
+						<div class="mb-2.5 flex w-full justify-between">
+							<div class="self-center text-xs font-medium">{$i18n.t('Top K Reranker')}</div>
+							<div class="flex items-center relative">
+								<input
+									class="flex-1 w-full rounded-lg text-sm bg-transparent outline-hidden"
+									type="number"
+									placeholder={$i18n.t('Enter Top K Reranker')}
+									bind:value={querySettings.k_reranker}
+									autocomplete="off"
+									min="0"
+								/>
+							</div>
+						</div>
+					{/if}
+
+
 					{#if querySettings.hybrid === true}
 						<div class="  mb-2.5 flex flex-col w-full justify-between">
 							<div class=" flex w-full justify-between">