Browse Source

feat: chat history support

Timothy J. Baek 1 year ago
parent
commit
4e4076e267
1 changed files with 388 additions and 128 deletions
  1. 388 128
      src/routes/+page.svelte

+ 388 - 128
src/routes/+page.svelte

@@ -39,6 +39,22 @@
 	let title = '';
 	let prompt = '';
 	let messages = [];
+	let history = {
+		messages: {},
+		currentId: null
+	};
+
+	$: if (history.currentId !== null) {
+		let _messages = [];
+
+		let currentMessage = history.messages[history.currentId];
+		while (currentMessage !== null) {
+			_messages.unshift({ ...currentMessage });
+			currentMessage =
+				currentMessage.parentId !== null ? history.messages[currentMessage.parentId] : null;
+		}
+		messages = _messages;
+	}
 
 	let showSettings = false;
 	let stopResponseFlag = false;
@@ -260,8 +276,13 @@
 		if (init || messages.length > 0) {
 			chatId = uuidv4();
 			autoScroll = true;
-			messages = [];
+
 			title = '';
+			messages = [];
+			history = {
+				messages: {},
+				currentId: null
+			};
 
 			settings = JSON.parse(localStorage.getItem('settings') ?? JSON.stringify(settings));
 
@@ -311,18 +332,58 @@
 
 	const loadChat = async (id) => {
 		const chat = await db.get('chats', id);
+		console.log(chat);
 		if (chatId !== chat.id) {
-			if (chat.messages.length > 0) {
-				chat.messages.at(-1).done = true;
+			if ('history' in chat) {
+				history = chat.history;
+			} else {
+				let _history = {
+					messages: {},
+					currentId: null
+				};
+
+				let parentMessageId = null;
+				let messageId = null;
+
+				for (const message of chat.messages) {
+					messageId = uuidv4();
+
+					if (parentMessageId !== null) {
+						_history.messages[parentMessageId].childrenIds = [
+							..._history.messages[parentMessageId].childrenIds,
+							messageId
+						];
+					}
+
+					_history.messages[messageId] = {
+						...message,
+						id: messageId,
+						parentId: parentMessageId,
+						childrenIds: []
+					};
+
+					parentMessageId = messageId;
+				}
+				_history.currentId = messageId;
+
+				history = _history;
 			}
-			messages = chat.messages;
+
+			console.log(history);
+
 			title = chat.title;
 			chatId = chat.id;
 			selectedModel = chat.model ?? selectedModel;
 			settings.system = chat.system ?? settings.system;
 			settings.temperature = chat.temperature ?? settings.temperature;
+			autoScroll = true;
 
 			await tick();
+
+			if (messages.length > 0) {
+				history.messages[messages.at(-1).id].done = true;
+			}
+
 			renderLatex();
 
 			hljs.highlightAll();
@@ -368,7 +429,8 @@
 				options: chat.options,
 				title: chat.title,
 				timestamp: chat.timestamp,
-				messages: chat.messages
+				messages: chat.messages,
+				history: chat.history
 			});
 		}
 		chats = await db.getAllFromIndex('chats', 'timestamp');
@@ -386,35 +448,44 @@
 		showSettings = true;
 	};
 
-	const editMessage = async (messageIdx) => {
-		messages = messages.map((message, idx) => {
-			if (messageIdx === idx) {
-				message.edit = true;
-				message.editedContent = message.content;
-			}
-			return message;
-		});
+	const editMessageHandler = async (messageId) => {
+		// let editMessage = history.messages[messageId];
+		history.messages[messageId].edit = true;
+		history.messages[messageId].editedContent = history.messages[messageId].content;
 	};
 
-	const confirmEditMessage = async (messageIdx) => {
-		let userPrompt = messages.at(messageIdx).editedContent;
+	const confirmEditMessage = async (messageId) => {
+		history.messages[messageId].edit = false;
 
-		messages.splice(messageIdx, messages.length - messageIdx);
-		messages = messages;
+		let userPrompt = history.messages[messageId].editedContent;
+		let userMessageId = uuidv4();
 
-		await submitPrompt(userPrompt);
-	};
+		let userMessage = {
+			id: userMessageId,
+			parentId: history.messages[messageId].parentId,
+			childrenIds: [],
+			role: 'user',
+			content: userPrompt
+		};
 
-	const cancelEditMessage = (messageIdx) => {
-		messages = messages.map((message, idx) => {
-			if (messageIdx === idx) {
-				message.edit = undefined;
-				message.editedContent = undefined;
-			}
-			return message;
-		});
+		let messageParentId = history.messages[messageId].parentId;
 
-		console.log(messages);
+		if (messageParentId !== null) {
+			history.messages[messageParentId].childrenIds = [
+				...history.messages[messageParentId].childrenIds,
+				userMessageId
+			];
+		}
+
+		history.messages[userMessageId] = userMessage;
+		history.currentId = userMessageId;
+
+		await sendPrompt(userPrompt, userMessageId);
+	};
+
+	const cancelEditMessage = (messageId) => {
+		history.messages[messageId].edit = false;
+		history.messages[messageId].editedContent = undefined;
 	};
 
 	const rateMessage = async (messageIdx, rating) => {
@@ -434,12 +505,89 @@
 				temperature: settings.temperature
 			},
 			timestamp: Date.now(),
-			messages: messages
+			messages: messages,
+			history: history
 		});
 
 		console.log(messages);
 	};
 
+	const showPreviousMessage = async (message) => {
+		if (message.parentId !== null) {
+			let messageId =
+				history.messages[message.parentId].childrenIds[
+					Math.max(history.messages[message.parentId].childrenIds.indexOf(message.id) - 1, 0)
+				];
+
+			if (message.id !== messageId) {
+				let messageChildrenIds = history.messages[messageId].childrenIds;
+
+				while (messageChildrenIds.length !== 0) {
+					messageId = messageChildrenIds.at(-1);
+					messageChildrenIds = history.messages[messageId].childrenIds;
+				}
+
+				history.currentId = messageId;
+			}
+		} else {
+			let childrenIds = Object.values(history.messages)
+				.filter((message) => message.parentId === null)
+				.map((message) => message.id);
+			let messageId = childrenIds[Math.max(childrenIds.indexOf(message.id) - 1, 0)];
+
+			if (message.id !== messageId) {
+				let messageChildrenIds = history.messages[messageId].childrenIds;
+
+				while (messageChildrenIds.length !== 0) {
+					messageId = messageChildrenIds.at(-1);
+					messageChildrenIds = history.messages[messageId].childrenIds;
+				}
+
+				history.currentId = messageId;
+			}
+		}
+	};
+
+	const showNextMessage = async (message) => {
+		if (message.parentId !== null) {
+			let messageId =
+				history.messages[message.parentId].childrenIds[
+					Math.min(
+						history.messages[message.parentId].childrenIds.indexOf(message.id) + 1,
+						history.messages[message.parentId].childrenIds.length - 1
+					)
+				];
+
+			if (message.id !== messageId) {
+				let messageChildrenIds = history.messages[messageId].childrenIds;
+
+				while (messageChildrenIds.length !== 0) {
+					messageId = messageChildrenIds.at(-1);
+					messageChildrenIds = history.messages[messageId].childrenIds;
+				}
+
+				history.currentId = messageId;
+			}
+		} else {
+			let childrenIds = Object.values(history.messages)
+				.filter((message) => message.parentId === null)
+				.map((message) => message.id);
+			let messageId =
+				childrenIds[Math.min(childrenIds.indexOf(message.id) + 1, childrenIds.length - 1)];
+
+			if (message.id !== messageId) {
+				let messageChildrenIds = history.messages[messageId].childrenIds;
+
+				while (messageChildrenIds.length !== 0) {
+					messageId = messageChildrenIds.at(-1);
+					messageChildrenIds = history.messages[messageId].childrenIds;
+				}
+
+				history.currentId = messageId;
+			}
+		}
+	};
+
 	//////////////////////////
 	// Ollama functions
 	//////////////////////////
@@ -507,21 +655,46 @@
 		}
 	};
 
-	const sendPrompt = async (userPrompt) => {
+	const sendPrompt = async (userPrompt, parentId) => {
+		// await Promise.all(
+		// 	selectedModels.map((model) => {
+		// 		if (selectedModel.includes('gpt-')) {
+		// 			await sendPromptOpenAI(userPrompt, parentId);
+		// 		} else {
+		// 			await sendPromptOllama(userPrompt, parentId);
+		// 		}
+		// 	})
+		// );
+
 		if (selectedModel.includes('gpt-')) {
-			await sendPromptOpenAI(userPrompt);
+			await sendPromptOpenAI(userPrompt, parentId);
 		} else {
-			await sendPromptOllama(userPrompt);
+			await sendPromptOllama(userPrompt, parentId);
 		}
+
+		console.log(history);
 	};
 
-	const sendPromptOllama = async (userPrompt) => {
+	const sendPromptOllama = async (userPrompt, parentId) => {
+		let responseMessageId = uuidv4();
+
 		let responseMessage = {
+			parentId: parentId,
+			id: responseMessageId,
+			childrenIds: [],
 			role: 'assistant',
 			content: ''
 		};
 
-		messages = [...messages, responseMessage];
+		history.messages[responseMessageId] = responseMessage;
+		history.currentId = responseMessageId;
+		if (parentId !== null) {
+			history.messages[parentId].childrenIds = [
+				...history.messages[parentId].childrenIds,
+				responseMessageId
+			];
+		}
+
 		window.scrollTo({ top: document.body.scrollHeight });
 
 		const res = await fetch(`${API_BASE_URL}/generate`, {
@@ -542,8 +715,9 @@
 				},
 				format: settings.requestFormat ?? undefined,
 				context:
-					messages.length > 3 && messages.at(-3).context != undefined
-						? messages.at(-3).context
+					history.messages[parentId] !== null &&
+					history.messages[parentId].parentId in history.messages
+						? history.messages[history.messages[parentId].parentId]?.context ?? undefined
 						: undefined
 			})
 		});
@@ -608,7 +782,8 @@
 					temperature: settings.temperature
 				},
 				timestamp: Date.now(),
-				messages: messages
+				messages: messages,
+				history: history
 			});
 		}
 
@@ -715,7 +890,8 @@
 							temperature: settings.temperature
 						},
 						timestamp: Date.now(),
-						messages: messages
+						messages: messages,
+						history: history
 					});
 				}
 
@@ -747,13 +923,22 @@
 		} else {
 			document.getElementById('chat-textarea').style.height = '';
 
-			messages = [
-				...messages,
-				{
-					role: 'user',
-					content: userPrompt
-				}
-			];
+			let userMessageId = uuidv4();
+
+			let userMessage = {
+				id: userMessageId,
+				parentId: messages.length !== 0 ? messages.at(-1).id : null,
+				childrenIds: [],
+				role: 'user',
+				content: userPrompt
+			};
+
+			if (messages.length !== 0) {
+				history.messages[messages.at(-1).id].childrenIds.push(userMessageId);
+			}
+
+			history.messages[userMessageId] = userMessage;
+			history.currentId = userMessageId;
 
 			prompt = '';
 
@@ -767,7 +952,8 @@
 					},
 					title: 'New Chat',
 					timestamp: Date.now(),
-					messages: messages
+					messages: messages,
+					history: history
 				});
 				chats = await db.getAllFromIndex('chats', 'timestamp');
 			}
@@ -776,7 +962,7 @@
 				window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
 			}, 50);
 
-			await sendPrompt(userPrompt);
+			await sendPrompt(userPrompt, userMessageId);
 
 			chats = await db.getAllFromIndex('chats', 'timestamp');
 		}
@@ -791,7 +977,7 @@
 			let userMessage = messages.at(-1);
 			let userPrompt = userMessage.content;
 
-			await sendPrompt(userPrompt);
+			await sendPrompt(userPrompt, userMessage.id);
 
 			chats = await db.getAllFromIndex('chats', 'timestamp');
 		}
@@ -1078,7 +1264,7 @@
 															<div class=" w-full">
 																<textarea
 																	class=" bg-transparent outline-none w-full resize-none"
-																	bind:value={message.editedContent}
+																	bind:value={history.messages[message.id].editedContent}
 																	on:input={(e) => {
 																		e.target.style.height = '';
 																		e.target.style.height = `${e.target.scrollHeight}px`;
@@ -1093,7 +1279,7 @@
 																	<button
 																		class="px-4 py-2.5 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded-lg"
 																		on:click={() => {
-																			confirmEditMessage(messageIdx);
+																			confirmEditMessage(message.id);
 																		}}
 																	>
 																		Save & Submit
@@ -1102,7 +1288,7 @@
 																	<button
 																		class=" px-4 py-2.5 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 text-gray-700 dark:text-gray-100 transition outline outline-1 outline-gray-200 dark:outline-gray-600 rounded-lg"
 																		on:click={() => {
-																			cancelEditMessage(messageIdx);
+																			cancelEditMessage(message.id);
 																		}}
 																	>
 																		Cancel
@@ -1113,90 +1299,113 @@
 															<div class="w-full">
 																{message.content}
 
-																<!-- <div class=" flex justify-start space-x-1">
-																	<div class="flex self-center">
-																		<button
-																			class="self-center"
-																			on:click={() => {
-																				message.selectedContentIdx = Math.max(
-																					0,
-																					message.selectedContentIdx - 1
-																				);
-																				messages = messages;
-																			}}
-																		>
-																			<svg
-																				xmlns="http://www.w3.org/2000/svg"
-																				viewBox="0 0 20 20"
-																				fill="currentColor"
-																				class="w-4 h-4"
+																<div class=" flex justify-start space-x-1">
+																	{#if message.parentId !== null && message.parentId in history.messages && (history.messages[message.parentId]?.childrenIds.length ?? 0) > 1}
+																		<div class="flex self-center">
+																			<button
+																				class="self-center"
+																				on:click={() => {
+																					showPreviousMessage(message);
+																				}}
 																			>
-																				<path
-																					fill-rule="evenodd"
-																					d="M12.79 5.23a.75.75 0 01-.02 1.06L8.832 10l3.938 3.71a.75.75 0 11-1.04 1.08l-4.5-4.25a.75.75 0 010-1.08l4.5-4.25a.75.75 0 011.06.02z"
-																					clip-rule="evenodd"
-																				/>
-																			</svg>
-																		</button>
-
-																		<div class="text-xs font-bold self-center">
-																			{message.selectedContentIdx + 1} / {message.contents.length}
+																				<svg
+																					xmlns="http://www.w3.org/2000/svg"
+																					viewBox="0 0 20 20"
+																					fill="currentColor"
+																					class="w-4 h-4"
+																				>
+																					<path
+																						fill-rule="evenodd"
+																						d="M12.79 5.23a.75.75 0 01-.02 1.06L8.832 10l3.938 3.71a.75.75 0 11-1.04 1.08l-4.5-4.25a.75.75 0 010-1.08l4.5-4.25a.75.75 0 011.06.02z"
+																						clip-rule="evenodd"
+																					/>
+																				</svg>
+																			</button>
+
+																			<div class="text-xs font-bold self-center">
+																				{history.messages[message.parentId].childrenIds.indexOf(
+																					message.id
+																				) + 1} / {history.messages[message.parentId].childrenIds
+																					.length}
+																			</div>
+
+																			<button
+																				class="self-center"
+																				on:click={() => {
+																					showNextMessage(message);
+																				}}
+																			>
+																				<svg
+																					xmlns="http://www.w3.org/2000/svg"
+																					viewBox="0 0 20 20"
+																					fill="currentColor"
+																					class="w-4 h-4"
+																				>
+																					<path
+																						fill-rule="evenodd"
+																						d="M7.21 14.77a.75.75 0 01.02-1.06L11.168 10 7.23 6.29a.75.75 0 111.04-1.08l4.5 4.25a.75.75 0 010 1.08l-4.5 4.25a.75.75 0 01-1.06-.02z"
+																						clip-rule="evenodd"
+																					/>
+																				</svg>
+																			</button>
 																		</div>
-
-																		<button
-																			class="self-center"
-																			on:click={() => {
-																				message.selectedContentIdx = Math.min(
-																					message.contents.length - 1,
-																					message.selectedContentIdx + 1
-																				);
-																				messages = messages;
-
-																				console.log(message);
-																			}}
-																		>
-																			<svg
-																				xmlns="http://www.w3.org/2000/svg"
-																				viewBox="0 0 20 20"
-																				fill="currentColor"
-																				class="w-4 h-4"
+																	{:else if message.parentId === null && Object.values(history.messages).filter((message) => message.parentId === null).length > 1}
+																		<div class="flex self-center">
+																			<button
+																				class="self-center"
+																				on:click={() => {
+																					showPreviousMessage(message);
+																				}}
 																			>
-																				<path
-																					fill-rule="evenodd"
-																					d="M7.21 14.77a.75.75 0 01.02-1.06L11.168 10 7.23 6.29a.75.75 0 111.04-1.08l4.5 4.25a.75.75 0 010 1.08l-4.5 4.25a.75.75 0 01-1.06-.02z"
-																					clip-rule="evenodd"
-																				/>
-																			</svg>
-																		</button>
-																	</div>
-																	<button
-																		class="invisible group-hover:visible p-1 rounded dark:hover:bg-gray-800 transition"
-																		on:click={() => {
-																			editMessage(messageIdx);
-																		}}
-																	>
-																		<svg
-																			xmlns="http://www.w3.org/2000/svg"
-																			fill="none"
-																			viewBox="0 0 24 24"
-																			stroke-width="1.5"
-																			stroke="currentColor"
-																			class="w-4 h-4"
-																		>
-																			<path
-																				stroke-linecap="round"
-																				stroke-linejoin="round"
-																				d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L6.832 19.82a4.5 4.5 0 01-1.897 1.13l-2.685.8.8-2.685a4.5 4.5 0 011.13-1.897L16.863 4.487zm0 0L19.5 7.125"
-																			/>
-																		</svg>
-																	</button>
-																</div> -->
+																				<svg
+																					xmlns="http://www.w3.org/2000/svg"
+																					viewBox="0 0 20 20"
+																					fill="currentColor"
+																					class="w-4 h-4"
+																				>
+																					<path
+																						fill-rule="evenodd"
+																						d="M12.79 5.23a.75.75 0 01-.02 1.06L8.832 10l3.938 3.71a.75.75 0 11-1.04 1.08l-4.5-4.25a.75.75 0 010-1.08l4.5-4.25a.75.75 0 011.06.02z"
+																						clip-rule="evenodd"
+																					/>
+																				</svg>
+																			</button>
+
+																			<div class="text-xs font-bold self-center">
+																				{Object.values(history.messages)
+																					.filter((message) => message.parentId === null)
+																					.map((message) => message.id)
+																					.indexOf(message.id) + 1} / {Object.values(
+																					history.messages
+																				).filter((message) => message.parentId === null).length}
+																			</div>
+
+																			<button
+																				class="self-center"
+																				on:click={() => {
+																					showNextMessage(message);
+																				}}
+																			>
+																				<svg
+																					xmlns="http://www.w3.org/2000/svg"
+																					viewBox="0 0 20 20"
+																					fill="currentColor"
+																					class="w-4 h-4"
+																				>
+																					<path
+																						fill-rule="evenodd"
+																						d="M7.21 14.77a.75.75 0 01.02-1.06L11.168 10 7.23 6.29a.75.75 0 111.04-1.08l4.5 4.25a.75.75 0 010 1.08l-4.5 4.25a.75.75 0 01-1.06-.02z"
+																						clip-rule="evenodd"
+																					/>
+																				</svg>
+																			</button>
+																		</div>
+																	{/if}
 
-																<div class=" flex justify-start space-x-1">
 																	<button
 																		class="invisible group-hover:visible p-1 rounded dark:hover:bg-gray-800 transition"
 																		on:click={() => {
-																			editMessage(messageIdx);
+																			editMessageHandler(message.id);
 																		}}
 																	>
 																		<svg
@@ -1223,6 +1432,56 @@
 
 															{#if message.done}
 																<div class=" flex justify-start space-x-1 -mt-2">
+																	{#if message.parentId !== null && message.parentId in history.messages && (history.messages[message.parentId]?.childrenIds.length ?? 0) > 1}
+																		<div class="flex self-center">
+																			<button
+																				class="self-center"
+																				on:click={() => {
+																					showPreviousMessage(message);
+																				}}
+																			>
+																				<svg
+																					xmlns="http://www.w3.org/2000/svg"
+																					viewBox="0 0 20 20"
+																					fill="currentColor"
+																					class="w-4 h-4"
+																				>
+																					<path
+																						fill-rule="evenodd"
+																						d="M12.79 5.23a.75.75 0 01-.02 1.06L8.832 10l3.938 3.71a.75.75 0 11-1.04 1.08l-4.5-4.25a.75.75 0 010-1.08l4.5-4.25a.75.75 0 011.06.02z"
+																						clip-rule="evenodd"
+																					/>
+																				</svg>
+																			</button>
+
+																			<div class="text-xs font-bold self-center">
+																				{history.messages[message.parentId].childrenIds.indexOf(
+																					message.id
+																				) + 1} / {history.messages[message.parentId].childrenIds
+																					.length}
+																			</div>
+
+																			<button
+																				class="self-center"
+																				on:click={() => {
+																					showNextMessage(message);
+																				}}
+																			>
+																				<svg
+																					xmlns="http://www.w3.org/2000/svg"
+																					viewBox="0 0 20 20"
+																					fill="currentColor"
+																					class="w-4 h-4"
+																				>
+																					<path
+																						fill-rule="evenodd"
+																						d="M7.21 14.77a.75.75 0 01.02-1.06L11.168 10 7.23 6.29a.75.75 0 111.04-1.08l4.5 4.25a.75.75 0 010 1.08l-4.5 4.25a.75.75 0 01-1.06-.02z"
+																						clip-rule="evenodd"
+																					/>
+																				</svg>
+																			</button>
+																		</div>
+																	{/if}
 																	<button
 																		class="{messageIdx + 1 === messages.length
 																			? 'visible'
@@ -1344,6 +1603,7 @@
 									class=" bg-white/20 p-1.5 rounded-full"
 									on:click={() => {
 										window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' });
+										autoScroll = true;
 									}}
 								>
 									<svg