Timothy Jaeryang Baek 3 месяцев назад
Родитель
Сommit
fdf7ca15ea
1 измененных файлов с 38 добавлено и 33 удалено
  1. 38 33
      backend/open_webui/routers/openai.py

+ 38 - 33
backend/open_webui/routers/openai.py

@@ -501,50 +501,55 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
             return response
         return None
 
-    def merge_models_lists(model_lists):
-        log.debug(f"merge_models_lists {model_lists}")
-        merged_list = []
-
-        for idx, models in enumerate(model_lists):
-            if models is not None and "error" not in models:
+    def is_supported_openai_models(model_id):
+        if any(
+            name in model_id
+            for name in [
+                "babbage",
+                "dall-e",
+                "davinci",
+                "embedding",
+                "tts",
+                "whisper",
+            ]
+        ):
+            return False
+        return True
 
-                merged_list.extend(
-                    [
-                        {
+    def get_merged_models(model_lists):
+        log.debug(f"merge_models_lists {model_lists}")
+        models = {}
+
+        for idx, model_list in enumerate(model_lists):
+            if model_list is not None and "error" not in model_list:
+                for model in model_list:
+                    model_id = model.get("id") or model.get("name")
+
+                    if (
+                        "api.openai.com"
+                        in request.app.state.config.OPENAI_API_BASE_URLS[idx]
+                        and not is_supported_openai_models(model_id)
+                    ):
+                        # Skip unwanted OpenAI models
+                        continue
+
+                    if model_id and model_id not in models:
+                        models[model_id] = {
                             **model,
-                            "name": model.get("name", model["id"]),
+                            "name": model.get("name", model_id),
                             "owned_by": "openai",
                             "openai": model,
                             "connection_type": model.get("connection_type", "external"),
                             "urlIdx": idx,
                         }
-                        for model in models
-                        if (model.get("id") or model.get("name"))
-                        and (
-                            "api.openai.com"
-                            not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
-                            or not any(
-                                name in model["id"]
-                                for name in [
-                                    "babbage",
-                                    "dall-e",
-                                    "davinci",
-                                    "embedding",
-                                    "tts",
-                                    "whisper",
-                                ]
-                            )
-                        )
-                    ]
-                )
 
-        return merged_list
+        return models
 
-    models = {"data": merge_models_lists(map(extract_data, responses))}
+    models = get_merged_models(map(extract_data, responses))
     log.debug(f"models: {models}")
 
-    request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
-    return models
+    request.app.state.OPENAI_MODELS = map(extract_data, responses)
+    return {"data": list(models.values())}
 
 
 @router.get("/models")