Browse Source

feat: prototype frontend web search integration

Jun Siang Cheah 1 year ago
parent
commit
2660a6e5b8

+ 9 - 3
backend/apps/rag/main.py

@@ -93,6 +93,7 @@ from config import (
     CHUNK_OVERLAP,
     RAG_TEMPLATE,
     ENABLE_RAG_LOCAL_WEB_FETCH,
+    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
 )
 
 from constants import ERROR_MESSAGES
@@ -538,18 +539,23 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
             detail=ERROR_MESSAGES.DEFAULT(e),
         )
 
+
 def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
     # Check if the URL is valid
     if not validate_url(url):
         raise ValueError(ERROR_MESSAGES.INVALID_URL)
-    return WebBaseLoader(url, verify_ssl=verify_ssl)
+    return WebBaseLoader(
+        url,
+        verify_ssl=verify_ssl,
+        requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
+    )
 
 
 def validate_url(url: Union[str, Sequence[str]]):
     if isinstance(url, str):
         if isinstance(validators.url(url), validators.ValidationError):
             raise ValueError(ERROR_MESSAGES.INVALID_URL)
-        if not ENABLE_LOCAL_WEB_FETCH:
+        if not ENABLE_RAG_LOCAL_WEB_FETCH:
             # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
             parsed_url = urllib.parse.urlparse(url)
             # Get IPv4 and IPv6 addresses
@@ -593,7 +599,7 @@ def store_websearch(form_data: SearchForm, user=Depends(get_current_user)):
             )
         urls = [result.link for result in web_results]
         loader = get_web_loader(urls)
-        data = loader.load()
+        data = loader.aload()
 
         collection_name = form_data.collection_name
         if collection_name == "":

+ 3 - 3
backend/apps/rag/search/brave.py

@@ -3,7 +3,7 @@ import logging
 import requests
 
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
+from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
         "Accept-Encoding": "gzip",
         "X-Subscription-Token": api_key,
     }
-    params = {"q": query, "count": WEB_SEARCH_RESULT_COUNT}
+    params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
 
     response = requests.get(url, headers=headers, params=params)
     response.raise_for_status()
@@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("snippet")
         )
-        for result in results[:WEB_SEARCH_RESULT_COUNT]
+        for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
     ]

+ 2 - 2
backend/apps/rag/search/google_pse.py

@@ -4,7 +4,7 @@ import logging
 import requests
 
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
+from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -27,7 +27,7 @@ def search_google_pse(
         "cx": search_engine_id,
         "q": query,
         "key": api_key,
-        "num": WEB_SEARCH_RESULT_COUNT,
+        "num": RAG_WEB_SEARCH_RESULT_COUNT,
     }
 
     response = requests.request("GET", url, headers=headers, params=params)

+ 2 - 2
backend/apps/rag/search/searxng.py

@@ -3,7 +3,7 @@ import logging
 import requests
 
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
+from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -40,5 +40,5 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("content")
         )
-        for result in sorted_results[:WEB_SEARCH_RESULT_COUNT]
+        for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
     ]

+ 2 - 2
backend/apps/rag/search/serper.py

@@ -4,7 +4,7 @@ import logging
 import requests
 
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
+from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
             title=result.get("title"),
             snippet=result.get("description"),
         )
-        for result in results[:WEB_SEARCH_RESULT_COUNT]
+        for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
     ]

+ 2 - 2
backend/apps/rag/search/serpstack.py

@@ -4,7 +4,7 @@ import logging
 import requests
 
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS, WEB_SEARCH_RESULT_COUNT
+from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -39,5 +39,5 @@ def search_serpstack(
         SearchResult(
             link=result["url"], title=result.get("title"), snippet=result.get("snippet")
         )
-        for result in results[:WEB_SEARCH_RESULT_COUNT]
+        for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
     ]

+ 4 - 1
backend/config.py

@@ -549,7 +549,10 @@ BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
 SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
 SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
 SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
-WEB_SEARCH_RESULT_COUNT = int(os.getenv("WEB_SEARCH_RESULT_COUNT", "10"))
+RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "10"))
+RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
+    os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
+)
 
 ####################################
 # Transcribe

+ 116 - 0
src/lib/apis/openai/index.ts

@@ -318,3 +318,119 @@ export const generateTitle = async (
 
 	return res?.choices[0]?.message?.content ?? 'New Chat';
 };
+
+export const generateSearchQuery = async (
+	token: string = '',
+	// template: string,
+	model: string,
+	prompt: string,
+	url: string = OPENAI_API_BASE_URL
+): Promise<string | undefined> => {
+	let error = null;
+
+	// TODO: Allow users to specify the prompt
+	// template = promptTemplate(template, prompt);
+
+	// Get the current date in the format "January 20, 2024"
+	const currentDate = new Intl.DateTimeFormat('en-US', {
+		year: 'numeric',
+		month: 'long',
+		day: '2-digit'
+	}).format(new Date());
+	const yesterdayDate = new Intl.DateTimeFormat('en-US', {
+		year: 'numeric',
+		month: 'long',
+		day: '2-digit'
+	}).format(new Date());
+
+	// console.log(template);
+
+	const res = await fetch(`${url}/chat/completions`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			model: model,
+			// Few shot prompting
+			messages: [
+				{
+					role: 'assistant',
+					content: `You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is ${currentDate}.`
+				},
+				{
+					role: 'user',
+					content: `Previous Questions:
+- Who is the president of France?
+
+Current Question: What about Mexico?`
+				},
+				{
+					role: 'assistant',
+					content: 'President of Mexico'
+				},
+				{
+					role: 'user',
+					content: `Previous questions: 
+- When is the next formula 1 grand prix?
+
+Current Question: Where is it being hosted?`
+				},
+				{
+					role: 'assistant',
+					content: 'location of next formula 1 grand prix'
+				},
+				{
+					role: 'user',
+					content: 'Current Question: What type of printhead does the Epson F2270 DTG printer use?'
+				},
+				{
+					role: 'assistant',
+					content: 'Epson F2270 DTG printer printhead'
+				},
+				{
+					role: 'user',
+					content: 'What were the news yesterday?'
+				},
+				{
+					role: 'assistant',
+					content: `news ${yesterdayDate}`
+				},
+				{
+					role: 'user',
+					content: 'What is the current weather in Paris?'
+				},
+				{
+					role: 'assistant',
+					content: `weather in Paris ${currentDate}`
+				},
+				{
+					role: 'user',
+					content: `Current Question: ${prompt}`
+				}
+			],
+			stream: false,
+			// Restricting the max tokens to 30 to avoid long search queries
+			max_tokens: 30
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			}
+			return undefined;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res?.choices[0]?.message?.content.replace(/["']/g, '') ?? undefined;
+};

+ 41 - 0
src/lib/apis/rag/index.ts

@@ -507,3 +507,44 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod
 
 	return res;
 };
+
+export const runWebSearch = async (
+	token: string,
+	query: string,
+	collection_name?: string
+): Promise<SearchDocument | undefined> => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/websearch`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			query,
+			collection_name
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return undefined;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export interface SearchDocument {
+	status: boolean;
+	collection_name: string;
+	filenames: string[];
+}

+ 62 - 2
src/routes/(app)/+page.svelte

@@ -30,8 +30,8 @@
 		getTagsById,
 		updateChatById
 	} from '$lib/apis/chats';
-	import { queryCollection, queryDoc } from '$lib/apis/rag';
-	import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
+	import { queryCollection, queryDoc, runWebSearch } from '$lib/apis/rag';
+	import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
 	import Messages from '$lib/components/chat/Messages.svelte';
@@ -55,6 +55,8 @@
 	let selectedModels = [''];
 	let atSelectedModel = '';
 
+	let useWebSearch = false;
+
 	let selectedModelfile = null;
 	$: selectedModelfile =
 		selectedModels.length === 1 &&
@@ -275,6 +277,39 @@
 						];
 					}
 
+					if (useWebSearch) {
+						// TODO: Toasts are temporary indicators for web search
+						toast.info($i18n.t('Generating search query'));
+						const searchQuery = await generateChatSearchQuery(prompt);
+						if (searchQuery) {
+							toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
+							const searchDocUuid = uuidv4();
+							const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
+							if (searchDocument) {
+								const parentMessage = history.messages[parentId];
+								if (!parentMessage.files) {
+									parentMessage.files = [];
+								}
+								parentMessage.files.push({
+									collection_name: searchDocument.collection_name,
+									name: searchQuery,
+									type: 'doc',
+									upload_status: true,
+									error: ""
+								});
+								// Find message in messages and update it
+								const messageIndex = messages.findIndex((message) => message.id === parentId);
+								if (messageIndex !== -1) {
+									messages[messageIndex] = parentMessage;
+								}
+							} else {
+								toast.warning($i18n.t('No search results found'));
+							}
+						} else {
+							toast.warning($i18n.t('No search query generated'));
+						}
+					}
+
 					if (model?.external) {
 						await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
 					} else if (model) {
@@ -807,6 +842,30 @@
 		}
 	};
 
+	// TODO: Add support for adding all the user's messages as context, and not just the last message
+	const generateChatSearchQuery = async (userPrompt: string) => {
+		const model = $models.find((model) => model.id === selectedModels[0]);
+
+		// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
+		const titleModelId =
+			model?.external ?? false
+				? $settings?.title?.modelExternal ?? selectedModels[0]
+				: $settings?.title?.model ?? selectedModels[0];
+		const titleModel = $models.find((model) => model.id === titleModelId);
+
+		console.log(titleModel);
+		return await generateSearchQuery(
+			localStorage.token,
+			titleModelId,
+			userPrompt,
+			titleModel?.external ?? false
+				? titleModel?.source?.toLowerCase() === 'litellm'
+					? `${LITELLM_API_BASE_URL}/v1`
+					: `${OPENAI_API_BASE_URL}`
+				: `${OLLAMA_API_BASE_URL}/v1`
+		);
+	};
+
 	const setChatTitle = async (_chatId, _title) => {
 		if (_chatId === $chatId) {
 			title = _title;
@@ -906,6 +965,7 @@
 	bind:prompt
 	bind:autoScroll
 	bind:selectedModel={atSelectedModel}
+	bind:useWebSearch
 	{messages}
 	{submitPrompt}
 	{stopResponse}

+ 62 - 1
src/routes/(app)/c/[id]/+page.svelte

@@ -30,7 +30,7 @@
 		getTagsById,
 		updateChatById
 	} from '$lib/apis/chats';
-	import { generateOpenAIChatCompletion, generateTitle } from '$lib/apis/openai';
+	import { generateOpenAIChatCompletion, generateSearchQuery, generateTitle } from '$lib/apis/openai';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
 	import Messages from '$lib/components/chat/Messages.svelte';
@@ -43,6 +43,7 @@
 		WEBUI_BASE_URL
 	} from '$lib/constants';
 	import { createOpenAITextStream } from '$lib/apis/streaming';
+	import { runWebSearch } from '$lib/apis/rag';
 
 	const i18n = getContext('i18n');
 
@@ -59,6 +60,8 @@
 	let selectedModels = [''];
 	let atSelectedModel = '';
 
+	let useWebSearch = false;
+
 	let selectedModelfile = null;
 
 	$: selectedModelfile =
@@ -287,6 +290,39 @@
 						];
 					}
 
+					if (useWebSearch) {
+						// TODO: Toasts are temporary indicators for web search
+						toast.info($i18n.t('Generating search query'));
+						const searchQuery = await generateChatSearchQuery(prompt);
+						if (searchQuery) {
+							toast.info($i18n.t('Searching the web for \'{{searchQuery}}\'', { searchQuery }));
+							const searchDocUuid = uuidv4();
+							const searchDocument = await runWebSearch(localStorage.token, searchQuery, searchDocUuid);
+							if (searchDocument) {
+								const parentMessage = history.messages[parentId];
+								if (!parentMessage.files) {
+									parentMessage.files = [];
+								}
+								parentMessage.files.push({
+									collection_name: searchDocument.collection_name,
+									name: searchQuery,
+									type: 'doc',
+									upload_status: true,
+									error: ""
+								});
+								// Find message in messages and update it
+								const messageIndex = messages.findIndex((message) => message.id === parentId);
+								if (messageIndex !== -1) {
+									messages[messageIndex] = parentMessage;
+								}
+							} else {
+								toast.warning($i18n.t('No search results found'));
+							}
+						} else {
+							toast.warning($i18n.t('No search query generated'));
+						}
+					}
+
 					if (model?.external) {
 						await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
 					} else if (model) {
@@ -819,6 +855,30 @@
 		}
 	};
 
+	// TODO: Add support for adding all the user's messages as context, and not just the last message
+	const generateChatSearchQuery = async (userPrompt: string) => {
+		const model = $models.find((model) => model.id === selectedModels[0]);
+
+		// TODO: rename titleModel to taskModel - this is the model used for non-chat tasks (e.g. title generation, search query generation)
+		const titleModelId =
+			model?.external ?? false
+				? $settings?.title?.modelExternal ?? selectedModels[0]
+				: $settings?.title?.model ?? selectedModels[0];
+		const titleModel = $models.find((model) => model.id === titleModelId);
+
+		console.log(titleModel);
+		return await generateSearchQuery(
+			localStorage.token,
+			titleModelId,
+			userPrompt,
+			titleModel?.external ?? false
+				? titleModel?.source?.toLowerCase() === 'litellm'
+					? `${LITELLM_API_BASE_URL}/v1`
+					: `${OPENAI_API_BASE_URL}`
+				: `${OLLAMA_API_BASE_URL}/v1`
+		);
+	};
+
 	const setChatTitle = async (_chatId, _title) => {
 		if (_chatId === $chatId) {
 			title = _title;
@@ -929,6 +989,7 @@
 		bind:prompt
 		bind:autoScroll
 		bind:selectedModel={atSelectedModel}
+		bind:useWebSearch
 		suggestionPrompts={selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
 		{messages}
 		{submitPrompt}