Quellcode durchsuchen

refac: external tools server support

Timothy Jaeryang Baek vor 1 Monat
Ursprung
Commit
d1bc2cfa2f

+ 1 - 0
backend/open_webui/main.py

@@ -1052,6 +1052,7 @@ async def chat_completion(
             "message_id": form_data.pop("id", None),
             "session_id": form_data.pop("session_id", None),
             "tool_ids": form_data.get("tool_ids", None),
+            "tool_servers": form_data.pop("tool_servers", None),
             "files": form_data.get("files", None),
             "features": form_data.get("features", None),
             "variables": form_data.get("variables", None),

+ 27 - 8
backend/open_webui/utils/middleware.py

@@ -213,8 +213,9 @@ async def chat_completion_tools_handler(
                                 "type": "execute:tool",
                                 "data": {
                                     "id": str(uuid4()),
-                                    "tool": tool,
+                                    "name": tool_function_name,
                                     "params": tool_function_params,
+                                    "tool": tool,
                                     "server": tool.get("server", {}),
                                     "session_id": metadata.get("session_id", None),
                                 },
@@ -224,17 +225,30 @@ async def chat_completion_tools_handler(
                 except Exception as e:
                     tool_output = str(e)
 
+                if isinstance(tool_output, dict):
+                    tool_output = json.dumps(tool_output, indent=4)
+
                 if isinstance(tool_output, str):
-                    if tools[tool_function_name]["citation"]:
+                    tool_id = tools[tool_function_name].get("toolkit_id", "")
+                    if tools[tool_function_name].get("citation", False):
+
                         sources.append(
                             {
                                 "source": {
-                                    "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+                                    "name": (
+                                        f"TOOL:" + f"{tool_id}/{tool_function_name}"
+                                        if tool_id
+                                        else f"{tool_function_name}"
+                                    ),
                                 },
                                 "document": [tool_output],
                                 "metadata": [
                                     {
-                                        "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+                                        "source": (
+                                            f"TOOL:" + f"{tool_id}/{tool_function_name}"
+                                            if tool_id
+                                            else f"{tool_function_name}"
+                                        )
                                     }
                                 ],
                             }
@@ -246,13 +260,17 @@ async def chat_completion_tools_handler(
                                 "document": [tool_output],
                                 "metadata": [
                                     {
-                                        "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
+                                        "source": (
+                                            f"TOOL:" + f"{tool_id}/{tool_function_name}"
+                                            if tool_id
+                                            else f"{tool_function_name}"
+                                        )
                                     }
                                 ],
                             }
                         )
 
-                    if tools[tool_function_name]["file_handler"]:
+                    if tools[tool_function_name].get("file_handler", False):
                         skip_files = True
 
             # check if "tool_calls" in result
@@ -788,7 +806,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
     # Server side tools
     tool_ids = metadata.get("tool_ids", None)
     # Client side tools
-    tool_servers = form_data.get("tool_servers", None)
+    tool_servers = metadata.get("tool_servers", None)
 
     log.debug(f"{tool_ids=}")
     log.debug(f"{tool_servers=}")
@@ -1824,8 +1842,9 @@ async def process_chat_response(
                                             "type": "execute:tool",
                                             "data": {
                                                 "id": str(uuid4()),
-                                                "tool": tool,
+                                                "name": tool_name,
                                                 "params": tool_function_params,
+                                                "tool": tool,
                                                 "server": tool.get("server", {}),
                                                 "session_id": metadata.get(
                                                     "session_id", None

+ 133 - 0
src/lib/apis/index.ts

@@ -1,4 +1,5 @@
 import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
+import { convertOpenApiToToolPayload } from '$lib/utils';
 import { getOpenAIModelsDirect } from './openai';
 
 export const getModels = async (
@@ -256,6 +257,138 @@ export const stopTask = async (token: string, id: string) => {
 	return res;
 };
 
+export const getToolServerData = async (token: string, url: string) => {
+	let error = null;
+
+	const res = await fetch(`${url}/openapi.json`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = err;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	const data = {
+		openapi: res,
+		info: res.info,
+		specs: convertOpenApiToToolPayload(res)
+	};
+
+	console.log(data);
+	return data;
+};
+
+export const getToolServersData = async (servers: object[]) => {
+	return await Promise.all(
+		servers
+			.filter(async (server) => server?.config?.enable)
+			.map(async (server) => {
+				const data = await getToolServerData(server?.key, server?.url).catch((err) => {
+					console.error(err);
+					return null;
+				});
+
+				if (data) {
+					const { openapi, info, specs } = data;
+					return {
+						url: server?.url,
+						openapi: openapi,
+						info: info,
+						specs: specs
+					};
+				}
+			})
+	);
+};
+
+export const executeToolServer = async (
+	token: string,
+	url: string,
+	name: string,
+	params: object,
+	serverData: { openapi: any; info: any; specs: any }
+) => {
+	let error = null;
+
+	try {
+		// Find the matching operationId in the OpenAPI specification
+		const matchingRoute = Object.entries(serverData.openapi.paths).find(([path, methods]) => {
+			return Object.entries(methods).some(
+				([method, operation]: any) => operation.operationId === name
+			);
+		});
+
+		if (!matchingRoute) {
+			throw new Error(`No matching route found for operationId: ${name}`);
+		}
+
+		const [route, methods] = matchingRoute;
+		const methodEntry = Object.entries(methods).find(
+			([method, operation]: any) => operation.operationId === name
+		);
+
+		if (!methodEntry) {
+			throw new Error(`No matching method found for operationId: ${name}`);
+		}
+
+		const [httpMethod, operation]: [string, any] = methodEntry;
+
+		// Replace path parameters in the URL
+		let finalUrl = `${url}${route}`;
+		if (operation.parameters) {
+			Object.entries(params).forEach(([key, value]) => {
+				finalUrl = finalUrl.replace(`{${key}}`, encodeURIComponent(value as string));
+			});
+		}
+
+		// Headers and request options
+		const headers = {
+			...(token && { authorization: `Bearer ${token}` }),
+			'Content-Type': 'application/json'
+		};
+
+		let requestOptions: RequestInit = {
+			method: httpMethod.toUpperCase(),
+			headers
+		};
+
+		// Handle request body for POST, PUT, PATCH
+		if (['post', 'put', 'patch'].includes(httpMethod.toLowerCase()) && operation.requestBody) {
+			requestOptions.body = JSON.stringify(params);
+		}
+
+		// Execute the request
+		const res = await fetch(finalUrl, requestOptions);
+		if (!res.ok) {
+			throw new Error(`HTTP error! Status: ${res.status}`);
+		}
+
+		return await res.json();
+	} catch (err: any) {
+		error = err.message;
+		console.error('API Request Error:', error);
+		return { error };
+	}
+};
+
 export const getTaskConfig = async (token: string = '') => {
 	let error = null;
 

+ 5 - 7
src/lib/components/chat/Chat.svelte

@@ -35,7 +35,8 @@
 		showOverview,
 		chatTitle,
 		showArtifacts,
-		tools
+		tools,
+		toolServers
 	} from '$lib/stores';
 	import {
 		convertMessagesToHistory,
@@ -120,8 +121,6 @@
 	let webSearchEnabled = false;
 	let codeInterpreterEnabled = false;
 
-	let toolServers = [];
-
 	let chat = null;
 	let tags = [];
 
@@ -194,8 +193,6 @@
 		setToolIds();
 	}
 
-	$: toolServers = ($settings?.toolServers ?? []).filter((server) => server?.config?.enable);
-
 	const setToolIds = async () => {
 		if (!$tools) {
 			tools.set(await getTools(localStorage.token));
@@ -1570,6 +1567,7 @@
 
 				files: (files?.length ?? 0) > 0 ? files : undefined,
 				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
+				tool_servers: $toolServers,
 
 				features: {
 					image_generation:
@@ -2038,7 +2036,7 @@
 								bind:codeInterpreterEnabled
 								bind:webSearchEnabled
 								bind:atSelectedModel
-								{toolServers}
+								toolServers={$toolServers}
 								transparentBackground={$settings?.backgroundImageUrl ?? false}
 								{stopResponse}
 								{createMessagePair}
@@ -2092,7 +2090,7 @@
 								bind:webSearchEnabled
 								bind:atSelectedModel
 								transparentBackground={$settings?.backgroundImageUrl ?? false}
-								{toolServers}
+								toolServers={$toolServers}
 								{stopResponse}
 								{createMessagePair}
 								on:upload={async (e) => {

+ 4 - 2
src/lib/components/chat/Settings/Tools.svelte

@@ -1,12 +1,12 @@
 <script lang="ts">
 	import { toast } from 'svelte-sonner';
 	import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
-	import { getModels as _getModels } from '$lib/apis';
+	import { getModels as _getModels, getToolServersData } from '$lib/apis';
 
 	const dispatch = createEventDispatcher();
 	const i18n = getContext('i18n');
 
-	import { models, settings, user } from '$lib/stores';
+	import { models, settings, toolServers, user } from '$lib/stores';
 
 	import Switch from '$lib/components/common/Switch.svelte';
 	import Spinner from '$lib/components/common/Spinner.svelte';
@@ -30,6 +30,8 @@
 		await saveSettings({
 			toolServers: servers
 		});
+
+		toolServers.set(await getToolServersData($settings?.toolServers ?? []));
 	};
 
 	onMount(async () => {

+ 2 - 0
src/lib/stores/index.ts

@@ -58,6 +58,8 @@ export const knowledge: Writable<null | Document[]> = writable(null);
 export const tools = writable(null);
 export const functions = writable(null);
 
+export const toolServers = writable([]);
+
 export const banners: Writable<Banner[]> = writable([]);
 
 export const settings: Writable<Settings> = writable({});

+ 56 - 0
src/lib/utils/index.ts

@@ -1070,3 +1070,59 @@ export const getLineCount = (text) => {
 	console.log(typeof text);
 	return text ? text.split('\n').length : 0;
 };
+
+export const convertOpenApiToToolPayload = (openApiSpec) => {
+	const toolPayload = [];
+
+	for (const [path, methods] of Object.entries(openApiSpec.paths)) {
+		for (const [method, operation] of Object.entries(methods)) {
+			const tool = {
+				type: 'function',
+				name: operation.operationId,
+				description: operation.summary || 'No description available.',
+				parameters: {
+					type: 'object',
+					properties: {},
+					required: []
+				}
+			};
+
+			// Extract path or query parameters
+			if (operation.parameters) {
+				operation.parameters.forEach((param) => {
+					tool.parameters.properties[param.name] = {
+						type: param.schema.type,
+						description: param.schema.description || ''
+					};
+
+					if (param.required) {
+						tool.parameters.required.push(param.name);
+					}
+				});
+			}
+
+			// Extract parameters from requestBody if applicable
+			if (operation.requestBody) {
+				const ref = operation.requestBody.content['application/json'].schema['$ref'];
+				if (ref) {
+					const schemaName = ref.split('/').pop();
+					const schemaDef = openApiSpec.components.schemas[schemaName];
+
+					if (schemaDef && schemaDef.properties) {
+						for (const [prop, details] of Object.entries(schemaDef.properties)) {
+							tool.parameters.properties[prop] = {
+								type: details.type,
+								description: details.description || ''
+							};
+						}
+						tool.parameters.required = schemaDef.required || [];
+					}
+				}
+			}
+
+			toolPayload.push(tool);
+		}
+	}
+
+	return toolPayload;
+};

+ 6 - 2
src/routes/(app)/+layout.svelte

@@ -12,7 +12,7 @@
 
 	import { getKnowledgeBases } from '$lib/apis/knowledge';
 	import { getFunctions } from '$lib/apis/functions';
-	import { getModels, getVersionUpdates } from '$lib/apis';
+	import { getModels, getToolServersData, getVersionUpdates } from '$lib/apis';
 	import { getAllTags } from '$lib/apis/chats';
 	import { getPrompts } from '$lib/apis/prompts';
 	import { getTools } from '$lib/apis/tools';
@@ -35,7 +35,8 @@
 		banners,
 		showSettings,
 		showChangelog,
-		temporaryChatEnabled
+		temporaryChatEnabled,
+		toolServers
 	} from '$lib/stores';
 
 	import Sidebar from '$lib/components/layout/Sidebar.svelte';
@@ -43,6 +44,7 @@
 	import ChangelogModal from '$lib/components/ChangelogModal.svelte';
 	import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte';
 	import UpdateInfoToast from '$lib/components/layout/UpdateInfoToast.svelte';
+	import { get } from 'svelte/store';
 
 	const i18n = getContext('i18n');
 
@@ -99,8 +101,10 @@
 					$config?.features?.enable_direct_connections && ($settings?.directConnections ?? null)
 				)
 			);
+
 			banners.set(await getBanners(localStorage.token));
 			tools.set(await getTools(localStorage.token));
+			toolServers.set(await getToolServersData($settings?.toolServers ?? []));
 
 			document.addEventListener('keydown', async function (event) {
 				const isCtrlPressed = event.ctrlKey || event.metaKey; // metaKey is for Cmd key on Mac

+ 31 - 12
src/routes/+layout.svelte

@@ -31,7 +31,7 @@
 	import { page } from '$app/stores';
 	import { Toaster, toast } from 'svelte-sonner';
 
-	import { getBackendConfig } from '$lib/apis';
+	import { executeToolServer, getBackendConfig } from '$lib/apis';
 	import { getSessionUser } from '$lib/apis/auths';
 
 	import '../tailwind.css';
@@ -205,17 +205,36 @@
 
 	const executeTool = async (data, cb) => {
 		console.log(data);
-		// TODO: MCP (SSE) support
-		// TODO: API Server support
-
-		if (cb) {
-			cb(
-				JSON.parse(
-					JSON.stringify({
-						result: null
-					})
-				)
-			);
+
+		const toolServer = $settings?.toolServers?.find((server) => server.url === data.server?.url);
+
+		if (toolServer) {
+			const res = await executeToolServer(
+				toolServer.key,
+				toolServer.url,
+				data?.name,
+				data?.params,
+				toolServer
+			).catch((error) => {
+				console.error('executeToolServer', error);
+				return {
+					error: error
+				};
+			});
+
+			if (cb) {
+				cb(JSON.parse(JSON.stringify(res)));
+			}
+		} else {
+			if (cb) {
+				cb(
+					JSON.parse(
+						JSON.stringify({
+							error: 'Tool Server Not Found'
+						})
+					)
+				);
+			}
 		}
 	};