123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import random
- import logging
- import sys
- from fastapi import Request
- from open_webui.models.users import UserModel
- from open_webui.models.models import Models
- from open_webui.utils.models import check_model_access
- from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
- from open_webui.routers.openai import embeddings as openai_embeddings
- from open_webui.routers.ollama import embeddings as ollama_embeddings
- from open_webui.routers.ollama import GenerateEmbeddingsForm
- from open_webui.routers.pipelines import process_pipeline_inlet_filter
- from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
- from open_webui.utils.response import convert_embedding_response_ollama_to_openai
- logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MAIN"])
- async def generate_embeddings(
- request: Request,
- form_data: dict,
- user: UserModel,
- bypass_filter: bool = False,
- ):
- """
- Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama, Arena, pipeline, etc).
- Args:
- request (Request): The FastAPI request context.
- form_data (dict): The input data sent to the endpoint.
- user (UserModel): The authenticated user.
- bypass_filter (bool): If True, disables access filtering (default False).
- Returns:
- dict: The embeddings response, following OpenAI API compatibility.
- """
- if BYPASS_MODEL_ACCESS_CONTROL:
- bypass_filter = True
- # Attach extra metadata from request.state if present
- if hasattr(request.state, "metadata"):
- if "metadata" not in form_data:
- form_data["metadata"] = request.state.metadata
- else:
- form_data["metadata"] = {
- **form_data["metadata"],
- **request.state.metadata,
- }
- # If "direct" flag present, use only that model
- if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
- models = {
- request.state.model["id"]: request.state.model,
- }
- else:
- models = request.app.state.MODELS
- model_id = form_data.get("model")
- if model_id not in models:
- raise Exception("Model not found")
- model = models[model_id]
- # Access filtering
- if not getattr(request.state, "direct", False):
- if not bypass_filter and user.role == "user":
- check_model_access(user, model)
- # Arena "meta-model": select a submodel at random
- if model.get("owned_by") == "arena":
- model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
- filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
- if model_ids and filter_mode == "exclude":
- model_ids = [
- m["id"]
- for m in list(models.values())
- if m.get("owned_by") != "arena" and m["id"] not in model_ids
- ]
- if isinstance(model_ids, list) and model_ids:
- selected_model_id = random.choice(model_ids)
- else:
- model_ids = [
- m["id"]
- for m in list(models.values())
- if m.get("owned_by") != "arena"
- ]
- selected_model_id = random.choice(model_ids)
- inner_form = dict(form_data)
- inner_form["model"] = selected_model_id
- response = await generate_embeddings(
- request, inner_form, user, bypass_filter=True
- )
- # Tag which concreted model was chosen
- if isinstance(response, dict):
- response = {
- **response,
- "selected_model_id": selected_model_id,
- }
- return response
- # Pipeline/Function models
- if model.get("pipe"):
- # The pipeline handler should provide OpenAI-compatible schema
- return await process_pipeline_inlet_filter(request, form_data, user, models)
- # Ollama backend
- if model.get("owned_by") == "ollama":
- ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
- form_obj = GenerateEmbeddingsForm(**ollama_payload)
- response = await ollama_embeddings(
- request=request,
- form_data=form_obj,
- user=user,
- )
- return convert_embedding_response_ollama_to_openai(response)
- # Default: OpenAI or compatible backend
- return await openai_embeddings(
- request=request,
- form_data=form_data,
- user=user,
- )
|