Browse Source

refac: tool name collision handling

Timothy Jaeryang Baek 1 month ago
parent
commit
70d0477418
2 changed files with 104 additions and 82 deletions
  1. 26 15
      backend/open_webui/utils/tools.py
  2. 78 67
      src/lib/components/chat/MessageInput/InputMenu.svelte

+ 26 - 15
backend/open_webui/utils/tools.py

@@ -5,6 +5,7 @@ import inspect
 import aiohttp
 import asyncio
 import yaml
+import json
 
 from pydantic import BaseModel
 from pydantic.fields import FieldInfo
@@ -85,7 +86,9 @@ async def get_tools(
                         tool_server_data = server
                         break
 
-                assert tool_server_data is not None
+                if tool_server_data is None:
+                    log.warning(f"Tool server data not found for {server_id}")
+                    continue
 
                 tool_server_idx = tool_server_data.get("idx", 0)
                 tool_server_connection = (
@@ -131,14 +134,15 @@ async def get_tools(
                         "spec": spec,
                     }
 
-                    # TODO: if collision, prepend toolkit name
-                    if function_name in tools_dict:
+                    # Handle function name collisions
+                    while function_name in tools_dict:
                         log.warning(
                             f"Tool {function_name} already exists in another tools!"
                         )
-                        log.warning(f"Discarding {tool_id}.{function_name}")
-                    else:
-                        tools_dict[function_name] = tool_dict
+                        # Prepend server ID to function name
+                        function_name = f"{server_id}_{function_name}"
+
+                    tools_dict[function_name] = tool_dict
             else:
                 continue
         else:
@@ -198,14 +202,15 @@ async def get_tools(
                     },
                 }
 
-                # TODO: if collision, prepend toolkit name
-                if function_name in tools_dict:
+                # Handle function name collisions
+                while function_name in tools_dict:
                     log.warning(
                         f"Tool {function_name} already exists in another tools!"
                     )
-                    log.warning(f"Discarding {tool_id}.{function_name}")
-                else:
-                    tools_dict[function_name] = tool_dict
+                    # Prepend tool ID to function name
+                    function_name = f"{tool_id}_{function_name}"
+
+                tools_dict[function_name] = tool_dict
 
     return tools_dict
 
@@ -453,8 +458,8 @@ async def set_tool_servers(request: Request):
     )
 
     if request.app.state.redis is not None:
-        await request.app.state.redis.hmset(
-            "tool_servers", request.app.state.TOOL_SERVERS
+        await request.app.state.redis.set(
+            "tool_servers", json.dumps(request.app.state.TOOL_SERVERS)
         )
 
     return request.app.state.TOOL_SERVERS
@@ -463,7 +468,10 @@ async def set_tool_servers(request: Request):
 async def get_tool_servers(request: Request):
     tool_servers = []
     if request.app.state.redis is not None:
-        tool_servers = await request.app.state.redis.hgetall("tool_servers")
+        try:
+            tool_servers = json.loads(await request.app.state.redis.get("tool_servers"))
+        except Exception as e:
+            log.error(f"Error fetching tool_servers from Redis: {e}")
 
     if not tool_servers:
         await set_tool_servers(request)
@@ -536,7 +544,10 @@ async def get_tool_servers_data(
             elif auth_type == "session":
                 token = session_token
 
-            id = info.get("id", idx)
+            id = info.get("id")
+            if not id:
+                id = str(idx)
+
             server_entries.append((id, idx, server, full_url, info, token))
 
     # Create async tasks to fetch data

+ 78 - 67
src/lib/components/chat/MessageInput/InputMenu.svelte

@@ -17,6 +17,7 @@
 	import CameraSolid from '$lib/components/icons/CameraSolid.svelte';
 	import PhotoSolid from '$lib/components/icons/PhotoSolid.svelte';
 	import CommandLineSolid from '$lib/components/icons/CommandLineSolid.svelte';
+	import Spinner from '$lib/components/common/Spinner.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -34,7 +35,7 @@
 
 	export let onClose: Function;
 
-	let tools = {};
+	let tools = null;
 	let show = false;
 	let showAllTools = false;
 
@@ -49,15 +50,17 @@
 
 	const init = async () => {
 		await _tools.set(await getTools(localStorage.token));
-
-		tools = $_tools.reduce((a, tool, i, arr) => {
-			a[tool.id] = {
-				name: tool.name,
-				description: tool.meta.description,
-				enabled: selectedToolIds.includes(tool.id)
-			};
-			return a;
-		}, {});
+		if ($_tools) {
+			tools = $_tools.reduce((a, tool, i, arr) => {
+				a[tool.id] = {
+					name: tool.name,
+					description: tool.meta.description,
+					enabled: selectedToolIds.includes(tool.id)
+				};
+				return a;
+			}, {});
+			selectedToolIds = selectedToolIds.filter((id) => $_tools?.some((tool) => tool.id === id));
+		}
 	};
 
 	const detectMobile = () => {
@@ -105,69 +108,77 @@
 			align="start"
 			transition={flyAndScale}
 		>
-			{#if Object.keys(tools).length > 0}
-				<div class="{showAllTools ? '' : 'max-h-28'} overflow-y-auto scrollbar-thin">
-					{#each Object.keys(tools) as toolId}
+			{#if tools}
+				{#if Object.keys(tools).length > 0}
+					<div class="{showAllTools ? '' : 'max-h-28'} overflow-y-auto scrollbar-thin">
+						{#each Object.keys(tools) as toolId}
+							<button
+								class="flex w-full justify-between gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
+								on:click={() => {
+									tools[toolId].enabled = !tools[toolId].enabled;
+								}}
+							>
+								<div class="flex-1 truncate">
+									<Tooltip
+										content={tools[toolId]?.description ?? ''}
+										placement="top-start"
+										className="flex flex-1 gap-2 items-center"
+									>
+										<div class="shrink-0">
+											<WrenchSolid />
+										</div>
+
+										<div class=" truncate">{tools[toolId].name}</div>
+									</Tooltip>
+								</div>
+
+								<div class=" shrink-0">
+									<Switch
+										state={tools[toolId].enabled}
+										on:change={async (e) => {
+											const state = e.detail;
+											await tick();
+											if (state) {
+												selectedToolIds = [...selectedToolIds, toolId];
+											} else {
+												selectedToolIds = selectedToolIds.filter((id) => id !== toolId);
+											}
+										}}
+									/>
+								</div>
+							</button>
+						{/each}
+					</div>
+					{#if Object.keys(tools).length > 3}
 						<button
-							class="flex w-full justify-between gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
+							class="flex w-full justify-center items-center text-sm font-medium cursor-pointer rounded-lg hover:bg-gray-50 dark:hover:bg-gray-800"
 							on:click={() => {
-								tools[toolId].enabled = !tools[toolId].enabled;
+								showAllTools = !showAllTools;
 							}}
+							title={showAllTools ? $i18n.t('Show Less') : $i18n.t('Show All')}
 						>
-							<div class="flex-1 truncate">
-								<Tooltip
-									content={tools[toolId]?.description ?? ''}
-									placement="top-start"
-									className="flex flex-1 gap-2 items-center"
-								>
-									<div class="shrink-0">
-										<WrenchSolid />
-									</div>
-
-									<div class=" truncate">{tools[toolId].name}</div>
-								</Tooltip>
-							</div>
-
-							<div class=" shrink-0">
-								<Switch
-									state={tools[toolId].enabled}
-									on:change={async (e) => {
-										const state = e.detail;
-										await tick();
-										if (state) {
-											selectedToolIds = [...selectedToolIds, toolId];
-										} else {
-											selectedToolIds = selectedToolIds.filter((id) => id !== toolId);
-										}
-									}}
-								/>
-							</div>
+							<svg
+								xmlns="http://www.w3.org/2000/svg"
+								fill="none"
+								viewBox="0 0 24 24"
+								stroke-width="2.5"
+								stroke="currentColor"
+								class="size-3 transition-transform duration-200 {showAllTools
+									? 'rotate-180'
+									: ''} text-gray-300 dark:text-gray-600"
+							>
+								<path stroke-linecap="round" stroke-linejoin="round" d="m19.5 8.25-7.5 7.5-7.5-7.5"
+								></path>
+							</svg>
 						</button>
-					{/each}
-				</div>
-				{#if Object.keys(tools).length > 3}
-					<button
-						class="flex w-full justify-center items-center text-sm font-medium cursor-pointer rounded-lg hover:bg-gray-50 dark:hover:bg-gray-800"
-						on:click={() => {
-							showAllTools = !showAllTools;
-						}}
-						title={showAllTools ? $i18n.t('Show Less') : $i18n.t('Show All')}
-					>
-						<svg
-							xmlns="http://www.w3.org/2000/svg"
-							fill="none"
-							viewBox="0 0 24 24"
-							stroke-width="2.5"
-							stroke="currentColor"
-							class="size-3 transition-transform duration-200 {showAllTools
-								? 'rotate-180'
-								: ''} text-gray-300 dark:text-gray-600"
-						>
-							<path stroke-linecap="round" stroke-linejoin="round" d="m19.5 8.25-7.5 7.5-7.5-7.5"
-							></path>
-						</svg>
-					</button>
+					{/if}
+					<hr class="border-black/5 dark:border-white/5 my-1" />
 				{/if}
+			{:else}
+				<div class="py-4">
+					<Spinner />
+				</div>
+
 				<hr class="border-black/5 dark:border-white/5 my-1" />
 			{/if}