Преглед изворни кода

Ollama embeddings adapted to pydantic

henry пре 4 месеци
родитељ
комит
cc12e9e1a3
2 измењених фајлова са 10 додато и 6 уклоњено
  1. 4 1
      backend/open_webui/utils/embeddings.py
  2. 6 5
      backend/open_webui/utils/payload.py

+ 4 - 1
backend/open_webui/utils/embeddings.py

@@ -10,8 +10,10 @@ from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS
 
 from open_webui.routers.openai import embeddings as openai_embeddings
 from open_webui.routers.ollama import embeddings as ollama_embeddings
+from open_webui.routers.ollama import GenerateEmbeddingsForm
 from open_webui.routers.pipelines import process_pipeline_inlet_filter
 
+
 from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
 from open_webui.utils.response import convert_response_ollama_to_openai
 
@@ -109,9 +111,10 @@ async def generate_embeddings(
     # Ollama backend
     if model.get("owned_by") == "ollama":
         ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
+        form_obj = GenerateEmbeddingsForm(**ollama_payload)
         response = await ollama_embeddings(
             request=request,
-            form_data=ollama_payload,
+            form_data=form_obj,
             user=user,
         )
         return convert_response_ollama_to_openai(response)

+ 6 - 5
backend/open_webui/utils/payload.py

@@ -336,24 +336,25 @@ def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
     Convert an embeddings request payload from OpenAI format to Ollama format.
 
     Args:
-        openai_payload (dict): The original payload designed for OpenAI API usage. 
-            Example: {"model": "...", "input": [str, ...] or str}
+        openai_payload (dict): The original payload designed for OpenAI API usage.
 
     Returns:
         dict: A payload compatible with the Ollama API embeddings endpoint.
-            Example: {"model": "...", "input": [str, ...]}
     """
     ollama_payload = {
         "model": openai_payload.get("model")
     }
     input_value = openai_payload.get("input")
-    # Ollama expects 'input' as a list. If it's a string, wrap it in a list.
+
+    # Ollama expects 'input' as a list, and 'prompt' as a single string.
     if isinstance(input_value, list):
         ollama_payload["input"] = input_value
+        ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
     else:
         ollama_payload["input"] = [input_value]
+        ollama_payload["prompt"] = str(input_value)
 
-    # Optionally forward 'options', 'truncate', 'keep_alive' if present in OpenAI request
+    # Optionally forward other fields if present
     for optional_key in ("options", "truncate", "keep_alive"):
         if optional_key in openai_payload:
             ollama_payload[optional_key] = openai_payload[optional_key]