configs.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. from fastapi import APIRouter, Depends, Request, HTTPException
  2. from pydantic import BaseModel, ConfigDict
  3. from typing import Optional
  4. from datetime import datetime, timedelta
  5. import secrets
  6. import string
  7. from open_webui.utils.auth import get_admin_user, get_verified_user
  8. from open_webui.config import get_config, save_config
  9. from open_webui.config import BannerModel
  10. from open_webui.models.users import Users
  11. from open_webui.models.groups import Groups
  12. from open_webui.env import WEBUI_AUTH
  13. from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
  14. router = APIRouter()
  15. ############################
  16. # ImportConfig
  17. ############################
  18. class ImportConfigForm(BaseModel):
  19. config: dict
  20. @router.post("/import", response_model=dict)
  21. async def import_config(form_data: ImportConfigForm, user=Depends(get_admin_user)):
  22. save_config(form_data.config)
  23. return get_config()
  24. ############################
  25. # ExportConfig
  26. ############################
  27. @router.get("/export", response_model=dict)
  28. async def export_config(user=Depends(get_admin_user)):
  29. return get_config()
  30. ############################
  31. # Direct Connections Config
  32. ############################
  33. class DirectConnectionsConfigForm(BaseModel):
  34. ENABLE_DIRECT_CONNECTIONS: bool
  35. @router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
  36. async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
  37. return {
  38. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  39. }
  40. @router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
  41. async def set_direct_connections_config(
  42. request: Request,
  43. form_data: DirectConnectionsConfigForm,
  44. user=Depends(get_admin_user),
  45. ):
  46. request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
  47. form_data.ENABLE_DIRECT_CONNECTIONS
  48. )
  49. return {
  50. "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
  51. }
  52. ############################
  53. # ToolServers Config
  54. ############################
  55. class ToolServerConnection(BaseModel):
  56. url: str
  57. path: str
  58. auth_type: Optional[str]
  59. key: Optional[str]
  60. config: Optional[dict]
  61. model_config = ConfigDict(extra="allow")
  62. class ToolServersConfigForm(BaseModel):
  63. TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
  64. @router.get("/tool_servers", response_model=ToolServersConfigForm)
  65. async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
  66. return {
  67. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  68. }
  69. @router.post("/tool_servers", response_model=ToolServersConfigForm)
  70. async def set_tool_servers_config(
  71. request: Request,
  72. form_data: ToolServersConfigForm,
  73. user=Depends(get_admin_user),
  74. ):
  75. request.app.state.config.TOOL_SERVER_CONNECTIONS = [
  76. connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
  77. ]
  78. request.app.state.TOOL_SERVERS = await get_tool_servers_data(
  79. request.app.state.config.TOOL_SERVER_CONNECTIONS
  80. )
  81. return {
  82. "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
  83. }
  84. @router.post("/tool_servers/verify")
  85. async def verify_tool_servers_config(
  86. request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
  87. ):
  88. """
  89. Verify the connection to the tool server.
  90. """
  91. try:
  92. token = None
  93. if form_data.auth_type == "bearer":
  94. token = form_data.key
  95. elif form_data.auth_type == "session":
  96. token = request.state.token.credentials
  97. url = f"{form_data.url}/{form_data.path}"
  98. return await get_tool_server_data(token, url)
  99. except Exception as e:
  100. raise HTTPException(
  101. status_code=400,
  102. detail=f"Failed to connect to the tool server: {str(e)}",
  103. )
  104. ############################
  105. # CodeInterpreterConfig
  106. ############################
  107. class CodeInterpreterConfigForm(BaseModel):
  108. ENABLE_CODE_EXECUTION: bool
  109. CODE_EXECUTION_ENGINE: str
  110. CODE_EXECUTION_JUPYTER_URL: Optional[str]
  111. CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
  112. CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
  113. CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
  114. CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
  115. ENABLE_CODE_INTERPRETER: bool
  116. CODE_INTERPRETER_ENGINE: str
  117. CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
  118. CODE_INTERPRETER_JUPYTER_URL: Optional[str]
  119. CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
  120. CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
  121. CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
  122. CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
  123. @router.get("/code_execution", response_model=CodeInterpreterConfigForm)
  124. async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
  125. return {
  126. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  127. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  128. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  129. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  130. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  131. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  132. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  133. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  134. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  135. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  136. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  137. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  138. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  139. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  140. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  141. }
  142. @router.post("/code_execution", response_model=CodeInterpreterConfigForm)
  143. async def set_code_execution_config(
  144. request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
  145. ):
  146. request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
  147. request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
  148. request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
  149. form_data.CODE_EXECUTION_JUPYTER_URL
  150. )
  151. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
  152. form_data.CODE_EXECUTION_JUPYTER_AUTH
  153. )
  154. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
  155. form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
  156. )
  157. request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
  158. form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
  159. )
  160. request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
  161. form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
  162. )
  163. request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
  164. request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
  165. request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
  166. form_data.CODE_INTERPRETER_PROMPT_TEMPLATE
  167. )
  168. request.app.state.config.CODE_INTERPRETER_JUPYTER_URL = (
  169. form_data.CODE_INTERPRETER_JUPYTER_URL
  170. )
  171. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH = (
  172. form_data.CODE_INTERPRETER_JUPYTER_AUTH
  173. )
  174. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN = (
  175. form_data.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN
  176. )
  177. request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
  178. form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
  179. )
  180. request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
  181. form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
  182. )
  183. return {
  184. "ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
  185. "CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
  186. "CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
  187. "CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
  188. "CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
  189. "CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
  190. "CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
  191. "ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
  192. "CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
  193. "CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
  194. "CODE_INTERPRETER_JUPYTER_URL": request.app.state.config.CODE_INTERPRETER_JUPYTER_URL,
  195. "CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
  196. "CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
  197. "CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
  198. "CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
  199. }
  200. ############################
  201. # SetDefaultModels
  202. ############################
  203. class ModelsConfigForm(BaseModel):
  204. DEFAULT_MODELS: Optional[str]
  205. MODEL_ORDER_LIST: Optional[list[str]]
  206. @router.get("/models", response_model=ModelsConfigForm)
  207. async def get_models_config(request: Request, user=Depends(get_admin_user)):
  208. return {
  209. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  210. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  211. }
  212. @router.post("/models", response_model=ModelsConfigForm)
  213. async def set_models_config(
  214. request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
  215. ):
  216. request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
  217. request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
  218. return {
  219. "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
  220. "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
  221. }
  222. class PromptSuggestion(BaseModel):
  223. title: list[str]
  224. content: str
  225. class SetDefaultSuggestionsForm(BaseModel):
  226. suggestions: list[PromptSuggestion]
  227. @router.post("/suggestions", response_model=list[PromptSuggestion])
  228. async def set_default_suggestions(
  229. request: Request,
  230. form_data: SetDefaultSuggestionsForm,
  231. user=Depends(get_admin_user),
  232. ):
  233. data = form_data.model_dump()
  234. request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
  235. return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
  236. ############################
  237. # SetBanners
  238. ############################
  239. class SetBannersForm(BaseModel):
  240. banners: list[BannerModel]
  241. @router.post("/banners", response_model=list[BannerModel])
  242. async def set_banners(
  243. request: Request,
  244. form_data: SetBannersForm,
  245. user=Depends(get_admin_user),
  246. ):
  247. data = form_data.model_dump()
  248. request.app.state.config.BANNERS = data["banners"]
  249. return request.app.state.config.BANNERS
  250. @router.get("/banners", response_model=list[BannerModel])
  251. async def get_banners(
  252. request: Request,
  253. user=Depends(get_verified_user),
  254. ):
  255. return request.app.state.config.BANNERS
  256. ############################
  257. # SCIM Configuration
  258. ############################
  259. class SCIMConfigForm(BaseModel):
  260. enabled: bool
  261. token: Optional[str] = None
  262. token_created_at: Optional[str] = None
  263. token_expires_at: Optional[str] = None
  264. class SCIMTokenRequest(BaseModel):
  265. expires_in: Optional[int] = None # seconds until expiration, None = never
  266. class SCIMTokenResponse(BaseModel):
  267. token: str
  268. created_at: str
  269. expires_at: Optional[str] = None
  270. class SCIMStats(BaseModel):
  271. total_users: int
  272. total_groups: int
  273. last_sync: Optional[str] = None
  274. # In-memory storage for SCIM tokens (in production, use database)
  275. scim_tokens = {}
  276. def generate_scim_token(length: int = 48) -> str:
  277. """Generate a secure random token for SCIM authentication"""
  278. alphabet = string.ascii_letters + string.digits + "-_"
  279. return "".join(secrets.choice(alphabet) for _ in range(length))
  280. @router.get("/scim", response_model=SCIMConfigForm)
  281. async def get_scim_config(request: Request, user=Depends(get_admin_user)):
  282. """Get current SCIM configuration"""
  283. # Get token info from storage
  284. token_info = None
  285. scim_token = getattr(request.app.state.config, "SCIM_TOKEN", None)
  286. # Handle both PersistentConfig and direct value
  287. if hasattr(scim_token, 'value'):
  288. scim_token = scim_token.value
  289. if scim_token and scim_token in scim_tokens:
  290. token_info = scim_tokens[scim_token]
  291. scim_enabled = getattr(request.app.state.config, "SCIM_ENABLED", False)
  292. print(f"Getting SCIM config - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}")
  293. # Handle both PersistentConfig and direct value
  294. if hasattr(scim_enabled, 'value'):
  295. scim_enabled = scim_enabled.value
  296. print(f"Returning SCIM config: enabled={scim_enabled}, token={'set' if scim_token else 'not set'}")
  297. return SCIMConfigForm(
  298. enabled=scim_enabled,
  299. token="***" if scim_token else None, # Don't expose actual token
  300. token_created_at=token_info.get("created_at") if token_info else None,
  301. token_expires_at=token_info.get("expires_at") if token_info else None,
  302. )
  303. @router.post("/scim", response_model=SCIMConfigForm)
  304. async def update_scim_config(request: Request, config: SCIMConfigForm, user=Depends(get_admin_user)):
  305. """Update SCIM configuration"""
  306. if not WEBUI_AUTH:
  307. raise HTTPException(400, detail="Authentication must be enabled for SCIM")
  308. print(f"Updating SCIM config: enabled={config.enabled}")
  309. # Import here to avoid circular import
  310. from open_webui.config import save_config, get_config
  311. # Get current config data
  312. config_data = get_config()
  313. # Update SCIM settings in config data
  314. if "scim" not in config_data:
  315. config_data["scim"] = {}
  316. config_data["scim"]["enabled"] = config.enabled
  317. # Save config to database
  318. save_config(config_data)
  319. # Also update the runtime config
  320. scim_enabled_attr = getattr(request.app.state.config, "SCIM_ENABLED", None)
  321. if scim_enabled_attr:
  322. if hasattr(scim_enabled_attr, 'value'):
  323. # It's a PersistentConfig object
  324. print(f"Updating PersistentConfig SCIM_ENABLED from {scim_enabled_attr.value} to {config.enabled}")
  325. scim_enabled_attr.value = config.enabled
  326. else:
  327. # Direct assignment
  328. print(f"Direct assignment SCIM_ENABLED to {config.enabled}")
  329. request.app.state.config.SCIM_ENABLED = config.enabled
  330. else:
  331. # Create if doesn't exist
  332. print(f"Creating SCIM_ENABLED with value {config.enabled}")
  333. request.app.state.config.SCIM_ENABLED = config.enabled
  334. # Return updated config
  335. return await get_scim_config(request=request, user=user)
  336. @router.post("/scim/token", response_model=SCIMTokenResponse)
  337. async def generate_scim_token_endpoint(
  338. request: Request, token_request: SCIMTokenRequest, user=Depends(get_admin_user)
  339. ):
  340. """Generate a new SCIM bearer token"""
  341. token = generate_scim_token()
  342. created_at = datetime.utcnow()
  343. expires_at = None
  344. if token_request.expires_in:
  345. expires_at = created_at + timedelta(seconds=token_request.expires_in)
  346. # Store token info
  347. token_info = {
  348. "token": token,
  349. "created_at": created_at.isoformat(),
  350. "expires_at": expires_at.isoformat() if expires_at else None,
  351. }
  352. scim_tokens[token] = token_info
  353. # Import here to avoid circular import
  354. from open_webui.config import save_config, get_config
  355. # Get current config data
  356. config_data = get_config()
  357. # Update SCIM token in config data
  358. if "scim" not in config_data:
  359. config_data["scim"] = {}
  360. config_data["scim"]["token"] = token
  361. # Save config to database
  362. save_config(config_data)
  363. # Also update the runtime config
  364. scim_token_attr = getattr(request.app.state.config, "SCIM_TOKEN", None)
  365. if scim_token_attr:
  366. if hasattr(scim_token_attr, 'value'):
  367. # It's a PersistentConfig object
  368. scim_token_attr.value = token
  369. else:
  370. # Direct assignment
  371. request.app.state.config.SCIM_TOKEN = token
  372. else:
  373. # Create if doesn't exist
  374. request.app.state.config.SCIM_TOKEN = token
  375. return SCIMTokenResponse(
  376. token=token,
  377. created_at=token_info["created_at"],
  378. expires_at=token_info["expires_at"],
  379. )
  380. @router.delete("/scim/token")
  381. async def revoke_scim_token(request: Request, user=Depends(get_admin_user)):
  382. """Revoke the current SCIM token"""
  383. # Get current token
  384. scim_token = getattr(request.app.state.config, "SCIM_TOKEN", None)
  385. if hasattr(scim_token, 'value'):
  386. scim_token = scim_token.value
  387. # Remove from storage
  388. if scim_token and scim_token in scim_tokens:
  389. del scim_tokens[scim_token]
  390. # Import here to avoid circular import
  391. from open_webui.config import save_config, get_config
  392. # Get current config data
  393. config_data = get_config()
  394. # Remove SCIM token from config data
  395. if "scim" in config_data:
  396. config_data["scim"]["token"] = None
  397. # Save config to database
  398. save_config(config_data)
  399. # Also update the runtime config
  400. scim_token_attr = getattr(request.app.state.config, "SCIM_TOKEN", None)
  401. if scim_token_attr:
  402. if hasattr(scim_token_attr, 'value'):
  403. # It's a PersistentConfig object
  404. scim_token_attr.value = None
  405. else:
  406. # Direct assignment
  407. request.app.state.config.SCIM_TOKEN = None
  408. return {"detail": "SCIM token revoked successfully"}
  409. @router.get("/scim/stats", response_model=SCIMStats)
  410. async def get_scim_stats(request: Request, user=Depends(get_admin_user)):
  411. """Get SCIM statistics"""
  412. users = Users.get_users()
  413. groups = Groups.get_groups()
  414. # Get last sync time (in production, track this properly)
  415. last_sync = None
  416. return SCIMStats(
  417. total_users=len(users),
  418. total_groups=len(groups) if groups else 0,
  419. last_sync=last_sync,
  420. )