Browse Source

refac/enh: forward user info header to reranker

Timothy Jaeryang Baek 2 months ago
parent
commit
0013f5c1fc

+ 12 - 4
backend/open_webui/main.py

@@ -89,6 +89,7 @@ from open_webui.routers import (
 
 from open_webui.routers.retrieval import (
     get_embedding_function,
+    get_reranking_function,
     get_ef,
     get_rf,
 )
@@ -878,6 +879,7 @@ app.state.config.FIRECRAWL_API_KEY = FIRECRAWL_API_KEY
 app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH
 
 app.state.EMBEDDING_FUNCTION = None
+app.state.RERANKING_FUNCTION = None
 app.state.ef = None
 app.state.rf = None
 
@@ -906,8 +908,8 @@ except Exception as e:
 app.state.EMBEDDING_FUNCTION = get_embedding_function(
     app.state.config.RAG_EMBEDDING_ENGINE,
     app.state.config.RAG_EMBEDDING_MODEL,
-    app.state.ef,
-    (
+    embedding_function=app.state.ef,
+    url=(
         app.state.config.RAG_OPENAI_API_BASE_URL
         if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
         else (
@@ -916,7 +918,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
             else app.state.config.RAG_AZURE_OPENAI_BASE_URL
         )
     ),
-    (
+    key=(
         app.state.config.RAG_OPENAI_API_KEY
         if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
         else (
@@ -925,7 +927,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
             else app.state.config.RAG_AZURE_OPENAI_API_KEY
         )
     ),
-    app.state.config.RAG_EMBEDDING_BATCH_SIZE,
+    embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE,
     azure_api_version=(
         app.state.config.RAG_AZURE_OPENAI_API_VERSION
         if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
@@ -933,6 +935,12 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
     ),
 )
 
+app.state.RERANKING_FUNCTION = get_reranking_function(
+    app.state.config.RAG_RERANKING_ENGINE,
+    app.state.config.RAG_RERANKING_MODEL,
+    reranking_function=app.state.rf,
+)
+
 ########################################
 #
 # CODE EXECUTION

+ 16 - 2
backend/open_webui/retrieval/models/external.py

@@ -1,8 +1,10 @@
 import logging
 import requests
 from typing import Optional, List, Tuple
+from urllib.parse import quote
 
-from open_webui.env import SRC_LOG_LEVELS
+
+from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
 from open_webui.retrieval.models.base_reranker import BaseReranker
 
 
@@ -21,7 +23,9 @@ class ExternalReranker(BaseReranker):
         self.url = url
         self.model = model
 
-    def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
+    def predict(
+        self, sentences: List[Tuple[str, str]], user=None
+    ) -> Optional[List[float]]:
         query = sentences[0][0]
         docs = [i[1] for i in sentences]
 
@@ -41,6 +45,16 @@ class ExternalReranker(BaseReranker):
                 headers={
                     "Content-Type": "application/json",
                     "Authorization": f"Bearer {self.api_key}",
+                    **(
+                        {
+                            "X-OpenWebUI-User-Name": quote(user.name, safe=" "),
+                            "X-OpenWebUI-User-Id": user.id,
+                            "X-OpenWebUI-User-Email": user.email,
+                            "X-OpenWebUI-User-Role": user.role,
+                        }
+                        if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                        else {}
+                    ),
                 },
                 json=payload,
             )

+ 10 - 1
backend/open_webui/retrieval/utils.py

@@ -445,6 +445,15 @@ def get_embedding_function(
         raise ValueError(f"Unknown embedding engine: {embedding_engine}")
 
 
+def get_reranking_function(reranking_engine, reranking_model, reranking_function):
+    if reranking_engine == "external":
+        return lambda sentences, user=None: reranking_function.predict(
+            sentences, user=user
+        )
+    else:
+        return lambda sentences, user=None: reranking_function.predict(sentences)
+
+
 def get_sources_from_items(
     request,
     items,
@@ -925,7 +934,7 @@ class RerankCompressor(BaseDocumentCompressor):
         reranking = self.reranking_function is not None
 
         if reranking:
-            scores = self.reranking_function.predict(
+            scores = self.reranking_function(
                 [(query, doc.page_content) for doc in documents]
             )
         else:

+ 13 - 2
backend/open_webui/routers/retrieval.py

@@ -70,6 +70,7 @@ from open_webui.retrieval.web.external import search_external
 
 from open_webui.retrieval.utils import (
     get_embedding_function,
+    get_reranking_function,
     get_model_path,
     query_collection,
     query_collection_with_hybrid_search,
@@ -824,6 +825,12 @@ async def update_rag_config(
                 request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
                 True,
             )
+
+            request.app.state.RERANKING_FUNCTION = get_reranking_function(
+                request.app.state.config.RAG_RERANKING_ENGINE,
+                request.app.state.config.RAG_RERANKING_MODEL,
+                request.app.state.rf,
+            )
         except Exception as e:
             log.error(f"Error loading reranking model: {e}")
             request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
@@ -2042,7 +2049,9 @@ def query_doc_handler(
                     query, prefix=prefix, user=user
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
-                reranking_function=request.app.state.rf,
+                reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
+                    sentences, user=user
+                ),
                 k_reranker=form_data.k_reranker
                 or request.app.state.config.TOP_K_RERANKER,
                 r=(
@@ -2099,7 +2108,9 @@ def query_collection_handler(
                     query, prefix=prefix, user=user
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
-                reranking_function=request.app.state.rf,
+                reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
+                    sentences, user=user
+                ),
                 k_reranker=form_data.k_reranker
                 or request.app.state.config.TOP_K_RERANKER,
                 r=(

+ 3 - 1
backend/open_webui/utils/middleware.py

@@ -652,7 +652,9 @@ async def chat_completion_files_handler(
                             query, prefix=prefix, user=user
                         ),
                         k=request.app.state.config.TOP_K,
-                        reranking_function=request.app.state.rf,
+                        reranking_function=lambda sentences: request.app.state.RERANKING_FUNCTION(
+                            sentences, user=user
+                        ),
                         k_reranker=request.app.state.config.TOP_K_RERANKER,
                         r=request.app.state.config.RELEVANCE_THRESHOLD,
                         hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,