Преглед изворни кода

feat: channel/thread @ model

Timothy Jaeryang Baek пре 3 недеља
родитељ
комит
4fe97d8794

+ 8 - 2
backend/open_webui/models/messages.py

@@ -201,8 +201,14 @@ class MessageTable:
         with get_db() as db:
             message = db.get(Message, id)
             message.content = form_data.content
-            message.data = form_data.data
-            message.meta = form_data.meta
+            message.data = {
+                **(message.data if message.data else {}),
+                **(form_data.data if form_data.data else {}),
+            }
+            message.meta = {
+                **(message.meta if message.meta else {}),
+                **(form_data.meta if form_data.meta else {}),
+            }
             message.updated_at = int(time.time_ns())
             db.commit()
             db.refresh(message)

+ 166 - 18
backend/open_webui/routers/channels.py

@@ -24,9 +24,17 @@ from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import SRC_LOG_LEVELS
 
 
+from open_webui.utils.models import (
+    get_all_models,
+    get_filtered_models,
+)
+from open_webui.utils.chat import generate_chat_completion
+
+
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, get_users_with_access
 from open_webui.utils.webhook import post_webhook
+from open_webui.utils.channels import extract_mentions, replace_mentions
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -221,13 +229,131 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
     return True
 
 
-@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
-async def post_new_message(
-    request: Request,
-    id: str,
-    form_data: MessageForm,
-    background_tasks: BackgroundTasks,
-    user=Depends(get_verified_user),
+async def model_response_handler(request, channel, message, user):
+    MODELS = {
+        model["id"]: model
+        for model in get_filtered_models(await get_all_models(request, user=user), user)
+    }
+
+    mentions = extract_mentions(message.content)
+    message_content = replace_mentions(message.content)
+
+    # check if any of the mentions are models
+    model_mentions = [mention for mention in mentions if mention["id_type"] == "M"]
+    if not model_mentions:
+        return False
+
+    for mention in model_mentions:
+        model_id = mention["id"]
+        model = MODELS.get(model_id, None)
+
+        if model:
+            try:
+                # reverse to get in chronological order
+                thread_messages = Messages.get_messages_by_parent_id(
+                    channel.id,
+                    message.parent_id if message.parent_id else message.id,
+                )[::-1]
+
+                response_message, channel = await new_message_handler(
+                    request,
+                    channel.id,
+                    MessageForm(
+                        **{
+                            "parent_id": (
+                                message.parent_id if message.parent_id else message.id
+                            ),
+                            "content": f"",
+                            "data": {},
+                            "meta": {
+                                "model_id": model_id,
+                                "model_name": model.get("name", model_id),
+                            },
+                        }
+                    ),
+                    user,
+                )
+
+                thread_history = []
+                message_users = {}
+
+                for thread_message in thread_messages:
+                    message_user = None
+                    if thread_message.user_id not in message_users:
+                        message_user = Users.get_user_by_id(thread_message.user_id)
+                        message_users[thread_message.user_id] = message_user
+                    else:
+                        message_user = message_users[thread_message.user_id]
+
+                    if thread_message.meta and thread_message.meta.get(
+                        "model_id", None
+                    ):
+                        # If the message was sent by a model, use the model name
+                        message_model_id = thread_message.meta.get("model_id", None)
+                        message_model = MODELS.get(message_model_id, None)
+                        username = (
+                            message_model.get("name", message_model_id)
+                            if message_model
+                            else message_model_id
+                        )
+                    else:
+                        username = message_user.name if message_user else "Unknown"
+
+                    thread_history.append(
+                        f"{username}: {replace_mentions(thread_message.content)}"
+                    )
+
+                system_message = {
+                    "role": "system",
+                    "content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational."
+                    + (
+                        f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally, addressing the most recent message while being aware of the full context."
+                        if thread_history
+                        else ""
+                    ),
+                }
+
+                form_data = {
+                    "model": model_id,
+                    "messages": [
+                        system_message,
+                        {
+                            "role": "user",
+                            "content": f"{user.name if user else 'User'}: {message_content}",
+                        },
+                    ],
+                    "stream": False,
+                }
+
+                res = await generate_chat_completion(
+                    request,
+                    form_data=form_data,
+                    user=user,
+                )
+
+                if res:
+                    await update_message_by_id(
+                        channel.id,
+                        response_message.id,
+                        MessageForm(
+                            **{
+                                "content": res["choices"][0]["message"]["content"],
+                                "meta": {
+                                    "done": True,
+                                },
+                            }
+                        ),
+                        user,
+                    )
+            except Exception as e:
+                log.info(e)
+                pass
+
+    return True
+
+
+async def new_message_handler(
+    request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user)
 ):
     channel = Channels.get_channel_by_id(id)
     if not channel:
@@ -301,21 +427,43 @@ async def post_new_message(
                         },
                         to=f"channel:{channel.id}",
                     )
+        return MessageModel(**message.model_dump()), channel
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
 
-            active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
 
-            async def background_handler():
-                await send_notification(
-                    request.app.state.WEBUI_NAME,
-                    request.app.state.config.WEBUI_URL,
-                    channel,
-                    message,
-                    active_user_ids,
-                )
+@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
+async def post_new_message(
+    request: Request,
+    id: str,
+    form_data: MessageForm,
+    background_tasks: BackgroundTasks,
+    user=Depends(get_verified_user),
+):
+
+    try:
+        message, channel = await new_message_handler(request, id, form_data, user)
+        active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
+
+        async def background_handler():
+            await model_response_handler(request, channel, message, user)
+            await send_notification(
+                request.app.state.WEBUI_NAME,
+                request.app.state.config.WEBUI_URL,
+                channel,
+                message,
+                active_user_ids,
+            )
 
-            background_tasks.add_task(background_handler)
+        background_tasks.add_task(background_handler)
 
-        return MessageModel(**message.model_dump())
+        return message
+
+    except HTTPException as e:
+        raise e
     except Exception as e:
         log.exception(e)
         raise HTTPException(

+ 38 - 2
backend/open_webui/routers/models.py

@@ -1,4 +1,6 @@
 from typing import Optional
+import io
+import base64
 
 from open_webui.models.models import (
     ModelForm,
@@ -10,12 +12,13 @@ from open_webui.models.models import (
 
 from pydantic import BaseModel
 from open_webui.constants import ERROR_MESSAGES
-from fastapi import APIRouter, Depends, HTTPException, Request, status
+from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
+from fastapi.responses import FileResponse, StreamingResponse
 
 
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
-from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
+from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
 
 router = APIRouter()
 
@@ -129,6 +132,39 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
         )
 
 
+###########################
+# GetModelById
+###########################
+
+
+@router.get("/model/profile/image")
+async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
+    model = Models.get_model_by_id(id)
+    if model:
+        if model.meta.profile_image_url:
+            if model.meta.profile_image_url.startswith("http"):
+                return Response(
+                    status_code=status.HTTP_302_FOUND,
+                    headers={"Location": model.meta.profile_image_url},
+                )
+            elif model.meta.profile_image_url.startswith("data:image"):
+                try:
+                    header, base64_data = model.meta.profile_image_url.split(",", 1)
+                    image_data = base64.b64decode(base64_data)
+                    image_buffer = io.BytesIO(image_data)
+
+                    return StreamingResponse(
+                        image_buffer,
+                        media_type="image/png",
+                        headers={"Content-Disposition": "inline; filename=image.png"},
+                    )
+                except Exception as e:
+                    pass
+        return FileResponse(f"{STATIC_DIR}/favicon.png")
+    else:
+        return FileResponse(f"{STATIC_DIR}/favicon.png")
+
+
 ############################
 # ToggleModelById
 ############################

+ 31 - 0
backend/open_webui/utils/channels.py

@@ -0,0 +1,31 @@
+import re
+
+
+def extract_mentions(message: str, triggerChar: str = "@"):
+    # Escape triggerChar in case it's a regex special character
+    triggerChar = re.escape(triggerChar)
+    pattern = rf"<{triggerChar}([A-Z]):([^|>]+)"
+
+    matches = re.findall(pattern, message)
+    return [{"id_type": id_type, "id": id_value} for id_type, id_value in matches]
+
+
+def replace_mentions(message: str, triggerChar: str = "@", use_label: bool = True):
+    """
+    Replace mentions in the message with either their label (after the pipe `|`)
+    or their id if no label exists.
+
+    Example:
+      "<@M:gpt-4.1|GPT-4>" -> "GPT-4"   (if use_label=True)
+      "<@M:gpt-4.1|GPT-4>" -> "gpt-4.1" (if use_label=False)
+    """
+    # Escape triggerChar
+    triggerChar = re.escape(triggerChar)
+
+    def replacer(match):
+        id_type, id_value, label = match.groups()
+        return label if use_label and label else id_value
+
+    # Regex captures: idType, id, optional label
+    pattern = rf"<{triggerChar}([A-Z]):([^|>]+)(?:\|([^>]+))?>"
+    return re.sub(pattern, replacer, message)

+ 2 - 1
src/lib/components/channel/Messages.svelte

@@ -95,7 +95,8 @@
 				{message}
 				{thread}
 				showUserProfile={messageIdx === 0 ||
-					messageList.at(messageIdx - 1)?.user_id !== message.user_id}
+					messageList.at(messageIdx - 1)?.user_id !== message.user_id ||
+					messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id}
 				onDelete={() => {
 					messages = messages.filter((m) => m.id !== message.id);
 

+ 31 - 13
src/lib/components/channel/Messages/Message.svelte

@@ -15,7 +15,7 @@
 
 	import { settings, user, shortCodesToEmojis } from '$lib/stores';
 
-	import { WEBUI_BASE_URL } from '$lib/constants';
+	import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
 
 	import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
 	import ProfileImage from '$lib/components/chat/Messages/ProfileImage.svelte';
@@ -34,6 +34,8 @@
 	import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
 	import { formatDate } from '$lib/utils';
 	import Emoji from '$lib/components/common/Emoji.svelte';
+	import { t } from 'i18next';
+	import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte';
 
 	export let message;
 	export let showUserProfile = true;
@@ -138,12 +140,20 @@
 		>
 			<div class={`shrink-0 mr-3 w-9`}>
 				{#if showUserProfile}
-					<ProfilePreview user={message.user}>
-						<ProfileImage
-							src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
-							className={'size-8 translate-y-1 ml-0.5'}
+					{#if message?.meta?.model_id}
+						<img
+							src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${message.meta.model_id}`}
+							alt={message.meta.model_name ?? message.meta.model_id}
+							class="size-8 translate-y-1 ml-0.5 object-cover rounded-full"
 						/>
-					</ProfilePreview>
+					{:else}
+						<ProfilePreview user={message.user}>
+							<ProfileImage
+								src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
+								className={'size-8 translate-y-1 ml-0.5'}
+							/>
+						</ProfilePreview>
+					{/if}
 				{:else}
 					<!-- <div class="w-7 h-7 rounded-full bg-transparent" /> -->
 
@@ -163,7 +173,11 @@
 				{#if showUserProfile}
 					<Name>
 						<div class=" self-end text-base shrink-0 font-medium truncate">
-							{message?.user?.name}
+							{#if message?.meta?.model_id}
+								{message?.meta?.model_name ?? message?.meta?.model_id}
+							{:else}
+								{message?.user?.name}
+							{/if}
 						</div>
 
 						{#if message.created_at}
@@ -251,12 +265,16 @@
 					</div>
 				{:else}
 					<div class=" min-w-full markdown-prose">
-						<Markdown
-							id={message.id}
-							content={message.content}
-						/>{#if message.created_at !== message.updated_at}<span class="text-gray-500 text-[10px]"
-								>(edited)</span
-							>{/if}
+						{#if (message?.content ?? '').trim() === '' && message?.meta?.model_id}
+							<Skeleton />
+						{:else}
+							<Markdown
+								id={message.id}
+								content={message.content}
+							/>{#if message.created_at !== message.updated_at && (message?.meta?.model_id ?? null) === null}<span
+									class="text-gray-500 text-[10px]">({$i18n.t('edited')})</span
+								>{/if}
+						{/if}
 					</div>
 
 					{#if (message?.reactions ?? []).length > 0}

+ 1 - 1
src/lib/components/channel/Messages/Message/UserStatusLinkPreview.svelte

@@ -27,7 +27,7 @@
 
 {#if user}
 	<LinkPreview.Content
-		class="w-full max-w-[260px] rounded-2xl border border-gray-100  dark:border-gray-800 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-lg transition"
+		class="w-full max-w-[260px] rounded-2xl border border-gray-100  dark:border-gray-800 z-999 bg-white dark:bg-gray-850 dark:text-white shadow-lg transition"
 		{side}
 		{align}
 		{sideOffset}