Timothy J. Baek 9 months ago
parent
commit
00eb022450
1 changed files with 53 additions and 62 deletions
  1. 53 62
      backend/open_webui/apps/retrieval/main.py

+ 53 - 62
backend/open_webui/apps/retrieval/main.py

@@ -628,39 +628,25 @@ async def update_query_settings(
 ####################################
 
 
-def store_data_in_vector_db(
-    data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
+def save_docs_to_vector_db(
+    docs,
+    collection_name,
+    metadata: Optional[dict] = None,
+    overwrite: bool = False,
+    split: bool = True,
 ) -> bool:
-    text_splitter = RecursiveCharacterTextSplitter(
-        chunk_size=app.state.config.CHUNK_SIZE,
-        chunk_overlap=app.state.config.CHUNK_OVERLAP,
-        add_start_index=True,
-    )
-    docs = text_splitter.split_documents(data)
-
-    if len(docs) > 0:
-        log.info(f"store_data_in_vector_db {docs}")
-        return store_docs_in_vector_db(docs, collection_name, metadata, overwrite)
-    else:
-        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
-
-
-def store_text_in_vector_db(
-    text, metadata, collection_name, overwrite: bool = False
-) -> bool:
-    text_splitter = RecursiveCharacterTextSplitter(
-        chunk_size=app.state.config.CHUNK_SIZE,
-        chunk_overlap=app.state.config.CHUNK_OVERLAP,
-        add_start_index=True,
-    )
-    docs = text_splitter.create_documents([text], metadatas=[metadata])
-    return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite)
+    log.info(f"save_docs_to_vector_db {docs} {collection_name}")
 
+    if split:
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=app.state.config.CHUNK_SIZE,
+            chunk_overlap=app.state.config.CHUNK_OVERLAP,
+            add_start_index=True,
+        )
+        docs = text_splitter.split_documents(docs)
 
-def store_docs_in_vector_db(
-    docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
-) -> bool:
-    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
+    if len(docs) == 0:
+        raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
 
     texts = [doc.page_content for doc in docs]
     metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs]
@@ -728,21 +714,24 @@ def process_file(
         file = Files.get_file_by_id(form_data.file_id)
         file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
 
+        collection_name = form_data.collection_name
+        if collection_name is None:
+            with open(file_path, "rb") as f:
+                collection_name = calculate_sha256(f)[:63]
+
         loader = Loader(
             engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
             TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
             PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
         )
-        data = loader.load(file.filename, file.meta.get("content_type"), file_path)
+        docs = loader.load(file.filename, file.meta.get("content_type"), file_path)
 
-        collection_name = form_data.collection_name
-        if collection_name is None:
-            with open(file_path, "rb") as f:
-                collection_name = calculate_sha256(f)[:63]
+        raw_content = " ".join([doc.page_content for doc in docs])
+        print(raw_content)
 
         try:
-            result = store_data_in_vector_db(
-                data,
+            result = save_docs_to_vector_db(
+                docs,
                 collection_name,
                 {
                     "file_id": form_data.file_id,
@@ -790,11 +779,13 @@ def process_text(
     if collection_name is None:
         collection_name = calculate_sha256_string(form_data.content)
 
-    result = store_text_in_vector_db(
-        form_data.content,
-        metadata={"name": form_data.name, "created_by": user.id},
-        collection_name=collection_name,
-    )
+    docs = [
+        Document(
+            page_content=form_data.content,
+            metadata={"name": form_data.name, "created_by": user.id},
+        )
+    ]
+    result = save_docs_to_vector_db(docs, collection_name)
 
     if result:
         return {"status": True, "collection_name": collection_name}
@@ -822,10 +813,10 @@ def process_docs_dir(user=Depends(get_admin_user)):
                     TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
                     PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
                 )
-                data = loader.load(filename, file_content_type[0], str(path))
+                docs = loader.load(filename, file_content_type[0], str(path))
 
                 try:
-                    result = store_data_in_vector_db(data, collection_name)
+                    result = save_docs_to_vector_db(docs, collection_name)
 
                     if result:
                         sanitized_filename = sanitize_filename(filename)
@@ -870,19 +861,19 @@ def process_docs_dir(user=Depends(get_admin_user)):
 @app.post("/process/youtube")
 def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
     try:
+        collection_name = form_data.collection_name
+        if not collection_name:
+            collection_name = calculate_sha256_string(form_data.url)[:63]
+
         loader = YoutubeLoader.from_youtube_url(
             form_data.url,
             add_video_info=True,
             language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
             translation=app.state.YOUTUBE_LOADER_TRANSLATION,
         )
-        data = loader.load()
-
-        collection_name = form_data.collection_name
-        if not collection_name:
-            collection_name = calculate_sha256_string(form_data.url)[:63]
+        docs = loader.load()
 
-        store_data_in_vector_db(data, collection_name, overwrite=True)
+        save_docs_to_vector_db(docs, collection_name, overwrite=True)
 
         return {
             "status": True,
@@ -900,18 +891,17 @@ def process_youtube_video(form_data: ProcessUrlForm, user=Depends(get_verified_u
 @app.post("/process/web")
 def process_web(form_data: ProcessUrlForm, user=Depends(get_verified_user)):
     try:
+        collection_name = form_data.collection_name
+        if not collection_name:
+            collection_name = calculate_sha256_string(form_data.url)[:63]
+
         loader = get_web_loader(
             form_data.url,
             verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
             requests_per_second=app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
         )
-        data = loader.load()
-
-        collection_name = form_data.collection_name
-        if not collection_name:
-            collection_name = calculate_sha256_string(form_data.url)[:63]
-
-        store_data_in_vector_db(data, collection_name, overwrite=True)
+        docs = loader.load()
+        save_docs_to_vector_db(docs, collection_name, overwrite=True)
 
         return {
             "status": True,
@@ -1060,15 +1050,16 @@ def process_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
         )
 
     try:
-        urls = [result.link for result in web_results]
-        loader = get_web_loader(urls)
-        data = loader.load()
-
         collection_name = form_data.collection_name
         if collection_name == "":
             collection_name = calculate_sha256_string(form_data.query)[:63]
 
-        store_data_in_vector_db(data, collection_name, overwrite=True)
+        urls = [result.link for result in web_results]
+
+        loader = get_web_loader(urls)
+        docs = loader.load()
+        save_docs_to_vector_db(docs, collection_name, overwrite=True)
+
         return {
             "status": True,
             "collection_name": collection_name,