Răsfoiți Sursa

feat: move JSON model import to backend

This moves the JSON model import functionality to the backend. Instead of the frontend parsing the JSON file and sending multiple requests, it now uploads the file to a new endpoint (/api/v1/models/import), which processes the file and imports the models. This improves efficiency and provides better user feedback.
silentoplayz 1 săptămână în urmă
părinte
comite
231d182c35

+ 55 - 1
backend/open_webui/routers/models.py

@@ -1,6 +1,9 @@
 from typing import Optional
 from typing import Optional
 import io
 import io
 import base64
 import base64
+import json
+import asyncio
+import logging
 
 
 from open_webui.models.models import (
 from open_webui.models.models import (
     ModelForm,
     ModelForm,
@@ -12,7 +15,16 @@ from open_webui.models.models import (
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
+from fastapi import (
+    APIRouter,
+    Depends,
+    HTTPException,
+    Request,
+    status,
+    Response,
+    UploadFile,
+    File,
+)
 from fastapi.responses import FileResponse, StreamingResponse
 from fastapi.responses import FileResponse, StreamingResponse
 
 
 
 
@@ -20,6 +32,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
 from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
 
 
+log = logging.getLogger(__name__)
+
 router = APIRouter()
 router = APIRouter()
 
 
 
 
@@ -93,6 +107,46 @@ async def export_models(user=Depends(get_admin_user)):
     return Models.get_models()
     return Models.get_models()
 
 
 
 
+############################
+# ImportModels
+############################
+
+
+@router.post("/import", response_model=bool)
+async def import_models(
+    user: str = Depends(get_admin_user), file: UploadFile = File(...)
+):
+    try:
+        data = json.loads(await file.read())
+        if isinstance(data, list):
+            for model_data in data:
+                # Here, you can add logic to validate model_data if needed
+                model_id = model_data.get("id")
+                if model_id:
+                    existing_model = Models.get_model_by_id(model_id)
+                    if existing_model:
+                        # Update existing model
+                        model_data["meta"] = model_data.get("meta", {})
+                        model_data["params"] = model_data.get("params", {})
+
+                        updated_model = ModelForm(
+                            **{**existing_model.model_dump(), **model_data}
+                        )
+                        Models.update_model_by_id(model_id, updated_model)
+                    else:
+                        # Insert new model
+                        model_data["meta"] = model_data.get("meta", {})
+                        model_data["params"] = model_data.get("params", {})
+                        new_model = ModelForm(**model_data)
+                        Models.insert_new_model(user_id=user.id, form_data=new_model)
+            return True
+        else:
+            raise HTTPException(status_code=400, detail="Invalid JSON format")
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(status_code=500, detail=str(e))
+
+
 ############################
 ############################
 # SyncModels
 # SyncModels
 ############################
 ############################

+ 30 - 0
src/lib/apis/models/index.ts

@@ -31,6 +31,36 @@ export const getModels = async (token: string = '') => {
 	return res;
 	return res;
 };
 };
 
 
+export const importModels = async (token: string, file: File) => {
+	let error = null;
+
+	const formData = new FormData();
+	formData.append('file', file);
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, {
+		method: 'POST',
+		headers: {
+			authorization: `Bearer ${token}`
+		},
+		body: formData
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err;
+			console.error(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getBaseModels = async (token: string = '') => {
 export const getBaseModels = async (token: string = '') => {
 	let error = null;
 	let error = null;
 
 

+ 19 - 33
src/lib/components/admin/Settings/Models.svelte

@@ -12,7 +12,8 @@
 		deleteAllModels,
 		deleteAllModels,
 		getBaseModels,
 		getBaseModels,
 		toggleModelById,
 		toggleModelById,
-		updateModelById
+	updateModelById,
+	importModels
 	} from '$lib/apis/models';
 	} from '$lib/apis/models';
 	import { copyToClipboard } from '$lib/utils';
 	import { copyToClipboard } from '$lib/utils';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
@@ -40,6 +41,7 @@
 
 
 	let shiftKey = false;
 	let shiftKey = false;
 
 
+let modelsImportInProgress = false;
 	let importFiles;
 	let importFiles;
 	let modelsImportInputElement: HTMLInputElement;
 	let modelsImportInputElement: HTMLInputElement;
 
 
@@ -463,48 +465,32 @@
 						type="file"
 						type="file"
 						accept=".json"
 						accept=".json"
 						hidden
 						hidden
-						on:change={() => {
-							console.log(importFiles);
-
-							let reader = new FileReader();
-							reader.onload = async (event) => {
-								let savedModels = JSON.parse(event.target.result);
-								console.log(savedModels);
-
-								for (const model of savedModels) {
-									if (Object.keys(model).includes('base_model_id')) {
-										if (model.base_model_id === null) {
-											upsertModelHandler(model);
-										}
-									} else {
-										if (model?.info ?? false) {
-											if (model.info.base_model_id === null) {
-												upsertModelHandler(model.info);
-											}
-										}
-									}
+						on:change={async () => {
+							if (importFiles.length > 0) {
+								modelsImportInProgress = true;
+								const res = await importModels(localStorage.token, importFiles[0]);
+								modelsImportInProgress = false;
+
+								if (res) {
+									toast.success($i18n.t('Models imported successfully'));
+									await init();
+								} else {
+									toast.error($i18n.t('Failed to import models'));
 								}
 								}
-
-								await _models.set(
-									await getModels(
-										localStorage.token,
-										$config?.features?.enable_direct_connections &&
-											($settings?.directConnections ?? null)
-									)
-								);
-								init();
-							};
-
-							reader.readAsText(importFiles[0]);
+							}
 						}}
 						}}
 					/>
 					/>
 
 
 					<button
 					<button
 						class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 						class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+						disabled={modelsImportInProgress}
 						on:click={() => {
 						on:click={() => {
 							modelsImportInputElement.click();
 							modelsImportInputElement.click();
 						}}
 						}}
 					>
 					>
+						{#if modelsImportInProgress}
+							<Spinner className="size-3" />
+						{/if}
 						<div class=" self-center mr-2 font-medium line-clamp-1">
 						<div class=" self-center mr-2 font-medium line-clamp-1">
 							{$i18n.t('Import Presets')}
 							{$i18n.t('Import Presets')}
 						</div>
 						</div>