|
@@ -5,7 +5,6 @@ import os
|
|
import shutil
|
|
import shutil
|
|
import asyncio
|
|
import asyncio
|
|
|
|
|
|
-
|
|
|
|
import uuid
|
|
import uuid
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
@@ -281,6 +280,18 @@ async def update_embedding_config(
|
|
log.info(
|
|
log.info(
|
|
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
|
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
|
)
|
|
)
|
|
|
|
+ if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
|
|
|
+ # unloads current internal embedding model and clears VRAM cache
|
|
|
|
+ request.app.state.ef = None
|
|
|
|
+ request.app.state.EMBEDDING_FUNCTION = None
|
|
|
|
+ import gc
|
|
|
|
+
|
|
|
|
+ gc.collect()
|
|
|
|
+ if DEVICE_TYPE == "cuda":
|
|
|
|
+ import torch
|
|
|
|
+
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
+ torch.cuda.empty_cache()
|
|
try:
|
|
try:
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
@@ -653,9 +664,6 @@ async def update_rag_config(
|
|
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
|
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
|
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
|
)
|
|
)
|
|
- # Free up memory if hybrid search is disabled
|
|
|
|
- if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
|
|
|
- request.app.state.rf = None
|
|
|
|
|
|
|
|
request.app.state.config.TOP_K_RERANKER = (
|
|
request.app.state.config.TOP_K_RERANKER = (
|
|
form_data.TOP_K_RERANKER
|
|
form_data.TOP_K_RERANKER
|
|
@@ -809,6 +817,18 @@ async def update_rag_config(
|
|
)
|
|
)
|
|
|
|
|
|
# Reranking settings
|
|
# Reranking settings
|
|
|
|
+ if request.app.state.config.RAG_RERANKING_ENGINE == "":
|
|
|
|
+ # Unloading the internal reranker and clear VRAM memory
|
|
|
|
+ request.app.state.rf = None
|
|
|
|
+ request.app.state.RERANKING_FUNCTION = None
|
|
|
|
+ import gc
|
|
|
|
+
|
|
|
|
+ gc.collect()
|
|
|
|
+ if DEVICE_TYPE == "cuda":
|
|
|
|
+ import torch
|
|
|
|
+
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
+ torch.cuda.empty_cache()
|
|
request.app.state.config.RAG_RERANKING_ENGINE = (
|
|
request.app.state.config.RAG_RERANKING_ENGINE = (
|
|
form_data.RAG_RERANKING_ENGINE
|
|
form_data.RAG_RERANKING_ENGINE
|
|
if form_data.RAG_RERANKING_ENGINE is not None
|
|
if form_data.RAG_RERANKING_ENGINE is not None
|
|
@@ -838,19 +858,23 @@ async def update_rag_config(
|
|
)
|
|
)
|
|
|
|
|
|
try:
|
|
try:
|
|
- request.app.state.rf = get_rf(
|
|
|
|
- request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
|
- request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
|
- request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
|
|
|
- request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
|
|
|
- True,
|
|
|
|
- )
|
|
|
|
|
|
+ if (
|
|
|
|
+ request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
|
|
|
+ and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
|
|
|
+ ):
|
|
|
|
+ request.app.state.rf = get_rf(
|
|
|
|
+ request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
|
+ request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
|
+ request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
|
|
|
+ request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
|
|
|
+ True,
|
|
|
|
+ )
|
|
|
|
|
|
- request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
|
|
|
- request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
|
- request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
|
- request.app.state.rf,
|
|
|
|
- )
|
|
|
|
|
|
+ request.app.state.RERANKING_FUNCTION = get_reranking_function(
|
|
|
|
+ request.app.state.config.RAG_RERANKING_ENGINE,
|
|
|
|
+ request.app.state.config.RAG_RERANKING_MODEL,
|
|
|
|
+ request.app.state.rf,
|
|
|
|
+ )
|
|
except Exception as e:
|
|
except Exception as e:
|
|
log.error(f"Error loading reranking model: {e}")
|
|
log.error(f"Error loading reranking model: {e}")
|
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|