ソースを参照

make weight for bm25 retriever in hybrid search ui-configurable

Jan Kessler 4 ヶ月 前
コミット
b5ddaf6417

+ 5 - 0
backend/open_webui/config.py

@@ -1928,6 +1928,11 @@ RAG_RELEVANCE_THRESHOLD = PersistentConfig(
     "rag.relevance_threshold",
     float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
 )
+RAG_BM25_WEIGHT = PersistentConfig(
+    "RAG_BM25_WEIGHT",
+    "rag.bm25_weight",
+    float(os.environ.get("RAG_BM25_WEIGHT", "0.5")),
+)
 
 ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
     "ENABLE_RAG_HYBRID_SEARCH",

+ 4 - 2
backend/open_webui/main.py

@@ -196,7 +196,10 @@ from open_webui.config import (
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_BATCH_SIZE,
+    RAG_TOP_K,
+    RAG_TOP_K_RERANKER,
     RAG_RELEVANCE_THRESHOLD,
+    RAG_BM25_WEIGHT,
     RAG_ALLOWED_FILE_EXTENSIONS,
     RAG_FILE_MAX_COUNT,
     RAG_FILE_MAX_SIZE,
@@ -217,8 +220,6 @@ from open_webui.config import (
     DOCUMENT_INTELLIGENCE_ENDPOINT,
     DOCUMENT_INTELLIGENCE_KEY,
     MISTRAL_OCR_API_KEY,
-    RAG_TOP_K,
-    RAG_TOP_K_RERANKER,
     RAG_TEXT_SPLITTER,
     TIKTOKEN_ENCODING_NAME,
     PDF_EXTRACT_IMAGES,
@@ -646,6 +647,7 @@ app.state.FUNCTIONS = {}
 app.state.config.TOP_K = RAG_TOP_K
 app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
 app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
+app.state.config.BM25_WEIGHT = RAG_BM25_WEIGHT
 app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
 app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
 app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT

+ 15 - 3
backend/open_webui/retrieval/utils.py

@@ -29,6 +29,7 @@ from open_webui.config import (
     RAG_EMBEDDING_QUERY_PREFIX,
     RAG_EMBEDDING_CONTENT_PREFIX,
     RAG_EMBEDDING_PREFIX_FIELD_NAME,
+    RAG_BM25_WEIGHT,
 )
 
 log = logging.getLogger(__name__)
@@ -131,9 +132,20 @@ def query_doc_with_hybrid_search(
             top_k=k,
         )
 
-        ensemble_retriever = EnsembleRetriever(
-            retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
-        )
+        if RAG_BM25_WEIGHT <= 0:
+            ensemble_retriever = EnsembleRetriever(
+                retrievers=[vector_search_retriever], weights=[1.]
+            )
+        elif RAG_BM25_WEIGHT >= 1:
+            ensemble_retriever = EnsembleRetriever(
+                retrievers=[bm25_retriever], weights=[1.]
+            )
+        else:
+            ensemble_retriever = EnsembleRetriever(
+                retrievers=[bm25_retriever, vector_search_retriever],
+                weights=[RAG_BM25_WEIGHT, 1. - RAG_BM25_WEIGHT]
+            )
+
         compressor = RerankCompressor(
             embedding_function=embedding_function,
             top_n=k_reranker,

+ 8 - 0
backend/open_webui/routers/retrieval.py

@@ -349,6 +349,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
         "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
         "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
         "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
+        "BM25_WEIGHT": request.app.state.config.BM25_WEIGHT,
         # Content extraction settings
         "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
         "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
@@ -492,6 +493,7 @@ class ConfigForm(BaseModel):
     ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
     TOP_K_RERANKER: Optional[int] = None
     RELEVANCE_THRESHOLD: Optional[float] = None
+    BM25_WEIGHT: Optional[float] = None
 
     # Content extraction settings
     CONTENT_EXTRACTION_ENGINE: Optional[str] = None
@@ -578,6 +580,11 @@ async def update_rag_config(
         if form_data.RELEVANCE_THRESHOLD is not None
         else request.app.state.config.RELEVANCE_THRESHOLD
     )
+    request.app.state.config.BM25_WEIGHT = (
+        form_data.BM25_WEIGHT
+        if form_data.BM25_WEIGHT is not None
+        else request.app.state.config.BM25_WEIGHT
+    )
 
     # Content extraction settings
     request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
@@ -837,6 +844,7 @@ async def update_rag_config(
         "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
         "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
         "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
+        "BM25_WEIGHT": request.app.state.config.BM25_WEIGHT,
         # Content extraction settings
         "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
         "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,

+ 18 - 0
src/lib/components/admin/Settings/Documents.svelte

@@ -770,6 +770,24 @@
 									</div>
 								</div>
 							{/if}
+
+							{#if RAGConfig.ENABLE_RAG_HYBRID_SEARCH === true}
+								<div class="mb-2.5 flex w-full justify-between">
+									<div class="self-center text-xs font-medium">{$i18n.t('BM25 Weight')}</div>
+									<div class="flex items-center relative">
+										<input
+											class="flex-1 w-full text-sm bg-transparent outline-hidden"
+											type="number"
+											step="0.01"
+											placeholder={$i18n.t('Enter BM25 Weight')}
+											bind:value={RAGConfig.BM25_WEIGHT}
+											autocomplete="off"
+											min="0.0"
+											max="1.0"
+										/>
+									</div>
+								</div>
+							{/if}
 						{/if}
 
 						<div class="  mb-2.5 flex flex-col w-full justify-between">

+ 1 - 0
src/lib/i18n/locales/en-US/translation.json

@@ -425,6 +425,7 @@
 	"Enter Application DN Password": "",
 	"Enter Bing Search V7 Endpoint": "",
 	"Enter Bing Search V7 Subscription Key": "",
+	"Enter BM25 Weight": "",
 	"Enter Bocha Search API Key": "",
 	"Enter Brave Search API Key": "",
 	"Enter certificate path": "",