Timothy Jaeryang Baek 4 miesięcy temu
rodzic
commit
9306ae5972
1 zmienionych plików z 22 dodań i 52 usunięć
  1. 22 52
      backend/open_webui/retrieval/utils.py

+ 22 - 52
backend/open_webui/retrieval/utils.py

@@ -818,63 +818,33 @@ def generate_embeddings(
             text = f"{prefix}{text}"
 
     if engine == "ollama":
-        if isinstance(text, list):
-            embeddings = generate_ollama_batch_embeddings(
-                **{
-                    "model": model,
-                    "texts": text,
-                    "url": url,
-                    "key": key,
-                    "prefix": prefix,
-                    "user": user,
-                }
-            )
-        else:
-            embeddings = generate_ollama_batch_embeddings(
-                **{
-                    "model": model,
-                    "texts": [text],
-                    "url": url,
-                    "key": key,
-                    "prefix": prefix,
-                    "user": user,
-                }
-            )
+        embeddings = generate_ollama_batch_embeddings(
+            **{
+                "model": model,
+                "texts": text if isinstance(text, list) else [text],
+                "url": url,
+                "key": key,
+                "prefix": prefix,
+                "user": user,
+            }
+        )
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
-        if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(
-                model, text, url, key, prefix, user
-            )
-        else:
-            embeddings = generate_openai_batch_embeddings(
-                model, [text], url, key, prefix, user
-            )
+        embeddings = generate_openai_batch_embeddings(
+            model, text if isinstance(text, list) else [text], url, key, prefix, user
+        )
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "azure_openai":
         azure_api_version = kwargs.get("azure_api_version", "")
-        if isinstance(text, list):
-            embeddings = generate_azure_openai_batch_embeddings(
-                model,
-                text,
-                url,
-                key,
-                model,
-                azure_api_version,
-                prefix,
-                user,
-            )
-        else:
-            embeddings = generate_azure_openai_batch_embeddings(
-                model,
-                [text],
-                url,
-                key,
-                model,
-                azure_api_version,
-                prefix,
-                user,
-            )
+        embeddings = generate_azure_openai_batch_embeddings(
+            model,
+            text if isinstance(text, list) else [text],
+            url,
+            key,
+            azure_api_version,
+            prefix,
+            user,
+        )
         return embeddings[0] if isinstance(text, str) else embeddings