Timothy J. Baek 1 year ago
parent
commit
50f7b20ac2

+ 39 - 7
backend/apps/rag/main.py

@@ -10,6 +10,7 @@ from fastapi import (
 )
 )
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 import os, shutil
 import os, shutil
+from typing import List
 
 
 # from chromadb.utils import embedding_functions
 # from chromadb.utils import embedding_functions
 
 
@@ -96,19 +97,22 @@ async def get_status():
     return {"status": True}
     return {"status": True}
 
 
 
 
-@app.get("/query/{collection_name}")
+class QueryCollectionForm(BaseModel):
+    collection_name: str
+    query: str
+    k: Optional[int] = 4
+
+
+@app.post("/query/collection")
 def query_collection(
 def query_collection(
-    collection_name: str,
-    query: str,
-    k: Optional[int] = 4,
+    form_data: QueryCollectionForm,
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
     try:
     try:
         collection = CHROMA_CLIENT.get_collection(
         collection = CHROMA_CLIENT.get_collection(
-            name=collection_name,
+            name=form_data.collection_name,
         )
         )
-        result = collection.query(query_texts=[query], n_results=k)
-
+        result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
         return result
         return result
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
@@ -118,6 +122,34 @@ def query_collection(
         )
         )
 
 
 
 
+class QueryCollectionsForm(BaseModel):
+    collection_names: List[str]
+    query: str
+    k: Optional[int] = 4
+
+
+@app.post("/query/collections")
+def query_collections(
+    form_data: QueryCollectionsForm,
+    user=Depends(get_current_user),
+):
+    results = []
+
+    for collection_name in form_data.collection_names:
+        try:
+            collection = CHROMA_CLIENT.get_collection(
+                name=collection_name,
+            )
+            result = collection.query(
+                query_texts=[form_data.query], n_results=form_data.k
+            )
+            results.append(result)
+        except:
+            pass
+
+    return results
+
+
 @app.post("/web")
 @app.post("/web")
 def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
 def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"

+ 14 - 17
src/lib/apis/rag/index.ts

@@ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
 
 
 export const queryVectorDB = async (
 export const queryVectorDB = async (
 	token: string,
 	token: string,
-	collection_name: string,
+	collection_names: string[],
 	query: string,
 	query: string,
 	k: number
 	k: number
 ) => {
 ) => {
 	let error = null;
 	let error = null;
-	const searchParams = new URLSearchParams();
 
 
-	searchParams.set('query', query);
-	if (k) {
-		searchParams.set('k', k.toString());
-	}
-
-	const res = await fetch(
-		`${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`,
-		{
-			method: 'GET',
-			headers: {
-				Accept: 'application/json',
-				authorization: `Bearer ${token}`
-			}
-		}
-	)
+	const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			collection_names: collection_names,
+			query: query,
+			k: k
+		})
+	})
 		.then(async (res) => {
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
 			if (!res.ok) throw await res.json();
 			return res.json();
 			return res.json();

+ 19 - 17
src/routes/(app)/+page.svelte

@@ -232,26 +232,28 @@
 			processing = 'Reading';
 			processing = 'Reading';
 			const query = history.messages[parentId].content;
 			const query = history.messages[parentId].content;
 
 
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch(
-						(error) => {
-							console.log(error);
-							return null;
-						}
-					);
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
+			let relevantContexts = await queryVectorDB(
+				localStorage.token,
+				docs.map((d) => d.collection_name),
+				query,
+				4
+			).catch((error) => {
+				console.log(error);
+				return null;
+			});
+
+			if (relevantContexts) {
+				relevantContexts = relevantContexts.filter((context) => context);
 
 
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
+				const contextString = relevantContexts.reduce((a, context, i, arr) => {
+					return `${a}${context.documents.join(' ')}\n`;
+				}, '');
 
 
-			console.log(contextString);
+				console.log(contextString);
 
 
-			history.messages[parentId].raContent = RAGTemplate(contextString, query);
-			history.messages[parentId].contexts = relevantContexts;
+				history.messages[parentId].raContent = RAGTemplate(contextString, query);
+				history.messages[parentId].contexts = relevantContexts;
+			}
 			await tick();
 			await tick();
 			processing = '';
 			processing = '';
 		}
 		}

+ 19 - 17
src/routes/(app)/c/[id]/+page.svelte

@@ -246,26 +246,28 @@
 			processing = 'Reading';
 			processing = 'Reading';
 			const query = history.messages[parentId].content;
 			const query = history.messages[parentId].content;
 
 
-			let relevantContexts = await Promise.all(
-				docs.map(async (doc) => {
-					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch(
-						(error) => {
-							console.log(error);
-							return null;
-						}
-					);
-				})
-			);
-			relevantContexts = relevantContexts.filter((context) => context);
+			let relevantContexts = await queryVectorDB(
+				localStorage.token,
+				docs.map((d) => d.collection_name),
+				query,
+				4
+			).catch((error) => {
+				console.log(error);
+				return null;
+			});
+
+			if (relevantContexts) {
+				relevantContexts = relevantContexts.filter((context) => context);
 
 
-			const contextString = relevantContexts.reduce((a, context, i, arr) => {
-				return `${a}${context.documents.join(' ')}\n`;
-			}, '');
+				const contextString = relevantContexts.reduce((a, context, i, arr) => {
+					return `${a}${context.documents.join(' ')}\n`;
+				}, '');
 
 
-			console.log(contextString);
+				console.log(contextString);
 
 
-			history.messages[parentId].raContent = RAGTemplate(contextString, query);
-			history.messages[parentId].contexts = relevantContexts;
+				history.messages[parentId].raContent = RAGTemplate(contextString, query);
+				history.messages[parentId].contexts = relevantContexts;
+			}
 			await tick();
 			await tick();
 			processing = '';
 			processing = '';
 		}
 		}