|
@@ -10,6 +10,7 @@ from fastapi import (
|
|
)
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import os, shutil
|
|
import os, shutil
|
|
|
|
+from typing import List
|
|
|
|
|
|
# from chromadb.utils import embedding_functions
|
|
# from chromadb.utils import embedding_functions
|
|
|
|
|
|
@@ -96,19 +97,22 @@ async def get_status():
|
|
return {"status": True}
|
|
return {"status": True}
|
|
|
|
|
|
|
|
|
|
-@app.get("/query/{collection_name}")
|
|
|
|
|
|
+class QueryCollectionForm(BaseModel):
|
|
|
|
+ collection_name: str
|
|
|
|
+ query: str
|
|
|
|
+ k: Optional[int] = 4
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/query/collection")
|
|
def query_collection(
|
|
def query_collection(
|
|
- collection_name: str,
|
|
|
|
- query: str,
|
|
|
|
- k: Optional[int] = 4,
|
|
|
|
|
|
+ form_data: QueryCollectionForm,
|
|
user=Depends(get_current_user),
|
|
user=Depends(get_current_user),
|
|
):
|
|
):
|
|
try:
|
|
try:
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
collection = CHROMA_CLIENT.get_collection(
|
|
- name=collection_name,
|
|
|
|
|
|
+ name=form_data.collection_name,
|
|
)
|
|
)
|
|
- result = collection.query(query_texts=[query], n_results=k)
|
|
|
|
-
|
|
|
|
|
|
+ result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
|
return result
|
|
return result
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print(e)
|
|
print(e)
|
|
@@ -118,6 +122,34 @@ def query_collection(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+class QueryCollectionsForm(BaseModel):
|
|
|
|
+ collection_names: List[str]
|
|
|
|
+ query: str
|
|
|
|
+ k: Optional[int] = 4
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/query/collections")
|
|
|
|
+def query_collections(
|
|
|
|
+ form_data: QueryCollectionsForm,
|
|
|
|
+ user=Depends(get_current_user),
|
|
|
|
+):
|
|
|
|
+ results = []
|
|
|
|
+
|
|
|
|
+ for collection_name in form_data.collection_names:
|
|
|
|
+ try:
|
|
|
|
+ collection = CHROMA_CLIENT.get_collection(
|
|
|
|
+ name=collection_name,
|
|
|
|
+ )
|
|
|
|
+ result = collection.query(
|
|
|
|
+ query_texts=[form_data.query], n_results=form_data.k
|
|
|
|
+ )
|
|
|
|
+ results.append(result)
|
|
|
|
+ except:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ return results
|
|
|
|
+
|
|
|
|
+
|
|
@app.post("/web")
|
|
@app.post("/web")
|
|
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|