|
@@ -9,6 +9,8 @@ from fastapi import (
|
|
|
)
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
import os, shutil
|
|
|
+
|
|
|
+from pathlib import Path
|
|
|
from typing import List
|
|
|
|
|
|
from chromadb.utils import embedding_functions
|
|
@@ -28,23 +30,45 @@ from langchain_community.document_loaders import (
|
|
|
)
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
|
-
|
|
|
-
|
|
|
from pydantic import BaseModel
|
|
|
from typing import Optional
|
|
|
-
|
|
|
+import mimetypes
|
|
|
import uuid
|
|
|
|
|
|
+from apps.web.models.documents import (
|
|
|
+ Documents,
|
|
|
+ DocumentForm,
|
|
|
+ DocumentResponse,
|
|
|
+)
|
|
|
|
|
|
-from utils.misc import calculate_sha256, calculate_sha256_string
|
|
|
+from utils.misc import (
|
|
|
+ calculate_sha256,
|
|
|
+ calculate_sha256_string,
|
|
|
+ sanitize_filename,
|
|
|
+ extract_folders_after_data_docs,
|
|
|
+)
|
|
|
from utils.utils import get_current_user, get_admin_user
|
|
|
-from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
|
|
|
+from config import (
|
|
|
+ UPLOAD_DIR,
|
|
|
+ DOCS_DIR,
|
|
|
+ SENTENCE_TRANSFORMER_EMBED_MODEL,
|
|
|
+ CHROMA_CLIENT,
|
|
|
+ CHUNK_SIZE,
|
|
|
+ CHUNK_OVERLAP,
|
|
|
+ RAG_TEMPLATE,
|
|
|
+)
|
|
|
+
|
|
|
from constants import ERROR_MESSAGES
|
|
|
|
|
|
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=SENTENCE_TRANSFORMER_EMBED_MODEL)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
+app.state.CHUNK_SIZE = CHUNK_SIZE
|
|
|
+app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|
|
+app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
|
|
+
|
|
|
+
|
|
|
origins = ["*"]
|
|
|
|
|
|
app.add_middleware(
|
|
@@ -66,7 +90,7 @@ class StoreWebForm(CollectionNameForm):
|
|
|
|
|
|
def store_data_in_vector_db(data, collection_name) -> bool:
|
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
|
- chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
|
|
|
+ chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
|
|
|
)
|
|
|
docs = text_splitter.split_documents(data)
|
|
|
|
|
@@ -96,7 +120,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
|
|
|
|
|
@app.get("/")
|
|
|
async def get_status():
|
|
|
- return {"status": True}
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "chunk_size": app.state.CHUNK_SIZE,
|
|
|
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/chunk")
|
|
|
+async def get_chunk_params(user=Depends(get_admin_user)):
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "chunk_size": app.state.CHUNK_SIZE,
|
|
|
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class ChunkParamUpdateForm(BaseModel):
|
|
|
+ chunk_size: int
|
|
|
+ chunk_overlap: int
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/chunk/update")
|
|
|
+async def update_chunk_params(
|
|
|
+ form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
|
|
|
+):
|
|
|
+ app.state.CHUNK_SIZE = form_data.chunk_size
|
|
|
+ app.state.CHUNK_OVERLAP = form_data.chunk_overlap
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "chunk_size": app.state.CHUNK_SIZE,
|
|
|
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/template")
|
|
|
+async def get_rag_template(user=Depends(get_current_user)):
|
|
|
+ return {
|
|
|
+ "status": True,
|
|
|
+ "template": app.state.RAG_TEMPLATE,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class RAGTemplateForm(BaseModel):
|
|
|
+ template: str
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/template/update")
|
|
|
+async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
|
|
|
+ # TODO: check template requirements
|
|
|
+ app.state.RAG_TEMPLATE = (
|
|
|
+ form_data.template if form_data.template != "" else RAG_TEMPLATE
|
|
|
+ )
|
|
|
+ return {"status": True, "template": app.state.RAG_TEMPLATE}
|
|
|
|
|
|
|
|
|
class QueryDocForm(BaseModel):
|
|
@@ -239,8 +316,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
|
|
)
|
|
|
|
|
|
|
|
|
-def get_loader(file, file_path):
|
|
|
- file_ext = file.filename.split(".")[-1].lower()
|
|
|
+def get_loader(filename: str, file_content_type: str, file_path: str):
|
|
|
+ file_ext = filename.split(".")[-1].lower()
|
|
|
known_type = True
|
|
|
|
|
|
known_source_ext = [
|
|
@@ -298,20 +375,20 @@ def get_loader(file, file_path):
|
|
|
loader = UnstructuredXMLLoader(file_path)
|
|
|
elif file_ext == "md":
|
|
|
loader = UnstructuredMarkdownLoader(file_path)
|
|
|
- elif file.content_type == "application/epub+zip":
|
|
|
+ elif file_content_type == "application/epub+zip":
|
|
|
loader = UnstructuredEPubLoader(file_path)
|
|
|
elif (
|
|
|
- file.content_type
|
|
|
+ file_content_type
|
|
|
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
|
or file_ext in ["doc", "docx"]
|
|
|
):
|
|
|
loader = Docx2txtLoader(file_path)
|
|
|
- elif file.content_type in [
|
|
|
+ elif file_content_type in [
|
|
|
"application/vnd.ms-excel",
|
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
|
] or file_ext in ["xls", "xlsx"]:
|
|
|
loader = UnstructuredExcelLoader(file_path)
|
|
|
- elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
|
|
|
+ elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
|
|
|
loader = TextLoader(file_path)
|
|
|
else:
|
|
|
loader = TextLoader(file_path)
|
|
@@ -342,7 +419,7 @@ def store_doc(
|
|
|
collection_name = calculate_sha256(f)[:63]
|
|
|
f.close()
|
|
|
|
|
|
- loader, known_type = get_loader(file, file_path)
|
|
|
+ loader, known_type = get_loader(file.filename, file.content_type, file_path)
|
|
|
data = loader.load()
|
|
|
result = store_data_in_vector_db(data, collection_name)
|
|
|
|
|
@@ -372,6 +449,63 @@ def store_doc(
|
|
|
)
|
|
|
|
|
|
|
|
|
+@app.get("/scan")
|
|
|
+def scan_docs_dir(user=Depends(get_admin_user)):
|
|
|
+ try:
|
|
|
+ for path in Path(DOCS_DIR).rglob("./**/*"):
|
|
|
+ if path.is_file() and not path.name.startswith("."):
|
|
|
+ tags = extract_folders_after_data_docs(path)
|
|
|
+ filename = path.name
|
|
|
+ file_content_type = mimetypes.guess_type(path)
|
|
|
+
|
|
|
+ f = open(path, "rb")
|
|
|
+ collection_name = calculate_sha256(f)[:63]
|
|
|
+ f.close()
|
|
|
+
|
|
|
+ loader, known_type = get_loader(
|
|
|
+ filename, file_content_type[0], str(path)
|
|
|
+ )
|
|
|
+ data = loader.load()
|
|
|
+
|
|
|
+ result = store_data_in_vector_db(data, collection_name)
|
|
|
+
|
|
|
+ if result:
|
|
|
+ sanitized_filename = sanitize_filename(filename)
|
|
|
+ doc = Documents.get_doc_by_name(sanitized_filename)
|
|
|
+
|
|
|
+ if doc == None:
|
|
|
+ doc = Documents.insert_new_doc(
|
|
|
+ user.id,
|
|
|
+ DocumentForm(
|
|
|
+ **{
|
|
|
+ "name": sanitized_filename,
|
|
|
+ "title": filename,
|
|
|
+ "collection_name": collection_name,
|
|
|
+ "filename": filename,
|
|
|
+ "content": (
|
|
|
+ json.dumps(
|
|
|
+ {
|
|
|
+ "tags": list(
|
|
|
+ map(
|
|
|
+ lambda name: {"name": name},
|
|
|
+ tags,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ }
|
|
|
+ )
|
|
|
+ if len(tags)
|
|
|
+ else "{}"
|
|
|
+ ),
|
|
|
+ }
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
@app.get("/reset/db")
|
|
|
def reset_vector_db(user=Depends(get_admin_user)):
|
|
|
CHROMA_CLIENT.reset()
|