configs.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. import logging
  2. import copy
  3. from fastapi import APIRouter, Depends, Request, HTTPException
  4. from pydantic import BaseModel, ConfigDict
  5. import aiohttp
  6. from typing import Optional
  7. from open_webui.utils.auth import get_admin_user, get_verified_user
  8. from open_webui.config import get_config, save_config
  9. from open_webui.config import BannerModel
  10. from open_webui.utils.tools import (
  11. get_tool_server_data,
  12. get_tool_server_url,
  13. set_tool_servers,
  14. )
  15. from open_webui.utils.mcp.client import MCPClient
  16. from open_webui.models.oauth_sessions import OAuthSessions
  17. from open_webui.env import SRC_LOG_LEVELS
  18. from open_webui.utils.oauth import (
  19. get_discovery_urls,
  20. get_oauth_client_info_with_dynamic_client_registration,
  21. encrypt_data,
  22. decrypt_data,
  23. OAuthClientInformationFull,
  24. )
  25. from mcp.shared.auth import OAuthMetadata
  26. router = APIRouter()
  27. log = logging.getLogger(__name__)
  28. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  29. ############################
  30. # ImportConfig
  31. ############################
  32. class ImportConfigForm(BaseModel):
  33. config: dict
  34. @router.post("/import", response_model=dict)
  35. async def import_config(form_data: ImportConfigForm, user=Depends(get_admin_user)):
  36. save_config(form_data.config)
  37. return get_config()
  38. ############################
  39. # ExportConfig
  40. ############################
  41. @router.get("/export", response_model=dict)
  42. async def export_config(user=Depends(get_admin_user)):
  43. return get_config()
  44. ############################
  45. # Connections Config
  46. ############################
  47. class ConnectionsConfigForm(BaseModel):
  48. ENABLE_DIRECT_CONNECTIONS: bool
  49. ENABLE_BASE_MODELS_CACHE: bool
  50. @router.get("/connections", response_model=ConnectionsConfigForm)
  51. async def get_connections_config(request: Request, user=Depends(get_admin_user)):
  52. return {
  53. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  54. "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
  55. }
  56. @router.post("/connections", response_model=ConnectionsConfigForm)
  57. async def set_connections_config(
  58. request: Request,
  59. form_data: ConnectionsConfigForm,
  60. user=Depends(get_admin_user),
  61. ):
  62. request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
  63. form_data.ENABLE_DIRECT_CONNECTIONS
  64. )
  65. request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
  66. form_data.ENABLE_BASE_MODELS_CACHE
  67. )
  68. return {
  69. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  70. "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
  71. }
  72. class OAuthClientRegistrationForm(BaseModel):
  73. url: str
  74. client_id: str
  75. client_name: Optional[str] = None
  76. @router.post("/oauth/clients/register")
  77. async def register_oauth_client(
  78. request: Request,
  79. form_data: OAuthClientRegistrationForm,
  80. type: Optional[str] = None,
  81. user=Depends(get_admin_user),
  82. ):
  83. try:
  84. oauth_client_id = form_data.client_id
  85. if type:
  86. oauth_client_id = f"{type}:{form_data.client_id}"
  87. oauth_client_info = (
  88. await get_oauth_client_info_with_dynamic_client_registration(
  89. request, oauth_client_id, form_data.url
  90. )
  91. )
  92. return {
  93. "status": True,
  94. "oauth_client_info": encrypt_data(
  95. oauth_client_info.model_dump(mode="json")
  96. ),
  97. }
  98. except Exception as e:
  99. log.debug(f"Failed to register OAuth client: {e}")
  100. raise HTTPException(
  101. status_code=400,
  102. detail=f"Failed to register OAuth client",
  103. )
  104. ############################
  105. # ToolServers Config
  106. ############################
  107. class ToolServerConnection(BaseModel):
  108. url: str
  109. path: str
  110. type: Optional[str] = "openapi" # openapi, mcp
  111. auth_type: Optional[str]
  112. key: Optional[str]
  113. config: Optional[dict]
  114. model_config = ConfigDict(extra="allow")
  115. class ToolServersConfigForm(BaseModel):
  116. TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
  117. @router.get("/tool_servers", response_model=ToolServersConfigForm)
  118. async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
  119. return {
  120. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  121. }
  122. @router.post("/tool_servers", response_model=ToolServersConfigForm)
  123. async def set_tool_servers_config(
  124. request: Request,
  125. form_data: ToolServersConfigForm,
  126. user=Depends(get_admin_user),
  127. ):
  128. for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  129. server_type = connection.get("type", "openapi")
  130. auth_type = connection.get("auth_type", "none")
  131. if auth_type == "oauth_2.1":
  132. # Remove existing OAuth clients for tool servers
  133. server_id = connection.get("info", {}).get("id")
  134. client_key = f"{server_type}:{server_id}"
  135. try:
  136. request.app.state.oauth_client_manager.remove_client(client_key)
  137. except:
  138. pass
  139. # Set new tool server connections
  140. request.app.state.config.TOOL_SERVER_CONNECTIONS = [
  141. connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
  142. ]
  143. await set_tool_servers(request)
  144. for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  145. server_type = connection.get("type", "openapi")
  146. if server_type == "mcp":
  147. server_id = connection.get("info", {}).get("id")
  148. auth_type = connection.get("auth_type", "none")
  149. if auth_type == "oauth_2.1" and server_id:
  150. try:
  151. oauth_client_info = connection.get("info", {}).get(
  152. "oauth_client_info", ""
  153. )
  154. oauth_client_info = decrypt_data(oauth_client_info)
  155. request.app.state.oauth_client_manager.add_client(
  156. f"{server_type}:{server_id}",
  157. OAuthClientInformationFull(**oauth_client_info),
  158. )
  159. except Exception as e:
  160. log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
  161. continue
  162. return {
  163. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  164. }
  165. @router.post("/tool_servers/verify")
  166. async def verify_tool_servers_config(
  167. request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
  168. ):
  169. """
  170. Verify the connection to the tool server.
  171. """
  172. try:
  173. if form_data.type == "mcp":
  174. if form_data.auth_type == "oauth_2.1":
  175. discovery_urls = get_discovery_urls(form_data.url)
  176. for discovery_url in discovery_urls:
  177. log.debug(
  178. f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}"
  179. )
  180. async with aiohttp.ClientSession(trust_env=True) as session:
  181. async with session.get(
  182. discovery_url
  183. ) as oauth_server_metadata_response:
  184. if oauth_server_metadata_response.status == 200:
  185. try:
  186. oauth_server_metadata = (
  187. OAuthMetadata.model_validate(
  188. await oauth_server_metadata_response.json()
  189. )
  190. )
  191. return {
  192. "status": True,
  193. "oauth_server_metadata": oauth_server_metadata.model_dump(
  194. mode="json"
  195. ),
  196. }
  197. except Exception as e:
  198. log.info(
  199. f"Failed to parse OAuth 2.1 discovery document: {e}"
  200. )
  201. raise HTTPException(
  202. status_code=400,
  203. detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_url}",
  204. )
  205. raise HTTPException(
  206. status_code=400,
  207. detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls}",
  208. )
  209. else:
  210. try:
  211. client = MCPClient()
  212. headers = None
  213. token = None
  214. if form_data.auth_type == "bearer":
  215. token = form_data.key
  216. elif form_data.auth_type == "session":
  217. token = request.state.token.credentials
  218. elif form_data.auth_type == "system_oauth":
  219. try:
  220. if request.cookies.get("oauth_session_id", None):
  221. token = await request.app.state.oauth_manager.get_oauth_token(
  222. user.id,
  223. request.cookies.get("oauth_session_id", None),
  224. )
  225. except Exception as e:
  226. pass
  227. if token:
  228. headers = {"Authorization": f"Bearer {token}"}
  229. await client.connect(form_data.url, headers=headers)
  230. specs = await client.list_tool_specs()
  231. return {
  232. "status": True,
  233. "specs": specs,
  234. }
  235. except Exception as e:
  236. log.debug(f"Failed to create MCP client: {e}")
  237. raise HTTPException(
  238. status_code=400,
  239. detail=f"Failed to create MCP client",
  240. )
  241. finally:
  242. if client:
  243. await client.disconnect()
  244. else: # openapi
  245. token = None
  246. if form_data.auth_type == "bearer":
  247. token = form_data.key
  248. elif form_data.auth_type == "session":
  249. token = request.state.token.credentials
  250. elif form_data.auth_type == "system_oauth":
  251. try:
  252. if request.cookies.get("oauth_session_id", None):
  253. token = await request.app.state.oauth_manager.get_oauth_token(
  254. user.id,
  255. request.cookies.get("oauth_session_id", None),
  256. )
  257. except Exception as e:
  258. pass
  259. url = get_tool_server_url(form_data.url, form_data.path)
  260. return await get_tool_server_data(token, url)
  261. except HTTPException as e:
  262. raise e
  263. except Exception as e:
  264. log.debug(f"Failed to connect to the tool server: {e}")
  265. raise HTTPException(
  266. status_code=400,
  267. detail=f"Failed to connect to the tool server",
  268. )
  269. ############################
  270. # CodeInterpreterConfig
  271. ############################
  272. class CodeInterpreterConfigForm(BaseModel):
  273. ENABLE_CODE_EXECUTION: bool
  274. CODE_EXECUTION_ENGINE: str
  275. CODE_EXECUTION_JUPYTER_URL: Optional[str]
  276. CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
  277. CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
  278. CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
  279. CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
  280. ENABLE_CODE_INTERPRETER: bool
  281. CODE_INTERPRETER_ENGINE: str
  282. CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
  283. CODE_INTERPRETER_JUPYTER_URL: Optional[str]
  284. CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
  285. CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
  286. CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
  287. CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
  288. @router.get("/code_execution", response_model=CodeInterpreterConfigForm)
  289. async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
  290. return {
  291. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  292. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  293. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  294. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  295. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  296. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  297. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  298. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  299. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  300. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  301. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  302. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  303. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  304. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  305. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  306. }
  307. @router.post("/code_execution", response_model=CodeInterpreterConfigForm)
  308. async def set_code_execution_config(
  309. request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
  310. ):
  311. request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
  312. request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
  313. request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
  314. form_data.CODE_EXECUTION_JUPYTER_URL
  315. )
  316. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
  317. form_data.CODE_EXECUTION_JUPYTER_AUTH
  318. )
  319. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
  320. form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
  321. )
  322. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
  323. form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
  324. )
  325. request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
  326. form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
  327. )
  328. request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
  329. request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
  330. request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
  331. form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
  332. )
  333. request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
  334. form_data.CODE_INTERPRETER_JUPYTER_URL
  335. )
  336. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
  337. form_data.CODE_INTERPRETER_JUPYTER_AUTH
  338. )
  339. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
  340. form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
  341. )
  342. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
  343. form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
  344. )
  345. request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
  346. form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
  347. )
  348. return {
  349. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  350. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  351. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  352. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  353. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  354. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  355. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  356. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  357. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  358. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  359. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  360. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  361. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  362. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  363. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  364. }
  365. ############################
  366. # SetDefaultModels
  367. ############################
  368. class ModelsConfigForm(BaseModel):
  369. DEFAULT_MODELS: Optional[str]
  370. MODEL_ORDER_LIST: Optional[list[str]]
  371. @router.get("/models", response_model=ModelsConfigForm)
  372. async def get_models_config(request: Request, user=Depends(get_admin_user)):
  373. return {
  374. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  375. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  376. }
  377. @router.post("/models", response_model=ModelsConfigForm)
  378. async def set_models_config(
  379. request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
  380. ):
  381. request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
  382. request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
  383. return {
  384. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  385. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  386. }
  387. class PromptSuggestion(BaseModel):
  388. title: list[str]
  389. content: str
  390. class SetDefaultSuggestionsForm(BaseModel):
  391. suggestions: list[PromptSuggestion]
  392. @router.post("/suggestions", response_model=list[PromptSuggestion])
  393. async def set_default_suggestions(
  394. request: Request,
  395. form_data: SetDefaultSuggestionsForm,
  396. user=Depends(get_admin_user),
  397. ):
  398. data = form_data.model_dump()
  399. request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
  400. return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
  401. ############################
  402. # SetBanners
  403. ############################
  404. class SetBannersForm(BaseModel):
  405. banners: list[BannerModel]
  406. @router.post("/banners", response_model=list[BannerModel])
  407. async def set_banners(
  408. request: Request,
  409. form_data: SetBannersForm,
  410. user=Depends(get_admin_user),
  411. ):
  412. data = form_data.model_dump()
  413. request.app.state.config.BANNERS = data["banners"]
  414. return request.app.state.config.BANNERS
  415. @router.get("/banners", response_model=list[BannerModel])
  416. async def get_banners(
  417. request: Request,
  418. user=Depends(get_verified_user),
  419. ):
  420. return request.app.state.config.BANNERS