浏览代码

feat: move JSON model import to backend

This moves the JSON model import functionality to the backend. Instead of the frontend parsing the JSON file and sending multiple requests, it now uploads the file to a new endpoint (/api/v1/models/import), which processes the file and imports the models. This improves efficiency and provides better user feedback.
silentoplayz 1 周之前
父节点
当前提交
231d182c35
共有 3 个文件被更改,包括 104 次插入34 次删除
  1. 55 1
      backend/open_webui/routers/models.py
  2. 30 0
      src/lib/apis/models/index.ts
  3. 19 33
      src/lib/components/admin/Settings/Models.svelte

+ 55 - 1
backend/open_webui/routers/models.py

@@ -1,6 +1,9 @@
 from typing import Optional
 import io
 import base64
+import json
+import asyncio
+import logging
 
 from open_webui.models.models import (
     ModelForm,
@@ -12,7 +15,16 @@ from open_webui.models.models import (
 
 from pydantic import BaseModel
 from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
+from fastapi import (
+    APIRouter,
+    Depends,
+    HTTPException,
+    Request,
+    status,
+    Response,
+    UploadFile,
+    File,
+)
 from fastapi.responses import FileResponse, StreamingResponse
 
 
@@ -20,6 +32,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
 
+log = logging.getLogger(__name__)
+
 router = APIRouter()
 
 
@@ -93,6 +107,46 @@ async def export_models(user=Depends(get_admin_user)):
     return Models.get_models()
 
 
+############################
+# ImportModels
+############################
+
+
+@router.post("/import", response_model=bool)
+async def import_models(
+    user: str = Depends(get_admin_user), file: UploadFile = File(...)
+):
+    try:
+        data = json.loads(await file.read())
+        if isinstance(data, list):
+            for model_data in data:
+                # Here, you can add logic to validate model_data if needed
+                model_id = model_data.get("id")
+                if model_id:
+                    existing_model = Models.get_model_by_id(model_id)
+                    if existing_model:
+                        # Update existing model
+                        model_data["meta"] = model_data.get("meta", {})
+                        model_data["params"] = model_data.get("params", {})
+
+                        updated_model = ModelForm(
+                            **{**existing_model.model_dump(), **model_data}
+                        )
+                        Models.update_model_by_id(model_id, updated_model)
+                    else:
+                        # Insert new model
+                        model_data["meta"] = model_data.get("meta", {})
+                        model_data["params"] = model_data.get("params", {})
+                        new_model = ModelForm(**model_data)
+                        Models.insert_new_model(user_id=user.id, form_data=new_model)
+            return True
+        else:
+            raise HTTPException(status_code=400, detail="Invalid JSON format")
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(status_code=500, detail=str(e))
+
+
 ############################
 # SyncModels
 ############################

+ 30 - 0
src/lib/apis/models/index.ts

@@ -31,6 +31,36 @@ export const getModels = async (token: string = '') => {
 	return res;
 };
 
+export const importModels = async (token: string, file: File) => {
+	let error = null;
+
+	const formData = new FormData();
+	formData.append('file', file);
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, {
+		method: 'POST',
+		headers: {
+			authorization: `Bearer ${token}`
+		},
+		body: formData
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err;
+			console.error(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getBaseModels = async (token: string = '') => {
 	let error = null;
 

+ 19 - 33
src/lib/components/admin/Settings/Models.svelte

@@ -12,7 +12,8 @@
 		deleteAllModels,
 		getBaseModels,
 		toggleModelById,
-		updateModelById
+	updateModelById,
+	importModels
 	} from '$lib/apis/models';
 	import { copyToClipboard } from '$lib/utils';
 	import { page } from '$app/stores';
@@ -40,6 +41,7 @@
 
 	let shiftKey = false;
 
+let modelsImportInProgress = false;
 	let importFiles;
 	let modelsImportInputElement: HTMLInputElement;
 
@@ -463,48 +465,32 @@
 						type="file"
 						accept=".json"
 						hidden
-						on:change={() => {
-							console.log(importFiles);
-
-							let reader = new FileReader();
-							reader.onload = async (event) => {
-								let savedModels = JSON.parse(event.target.result);
-								console.log(savedModels);
-
-								for (const model of savedModels) {
-									if (Object.keys(model).includes('base_model_id')) {
-										if (model.base_model_id === null) {
-											upsertModelHandler(model);
-										}
-									} else {
-										if (model?.info ?? false) {
-											if (model.info.base_model_id === null) {
-												upsertModelHandler(model.info);
-											}
-										}
-									}
+						on:change={async () => {
+							if (importFiles.length > 0) {
+								modelsImportInProgress = true;
+								const res = await importModels(localStorage.token, importFiles[0]);
+								modelsImportInProgress = false;
+
+								if (res) {
+									toast.success($i18n.t('Models imported successfully'));
+									await init();
+								} else {
+									toast.error($i18n.t('Failed to import models'));
 								}
-
-								await _models.set(
-									await getModels(
-										localStorage.token,
-										$config?.features?.enable_direct_connections &&
-											($settings?.directConnections ?? null)
-									)
-								);
-								init();
-							};
-
-							reader.readAsText(importFiles[0]);
+							}
 						}}
 					/>
 
 					<button
 						class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+						disabled={modelsImportInProgress}
 						on:click={() => {
 							modelsImportInputElement.click();
 						}}
 					>
+						{#if modelsImportInProgress}
+							<Spinner className="size-3" />
+						{/if}
 						<div class=" self-center mr-2 font-medium line-clamp-1">
 							{$i18n.t('Import Presets')}
 						</div>