ソースを参照

feat: model sync endpoint

Timothy Jaeryang Baek 2 ヶ月 前
コミット
c1e4139e5c
2 ファイル変更62 行追加0 行削除
  1. 44 0
      backend/open_webui/models/models.py
  2. 18 0
      backend/open_webui/routers/models.py

+ 44 - 0
backend/open_webui/models/models.py

@@ -269,5 +269,49 @@ class ModelsTable:
         except Exception:
             return False
 
+    def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]:
+        try:
+            with get_db() as db:
+                # Get existing models
+                existing_models = db.query(Model).all()
+                existing_ids = {model.id for model in existing_models}
+
+                # Prepare a set of new model IDs
+                new_model_ids = {model.id for model in models}
+
+                # Update or insert models
+                for model in models:
+                    if model.id in existing_ids:
+                        db.query(Model).filter_by(id=model.id).update(
+                            {
+                                **model.model_dump(),
+                                "user_id": user_id,
+                                "updated_at": int(time.time()),
+                            }
+                        )
+                    else:
+                        new_model = Model(
+                            **{
+                                **model.model_dump(),
+                                "user_id": user_id,
+                                "updated_at": int(time.time()),
+                            }
+                        )
+                        db.add(new_model)
+
+                # Remove models that are no longer present
+                for model in existing_models:
+                    if model.id not in new_model_ids:
+                        db.delete(model)
+
+                db.commit()
+
+                return [
+                    ModelModel.model_validate(model) for model in db.query(Model).all()
+                ]
+        except Exception as e:
+            log.exception(f"Error syncing models for user {user_id}: {e}")
+            return []
+
 
 Models = ModelsTable()

+ 18 - 0
backend/open_webui/routers/models.py

@@ -7,6 +7,8 @@ from open_webui.models.models import (
     ModelUserResponse,
     Models,
 )
+
+from pydantic import BaseModel
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 
@@ -78,6 +80,22 @@ async def create_new_model(
             )
 
 
+############################
+# SyncModels
+############################
+
+
+class SyncModelsForm(BaseModel):
+    models: list[ModelModel] = []
+
+
+@router.post("/sync", response_model=list[ModelModel])
+async def sync_models(
+    request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user)
+):
+    return Models.sync_models(user.id, form_data.models)
+
+
 ###########################
 # GetModelById
 ###########################