Prechádzať zdrojové kódy

enh: reply to message

Timothy Jaeryang Baek 1 týždeň pred
rodič
commit
1a18928c94

+ 34 - 0
backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py

@@ -0,0 +1,34 @@
+"""Add reply_to_id column to message
+
+Revision ID: a5c220713937
+Revises: 38d63c18f30f
+Create Date: 2025-09-27 02:24:18.058455
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision: str = "a5c220713937"
+down_revision: Union[str, None] = "38d63c18f30f"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+    # Add 'reply_to_id' column to the 'message' table for replying to messages
+    op.add_column(
+        "message",
+        sa.Column("reply_to_id", sa.Text(), nullable=True),
+    )
+    pass
+
+
+def downgrade() -> None:
+    # Remove 'reply_to_id' column from the 'message' table
+    op.drop_column("message", "reply_to_id")
+
+    pass

+ 58 - 9
backend/open_webui/models/messages.py

@@ -5,6 +5,7 @@ from typing import Optional
 
 from open_webui.internal.db import Base, get_db
 from open_webui.models.tags import TagModel, Tag, Tags
+from open_webui.models.users import Users, UserNameResponse
 
 
 from pydantic import BaseModel, ConfigDict
@@ -43,6 +44,7 @@ class Message(Base):
     user_id = Column(Text)
     channel_id = Column(Text, nullable=True)
 
+    reply_to_id = Column(Text, nullable=True)
     parent_id = Column(Text, nullable=True)
 
     content = Column(Text)
@@ -60,6 +62,7 @@ class MessageModel(BaseModel):
     user_id: str
     channel_id: Optional[str] = None
 
+    reply_to_id: Optional[str] = None
     parent_id: Optional[str] = None
 
     content: str
@@ -77,6 +80,7 @@ class MessageModel(BaseModel):
 
 class MessageForm(BaseModel):
     content: str
+    reply_to_id: Optional[str] = None
     parent_id: Optional[str] = None
     data: Optional[dict] = None
     meta: Optional[dict] = None
@@ -88,7 +92,15 @@ class Reactions(BaseModel):
     count: int
 
 
-class MessageResponse(MessageModel):
+class MessageUserResponse(MessageModel):
+    user: Optional[UserNameResponse] = None
+
+
+class MessageReplyToResponse(MessageUserResponse):
+    reply_to_message: Optional[MessageUserResponse] = None
+
+
+class MessageResponse(MessageReplyToResponse):
     latest_reply_at: Optional[int]
     reply_count: int
     reactions: list[Reactions]
@@ -107,6 +119,7 @@ class MessageTable:
                     "id": id,
                     "user_id": user_id,
                     "channel_id": channel_id,
+                    "reply_to_id": form_data.reply_to_id,
                     "parent_id": form_data.parent_id,
                     "content": form_data.content,
                     "data": form_data.data,
@@ -122,25 +135,36 @@ class MessageTable:
             db.refresh(result)
             return MessageModel.model_validate(result) if result else None
 
-    def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
+    def get_message_by_id(self, id: str) -> Optional[MessageReplyToResponse]:
         with get_db() as db:
             message = db.get(Message, id)
             if not message:
                 return None
 
+            reply_to_message = (
+                self.get_message_by_id(message.reply_to_id)
+                if message.reply_to_id
+                else None
+            )
             reactions = self.get_reactions_by_message_id(id)
-            replies = self.get_replies_by_message_id(id)
+            replies = self.get_thread_replies_by_message_id(id)
 
-            return MessageResponse(
-                **{
+            user = Users.get_user_by_id(message.user_id)
+
+            return MessageReplyToResponse.model_validate(
+                {
                     **MessageModel.model_validate(message).model_dump(),
+                    "user": user.model_dump() if user else None,
+                    "reply_to_message": (
+                        reply_to_message.model_dump() if reply_to_message else None
+                    ),
                     "latest_reply_at": replies[0].created_at if replies else None,
                     "reply_count": len(replies),
                     "reactions": reactions,
                 }
             )
 
-    def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
+    def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
         with get_db() as db:
             all_messages = (
                 db.query(Message)
@@ -148,7 +172,19 @@ class MessageTable:
                 .order_by(Message.created_at.desc())
                 .all()
             )
-            return [MessageModel.model_validate(message) for message in all_messages]
+            return [
+                MessageReplyToResponse.model_validate(
+                    {
+                        **MessageModel.model_validate(message).model_dump(),
+                        "reply_to_message": (
+                            self.get_message_by_id(message.reply_to_id).model_dump()
+                            if message.reply_to_id
+                            else None
+                        ),
+                    }
+                )
+                for message in all_messages
+            ]
 
     def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
         with get_db() as db:
@@ -159,7 +195,7 @@ class MessageTable:
 
     def get_messages_by_channel_id(
         self, channel_id: str, skip: int = 0, limit: int = 50
-    ) -> list[MessageModel]:
+    ) -> list[MessageReplyToResponse]:
         with get_db() as db:
             all_messages = (
                 db.query(Message)
@@ -169,7 +205,20 @@ class MessageTable:
                 .limit(limit)
                 .all()
             )
-            return [MessageModel.model_validate(message) for message in all_messages]
+
+            return [
+                MessageReplyToResponse.model_validate(
+                    {
+                        **MessageModel.model_validate(message).model_dump(),
+                        "reply_to_message": (
+                            self.get_message_by_id(message.reply_to_id).model_dump()
+                            if message.reply_to_id
+                            else None
+                        ),
+                    }
+                )
+                for message in all_messages
+            ]
 
     def get_messages_by_parent_id(
         self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50

+ 31 - 54
backend/open_webui/routers/channels.py

@@ -167,7 +167,7 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
 
 
 class MessageUserResponse(MessageResponse):
-    user: UserNameResponse
+    pass
 
 
 @router.get("/{id}/messages", response_model=list[MessageUserResponse])
@@ -196,15 +196,17 @@ async def get_channel_messages(
             user = Users.get_user_by_id(message.user_id)
             users[message.user_id] = user
 
-        replies = Messages.get_replies_by_message_id(message.id)
-        latest_reply_at = replies[0].created_at if replies else None
+        thread_replies = Messages.get_thread_replies_by_message_id(message.id)
+        latest_thread_reply_at = (
+            thread_replies[0].created_at if thread_replies else None
+        )
 
         messages.append(
             MessageUserResponse(
                 **{
                     **message.model_dump(),
-                    "reply_count": len(replies),
-                    "latest_reply_at": latest_reply_at,
+                    "reply_count": len(thread_replies),
+                    "latest_reply_at": latest_thread_reply_at,
                     "reactions": Messages.get_reactions_by_message_id(message.id),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                 }
@@ -253,12 +255,26 @@ async def model_response_handler(request, channel, message, user):
     mentions = extract_mentions(message.content)
     message_content = replace_mentions(message.content)
 
+    model_mentions = {}
+
+    # check if the message is a reply to a message sent by a model
+    if (
+        message.reply_to_message
+        and message.reply_to_message.meta
+        and message.reply_to_message.meta.get("model_id", None)
+    ):
+        model_id = message.reply_to_message.meta.get("model_id", None)
+        model_mentions[model_id] = {"id": model_id, "id_type": "M"}
+
     # check if any of the mentions are models
-    model_mentions = [mention for mention in mentions if mention["id_type"] == "M"]
+    for mention in mentions:
+        if mention["id_type"] == "M" and mention["id"] not in model_mentions:
+            model_mentions[mention["id"]] = mention
+
     if not model_mentions:
         return False
 
-    for mention in model_mentions:
+    for mention in model_mentions.values():
         model_id = mention["id"]
         model = MODELS.get(model_id, None)
 
@@ -406,24 +422,14 @@ async def new_message_handler(
 
     try:
         message = Messages.insert_new_message(form_data, channel.id, user.id)
-
         if message:
+            message = Messages.get_message_by_id(message.id)
             event_data = {
                 "channel_id": channel.id,
                 "message_id": message.id,
                 "data": {
                     "type": "message",
-                    "data": MessageUserResponse(
-                        **{
-                            **message.model_dump(),
-                            "reply_count": 0,
-                            "latest_reply_at": None,
-                            "reactions": Messages.get_reactions_by_message_id(
-                                message.id
-                            ),
-                            "user": UserNameResponse(**user.model_dump()),
-                        }
-                    ).model_dump(),
+                    "data": message.model_dump(),
                 },
                 "user": UserNameResponse(**user.model_dump()).model_dump(),
                 "channel": channel.model_dump(),
@@ -447,23 +453,16 @@ async def new_message_handler(
                             "message_id": parent_message.id,
                             "data": {
                                 "type": "message:reply",
-                                "data": MessageUserResponse(
-                                    **{
-                                        **parent_message.model_dump(),
-                                        "user": UserNameResponse(
-                                            **Users.get_user_by_id(
-                                                parent_message.user_id
-                                            ).model_dump()
-                                        ),
-                                    }
-                                ).model_dump(),
+                                "data": parent_message.model_dump(),
                             },
                             "user": UserNameResponse(**user.model_dump()).model_dump(),
                             "channel": channel.model_dump(),
                         },
                         to=f"channel:{channel.id}",
                     )
-        return MessageModel(**message.model_dump()), channel
+            return message, channel
+        else:
+            raise Exception("Error creating message")
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -651,14 +650,7 @@ async def update_message_by_id(
                     "message_id": message.id,
                     "data": {
                         "type": "message:update",
-                        "data": MessageUserResponse(
-                            **{
-                                **message.model_dump(),
-                                "user": UserNameResponse(
-                                    **user.model_dump()
-                                ).model_dump(),
-                            }
-                        ).model_dump(),
+                        "data": message.model_dump(),
                     },
                     "user": UserNameResponse(**user.model_dump()).model_dump(),
                     "channel": channel.model_dump(),
@@ -724,9 +716,6 @@ async def add_reaction_to_message(
                     "type": "message:reaction:add",
                     "data": {
                         **message.model_dump(),
-                        "user": UserNameResponse(
-                            **Users.get_user_by_id(message.user_id).model_dump()
-                        ).model_dump(),
                         "name": form_data.name,
                     },
                 },
@@ -793,9 +782,6 @@ async def remove_reaction_by_id_and_user_id_and_name(
                     "type": "message:reaction:remove",
                     "data": {
                         **message.model_dump(),
-                        "user": UserNameResponse(
-                            **Users.get_user_by_id(message.user_id).model_dump()
-                        ).model_dump(),
                         "name": form_data.name,
                     },
                 },
@@ -882,16 +868,7 @@ async def delete_message_by_id(
                         "message_id": parent_message.id,
                         "data": {
                             "type": "message:reply",
-                            "data": MessageUserResponse(
-                                **{
-                                    **parent_message.model_dump(),
-                                    "user": UserNameResponse(
-                                        **Users.get_user_by_id(
-                                            parent_message.user_id
-                                        ).model_dump()
-                                    ),
-                                }
-                            ).model_dump(),
+                            "data": parent_message.model_dump(),
                         },
                         "user": UserNameResponse(**user.model_dump()).model_dump(),
                         "channel": channel.model_dump(),

+ 1 - 0
src/lib/apis/channels/index.ts

@@ -248,6 +248,7 @@ export const getChannelThreadMessages = async (
 };
 
 type MessageForm = {
+	reply_to_id?: string;
 	parent_id?: string;
 	content: string;
 	data?: object;

+ 21 - 7
src/lib/components/channel/Channel.svelte

@@ -20,12 +20,14 @@
 
 	let scrollEnd = true;
 	let messagesContainerElement = null;
+	let chatInputElement = null;
 
 	let top = false;
 
 	let channel = null;
 	let messages = null;
 
+	let replyToMessage = null;
 	let threadId = null;
 
 	let typingUsers = [];
@@ -141,16 +143,20 @@
 			return;
 		}
 
-		const res = await sendMessage(localStorage.token, id, { content: content, data: data }).catch(
-			(error) => {
-				toast.error(`${error}`);
-				return null;
-			}
-		);
+		const res = await sendMessage(localStorage.token, id, {
+			content: content,
+			data: data,
+			reply_to_id: replyToMessage?.id ?? null
+		}).catch((error) => {
+			toast.error(`${error}`);
+			return null;
+		});
 
 		if (res) {
 			messagesContainerElement.scrollTop = messagesContainerElement.scrollHeight;
 		}
+
+		replyToMessage = null;
 	};
 
 	const onChange = async () => {
@@ -222,8 +228,14 @@
 						{#key id}
 							<Messages
 								{channel}
-								{messages}
 								{top}
+								{messages}
+								{replyToMessage}
+								onReply={async (message) => {
+									replyToMessage = message;
+									await tick();
+									chatInputElement?.focus();
+								}}
 								onThread={(id) => {
 									threadId = id;
 								}}
@@ -250,6 +262,8 @@
 			<div class=" pb-[1rem] px-2.5">
 				<MessageInput
 					id="root"
+					bind:chatInputElement
+					bind:replyToMessage
 					{typingUsers}
 					userSuggestions={true}
 					channelSuggestions={true}

+ 37 - 5
src/lib/components/channel/MessageInput.svelte

@@ -23,20 +23,23 @@
 
 	import { getSessionUser } from '$lib/apis/auths';
 
+	import { uploadFile } from '$lib/apis/files';
+	import { WEBUI_API_BASE_URL } from '$lib/constants';
+
+	import { getSuggestionRenderer } from '../common/RichTextInput/suggestions';
+	import CommandSuggestionList from '../chat/MessageInput/CommandSuggestionList.svelte';
+
+	import InputMenu from './MessageInput/InputMenu.svelte';
 	import Tooltip from '../common/Tooltip.svelte';
 	import RichTextInput from '../common/RichTextInput.svelte';
 	import VoiceRecording from '../chat/MessageInput/VoiceRecording.svelte';
-	import InputMenu from './MessageInput/InputMenu.svelte';
-	import { uploadFile } from '$lib/apis/files';
-	import { WEBUI_API_BASE_URL } from '$lib/constants';
 	import FileItem from '../common/FileItem.svelte';
 	import Image from '../common/Image.svelte';
 	import FilesOverlay from '../chat/MessageInput/FilesOverlay.svelte';
 	import InputVariablesModal from '../chat/MessageInput/InputVariablesModal.svelte';
-	import { getSuggestionRenderer } from '../common/RichTextInput/suggestions';
-	import CommandSuggestionList from '../chat/MessageInput/CommandSuggestionList.svelte';
 	import MentionList from './MessageInput/MentionList.svelte';
 	import Skeleton from '../chat/Messages/Skeleton.svelte';
+	import XMark from '../icons/XMark.svelte';
 
 	export let placeholder = $i18n.t('Type here...');
 
@@ -60,6 +63,8 @@
 	export let userSuggestions = false;
 	export let channelSuggestions = false;
 
+	export let replyToMessage = null;
+
 	export let typingUsersClassName = 'from-white dark:from-gray-900';
 
 	let loaded = false;
@@ -773,6 +778,32 @@
 							class="flex-1 flex flex-col relative w-full shadow-lg rounded-3xl border border-gray-50 dark:border-gray-850 hover:border-gray-100 focus-within:border-gray-100 hover:dark:border-gray-800 focus-within:dark:border-gray-800 transition px-1 bg-white/90 dark:bg-gray-400/5 dark:text-gray-100"
 							dir={$settings?.chatDirection ?? 'auto'}
 						>
+							{#if replyToMessage !== null}
+								<div class="px-3 pt-3 text-left w-full flex flex-col z-10">
+									<div class="flex items-center justify-between w-full">
+										<div class="pl-[1px] flex items-center gap-2 text-sm">
+											<div class="translate-y-[0.5px]">
+												<span class=""
+													>{$i18n.t('Replying to {{NAME}}', {
+														NAME: replyToMessage?.meta?.model_name ?? replyToMessage.user.name
+													})}</span
+												>
+											</div>
+										</div>
+										<div>
+											<button
+												class="flex items-center dark:text-gray-500"
+												on:click={() => {
+													replyToMessage = null;
+												}}
+											>
+												<XMark />
+											</button>
+										</div>
+									</div>
+								</div>
+							{/if}
+
 							{#if files.length > 0}
 								<div class="mx-2 mt-2.5 -mb-1 flex flex-wrap gap-2">
 									{#each files as file, fileIdx}
@@ -890,6 +921,7 @@
 
 												if (e.key === 'Escape') {
 													console.info('Escape');
+													replyToMessage = null;
 												}
 											}}
 											on:paste={async (e) => {

+ 8 - 1
src/lib/components/channel/Messages.svelte

@@ -23,10 +23,12 @@
 	export let id = null;
 	export let channel = null;
 	export let messages = [];
+	export let replyToMessage = null;
 	export let top = false;
 	export let thread = false;
 
 	export let onLoad: Function = () => {};
+	export let onReply: Function = () => {};
 	export let onThread: Function = () => {};
 
 	let messagesLoading = false;
@@ -94,10 +96,12 @@
 			<Message
 				{message}
 				{thread}
+				replyToMessage={replyToMessage?.id === message.id}
 				disabled={!channel?.write_access}
 				showUserProfile={messageIdx === 0 ||
 					messageList.at(messageIdx - 1)?.user_id !== message.user_id ||
-					messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id}
+					messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id ||
+					message?.reply_to_message}
 				onDelete={() => {
 					messages = messages.filter((m) => m.id !== message.id);
 
@@ -123,6 +127,9 @@
 						return null;
 					});
 				}}
+				onReply={(message) => {
+					onReply(message);
+				}}
 				onThread={(id) => {
 					onThread(id);
 				}}

+ 90 - 5
src/lib/components/channel/Messages/Message.svelte

@@ -13,8 +13,9 @@
 	import { getContext, onMount } from 'svelte';
 	const i18n = getContext<Writable<i18nType>>('i18n');
 
-	import { settings, user, shortCodesToEmojis } from '$lib/stores';
+	import { formatDate } from '$lib/utils';
 
+	import { settings, user, shortCodesToEmojis } from '$lib/stores';
 	import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
 
 	import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
@@ -32,18 +33,20 @@
 	import FaceSmile from '$lib/components/icons/FaceSmile.svelte';
 	import EmojiPicker from '$lib/components/common/EmojiPicker.svelte';
 	import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
-	import { formatDate } from '$lib/utils';
 	import Emoji from '$lib/components/common/Emoji.svelte';
-	import { t } from 'i18next';
 	import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte';
+	import ArrowUpLeftAlt from '$lib/components/icons/ArrowUpLeftAlt.svelte';
 
 	export let message;
 	export let showUserProfile = true;
 	export let thread = false;
+
+	export let replyToMessage = false;
 	export let disabled = false;
 
 	export let onDelete: Function = () => {};
 	export let onEdit: Function = () => {};
+	export let onReply: Function = () => {};
 	export let onThread: Function = () => {};
 	export let onReaction: Function = () => {};
 
@@ -65,9 +68,15 @@
 
 {#if message}
 	<div
+		id="message-{message.id}"
 		class="flex flex-col justify-between px-5 {showUserProfile
 			? 'pt-1.5 pb-0.5'
-			: ''} w-full max-w-full mx-auto group hover:bg-gray-300/5 dark:hover:bg-gray-700/5 transition relative"
+			: ''} w-full max-w-full mx-auto group hover:bg-gray-300/5 dark:hover:bg-gray-700/5 transition relative {replyToMessage
+			? 'border-l-4 border-blue-500 bg-blue-100/10 dark:bg-blue-100/5 pl-4'
+			: ''} {(message?.reply_to_message?.meta?.model_id ?? message?.reply_to_message?.user_id) ===
+		$user?.id
+			? 'border-l-4 border-orange-500 bg-orange-100/10 dark:bg-orange-100/5 pl-4'
+			: ''}"
 	>
 		{#if !edit && !disabled}
 			<div
@@ -95,6 +104,17 @@
 						</Tooltip>
 					</EmojiPicker>
 
+					<Tooltip content={$i18n.t('Reply')}>
+						<button
+							class="hover:bg-gray-100 dark:hover:bg-gray-800 transition rounded-lg p-0.5"
+							on:click={() => {
+								onReply(message);
+							}}
+						>
+							<ArrowUpLeftAlt className="size-5" />
+						</button>
+					</Tooltip>
+
 					{#if !thread}
 						<Tooltip content={$i18n.t('Reply in Thread')}>
 							<button
@@ -134,6 +154,56 @@
 			</div>
 		{/if}
 
+		{#if message?.reply_to_message?.user}
+			<div class="relative text-xs mb-1">
+				<div
+					class="absolute h-3 w-7 left-[18px] top-2 rounded-tl-lg border-t-2 border-l-2 border-gray-300 dark:border-gray-500 z-0"
+				></div>
+
+				<button
+					class="ml-12 flex items-center space-x-2 relative z-0"
+					on:click={() => {
+						const messageElement = document.getElementById(
+							`message-${message.reply_to_message.id}`
+						);
+						if (messageElement) {
+							messageElement.scrollIntoView({ behavior: 'smooth', block: 'center' });
+							messageElement.classList.add('highlight');
+							setTimeout(() => {
+								messageElement.classList.remove('highlight');
+							}, 2000);
+							return;
+						}
+					}}
+				>
+					{#if message?.reply_to_message?.meta?.model_id}
+						<img
+							src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${message.reply_to_message.meta.model_id}`}
+							alt={message.reply_to_message.meta.model_name ??
+								message.reply_to_message.meta.model_id}
+							class="size-4 ml-0.5 rounded-full object-cover"
+						/>
+					{:else}
+						<img
+							src={message.reply_to_message.user?.profile_image_url ??
+								`${WEBUI_BASE_URL}/static/favicon.png`}
+							alt={message.reply_to_message.user?.name ?? $i18n.t('Unknown User')}
+							class="size-4 ml-0.5 rounded-full object-cover"
+						/>
+					{/if}
+
+					<div class="shrink-0">
+						{message?.reply_to_message.meta?.model_name ??
+							message?.reply_to_message.user?.name ??
+							$i18n.t('Unknown User')}
+					</div>
+
+					<div class="italic text-sm text-gray-500 dark:text-gray-400 line-clamp-1 w-full flex-1">
+						<Markdown id={`${message.id}-reply-to`} content={message?.reply_to_message?.content} />
+					</div>
+				</button>
+			</div>
+		{/if}
 		<div
 			class=" flex w-full message-{message.id}"
 			id="message-{message.id}"
@@ -151,7 +221,7 @@
 						<ProfilePreview user={message.user}>
 							<ProfileImage
 								src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
-								className={'size-8 translate-y-1 ml-0.5'}
+								className={'size-8 ml-0.5'}
 							/>
 						</ProfilePreview>
 					{/if}
@@ -348,3 +418,18 @@
 		</div>
 	</div>
 {/if}
+
+<style>
+	.highlight {
+		animation: highlightAnimation 2s ease-in-out;
+	}
+
+	@keyframes highlightAnimation {
+		0% {
+			background-color: rgba(0, 60, 255, 0.1);
+		}
+		100% {
+			background-color: transparent;
+		}
+	}
+</style>

+ 18 - 3
src/lib/components/channel/Thread.svelte

@@ -22,11 +22,14 @@
 	let messages = null;
 	let top = false;
 
+	let messagesContainerElement = null;
+	let chatInputElement = null;
+
+	let replyToMessage = null;
+
 	let typingUsers = [];
 	let typingUsersTimeout = {};
 
-	let messagesContainerElement = null;
-
 	$: if (threadId) {
 		initHandler();
 	}
@@ -128,12 +131,15 @@
 
 		const res = await sendMessage(localStorage.token, channel.id, {
 			parent_id: threadId,
+			reply_to_id: replyToMessage?.id ?? null,
 			content: content,
 			data: data
 		}).catch((error) => {
 			toast.error(`${error}`);
 			return null;
 		});
+
+		replyToMessage = null;
 	};
 
 	const onChange = async () => {
@@ -180,9 +186,16 @@
 				<Messages
 					id={threadId}
 					{channel}
-					{messages}
 					{top}
+					{messages}
+					{replyToMessage}
 					thread={true}
+					onReply={async (message) => {
+						replyToMessage = message;
+
+						await tick();
+						chatInputElement?.focus();
+					}}
 					onLoad={async () => {
 						const newMessages = await getChannelThreadMessages(
 							localStorage.token,
@@ -207,6 +220,8 @@
 
 			<div class=" pb-[1rem] px-2.5 w-full">
 				<MessageInput
+					bind:replyToMessage
+					bind:chatInputElement
 					id={threadId}
 					disabled={!channel?.write_access}
 					placeholder={!channel?.write_access

+ 20 - 0
src/lib/components/icons/ArrowUpLeftAlt.svelte

@@ -0,0 +1,20 @@
+<script lang="ts">
+	export let className = 'w-4 h-4';
+	export let strokeWidth = '1.5';
+</script>
+
+<svg
+	class={className}
+	aria-hidden="true"
+	xmlns="http://www.w3.org/2000/svg"
+	stroke-width={strokeWidth}
+	fill="none"
+	stroke="currentColor"
+	viewBox="0 0 24 24"
+	><path d="M10.25 4.75L6.75 8.25L10.25 11.75" stroke-linecap="round" stroke-linejoin="round"
+	></path><path
+		d="M6.75 8.25L12.75 8.25C14.9591 8.25 16.75 10.0409 16.75 12.25V19.25"
+		stroke-linecap="round"
+		stroke-linejoin="round"
+	></path></svg
+>