Browse Source

Merge pull request #656 from ollama-webui/openai-voice

feat: openai tts support
Timothy Jaeryang Baek 1 year ago
parent
commit
7f3ba3d2ac

+ 68 - 4
backend/apps/openai/main.py

@@ -1,15 +1,19 @@
 from fastapi import FastAPI, Request, Response, HTTPException, Depends
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse, JSONResponse
+from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 
 import requests
 import json
 from pydantic import BaseModel
 
+
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user
-from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
+from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
+
+import hashlib
+from pathlib import Path
 
 app = FastAPI()
 app.add_middleware(
@@ -66,6 +70,68 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
+@app.post("/audio/speech")
+async def speech(request: Request, user=Depends(get_current_user)):
+    target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
+
+    if user.role not in ["user", "admin"]:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+    if app.state.OPENAI_API_KEY == "":
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    body = await request.body()
+
+    name = hashlib.sha256(body).hexdigest()
+
+    SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
+    SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+    file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
+    file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
+
+    # Check if the file already exists in the cache
+    if file_path.is_file():
+        return FileResponse(file_path)
+
+    headers = {}
+    headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+    headers["Content-Type"] = "application/json"
+
+    try:
+        print("openai")
+        r = requests.post(
+            url=target_url,
+            data=body,
+            headers=headers,
+            stream=True,
+        )
+
+        r.raise_for_status()
+
+        # Save the streaming content to a file
+        with open(file_path, "wb") as f:
+            for chunk in r.iter_content(chunk_size=8192):
+                f.write(chunk)
+
+        with open(file_body_path, "w") as f:
+            json.dump(json.loads(body.decode("utf-8")), f)
+
+        # Return the saved file
+        return FileResponse(file_path)
+
+    except Exception as e:
+        print(e)
+        error_detail = "Ollama WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"External: {res['error']}"
+            except:
+                error_detail = f"External: {e}"
+
+        raise HTTPException(status_code=r.status_code, detail=error_detail)
+
+
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
@@ -129,8 +195,6 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
 
             response_data = r.json()
 
-            print(type(response_data))
-
             if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
                 response_data["data"] = list(
                     filter(lambda model: "gpt" in model["id"], response_data["data"])

+ 8 - 0
backend/config.py

@@ -35,6 +35,14 @@ FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
 UPLOAD_DIR = f"{DATA_DIR}/uploads"
 Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
 
+
+####################################
+# Cache DIR
+####################################
+
+CACHE_DIR = f"{DATA_DIR}/cache"
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+
 ####################################
 # OLLAMA_API_BASE_URL
 ####################################

+ 31 - 0
src/lib/apis/openai/index.ts

@@ -229,3 +229,34 @@ export const generateOpenAIChatCompletion = async (token: string = '', body: obj
 
 	return res;
 };
+
+export const synthesizeOpenAISpeech = async (
+	token: string = '',
+	speaker: string = 'alloy',
+	text: string = ''
+) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/audio/speech`, {
+		method: 'POST',
+		headers: {
+			Authorization: `Bearer ${token}`,
+			'Content-Type': 'application/json'
+		},
+		body: JSON.stringify({
+			model: 'tts-1',
+			input: text,
+			voice: speaker
+		})
+	}).catch((err) => {
+		console.log(err);
+		error = err;
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 91 - 7
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -1,7 +1,8 @@
 <script lang="ts">
+	import toast from 'svelte-french-toast';
 	import dayjs from 'dayjs';
 	import { marked } from 'marked';
-	import { settings, voices } from '$lib/stores';
+	import { settings } from '$lib/stores';
 	import tippy from 'tippy.js';
 	import auto_render from 'katex/dist/contrib/auto-render.mjs';
 	import 'katex/dist/katex.min.css';
@@ -13,6 +14,8 @@
 	import Skeleton from './Skeleton.svelte';
 	import CodeBlock from './CodeBlock.svelte';
 
+	import { synthesizeOpenAISpeech } from '$lib/apis/openai';
+
 	export let modelfiles = [];
 	export let message;
 	export let siblings;
@@ -31,7 +34,10 @@
 	let editedContent = '';
 
 	let tooltipInstance = null;
+
+	let audioMap = {};
 	let speaking = null;
+	let loadingSpeech = false;
 
 	$: tokens = marked.lexer(message.content);
 
@@ -114,12 +120,58 @@
 		if (speaking) {
 			speechSynthesis.cancel();
 			speaking = null;
+
+			audioMap[message.id].pause();
+			audioMap[message.id].currentTime = 0;
 		} else {
 			speaking = true;
-			const speak = new SpeechSynthesisUtterance(message.content);
-			const voice = $voices?.filter((v) => v.name === $settings?.speakVoice)?.at(0) ?? undefined;
-			speak.voice = voice;
-			speechSynthesis.speak(speak);
+
+			if ($settings?.speech?.engine === 'openai') {
+				loadingSpeech = true;
+				const res = await synthesizeOpenAISpeech(
+					localStorage.token,
+					$settings?.speech?.speaker,
+					message.content
+				).catch((error) => {
+					toast.error(error);
+					return null;
+				});
+
+				if (res) {
+					const blob = await res.blob();
+					const blobUrl = URL.createObjectURL(blob);
+					console.log(blobUrl);
+
+					loadingSpeech = false;
+
+					const audio = new Audio(blobUrl);
+					audioMap[message.id] = audio;
+
+					audio.onended = () => {
+						speaking = null;
+					};
+					audio.play().catch((e) => console.error('Error playing audio:', e));
+				}
+			} else {
+				let voices = [];
+				const getVoicesLoop = setInterval(async () => {
+					voices = await speechSynthesis.getVoices();
+					if (voices.length > 0) {
+						clearInterval(getVoicesLoop);
+
+						const voice =
+							voices?.filter((v) => v.name === $settings?.speech?.speaker)?.at(0) ?? undefined;
+
+						const speak = new SpeechSynthesisUtterance(message.content);
+
+						speak.onend = () => {
+							speaking = null;
+						};
+						speak.voice = voice;
+						speechSynthesis.speak(speak);
+					}
+				}, 100);
+			}
 		}
 	};
 
@@ -410,10 +462,42 @@
 												? 'visible'
 												: 'invisible group-hover:visible'} p-1 rounded dark:hover:bg-gray-800 transition"
 											on:click={() => {
-												toggleSpeakMessage(message);
+												if (!loadingSpeech) {
+													toggleSpeakMessage(message);
+												}
 											}}
 										>
-											{#if speaking}
+											{#if loadingSpeech}
+												<svg
+													class=" w-4 h-4"
+													fill="currentColor"
+													viewBox="0 0 24 24"
+													xmlns="http://www.w3.org/2000/svg"
+													><style>
+														.spinner_S1WN {
+															animation: spinner_MGfb 0.8s linear infinite;
+															animation-delay: -0.8s;
+														}
+														.spinner_Km9P {
+															animation-delay: -0.65s;
+														}
+														.spinner_JApP {
+															animation-delay: -0.5s;
+														}
+														@keyframes spinner_MGfb {
+															93.75%,
+															100% {
+																opacity: 0.2;
+															}
+														}
+													</style><circle class="spinner_S1WN" cx="4" cy="12" r="3" /><circle
+														class="spinner_S1WN spinner_Km9P"
+														cx="12"
+														cy="12"
+														r="3"
+													/><circle class="spinner_S1WN spinner_JApP" cx="20" cy="12" r="3" /></svg
+												>
+											{:else if speaking}
 												<svg
 													xmlns="http://www.w3.org/2000/svg"
 													fill="none"

+ 82 - 39
src/lib/components/chat/Settings/Voice.svelte

@@ -1,27 +1,49 @@
 <script lang="ts">
 	import { createEventDispatcher, onMount } from 'svelte';
-	import { voices } from '$lib/stores';
 	const dispatch = createEventDispatcher();
 
 	export let saveSettings: Function;
 
 	// Voice
-	let speakVoice = '';
+	let engines = ['', 'openai'];
+	let engine = '';
 
-	onMount(async () => {
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
+	let voices = [];
+	let speaker = '';
 
-		speakVoice = settings.speakVoice ?? '';
+	const getOpenAIVoices = () => {
+		voices = [
+			{ name: 'alloy' },
+			{ name: 'echo' },
+			{ name: 'fable' },
+			{ name: 'onyx' },
+			{ name: 'nova' },
+			{ name: 'shimmer' }
+		];
+	};
 
+	const getWebAPIVoices = () => {
 		const getVoicesLoop = setInterval(async () => {
-			const _voices = await speechSynthesis.getVoices();
-			await voices.set(_voices);
+			voices = await speechSynthesis.getVoices();
 
 			// do your loop
-			if (_voices.length > 0) {
+			if (voices.length > 0) {
 				clearInterval(getVoicesLoop);
 			}
 		}, 100);
+	};
+
+	onMount(async () => {
+		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
+
+		engine = settings?.speech?.engine ?? '';
+		speaker = settings?.speech?.speaker ?? '';
+
+		if (engine === 'openai') {
+			getOpenAIVoices();
+		} else {
+			getWebAPIVoices();
+		}
 	});
 </script>
 
@@ -29,24 +51,52 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={() => {
 		saveSettings({
-			speakVoice: speakVoice !== '' ? speakVoice : undefined
+			speech: {
+				engine: engine !== '' ? engine : undefined,
+				speaker: speaker !== '' ? speaker : undefined
+			}
 		});
 		dispatch('save');
 	}}
 >
 	<div class=" space-y-3">
-		<div class=" space-y-3">
+		<div class=" py-0.5 flex w-full justify-between">
+			<div class=" self-center text-sm font-medium">Speech Engine</div>
+			<div class="flex items-center relative">
+				<select
+					class="w-fit pr-8 rounded py-2 px-2 text-xs bg-transparent outline-none text-right"
+					bind:value={engine}
+					placeholder="Select a mode"
+					on:change={(e) => {
+						if (e.target.value === 'openai') {
+							getOpenAIVoices();
+							speaker = 'alloy';
+						} else {
+							getWebAPIVoices();
+							speaker = '';
+						}
+					}}
+				>
+					<option value="">Default (Web API)</option>
+					<option value="openai">Open AI</option>
+				</select>
+			</div>
+		</div>
+
+		<hr class=" dark:border-gray-700" />
+
+		{#if engine === ''}
 			<div>
-				<div class=" mb-2.5 text-sm font-medium">Set Default Voice</div>
+				<div class=" mb-2.5 text-sm font-medium">Set Voice</div>
 				<div class="flex w-full">
 					<div class="flex-1">
 						<select
 							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-							bind:value={speakVoice}
+							bind:value={speaker}
 							placeholder="Select a voice"
 						>
 							<option value="" selected>Default</option>
-							{#each $voices.filter((v) => v.localService === true) as voice}
+							{#each voices.filter((v) => v.localService === true) as voice}
 								<option value={voice.name} class="bg-gray-100 dark:bg-gray-700">{voice.name}</option
 								>
 							{/each}
@@ -54,32 +104,25 @@
 					</div>
 				</div>
 			</div>
-		</div>
-
-		<!--
-							<div>
-								<div class=" mb-2.5 text-sm font-medium">
-									Gravatar Email <span class=" text-gray-400 text-sm">(optional)</span>
-								</div>
-								<div class="flex w-full">
-									<div class="flex-1">
-										<input
-											class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-											placeholder="Enter Your Email"
-											bind:value={gravatarEmail}
-											autocomplete="off"
-											type="email"
-										/>
-									</div>
-								</div>
-								<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
-									Changes user profile image to match your <a
-										class=" text-gray-500 dark:text-gray-300 font-medium"
-										href="https://gravatar.com/"
-										target="_blank">Gravatar.</a
-									>
-								</div>
-							</div> -->
+		{:else if engine === 'openai'}
+			<div>
+				<div class=" mb-2.5 text-sm font-medium">Set Voice</div>
+				<div class="flex w-full">
+					<div class="flex-1">
+						<select
+							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							bind:value={speaker}
+							placeholder="Select a voice"
+						>
+							{#each voices as voice}
+								<option value={voice.name} class="bg-gray-100 dark:bg-gray-700">{voice.name}</option
+								>
+							{/each}
+						</select>
+					</div>
+				</div>
+			</div>
+		{/if}
 	</div>
 
 	<div class="flex justify-end pt-3 text-sm font-medium">

+ 0 - 1
src/lib/stores/index.ts

@@ -12,7 +12,6 @@ export const chatId = writable('');
 export const chats = writable([]);
 export const tags = writable([]);
 export const models = writable([]);
-export const voices = writable([]);
 
 export const modelfiles = writable([]);
 export const prompts = writable([]);