Browse Source

enh: sentence transformers env vars

Co-Authored-By: DrZoidberg09 <96449693+drzoidberg09@users.noreply.github.com>
Timothy Jaeryang Baek 2 months ago
parent
commit
732d7aee70
2 changed files with 56 additions and 0 deletions
  1. 47 0
      backend/open_webui/env.py
  2. 9 0
      backend/open_webui/routers/retrieval.py

+ 47 - 0
backend/open_webui/env.py

@@ -451,6 +451,51 @@ AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
     os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
 )
 
+
+####################################
+# SENTENCE TRANSFORMERS
+####################################
+
+
+SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
+if SENTENCE_TRANSFORMERS_BACKEND == "":
+    SENTENCE_TRANSFORMERS_BACKEND = "torch"
+
+
+SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
+    "SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
+)
+if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
+    SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
+else:
+    try:
+        SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
+            SENTENCE_TRANSFORMERS_MODEL_KWARGS
+        )
+    except Exception:
+        SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
+
+
+SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
+    "SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
+)
+if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
+    SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
+
+
+SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
+    "SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
+)
+if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
+    SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
+else:
+    try:
+        SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
+            SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
+        )
+    except Exception:
+        SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
+
 ####################################
 # OFFLINE_MODE
 ####################################
@@ -460,6 +505,7 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
 if OFFLINE_MODE:
     os.environ["HF_HUB_OFFLINE"] = "1"
 
+
 ####################################
 # AUDIT LOGGING
 ####################################
@@ -481,6 +527,7 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
 AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
 AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
 
+
 ####################################
 # OPENTELEMETRY
 ####################################

+ 9 - 0
backend/open_webui/routers/retrieval.py

@@ -91,7 +91,12 @@ from open_webui.env import (
     SRC_LOG_LEVELS,
     DEVICE_TYPE,
     DOCKER,
+    SENTENCE_TRANSFORMERS_BACKEND,
+    SENTENCE_TRANSFORMERS_MODEL_KWARGS,
+    SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
+    SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
 )
+
 from open_webui.constants import ERROR_MESSAGES
 
 log = logging.getLogger(__name__)
@@ -118,6 +123,8 @@ def get_ef(
                 get_model_path(embedding_model, auto_update),
                 device=DEVICE_TYPE,
                 trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+                backend=SENTENCE_TRANSFORMERS_BACKEND,
+                model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS,
             )
         except Exception as e:
             log.debug(f"Error loading SentenceTransformer: {e}")
@@ -151,6 +158,8 @@ def get_rf(
                     get_model_path(reranking_model, auto_update),
                     device=DEVICE_TYPE,
                     trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+                    backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
+                    model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
                 )
             except Exception as e:
                 log.error(f"CrossEncoder: {e}")