Jelajahi Sumber

feat(openai): Add stream_options to payload if api_version supports

Signed-off-by: Adam Tao <tcx4c70@gmail.com>
Adam Tao 3 bulan lalu
induk
melakukan
baafdb752c
1 mengubah file dengan 21 tambahan dan 14 penghapusan
  1. 21 14
      backend/open_webui/routers/openai.py

+ 21 - 14
backend/open_webui/routers/openai.py

@@ -633,13 +633,7 @@ async def verify_connection(
             raise HTTPException(status_code=500, detail=error_detail)
 
 
-def convert_to_azure_payload(
-    url,
-    payload: dict,
-):
-    model = payload.get("model", "")
-
-    # Filter allowed parameters based on Azure OpenAI API
+def get_azure_allowed_params(api_version: str) -> set[str]:
     allowed_params = {
         "messages",
         "temperature",
@@ -668,6 +662,20 @@ def convert_to_azure_payload(
         "seed",
         "max_completion_tokens",
     }
+    if api_version >= "2024-09-01-preview":
+        allowed_params.add("stream_options")
+    return allowed_params
+
+
+def convert_to_azure_payload(
+    url,
+    payload: dict,
+    api_version: str
+):
+    model = payload.get("model", "")
+
+    # Filter allowed parameters based on Azure OpenAI API
+    allowed_params = get_azure_allowed_params(api_version)
 
     # Special handling for o-series models
     if model.startswith("o") and model.endswith("-mini"):
@@ -817,8 +825,8 @@ async def generate_chat_completion(
     }
 
     if api_config.get("azure", False):
-        request_url, payload = convert_to_azure_payload(url, payload)
-        api_version = api_config.get("api_version", "") or "2023-03-15-preview"
+        api_version = api_config.get("api_version", "2023-03-15-preview")
+        request_url, payload = convert_to_azure_payload(url, payload, api_version)
         headers["api-key"] = key
         headers["api-version"] = api_version
         request_url = f"{request_url}/chat/completions?api-version={api_version}"
@@ -1007,16 +1015,15 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         }
 
         if api_config.get("azure", False):
+            api_version = api_config.get("api_version", "2023-03-15-preview")
             headers["api-key"] = key
-            headers["api-version"] = (
-                api_config.get("api_version", "") or "2023-03-15-preview"
-            )
+            headers["api-version"] = api_version
 
             payload = json.loads(body)
-            url, payload = convert_to_azure_payload(url, payload)
+            url, payload = convert_to_azure_payload(url, payload, api_version)
             body = json.dumps(payload).encode()
 
-            request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
+            request_url = f"{url}/{path}?api-version={api_version}"
         else:
             headers["Authorization"] = f"Bearer {key}"
             request_url = f"{url}/{path}"