|
@@ -1032,6 +1032,82 @@ class OpenAIChatCompletionForm(BaseModel):
|
|
|
model_config = ConfigDict(extra="allow")
|
|
|
|
|
|
|
|
|
+class OpenAICompletionForm(BaseModel):
|
|
|
+ model: str
|
|
|
+ prompt: str
|
|
|
+
|
|
|
+ model_config = ConfigDict(extra="allow")
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/v1/completions")
|
|
|
+@app.post("/v1/completions/{url_idx}")
|
|
|
+async def generate_openai_completion(
|
|
|
+ form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
|
|
+):
|
|
|
+ try:
|
|
|
+ form_data = OpenAICompletionForm(**form_data)
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=400,
|
|
|
+ detail=str(e),
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
|
|
|
+ if "metadata" in payload:
|
|
|
+ del payload["metadata"]
|
|
|
+
|
|
|
+ model_id = form_data.model
|
|
|
+ if ":" not in model_id:
|
|
|
+ model_id = f"{model_id}:latest"
|
|
|
+
|
|
|
+ model_info = Models.get_model_by_id(model_id)
|
|
|
+ if model_info:
|
|
|
+ if model_info.base_model_id:
|
|
|
+ payload["model"] = model_info.base_model_id
|
|
|
+ params = model_info.params.model_dump()
|
|
|
+
|
|
|
+ if params:
|
|
|
+ payload = apply_model_params_to_body_openai(params, payload)
|
|
|
+
|
|
|
+ # Check if user has access to the model
|
|
|
+ if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
|
|
+ if not (
|
|
|
+ user.id == model_info.user_id
|
|
|
+ or has_access(
|
|
|
+ user.id, type="read", access_control=model_info.access_control
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=403,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ if user.role != "admin":
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=403,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ if ":" not in payload["model"]:
|
|
|
+ payload["model"] = f"{payload['model']}:latest"
|
|
|
+
|
|
|
+ url = await get_ollama_url(url_idx, payload["model"])
|
|
|
+ log.info(f"url: {url}")
|
|
|
+
|
|
|
+ api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
|
|
+ prefix_id = api_config.get("prefix_id", None)
|
|
|
+
|
|
|
+ if prefix_id:
|
|
|
+ payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
+
|
|
|
+ return await post_streaming_url(
|
|
|
+ f"{url}/v1/completions",
|
|
|
+ json.dumps(payload),
|
|
|
+ stream=payload.get("stream", False),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@app.post("/v1/chat/completions")
|
|
|
@app.post("/v1/chat/completions/{url_idx}")
|
|
|
async def generate_openai_chat_completion(
|