embeddings.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import random
  2. import logging
  3. import sys
  4. from fastapi import Request
  5. from open_webui.models.users import UserModel
  6. from open_webui.models.models import Models
  7. from open_webui.utils.models import check_model_access
  8. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
  9. from open_webui.routers.openai import embeddings as openai_embeddings
  10. from open_webui.routers.ollama import embeddings as ollama_embeddings
  11. from open_webui.routers.pipelines import process_pipeline_inlet_filter
  12. from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
  13. from open_webui.utils.response import convert_response_ollama_to_openai
  14. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  15. log = logging.getLogger(__name__)
  16. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  17. async def generate_embeddings(
  18. request: Request,
  19. form_data: dict,
  20. user: UserModel,
  21. bypass_filter: bool = False,
  22. ):
  23. """
  24. Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
  25. Args:
  26. request (Request): The FastAPI request context.
  27. form_data (dict): The input data sent to the endpoint.
  28. user (UserModel): The authenticated user.
  29. bypass_filter (bool): If True, disables access filtering (default False).
  30. Returns:
  31. dict: The embeddings response, following OpenAI API compatibility.
  32. """
  33. if BYPASS_MODEL_ACCESS_CONTROL:
  34. bypass_filter = True
  35. # Attach extra metadata from request.state if present
  36. if hasattr(request.state, "metadata"):
  37. if "metadata" not in form_data:
  38. form_data["metadata"] = request.state.metadata
  39. else:
  40. form_data["metadata"] = {
  41. **form_data["metadata"],
  42. **request.state.metadata,
  43. }
  44. # If "direct" flag present, use only that model
  45. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  46. models = {
  47. request.state.model["id"]: request.state.model,
  48. }
  49. else:
  50. models = request.app.state.MODELS
  51. model_id = form_data.get("model")
  52. if model_id not in models:
  53. raise Exception("Model not found")
  54. model = models[model_id]
  55. # Access filtering
  56. if not getattr(request.state, "direct", False):
  57. if not bypass_filter and user.role == "user":
  58. check_model_access(user, model)
  59. # Arena "meta-model": select a submodel at random
  60. if model.get("owned_by") == "arena":
  61. model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
  62. filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
  63. if model_ids and filter_mode == "exclude":
  64. model_ids = [
  65. m["id"]
  66. for m in list(models.values())
  67. if m.get("owned_by") != "arena" and m["id"] not in model_ids
  68. ]
  69. if isinstance(model_ids, list) and model_ids:
  70. selected_model_id = random.choice(model_ids)
  71. else:
  72. model_ids = [
  73. m["id"]
  74. for m in list(models.values())
  75. if m.get("owned_by") != "arena"
  76. ]
  77. selected_model_id = random.choice(model_ids)
  78. inner_form = dict(form_data)
  79. inner_form["model"] = selected_model_id
  80. response = await generate_embeddings(
  81. request, inner_form, user, bypass_filter=True
  82. )
  83. # Tag which concreted model was chosen
  84. if isinstance(response, dict):
  85. response = {
  86. **response,
  87. "selected_model_id": selected_model_id,
  88. }
  89. return response
  90. # Pipeline/Function models
  91. if model.get("pipe"):
  92. # The pipeline handler should provide OpenAI-compatible schema
  93. return await process_pipeline_inlet_filter(request, form_data, user, models)
  94. # Ollama backend
  95. if model.get("owned_by") == "ollama":
  96. ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
  97. response = await ollama_embeddings(
  98. request=request,
  99. form_data=ollama_payload,
  100. user=user,
  101. )
  102. return convert_response_ollama_to_openai(response)
  103. # Default: OpenAI or compatible backend
  104. return await openai_embeddings(
  105. request=request,
  106. form_data=form_data,
  107. user=user,
  108. )