main.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. import asyncio
  2. import random
  3. import socketio
  4. import logging
  5. import sys
  6. import time
  7. from typing import Dict, Set
  8. from redis import asyncio as aioredis
  9. import pycrdt as Y
  10. from open_webui.models.users import Users, UserNameResponse
  11. from open_webui.models.channels import Channels
  12. from open_webui.models.chats import Chats
  13. from open_webui.models.notes import Notes, NoteUpdateForm
  14. from open_webui.utils.redis import (
  15. get_sentinels_from_env,
  16. get_sentinel_url_from_env,
  17. )
  18. from open_webui.config import (
  19. CORS_ALLOW_ORIGIN,
  20. )
  21. from open_webui.env import (
  22. VERSION,
  23. ENABLE_WEBSOCKET_SUPPORT,
  24. WEBSOCKET_MANAGER,
  25. WEBSOCKET_REDIS_URL,
  26. WEBSOCKET_REDIS_CLUSTER,
  27. WEBSOCKET_REDIS_LOCK_TIMEOUT,
  28. WEBSOCKET_SENTINEL_PORT,
  29. WEBSOCKET_SENTINEL_HOSTS,
  30. REDIS_KEY_PREFIX,
  31. )
  32. from open_webui.utils.auth import decode_token
  33. from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
  34. from open_webui.tasks import create_task, stop_item_tasks
  35. from open_webui.utils.redis import get_redis_connection
  36. from open_webui.utils.access_control import has_access, get_users_with_access
  37. from open_webui.env import (
  38. GLOBAL_LOG_LEVEL,
  39. SRC_LOG_LEVELS,
  40. )
  41. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  42. log = logging.getLogger(__name__)
  43. log.setLevel(SRC_LOG_LEVELS["SOCKET"])
  44. REDIS = None
  45. # Configure CORS for Socket.IO
  46. SOCKETIO_CORS_ORIGINS = "*" if CORS_ALLOW_ORIGIN == ["*"] else CORS_ALLOW_ORIGIN
  47. if WEBSOCKET_MANAGER == "redis":
  48. if WEBSOCKET_SENTINEL_HOSTS:
  49. mgr = socketio.AsyncRedisManager(
  50. get_sentinel_url_from_env(
  51. WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
  52. )
  53. )
  54. else:
  55. mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
  56. sio = socketio.AsyncServer(
  57. cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
  58. async_mode="asgi",
  59. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  60. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  61. always_connect=True,
  62. client_manager=mgr,
  63. )
  64. else:
  65. sio = socketio.AsyncServer(
  66. cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
  67. async_mode="asgi",
  68. transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
  69. allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
  70. always_connect=True,
  71. )
  72. # Timeout duration in seconds
  73. TIMEOUT_DURATION = 3
  74. # Dictionary to maintain the user pool
  75. if WEBSOCKET_MANAGER == "redis":
  76. log.debug("Using Redis to manage websockets.")
  77. REDIS = get_redis_connection(
  78. redis_url=WEBSOCKET_REDIS_URL,
  79. redis_sentinels=get_sentinels_from_env(
  80. WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
  81. ),
  82. redis_cluster=WEBSOCKET_REDIS_CLUSTER,
  83. async_mode=True,
  84. )
  85. redis_sentinels = get_sentinels_from_env(
  86. WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
  87. )
  88. SESSION_POOL = RedisDict(
  89. f"{REDIS_KEY_PREFIX}:session_pool",
  90. redis_url=WEBSOCKET_REDIS_URL,
  91. redis_sentinels=redis_sentinels,
  92. redis_cluster=WEBSOCKET_REDIS_CLUSTER,
  93. )
  94. USER_POOL = RedisDict(
  95. f"{REDIS_KEY_PREFIX}:user_pool",
  96. redis_url=WEBSOCKET_REDIS_URL,
  97. redis_sentinels=redis_sentinels,
  98. redis_cluster=WEBSOCKET_REDIS_CLUSTER,
  99. )
  100. USAGE_POOL = RedisDict(
  101. f"{REDIS_KEY_PREFIX}:usage_pool",
  102. redis_url=WEBSOCKET_REDIS_URL,
  103. redis_sentinels=redis_sentinels,
  104. redis_cluster=WEBSOCKET_REDIS_CLUSTER,
  105. )
  106. clean_up_lock = RedisLock(
  107. redis_url=WEBSOCKET_REDIS_URL,
  108. lock_name=f"{REDIS_KEY_PREFIX}:usage_cleanup_lock",
  109. timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
  110. redis_sentinels=redis_sentinels,
  111. redis_cluster=WEBSOCKET_REDIS_CLUSTER,
  112. )
  113. aquire_func = clean_up_lock.aquire_lock
  114. renew_func = clean_up_lock.renew_lock
  115. release_func = clean_up_lock.release_lock
  116. else:
  117. SESSION_POOL = {}
  118. USER_POOL = {}
  119. USAGE_POOL = {}
  120. aquire_func = release_func = renew_func = lambda: True
  121. YDOC_MANAGER = YdocManager(
  122. redis=REDIS,
  123. redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents",
  124. )
  125. async def periodic_usage_pool_cleanup():
  126. max_retries = 2
  127. retry_delay = random.uniform(
  128. WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
  129. )
  130. for attempt in range(max_retries + 1):
  131. if aquire_func():
  132. break
  133. else:
  134. if attempt < max_retries:
  135. log.debug(
  136. f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
  137. )
  138. await asyncio.sleep(retry_delay)
  139. else:
  140. log.warning(
  141. "Failed to acquire cleanup lock after retries. Skipping cleanup."
  142. )
  143. return
  144. log.debug("Running periodic_cleanup")
  145. try:
  146. while True:
  147. if not renew_func():
  148. log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
  149. raise Exception("Unable to renew usage pool cleanup lock.")
  150. now = int(time.time())
  151. send_usage = False
  152. for model_id, connections in list(USAGE_POOL.items()):
  153. # Creating a list of sids to remove if they have timed out
  154. expired_sids = [
  155. sid
  156. for sid, details in connections.items()
  157. if now - details["updated_at"] > TIMEOUT_DURATION
  158. ]
  159. for sid in expired_sids:
  160. del connections[sid]
  161. if not connections:
  162. log.debug(f"Cleaning up model {model_id} from usage pool")
  163. del USAGE_POOL[model_id]
  164. else:
  165. USAGE_POOL[model_id] = connections
  166. send_usage = True
  167. await asyncio.sleep(TIMEOUT_DURATION)
  168. finally:
  169. release_func()
  170. app = socketio.ASGIApp(
  171. sio,
  172. socketio_path="/ws/socket.io",
  173. )
  174. def get_models_in_use():
  175. # List models that are currently in use
  176. models_in_use = list(USAGE_POOL.keys())
  177. return models_in_use
  178. def get_active_user_ids():
  179. """Get the list of active user IDs."""
  180. return list(USER_POOL.keys())
  181. def get_user_active_status(user_id):
  182. """Check if a user is currently active."""
  183. return user_id in USER_POOL
  184. def get_user_id_from_session_pool(sid):
  185. user = SESSION_POOL.get(sid)
  186. if user:
  187. return user["id"]
  188. return None
  189. def get_session_ids_from_room(room):
  190. """Get all session IDs from a specific room."""
  191. active_session_ids = sio.manager.get_participants(
  192. namespace="/",
  193. room=room,
  194. )
  195. return [session_id[0] for session_id in active_session_ids]
  196. def get_user_ids_from_room(room):
  197. active_session_ids = get_session_ids_from_room(room)
  198. active_user_ids = list(
  199. set([SESSION_POOL.get(session_id)["id"] for session_id in active_session_ids])
  200. )
  201. return active_user_ids
  202. def get_active_status_by_user_id(user_id):
  203. if user_id in USER_POOL:
  204. return True
  205. return False
  206. @sio.on("usage")
  207. async def usage(sid, data):
  208. if sid in SESSION_POOL:
  209. model_id = data["model"]
  210. # Record the timestamp for the last update
  211. current_time = int(time.time())
  212. # Store the new usage data and task
  213. USAGE_POOL[model_id] = {
  214. **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
  215. sid: {"updated_at": current_time},
  216. }
  217. @sio.event
  218. async def connect(sid, environ, auth):
  219. user = None
  220. if auth and "token" in auth:
  221. data = decode_token(auth["token"])
  222. if data is not None and "id" in data:
  223. user = Users.get_user_by_id(data["id"])
  224. if user:
  225. SESSION_POOL[sid] = user.model_dump(
  226. exclude=["date_of_birth", "bio", "gender"]
  227. )
  228. if user.id in USER_POOL:
  229. USER_POOL[user.id] = USER_POOL[user.id] + [sid]
  230. else:
  231. USER_POOL[user.id] = [sid]
  232. await sio.enter_room(sid, f"user:{user.id}")
  233. @sio.on("user-join")
  234. async def user_join(sid, data):
  235. auth = data["auth"] if "auth" in data else None
  236. if not auth or "token" not in auth:
  237. return
  238. data = decode_token(auth["token"])
  239. if data is None or "id" not in data:
  240. return
  241. user = Users.get_user_by_id(data["id"])
  242. if not user:
  243. return
  244. SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"])
  245. if user.id in USER_POOL:
  246. USER_POOL[user.id] = USER_POOL[user.id] + [sid]
  247. else:
  248. USER_POOL[user.id] = [sid]
  249. await sio.enter_room(sid, f"user:{user.id}")
  250. # Join all the channels
  251. channels = Channels.get_channels_by_user_id(user.id)
  252. log.debug(f"{channels=}")
  253. for channel in channels:
  254. await sio.enter_room(sid, f"channel:{channel.id}")
  255. return {"id": user.id, "name": user.name}
  256. @sio.on("join-channels")
  257. async def join_channel(sid, data):
  258. auth = data["auth"] if "auth" in data else None
  259. if not auth or "token" not in auth:
  260. return
  261. data = decode_token(auth["token"])
  262. if data is None or "id" not in data:
  263. return
  264. user = Users.get_user_by_id(data["id"])
  265. if not user:
  266. return
  267. # Join all the channels
  268. channels = Channels.get_channels_by_user_id(user.id)
  269. log.debug(f"{channels=}")
  270. for channel in channels:
  271. await sio.enter_room(sid, f"channel:{channel.id}")
  272. @sio.on("join-note")
  273. async def join_note(sid, data):
  274. auth = data["auth"] if "auth" in data else None
  275. if not auth or "token" not in auth:
  276. return
  277. token_data = decode_token(auth["token"])
  278. if token_data is None or "id" not in token_data:
  279. return
  280. user = Users.get_user_by_id(token_data["id"])
  281. if not user:
  282. return
  283. note = Notes.get_note_by_id(data["note_id"])
  284. if not note:
  285. log.error(f"Note {data['note_id']} not found for user {user.id}")
  286. return
  287. if (
  288. user.role != "admin"
  289. and user.id != note.user_id
  290. and not has_access(user.id, type="read", access_control=note.access_control)
  291. ):
  292. log.error(f"User {user.id} does not have access to note {data['note_id']}")
  293. return
  294. log.debug(f"Joining note {note.id} for user {user.id}")
  295. await sio.enter_room(sid, f"note:{note.id}")
  296. @sio.on("events:channel")
  297. async def channel_events(sid, data):
  298. room = f"channel:{data['channel_id']}"
  299. participants = sio.manager.get_participants(
  300. namespace="/",
  301. room=room,
  302. )
  303. sids = [sid for sid, _ in participants]
  304. if sid not in sids:
  305. return
  306. event_data = data["data"]
  307. event_type = event_data["type"]
  308. if event_type == "typing":
  309. await sio.emit(
  310. "events:channel",
  311. {
  312. "channel_id": data["channel_id"],
  313. "message_id": data.get("message_id", None),
  314. "data": event_data,
  315. "user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
  316. },
  317. room=room,
  318. )
  319. @sio.on("ydoc:document:join")
  320. async def ydoc_document_join(sid, data):
  321. """Handle user joining a document"""
  322. user = SESSION_POOL.get(sid)
  323. try:
  324. document_id = data["document_id"]
  325. if document_id.startswith("note:"):
  326. note_id = document_id.split(":")[1]
  327. note = Notes.get_note_by_id(note_id)
  328. if not note:
  329. log.error(f"Note {note_id} not found")
  330. return
  331. if (
  332. user.get("role") != "admin"
  333. and user.get("id") != note.user_id
  334. and not has_access(
  335. user.get("id"), type="read", access_control=note.access_control
  336. )
  337. ):
  338. log.error(
  339. f"User {user.get('id')} does not have access to note {note_id}"
  340. )
  341. return
  342. user_id = data.get("user_id", sid)
  343. user_name = data.get("user_name", "Anonymous")
  344. user_color = data.get("user_color", "#000000")
  345. log.info(f"User {user_id} joining document {document_id}")
  346. await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid)
  347. # Join Socket.IO room
  348. await sio.enter_room(sid, f"doc_{document_id}")
  349. active_session_ids = get_session_ids_from_room(f"doc_{document_id}")
  350. # Get the Yjs document state
  351. ydoc = Y.Doc()
  352. updates = await YDOC_MANAGER.get_updates(document_id)
  353. for update in updates:
  354. ydoc.apply_update(bytes(update))
  355. # Encode the entire document state as an update
  356. state_update = ydoc.get_update()
  357. await sio.emit(
  358. "ydoc:document:state",
  359. {
  360. "document_id": document_id,
  361. "state": list(state_update), # Convert bytes to list for JSON
  362. "sessions": active_session_ids,
  363. },
  364. room=sid,
  365. )
  366. # Notify other users about the new user
  367. await sio.emit(
  368. "ydoc:user:joined",
  369. {
  370. "document_id": document_id,
  371. "user_id": user_id,
  372. "user_name": user_name,
  373. "user_color": user_color,
  374. },
  375. room=f"doc_{document_id}",
  376. skip_sid=sid,
  377. )
  378. log.info(f"User {user_id} successfully joined document {document_id}")
  379. except Exception as e:
  380. log.error(f"Error in yjs_document_join: {e}")
  381. await sio.emit("error", {"message": "Failed to join document"}, room=sid)
  382. async def document_save_handler(document_id, data, user):
  383. if document_id.startswith("note:"):
  384. note_id = document_id.split(":")[1]
  385. note = Notes.get_note_by_id(note_id)
  386. if not note:
  387. log.error(f"Note {note_id} not found")
  388. return
  389. if (
  390. user.get("role") != "admin"
  391. and user.get("id") != note.user_id
  392. and not has_access(
  393. user.get("id"), type="read", access_control=note.access_control
  394. )
  395. ):
  396. log.error(f"User {user.get('id')} does not have access to note {note_id}")
  397. return
  398. Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
  399. @sio.on("ydoc:document:state")
  400. async def yjs_document_state(sid, data):
  401. """Send the current state of the Yjs document to the user"""
  402. try:
  403. document_id = data["document_id"]
  404. room = f"doc_{document_id}"
  405. active_session_ids = get_session_ids_from_room(room)
  406. if sid not in active_session_ids:
  407. log.warning(f"Session {sid} not in room {room}. Cannot send state.")
  408. return
  409. if not await YDOC_MANAGER.document_exists(document_id):
  410. log.warning(f"Document {document_id} not found")
  411. return
  412. # Get the Yjs document state
  413. ydoc = Y.Doc()
  414. updates = await YDOC_MANAGER.get_updates(document_id)
  415. for update in updates:
  416. ydoc.apply_update(bytes(update))
  417. # Encode the entire document state as an update
  418. state_update = ydoc.get_update()
  419. await sio.emit(
  420. "ydoc:document:state",
  421. {
  422. "document_id": document_id,
  423. "state": list(state_update), # Convert bytes to list for JSON
  424. "sessions": active_session_ids,
  425. },
  426. room=sid,
  427. )
  428. except Exception as e:
  429. log.error(f"Error in yjs_document_state: {e}")
  430. @sio.on("ydoc:document:update")
  431. async def yjs_document_update(sid, data):
  432. """Handle Yjs document updates"""
  433. try:
  434. document_id = data["document_id"]
  435. try:
  436. await stop_item_tasks(REDIS, document_id)
  437. except:
  438. pass
  439. user_id = data.get("user_id", sid)
  440. update = data["update"] # List of bytes from frontend
  441. await YDOC_MANAGER.append_to_updates(
  442. document_id=document_id,
  443. update=update, # Convert list of bytes to bytes
  444. )
  445. # Broadcast update to all other users in the document
  446. await sio.emit(
  447. "ydoc:document:update",
  448. {
  449. "document_id": document_id,
  450. "user_id": user_id,
  451. "update": update,
  452. "socket_id": sid, # Add socket_id to match frontend filtering
  453. },
  454. room=f"doc_{document_id}",
  455. skip_sid=sid,
  456. )
  457. async def debounced_save():
  458. await asyncio.sleep(0.5)
  459. await document_save_handler(
  460. document_id, data.get("data", {}), SESSION_POOL.get(sid)
  461. )
  462. if data.get("data"):
  463. await create_task(REDIS, debounced_save(), document_id)
  464. except Exception as e:
  465. log.error(f"Error in yjs_document_update: {e}")
  466. @sio.on("ydoc:document:leave")
  467. async def yjs_document_leave(sid, data):
  468. """Handle user leaving a document"""
  469. try:
  470. document_id = data["document_id"]
  471. user_id = data.get("user_id", sid)
  472. log.info(f"User {user_id} leaving document {document_id}")
  473. # Remove user from the document
  474. await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid)
  475. # Leave Socket.IO room
  476. await sio.leave_room(sid, f"doc_{document_id}")
  477. # Notify other users
  478. await sio.emit(
  479. "ydoc:user:left",
  480. {"document_id": document_id, "user_id": user_id},
  481. room=f"doc_{document_id}",
  482. )
  483. if (
  484. await YDOC_MANAGER.document_exists(document_id)
  485. and len(await YDOC_MANAGER.get_users(document_id)) == 0
  486. ):
  487. log.info(f"Cleaning up document {document_id} as no users are left")
  488. await YDOC_MANAGER.clear_document(document_id)
  489. except Exception as e:
  490. log.error(f"Error in yjs_document_leave: {e}")
  491. @sio.on("ydoc:awareness:update")
  492. async def yjs_awareness_update(sid, data):
  493. """Handle awareness updates (cursors, selections, etc.)"""
  494. try:
  495. document_id = data["document_id"]
  496. user_id = data.get("user_id", sid)
  497. update = data["update"]
  498. # Broadcast awareness update to all other users in the document
  499. await sio.emit(
  500. "ydoc:awareness:update",
  501. {"document_id": document_id, "user_id": user_id, "update": update},
  502. room=f"doc_{document_id}",
  503. skip_sid=sid,
  504. )
  505. except Exception as e:
  506. log.error(f"Error in yjs_awareness_update: {e}")
  507. @sio.event
  508. async def disconnect(sid):
  509. if sid in SESSION_POOL:
  510. user = SESSION_POOL[sid]
  511. del SESSION_POOL[sid]
  512. user_id = user["id"]
  513. USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
  514. if len(USER_POOL[user_id]) == 0:
  515. del USER_POOL[user_id]
  516. await YDOC_MANAGER.remove_user_from_all_documents(sid)
  517. else:
  518. pass
  519. # print(f"Unknown session ID {sid} disconnected")
  520. def get_event_emitter(request_info, update_db=True):
  521. async def __event_emitter__(event_data):
  522. user_id = request_info["user_id"]
  523. chat_id = request_info["chat_id"]
  524. message_id = request_info["message_id"]
  525. await sio.emit(
  526. "events",
  527. {
  528. "chat_id": chat_id,
  529. "message_id": message_id,
  530. "data": event_data,
  531. },
  532. room=f"user:{user_id}",
  533. )
  534. if (
  535. update_db
  536. and message_id
  537. and not request_info.get("chat_id", "").startswith("local:")
  538. ):
  539. if "type" in event_data and event_data["type"] == "status":
  540. Chats.add_message_status_to_chat_by_id_and_message_id(
  541. request_info["chat_id"],
  542. request_info["message_id"],
  543. event_data.get("data", {}),
  544. )
  545. if "type" in event_data and event_data["type"] == "message":
  546. message = Chats.get_message_by_id_and_message_id(
  547. request_info["chat_id"],
  548. request_info["message_id"],
  549. )
  550. if message:
  551. content = message.get("content", "")
  552. content += event_data.get("data", {}).get("content", "")
  553. Chats.upsert_message_to_chat_by_id_and_message_id(
  554. request_info["chat_id"],
  555. request_info["message_id"],
  556. {
  557. "content": content,
  558. },
  559. )
  560. if "type" in event_data and event_data["type"] == "replace":
  561. content = event_data.get("data", {}).get("content", "")
  562. Chats.upsert_message_to_chat_by_id_and_message_id(
  563. request_info["chat_id"],
  564. request_info["message_id"],
  565. {
  566. "content": content,
  567. },
  568. )
  569. if "type" in event_data and event_data["type"] == "embeds":
  570. message = Chats.get_message_by_id_and_message_id(
  571. request_info["chat_id"],
  572. request_info["message_id"],
  573. )
  574. embeds = event_data.get("data", {}).get("embeds", [])
  575. embeds.extend(message.get("embeds", []))
  576. Chats.upsert_message_to_chat_by_id_and_message_id(
  577. request_info["chat_id"],
  578. request_info["message_id"],
  579. {
  580. "embeds": embeds,
  581. },
  582. )
  583. if "type" in event_data and event_data["type"] == "files":
  584. message = Chats.get_message_by_id_and_message_id(
  585. request_info["chat_id"],
  586. request_info["message_id"],
  587. )
  588. files = event_data.get("data", {}).get("files", [])
  589. files.extend(message.get("files", []))
  590. Chats.upsert_message_to_chat_by_id_and_message_id(
  591. request_info["chat_id"],
  592. request_info["message_id"],
  593. {
  594. "files": files,
  595. },
  596. )
  597. if event_data.get("type") in ["source", "citation"]:
  598. data = event_data.get("data", {})
  599. if data.get("type") == None:
  600. message = Chats.get_message_by_id_and_message_id(
  601. request_info["chat_id"],
  602. request_info["message_id"],
  603. )
  604. sources = message.get("sources", [])
  605. sources.append(data)
  606. Chats.upsert_message_to_chat_by_id_and_message_id(
  607. request_info["chat_id"],
  608. request_info["message_id"],
  609. {
  610. "sources": sources,
  611. },
  612. )
  613. if (
  614. "user_id" in request_info
  615. and "chat_id" in request_info
  616. and "message_id" in request_info
  617. ):
  618. return __event_emitter__
  619. else:
  620. return None
  621. def get_event_call(request_info):
  622. async def __event_caller__(event_data):
  623. response = await sio.call(
  624. "events",
  625. {
  626. "chat_id": request_info.get("chat_id", None),
  627. "message_id": request_info.get("message_id", None),
  628. "data": event_data,
  629. },
  630. to=request_info["session_id"],
  631. )
  632. return response
  633. if (
  634. "session_id" in request_info
  635. and "chat_id" in request_info
  636. and "message_id" in request_info
  637. ):
  638. return __event_caller__
  639. else:
  640. return None
  641. get_event_caller = get_event_call