Răsfoiți Sursa

Added server side Prefixing

jayteaftw 3 luni în urmă
părinte
comite
6d2f87e904
2 a modificat fișierele cu 24 adăugiri și 3 ștergeri
  1. 1 1
      backend/open_webui/config.py
  2. 23 2
      backend/open_webui/retrieval/utils.py

+ 1 - 1
backend/open_webui/config.py

@@ -1339,7 +1339,7 @@ RAG_EMBEDDING_PASSAGE_PREFIX = (
 )
 
 RAG_EMBEDDING_PREFIX_FIELD_NAME = (
-    os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", "input_type")
+    os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", None)
 )
 
 RAG_RERANKING_MODEL = PersistentConfig(

+ 23 - 2
backend/open_webui/retrieval/utils.py

@@ -418,14 +418,22 @@ def get_model_path(model: str, update_model: bool = False):
 def generate_openai_batch_embeddings(
     model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", prefix: str = None
 ) -> Optional[list[list[float]]]:
+    
     try:
+        json_data = {
+            "input": texts, 
+            "model": model
+        }
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str):
+            json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
+
         r = requests.post(
             f"{url}/embeddings",
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",
             },
-            json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix},
+            json=json_data,
         )
         r.raise_for_status()
         data = r.json()
@@ -442,13 +450,20 @@ def generate_ollama_batch_embeddings(
     model: str, texts: list[str], url: str, key: str = "", prefix: str = None 
 ) -> Optional[list[list[float]]]:
     try:
+        json_data = {
+            "input": texts, 
+            "model": model
+        }
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str):
+            json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
+            
         r = requests.post(
             f"{url}/api/embed",
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",
             },
-            json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, RAG_EMBEDDING_PREFIX_FIELD_NAME: prefix},
+            json=json_data,
         )
         r.raise_for_status()
         data = r.json()
@@ -466,6 +481,12 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], pr
     url = kwargs.get("url", "")
     key = kwargs.get("key", "")
 
+    if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
+        if isinstance(text, list):
+            text = [f'{prefix}{text_element}' for text_element in text]
+        else:
+            text = f'{prefix}{text}'
+
     if engine == "ollama":
         if isinstance(text, list):
             embeddings = generate_ollama_batch_embeddings(