embeddings.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import random
  2. import logging
  3. import sys
  4. from fastapi import Request
  5. from open_webui.models.users import UserModel
  6. from open_webui.models.models import Models
  7. from open_webui.utils.models import check_model_access
  8. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
  9. from open_webui.routers.openai import embeddings as openai_embeddings
  10. from open_webui.routers.ollama import embeddings as ollama_embeddings
  11. from open_webui.routers.ollama import GenerateEmbeddingsForm
  12. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  13. from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
  14. from open_webui.utils.response import convert_embedding_response_ollama_to_openai
  15. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  16. log = logging.getLogger(__name__)
  17. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  18. async def generate_embeddings(
  19. request: Request,
  20. form_data: dict,
  21. user: UserModel,
  22. bypass_filter: bool = False,
  23. ):
  24. """
  25. Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
  26. Args:
  27. request (Request): The FastAPI request context.
  28. form_data (dict): The input data sent to the endpoint.
  29. user (UserModel): The authenticated user.
  30. bypass_filter (bool): If True, disables access filtering (default False).
  31. Returns:
  32. dict: The embeddings response, following OpenAI API compatibility.
  33. """
  34. if BYPASS_MODEL_ACCESS_CONTROL:
  35. bypass_filter = True
  36. # Attach extra metadata from request.state if present
  37. if hasattr(request.state, "metadata"):
  38. if "metadata" not in form_data:
  39. form_data["metadata"] = request.state.metadata
  40. else:
  41. form_data["metadata"] = {
  42. **form_data["metadata"],
  43. **request.state.metadata,
  44. }
  45. # If "direct" flag present, use only that model
  46. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  47. models = {
  48. request.state.model["id"]: request.state.model,
  49. }
  50. else:
  51. models = request.app.state.MODELS
  52. model_id = form_data.get("model")
  53. if model_id not in models:
  54. raise Exception("Model not found")
  55. model = models[model_id]
  56. # Access filtering
  57. if not getattr(request.state, "direct", False):
  58. if not bypass_filter and user.role == "user":
  59. check_model_access(user, model)
  60. # Arena "meta-model": select a submodel at random
  61. if model.get("owned_by") == "arena":
  62. model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
  63. filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
  64. if model_ids and filter_mode == "exclude":
  65. model_ids = [
  66. m["id"]
  67. for m in list(models.values())
  68. if m.get("owned_by") != "arena" and m["id"] not in model_ids
  69. ]
  70. if isinstance(model_ids, list) and model_ids:
  71. selected_model_id = random.choice(model_ids)
  72. else:
  73. model_ids = [
  74. m["id"]
  75. for m in list(models.values())
  76. if m.get("owned_by") != "arena"
  77. ]
  78. selected_model_id = random.choice(model_ids)
  79. inner_form = dict(form_data)
  80. inner_form["model"] = selected_model_id
  81. response = await generate_embeddings(
  82. request, inner_form, user, bypass_filter=True
  83. )
  84. # Tag which concreted model was chosen
  85. if isinstance(response, dict):
  86. response = {
  87. **response,
  88. "selected_model_id": selected_model_id,
  89. }
  90. return response
  91. # Pipeline/Function models
  92. if model.get("pipe"):
  93. # The pipeline handler should provide OpenAI-compatible schema
  94. return await process_pipeline_inlet_filter(request, form_data, user, models)
  95. # Ollama backend
  96. if model.get("owned_by") == "ollama":
  97. ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
  98. form_obj = GenerateEmbeddingsForm(**ollama_payload)
  99. response = await ollama_embeddings(
  100. request=request,
  101. form_data=form_obj,
  102. user=user,
  103. )
  104. return convert_embedding_response_ollama_to_openai(response)
  105. # Default: OpenAI or compatible backend
  106. return await openai_embeddings(
  107. request=request,
  108. form_data=form_data,
  109. user=user,
  110. )