Browse Source

enh: ollama `/v1/completion` endpoint support

Timothy Jaeryang Baek 7 months ago
parent
commit
1439f6862d
1 changed files with 76 additions and 0 deletions
  1. 76 0
      backend/open_webui/apps/ollama/main.py

+ 76 - 0
backend/open_webui/apps/ollama/main.py

@@ -1032,6 +1032,82 @@ class OpenAIChatCompletionForm(BaseModel):
     model_config = ConfigDict(extra="allow")
 
 
+class OpenAICompletionForm(BaseModel):
+    model: str
+    prompt: str
+
+    model_config = ConfigDict(extra="allow")
+
+
+@app.post("/v1/completions")
+@app.post("/v1/completions/{url_idx}")
+async def generate_openai_completion(
+    form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
+):
+    try:
+        form_data = OpenAICompletionForm(**form_data)
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=400,
+            detail=str(e),
+        )
+
+    payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
+    if "metadata" in payload:
+        del payload["metadata"]
+
+    model_id = form_data.model
+    if ":" not in model_id:
+        model_id = f"{model_id}:latest"
+
+    model_info = Models.get_model_by_id(model_id)
+    if model_info:
+        if model_info.base_model_id:
+            payload["model"] = model_info.base_model_id
+        params = model_info.params.model_dump()
+
+        if params:
+            payload = apply_model_params_to_body_openai(params, payload)
+
+        # Check if user has access to the model
+        if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
+            if not (
+                user.id == model_info.user_id
+                or has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                )
+            ):
+                raise HTTPException(
+                    status_code=403,
+                    detail="Model not found",
+                )
+    else:
+        if user.role != "admin":
+            raise HTTPException(
+                status_code=403,
+                detail="Model not found",
+            )
+
+    if ":" not in payload["model"]:
+        payload["model"] = f"{payload['model']}:latest"
+
+    url = await get_ollama_url(url_idx, payload["model"])
+    log.info(f"url: {url}")
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    prefix_id = api_config.get("prefix_id", None)
+
+    if prefix_id:
+        payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
+
+    return await post_streaming_url(
+        f"{url}/v1/completions",
+        json.dumps(payload),
+        stream=payload.get("stream", False),
+    )
+
+
 @app.post("/v1/chat/completions")
 @app.post("/v1/chat/completions/{url_idx}")
 async def generate_openai_chat_completion(