Quellcode durchsuchen

refac: batch file processing

Co-Authored-By: Sihyeon Jang <24850223+sihyeonn@users.noreply.github.com>
Timothy Jaeryang Baek vor 3 Monaten
Ursprung
Commit
a65cc196a5
2 geänderte Dateien mit 57 neuen und 20 gelöschten Zeilen
  1. 30 0
      backend/open_webui/models/files.py
  2. 27 20
      backend/open_webui/routers/retrieval.py

+ 30 - 0
backend/open_webui/models/files.py

@@ -98,6 +98,13 @@ class FileForm(BaseModel):
     access_control: Optional[dict] = None
     access_control: Optional[dict] = None
 
 
 
 
+class FileUpdateForm(BaseModel):
+    id: str
+    hash: Optional[str] = None
+    data: Optional[dict] = None
+    meta: Optional[dict] = None
+
+
 class FilesTable:
 class FilesTable:
     def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
     def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
         with get_db() as db:
         with get_db() as db:
@@ -204,6 +211,29 @@ class FilesTable:
                 for file in db.query(File).filter_by(user_id=user_id).all()
                 for file in db.query(File).filter_by(user_id=user_id).all()
             ]
             ]
 
 
+    def update_file_by_id(
+        self, id: str, form_data: FileUpdateForm
+    ) -> Optional[FileModel]:
+        with get_db() as db:
+            try:
+                file = db.query(File).filter_by(id=id).first()
+
+                if form_data.hash is not None:
+                    file.hash = form_data.hash
+
+                if form_data.data is not None:
+                    file.data = {**(file.data if file.data else {}), **form_data.data}
+
+                if form_data.meta is not None:
+                    file.meta = {**(file.meta if file.meta else {}), **form_data.meta}
+
+                file.updated_at = int(time.time())
+                db.commit()
+                return FileModel.model_validate(file)
+            except Exception as e:
+                log.exception(f"Error updating file completely by id: {e}")
+                return None
+
     def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
     def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
         with get_db() as db:
         with get_db() as db:
             try:
             try:

+ 27 - 20
backend/open_webui/routers/retrieval.py

@@ -32,7 +32,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSpl
 from langchain_text_splitters import MarkdownHeaderTextSplitter
 from langchain_text_splitters import MarkdownHeaderTextSplitter
 from langchain_core.documents import Document
 from langchain_core.documents import Document
 
 
-from open_webui.models.files import FileModel, Files
+from open_webui.models.files import FileModel, FileUpdateForm, Files
 from open_webui.models.knowledge import Knowledges
 from open_webui.models.knowledge import Knowledges
 from open_webui.storage.provider import Storage
 from open_webui.storage.provider import Storage
 
 
@@ -2452,16 +2452,19 @@ def process_files_batch(
     """
     """
     Process a batch of files and save them to the vector database.
     Process a batch of files and save them to the vector database.
     """
     """
-    results: List[BatchProcessFilesResult] = []
-    errors: List[BatchProcessFilesResult] = []
+
     collection_name = form_data.collection_name
     collection_name = form_data.collection_name
 
 
+    file_results: List[BatchProcessFilesResult] = []
+    file_errors: List[BatchProcessFilesResult] = []
+    file_updates: List[FileUpdateForm] = []
+
     # Prepare all documents first
     # Prepare all documents first
     all_docs: List[Document] = []
     all_docs: List[Document] = []
+
     for file in form_data.files:
     for file in form_data.files:
         try:
         try:
             text_content = file.data.get("content", "")
             text_content = file.data.get("content", "")
-
             docs: List[Document] = [
             docs: List[Document] = [
                 Document(
                 Document(
                     page_content=text_content.replace("<br/>", "\n"),
                     page_content=text_content.replace("<br/>", "\n"),
@@ -2475,16 +2478,22 @@ def process_files_batch(
                 )
                 )
             ]
             ]
 
 
-            hash = calculate_sha256_string(text_content)
-            Files.update_file_hash_by_id(file.id, hash)
-            Files.update_file_data_by_id(file.id, {"content": text_content})
-
             all_docs.extend(docs)
             all_docs.extend(docs)
-            results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
+
+            file_updates.append(
+                FileUpdateForm(
+                    id=file.id,
+                    hash=calculate_sha256_string(text_content),
+                    data={"content": text_content},
+                )
+            )
+            file_results.append(
+                BatchProcessFilesResult(file_id=file.id, status="prepared")
+            )
 
 
         except Exception as e:
         except Exception as e:
             log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
             log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
-            errors.append(
+            file_errors.append(
                 BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
                 BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
             )
             )
 
 
@@ -2500,20 +2509,18 @@ def process_files_batch(
             )
             )
 
 
             # Update all files with collection name
             # Update all files with collection name
-            for result in results:
-                Files.update_file_metadata_by_id(
-                    result.file_id, {"collection_name": collection_name}
-                )
-                result.status = "completed"
+            for file_update, file_result in zip(file_updates, file_results):
+                Files.update_file_by_id(id=file_result.file_id, form_data=file_update)
+                file_result.status = "completed"
 
 
         except Exception as e:
         except Exception as e:
             log.error(
             log.error(
                 f"process_files_batch: Error saving documents to vector DB: {str(e)}"
                 f"process_files_batch: Error saving documents to vector DB: {str(e)}"
             )
             )
-            for result in results:
-                result.status = "failed"
-                errors.append(
-                    BatchProcessFilesResult(file_id=result.file_id, error=str(e))
+            for file_result in file_results:
+                file_result.status = "failed"
+                file_errors.append(
+                    BatchProcessFilesResult(file_id=file_result.file_id, error=str(e))
                 )
                 )
 
 
-    return BatchProcessFilesResponse(results=results, errors=errors)
+    return BatchProcessFilesResponse(results=file_results, errors=file_errors)