Просмотр исходного кода

Chage torch import to conditional import

Marko Henning 6 месяцев назад
Родитель
Сommit
b3de3295d6
1 измененных файлов с 8 добавлено и 5 удалено
  1. 8 5
      backend/open_webui/routers/retrieval.py

+ 8 - 5
backend/open_webui/routers/retrieval.py

@@ -4,7 +4,6 @@ import mimetypes
 import os
 import shutil
 import asyncio
-import torch
 
 import uuid
 from datetime import datetime
@@ -287,8 +286,10 @@ async def update_embedding_config(
         request.app.state.EMBEDDING_FUNCTION = None
         import gc
         gc.collect()
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
+        if DEVICE_TYPE == 'cuda':
+            import torch
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
     try:
         request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
         request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@@ -820,8 +821,10 @@ async def update_rag_config(
         request.app.state.RERANKING_FUNCTION = None
         import gc
         gc.collect()
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
+        if DEVICE_TYPE == 'cuda':
+            import torch
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
     request.app.state.config.RAG_RERANKING_ENGINE = (
         form_data.RAG_RERANKING_ENGINE
         if form_data.RAG_RERANKING_ENGINE is not None