1
0
Timothy Jaeryang Baek 4 сар өмнө
parent
commit
2e56b1f13d

+ 126 - 38
backend/open_webui/routers/openai.py

@@ -647,6 +647,63 @@ async def verify_connection(
             raise HTTPException(status_code=500, detail=error_detail)
 
 
+def convert_to_azure_payload(
+    url,
+    payload: dict,
+):
+    model = payload.get("model", "")
+
+    # Filter allowed parameters based on Azure OpenAI API
+    allowed_params = {
+        "messages",
+        "temperature",
+        "role",
+        "content",
+        "contentPart",
+        "contentPartImage",
+        "enhancements",
+        "dataSources",
+        "n",
+        "stream",
+        "stop",
+        "max_tokens",
+        "presence_penalty",
+        "frequency_penalty",
+        "logit_bias",
+        "user",
+        "function_call",
+        "functions",
+        "tools",
+        "tool_choice",
+        "top_p",
+        "log_probs",
+        "top_logprobs",
+        "response_format",
+        "seed",
+        "max_completion_tokens",
+    }
+
+    # Special handling for o-series models
+    if model.startswith("o") and model.endswith("-mini"):
+        # Convert max_tokens to max_completion_tokens for o-series models
+        if "max_tokens" in payload:
+            payload["max_completion_tokens"] = payload["max_tokens"]
+            del payload["max_tokens"]
+
+        # Remove temperature if not 1 for o-series models
+        if "temperature" in payload and payload["temperature"] != 1:
+            log.debug(
+                f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
+            )
+            del payload["temperature"]
+
+    # Filter out unsupported parameters
+    payload = {k: v for k, v in payload.items() if k in allowed_params}
+
+    url = f"{url}/openai/deployments/{model}"
+    return url, payload
+
+
 @router.post("/chat/completions")
 async def generate_chat_completion(
     request: Request,
@@ -747,6 +804,38 @@ async def generate_chat_completion(
             convert_logit_bias_input_to_json(payload["logit_bias"])
         )
 
+    headers = {
+        "Content-Type": "application/json",
+        **(
+            {
+                "HTTP-Referer": "https://openwebui.com/",
+                "X-Title": "Open WebUI",
+            }
+            if "openrouter.ai" in url
+            else {}
+        ),
+        **(
+            {
+                "X-OpenWebUI-User-Name": user.name,
+                "X-OpenWebUI-User-Id": user.id,
+                "X-OpenWebUI-User-Email": user.email,
+                "X-OpenWebUI-User-Role": user.role,
+            }
+            if ENABLE_FORWARD_USER_INFO_HEADERS
+            else {}
+        ),
+    }
+
+    if api_config.get("azure", False):
+        request_url, payload = convert_to_azure_payload(url, payload)
+        api_version = api_config.get("api_version", "2023-03-15-preview")
+        headers["api-key"] = key
+        headers["api-version"] = api_version
+        request_url = f"{request_url}/chat/completions?api-version={api_version}"
+    else:
+        request_url = f"{url}/chat/completions"
+        headers["Authorization"] = f"Bearer {key}"
+
     payload = json.dumps(payload)
 
     r = None
@@ -761,30 +850,9 @@ async def generate_chat_completion(
 
         r = await session.request(
             method="POST",
-            url=f"{url}/chat/completions",
+            url=request_url,
             data=payload,
-            headers={
-                "Authorization": f"Bearer {key}",
-                "Content-Type": "application/json",
-                **(
-                    {
-                        "HTTP-Referer": "https://openwebui.com/",
-                        "X-Title": "Open WebUI",
-                    }
-                    if "openrouter.ai" in url
-                    else {}
-                ),
-                **(
-                    {
-                        "X-OpenWebUI-User-Name": user.name,
-                        "X-OpenWebUI-User-Id": user.id,
-                        "X-OpenWebUI-User-Email": user.email,
-                        "X-OpenWebUI-User-Role": user.role,
-                    }
-                    if ENABLE_FORWARD_USER_INFO_HEADERS
-                    else {}
-                ),
-            },
+            headers=headers,
             ssl=AIOHTTP_CLIENT_SESSION_SSL,
         )
 
@@ -840,31 +908,51 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     idx = 0
     url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
     key = request.app.state.config.OPENAI_API_KEYS[idx]
+    api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
+        str(idx),
+        request.app.state.config.OPENAI_API_CONFIGS.get(
+            request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
+        ),  # Legacy support
+    )
 
     r = None
     session = None
     streaming = False
 
     try:
+        headers = {
+            "Content-Type": "application/json",
+            **(
+                {
+                    "X-OpenWebUI-User-Name": user.name,
+                    "X-OpenWebUI-User-Id": user.id,
+                    "X-OpenWebUI-User-Email": user.email,
+                    "X-OpenWebUI-User-Role": user.role,
+                }
+                if ENABLE_FORWARD_USER_INFO_HEADERS
+                else {}
+            ),
+        }
+
+        if api_config.get("azure", False):
+            headers["api-key"] = key
+            headers["api-version"] = api_config.get("api_version", "2023-03-15-preview")
+
+            payload = json.loads(body)
+            url, payload = convert_to_azure_payload(url, payload)
+            body = json.dumps(payload).encode()
+
+            request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
+        else:
+            headers["Authorization"] = f"Bearer {key}"
+            request_url = f"{url}/{path}"
+
         session = aiohttp.ClientSession(trust_env=True)
         r = await session.request(
             method=request.method,
-            url=f"{url}/{path}",
+            url=request_url,
             data=body,
-            headers={
-                "Authorization": f"Bearer {key}",
-                "Content-Type": "application/json",
-                **(
-                    {
-                        "X-OpenWebUI-User-Name": user.name,
-                        "X-OpenWebUI-User-Id": user.id,
-                        "X-OpenWebUI-User-Email": user.email,
-                        "X-OpenWebUI-User-Role": user.role,
-                    }
-                    if ENABLE_FORWARD_USER_INFO_HEADERS
-                    else {}
-                ),
-            },
+            headers=headers,
             ssl=AIOHTTP_CLIENT_SESSION_SSL,
         )
         r.raise_for_status()