Переглянути джерело

refac: robust file upload failed handling

Timothy Jaeryang Baek 1 тиждень тому
батько
коміт
23a51f2d01
2 змінених файлів з 199 додано та 171 видалено
  1. 11 0
      backend/open_webui/models/files.py
  2. 188 171
      backend/open_webui/routers/retrieval.py

+ 11 - 0
backend/open_webui/models/files.py

@@ -130,6 +130,17 @@ class FilesTable:
             except Exception:
                 return None
 
+    def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
+        with get_db() as db:
+            try:
+                file = db.query(File).filter_by(id=id, user_id=user_id).first()
+                if file:
+                    return FileModel.model_validate(file)
+                else:
+                    return None
+            except Exception:
+                return None
+
     def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
         with get_db() as db:
             try:

+ 188 - 171
backend/open_webui/routers/retrieval.py

@@ -1407,59 +1407,35 @@ def process_file(
     form_data: ProcessFileForm,
     user=Depends(get_verified_user),
 ):
-    try:
+    if user.role == "admin":
         file = Files.get_file_by_id(form_data.file_id)
+    else:
+        file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id)
 
-        collection_name = form_data.collection_name
+    if file:
+        try:
 
-        if collection_name is None:
-            collection_name = f"file-{file.id}"
+            collection_name = form_data.collection_name
 
-        if form_data.content:
-            # Update the content in the file
-            # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
+            if collection_name is None:
+                collection_name = f"file-{file.id}"
 
-            try:
-                # /files/{file_id}/data/content/update
-                VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
-            except:
-                # Audio file upload pipeline
-                pass
+            if form_data.content:
+                # Update the content in the file
+                # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline)
 
-            docs = [
-                Document(
-                    page_content=form_data.content.replace("<br/>", "\n"),
-                    metadata={
-                        **file.meta,
-                        "name": file.filename,
-                        "created_by": file.user_id,
-                        "file_id": file.id,
-                        "source": file.filename,
-                    },
-                )
-            ]
-
-            text_content = form_data.content
-        elif form_data.collection_name:
-            # Check if the file has already been processed and save the content
-            # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
-
-            result = VECTOR_DB_CLIENT.query(
-                collection_name=f"file-{file.id}", filter={"file_id": file.id}
-            )
-
-            if result is not None and len(result.ids[0]) > 0:
-                docs = [
-                    Document(
-                        page_content=result.documents[0][idx],
-                        metadata=result.metadatas[0][idx],
+                try:
+                    # /files/{file_id}/data/content/update
+                    VECTOR_DB_CLIENT.delete_collection(
+                        collection_name=f"file-{file.id}"
                     )
-                    for idx, id in enumerate(result.ids[0])
-                ]
-            else:
+                except:
+                    # Audio file upload pipeline
+                    pass
+
                 docs = [
                     Document(
-                        page_content=file.data.get("content", ""),
+                        page_content=form_data.content.replace("<br/>", "\n"),
                         metadata={
                             **file.meta,
                             "name": file.filename,
@@ -1470,148 +1446,189 @@ def process_file(
                     )
                 ]
 
-            text_content = file.data.get("content", "")
-        else:
-            # Process the file and save the content
-            # Usage: /files/
-            file_path = file.path
-            if file_path:
-                file_path = Storage.get_file(file_path)
-                loader = Loader(
-                    engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
-                    DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
-                    DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
-                    DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
-                    DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
-                    DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
-                    DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
-                    DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
-                    DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
-                    DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
-                    DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
-                    DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
-                    EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
-                    EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
-                    TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
-                    DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
-                    DOCLING_PARAMS={
-                        "do_ocr": request.app.state.config.DOCLING_DO_OCR,
-                        "force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
-                        "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
-                        "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
-                        "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
-                        "table_mode": request.app.state.config.DOCLING_TABLE_MODE,
-                        "pipeline": request.app.state.config.DOCLING_PIPELINE,
-                        "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
-                        "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
-                        "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
-                        "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
-                    },
-                    PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
-                    DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
-                    DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
-                    MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
-                )
-                docs = loader.load(
-                    file.filename, file.meta.get("content_type"), file_path
+                text_content = form_data.content
+            elif form_data.collection_name:
+                # Check if the file has already been processed and save the content
+                # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
+
+                result = VECTOR_DB_CLIENT.query(
+                    collection_name=f"file-{file.id}", filter={"file_id": file.id}
                 )
 
-                docs = [
-                    Document(
-                        page_content=doc.page_content,
-                        metadata={
-                            **doc.metadata,
-                            "name": file.filename,
-                            "created_by": file.user_id,
-                            "file_id": file.id,
-                            "source": file.filename,
+                if result is not None and len(result.ids[0]) > 0:
+                    docs = [
+                        Document(
+                            page_content=result.documents[0][idx],
+                            metadata=result.metadatas[0][idx],
+                        )
+                        for idx, id in enumerate(result.ids[0])
+                    ]
+                else:
+                    docs = [
+                        Document(
+                            page_content=file.data.get("content", ""),
+                            metadata={
+                                **file.meta,
+                                "name": file.filename,
+                                "created_by": file.user_id,
+                                "file_id": file.id,
+                                "source": file.filename,
+                            },
+                        )
+                    ]
+
+                text_content = file.data.get("content", "")
+            else:
+                # Process the file and save the content
+                # Usage: /files/
+                file_path = file.path
+                if file_path:
+                    file_path = Storage.get_file(file_path)
+                    loader = Loader(
+                        engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
+                        DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY,
+                        DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL,
+                        DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG,
+                        DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
+                        DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR,
+                        DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE,
+                        DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
+                        DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
+                        DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES,
+                        DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM,
+                        DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
+                        EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
+                        EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
+                        TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
+                        DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
+                        DOCLING_PARAMS={
+                            "do_ocr": request.app.state.config.DOCLING_DO_OCR,
+                            "force_ocr": request.app.state.config.DOCLING_FORCE_OCR,
+                            "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
+                            "ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
+                            "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND,
+                            "table_mode": request.app.state.config.DOCLING_TABLE_MODE,
+                            "pipeline": request.app.state.config.DOCLING_PIPELINE,
+                            "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
+                            "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
+                            "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
+                            "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
                         },
+                        PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
+                        DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
+                        DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
+                        MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
                     )
-                    for doc in docs
-                ]
+                    docs = loader.load(
+                        file.filename, file.meta.get("content_type"), file_path
+                    )
+
+                    docs = [
+                        Document(
+                            page_content=doc.page_content,
+                            metadata={
+                                **doc.metadata,
+                                "name": file.filename,
+                                "created_by": file.user_id,
+                                "file_id": file.id,
+                                "source": file.filename,
+                            },
+                        )
+                        for doc in docs
+                    ]
+                else:
+                    docs = [
+                        Document(
+                            page_content=file.data.get("content", ""),
+                            metadata={
+                                **file.meta,
+                                "name": file.filename,
+                                "created_by": file.user_id,
+                                "file_id": file.id,
+                                "source": file.filename,
+                            },
+                        )
+                    ]
+                text_content = " ".join([doc.page_content for doc in docs])
+
+            log.debug(f"text_content: {text_content}")
+            Files.update_file_data_by_id(
+                file.id,
+                {"content": text_content},
+            )
+            hash = calculate_sha256_string(text_content)
+            Files.update_file_hash_by_id(file.id, hash)
+
+            if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
+                Files.update_file_data_by_id(file.id, {"status": "completed"})
+                return {
+                    "status": True,
+                    "collection_name": None,
+                    "filename": file.filename,
+                    "content": text_content,
+                }
             else:
-                docs = [
-                    Document(
-                        page_content=file.data.get("content", ""),
+                try:
+                    result = save_docs_to_vector_db(
+                        request,
+                        docs=docs,
+                        collection_name=collection_name,
                         metadata={
-                            **file.meta,
-                            "name": file.filename,
-                            "created_by": file.user_id,
                             "file_id": file.id,
-                            "source": file.filename,
+                            "name": file.filename,
+                            "hash": hash,
                         },
+                        add=(True if form_data.collection_name else False),
+                        user=user,
                     )
-                ]
-            text_content = " ".join([doc.page_content for doc in docs])
-
-        log.debug(f"text_content: {text_content}")
-        Files.update_file_data_by_id(
-            file.id,
-            {"content": text_content},
-        )
-        hash = calculate_sha256_string(text_content)
-        Files.update_file_hash_by_id(file.id, hash)
+                    log.info(f"added {len(docs)} items to collection {collection_name}")
+
+                    if result:
+                        Files.update_file_metadata_by_id(
+                            file.id,
+                            {
+                                "collection_name": collection_name,
+                            },
+                        )
 
-        if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
-            Files.update_file_data_by_id(file.id, {"status": "completed"})
-            return {
-                "status": True,
-                "collection_name": None,
-                "filename": file.filename,
-                "content": text_content,
-            }
-        else:
-            try:
-                result = save_docs_to_vector_db(
-                    request,
-                    docs=docs,
-                    collection_name=collection_name,
-                    metadata={
-                        "file_id": file.id,
-                        "name": file.filename,
-                        "hash": hash,
-                    },
-                    add=(True if form_data.collection_name else False),
-                    user=user,
-                )
-                log.info(f"added {len(docs)} items to collection {collection_name}")
+                        Files.update_file_data_by_id(
+                            file.id,
+                            {"status": "completed"},
+                        )
 
-                if result:
-                    Files.update_file_metadata_by_id(
-                        file.id,
-                        {
+                        return {
+                            "status": True,
                             "collection_name": collection_name,
-                        },
-                    )
+                            "filename": file.filename,
+                            "content": text_content,
+                        }
+                    else:
+                        raise Exception("Error saving document to vector database")
+                except Exception as e:
+                    raise e
 
-                    Files.update_file_data_by_id(
-                        file.id,
-                        {"status": "completed"},
-                    )
+        except Exception as e:
+            log.exception(e)
+            Files.update_file_data_by_id(
+                file.id,
+                {"status": "failed"},
+            )
 
-                    return {
-                        "status": True,
-                        "collection_name": collection_name,
-                        "filename": file.filename,
-                        "content": text_content,
-                    }
-                else:
-                    raise Exception("Error saving document to vector database")
-            except Exception as e:
-                raise e
+            if "No pandoc was found" in str(e):
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
+                )
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=str(e),
+                )
 
-    except Exception as e:
-        log.exception(e)
-        if "No pandoc was found" in str(e):
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=str(e),
-            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
+        )
 
 
 class ProcessTextForm(BaseModel):