Browse Source

feat: follow ups backend integration

Timothy Jaeryang Baek 3 weeks ago
parent
commit
185249623b

+ 5 - 4
backend/open_webui/config.py

@@ -1419,12 +1419,13 @@ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
 )
 
 DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
-Suggest 3-5 relevant follow-up questions or discussion prompts based on the chat history to help continue or deepen the conversation.
+SSuggest 3-5 relevant follow-up questions or prompts that the **user** might naturally ask next in this conversation, based on the chat history, to help continue or deepen the discussion.
 ### Guidelines:
+- Phrase all follow-up questions from the user’s perspective, addressed to the assistant or expert.
 - Make questions concise, clear, and directly related to the discussed topic(s).
-- Only generate follow-ups that make sense given the chat content and do not repeat what was already covered.
-- If the conversation is very short or not specific, suggest more general follow-ups.
-- Use the chat's primary language; default to English if multilingual.
+- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
+- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
+- Use the conversation's primary language; default to English if multilingual.
 - Response must be a JSON array of strings, no extra text or formatting.
 ### Output:
 JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }

+ 54 - 0
backend/open_webui/utils/middleware.py

@@ -32,6 +32,7 @@ from open_webui.socket.main import (
 from open_webui.routers.tasks import (
     generate_queries,
     generate_title,
+    generate_follow_ups,
     generate_image_prompt,
     generate_chat_tags,
 )
@@ -1104,6 +1105,59 @@ async def process_chat_response(
                             }
                         )
 
+                if (
+                    TASKS.FOLLOW_UP_GENERATION in tasks
+                    and tasks[TASKS.FOLLOW_UP_GENERATION]
+                ):
+                    res = await generate_follow_ups(
+                        request,
+                        {
+                            "model": message["model"],
+                            "messages": messages,
+                            "message_id": metadata["message_id"],
+                            "chat_id": metadata["chat_id"],
+                        },
+                        user,
+                    )
+
+                    if res and isinstance(res, dict):
+                        if len(res.get("choices", [])) == 1:
+                            follow_ups_string = (
+                                res.get("choices", [])[0]
+                                .get("message", {})
+                                .get("content", "")
+                            )
+                        else:
+                            follow_ups_string = ""
+
+                        follow_ups_string = follow_ups_string[
+                            follow_ups_string.find("{") : follow_ups_string.rfind("}")
+                            + 1
+                        ]
+
+                        try:
+                            follow_ups = json.loads(follow_ups_string).get(
+                                "follow_ups", []
+                            )
+                            Chats.upsert_message_to_chat_by_id_and_message_id(
+                                metadata["chat_id"],
+                                metadata["message_id"],
+                                {
+                                    "followUps": follow_ups,
+                                },
+                            )
+
+                            await event_emitter(
+                                {
+                                    "type": "chat:message:follow_ups",
+                                    "data": {
+                                        "follow_ups": follow_ups,
+                                    },
+                                }
+                            )
+                        except Exception as e:
+                            pass
+
                 if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
                     res = await generate_chat_tags(
                         request,

+ 72 - 0
src/lib/apis/index.ts

@@ -612,6 +612,78 @@ export const generateTitle = async (
 	}
 };
 
+export const generateFollowUps = async (
+	token: string = '',
+	model: string,
+	messages: string,
+	chat_id?: string
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/v1/tasks/follow_ups/completions`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			model: model,
+			messages: messages,
+			...(chat_id && { chat_id: chat_id })
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.error(err);
+			if ('detail' in err) {
+				error = err.detail;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	try {
+		// Step 1: Safely extract the response string
+		const response = res?.choices[0]?.message?.content ?? '';
+
+		// Step 2: Attempt to fix common JSON format issues like single quotes
+		const sanitizedResponse = response.replace(/['‘’`]/g, '"'); // Convert single quotes to double quotes for valid JSON
+
+		// Step 3: Find the relevant JSON block within the response
+		const jsonStartIndex = sanitizedResponse.indexOf('{');
+		const jsonEndIndex = sanitizedResponse.lastIndexOf('}');
+
+		// Step 4: Check if we found a valid JSON block (with both `{` and `}`)
+		if (jsonStartIndex !== -1 && jsonEndIndex !== -1) {
+			const jsonResponse = sanitizedResponse.substring(jsonStartIndex, jsonEndIndex + 1);
+
+			// Step 5: Parse the JSON block
+			const parsed = JSON.parse(jsonResponse);
+
+			// Step 6: If there's a "follow_ups" key, return the follow_ups array; otherwise, return an empty array
+			if (parsed && parsed.follow_ups) {
+				return Array.isArray(parsed.follow_ups) ? parsed.follow_ups : [];
+			} else {
+				return [];
+			}
+		}
+
+		// If no valid JSON block found, return an empty array
+		return [];
+	} catch (e) {
+		// Catch and safely return empty array on any parsing errors
+		console.error('Failed to parse response: ', e);
+		return [];
+	}
+};
+
 export const generateTags = async (
 	token: string = '',
 	model: string,

+ 13 - 10
src/lib/components/chat/Chat.svelte

@@ -300,6 +300,8 @@
 					message.content = data.content;
 				} else if (type === 'chat:message:files' || type === 'files') {
 					message.files = data.files;
+				} else if (type === 'chat:message:follow_ups') {
+					message.followUps = data.follow_ups;
 				} else if (type === 'chat:title') {
 					chatTitle.set(data);
 					currentChatPage.set(1);
@@ -1678,19 +1680,20 @@
 				chat_id: $chatId,
 				id: responseMessageId,
 
-				...(!$temporaryChatEnabled &&
-				(messages.length == 1 ||
-					(messages.length == 2 &&
-						messages.at(0)?.role === 'system' &&
-						messages.at(1)?.role === 'user')) &&
-				(selectedModels[0] === model.id || atSelectedModel !== undefined)
-					? {
-							background_tasks: {
+				background_tasks: {
+					...(!$temporaryChatEnabled &&
+					(messages.length == 1 ||
+						(messages.length == 2 &&
+							messages.at(0)?.role === 'system' &&
+							messages.at(1)?.role === 'user')) &&
+					(selectedModels[0] === model.id || atSelectedModel !== undefined)
+						? {
 								title_generation: $settings?.title?.auto ?? true,
 								tags_generation: $settings?.autoTags ?? true
 							}
-						}
-					: {}),
+						: {}),
+					follow_up_generation: $settings?.autoFollowUps ?? true
+				},
 
 				...(stream && (model.info?.meta?.capabilities?.usage ?? false)
 					? {

+ 1 - 1
src/lib/components/chat/Messages/ProfileImage.svelte

@@ -15,7 +15,7 @@
 			  src.startsWith('/')
 			? src
 			: `/user.png`}
-	class=" {className} object-cover rounded-full -translate-y-[1px]"
+	class=" {className} object-cover rounded-full"
 	alt="profile"
 	draggable="false"
 />

+ 14 - 1
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -48,6 +48,8 @@
 	import ContentRenderer from './ContentRenderer.svelte';
 	import { KokoroWorker } from '$lib/workers/KokoroWorker';
 	import FileItem from '$lib/components/common/FileItem.svelte';
+	import FollowUps from './ResponseMessage/FollowUps.svelte';
+	import { fade } from 'svelte/transition';
 
 	interface MessageType {
 		id: string;
@@ -606,7 +608,7 @@
 			/>
 		</div>
 
-		<div class="flex-auto w-0 pl-1 relative">
+		<div class="flex-auto w-0 pl-1 relative -translate-y-0.5">
 			<Name>
 				<Tooltip content={model?.name ?? message.model} placement="top-start">
 					<span class="line-clamp-1 text-black dark:text-white">
@@ -1425,6 +1427,17 @@
 							}}
 						/>
 					{/if}
+
+					{#if isLastMessage && message.done && !readOnly && (message?.followUps ?? []).length > 0}
+						<div class="mt-2.5" in:fade={{ duration: 50 }}>
+							<FollowUps
+								followUps={message?.followUps}
+								onClick={() => {
+									console.log('Follow-ups clicked');
+								}}
+							/>
+						</div>
+					{/if}
 				{/if}
 			</div>
 

+ 27 - 0
src/lib/components/chat/Messages/ResponseMessage/FollowUps.svelte

@@ -0,0 +1,27 @@
+<script lang="ts">
+	import { onMount, tick, getContext } from 'svelte';
+
+	const i18n = getContext('i18n');
+
+	export let followUps: string[] = [];
+	export let onClick: (followUp: string) => void = () => {};
+</script>
+
+<div class="mt-2">
+	<div class="text-xs text-gray-500 font-medium">
+		{$i18n.t('Follow up')}
+	</div>
+
+	<div class="flex flex-col text-left gap-1 mt-1.5">
+		{#each followUps as followUp, idx (idx)}
+			<button
+				class=" mr-2 py-1 bg-transparent text-left text-xs"
+				on:click={() => onClick(followUp)}
+				title={followUp}
+				aria-label={followUp}
+			>
+				{followUp}
+			</button>
+		{/each}
+	</div>
+</div>

+ 26 - 0
src/lib/components/chat/Settings/Interface.svelte

@@ -17,6 +17,7 @@
 
 	// Addons
 	let titleAutoGenerate = true;
+	let autoFollowUps = true;
 	let autoTags = true;
 
 	let responseAutoCopy = false;
@@ -197,6 +198,11 @@
 		});
 	};
 
+	const toggleAutoFollowUps = async () => {
+		autoFollowUps = !autoFollowUps;
+		saveSettings({ autoFollowUps });
+	};
+
 	const toggleAutoTags = async () => {
 		autoTags = !autoTags;
 		saveSettings({ autoTags });
@@ -619,6 +625,26 @@
 				</div>
 			</div>
 
+			<div>
+				<div class=" py-0.5 flex w-full justify-between">
+					<div class=" self-center text-xs">{$i18n.t('Follow-Up Auto-Generation')}</div>
+
+					<button
+						class="p-1 px-3 text-xs flex rounded-sm transition"
+						on:click={() => {
+							toggleAutoFollowUps();
+						}}
+						type="button"
+					>
+						{#if autoFollowUps === true}
+							<span class="ml-2 self-center">{$i18n.t('On')}</span>
+						{:else}
+							<span class="ml-2 self-center">{$i18n.t('Off')}</span>
+						{/if}
+					</button>
+				</div>
+			</div>
+
 			<div>
 				<div class=" py-0.5 flex w-full justify-between">
 					<div class=" self-center text-xs">{$i18n.t('Chat Tags Auto-Generation')}</div>