ソースを参照

Unloads only if internal models are used.

Marko Henning 1 ヶ月 前
コミット
6663fc3a6c
1 ファイル変更17 行追加18 行削除
  1. 17 18
      backend/open_webui/routers/retrieval.py

+ 17 - 18
backend/open_webui/routers/retrieval.py

@@ -4,7 +4,7 @@ import mimetypes
 import os
 import shutil
 import asyncio
-
+import torch
 
 import uuid
 from datetime import datetime
@@ -281,6 +281,14 @@ async def update_embedding_config(
     log.info(
         f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
     )
+    if request.app.state.config.RAG_EMBEDDING_ENGINE == '':
+        # unloads current internal embedding model and clears VRAM cache
+        request.app.state.ef = None
+        request.app.state.EMBEDDING_FUNCTION = None
+        import gc
+        gc.collect()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
     try:
         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@@ -321,14 +329,6 @@ async def update_embedding_config(
                 form_data.embedding_batch_size
             )
 
-        # unloads current embedding model and clears VRAM cache
-        request.app.state.ef = None
-        request.app.state.EMBEDDING_FUNCTION = None
-        import gc
-        gc.collect()
-        import torch
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
         request.app.state.ef = get_ef(
             request.app.state.config.RAG_EMBEDDING_ENGINE,
             request.app.state.config.RAG_EMBEDDING_MODEL,
@@ -814,6 +814,14 @@ async def update_rag_config(
     )
 
     # Reranking settings
+    if request.app.state.config.RAG_RERANKING_ENGINE == '':
+        # Unloading the internal reranker and clear VRAM memory
+        request.app.state.rf = None
+        request.app.state.RERANKING_FUNCTION = None
+        import gc
+        gc.collect()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
     request.app.state.config.RAG_RERANKING_ENGINE = (
         form_data.RAG_RERANKING_ENGINE
         if form_data.RAG_RERANKING_ENGINE is not None
@@ -843,15 +851,6 @@ async def update_rag_config(
         )
 
         try:
-            # Unloading the reranker and clear VRAM memory.
-            if request.app.state.rf != None:
-                request.app.state.rf = None
-                request.app.state.RERANKING_FUNCTION = None
-                import gc
-                gc.collect()
-                import torch
-                if torch.cuda.is_available():
-                    torch.cuda.empty_cache()
             if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
                 request.app.state.rf = get_rf(
                     request.app.state.config.RAG_RERANKING_ENGINE,