|
@@ -96,94 +96,118 @@ async def get_tools(
|
|
|
for tool_id in tool_ids:
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
if tool is None:
|
|
|
+
|
|
|
if tool_id.startswith("server:"):
|
|
|
- server_id = tool_id.split(":")[1]
|
|
|
+ splits = tool_id.split(":")
|
|
|
+
|
|
|
+ if len(splits) == 2:
|
|
|
+ type = "openapi"
|
|
|
+ server_id = splits[1]
|
|
|
+ elif len(splits) == 3:
|
|
|
+ type = splits[1]
|
|
|
+ server_id = splits[2]
|
|
|
+
|
|
|
+ server_id_splits = server_id.split("|")
|
|
|
+ if len(server_id_splits) == 2:
|
|
|
+ server_id = server_id_splits[0]
|
|
|
+ function_names = server_id_splits[1].split(",")
|
|
|
+
|
|
|
+ if type == "openapi":
|
|
|
+
|
|
|
+ tool_server_data = None
|
|
|
+ for server in await get_tool_servers(request):
|
|
|
+ if server["id"] == server_id:
|
|
|
+ tool_server_data = server
|
|
|
+ break
|
|
|
+
|
|
|
+ if tool_server_data is None:
|
|
|
+ log.warning(f"Tool server data not found for {server_id}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ tool_server_idx = tool_server_data.get("idx", 0)
|
|
|
+ tool_server_connection = (
|
|
|
+ request.app.state.config.TOOL_SERVER_CONNECTIONS[
|
|
|
+ tool_server_idx
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
- tool_server_data = None
|
|
|
- for server in await get_tool_servers(request):
|
|
|
- if server["id"] == server_id:
|
|
|
- tool_server_data = server
|
|
|
- break
|
|
|
+ specs = tool_server_data.get("specs", [])
|
|
|
+ for spec in specs:
|
|
|
+ function_name = spec["name"]
|
|
|
|
|
|
- if tool_server_data is None:
|
|
|
- log.warning(f"Tool server data not found for {server_id}")
|
|
|
- continue
|
|
|
+ auth_type = tool_server_connection.get("auth_type", "bearer")
|
|
|
|
|
|
- tool_server_idx = tool_server_data.get("idx", 0)
|
|
|
- tool_server_connection = (
|
|
|
- request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx]
|
|
|
- )
|
|
|
+ cookies = {}
|
|
|
+ headers = {}
|
|
|
|
|
|
- specs = tool_server_data.get("specs", [])
|
|
|
- for spec in specs:
|
|
|
- function_name = spec["name"]
|
|
|
+ if auth_type == "bearer":
|
|
|
+ headers["Authorization"] = (
|
|
|
+ f"Bearer {tool_server_connection.get('key', '')}"
|
|
|
+ )
|
|
|
+ elif auth_type == "none":
|
|
|
+ # No authentication
|
|
|
+ pass
|
|
|
+ elif auth_type == "session":
|
|
|
+ cookies = request.cookies
|
|
|
+ headers["Authorization"] = (
|
|
|
+ f"Bearer {request.state.token.credentials}"
|
|
|
+ )
|
|
|
+ elif auth_type == "system_oauth":
|
|
|
+ cookies = request.cookies
|
|
|
+ oauth_token = extra_params.get("__oauth_token__", None)
|
|
|
+ if oauth_token:
|
|
|
+ headers["Authorization"] = (
|
|
|
+ f"Bearer {oauth_token.get('access_token', '')}"
|
|
|
+ )
|
|
|
|
|
|
- auth_type = tool_server_connection.get("auth_type", "bearer")
|
|
|
+ headers["Content-Type"] = "application/json"
|
|
|
+
|
|
|
+ def make_tool_function(
|
|
|
+ function_name, tool_server_data, headers
|
|
|
+ ):
|
|
|
+ async def tool_function(**kwargs):
|
|
|
+ return await execute_tool_server(
|
|
|
+ url=tool_server_data["url"],
|
|
|
+ headers=headers,
|
|
|
+ cookies=cookies,
|
|
|
+ name=function_name,
|
|
|
+ params=kwargs,
|
|
|
+ server_data=tool_server_data,
|
|
|
+ )
|
|
|
|
|
|
- cookies = {}
|
|
|
- headers = {}
|
|
|
+ return tool_function
|
|
|
|
|
|
- if auth_type == "bearer":
|
|
|
- headers["Authorization"] = (
|
|
|
- f"Bearer {tool_server_connection.get('key', '')}"
|
|
|
+ tool_function = make_tool_function(
|
|
|
+ function_name, tool_server_data, headers
|
|
|
)
|
|
|
- elif auth_type == "none":
|
|
|
- # No authentication
|
|
|
- pass
|
|
|
- elif auth_type == "session":
|
|
|
- cookies = request.cookies
|
|
|
- headers["Authorization"] = (
|
|
|
- f"Bearer {request.state.token.credentials}"
|
|
|
+
|
|
|
+ callable = get_async_tool_function_and_apply_extra_params(
|
|
|
+ tool_function,
|
|
|
+ {},
|
|
|
)
|
|
|
- elif auth_type == "system_oauth":
|
|
|
- cookies = request.cookies
|
|
|
- oauth_token = extra_params.get("__oauth_token__", None)
|
|
|
- if oauth_token:
|
|
|
- headers["Authorization"] = (
|
|
|
- f"Bearer {oauth_token.get('access_token', '')}"
|
|
|
- )
|
|
|
|
|
|
- headers["Content-Type"] = "application/json"
|
|
|
-
|
|
|
- def make_tool_function(function_name, tool_server_data, headers):
|
|
|
- async def tool_function(**kwargs):
|
|
|
- return await execute_tool_server(
|
|
|
- url=tool_server_data["url"],
|
|
|
- headers=headers,
|
|
|
- cookies=cookies,
|
|
|
- name=function_name,
|
|
|
- params=kwargs,
|
|
|
- server_data=tool_server_data,
|
|
|
+ tool_dict = {
|
|
|
+ "tool_id": tool_id,
|
|
|
+ "callable": callable,
|
|
|
+ "spec": spec,
|
|
|
+ # Misc info
|
|
|
+ "type": "external",
|
|
|
+ }
|
|
|
+
|
|
|
+ # Handle function name collisions
|
|
|
+ while function_name in tools_dict:
|
|
|
+ log.warning(
|
|
|
+ f"Tool {function_name} already exists in another tools!"
|
|
|
)
|
|
|
+ # Prepend server ID to function name
|
|
|
+ function_name = f"{server_id}_{function_name}"
|
|
|
|
|
|
- return tool_function
|
|
|
-
|
|
|
- tool_function = make_tool_function(
|
|
|
- function_name, tool_server_data, headers
|
|
|
- )
|
|
|
-
|
|
|
- callable = get_async_tool_function_and_apply_extra_params(
|
|
|
- tool_function,
|
|
|
- {},
|
|
|
- )
|
|
|
+ tools_dict[function_name] = tool_dict
|
|
|
|
|
|
- tool_dict = {
|
|
|
- "tool_id": tool_id,
|
|
|
- "callable": callable,
|
|
|
- "spec": spec,
|
|
|
- # Misc info
|
|
|
- "type": "external",
|
|
|
- }
|
|
|
-
|
|
|
- # Handle function name collisions
|
|
|
- while function_name in tools_dict:
|
|
|
- log.warning(
|
|
|
- f"Tool {function_name} already exists in another tools!"
|
|
|
- )
|
|
|
- # Prepend server ID to function name
|
|
|
- function_name = f"{server_id}_{function_name}"
|
|
|
+ else:
|
|
|
+ log.warning(f"Unsupported tool server type: {type}")
|
|
|
+ continue
|
|
|
|
|
|
- tools_dict[function_name] = tool_dict
|
|
|
else:
|
|
|
continue
|
|
|
else:
|
|
@@ -579,7 +603,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
|
|
|
# Prepare list of enabled servers along with their original index
|
|
|
server_entries = []
|
|
|
for idx, server in enumerate(servers):
|
|
|
- if server.get("config", {}).get("enable"):
|
|
|
+ if (
|
|
|
+ server.get("config", {}).get("enable")
|
|
|
+ and server.get("type", "openapi") == "openapi"
|
|
|
+ ):
|
|
|
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
|
|
|
openapi_path = server.get("path", "openapi.json")
|
|
|
full_url = get_tool_server_url(server.get("url"), openapi_path)
|