Browse Source

Merge pull request #16779 from mahenning/fix--clean-unload-embed/reranker-models

Fix: Free VRAM memory when updating embedding / reranking models
Tim Jaeryang Baek 1 month ago
parent
commit
5a66f69460
2 changed files with 53 additions and 24 deletions
  1. 13 8
      backend/open_webui/main.py
  2. 40 16
      backend/open_webui/routers/retrieval.py

+ 13 - 8
backend/open_webui/main.py

@@ -924,14 +924,19 @@ try:
         app.state.config.RAG_EMBEDDING_MODEL,
         RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     )
-
-    app.state.rf = get_rf(
-        app.state.config.RAG_RERANKING_ENGINE,
-        app.state.config.RAG_RERANKING_MODEL,
-        app.state.config.RAG_EXTERNAL_RERANKER_URL,
-        app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
-        RAG_RERANKING_MODEL_AUTO_UPDATE,
-    )
+    if (
+        app.state.config.ENABLE_RAG_HYBRID_SEARCH
+        and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
+    ):
+        app.state.rf = get_rf(
+            app.state.config.RAG_RERANKING_ENGINE,
+            app.state.config.RAG_RERANKING_MODEL,
+            app.state.config.RAG_EXTERNAL_RERANKER_URL,
+            app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
+            RAG_RERANKING_MODEL_AUTO_UPDATE,
+        )
+    else:
+        app.state.rf = None
 except Exception as e:
     log.error(f"Error updating models: {e}")
     pass

+ 40 - 16
backend/open_webui/routers/retrieval.py

@@ -5,7 +5,6 @@ import os
 import shutil
 import asyncio
 
-
 import uuid
 from datetime import datetime
 from pathlib import Path
@@ -281,6 +280,18 @@ async def update_embedding_config(
     log.info(
         f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
     )
+    if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
+        # unloads current internal embedding model and clears VRAM cache
+        request.app.state.ef = None
+        request.app.state.EMBEDDING_FUNCTION = None
+        import gc
+
+        gc.collect()
+        if DEVICE_TYPE == "cuda":
+            import torch
+
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
     try:
         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@@ -653,9 +664,6 @@ async def update_rag_config(
         if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
         else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
     )
-    # Free up memory if hybrid search is disabled
-    if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
-        request.app.state.rf = None
 
     request.app.state.config.TOP_K_RERANKER = (
         form_data.TOP_K_RERANKER
@@ -809,6 +817,18 @@ async def update_rag_config(
     )
 
     # Reranking settings
+    if request.app.state.config.RAG_RERANKING_ENGINE == "":
+        # Unloading the internal reranker and clear VRAM memory
+        request.app.state.rf = None
+        request.app.state.RERANKING_FUNCTION = None
+        import gc
+
+        gc.collect()
+        if DEVICE_TYPE == "cuda":
+            import torch
+
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
     request.app.state.config.RAG_RERANKING_ENGINE = (
         form_data.RAG_RERANKING_ENGINE
         if form_data.RAG_RERANKING_ENGINE is not None
@@ -838,19 +858,23 @@ async def update_rag_config(
         )
 
         try:
-            request.app.state.rf = get_rf(
-                request.app.state.config.RAG_RERANKING_ENGINE,
-                request.app.state.config.RAG_RERANKING_MODEL,
-                request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
-                request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
-                True,
-            )
+            if (
+                request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
+                and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
+            ):
+                request.app.state.rf = get_rf(
+                    request.app.state.config.RAG_RERANKING_ENGINE,
+                    request.app.state.config.RAG_RERANKING_MODEL,
+                    request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
+                    request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
+                    True,
+                )
 
-            request.app.state.RERANKING_FUNCTION = get_reranking_function(
-                request.app.state.config.RAG_RERANKING_ENGINE,
-                request.app.state.config.RAG_RERANKING_MODEL,
-                request.app.state.rf,
-            )
+                request.app.state.RERANKING_FUNCTION = get_reranking_function(
+                    request.app.state.config.RAG_RERANKING_ENGINE,
+                    request.app.state.config.RAG_RERANKING_MODEL,
+                    request.app.state.rf,
+                )
         except Exception as e:
             log.error(f"Error loading reranking model: {e}")
             request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False