|
|
@@ -4,7 +4,6 @@ import mimetypes
|
|
|
import os
|
|
|
import shutil
|
|
|
import asyncio
|
|
|
-import torch
|
|
|
|
|
|
import uuid
|
|
|
from datetime import datetime
|
|
|
@@ -287,8 +286,10 @@ async def update_embedding_config(
|
|
|
request.app.state.EMBEDDING_FUNCTION = None
|
|
|
import gc
|
|
|
gc.collect()
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.empty_cache()
|
|
|
+ if DEVICE_TYPE == 'cuda':
|
|
|
+ import torch
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
try:
|
|
|
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
|
|
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
|
|
@@ -820,8 +821,10 @@ async def update_rag_config(
|
|
|
request.app.state.RERANKING_FUNCTION = None
|
|
|
import gc
|
|
|
gc.collect()
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.empty_cache()
|
|
|
+ if DEVICE_TYPE == 'cuda':
|
|
|
+ import torch
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
request.app.state.config.RAG_RERANKING_ENGINE = (
|
|
|
form_data.RAG_RERANKING_ENGINE
|
|
|
if form_data.RAG_RERANKING_ENGINE is not None
|