Browse Source

feat: add in-message progress indicator for web search

Jun Siang Cheah 1 year ago
parent
commit
3baeda7edc

+ 2 - 11
src/lib/apis/rag/index.ts

@@ -519,9 +519,7 @@ export const runWebSearch = async (
 	query: string,
 	collection_name?: string
 ): Promise<SearchDocument | undefined> => {
-	let error = null;
-
-	const res = await fetch(`${RAG_API_BASE_URL}/websearch`, {
+	return await fetch(`${RAG_API_BASE_URL}/websearch`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'application/json',
@@ -529,7 +527,7 @@ export const runWebSearch = async (
 		},
 		body: JSON.stringify({
 			query,
-			collection_name
+			collection_name: collection_name ?? ''
 		})
 	})
 		.then(async (res) => {
@@ -538,15 +536,8 @@ export const runWebSearch = async (
 		})
 		.catch((err) => {
 			console.log(err);
-			error = err.detail;
 			return undefined;
 		});
-
-	if (error) {
-		throw error;
-	}
-
-	return res;
 };
 
 export interface SearchDocument {

+ 56 - 0
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -369,6 +369,62 @@
 				class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:m-0 prose-p:-mb-6 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-8 prose-ol:p-0 prose-li:-mb-4 whitespace-pre-line"
 			>
 				<div>
+					{#if message.progress}
+						<div class="my-2.5 w-full flex overflow-x-auto gap-2 flex-wrap">
+							<div>
+								<button
+									class="h-16  flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left"
+									type="button"
+								>
+									<div class="p-2.5 bg-red-400 text-white rounded-lg">
+										<svg
+											class=" w-6 h-6 translate-y-[0.5px]"
+											fill="currentColor"
+											viewBox="0 0 24 24"
+											xmlns="http://www.w3.org/2000/svg"
+											><style>
+												.spinner_qM83 {
+													animation: spinner_8HQG 1.05s infinite;
+												}
+												.spinner_oXPr {
+													animation-delay: 0.1s;
+												}
+												.spinner_ZTLf {
+													animation-delay: 0.2s;
+												}
+												@keyframes spinner_8HQG {
+													0%,
+													57.14% {
+														animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
+														transform: translate(0);
+													}
+													28.57% {
+														animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
+														transform: translateY(-6px);
+													}
+													100% {
+														transform: translate(0);
+													}
+												}
+											</style><circle class="spinner_qM83" cx="4" cy="12" r="2.5" /><circle
+												class="spinner_qM83 spinner_oXPr"
+												cx="12"
+												cy="12"
+												r="2.5"
+											/><circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="2.5" /></svg
+										>
+									</div>
+
+									<div class="flex flex-col justify-center -space-y-0.5">
+										<div class=" dark:text-gray-100 text-sm font-medium line-clamp-2 text-wrap">
+											{message.progress}
+										</div>
+									</div>
+								</button>
+							</div>
+						</div>
+					{/if}
+
 					{#if edit === true}
 						<div class=" w-full">
 							<textarea

+ 41 - 31
src/routes/(app)/+page.svelte

@@ -31,7 +31,11 @@
 		updateChatById
 	} from '$lib/apis/chats';
 	import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag';
-	import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
+	import {
+		generateOpenAIChatCompletion,
+		generateSearchQuery,
+		generateTitle
+	} from '$lib/apis/openai';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
 	import Messages from '$lib/components/chat/Messages.svelte';
@@ -286,36 +290,7 @@
 					}
 
 					if (useWebSearch) {
-						// TODO: Toasts are temporary indicators for web search
-						toast.info($i18n.t('Generating search query'));
-						const searchQuery = await generateChatSearchQuery(prompt);
-						if (searchQuery) {
-							toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
-							const searchDocUuid = uuidv4();
-							const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
-							if (searchDocument) {
-								const parentMessage = history.messages[parentId];
-								if (!parentMessage.files) {
-									parentMessage.files = [];
-								}
-								parentMessage.files.push({
-									collection_name: searchDocument.collection_name,
-									name: searchQuery,
-									type: 'doc',
-									upload_status: true,
-									error: ""
-								});
-								// Find message in messages and update it
-								const messageIndex = messages.findIndex((message) => message.id === parentId);
-								if (messageIndex !== -1) {
-									messages[messageIndex] = parentMessage;
-								}
-							} else {
-								toast.warning($i18n.t('No search results found'));
-							}
-						} else {
-							toast.warning($i18n.t('No search query generated'));
-						}
+						await runWebSearchForPrompt(parentId, responseMessageId, prompt);
 					}
 
 					if (model?.external) {
@@ -332,6 +307,41 @@
 		await chats.set(await getChatList(localStorage.token));
 	};
 
+	const runWebSearchForPrompt = async (parentId: string, responseId: string, prompt: string) => {
+		const responseMessage = history.messages[responseId];
+		responseMessage.progress = $i18n.t('Generating search query');
+		messages = messages;
+		const searchQuery = await generateChatSearchQuery(prompt);
+		if (!searchQuery) {
+			toast.warning($i18n.t('No search query generated'));
+			responseMessage.progress = undefined;
+			messages = messages;
+			return;
+		}
+		responseMessage.progress = $i18n.t("Searching the web for '{{searchQuery}}'", { searchQuery });
+		messages = messages;
+		const searchDocument = await runWebSearch(localStorage.token, searchQuery);
+		if (!searchDocument) {
+			toast.warning($i18n.t('No search results found'));
+			responseMessage.progress = undefined;
+			messages = messages;
+			return;
+		}
+		const parentMessage = history.messages[parentId];
+		if (!parentMessage.files) {
+			parentMessage.files = [];
+		}
+		parentMessage.files.push({
+			collection_name: searchDocument!.collection_name,
+			name: searchQuery,
+			type: 'doc',
+			upload_status: true,
+			error: ''
+		});
+		responseMessage.progress = undefined;
+		messages = messages;
+	};
+
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
 		model = model.id;
 		const responseMessage = history.messages[responseMessageId];

+ 39 - 30
src/routes/(app)/c/[id]/+page.svelte

@@ -291,36 +291,7 @@
 					}
 
 					if (useWebSearch) {
-						// TODO: Toasts are temporary indicators for web search
-						toast.info($i18n.t('Generating search query'));
-						const searchQuery = await generateChatSearchQuery(prompt);
-						if (searchQuery) {
-							toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
-							const searchDocUuid = uuidv4();
-							const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
-							if (searchDocument) {
-								const parentMessage = history.messages[parentId];
-								if (!parentMessage.files) {
-									parentMessage.files = [];
-								}
-								parentMessage.files.push({
-									collection_name: searchDocument.collection_name,
-									name: searchQuery,
-									type: 'doc',
-									upload_status: true,
-									error: ""
-								});
-								// Find message in messages and update it
-								const messageIndex = messages.findIndex((message) => message.id === parentId);
-								if (messageIndex !== -1) {
-									messages[messageIndex] = parentMessage;
-								}
-							} else {
-								toast.warning($i18n.t('No search results found'));
-							}
-						} else {
-							toast.warning($i18n.t('No search query generated'));
-						}
+						await runWebSearchForPrompt(parentId, responseMessageId, prompt);
 					}
 
 					if (model?.external) {
@@ -337,6 +308,44 @@
 		await chats.set(await getChatList(localStorage.token));
 	};
 
+	const runWebSearchForPrompt = async (parentId: string, responseId: string, prompt: string) => {
+		const responseMessage = history.messages[responseId];
+		responseMessage.progress = $i18n.t('Generating search query');
+		messages = messages;
+		const searchQuery = await generateChatSearchQuery(prompt);
+		if (!searchQuery) {
+			toast.warning($i18n.t('No search query generated'));
+			responseMessage.progress = undefined;
+			messages = messages;
+			return;
+		}
+		responseMessage.progress = $i18n.t("Searching the web for '{{searchQuery}}'", { searchQuery });
+		messages = messages;
+		const searchDocument = await runWebSearch(
+			localStorage.token,
+			searchQuery,
+		);
+		if (!searchDocument) {
+			toast.warning($i18n.t('No search results found'));
+			responseMessage.progress = undefined;
+			messages = messages;
+			return;
+		}
+		const parentMessage = history.messages[parentId];
+		if (!parentMessage.files) {
+			parentMessage.files = [];
+		}
+		parentMessage.files.push({
+			collection_name: searchDocument!.collection_name,
+			name: searchQuery,
+			type: 'doc',
+			upload_status: true,
+			error: ''
+		});
+		responseMessage.progress = undefined;
+		messages = messages;
+	};
+
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
 		model = model.id;
 		const responseMessage = history.messages[responseMessageId];