Bläddra i källkod

feat: collaborative note

Timothy Jaeryang Baek 2 månader sedan
förälder
incheckning
2fbff741da

+ 221 - 0
backend/open_webui/socket/main.py

@@ -5,11 +5,14 @@ import socketio
 import logging
 import sys
 import time
+from typing import Dict, Set
 from redis import asyncio as aioredis
+import pycrdt as Y
 
 from open_webui.models.users import Users, UserNameResponse
 from open_webui.models.channels import Channels
 from open_webui.models.chats import Chats
+from open_webui.models.notes import Notes, NoteUpdateForm
 from open_webui.utils.redis import (
     get_sentinels_from_env,
     get_sentinel_url_from_env,
@@ -25,6 +28,10 @@ from open_webui.env import (
 )
 from open_webui.utils.auth import decode_token
 from open_webui.socket.utils import RedisDict, RedisLock
+from open_webui.tasks import create_task, stop_item_tasks
+from open_webui.utils.redis import get_redis_connection
+from open_webui.utils.access_control import has_access, get_users_with_access
+
 
 from open_webui.env import (
     GLOBAL_LOG_LEVEL,
@@ -37,6 +44,14 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["SOCKET"])
 
 
+REDIS = get_redis_connection(
+    redis_url=WEBSOCKET_REDIS_URL,
+    redis_sentinels=get_sentinels_from_env(
+        WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
+    ),
+    async_mode=True,
+)
+
 if WEBSOCKET_MANAGER == "redis":
     if WEBSOCKET_SENTINEL_HOSTS:
         mgr = socketio.AsyncRedisManager(
@@ -90,6 +105,9 @@ if WEBSOCKET_MANAGER == "redis":
         redis_sentinels=redis_sentinels,
     )
 
+    DOCUMENTS = {}
+    DOCUMENT_USERS = {}
+
     clean_up_lock = RedisLock(
         redis_url=WEBSOCKET_REDIS_URL,
         lock_name="usage_cleanup_lock",
@@ -103,6 +121,9 @@ else:
     SESSION_POOL = {}
     USER_POOL = {}
     USAGE_POOL = {}
+
+    DOCUMENTS = {}  # document_id -> Y.YDoc instance
+    DOCUMENT_USERS = {}  # document_id -> set of user sids
     aquire_func = release_func = renew_func = lambda: True
 
 
@@ -316,6 +337,206 @@ async def channel_events(sid, data):
         )
 
 
+@sio.on("yjs:document:join")
+async def yjs_document_join(sid, data):
+    """Handle user joining a document"""
+    user = SESSION_POOL.get(sid)
+
+    try:
+        document_id = data["document_id"]
+
+        if document_id.startswith("note:"):
+            note_id = document_id.split(":")[1]
+            note = Notes.get_note_by_id(note_id)
+            if not note:
+                log.error(f"Note {note_id} not found")
+                return
+
+            if user.get("role") != "admin" and has_access(
+                user.get("id"), type="read", access_control=note.access_control
+            ):
+                log.error(
+                    f"User {user.get('id')} does not have access to note {note_id}"
+                )
+                return
+
+        user_id = data.get("user_id", sid)
+        user_name = data.get("user_name", "Anonymous")
+        user_color = data.get("user_color", "#000000")
+
+        log.info(f"User {user_id} joining document {document_id}")
+
+        # Initialize document if it doesn't exist
+        if document_id not in DOCUMENTS:
+            DOCUMENTS[document_id] = {
+                "ydoc": Y.Doc(),  # Create actual Yjs document
+                "users": set(),
+            }
+            DOCUMENT_USERS[document_id] = set()
+
+        # Add user to document
+        DOCUMENTS[document_id]["users"].add(sid)
+        DOCUMENT_USERS[document_id].add(sid)
+
+        # Join Socket.IO room
+        await sio.enter_room(sid, f"doc_{document_id}")
+
+        # Send current document state as a proper Yjs update
+        ydoc = DOCUMENTS[document_id]["ydoc"]
+
+        # Encode the entire document state as an update
+        state_update = ydoc.get_update()
+        await sio.emit(
+            "yjs:document:state",
+            {
+                "document_id": document_id,
+                "state": list(state_update),  # Convert bytes to list for JSON
+            },
+            room=sid,
+        )
+
+        # Notify other users about the new user
+        await sio.emit(
+            "yjs:user:joined",
+            {
+                "document_id": document_id,
+                "user_id": user_id,
+                "user_name": user_name,
+                "user_color": user_color,
+            },
+            room=f"doc_{document_id}",
+            skip_sid=sid,
+        )
+
+        log.info(f"User {user_id} successfully joined document {document_id}")
+
+    except Exception as e:
+        log.error(f"Error in yjs_document_join: {e}")
+        await sio.emit("error", {"message": "Failed to join document"}, room=sid)
+
+
+async def document_save_handler(document_id, data, user):
+    if document_id.startswith("note:"):
+        note_id = document_id.split(":")[1]
+        note = Notes.get_note_by_id(note_id)
+        if not note:
+            log.error(f"Note {note_id} not found")
+            return
+
+        if user.get("role") != "admin" and has_access(
+            user.get("id"), type="read", access_control=note.access_control
+        ):
+            log.error(f"User {user.get('id')} does not have access to note {note_id}")
+            return
+
+        Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
+
+
+@sio.on("yjs:document:update")
+async def yjs_document_update(sid, data):
+    """Handle Yjs document updates"""
+    try:
+        document_id = data["document_id"]
+        await stop_item_tasks(REDIS, document_id)
+
+        user_id = data.get("user_id", sid)
+        update = data["update"]  # List of bytes from frontend
+
+        if document_id not in DOCUMENTS:
+            log.warning(f"Document {document_id} not found")
+            return
+
+        # Apply the update to the server's Yjs document
+        ydoc = DOCUMENTS[document_id]["ydoc"]
+        update_bytes = bytes(update)
+
+        try:
+            ydoc.apply_update(update_bytes)
+        except Exception as e:
+            log.error(f"Failed to apply Yjs update: {e}")
+            return
+
+        # Broadcast update to all other users in the document
+        await sio.emit(
+            "yjs:document:update",
+            {
+                "document_id": document_id,
+                "user_id": user_id,
+                "update": update,
+                "socket_id": sid,  # Add socket_id to match frontend filtering
+            },
+            room=f"doc_{document_id}",
+            skip_sid=sid,
+        )
+
+        async def debounced_save():
+            await asyncio.sleep(0.5)
+            await document_save_handler(
+                document_id, data.get("data", {}), SESSION_POOL.get(sid)
+            )
+
+        await stop_item_tasks(REDIS, document_id)  # Cancel previous in-flight save
+        await create_task(REDIS, debounced_save(), document_id)
+
+    except Exception as e:
+        log.error(f"Error in yjs_document_update: {e}")
+
+
+@sio.on("yjs:document:leave")
+async def yjs_document_leave(sid, data):
+    """Handle user leaving a document"""
+    try:
+        document_id = data["document_id"]
+        user_id = data.get("user_id", sid)
+
+        log.info(f"User {user_id} leaving document {document_id}")
+
+        if document_id in DOCUMENTS:
+            DOCUMENTS[document_id]["users"].discard(sid)
+
+        if document_id in DOCUMENT_USERS:
+            DOCUMENT_USERS[document_id].discard(sid)
+
+        # Leave Socket.IO room
+        await sio.leave_room(sid, f"doc_{document_id}")
+
+        # Notify other users
+        await sio.emit(
+            "yjs:user:left",
+            {"document_id": document_id, "user_id": user_id},
+            room=f"doc_{document_id}",
+        )
+
+        if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]:
+            # If no users left, clean up the document
+            log.info(f"Cleaning up document {document_id} as no users are left")
+            del DOCUMENTS[document_id]
+            del DOCUMENT_USERS[document_id]
+
+    except Exception as e:
+        log.error(f"Error in yjs_document_leave: {e}")
+
+
+@sio.on("yjs:awareness:update")
+async def yjs_awareness_update(sid, data):
+    """Handle awareness updates (cursors, selections, etc.)"""
+    try:
+        document_id = data["document_id"]
+        user_id = data.get("user_id", sid)
+        update = data["update"]
+
+        # Broadcast awareness update to all other users in the document
+        await sio.emit(
+            "yjs:awareness:update",
+            {"document_id": document_id, "user_id": user_id, "update": update},
+            room=f"doc_{document_id}",
+            skip_sid=sid,
+        )
+
+    except Exception as e:
+        log.error(f"Error in yjs_awareness_update: {e}")
+
+
 @sio.event
 async def disconnect(sid):
     if sid in SESSION_POOL:

+ 289 - 104
src/lib/components/common/RichTextInput.svelte

@@ -56,13 +56,22 @@
 
 	import { Fragment, DOMParser } from 'prosemirror-model';
 	import { EditorState, Plugin, PluginKey, TextSelection, Selection } from 'prosemirror-state';
-	import { receiveTransaction, sendableSteps, getVersion } from 'prosemirror-collab';
-	import { Step } from 'prosemirror-transform';
-	import { Decoration, DecorationSet } from 'prosemirror-view';
 	import { Editor, Extension } from '@tiptap/core';
 
+	// Yjs imports
+	import * as Y from 'yjs';
+	import {
+		ySyncPlugin,
+		yCursorPlugin,
+		yUndoPlugin,
+		undo,
+		redo,
+		prosemirrorJSONToYDoc,
+		yDocToProsemirrorJSON
+	} from 'y-prosemirror';
+	import { keymap } from 'prosemirror-keymap';
+
 	import { AIAutocompletion } from './RichTextInput/AutoCompletion.js';
-	import History from '@tiptap/extension-history';
 	import Table from '@tiptap/extension-table';
 	import TableRow from '@tiptap/extension-table-row';
 	import TableHeader from '@tiptap/extension-table-header';
@@ -126,98 +135,292 @@
 	export let largeTextAsFile = false;
 	export let insertPromptAsRichText = false;
 
-	let isConnected = false;
-	let collaborators = new Map();
-	let version = 0;
-
-	// Custom collaboration plugin
-	const collaborationPlugin = () => {
-		return new Plugin({
-			key: new PluginKey('collaboration'),
-			state: {
-				init: () => ({ version: 0 }),
-				apply: (tr, pluginState) => {
-					const newState = { ...pluginState };
-
-					if (tr.getMeta('collaboration')) {
-						newState.version = tr.getMeta('collaboration').version;
+	let content = null;
+	let htmlValue = '';
+	let jsonValue = '';
+	let mdValue = '';
+
+	// Yjs setup
+	let ydoc = null;
+	let yXmlFragment = null;
+	let awareness = null;
+
+	// Custom Yjs Socket.IO provider
+	class SocketIOProvider {
+		constructor(doc, documentId, socket, user) {
+			this.doc = doc;
+			this.documentId = documentId;
+			this.socket = socket;
+			this.user = user;
+			this.isConnected = false;
+			this.synced = false;
+
+			this.setupEventListeners();
+		}
+
+		onConnect() {
+			this.isConnected = true;
+			this.joinDocument();
+		}
+
+		onDisconnect() {
+			this.isConnected = false;
+			this.synced = false;
+		}
+
+		setupEventListeners() {
+			// Listen for document updates from server
+			this.socket.on('yjs:document:update', (data) => {
+				if (data.document_id === this.documentId && data.socket_id !== this.socket.id) {
+					try {
+						const update = new Uint8Array(data.update);
+						Y.applyUpdate(this.doc, update);
+					} catch (error) {
+						console.error('Error applying Yjs update:', error);
 					}
+				}
+			});
 
-					return newState;
+			// Listen for document state from server
+			this.socket.on('yjs:document:state', async (data) => {
+				if (data.document_id === this.documentId) {
+					try {
+						if (data.state) {
+							const state = new Uint8Array(data.state);
+
+							if (state.length === 2 && state[0] === 0 && state[1] === 0) {
+								// Empty state, check if we have content to initialize
+								if (content) {
+									const pydoc = prosemirrorJSONToYDoc(editor.schema, content);
+									if (pydoc) {
+										Y.applyUpdate(this.doc, Y.encodeStateAsUpdate(pydoc));
+									}
+								}
+							} else {
+								Y.applyUpdate(this.doc, state);
+							}
+						}
+						this.synced = true;
+					} catch (error) {
+						console.error('Error applying Yjs state:', error);
+					}
 				}
-			},
-			view: () => ({
-				update: (view, prevState) => {
-					const sendable = sendableSteps(view.state);
-					if (sendable) {
-						socket.emit('document_steps', {
-							document_id: documentId,
-							user_id: user?.id,
-							version: sendable.version,
-							steps: sendable.steps.map((step) => step.toJSON()),
-							clientID: sendable.clientID
-						});
+			});
+
+			// Listen for awareness updates
+			this.socket.on('yjs:awareness:update', (data) => {
+				if (data.document_id === this.documentId && awareness) {
+					try {
+						const awarenessUpdate = new Uint8Array(data.update);
+						awareness.applyUpdate(awarenessUpdate, 'server');
+					} catch (error) {
+						console.error('Error applying awareness update:', error);
 					}
 				}
-			})
-		});
-	};
+			});
 
-	function initializeCollaboration() {
-		if (!socket || !user || !documentId) {
-			console.warn('Collaboration not initialized: missing socket, user, or documentId');
-			return;
+			// Handle connection events
+			this.socket.on('connect', this.onConnect);
+			this.socket.on('disconnect', this.onDisconnect);
+
+			// Listen for document updates from Yjs
+			this.doc.on('update', async (update, origin) => {
+				if (origin !== 'server' && this.isConnected) {
+					await tick(); // Ensure the DOM is updated before sending
+					this.socket.emit('yjs:document:update', {
+						document_id: this.documentId,
+						user_id: this.user?.id,
+						socket_id: this.socket.id,
+						update: Array.from(update),
+						data: {
+							content: {
+								md: mdValue,
+								html: htmlValue,
+								json: jsonValue
+							}
+						}
+					});
+				}
+			});
+
+			// Listen for awareness updates from Yjs
+			if (awareness) {
+				awareness.on('change', ({ added, updated, removed }, origin) => {
+					if (origin !== 'server' && this.isConnected) {
+						const changedClients = added.concat(updated).concat(removed);
+						const awarenessUpdate = awareness.encodeUpdate(changedClients);
+						this.socket.emit('yjs:awareness:update', {
+							document_id: this.documentId,
+							user_id: this.socket.id,
+							update: Array.from(awarenessUpdate)
+						});
+					}
+				});
+			}
+
+			if (this.socket.connected) {
+				this.isConnected = true;
+				this.joinDocument();
+			}
 		}
 
-		socket.emit('join_document', {
-			document_id: documentId,
-			user_id: user?.id,
-			user_name: user?.name,
-			user_color: user?.color
-		});
+		generateUserColor() {
+			const colors = [
+				'#FF6B6B',
+				'#4ECDC4',
+				'#45B7D1',
+				'#96CEB4',
+				'#FFEAA7',
+				'#DDA0DD',
+				'#98D8C8',
+				'#F7DC6F',
+				'#BB8FCE',
+				'#85C1E9'
+			];
+			return colors[Math.floor(Math.random() * colors.length)];
+		}
 
-		socket.on('document_steps', handleDocumentSteps);
-		socket.on('document_state', handleDocumentState);
-		socket.on('user_joined', handleUserJoined);
-		socket.on('user_left', handleUserLeft);
-		socket.on('connect', () => {
-			isConnected = true;
-		});
-		socket.on('disconnect', () => {
-			isConnected = false;
-		});
-	}
+		joinDocument() {
+			const userColor = this.generateUserColor();
+			this.socket.emit('yjs:document:join', {
+				document_id: this.documentId,
+				user_id: this.user?.id,
+				user_name: this.user?.name,
+				user_color: userColor
+			});
 
-	function handleDocumentSteps(data) {
-		if (data.user_id !== user?.id && editor) {
-			const steps = data.steps.map((stepJSON) => Step.fromJSON(editor.schema, stepJSON));
-			const tr = receiveTransaction(editor.state, steps, data.clientID);
+			// Set user awareness info
+			if (awareness && this.user) {
+				awareness.setLocalStateField('user', {
+					name: `${this.user.name}`,
+					color: userColor,
+					id: this.socket.id
+				});
+			}
+		}
 
-			if (tr) {
-				editor.view.dispatch(tr);
+		destroy() {
+			this.socket.off('yjs:document:update');
+			this.socket.off('yjs:document:state');
+			this.socket.off('yjs:awareness:update');
+			this.socket.off('connect', this.onConnect);
+			this.socket.off('disconnect', this.onDisconnect);
+
+			if (this.isConnected) {
+				this.socket.emit('yjs:document:leave', {
+					document_id: this.documentId,
+					user_id: this.user?.id
+				});
 			}
 		}
 	}
 
-	function handleDocumentState(data) {
-		version = data.version;
-		if (data.content && editor) {
-			editor.commands.setContent(data.content);
+	let provider = null;
+
+	// Simple awareness implementation
+	class SimpleAwareness {
+		constructor(yDoc) {
+			// Yjs awareness expects clientID (not clientId) property
+			this.clientID = yDoc ? yDoc.clientID : Math.floor(Math.random() * 0xffffffff);
+			// Map from clientID (number) to state (object)
+			this._states = new Map(); // _states, not states; will make getStates() for compat
+			this._updateHandlers = [];
+			this._localState = {};
+			// As in Yjs Awareness, add our local state to the states map from the start:
+			this._states.set(this.clientID, this._localState);
+		}
+		on(event, handler) {
+			if (event === 'change') this._updateHandlers.push(handler);
+		}
+		off(event, handler) {
+			if (event === 'change') {
+				const i = this._updateHandlers.indexOf(handler);
+				if (i !== -1) this._updateHandlers.splice(i, 1);
+			}
+		}
+		getLocalState() {
+			return this._states.get(this.clientID) || null;
+		}
+		getStates() {
+			// Yjs returns a Map (clientID->state)
+			return this._states;
+		}
+		setLocalStateField(field, value) {
+			let localState = this._states.get(this.clientID);
+			if (!localState) {
+				localState = {};
+				this._states.set(this.clientID, localState);
+			}
+			localState[field] = value;
+			// After updating, fire 'update' event to all handlers
+			for (const cb of this._updateHandlers) {
+				// Follows Yjs Awareness ({ added, updated, removed }, origin)
+				cb({ added: [], updated: [this.clientID], removed: [] }, 'local');
+			}
+		}
+		applyUpdate(update, origin) {
+			// Very simple: Accepts a serialized JSON state for now as Uint8Array
+			try {
+				const str = new TextDecoder().decode(update);
+				const obj = JSON.parse(str);
+				// Should be a plain object: { clientID: state, ... }
+				for (const [k, v] of Object.entries(obj)) {
+					this._states.set(+k, v);
+				}
+				for (const cb of this._updateHandlers) {
+					cb({ added: [], updated: Array.from(Object.keys(obj)).map(Number), removed: [] }, origin);
+				}
+			} catch (e) {
+				console.warn('SimpleAwareness: Could not decode update:', e);
+			}
+		}
+		encodeUpdate(clients) {
+			// Encodes the states for the given clientIDs as Uint8Array (JSON)
+			const obj = {};
+			for (const id of clients || Array.from(this._states.keys())) {
+				const st = this._states.get(id);
+				if (st) obj[id] = st;
+			}
+			const json = JSON.stringify(obj);
+			return new TextEncoder().encode(json);
 		}
-		isConnected = true;
 	}
 
-	function handleUserJoined(data) {
-		collaborators.set(data.user_id, {
-			name: data.user_name,
-			color: data.user_color
-		});
-		collaborators = collaborators;
-	}
+	// Yjs collaboration extension
+	const YjsCollaboration = Extension.create({
+		name: 'yjsCollaboration',
+
+		addProseMirrorPlugins() {
+			if (!collaboration || !yXmlFragment) return [];
 
-	function handleUserLeft(data) {
-		collaborators.delete(data.user_id);
-		collaborators = collaborators;
+			const plugins = [
+				ySyncPlugin(yXmlFragment),
+				yUndoPlugin(),
+				keymap({
+					'Mod-z': undo,
+					'Mod-y': redo,
+					'Mod-Shift-z': redo
+				})
+			];
+
+			if (awareness) {
+				plugins.push(yCursorPlugin(awareness));
+			}
+
+			return plugins;
+		}
+	});
+
+	function initializeCollaboration() {
+		if (!collaboration) return;
+
+		// Create Yjs document
+		ydoc = new Y.Doc();
+		yXmlFragment = ydoc.getXmlFragment('prosemirror');
+		awareness = new SimpleAwareness(ydoc);
+
+		// Create custom Socket.IO provider
+		provider = new SocketIOProvider(ydoc, documentId, socket, user);
 	}
 
 	let floatingMenuElement = null;
@@ -538,7 +741,7 @@
 	};
 
 	onMount(async () => {
-		let content = value;
+		content = value;
 
 		if (json) {
 			if (!content) {
@@ -655,28 +858,18 @@
 							})
 						]
 					: []),
-
-				...(collaboration
-					? [
-							Extension.create({
-								name: 'socketCollaboration',
-								addProseMirrorPlugins() {
-									return [collaborationPlugin()];
-								}
-							})
-						]
-					: [])
+				...(collaboration ? [YjsCollaboration] : [])
 			],
-			content: content,
+			content: collaboration ? undefined : content,
 			autofocus: messageInput ? true : false,
 			onTransaction: () => {
 				// force re-render so `editor.isActive` works as expected
 				editor = editor;
 
-				const htmlValue = editor.getHTML();
-				const jsonValue = editor.getJSON();
+				htmlValue = editor.getHTML();
+				jsonValue = editor.getJSON();
 
-				let mdValue = turndownService
+				mdValue = turndownService
 					.turndown(
 						htmlValue
 							.replace(/<p><\/p>/g, '<br/>')
@@ -872,16 +1065,8 @@
 	});
 
 	onDestroy(() => {
-		if (socket) {
-			socket.off('document_steps', handleDocumentSteps);
-			socket.off('document_state', handleDocumentState);
-			socket.off('user_joined', handleUserJoined);
-			socket.off('user_left', handleUserLeft);
-
-			socket.emit('leave_document', {
-				document_id: documentId,
-				user_id: userId
-			});
+		if (provider) {
+			provider.destroy();
 		}
 
 		if (editor) {
@@ -889,7 +1074,7 @@
 		}
 	});
 
-	$: if (value !== null && editor) {
+	$: if (value !== null && editor && !collaboration) {
 		onValueChange();
 	}
 

+ 7 - 7
src/lib/components/notes/NoteEditor.svelte

@@ -31,7 +31,7 @@
 	import { uploadFile } from '$lib/apis/files';
 	import { chatCompletion } from '$lib/apis/openai';
 
-	import { config, models, settings, showSidebar } from '$lib/stores';
+	import { config, models, settings, showSidebar, socket, user } from '$lib/stores';
 
 	import NotePanel from '$lib/components/notes/NotePanel.svelte';
 	import MenuLines from '../icons/MenuLines.svelte';
@@ -171,10 +171,6 @@
 		}, 200);
 	};
 
-	$: if (note) {
-		changeDebounceHandler();
-	}
-
 	$: if (id) {
 		init();
 	}
@@ -862,7 +858,7 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
 						</div>
 					</div>
 
-					<div class=" mb-2.5 px-2.5">
+					<div class="  px-2.5">
 						<div
 							class=" flex w-full bg-transparent overflow-x-auto scrollbar-none"
 							on:wheel={(e) => {
@@ -906,7 +902,7 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
 					</div>
 
 					<div
-						class=" flex-1 w-full h-full overflow-auto px-3.5 pb-20 relative"
+						class=" flex-1 w-full h-full overflow-auto px-3.5 pb-20 relative z-40 pt-2.5"
 						id="note-content-container"
 					>
 						{#if enhancing}
@@ -959,6 +955,10 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
 							html={note.data?.content?.html}
 							json={true}
 							link={true}
+							documentId={`note:${note.id}`}
+							collaboration={true}
+							socket={$socket}
+							user={$user}
 							placeholder={$i18n.t('Write something...')}
 							editable={versionIdx === null && !enhancing}
 							onChange={(content) => {