Browse Source

enh/refac: distributed crdt

Timothy Jaeryang Baek 2 months ago
parent
commit
7f1f39058a
2 changed files with 133 additions and 26 deletions
  1. 25 26
      backend/open_webui/socket/main.py
  2. 108 0
      backend/open_webui/socket/utils.py

+ 25 - 26
backend/open_webui/socket/main.py

@@ -27,7 +27,7 @@ from open_webui.env import (
     WEBSOCKET_SENTINEL_HOSTS,
 )
 from open_webui.utils.auth import decode_token
-from open_webui.socket.utils import RedisDict, RedisLock
+from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
 from open_webui.tasks import create_task, stop_item_tasks
 from open_webui.utils.redis import get_redis_connection
 from open_webui.utils.access_control import has_access, get_users_with_access
@@ -125,7 +125,10 @@ else:
 
 
 # TODO: Implement Yjs document management with Redis
-DOCUMENTS = {}  # document_id -> Y.YDoc instance
+YDOC_MANAGER = YdocManager(
+    redis=REDIS,
+    redis_key_prefix="open-webui:ydoc:documents",
+)
 
 
 async def periodic_usage_pool_cleanup():
@@ -374,16 +377,7 @@ async def ydoc_document_join(sid, data):
         user_color = data.get("user_color", "#000000")
 
         log.info(f"User {user_id} joining document {document_id}")
-
-        # Initialize document if it doesn't exist
-        if document_id not in DOCUMENTS:
-            DOCUMENTS[document_id] = {
-                "updates": [],  # Store updates for the document
-                "users": set(),
-            }
-
-        # Add user to document
-        DOCUMENTS[document_id]["users"].add(sid)
+        await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid)
 
         # Join Socket.IO room
         await sio.enter_room(sid, f"doc_{document_id}")
@@ -392,7 +386,8 @@ async def ydoc_document_join(sid, data):
 
         # Get the Yjs document state
         ydoc = Y.Doc()
-        for update in DOCUMENTS[document_id]["updates"]:
+        updates = await YDOC_MANAGER.get_updates(document_id)
+        for update in updates:
             ydoc.apply_update(bytes(update))
 
         # Encode the entire document state as an update
@@ -461,13 +456,14 @@ async def yjs_document_state(sid, data):
             log.warning(f"Session {sid} not in room {room}. Cannot send state.")
             return
 
-        if document_id not in DOCUMENTS:
+        if not await YDOC_MANAGER.document_exists(document_id):
             log.warning(f"Document {document_id} not found")
             return
 
         # Get the Yjs document state
         ydoc = Y.Doc()
-        for update in DOCUMENTS[document_id]["updates"]:
+        updates = await YDOC_MANAGER.get_updates(document_id)
+        for update in updates:
             ydoc.apply_update(bytes(update))
 
         # Encode the entire document state as an update
@@ -491,6 +487,7 @@ async def yjs_document_update(sid, data):
     """Handle Yjs document updates"""
     try:
         document_id = data["document_id"]
+
         try:
             await stop_item_tasks(REDIS, document_id)
         except:
@@ -500,12 +497,10 @@ async def yjs_document_update(sid, data):
 
         update = data["update"]  # List of bytes from frontend
 
-        if document_id not in DOCUMENTS:
-            log.warning(f"Document {document_id} not found")
-            return
-
-        updates = DOCUMENTS[document_id]["updates"]
-        updates.append(update)
+        await YDOC_MANAGER.append_to_updates(
+            document_id=document_id,
+            update=update,  # Convert list of bytes to bytes
+        )
 
         # Broadcast update to all other users in the document
         await sio.emit(
@@ -541,8 +536,8 @@ async def yjs_document_leave(sid, data):
 
         log.info(f"User {user_id} leaving document {document_id}")
 
-        if document_id in DOCUMENTS:
-            DOCUMENTS[document_id]["users"].discard(sid)
+        # Remove user from the document
+        await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid)
 
         # Leave Socket.IO room
         await sio.leave_room(sid, f"doc_{document_id}")
@@ -554,10 +549,12 @@ async def yjs_document_leave(sid, data):
             room=f"doc_{document_id}",
         )
 
-        if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]:
-            # If no users left, clean up the document
+        if (
+            YDOC_MANAGER.document_exists(document_id)
+            and len(await YDOC_MANAGER.get_users(document_id)) == 0
+        ):
             log.info(f"Cleaning up document {document_id} as no users are left")
-            del DOCUMENTS[document_id]
+            await YDOC_MANAGER.clear_document(document_id)
 
     except Exception as e:
         log.error(f"Error in yjs_document_leave: {e}")
@@ -594,6 +591,8 @@ async def disconnect(sid):
 
         if len(USER_POOL[user_id]) == 0:
             del USER_POOL[user_id]
+
+        await YDOC_MANAGER.remove_user_from_all_documents(sid)
     else:
         pass
         # print(f"Unknown session ID {sid} disconnected")

+ 108 - 0
backend/open_webui/socket/utils.py

@@ -1,6 +1,8 @@
 import json
 import uuid
 from open_webui.utils.redis import get_redis_connection
+from typing import Optional, List, Tuple
+import pycrdt as Y
 
 
 class RedisLock:
@@ -89,3 +91,109 @@ class RedisDict:
         if key not in self:
             self[key] = default
         return self[key]
+
+
+class YdocManager:
+    def __init__(
+        self,
+        redis=None,
+        redis_key_prefix: str = "open-webui:ydoc:documents",
+    ):
+        self._updates = {}
+        self._users = {}
+        self._redis = redis
+        self._redis_key_prefix = redis_key_prefix
+
+    async def append_to_updates(self, document_id: str, update: bytes):
+        document_id = document_id.replace(":", "_")
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+            await self._redis.rpush(redis_key, json.dumps(list(update)))
+        else:
+            if document_id not in self._updates:
+                self._updates[document_id] = []
+            self._updates[document_id].append(update)
+
+    async def get_updates(self, document_id: str) -> List[bytes]:
+        document_id = document_id.replace(":", "_")
+
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+            updates = await self._redis.lrange(redis_key, 0, -1)
+            return [bytes(json.loads(update)) for update in updates]
+        else:
+            return self._updates.get(document_id, [])
+
+    async def document_exists(self, document_id: str) -> bool:
+        document_id = document_id.replace(":", "_")
+
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+            return await self._redis.exists(redis_key) > 0
+        else:
+            return document_id in self._updates
+
+    async def get_users(self, document_id: str) -> List[str]:
+        document_id = document_id.replace(":", "_")
+
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+            users = await self._redis.smembers(redis_key)
+            return list(users)
+        else:
+            return self._users.get(document_id, [])
+
+    async def add_user(self, document_id: str, user_id: str):
+        document_id = document_id.replace(":", "_")
+
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+            await self._redis.sadd(redis_key, user_id)
+        else:
+            if document_id not in self._users:
+                self._users[document_id] = set()
+            self._users[document_id].add(user_id)
+
+    async def remove_user(self, document_id: str, user_id: str):
+        document_id = document_id.replace(":", "_")
+
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+            await self._redis.srem(redis_key, user_id)
+        else:
+            if document_id in self._users and user_id in self._users[document_id]:
+                self._users[document_id].remove(user_id)
+
+    async def remove_user_from_all_documents(self, user_id: str):
+        if self._redis:
+            keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
+            for key in keys:
+                if key.endswith(":users"):
+                    await self._redis.srem(key, user_id)
+
+                    document_id = key.split(":")[-2]
+                    if len(await self.get_users(document_id)) == 0:
+                        await self.clear_document(document_id)
+
+        else:
+            for document_id in list(self._users.keys()):
+                if user_id in self._users[document_id]:
+                    self._users[document_id].remove(user_id)
+                    if not self._users[document_id]:
+                        del self._users[document_id]
+
+                        await self.clear_document(document_id)
+
+    async def clear_document(self, document_id: str):
+        document_id = document_id.replace(":", "_")
+
+        if self._redis:
+            redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+            await self._redis.delete(redis_key)
+            redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
+            await self._redis.delete(redis_users_key)
+        else:
+            if document_id in self._updates:
+                del self._updates[document_id]
+            if document_id in self._users:
+                del self._users[document_id]