configs.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. import logging
  2. from fastapi import APIRouter, Depends, Request, HTTPException
  3. from pydantic import BaseModel, ConfigDict
  4. import aiohttp
  5. from typing import Optional
  6. from open_webui.utils.auth import get_admin_user, get_verified_user
  7. from open_webui.config import get_config, save_config
  8. from open_webui.config import BannerModel
  9. from open_webui.utils.tools import (
  10. get_tool_server_data,
  11. get_tool_server_url,
  12. set_tool_servers,
  13. )
  14. from open_webui.utils.mcp.client import MCPClient
  15. from open_webui.env import SRC_LOG_LEVELS
  16. from open_webui.utils.oauth import (
  17. get_discovery_urls,
  18. get_oauth_client_info_with_dynamic_client_registration,
  19. encrypt_data,
  20. decrypt_data,
  21. OAuthClientInformationFull,
  22. )
  23. from mcp.shared.auth import OAuthMetadata
  24. router = APIRouter()
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. ############################
  28. # ImportConfig
  29. ############################
  30. class ImportConfigForm(BaseModel):
  31. config: dict
  32. @router.post("/import", response_model=dict)
  33. async def import_config(form_data: ImportConfigForm, user=Depends(get_admin_user)):
  34. save_config(form_data.config)
  35. return get_config()
  36. ############################
  37. # ExportConfig
  38. ############################
  39. @router.get("/export", response_model=dict)
  40. async def export_config(user=Depends(get_admin_user)):
  41. return get_config()
  42. ############################
  43. # Connections Config
  44. ############################
  45. class ConnectionsConfigForm(BaseModel):
  46. ENABLE_DIRECT_CONNECTIONS: bool
  47. ENABLE_BASE_MODELS_CACHE: bool
  48. @router.get("/connections", response_model=ConnectionsConfigForm)
  49. async def get_connections_config(request: Request, user=Depends(get_admin_user)):
  50. return {
  51. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  52. "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
  53. }
  54. @router.post("/connections", response_model=ConnectionsConfigForm)
  55. async def set_connections_config(
  56. request: Request,
  57. form_data: ConnectionsConfigForm,
  58. user=Depends(get_admin_user),
  59. ):
  60. request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
  61. form_data.ENABLE_DIRECT_CONNECTIONS
  62. )
  63. request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
  64. form_data.ENABLE_BASE_MODELS_CACHE
  65. )
  66. return {
  67. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  68. "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
  69. }
  70. class OAuthClientRegistrationForm(BaseModel):
  71. url: str
  72. client_id: str
  73. client_name: Optional[str] = None
  74. @router.post("/oauth/clients/register")
  75. async def register_oauth_client(
  76. request: Request,
  77. form_data: OAuthClientRegistrationForm,
  78. type: Optional[str] = None,
  79. user=Depends(get_admin_user),
  80. ):
  81. try:
  82. oauth_client_id = form_data.client_id
  83. if type:
  84. oauth_client_id = f"{type}:{form_data.client_id}"
  85. oauth_client_info = (
  86. await get_oauth_client_info_with_dynamic_client_registration(
  87. request, oauth_client_id, form_data.url
  88. )
  89. )
  90. return {
  91. "status": True,
  92. "oauth_client_info": encrypt_data(
  93. oauth_client_info.model_dump(mode="json")
  94. ),
  95. }
  96. except Exception as e:
  97. log.debug(f"Failed to register OAuth client: {e}")
  98. raise HTTPException(
  99. status_code=400,
  100. detail=f"Failed to register OAuth client",
  101. )
  102. ############################
  103. # ToolServers Config
  104. ############################
  105. class ToolServerConnection(BaseModel):
  106. url: str
  107. path: str
  108. type: Optional[str] = "openapi" # openapi, mcp
  109. auth_type: Optional[str]
  110. key: Optional[str]
  111. config: Optional[dict]
  112. model_config = ConfigDict(extra="allow")
  113. class ToolServersConfigForm(BaseModel):
  114. TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
  115. @router.get("/tool_servers", response_model=ToolServersConfigForm)
  116. async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
  117. return {
  118. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  119. }
  120. @router.post("/tool_servers", response_model=ToolServersConfigForm)
  121. async def set_tool_servers_config(
  122. request: Request,
  123. form_data: ToolServersConfigForm,
  124. user=Depends(get_admin_user),
  125. ):
  126. request.app.state.config.TOOL_SERVER_CONNECTIONS = [
  127. connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
  128. ]
  129. await set_tool_servers(request)
  130. for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
  131. server_type = connection.get("type", "openapi")
  132. if server_type == "mcp":
  133. server_id = connection.get("info", {}).get("id")
  134. auth_type = connection.get("auth_type", "none")
  135. if auth_type == "oauth_2.1" and server_id:
  136. try:
  137. oauth_client_info = connection.get("info", {}).get(
  138. "oauth_client_info", ""
  139. )
  140. oauth_client_info = decrypt_data(oauth_client_info)
  141. await request.app.state.oauth_client_manager.add_client(
  142. f"{server_type}:{server_id}",
  143. OAuthClientInformationFull(**oauth_client_info),
  144. )
  145. except Exception as e:
  146. log.debug(f"Failed to add OAuth client for MCP tool server: {e}")
  147. continue
  148. return {
  149. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  150. }
  151. @router.post("/tool_servers/verify")
  152. async def verify_tool_servers_config(
  153. request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
  154. ):
  155. """
  156. Verify the connection to the tool server.
  157. """
  158. try:
  159. if form_data.type == "mcp":
  160. if form_data.auth_type == "oauth_2.1":
  161. discovery_urls = get_discovery_urls(form_data.url)
  162. async with aiohttp.ClientSession() as session:
  163. async with session.get(
  164. discovery_urls[0]
  165. ) as oauth_server_metadata_response:
  166. if oauth_server_metadata_response.status != 200:
  167. raise HTTPException(
  168. status_code=400,
  169. detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}",
  170. )
  171. try:
  172. oauth_server_metadata = OAuthMetadata.model_validate(
  173. await oauth_server_metadata_response.json()
  174. )
  175. return {
  176. "status": True,
  177. "oauth_server_metadata": oauth_server_metadata.model_dump(
  178. mode="json"
  179. ),
  180. }
  181. except Exception as e:
  182. log.info(
  183. f"Failed to parse OAuth 2.1 discovery document: {e}"
  184. )
  185. raise HTTPException(
  186. status_code=400,
  187. detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}",
  188. )
  189. raise HTTPException(
  190. status_code=400,
  191. detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}",
  192. )
  193. else:
  194. try:
  195. client = MCPClient()
  196. headers = None
  197. token = None
  198. if form_data.auth_type == "bearer":
  199. token = form_data.key
  200. elif form_data.auth_type == "session":
  201. token = request.state.token.credentials
  202. elif form_data.auth_type == "system_oauth":
  203. try:
  204. if request.cookies.get("oauth_session_id", None):
  205. token = await request.app.state.oauth_manager.get_oauth_token(
  206. user.id,
  207. request.cookies.get("oauth_session_id", None),
  208. )
  209. except Exception as e:
  210. pass
  211. if token:
  212. headers = {"Authorization": f"Bearer {token}"}
  213. await client.connect(form_data.url, headers=headers)
  214. specs = await client.list_tool_specs()
  215. return {
  216. "status": True,
  217. "specs": specs,
  218. }
  219. except Exception as e:
  220. log.debug(f"Failed to create MCP client: {e}")
  221. raise HTTPException(
  222. status_code=400,
  223. detail=f"Failed to create MCP client",
  224. )
  225. finally:
  226. if client:
  227. await client.disconnect()
  228. else: # openapi
  229. token = None
  230. if form_data.auth_type == "bearer":
  231. token = form_data.key
  232. elif form_data.auth_type == "session":
  233. token = request.state.token.credentials
  234. elif form_data.auth_type == "system_oauth":
  235. try:
  236. if request.cookies.get("oauth_session_id", None):
  237. token = await request.app.state.oauth_manager.get_oauth_token(
  238. user.id,
  239. request.cookies.get("oauth_session_id", None),
  240. )
  241. except Exception as e:
  242. pass
  243. url = get_tool_server_url(form_data.url, form_data.path)
  244. return await get_tool_server_data(token, url)
  245. except HTTPException as e:
  246. raise e
  247. except Exception as e:
  248. log.debug(f"Failed to connect to the tool server: {e}")
  249. raise HTTPException(
  250. status_code=400,
  251. detail=f"Failed to connect to the tool server",
  252. )
  253. ############################
  254. # CodeInterpreterConfig
  255. ############################
  256. class CodeInterpreterConfigForm(BaseModel):
  257. ENABLE_CODE_EXECUTION: bool
  258. CODE_EXECUTION_ENGINE: str
  259. CODE_EXECUTION_JUPYTER_URL: Optional[str]
  260. CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
  261. CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
  262. CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
  263. CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
  264. ENABLE_CODE_INTERPRETER: bool
  265. CODE_INTERPRETER_ENGINE: str
  266. CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
  267. CODE_INTERPRETER_JUPYTER_URL: Optional[str]
  268. CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
  269. CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
  270. CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
  271. CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
  272. @router.get("/code_execution", response_model=CodeInterpreterConfigForm)
  273. async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
  274. return {
  275. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  276. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  277. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  278. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  279. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  280. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  281. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  282. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  283. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  284. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  285. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  286. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  287. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  288. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  289. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  290. }
  291. @router.post("/code_execution", response_model=CodeInterpreterConfigForm)
  292. async def set_code_execution_config(
  293. request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
  294. ):
  295. request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
  296. request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
  297. request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
  298. form_data.CODE_EXECUTION_JUPYTER_URL
  299. )
  300. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
  301. form_data.CODE_EXECUTION_JUPYTER_AUTH
  302. )
  303. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
  304. form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
  305. )
  306. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
  307. form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
  308. )
  309. request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
  310. form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
  311. )
  312. request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
  313. request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
  314. request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
  315. form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
  316. )
  317. request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
  318. form_data.CODE_INTERPRETER_JUPYTER_URL
  319. )
  320. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
  321. form_data.CODE_INTERPRETER_JUPYTER_AUTH
  322. )
  323. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
  324. form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
  325. )
  326. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
  327. form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
  328. )
  329. request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
  330. form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
  331. )
  332. return {
  333. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  334. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  335. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  336. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  337. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  338. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  339. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  340. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  341. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  342. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  343. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  344. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  345. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  346. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  347. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  348. }
  349. ############################
  350. # SetDefaultModels
  351. ############################
  352. class ModelsConfigForm(BaseModel):
  353. DEFAULT_MODELS: Optional[str]
  354. MODEL_ORDER_LIST: Optional[list[str]]
  355. @router.get("/models", response_model=ModelsConfigForm)
  356. async def get_models_config(request: Request, user=Depends(get_admin_user)):
  357. return {
  358. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  359. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  360. }
  361. @router.post("/models", response_model=ModelsConfigForm)
  362. async def set_models_config(
  363. request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
  364. ):
  365. request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
  366. request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
  367. return {
  368. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  369. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  370. }
  371. class PromptSuggestion(BaseModel):
  372. title: list[str]
  373. content: str
  374. class SetDefaultSuggestionsForm(BaseModel):
  375. suggestions: list[PromptSuggestion]
  376. @router.post("/suggestions", response_model=list[PromptSuggestion])
  377. async def set_default_suggestions(
  378. request: Request,
  379. form_data: SetDefaultSuggestionsForm,
  380. user=Depends(get_admin_user),
  381. ):
  382. data = form_data.model_dump()
  383. request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
  384. return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
  385. ############################
  386. # SetBanners
  387. ############################
  388. class SetBannersForm(BaseModel):
  389. banners: list[BannerModel]
  390. @router.post("/banners", response_model=list[BannerModel])
  391. async def set_banners(
  392. request: Request,
  393. form_data: SetBannersForm,
  394. user=Depends(get_admin_user),
  395. ):
  396. data = form_data.model_dump()
  397. request.app.state.config.BANNERS = data["banners"]
  398. return request.app.state.config.BANNERS
  399. @router.get("/banners", response_model=list[BannerModel])
  400. async def get_banners(
  401. request: Request,
  402. user=Depends(get_verified_user),
  403. ):
  404. return request.app.state.config.BANNERS