Преглед изворни кода

Correctly unloads embedding/reranker models

Marko Henning пре 1 месец
родитељ
комит
39fe385017
2 измењених фајлова са 40 додато и 23 уклоњено
  1. 10 8
      backend/open_webui/main.py
  2. 30 15
      backend/open_webui/routers/retrieval.py

+ 10 - 8
backend/open_webui/main.py

@@ -924,14 +924,16 @@ try:
         app.state.config.RAG_EMBEDDING_MODEL,
         RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     )
-
-    app.state.rf = get_rf(
-        app.state.config.RAG_RERANKING_ENGINE,
-        app.state.config.RAG_RERANKING_MODEL,
-        app.state.config.RAG_EXTERNAL_RERANKER_URL,
-        app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
-        RAG_RERANKING_MODEL_AUTO_UPDATE,
-    )
+    if ENABLE_RAG_HYBRID_SEARCH and not BYPASS_EMBEDDING_AND_RETRIEVAL:
+        app.state.rf = get_rf(
+            app.state.config.RAG_RERANKING_ENGINE,
+            app.state.config.RAG_RERANKING_MODEL,
+            app.state.config.RAG_EXTERNAL_RERANKER_URL,
+            app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
+            RAG_RERANKING_MODEL_AUTO_UPDATE,
+        )
+    else:
+        app.state.rf = None
 except Exception as e:
     log.error(f"Error updating models: {e}")
     pass

+ 30 - 15
backend/open_webui/routers/retrieval.py

@@ -321,6 +321,14 @@ async def update_embedding_config(
                 form_data.embedding_batch_size
             )
 
+        # unloads current embedding model and clears VRAM cache
+        request.app.state.ef = None
+        request.app.state.EMBEDDING_FUNCTION = None
+        import gc
+        gc.collect()
+        import torch
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
         request.app.state.ef = get_ef(
             request.app.state.config.RAG_EMBEDDING_ENGINE,
             request.app.state.config.RAG_EMBEDDING_MODEL,
@@ -653,9 +661,6 @@ async def update_rag_config(
         if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
         else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
     )
-    # Free up memory if hybrid search is disabled
-    if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
-        request.app.state.rf = None
 
     request.app.state.config.TOP_K_RERANKER = (
         form_data.TOP_K_RERANKER
@@ -838,19 +843,29 @@ async def update_rag_config(
         )
 
         try:
-            request.app.state.rf = get_rf(
-                request.app.state.config.RAG_RERANKING_ENGINE,
-                request.app.state.config.RAG_RERANKING_MODEL,
-                request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
-                request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
-                True,
-            )
+            # Unloading the reranker and clear VRAM memory.
+            if request.app.state.rf != None:
+                request.app.state.rf = None
+                request.app.state.RERANKING_FUNCTION = None
+                import gc
+                gc.collect()
+                import torch
+                if torch.cuda.is_available():
+                    torch.cuda.empty_cache()
+            if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.BYPASS_EMBEDDING_AND_RETRIEVAL:
+                request.app.state.rf = get_rf(
+                    request.app.state.config.RAG_RERANKING_ENGINE,
+                    request.app.state.config.RAG_RERANKING_MODEL,
+                    request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
+                    request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
+                    True,
+                )
 
-            request.app.state.RERANKING_FUNCTION = get_reranking_function(
-                request.app.state.config.RAG_RERANKING_ENGINE,
-                request.app.state.config.RAG_RERANKING_MODEL,
-                request.app.state.rf,
-            )
+                request.app.state.RERANKING_FUNCTION = get_reranking_function(
+                    request.app.state.config.RAG_RERANKING_ENGINE,
+                    request.app.state.config.RAG_RERANKING_MODEL,
+                    request.app.state.rf,
+                )
         except Exception as e:
             log.error(f"Error loading reranking model: {e}")
             request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False