configs.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import logging
  2. from fastapi import APIRouter, Depends, Request, HTTPException
  3. from pydantic import BaseModel, ConfigDict
  4. import aiohttp
  5. from typing import Optional
  6. from open_webui.utils.auth import get_admin_user, get_verified_user
  7. from open_webui.config import get_config, save_config
  8. from open_webui.config import BannerModel
  9. from open_webui.utils.tools import (
  10. get_tool_server_data,
  11. get_tool_server_url,
  12. set_tool_servers,
  13. )
  14. from open_webui.utils.mcp.client import MCPClient
  15. from open_webui.env import SRC_LOG_LEVELS
  16. from open_webui.utils.oauth import (
  17. get_discovery_urls,
  18. get_oauth_client_info_with_dynamic_client_registration,
  19. encrypt_data,
  20. decrypt_data,
  21. OAuthClientInformationFull,
  22. )
  23. from mcp.shared.auth import OAuthMetadata
  24. router = APIRouter()
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. ############################
  28. # ImportConfig
  29. ############################
  30. class ImportConfigForm(BaseModel):
  31. config: dict
  32. @router.post("/import", response_model=dict)
  33. async def import_config(form_data: ImportConfigForm, user=Depends(get_admin_user)):
  34. save_config(form_data.config)
  35. return get_config()
  36. ############################
  37. # ExportConfig
  38. ############################
  39. @router.get("/export", response_model=dict)
  40. async def export_config(user=Depends(get_admin_user)):
  41. return get_config()
  42. ############################
  43. # Connections Config
  44. ############################
  45. class ConnectionsConfigForm(BaseModel):
  46. ENABLE_DIRECT_CONNECTIONS: bool
  47. ENABLE_BASE_MODELS_CACHE: bool
  48. @router.get("/connections", response_model=ConnectionsConfigForm)
  49. async def get_connections_config(request: Request, user=Depends(get_admin_user)):
  50. return {
  51. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  52. "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
  53. }
  54. @router.post("/connections", response_model=ConnectionsConfigForm)
  55. async def set_connections_config(
  56. request: Request,
  57. form_data: ConnectionsConfigForm,
  58. user=Depends(get_admin_user),
  59. ):
  60. request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
  61. form_data.ENABLE_DIRECT_CONNECTIONS
  62. )
  63. request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
  64. form_data.ENABLE_BASE_MODELS_CACHE
  65. )
  66. return {
  67. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  68. "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
  69. }
  70. class OAuthClientRegistrationForm(BaseModel):
  71. url: str
  72. client_id: str
  73. client_name: Optional[str] = None
  74. @router.post("/oauth/clients/register")
  75. async def register_oauth_client(
  76. request: Request,
  77. form_data: OAuthClientRegistrationForm,
  78. type: Optional[str] = None,
  79. user=Depends(get_admin_user),
  80. ):
  81. try:
  82. oauth_client_id = form_data.client_id
  83. if type:
  84. oauth_client_id = f"{type}:{form_data.client_id}"
  85. oauth_client_info = (
  86. await get_oauth_client_info_with_dynamic_client_registration(
  87. request, oauth_client_id, form_data.url
  88. )
  89. )
  90. return {
  91. "status": True,
  92. "oauth_client_info": encrypt_data(
  93. oauth_client_info.model_dump(mode="json")
  94. ),
  95. }
  96. except Exception as e:
  97. log.debug(f"Failed to register OAuth client: {e}")
  98. raise HTTPException(
  99. status_code=400,
  100. detail=f"Failed to register OAuth client",
  101. )
  102. ############################
  103. # ToolServers Config
  104. ############################
  105. class ToolServerConnection(BaseModel):
  106. url: str
  107. path: str
  108. type: Optional[str] = "openapi" # openapi, mcp
  109. auth_type: Optional[str]
  110. key: Optional[str]
  111. config: Optional[dict]
  112. model_config = ConfigDict(extra="allow")
  113. class ToolServersConfigForm(BaseModel):
  114. TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
  115. @router.get("/tool_servers", response_model=ToolServersConfigForm)
  116. async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
  117. return {
  118. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  119. }
  120. @router.post("/tool_servers", response_model=ToolServersConfigForm)
  121. async def set_tool_servers_config(
  122. request: Request,
  123. form_data: ToolServersConfigForm,
  124. user=Depends(get_admin_user),
  125. ):
  126. request.app.state.config.TOOL_SERVER_CONNECTIONS = [
  127. connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
  128. ]
  129. await set_tool_servers(request)
  130. for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  131. server_type = connection.get("type", "openapi")
  132. if server_type == "mcp":
  133. server_id = connection.get("info", {}).get("id")
  134. auth_type = connection.get("auth_type", "none")
  135. if auth_type == "oauth_2.1" and server_id:
  136. try:
  137. oauth_client_info = connection.get("info", {}).get(
  138. "oauth_client_info", ""
  139. )
  140. oauth_client_info = decrypt_data(oauth_client_info)
  141. await request.app.state.oauth_client_manager.add_client(
  142. f"{server_type}:{server_id}",
  143. OAuthClientInformationFull(**oauth_client_info),
  144. )
  145. except Exception as e:
  146. log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
  147. continue
  148. return {
  149. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  150. }
  151. @router.post("/tool_servers/verify")
  152. async def verify_tool_servers_config(
  153. request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
  154. ):
  155. """
  156. Verify the connection to the tool server.
  157. """
  158. try:
  159. if form_data.type == "mcp":
  160. if form_data.auth_type == "oauth_2.1":
  161. discovery_urls = get_discovery_urls(form_data.url)
  162. for discovery_url in discovery_urls:
  163. log.debug(
  164. f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}"
  165. )
  166. async with aiohttp.ClientSession() as session:
  167. async with session.get(
  168. discovery_url
  169. ) as oauth_server_metadata_response:
  170. if oauth_server_metadata_response.status == 200:
  171. try:
  172. oauth_server_metadata = (
  173. OAuthMetadata.model_validate(
  174. await oauth_server_metadata_response.json()
  175. )
  176. )
  177. return {
  178. "status": True,
  179. "oauth_server_metadata": oauth_server_metadata.model_dump(
  180. mode="json"
  181. ),
  182. }
  183. except Exception as e:
  184. log.info(
  185. f"Failed to parse OAuth 2.1 discovery document: {e}"
  186. )
  187. raise HTTPException(
  188. status_code=400,
  189. detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_url}",
  190. )
  191. raise HTTPException(
  192. status_code=400,
  193. detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls}",
  194. )
  195. else:
  196. try:
  197. client = MCPClient()
  198. headers = None
  199. token = None
  200. if form_data.auth_type == "bearer":
  201. token = form_data.key
  202. elif form_data.auth_type == "session":
  203. token = request.state.token.credentials
  204. elif form_data.auth_type == "system_oauth":
  205. try:
  206. if request.cookies.get("oauth_session_id", None):
  207. token = await request.app.state.oauth_manager.get_oauth_token(
  208. user.id,
  209. request.cookies.get("oauth_session_id", None),
  210. )
  211. except Exception as e:
  212. pass
  213. if token:
  214. headers = {"Authorization": f"Bearer {token}"}
  215. await client.connect(form_data.url, headers=headers)
  216. specs = await client.list_tool_specs()
  217. return {
  218. "status": True,
  219. "specs": specs,
  220. }
  221. except Exception as e:
  222. log.debug(f"Failed to create MCP client: {e}")
  223. raise HTTPException(
  224. status_code=400,
  225. detail=f"Failed to create MCP client",
  226. )
  227. finally:
  228. if client:
  229. await client.disconnect()
  230. else: # openapi
  231. token = None
  232. if form_data.auth_type == "bearer":
  233. token = form_data.key
  234. elif form_data.auth_type == "session":
  235. token = request.state.token.credentials
  236. elif form_data.auth_type == "system_oauth":
  237. try:
  238. if request.cookies.get("oauth_session_id", None):
  239. token = await request.app.state.oauth_manager.get_oauth_token(
  240. user.id,
  241. request.cookies.get("oauth_session_id", None),
  242. )
  243. except Exception as e:
  244. pass
  245. url = get_tool_server_url(form_data.url, form_data.path)
  246. return await get_tool_server_data(token, url)
  247. except HTTPException as e:
  248. raise e
  249. except Exception as e:
  250. log.debug(f"Failed to connect to the tool server: {e}")
  251. raise HTTPException(
  252. status_code=400,
  253. detail=f"Failed to connect to the tool server",
  254. )
  255. ############################
  256. # CodeInterpreterConfig
  257. ############################
  258. class CodeInterpreterConfigForm(BaseModel):
  259. ENABLE_CODE_EXECUTION: bool
  260. CODE_EXECUTION_ENGINE: str
  261. CODE_EXECUTION_JUPYTER_URL: Optional[str]
  262. CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
  263. CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
  264. CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
  265. CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
  266. ENABLE_CODE_INTERPRETER: bool
  267. CODE_INTERPRETER_ENGINE: str
  268. CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
  269. CODE_INTERPRETER_JUPYTER_URL: Optional[str]
  270. CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
  271. CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
  272. CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
  273. CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
  274. @router.get("/code_execution", response_model=CodeInterpreterConfigForm)
  275. async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
  276. return {
  277. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  278. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  279. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  280. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  281. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  282. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  283. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  284. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  285. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  286. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  287. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  288. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  289. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  290. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  291. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  292. }
  293. @router.post("/code_execution", response_model=CodeInterpreterConfigForm)
  294. async def set_code_execution_config(
  295. request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
  296. ):
  297. request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
  298. request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
  299. request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
  300. form_data.CODE_EXECUTION_JUPYTER_URL
  301. )
  302. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
  303. form_data.CODE_EXECUTION_JUPYTER_AUTH
  304. )
  305. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
  306. form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
  307. )
  308. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
  309. form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
  310. )
  311. request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
  312. form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
  313. )
  314. request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
  315. request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
  316. request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
  317. form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
  318. )
  319. request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
  320. form_data.CODE_INTERPRETER_JUPYTER_URL
  321. )
  322. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
  323. form_data.CODE_INTERPRETER_JUPYTER_AUTH
  324. )
  325. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
  326. form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
  327. )
  328. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
  329. form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
  330. )
  331. request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
  332. form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
  333. )
  334. return {
  335. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  336. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  337. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  338. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  339. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  340. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  341. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  342. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  343. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  344. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  345. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  346. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  347. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  348. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  349. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  350. }
  351. ############################
  352. # SetDefaultModels
  353. ############################
  354. class ModelsConfigForm(BaseModel):
  355. DEFAULT_MODELS: Optional[str]
  356. MODEL_ORDER_LIST: Optional[list[str]]
  357. @router.get("/models", response_model=ModelsConfigForm)
  358. async def get_models_config(request: Request, user=Depends(get_admin_user)):
  359. return {
  360. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  361. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  362. }
  363. @router.post("/models", response_model=ModelsConfigForm)
  364. async def set_models_config(
  365. request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
  366. ):
  367. request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
  368. request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
  369. return {
  370. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  371. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  372. }
  373. class PromptSuggestion(BaseModel):
  374. title: list[str]
  375. content: str
  376. class SetDefaultSuggestionsForm(BaseModel):
  377. suggestions: list[PromptSuggestion]
  378. @router.post("/suggestions", response_model=list[PromptSuggestion])
  379. async def set_default_suggestions(
  380. request: Request,
  381. form_data: SetDefaultSuggestionsForm,
  382. user=Depends(get_admin_user),
  383. ):
  384. data = form_data.model_dump()
  385. request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
  386. return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
  387. ############################
  388. # SetBanners
  389. ############################
  390. class SetBannersForm(BaseModel):
  391. banners: list[BannerModel]
  392. @router.post("/banners", response_model=list[BannerModel])
  393. async def set_banners(
  394. request: Request,
  395. form_data: SetBannersForm,
  396. user=Depends(get_admin_user),
  397. ):
  398. data = form_data.model_dump()
  399. request.app.state.config.BANNERS = data["banners"]
  400. return request.app.state.config.BANNERS
  401. @router.get("/banners", response_model=list[BannerModel])
  402. async def get_banners(
  403. request: Request,
  404. user=Depends(get_verified_user),
  405. ):
  406. return request.app.state.config.BANNERS