main.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. import asyncio
  2. import random
  3. import socketio
  4. import logging
  5. import sys
  6. import time
  7. from typing import Dict, Set
  8. from redis import asyncio as aioredis
  9. import pycrdt as Y
  10. from open_webui.models.users import Users, UserNameResponse
  11. from open_webui.models.channels import Channels
  12. from open_webui.models.chats import Chats
  13. from open_webui.models.notes import Notes, NoteUpdateForm
  14. from open_webui.utils.redis import (
  15. get_sentinels_from_env,
  16. get_sentinel_url_from_env,
  17. )
  18. from open_webui.env import (
  19. ENABLE_WEBSOCKET_SUPPORT,
  20. WEBSOCKET_MANAGER,
  21. WEBSOCKET_REDIS_URL,
  22. WEBSOCKET_REDIS_LOCK_TIMEOUT,
  23. WEBSOCKET_SENTINEL_PORT,
  24. WEBSOCKET_SENTINEL_HOSTS,
  25. )
  26. from open_webui.utils.auth import decode_token
  27. from open_webui.socket.utils import RedisDict, RedisLock
  28. from open_webui.tasks import create_task, stop_item_tasks
  29. from open_webui.utils.redis import get_redis_connection
  30. from open_webui.utils.access_control import has_access, get_users_with_access
  31. from open_webui.env import (
  32. GLOBAL_LOG_LEVEL,
  33. SRC_LOG_LEVELS,
  34. )
  35. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  36. log = logging.getLogger(__name__)
  37. log.setLevel(SRC_LOG_LEVELS["SOCKET"])
  38. REDIS = get_redis_connection(
  39. redis_url=WEBSOCKET_REDIS_URL,
  40. redis_sentinels=get_sentinels_from_env(
  41. WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
  42. ),
  43. async_mode=True,
  44. )
  45. if WEBSOCKET_MANAGER == "redis":
  46. if WEBSOCKET_SENTINEL_HOSTS:
  47. mgr = socketio.AsyncRedisManager(
  48. get_sentinel_url_from_env(
  49. WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
  50. )
  51. )
  52. else:
  53. mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
  54. sio = socketio.AsyncServer(
  55. cors_allowed_origins=[],
  56. async_mode="asgi",
  57. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  58. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  59. always_connect=True,
  60. client_manager=mgr,
  61. )
  62. else:
  63. sio = socketio.AsyncServer(
  64. cors_allowed_origins=[],
  65. async_mode="asgi",
  66. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  67. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  68. always_connect=True,
  69. )
  70. # Timeout duration in seconds
  71. TIMEOUT_DURATION = 3
  72. # Dictionary to maintain the user pool
  73. if WEBSOCKET_MANAGER == "redis":
  74. log.debug("Using Redis to manage websockets.")
  75. redis_sentinels = get_sentinels_from_env(
  76. WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
  77. )
  78. SESSION_POOL = RedisDict(
  79. "open-webui:session_pool",
  80. redis_url=WEBSOCKET_REDIS_URL,
  81. redis_sentinels=redis_sentinels,
  82. )
  83. USER_POOL = RedisDict(
  84. "open-webui:user_pool",
  85. redis_url=WEBSOCKET_REDIS_URL,
  86. redis_sentinels=redis_sentinels,
  87. )
  88. USAGE_POOL = RedisDict(
  89. "open-webui:usage_pool",
  90. redis_url=WEBSOCKET_REDIS_URL,
  91. redis_sentinels=redis_sentinels,
  92. )
  93. # TODO: Implement Yjs document management with Redis
  94. DOCUMENTS = {}
  95. DOCUMENT_USERS = {}
  96. clean_up_lock = RedisLock(
  97. redis_url=WEBSOCKET_REDIS_URL,
  98. lock_name="usage_cleanup_lock",
  99. timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
  100. redis_sentinels=redis_sentinels,
  101. )
  102. aquire_func = clean_up_lock.aquire_lock
  103. renew_func = clean_up_lock.renew_lock
  104. release_func = clean_up_lock.release_lock
  105. else:
  106. SESSION_POOL = {}
  107. USER_POOL = {}
  108. USAGE_POOL = {}
  109. DOCUMENTS = {} # document_id -> Y.YDoc instance
  110. DOCUMENT_USERS = {} # document_id -> set of user sids
  111. aquire_func = release_func = renew_func = lambda: True
  112. async def periodic_usage_pool_cleanup():
  113. max_retries = 2
  114. retry_delay = random.uniform(
  115. WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
  116. )
  117. for attempt in range(max_retries + 1):
  118. if aquire_func():
  119. break
  120. else:
  121. if attempt < max_retries:
  122. log.debug(
  123. f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
  124. )
  125. await asyncio.sleep(retry_delay)
  126. else:
  127. log.warning(
  128. "Failed to acquire cleanup lock after retries. Skipping cleanup."
  129. )
  130. return
  131. log.debug("Running periodic_cleanup")
  132. try:
  133. while True:
  134. if not renew_func():
  135. log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
  136. raise Exception("Unable to renew usage pool cleanup lock.")
  137. now = int(time.time())
  138. send_usage = False
  139. for model_id, connections in list(USAGE_POOL.items()):
  140. # Creating a list of sids to remove if they have timed out
  141. expired_sids = [
  142. sid
  143. for sid, details in connections.items()
  144. if now - details["updated_at"] > TIMEOUT_DURATION
  145. ]
  146. for sid in expired_sids:
  147. del connections[sid]
  148. if not connections:
  149. log.debug(f"Cleaning up model {model_id} from usage pool")
  150. del USAGE_POOL[model_id]
  151. else:
  152. USAGE_POOL[model_id] = connections
  153. send_usage = True
  154. await asyncio.sleep(TIMEOUT_DURATION)
  155. finally:
  156. release_func()
  157. app = socketio.ASGIApp(
  158. sio,
  159. socketio_path="/ws/socket.io",
  160. )
  161. def get_models_in_use():
  162. # List models that are currently in use
  163. models_in_use = list(USAGE_POOL.keys())
  164. return models_in_use
  165. def get_active_user_ids():
  166. """Get the list of active user IDs."""
  167. return list(USER_POOL.keys())
  168. def get_user_active_status(user_id):
  169. """Check if a user is currently active."""
  170. return user_id in USER_POOL
  171. def get_user_id_from_session_pool(sid):
  172. user = SESSION_POOL.get(sid)
  173. if user:
  174. return user["id"]
  175. return None
  176. def get_user_ids_from_room(room):
  177. active_session_ids = sio.manager.get_participants(
  178. namespace="/",
  179. room=room,
  180. )
  181. active_user_ids = list(
  182. set(
  183. [SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids]
  184. )
  185. )
  186. return active_user_ids
  187. def get_active_status_by_user_id(user_id):
  188. if user_id in USER_POOL:
  189. return True
  190. return False
  191. @sio.on("usage")
  192. async def usage(sid, data):
  193. if sid in SESSION_POOL:
  194. model_id = data["model"]
  195. # Record the timestamp for the last update
  196. current_time = int(time.time())
  197. # Store the new usage data and task
  198. USAGE_POOL[model_id] = {
  199. **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
  200. sid: {"updated_at": current_time},
  201. }
  202. @sio.event
  203. async def connect(sid, environ, auth):
  204. user = None
  205. if auth and "token" in auth:
  206. data = decode_token(auth["token"])
  207. if data is not None and "id" in data:
  208. user = Users.get_user_by_id(data["id"])
  209. if user:
  210. SESSION_POOL[sid] = user.model_dump()
  211. if user.id in USER_POOL:
  212. USER_POOL[user.id] = USER_POOL[user.id] + [sid]
  213. else:
  214. USER_POOL[user.id] = [sid]
  215. @sio.on("user-join")
  216. async def user_join(sid, data):
  217. auth = data["auth"] if "auth" in data else None
  218. if not auth or "token" not in auth:
  219. return
  220. data = decode_token(auth["token"])
  221. if data is None or "id" not in data:
  222. return
  223. user = Users.get_user_by_id(data["id"])
  224. if not user:
  225. return
  226. SESSION_POOL[sid] = user.model_dump()
  227. if user.id in USER_POOL:
  228. USER_POOL[user.id] = USER_POOL[user.id] + [sid]
  229. else:
  230. USER_POOL[user.id] = [sid]
  231. # Join all the channels
  232. channels = Channels.get_channels_by_user_id(user.id)
  233. log.debug(f"{channels=}")
  234. for channel in channels:
  235. await sio.enter_room(sid, f"channel:{channel.id}")
  236. return {"id": user.id, "name": user.name}
  237. @sio.on("join-channels")
  238. async def join_channel(sid, data):
  239. auth = data["auth"] if "auth" in data else None
  240. if not auth or "token" not in auth:
  241. return
  242. data = decode_token(auth["token"])
  243. if data is None or "id" not in data:
  244. return
  245. user = Users.get_user_by_id(data["id"])
  246. if not user:
  247. return
  248. # Join all the channels
  249. channels = Channels.get_channels_by_user_id(user.id)
  250. log.debug(f"{channels=}")
  251. for channel in channels:
  252. await sio.enter_room(sid, f"channel:{channel.id}")
  253. @sio.on("channel-events")
  254. async def channel_events(sid, data):
  255. room = f"channel:{data['channel_id']}"
  256. participants = sio.manager.get_participants(
  257. namespace="/",
  258. room=room,
  259. )
  260. sids = [sid for sid, _ in participants]
  261. if sid not in sids:
  262. return
  263. event_data = data["data"]
  264. event_type = event_data["type"]
  265. if event_type == "typing":
  266. await sio.emit(
  267. "channel-events",
  268. {
  269. "channel_id": data["channel_id"],
  270. "message_id": data.get("message_id", None),
  271. "data": event_data,
  272. "user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
  273. },
  274. room=room,
  275. )
  276. @sio.on("yjs:document:join")
  277. async def yjs_document_join(sid, data):
  278. """Handle user joining a document"""
  279. user = SESSION_POOL.get(sid)
  280. try:
  281. document_id = data["document_id"]
  282. if document_id.startswith("note:"):
  283. note_id = document_id.split(":")[1]
  284. note = Notes.get_note_by_id(note_id)
  285. if not note:
  286. log.error(f"Note {note_id} not found")
  287. return
  288. if user.get("role") != "admin" and has_access(
  289. user.get("id"), type="read", access_control=note.access_control
  290. ):
  291. log.error(
  292. f"User {user.get('id')} does not have access to note {note_id}"
  293. )
  294. return
  295. user_id = data.get("user_id", sid)
  296. user_name = data.get("user_name", "Anonymous")
  297. user_color = data.get("user_color", "#000000")
  298. log.info(f"User {user_id} joining document {document_id}")
  299. # Initialize document if it doesn't exist
  300. if document_id not in DOCUMENTS:
  301. DOCUMENTS[document_id] = {
  302. "ydoc": Y.Doc(), # Create actual Yjs document
  303. "users": set(),
  304. }
  305. DOCUMENT_USERS[document_id] = set()
  306. # Add user to document
  307. DOCUMENTS[document_id]["users"].add(sid)
  308. DOCUMENT_USERS[document_id].add(sid)
  309. # Join Socket.IO room
  310. await sio.enter_room(sid, f"doc_{document_id}")
  311. # Send current document state as a proper Yjs update
  312. ydoc = DOCUMENTS[document_id]["ydoc"]
  313. # Encode the entire document state as an update
  314. state_update = ydoc.get_update()
  315. await sio.emit(
  316. "yjs:document:state",
  317. {
  318. "document_id": document_id,
  319. "state": list(state_update), # Convert bytes to list for JSON
  320. },
  321. room=sid,
  322. )
  323. # Notify other users about the new user
  324. await sio.emit(
  325. "yjs:user:joined",
  326. {
  327. "document_id": document_id,
  328. "user_id": user_id,
  329. "user_name": user_name,
  330. "user_color": user_color,
  331. },
  332. room=f"doc_{document_id}",
  333. skip_sid=sid,
  334. )
  335. log.info(f"User {user_id} successfully joined document {document_id}")
  336. except Exception as e:
  337. log.error(f"Error in yjs_document_join: {e}")
  338. await sio.emit("error", {"message": "Failed to join document"}, room=sid)
  339. async def document_save_handler(document_id, data, user):
  340. if document_id.startswith("note:"):
  341. note_id = document_id.split(":")[1]
  342. note = Notes.get_note_by_id(note_id)
  343. if not note:
  344. log.error(f"Note {note_id} not found")
  345. return
  346. if user.get("role") != "admin" and has_access(
  347. user.get("id"), type="read", access_control=note.access_control
  348. ):
  349. log.error(f"User {user.get('id')} does not have access to note {note_id}")
  350. return
  351. Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
  352. @sio.on("yjs:document:update")
  353. async def yjs_document_update(sid, data):
  354. """Handle Yjs document updates"""
  355. try:
  356. document_id = data["document_id"]
  357. await stop_item_tasks(REDIS, document_id)
  358. user_id = data.get("user_id", sid)
  359. update = data["update"] # List of bytes from frontend
  360. if document_id not in DOCUMENTS:
  361. log.warning(f"Document {document_id} not found")
  362. return
  363. # Apply the update to the server's Yjs document
  364. ydoc = DOCUMENTS[document_id]["ydoc"]
  365. update_bytes = bytes(update)
  366. try:
  367. ydoc.apply_update(update_bytes)
  368. except Exception as e:
  369. log.error(f"Failed to apply Yjs update: {e}")
  370. return
  371. # Broadcast update to all other users in the document
  372. await sio.emit(
  373. "yjs:document:update",
  374. {
  375. "document_id": document_id,
  376. "user_id": user_id,
  377. "update": update,
  378. "socket_id": sid, # Add socket_id to match frontend filtering
  379. },
  380. room=f"doc_{document_id}",
  381. skip_sid=sid,
  382. )
  383. async def debounced_save():
  384. await asyncio.sleep(0.5)
  385. await document_save_handler(
  386. document_id, data.get("data", {}), SESSION_POOL.get(sid)
  387. )
  388. await stop_item_tasks(REDIS, document_id) # Cancel previous in-flight save
  389. await create_task(REDIS, debounced_save(), document_id)
  390. except Exception as e:
  391. log.error(f"Error in yjs_document_update: {e}")
  392. @sio.on("yjs:document:leave")
  393. async def yjs_document_leave(sid, data):
  394. """Handle user leaving a document"""
  395. try:
  396. document_id = data["document_id"]
  397. user_id = data.get("user_id", sid)
  398. log.info(f"User {user_id} leaving document {document_id}")
  399. if document_id in DOCUMENTS:
  400. DOCUMENTS[document_id]["users"].discard(sid)
  401. if document_id in DOCUMENT_USERS:
  402. DOCUMENT_USERS[document_id].discard(sid)
  403. # Leave Socket.IO room
  404. await sio.leave_room(sid, f"doc_{document_id}")
  405. # Notify other users
  406. await sio.emit(
  407. "yjs:user:left",
  408. {"document_id": document_id, "user_id": user_id},
  409. room=f"doc_{document_id}",
  410. )
  411. if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]:
  412. # If no users left, clean up the document
  413. log.info(f"Cleaning up document {document_id} as no users are left")
  414. del DOCUMENTS[document_id]
  415. del DOCUMENT_USERS[document_id]
  416. except Exception as e:
  417. log.error(f"Error in yjs_document_leave: {e}")
  418. @sio.on("yjs:awareness:update")
  419. async def yjs_awareness_update(sid, data):
  420. """Handle awareness updates (cursors, selections, etc.)"""
  421. try:
  422. document_id = data["document_id"]
  423. user_id = data.get("user_id", sid)
  424. update = data["update"]
  425. # Broadcast awareness update to all other users in the document
  426. await sio.emit(
  427. "yjs:awareness:update",
  428. {"document_id": document_id, "user_id": user_id, "update": update},
  429. room=f"doc_{document_id}",
  430. skip_sid=sid,
  431. )
  432. except Exception as e:
  433. log.error(f"Error in yjs_awareness_update: {e}")
  434. @sio.event
  435. async def disconnect(sid):
  436. if sid in SESSION_POOL:
  437. user = SESSION_POOL[sid]
  438. del SESSION_POOL[sid]
  439. user_id = user["id"]
  440. USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
  441. if len(USER_POOL[user_id]) == 0:
  442. del USER_POOL[user_id]
  443. else:
  444. pass
  445. # print(f"Unknown session ID {sid} disconnected")
  446. def get_event_emitter(request_info, update_db=True):
  447. async def __event_emitter__(event_data):
  448. user_id = request_info["user_id"]
  449. session_ids = list(
  450. set(
  451. USER_POOL.get(user_id, [])
  452. + (
  453. [request_info.get("session_id")]
  454. if request_info.get("session_id")
  455. else []
  456. )
  457. )
  458. )
  459. emit_tasks = [
  460. sio.emit(
  461. "chat-events",
  462. {
  463. "chat_id": request_info.get("chat_id", None),
  464. "message_id": request_info.get("message_id", None),
  465. "data": event_data,
  466. },
  467. to=session_id,
  468. )
  469. for session_id in session_ids
  470. ]
  471. await asyncio.gather(*emit_tasks)
  472. if update_db:
  473. if "type" in event_data and event_data["type"] == "status":
  474. Chats.add_message_status_to_chat_by_id_and_message_id(
  475. request_info["chat_id"],
  476. request_info["message_id"],
  477. event_data.get("data", {}),
  478. )
  479. if "type" in event_data and event_data["type"] == "message":
  480. message = Chats.get_message_by_id_and_message_id(
  481. request_info["chat_id"],
  482. request_info["message_id"],
  483. )
  484. if message:
  485. content = message.get("content", "")
  486. content += event_data.get("data", {}).get("content", "")
  487. Chats.upsert_message_to_chat_by_id_and_message_id(
  488. request_info["chat_id"],
  489. request_info["message_id"],
  490. {
  491. "content": content,
  492. },
  493. )
  494. if "type" in event_data and event_data["type"] == "replace":
  495. content = event_data.get("data", {}).get("content", "")
  496. Chats.upsert_message_to_chat_by_id_and_message_id(
  497. request_info["chat_id"],
  498. request_info["message_id"],
  499. {
  500. "content": content,
  501. },
  502. )
  503. return __event_emitter__
  504. def get_event_call(request_info):
  505. async def __event_caller__(event_data):
  506. response = await sio.call(
  507. "chat-events",
  508. {
  509. "chat_id": request_info.get("chat_id", None),
  510. "message_id": request_info.get("message_id", None),
  511. "data": event_data,
  512. },
  513. to=request_info["session_id"],
  514. )
  515. return response
  516. return __event_caller__
  517. get_event_caller = get_event_call