浏览代码

Merge pull request #7881 from gabriel-ecegi/dev

feat: Batch Processing for Large-Scale Document Import
Timothy Jaeryang Baek 7 月之前
父节点
当前提交
9abae36264
共有 2 个文件被更改,包括 181 次插入6 次删除
  1. 85 4
      backend/open_webui/routers/knowledge.py
  2. 96 2
      backend/open_webui/routers/retrieval.py

+ 85 - 4
backend/open_webui/routers/knowledge.py

@@ -1,5 +1,4 @@
-import json
-from typing import Optional, Union
+from typing import List, Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 import logging
 import logging
@@ -12,11 +11,11 @@ from open_webui.models.knowledge import (
 )
 )
 from open_webui.models.files import Files, FileModel
 from open_webui.models.files import Files, FileModel
 from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.routers.retrieval import process_file, ProcessFileForm
+from open_webui.routers.retrieval import process_file, ProcessFileForm, process_files_batch, BatchProcessFilesForm
 
 
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.utils.access_control import has_access, has_permission
 
 
 
 
@@ -514,3 +513,85 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
 
 
     return knowledge
     return knowledge
+
+
+############################
+# AddFilesToKnowledge
+############################
+
+@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
+def add_files_to_knowledge_batch(
+    id: str,
+    form_data: list[KnowledgeFileIdForm],
+    user=Depends(get_verified_user),
+):
+    """
+    Add multiple files to a knowledge base
+    """
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    # Get files content
+    print(f"files/batch/add - {len(form_data)} files")
+    files: List[FileModel] = []
+    for form in form_data:
+        file = Files.get_file_by_id(form.file_id)
+        if not file:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=f"File {form.file_id} not found",
+            )
+        files.append(file)
+
+    # Process files
+    try:
+        result = process_files_batch(BatchProcessFilesForm(
+            files=files,
+            collection_name=id
+        ))
+    except Exception as e:
+        log.error(f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=str(e)
+        )
+    
+    # Add successful files to knowledge base
+    data = knowledge.data or {}
+    existing_file_ids = data.get("file_ids", [])
+    
+    # Only add files that were successfully processed
+    successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
+    for file_id in successful_file_ids:
+        if file_id not in existing_file_ids:
+            existing_file_ids.append(file_id)
+    
+    data["file_ids"] = existing_file_ids
+    knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
+
+    # If there were any errors, include them in the response
+    if result.errors:
+        error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
+        return KnowledgeFilesResponse(
+            **knowledge.model_dump(),
+            files=Files.get_files_by_ids(existing_file_ids),
+            warnings={
+                "message": "Some files failed to process",
+                "errors": error_details
+            }
+        )
+
+    return KnowledgeFilesResponse(
+        **knowledge.model_dump(),
+        files=Files.get_files_by_ids(existing_file_ids)
+    )

+ 96 - 2
backend/open_webui/routers/retrieval.py

@@ -7,7 +7,7 @@ import shutil
 import uuid
 import uuid
 from datetime import datetime
 from datetime import datetime
 from pathlib import Path
 from pathlib import Path
-from typing import Iterator, Optional, Sequence, Union
+from typing import Iterator, List, Optional, Sequence, Union
 
 
 from fastapi import (
 from fastapi import (
     Depends,
     Depends,
@@ -28,7 +28,7 @@ import tiktoken
 from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
 from langchain_core.documents import Document
 from langchain_core.documents import Document
 
 
-from open_webui.models.files import Files
+from open_webui.models.files import FileModel, Files
 from open_webui.models.knowledge import Knowledges
 from open_webui.models.knowledge import Knowledges
 from open_webui.storage.provider import Storage
 from open_webui.storage.provider import Storage
 
 
@@ -1428,3 +1428,97 @@ if ENV == "dev":
     @router.get("/ef/{text}")
     @router.get("/ef/{text}")
     async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
     async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
         return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
         return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
+
+class BatchProcessFilesForm(BaseModel):
+    files: List[FileModel]
+    collection_name: str
+
+class BatchProcessFilesResult(BaseModel):
+    file_id: str
+    status: str
+    error: Optional[str] = None
+
+class BatchProcessFilesResponse(BaseModel):
+    results: List[BatchProcessFilesResult]
+    errors: List[BatchProcessFilesResult]
+
+@router.post("/process/files/batch")
+def process_files_batch(
+    form_data: BatchProcessFilesForm,
+    user=Depends(get_verified_user),
+) -> BatchProcessFilesResponse:
+    """
+    Process a batch of files and save them to the vector database.
+    """
+    results: List[BatchProcessFilesResult] = []
+    errors: List[BatchProcessFilesResult] = []
+    collection_name = form_data.collection_name
+
+    # Prepare all documents first
+    all_docs: List[Document] = []
+    for file in form_data.files:
+        try:
+            text_content = file.data.get("content", "")
+            
+            docs: List[Document] = [
+                Document(
+                    page_content=text_content.replace("<br/>", "\n"),
+                    metadata={
+                        **file.meta,
+                        "name": file.filename,
+                        "created_by": file.user_id,
+                        "file_id": file.id,
+                        "source": file.filename,
+                    },
+                )
+            ]
+
+            hash = calculate_sha256_string(text_content)
+            Files.update_file_hash_by_id(file.id, hash)
+            Files.update_file_data_by_id(file.id, {"content": text_content})
+            
+            all_docs.extend(docs)
+            results.append(BatchProcessFilesResult(
+                file_id=file.id,
+                status="prepared"
+            ))
+
+        except Exception as e:
+            log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
+            errors.append(BatchProcessFilesResult(
+                file_id=file.id,
+                status="failed",
+                error=str(e)
+            ))
+
+    # Save all documents in one batch
+    if all_docs:
+        try:
+            save_docs_to_vector_db(
+                docs=all_docs,
+                collection_name=collection_name,
+                add=True
+            )
+            
+            # Update all files with collection name
+            for result in results:
+                Files.update_file_metadata_by_id(
+                    result.file_id,
+                    {"collection_name": collection_name}
+                )
+                result.status = "completed"
+
+        except Exception as e:
+            log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
+            for result in results:
+                result.status = "failed"
+                errors.append(BatchProcessFilesResult(
+                    file_id=result.file_id,
+                    error=str(e)
+                ))
+
+    return BatchProcessFilesResponse(
+        results=results,
+        errors=errors
+    )
+