Browse Source

refac: tool servers

Timothy Jaeryang Baek 1 month ago
parent
commit
9747a0e1f1

+ 35 - 4
backend/open_webui/routers/tools.py

@@ -1,6 +1,7 @@
 import logging
 import logging
 from pathlib import Path
 from pathlib import Path
 from typing import Optional
 from typing import Optional
+import time
 
 
 from open_webui.models.tools import (
 from open_webui.models.tools import (
     ToolForm,
     ToolForm,
@@ -43,10 +44,40 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
             request.app.state.config.TOOL_SERVER_CONNECTIONS
             request.app.state.config.TOOL_SERVER_CONNECTIONS
         )
         )
 
 
-    if user.role == "admin":
-        tools = Tools.get_tools()
-    else:
-        tools = Tools.get_tools_by_user_id(user.id, "read")
+    tools = Tools.get_tools()
+    for idx, server in enumerate(request.app.state.TOOL_SERVERS):
+        tools.append(
+            ToolUserResponse(
+                **{
+                    "id": f"server:{server['idx']}",
+                    "user_id": f"server:{server['idx']}",
+                    "name": server["openapi"]
+                    .get("info", {})
+                    .get("title", "Tool Server"),
+                    "meta": {
+                        "description": server["openapi"]
+                        .get("info", {})
+                        .get("description", ""),
+                    },
+                    "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
+                        idx
+                    ]
+                    .get("config", {})
+                    .get("access_control", None),
+                    "updated_at": int(time.time()),
+                    "created_at": int(time.time()),
+                }
+            )
+        )
+
+    if user.role != "admin":
+        tools = [
+            tool
+            for tool in tools
+            if tool.user_id == user.id
+            or has_access(user.id, "read", tool.access_control)
+        ]
+
     return tools
     return tools
 
 
 
 

+ 35 - 28
backend/open_webui/utils/tools.py

@@ -5,7 +5,7 @@ import inspect
 import aiohttp
 import aiohttp
 import asyncio
 import asyncio
 
 
-from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union
+from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional
 from functools import update_wrapper, partial
 from functools import update_wrapper, partial
 
 
 
 
@@ -348,40 +348,47 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
     return data
     return data
 
 
 
 
-async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
-    enabled_servers = [
-        server for server in servers if server.get("config", {}).get("enable")
-    ]
+async def get_tool_servers_data(
+    servers: List[Dict[str, Any]], session_token: Optional[str] = None
+) -> List[Dict[str, Any]]:
+    # Prepare list of enabled servers along with their original index
+    server_entries = []
+    for idx, server in enumerate(servers):
+        if server.get("config", {}).get("enable"):
+            url_path = server.get("path", "openapi.json")
+            full_url = f"{server.get('url')}/{url_path}"
 
 
-    urls = [
-        (
-            server,
-            f"{server.get('url')}/{server.get('path', 'openapi.json')}",
-            server.get("key", ""),
-        )
-        for server in enabled_servers
-    ]
+            auth_type = server.get("auth_type", "bearer")
+            token = None
 
 
-    tasks = [get_tool_server_data(token, url) for _, url, token in urls]
+            if auth_type == "bearer":
+                token = server.get("key", "")
+            elif auth_type == "session":
+                token = session_token
+            server_entries.append((idx, server, full_url, token))
 
 
-    results: List[Dict[str, Any]] = []
+    # Create async tasks to fetch data
+    tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
 
 
+    # Execute tasks concurrently
     responses = await asyncio.gather(*tasks, return_exceptions=True)
     responses = await asyncio.gather(*tasks, return_exceptions=True)
 
 
-    for (server, _, _), response in zip(urls, responses):
+    # Build final results with index and server metadata
+    results = []
+    for (idx, server, url, _), response in zip(server_entries, responses):
         if isinstance(response, Exception):
         if isinstance(response, Exception):
-            url_path = server.get("path", "openapi.json")
-            full_url = f"{server.get('url')}/{url_path}"
-            print(f"Failed to connect to {full_url} OpenAPI tool server")
-        else:
-            results.append(
-                {
-                    "url": server.get("url"),
-                    "openapi": response["openapi"],
-                    "info": response["info"],
-                    "specs": response["specs"],
-                }
-            )
+            print(f"Failed to connect to {url} OpenAPI tool server")
+            continue
+
+        results.append(
+            {
+                "idx": idx,
+                "url": server.get("url"),
+                "openapi": response.get("openapi"),
+                "info": response.get("info"),
+                "specs": response.get("specs"),
+            }
+        )
 
 
     return results
     return results
 
 

+ 19 - 2
src/lib/components/AddServerModal.svelte

@@ -17,6 +17,7 @@
 	import Tags from './common/Tags.svelte';
 	import Tags from './common/Tags.svelte';
 	import { getToolServerData } from '$lib/apis';
 	import { getToolServerData } from '$lib/apis';
 	import { verifyToolServerConnection } from '$lib/apis/configs';
 	import { verifyToolServerConnection } from '$lib/apis/configs';
+	import AccessControl from './workspace/common/AccessControl.svelte';
 
 
 	export let onSubmit: Function = () => {};
 	export let onSubmit: Function = () => {};
 	export let onDelete: Function = () => {};
 	export let onDelete: Function = () => {};
@@ -34,6 +35,8 @@
 	let auth_type = 'bearer';
 	let auth_type = 'bearer';
 	let key = '';
 	let key = '';
 
 
+	let accessControl = null;
+
 	let enable = true;
 	let enable = true;
 
 
 	let loading = false;
 	let loading = false;
@@ -68,7 +71,8 @@
 				auth_type,
 				auth_type,
 				key,
 				key,
 				config: {
 				config: {
-					enable: enable
+					enable: enable,
+					access_control: accessControl
 				}
 				}
 			}).catch((err) => {
 			}).catch((err) => {
 				toast.error($i18n.t('Connection failed'));
 				toast.error($i18n.t('Connection failed'));
@@ -93,7 +97,8 @@
 			auth_type,
 			auth_type,
 			key,
 			key,
 			config: {
 			config: {
-				enable: enable
+				enable: enable,
+				access_control: accessControl
 			}
 			}
 		};
 		};
 
 
@@ -108,6 +113,7 @@
 		auth_type = 'bearer';
 		auth_type = 'bearer';
 
 
 		enable = true;
 		enable = true;
+		accessControl = null;
 	};
 	};
 
 
 	const init = () => {
 	const init = () => {
@@ -119,6 +125,7 @@
 			key = connection?.key ?? '';
 			key = connection?.key ?? '';
 
 
 			enable = connection.config?.enable ?? true;
 			enable = connection.config?.enable ?? true;
+			accessControl = connection.config?.access_control ?? null;
 		}
 		}
 	};
 	};
 
 
@@ -269,6 +276,16 @@
 								</div>
 								</div>
 							</div>
 							</div>
 						</div>
 						</div>
+
+						{#if !direct}
+							<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
+
+							<div class="my-2 -mx-2">
+								<div class="px-3 py-2 bg-gray-50 dark:bg-gray-950 rounded-lg">
+									<AccessControl bind:accessControl />
+								</div>
+							</div>
+						{/if}
 					</div>
 					</div>
 
 
 					<div class="flex justify-end pt-3 text-sm font-medium gap-1.5">
 					<div class="flex justify-end pt-3 text-sm font-medium gap-1.5">

+ 1 - 1
src/lib/components/chat/Settings/Tools/Connection.svelte

@@ -20,7 +20,7 @@
 
 
 <AddServerModal
 <AddServerModal
 	edit
 	edit
-	direct
+	{direct}
 	bind:show={showConfigModal}
 	bind:show={showConfigModal}
 	{connection}
 	{connection}
 	onDelete={() => {
 	onDelete={() => {