embeddings.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import random
  2. import logging
  3. import sys
  4. from fastapi import Request
  5. from open_webui.models.users import UserModel
  6. from open_webui.models.models import Models
  7. from open_webui.utils.models import check_model_access
  8. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
  9. from open_webui.routers.openai import embeddings as openai_embeddings
  10. from open_webui.routers.ollama import (
  11. embeddings as ollama_embeddings,
  12. GenerateEmbeddingsForm,
  13. )
  14. from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
  15. from open_webui.utils.response import convert_embedding_response_ollama_to_openai
  16. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  17. log = logging.getLogger(__name__)
  18. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  19. async def generate_embeddings(
  20. request: Request,
  21. form_data: dict,
  22. user: UserModel,
  23. bypass_filter: bool = False,
  24. ):
  25. """
  26. Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
  27. Args:
  28. request (Request): The FastAPI request context.
  29. form_data (dict): The input data sent to the endpoint.
  30. user (UserModel): The authenticated user.
  31. bypass_filter (bool): If True, disables access filtering (default False).
  32. Returns:
  33. dict: The embeddings response, following OpenAI API compatibility.
  34. """
  35. if BYPASS_MODEL_ACCESS_CONTROL:
  36. bypass_filter = True
  37. # Attach extra metadata from request.state if present
  38. if hasattr(request.state, "metadata"):
  39. if "metadata" not in form_data:
  40. form_data["metadata"] = request.state.metadata
  41. else:
  42. form_data["metadata"] = {
  43. **form_data["metadata"],
  44. **request.state.metadata,
  45. }
  46. # If "direct" flag present, use only that model
  47. if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
  48. models = {
  49. request.state.model["id"]: request.state.model,
  50. }
  51. else:
  52. models = request.app.state.MODELS
  53. model_id = form_data.get("model")
  54. if model_id not in models:
  55. raise Exception("Model not found")
  56. model = models[model_id]
  57. # Access filtering
  58. if not getattr(request.state, "direct", False):
  59. if not bypass_filter and user.role == "user":
  60. check_model_access(user, model)
  61. # Ollama backend
  62. if model.get("owned_by") == "ollama":
  63. ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
  64. response = await ollama_embeddings(
  65. request=request,
  66. form_data=GenerateEmbeddingsForm(**ollama_payload),
  67. user=user,
  68. )
  69. return convert_embedding_response_ollama_to_openai(response)
  70. # Default: OpenAI or compatible backend
  71. return await openai_embeddings(
  72. request=request,
  73. form_data=form_data,
  74. user=user,
  75. )