|
@@ -137,7 +137,10 @@ def get_ef(
|
|
|
|
|
|
|
|
|
def get_rf(
|
|
|
+ engine: str = "",
|
|
|
reranking_model: Optional[str] = None,
|
|
|
+ external_reranker_url: str = "",
|
|
|
+ external_reranker_api_key: str = "",
|
|
|
auto_update: bool = False,
|
|
|
):
|
|
|
rf = None
|
|
@@ -155,19 +158,33 @@ def get_rf(
|
|
|
log.error(f"ColBERT: {e}")
|
|
|
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
|
|
else:
|
|
|
- import sentence_transformers
|
|
|
+ if engine == "external":
|
|
|
+ try:
|
|
|
+ from open_webui.retrieval.models.external import ExternalReranker
|
|
|
+
|
|
|
+ rf = ExternalReranker(
|
|
|
+ url=external_reranker_url,
|
|
|
+ api_key=external_reranker_api_key,
|
|
|
+ model=reranking_model,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"ExternalReranking: {e}")
|
|
|
+ raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
|
|
+ else:
|
|
|
+ import sentence_transformers
|
|
|
+
|
|
|
+ try:
|
|
|
+ rf = sentence_transformers.CrossEncoder(
|
|
|
+ get_model_path(reranking_model, auto_update),
|
|
|
+ device=DEVICE_TYPE,
|
|
|
+ trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
+ backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
|
|
+ model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"CrossEncoder: {e}")
|
|
|
+ raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
|
|
|
|
|
- try:
|
|
|
- rf = sentence_transformers.CrossEncoder(
|
|
|
- get_model_path(reranking_model, auto_update),
|
|
|
- device=DEVICE_TYPE,
|
|
|
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
|
- backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
|
|
- model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- log.error(f"CrossEncoder: {e}")
|
|
|
- raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
|
|
return rf
|
|
|
|
|
|
|
|
@@ -225,14 +242,6 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
|
|
}
|
|
|
|
|
|
|
|
|
-@router.get("/reranking")
|
|
|
-async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
|
|
|
- return {
|
|
|
- "status": True,
|
|
|
- "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
class OpenAIConfigForm(BaseModel):
|
|
|
url: str
|
|
|
key: str
|
|
@@ -327,41 +336,6 @@ async def update_embedding_config(
|
|
|
)
|
|
|
|
|
|
|
|
|
-class RerankingModelUpdateForm(BaseModel):
|
|
|
- reranking_model: str
|
|
|
-
|
|
|
-
|
|
|
-@router.post("/reranking/update")
|
|
|
-async def update_reranking_config(
|
|
|
- request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
|
|
-):
|
|
|
- log.info(
|
|
|
- f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
|
|
- )
|
|
|
- try:
|
|
|
- request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
|
|
-
|
|
|
- try:
|
|
|
- request.app.state.rf = get_rf(
|
|
|
- request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
- True,
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- log.error(f"Error loading reranking model: {e}")
|
|
|
- request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
-
|
|
|
- return {
|
|
|
- "status": True,
|
|
|
- "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
- }
|
|
|
- except Exception as e:
|
|
|
- log.exception(f"Problem updating reranking model: {e}")
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
- detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
@router.get("/config")
|
|
|
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|
|
return {
|
|
@@ -385,6 +359,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
|
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
|
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
|
|
+ # Reranking settings
|
|
|
+ "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
+ "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
+ "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
|
|
|
+ "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
|
|
|
# Chunking settings
|
|
|
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
|
|
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
|
@@ -521,6 +500,12 @@ class ConfigForm(BaseModel):
|
|
|
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
|
|
MISTRAL_OCR_API_KEY: Optional[str] = None
|
|
|
|
|
|
+ # Reranking settings
|
|
|
+ RAG_RERANKING_MODEL: Optional[str] = None
|
|
|
+ RAG_RERANKING_ENGINE: Optional[str] = None
|
|
|
+ RAG_EXTERNAL_RERANKING_URL: Optional[str] = None
|
|
|
+ RAG_EXTERNAL_RERANKING_API_KEY: Optional[str] = None
|
|
|
+
|
|
|
# Chunking settings
|
|
|
TEXT_SPLITTER: Optional[str] = None
|
|
|
CHUNK_SIZE: Optional[int] = None
|
|
@@ -632,6 +617,49 @@ async def update_rag_config(
|
|
|
else request.app.state.config.MISTRAL_OCR_API_KEY
|
|
|
)
|
|
|
|
|
|
+ # Reranking settings
|
|
|
+ request.app.state.config.RAG_RERANKING_ENGINE = (
|
|
|
+ form_data.RAG_RERANKING_ENGINE
|
|
|
+ if form_data.RAG_RERANKING_ENGINE is not None
|
|
|
+ else request.app.state.config.RAG_RERANKING_ENGINE
|
|
|
+ )
|
|
|
+
|
|
|
+ request.app.state.config.RAG_EXTERNAL_RERANKING_URL = (
|
|
|
+ form_data.RAG_EXTERNAL_RERANKING_URL
|
|
|
+ if form_data.RAG_EXTERNAL_RERANKING_URL is not None
|
|
|
+ else request.app.state.config.RAG_EXTERNAL_RERANKING_URL
|
|
|
+ )
|
|
|
+
|
|
|
+ request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY = (
|
|
|
+ form_data.RAG_EXTERNAL_RERANKING_API_KEY
|
|
|
+ if form_data.RAG_EXTERNAL_RERANKING_API_KEY is not None
|
|
|
+ else request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY
|
|
|
+ )
|
|
|
+
|
|
|
+ log.info(
|
|
|
+ f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
|
|
|
+
|
|
|
+ try:
|
|
|
+ request.app.state.rf = get_rf(
|
|
|
+ request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
+ request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
+ request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
|
|
|
+ request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
|
|
|
+ True,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.error(f"Error loading reranking model: {e}")
|
|
|
+ request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(f"Problem updating reranking model: {e}")
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
+ detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
+ )
|
|
|
+
|
|
|
# Chunking settings
|
|
|
request.app.state.config.TEXT_SPLITTER = (
|
|
|
form_data.TEXT_SPLITTER
|
|
@@ -788,6 +816,11 @@ async def update_rag_config(
|
|
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
|
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
|
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
|
|
+ # Reranking settings
|
|
|
+ "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
+ "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
+ "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
|
|
|
+ "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
|
|
|
# Chunking settings
|
|
|
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
|
|
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|