Browse Source

Merge pull request #17871 from silentoplayz/backend-json-model-import

feat: move JSON model import to backend for massive speedup
Tim Jaeryang Baek 1 week ago
parent
commit
1c418a7f83

+ 57 - 1
backend/open_webui/routers/models.py

@@ -1,6 +1,9 @@
 from typing import Optional
 import io
 import base64
+import json
+import asyncio
+import logging
 
 from open_webui.models.models import (
     ModelForm,
@@ -12,7 +15,14 @@ from open_webui.models.models import (
 
 from pydantic import BaseModel
 from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
+from fastapi import (
+    APIRouter,
+    Depends,
+    HTTPException,
+    Request,
+    status,
+    Response,
+)
 from fastapi.responses import FileResponse, StreamingResponse
 
 
@@ -20,6 +30,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
 
+log = logging.getLogger(__name__)
+
 router = APIRouter()
 
 
@@ -93,6 +105,50 @@ async def export_models(user=Depends(get_admin_user)):
     return Models.get_models()
 
 
+############################
+# ImportModels
+############################
+
+
+class ModelsImportForm(BaseModel):
+    models: list[dict]
+
+
+@router.post("/import", response_model=bool)
+async def import_models(
+    user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...)
+):
+    try:
+        data = form_data.models
+        if isinstance(data, list):
+            for model_data in data:
+                # Here, you can add logic to validate model_data if needed
+                model_id = model_data.get("id")
+                if model_id:
+                    existing_model = Models.get_model_by_id(model_id)
+                    if existing_model:
+                        # Update existing model
+                        model_data["meta"] = model_data.get("meta", {})
+                        model_data["params"] = model_data.get("params", {})
+
+                        updated_model = ModelForm(
+                            **{**existing_model.model_dump(), **model_data}
+                        )
+                        Models.update_model_by_id(model_id, updated_model)
+                    else:
+                        # Insert new model
+                        model_data["meta"] = model_data.get("meta", {})
+                        model_data["params"] = model_data.get("params", {})
+                        new_model = ModelForm(**model_data)
+                        Models.insert_new_model(user_id=user.id, form_data=new_model)
+            return True
+        else:
+            raise HTTPException(status_code=400, detail="Invalid JSON format")
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(status_code=500, detail=str(e))
+
+
 ############################
 # SyncModels
 ############################

+ 28 - 0
src/lib/apis/models/index.ts

@@ -31,6 +31,34 @@ export const getModels = async (token: string = '') => {
 	return res;
 };
 
+export const importModels = async (token: string, models: object[]) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({ models: models })
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err;
+			console.error(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const getBaseModels = async (token: string = '') => {
 	let error = null;
 

+ 27 - 31
src/lib/components/admin/Settings/Models.svelte

@@ -12,7 +12,8 @@
 		deleteAllModels,
 		getBaseModels,
 		toggleModelById,
-		updateModelById
+	updateModelById,
+	importModels
 	} from '$lib/apis/models';
 	import { copyToClipboard } from '$lib/utils';
 	import { page } from '$app/stores';
@@ -40,6 +41,7 @@
 
 	let shiftKey = false;
 
+let modelsImportInProgress = false;
 	let importFiles;
 	let modelsImportInputElement: HTMLInputElement;
 
@@ -464,47 +466,41 @@
 						accept=".json"
 						hidden
 						on:change={() => {
-							console.log(importFiles);
-
-							let reader = new FileReader();
-							reader.onload = async (event) => {
-								let savedModels = JSON.parse(event.target.result);
-								console.log(savedModels);
-
-								for (const model of savedModels) {
-									if (Object.keys(model).includes('base_model_id')) {
-										if (model.base_model_id === null) {
-											upsertModelHandler(model);
-										}
-									} else {
-										if (model?.info ?? false) {
-											if (model.info.base_model_id === null) {
-												upsertModelHandler(model.info);
-											}
+							if (importFiles.length > 0) {
+								const reader = new FileReader();
+								reader.onload = async (event) => {
+									try {
+										const models = JSON.parse(String(event.target.result));
+										modelsImportInProgress = true;
+										const res = await importModels(localStorage.token, models);
+										modelsImportInProgress = false;
+
+										if (res) {
+											toast.success($i18n.t('Models imported successfully'));
+											await init();
+										} else {
+											toast.error($i18n.t('Failed to import models'));
 										}
+									} catch (e) {
+										toast.error($i18n.t('Invalid JSON file'));
+										console.error(e);
 									}
-								}
-
-								await _models.set(
-									await getModels(
-										localStorage.token,
-										$config?.features?.enable_direct_connections &&
-											($settings?.directConnections ?? null)
-									)
-								);
-								init();
-							};
-
-							reader.readAsText(importFiles[0]);
+								};
+								reader.readAsText(importFiles[0]);
+							}
 						}}
 					/>
 
 					<button
 						class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+						disabled={modelsImportInProgress}
 						on:click={() => {
 							modelsImportInputElement.click();
 						}}
 					>
+						{#if modelsImportInProgress}
+							<Spinner className="size-3" />
+						{/if}
 						<div class=" self-center mr-2 font-medium line-clamp-1">
 							{$i18n.t('Import Presets')}
 						</div>