|
@@ -2,6 +2,9 @@ import logging
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
import time
|
|
|
+import re
|
|
|
+import aiohttp
|
|
|
+from pydantic import BaseModel, HttpUrl
|
|
|
|
|
|
from open_webui.models.tools import (
|
|
|
ToolForm,
|
|
@@ -21,6 +24,7 @@ from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
|
from open_webui.utils.tools import get_tool_servers_data
|
|
|
|
|
|
+
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
|
|
@@ -95,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
|
|
|
return tools
|
|
|
|
|
|
|
|
|
+############################
|
|
|
+# LoadFunctionFromLink
|
|
|
+############################
|
|
|
+
|
|
|
+
|
|
|
+class LoadUrlForm(BaseModel):
|
|
|
+ url: HttpUrl
|
|
|
+
|
|
|
+
|
|
|
+def github_url_to_raw_url(url: str) -> str:
|
|
|
+ # Handle 'tree' (folder) URLs (add main.py at the end)
|
|
|
+ m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
|
|
+ if m1:
|
|
|
+ org, repo, branch, path = m1.groups()
|
|
|
+ return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
|
|
+
|
|
|
+ # Handle 'blob' (file) URLs
|
|
|
+ m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
|
|
+ if m2:
|
|
|
+ org, repo, branch, path = m2.groups()
|
|
|
+ return (
|
|
|
+ f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # No match; return as-is
|
|
|
+ return url
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/load/url", response_model=Optional[dict])
|
|
|
+async def load_tool_from_url(
|
|
|
+ request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
|
|
+):
|
|
|
+ # NOTE: This is NOT a SSRF vulnerability:
|
|
|
+ # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
|
|
+ # and does NOT accept untrusted user input. Access is enforced by authentication.
|
|
|
+
|
|
|
+ url = str(form_data.url)
|
|
|
+ if not url:
|
|
|
+ raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
|
|
+
|
|
|
+ url = github_url_to_raw_url(url)
|
|
|
+ url_parts = url.rstrip("/").split("/")
|
|
|
+
|
|
|
+ file_name = url_parts[-1]
|
|
|
+ tool_name = (
|
|
|
+ file_name[:-3]
|
|
|
+ if (
|
|
|
+ file_name.endswith(".py")
|
|
|
+ and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
|
|
+ )
|
|
|
+ else url_parts[-2] if len(url_parts) > 1 else "function"
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ async with session.get(
|
|
|
+ url, headers={"Content-Type": "application/json"}
|
|
|
+ ) as resp:
|
|
|
+ if resp.status != 200:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=resp.status, detail="Failed to fetch the tool"
|
|
|
+ )
|
|
|
+ data = await resp.text()
|
|
|
+ if not data:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=400, detail="No data received from the URL"
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "name": tool_name,
|
|
|
+ "content": data,
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
|
|
|
+
|
|
|
+
|
|
|
############################
|
|
|
# ExportTools
|
|
|
############################
|