Browse Source

feat: external reranker

Co-Authored-By: Brendan Campbell <20541191+bcambs09@users.noreply.github.com>
Timothy Jaeryang Baek 1 month ago
parent
commit
d5fd3b3600

+ 19 - 0
backend/open_webui/config.py

@@ -1965,6 +1965,12 @@ RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get(
     "RAG_EMBEDDING_PREFIX_FIELD_NAME", None
 )
 
+RAG_RERANKING_ENGINE = PersistentConfig(
+    "RAG_RERANKING_ENGINE",
+    "rag.reranking_engine",
+    os.environ.get("RAG_RERANKING_ENGINE", ""),
+)
+
 RAG_RERANKING_MODEL = PersistentConfig(
     "RAG_RERANKING_MODEL",
     "rag.reranking_model",
@@ -1973,6 +1979,7 @@ RAG_RERANKING_MODEL = PersistentConfig(
 if RAG_RERANKING_MODEL.value != "":
     log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
 
+
 RAG_RERANKING_MODEL_AUTO_UPDATE = (
     not OFFLINE_MODE
     and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
@@ -1982,6 +1989,18 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
 )
 
+RAG_EXTERNAL_RERANKER_URL = PersistentConfig(
+    "RAG_EXTERNAL_RERANKER_URL",
+    "rag.external_reranker_url",
+    os.environ.get("RAG_EXTERNAL_RERANKER_URL", ""),
+)
+
+RAG_EXTERNAL_RERANKER_API_KEY = PersistentConfig(
+    "RAG_EXTERNAL_RERANKER_API_KEY",
+    "rag.external_reranker_api_key",
+    os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""),
+)
+
 
 RAG_TEXT_SPLITTER = PersistentConfig(
     "RAG_TEXT_SPLITTER",

+ 11 - 0
backend/open_webui/main.py

@@ -188,7 +188,10 @@ from open_webui.config import (
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+    RAG_RERANKING_ENGINE,
     RAG_RERANKING_MODEL,
+    RAG_EXTERNAL_RERANKER_URL,
+    RAG_EXTERNAL_RERANKER_API_KEY,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_EMBEDDING_ENGINE,
@@ -655,7 +658,12 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
+
+app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
 app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
+app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL
+app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY
+
 app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 
 app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
@@ -736,7 +744,10 @@ try:
     )
 
     app.state.rf = get_rf(
+        app.state.config.RAG_RERANKING_ENGINE,
         app.state.config.RAG_RERANKING_MODEL,
+        app.state.config.RAG_EXTERNAL_RERANKER_URL,
+        app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
         RAG_RERANKING_MODEL_AUTO_UPDATE,
     )
 except Exception as e:

+ 58 - 0
backend/open_webui/retrieval/models/external.py

@@ -0,0 +1,58 @@
+import logging
+import requests
+from typing import Optional, List, Tuple
+
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+class ExternalReranker:
+    def __init__(
+        self,
+        api_key: str,
+        url: str = "http://localhost:8080/v1/rerank",
+        model: str = "reranker",
+    ):
+        self.api_key = api_key
+        self.url = url
+        self.model = model
+
+    def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
+        query = sentences[0][0]
+        docs = [i[1] for i in sentences]
+
+        payload = {
+            "model": self.model,
+            "query": query,
+            "documents": docs,
+            "top_n": len(docs),
+        }
+
+        try:
+            log.info(f"ExternalReranker:predict:model {self.model}")
+            log.info(f"ExternalReranker:predict:query {query}")
+
+            r = requests.post(
+                f"{self.url}",
+                headers={
+                    "Content-Type": "application/json",
+                    "Authorization": f"Bearer {self.api_key}",
+                },
+                json=payload,
+            )
+
+            r.raise_for_status()
+            data = r.json()
+
+            if "results" in data:
+                sorted_results = sorted(data["results"], key=lambda x: x["index"])
+                return [result["relevance_score"] for result in sorted_results]
+            else:
+                log.error("No results found in external reranking response")
+                return None
+
+        except Exception as e:
+            log.exception(f"Error in external reranking: {e}")
+            return None

+ 88 - 55
backend/open_webui/routers/retrieval.py

@@ -137,7 +137,10 @@ def get_ef(
 
 
 def get_rf(
+    engine: str = "",
     reranking_model: Optional[str] = None,
+    external_reranker_url: str = "",
+    external_reranker_api_key: str = "",
     auto_update: bool = False,
 ):
     rf = None
@@ -155,19 +158,33 @@ def get_rf(
                 log.error(f"ColBERT: {e}")
                 raise Exception(ERROR_MESSAGES.DEFAULT(e))
         else:
-            import sentence_transformers
+            if engine == "external":
+                try:
+                    from open_webui.retrieval.models.external import ExternalReranker
+
+                    rf = ExternalReranker(
+                        url=external_reranker_url,
+                        api_key=external_reranker_api_key,
+                        model=reranking_model,
+                    )
+                except Exception as e:
+                    log.error(f"ExternalReranking: {e}")
+                    raise Exception(ERROR_MESSAGES.DEFAULT(e))
+            else:
+                import sentence_transformers
+
+                try:
+                    rf = sentence_transformers.CrossEncoder(
+                        get_model_path(reranking_model, auto_update),
+                        device=DEVICE_TYPE,
+                        trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+                        backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
+                        model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
+                    )
+                except Exception as e:
+                    log.error(f"CrossEncoder: {e}")
+                    raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
 
-            try:
-                rf = sentence_transformers.CrossEncoder(
-                    get_model_path(reranking_model, auto_update),
-                    device=DEVICE_TYPE,
-                    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-                    backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
-                    model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
-                )
-            except Exception as e:
-                log.error(f"CrossEncoder: {e}")
-                raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
     return rf
 
 
@@ -225,14 +242,6 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
     }
 
 
-@router.get("/reranking")
-async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
-    return {
-        "status": True,
-        "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
-    }
-
-
 class OpenAIConfigForm(BaseModel):
     url: str
     key: str
@@ -327,41 +336,6 @@ async def update_embedding_config(
         )
 
 
-class RerankingModelUpdateForm(BaseModel):
-    reranking_model: str
-
-
-@router.post("/reranking/update")
-async def update_reranking_config(
-    request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
-):
-    log.info(
-        f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
-    )
-    try:
-        request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
-
-        try:
-            request.app.state.rf = get_rf(
-                request.app.state.config.RAG_RERANKING_MODEL,
-                True,
-            )
-        except Exception as e:
-            log.error(f"Error loading reranking model: {e}")
-            request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
-
-        return {
-            "status": True,
-            "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
-        }
-    except Exception as e:
-        log.exception(f"Problem updating reranking model: {e}")
-        raise HTTPException(
-            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-            detail=ERROR_MESSAGES.DEFAULT(e),
-        )
-
-
 @router.get("/config")
 async def get_rag_config(request: Request, user=Depends(get_admin_user)):
     return {
@@ -385,6 +359,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
         "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
         "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
         "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
+        # Reranking settings
+        "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
+        "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
+        "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
+        "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
         # Chunking settings
         "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
         "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
@@ -521,6 +500,12 @@ class ConfigForm(BaseModel):
     DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
     MISTRAL_OCR_API_KEY: Optional[str] = None
 
+    # Reranking settings
+    RAG_RERANKING_MODEL: Optional[str] = None
+    RAG_RERANKING_ENGINE: Optional[str] = None
+    RAG_EXTERNAL_RERANKING_URL: Optional[str] = None
+    RAG_EXTERNAL_RERANKING_API_KEY: Optional[str] = None
+
     # Chunking settings
     TEXT_SPLITTER: Optional[str] = None
     CHUNK_SIZE: Optional[int] = None
@@ -632,6 +617,49 @@ async def update_rag_config(
         else request.app.state.config.MISTRAL_OCR_API_KEY
     )
 
+    # Reranking settings
+    request.app.state.config.RAG_RERANKING_ENGINE = (
+        form_data.RAG_RERANKING_ENGINE
+        if form_data.RAG_RERANKING_ENGINE is not None
+        else request.app.state.config.RAG_RERANKING_ENGINE
+    )
+
+    request.app.state.config.RAG_EXTERNAL_RERANKING_URL = (
+        form_data.RAG_EXTERNAL_RERANKING_URL
+        if form_data.RAG_EXTERNAL_RERANKING_URL is not None
+        else request.app.state.config.RAG_EXTERNAL_RERANKING_URL
+    )
+
+    request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY = (
+        form_data.RAG_EXTERNAL_RERANKING_API_KEY
+        if form_data.RAG_EXTERNAL_RERANKING_API_KEY is not None
+        else request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY
+    )
+
+    log.info(
+        f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
+    )
+    try:
+        request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
+
+        try:
+            request.app.state.rf = get_rf(
+                request.app.state.config.RAG_RERANKING_ENGINE,
+                request.app.state.config.RAG_RERANKING_MODEL,
+                request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
+                request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
+                True,
+            )
+        except Exception as e:
+            log.error(f"Error loading reranking model: {e}")
+            request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
+    except Exception as e:
+        log.exception(f"Problem updating reranking model: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
     # Chunking settings
     request.app.state.config.TEXT_SPLITTER = (
         form_data.TEXT_SPLITTER
@@ -788,6 +816,11 @@ async def update_rag_config(
         "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
         "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
         "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
+        # Reranking settings
+        "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
+        "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
+        "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
+        "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
         # Chunking settings
         "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
         "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,