소스 검색

Merge pull request #14245 from PVBLIC-F/dev

perf Update mistral.py
Tim Jaeryang Baek 4 달 전
부모
커밋
c8f1bdf928
1개의 변경된 파일423개의 추가작업 그리고 73개의 파일을 삭제
  1. 423 73
      backend/open_webui/retrieval/loaders/mistral.py

+ 423 - 73
backend/open_webui/retrieval/loaders/mistral.py

@@ -1,8 +1,12 @@
 import requests
+import aiohttp
+import asyncio
 import logging
 import os
 import sys
+import time
 from typing import List, Dict, Any
+from contextlib import asynccontextmanager
 
 from langchain_core.documents import Document
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
@@ -14,18 +18,29 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 class MistralLoader:
     """
+    Enhanced Mistral OCR loader with both sync and async support.
     Loads documents by processing them through the Mistral OCR API.
     """
 
     BASE_API_URL = "https://api.mistral.ai/v1"
 
-    def __init__(self, api_key: str, file_path: str):
+    def __init__(
+        self, 
+        api_key: str, 
+        file_path: str,
+        timeout: int = 300,  # 5 minutes default
+        max_retries: int = 3,
+        enable_debug_logging: bool = False
+    ):
         """
-        Initializes the loader.
+        Initializes the loader with enhanced features.
 
         Args:
             api_key: Your Mistral API key.
             file_path: The local path to the PDF file to process.
+            timeout: Request timeout in seconds.
+            max_retries: Maximum number of retry attempts.
+            enable_debug_logging: Enable detailed debug logs.
         """
         if not api_key:
             raise ValueError("API key cannot be empty.")
@@ -34,7 +49,23 @@ class MistralLoader:
 
         self.api_key = api_key
         self.file_path = file_path
-        self.headers = {"Authorization": f"Bearer {self.api_key}"}
+        self.timeout = timeout
+        self.max_retries = max_retries
+        self.debug = enable_debug_logging
+        
+        # Pre-compute file info for performance
+        self.file_name = os.path.basename(file_path)
+        self.file_size = os.path.getsize(file_path)
+        
+        self.headers = {
+            "Authorization": f"Bearer {self.api_key}",
+            "User-Agent": "OpenWebUI-MistralLoader/2.0"
+        }
+
+    def _debug_log(self, message: str, *args) -> None:
+        """Conditional debug logging for performance."""
+        if self.debug:
+            log.debug(message, *args)
 
     def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
         """Checks response status and returns JSON content."""
@@ -54,24 +85,81 @@ class MistralLoader:
             log.error(f"JSON decode error: {json_err} - Response: {response.text}")
             raise  # Re-raise after logging
 
+    async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
+        """Async version of response handling with better error info."""
+        try:
+            response.raise_for_status()
+            
+            # Check content type
+            content_type = response.headers.get('content-type', '')
+            if 'application/json' not in content_type:
+                if response.status == 204:
+                    return {}
+                text = await response.text()
+                raise ValueError(f"Unexpected content type: {content_type}, body: {text[:200]}...")
+            
+            return await response.json()
+            
+        except aiohttp.ClientResponseError as e:
+            error_text = await response.text() if response else "No response"
+            log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
+            raise
+        except aiohttp.ClientError as e:
+            log.error(f"Client error: {e}")
+            raise
+        except Exception as e:
+            log.error(f"Unexpected error processing response: {e}")
+            raise
+
+    def _retry_request_sync(self, request_func, *args, **kwargs):
+        """Synchronous retry logic with exponential backoff."""
+        for attempt in range(self.max_retries):
+            try:
+                return request_func(*args, **kwargs)
+            except (requests.exceptions.RequestException, Exception) as e:
+                if attempt == self.max_retries - 1:
+                    raise
+                
+                wait_time = (2 ** attempt) + 0.5
+                log.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...")
+                time.sleep(wait_time)
+
+    async def _retry_request_async(self, request_func, *args, **kwargs):
+        """Async retry logic with exponential backoff."""
+        for attempt in range(self.max_retries):
+            try:
+                return await request_func(*args, **kwargs)
+            except (aiohttp.ClientError, asyncio.TimeoutError) as e:
+                if attempt == self.max_retries - 1:
+                    raise
+                
+                wait_time = (2 ** attempt) + 0.5
+                log.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...")
+                await asyncio.sleep(wait_time)
+
     def _upload_file(self) -> str:
-        """Uploads the file to Mistral for OCR processing."""
+        """Uploads the file to Mistral for OCR processing (sync version)."""
         log.info("Uploading file to Mistral API")
         url = f"{self.BASE_API_URL}/files"
         file_name = os.path.basename(self.file_path)
 
-        try:
+        def upload_request():
             with open(self.file_path, "rb") as f:
                 files = {"file": (file_name, f, "application/pdf")}
                 data = {"purpose": "ocr"}
 
-                upload_headers = self.headers.copy()  # Avoid modifying self.headers
-
                 response = requests.post(
-                    url, headers=upload_headers, files=files, data=data
+                    url, 
+                    headers=self.headers, 
+                    files=files, 
+                    data=data,
+                    timeout=self.timeout
                 )
 
-            response_data = self._handle_response(response)
+            return self._handle_response(response)
+
+        try:
+            response_data = self._retry_request_sync(upload_request)
             file_id = response_data.get("id")
             if not file_id:
                 raise ValueError("File ID not found in upload response.")
@@ -81,16 +169,63 @@ class MistralLoader:
             log.error(f"Failed to upload file: {e}")
             raise
 
+    async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
+        """Async file upload with streaming for better memory efficiency."""
+        url = f"{self.BASE_API_URL}/files"
+        
+        async def upload_request():
+            # Create multipart writer for streaming upload
+            writer = aiohttp.MultipartWriter('form-data')
+            
+            # Add purpose field
+            purpose_part = writer.append('ocr')
+            purpose_part.set_content_disposition('form-data', name='purpose')
+            
+            # Add file part with streaming
+            file_part = writer.append_payload(aiohttp.streams.FilePayload(
+                self.file_path,
+                filename=self.file_name,
+                content_type='application/pdf'
+            ))
+            file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
+            
+            self._debug_log(f"Uploading file: {self.file_name} ({self.file_size:,} bytes)")
+            
+            async with session.post(
+                url, 
+                data=writer, 
+                headers=self.headers,
+                timeout=aiohttp.ClientTimeout(total=self.timeout)
+            ) as response:
+                return await self._handle_response_async(response)
+                
+        response_data = await self._retry_request_async(upload_request)
+        
+        file_id = response_data.get("id")
+        if not file_id:
+            raise ValueError("File ID not found in upload response.")
+        
+        log.info(f"File uploaded successfully. File ID: {file_id}")
+        return file_id
+
     def _get_signed_url(self, file_id: str) -> str:
-        """Retrieves a temporary signed URL for the uploaded file."""
+        """Retrieves a temporary signed URL for the uploaded file (sync version)."""
         log.info(f"Getting signed URL for file ID: {file_id}")
         url = f"{self.BASE_API_URL}/files/{file_id}/url"
         params = {"expiry": 1}
         signed_url_headers = {**self.headers, "Accept": "application/json"}
 
+        def url_request():
+            response = requests.get(
+                url, 
+                headers=signed_url_headers, 
+                params=params,
+                timeout=self.timeout
+            )
+            return self._handle_response(response)
+
         try:
-            response = requests.get(url, headers=signed_url_headers, params=params)
-            response_data = self._handle_response(response)
+            response_data = self._retry_request_sync(url_request)
             signed_url = response_data.get("url")
             if not signed_url:
                 raise ValueError("Signed URL not found in response.")
@@ -100,8 +235,37 @@ class MistralLoader:
             log.error(f"Failed to get signed URL: {e}")
             raise
 
+    async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
+        """Async signed URL retrieval."""
+        url = f"{self.BASE_API_URL}/files/{file_id}/url"
+        params = {"expiry": 1}
+        
+        headers = {
+            **self.headers,
+            "Accept": "application/json"
+        }
+        
+        async def url_request():
+            self._debug_log(f"Getting signed URL for file ID: {file_id}")
+            async with session.get(
+                url, 
+                headers=headers, 
+                params=params,
+                timeout=aiohttp.ClientTimeout(total=self.timeout)
+            ) as response:
+                return await self._handle_response_async(response)
+        
+        response_data = await self._retry_request_async(url_request)
+        
+        signed_url = response_data.get("url")
+        if not signed_url:
+            raise ValueError("Signed URL not found in response.")
+        
+        self._debug_log("Signed URL received successfully")
+        return signed_url
+
     def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
-        """Sends the signed URL to the OCR endpoint for processing."""
+        """Sends the signed URL to the OCR endpoint for processing (sync version)."""
         log.info("Processing OCR via Mistral API")
         url = f"{self.BASE_API_URL}/ocr"
         ocr_headers = {
@@ -118,43 +282,179 @@ class MistralLoader:
             "include_image_base64": False,
         }
 
+        def ocr_request():
+            response = requests.post(
+                url, 
+                headers=ocr_headers, 
+                json=payload,
+                timeout=self.timeout
+            )
+            return self._handle_response(response)
+
         try:
-            response = requests.post(url, headers=ocr_headers, json=payload)
-            ocr_response = self._handle_response(response)
+            ocr_response = self._retry_request_sync(ocr_request)
             log.info("OCR processing done.")
-            log.debug("OCR response: %s", ocr_response)
+            self._debug_log("OCR response: %s", ocr_response)
             return ocr_response
         except Exception as e:
             log.error(f"Failed during OCR processing: {e}")
             raise
 
+    async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
+        """Async OCR processing with timing metrics."""
+        url = f"{self.BASE_API_URL}/ocr"
+        
+        headers = {
+            **self.headers,
+            "Content-Type": "application/json",
+            "Accept": "application/json",
+        }
+        
+        payload = {
+            "model": "mistral-ocr-latest",
+            "document": {
+                "type": "document_url",
+                "document_url": signed_url,
+            },
+            "include_image_base64": False,
+        }
+        
+        async def ocr_request():
+            log.info("Starting OCR processing via Mistral API")
+            start_time = time.time()
+            
+            async with session.post(
+                url, 
+                json=payload, 
+                headers=headers,
+                timeout=aiohttp.ClientTimeout(total=self.timeout)
+            ) as response:
+                ocr_response = await self._handle_response_async(response)
+                
+            processing_time = time.time() - start_time
+            log.info(f"OCR processing completed in {processing_time:.2f}s")
+            
+            return ocr_response
+        
+        return await self._retry_request_async(ocr_request)
+
     def _delete_file(self, file_id: str) -> None:
-        """Deletes the file from Mistral storage."""
+        """Deletes the file from Mistral storage (sync version)."""
         log.info(f"Deleting uploaded file ID: {file_id}")
         url = f"{self.BASE_API_URL}/files/{file_id}"
-        # No specific Accept header needed, default or Authorization is usually sufficient
 
         try:
-            response = requests.delete(url, headers=self.headers)
-            delete_response = self._handle_response(
-                response
-            )  # Check status, ignore response body unless needed
-            log.info(
-                f"File deleted successfully: {delete_response}"
-            )  # Log the response if available
+            response = requests.delete(url, headers=self.headers, timeout=30)
+            delete_response = self._handle_response(response)
+            log.info(f"File deleted successfully: {delete_response}")
         except Exception as e:
             # Log error but don't necessarily halt execution if deletion fails
             log.error(f"Failed to delete file ID {file_id}: {e}")
-            # Depending on requirements, you might choose to raise the error here
+
+    async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
+        """Async file deletion with error tolerance."""
+        try:
+            async def delete_request():
+                self._debug_log(f"Deleting file ID: {file_id}")
+                async with session.delete(
+                    url=f"{self.BASE_API_URL}/files/{file_id}", 
+                    headers=self.headers,
+                    timeout=aiohttp.ClientTimeout(total=30)  # Shorter timeout for cleanup
+                ) as response:
+                    return await self._handle_response_async(response)
+            
+            await self._retry_request_async(delete_request)
+            self._debug_log(f"File {file_id} deleted successfully")
+            
+        except Exception as e:
+            # Don't fail the entire process if cleanup fails
+            log.warning(f"Failed to delete file ID {file_id}: {e}")
+
+    @asynccontextmanager
+    async def _get_session(self):
+        """Context manager for HTTP session with optimized settings."""
+        connector = aiohttp.TCPConnector(
+            limit=10,  # Total connection limit
+            limit_per_host=5,  # Per-host connection limit
+            ttl_dns_cache=300,  # DNS cache TTL
+            use_dns_cache=True,
+            keepalive_timeout=30,
+            enable_cleanup_closed=True
+        )
+        
+        async with aiohttp.ClientSession(
+            connector=connector,
+            timeout=aiohttp.ClientTimeout(total=self.timeout),
+            headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}
+        ) as session:
+            yield session
+
+    def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
+        """Process OCR results into Document objects with enhanced metadata."""
+        pages_data = ocr_response.get("pages")
+        if not pages_data:
+            log.warning("No pages found in OCR response.")
+            return [Document(page_content="No text content found", metadata={"error": "no_pages"})]
+
+        documents = []
+        total_pages = len(pages_data)
+        skipped_pages = 0
+        
+        for page_data in pages_data:
+            page_content = page_data.get("markdown")
+            page_index = page_data.get("index")  # API uses 0-based index
+
+            if page_content is not None and page_index is not None:
+                # Clean up content efficiently
+                cleaned_content = page_content.strip() if isinstance(page_content, str) else str(page_content)
+                
+                if cleaned_content:  # Only add non-empty pages
+                    documents.append(
+                        Document(
+                            page_content=cleaned_content,
+                            metadata={
+                                "page": page_index,  # 0-based index from API
+                                "page_label": page_index + 1,  # 1-based label for convenience
+                                "total_pages": total_pages,
+                                "file_name": self.file_name,
+                                "file_size": self.file_size,
+                                "processing_engine": "mistral-ocr"
+                            },
+                        )
+                    )
+                else:
+                    skipped_pages += 1
+                    self._debug_log(f"Skipping empty page {page_index}")
+            else:
+                skipped_pages += 1
+                self._debug_log(f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}")
+
+        if skipped_pages > 0:
+            log.info(f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages")
+
+        if not documents:
+            # Case where pages existed but none had valid markdown/index
+            log.warning("OCR response contained pages, but none had valid content/index.")
+            return [
+                Document(
+                    page_content="No valid text content found in document",
+                    metadata={"error": "no_valid_pages", "total_pages": total_pages}
+                )
+            ]
+
+        return documents
 
     def load(self) -> List[Document]:
         """
         Executes the full OCR workflow: upload, get URL, process OCR, delete file.
+        Synchronous version for backward compatibility.
 
         Returns:
             A list of Document objects, one for each page processed.
         """
         file_id = None
+        start_time = time.time()
+        
         try:
             # 1. Upload file
             file_id = self._upload_file()
@@ -166,53 +466,21 @@ class MistralLoader:
             ocr_response = self._process_ocr(signed_url)
 
             # 4. Process results
-            pages_data = ocr_response.get("pages")
-            if not pages_data:
-                log.warning("No pages found in OCR response.")
-                return [Document(page_content="No text content found", metadata={})]
-
-            documents = []
-            total_pages = len(pages_data)
-            for page_data in pages_data:
-                page_content = page_data.get("markdown")
-                page_index = page_data.get("index")  # API uses 0-based index
-
-                if page_content is not None and page_index is not None:
-                    documents.append(
-                        Document(
-                            page_content=page_content,
-                            metadata={
-                                "page": page_index,  # 0-based index from API
-                                "page_label": page_index
-                                + 1,  # 1-based label for convenience
-                                "total_pages": total_pages,
-                                # Add other relevant metadata from page_data if available/needed
-                                # e.g., page_data.get('width'), page_data.get('height')
-                            },
-                        )
-                    )
-                else:
-                    log.warning(
-                        f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
-                    )
-
-            if not documents:
-                # Case where pages existed but none had valid markdown/index
-                log.warning(
-                    "OCR response contained pages, but none had valid content/index."
-                )
-                return [
-                    Document(
-                        page_content="No text content found in valid pages", metadata={}
-                    )
-                ]
-
+            documents = self._process_results(ocr_response)
+            
+            total_time = time.time() - start_time
+            log.info(f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents")
+            
             return documents
 
         except Exception as e:
-            log.error(f"An error occurred during the loading process: {e}")
-            # Return an empty list or a specific error document on failure
-            return [Document(page_content=f"Error during processing: {e}", metadata={})]
+            total_time = time.time() - start_time
+            log.error(f"An error occurred during the loading process after {total_time:.2f}s: {e}")
+            # Return an error document on failure
+            return [Document(
+                page_content=f"Error during processing: {e}", 
+                metadata={"error": "processing_failed", "file_name": self.file_name}
+            )]
         finally:
             # 5. Delete file (attempt even if prior steps failed after upload)
             if file_id:
@@ -220,6 +488,88 @@ class MistralLoader:
                     self._delete_file(file_id)
                 except Exception as del_e:
                     # Log deletion error, but don't overwrite original error if one occurred
-                    log.error(
-                        f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
-                    )
+                    log.error(f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}")
+
+    async def load_async(self) -> List[Document]:
+        """
+        Asynchronous OCR workflow execution with optimized performance.
+        
+        Returns:
+            A list of Document objects, one for each page processed.
+        """
+        file_id = None
+        start_time = time.time()
+        
+        try:
+            async with self._get_session() as session:
+                # 1. Upload file with streaming
+                file_id = await self._upload_file_async(session)
+
+                # 2. Get signed URL
+                signed_url = await self._get_signed_url_async(session, file_id)
+
+                # 3. Process OCR
+                ocr_response = await self._process_ocr_async(session, signed_url)
+
+                # 4. Process results
+                documents = self._process_results(ocr_response)
+                
+                total_time = time.time() - start_time
+                log.info(f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents")
+                
+                return documents
+
+        except Exception as e:
+            total_time = time.time() - start_time
+            log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
+            return [Document(
+                page_content=f"Error during OCR processing: {e}",
+                metadata={"error": "processing_failed", "file_name": self.file_name}
+            )]
+        finally:
+            # 5. Cleanup - always attempt file deletion
+            if file_id:
+                try:
+                    async with self._get_session() as session:
+                        await self._delete_file_async(session, file_id)
+                except Exception as cleanup_error:
+                    log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
+
+    @staticmethod
+    async def load_multiple_async(loaders: List['MistralLoader']) -> List[List[Document]]:
+        """
+        Process multiple files concurrently for maximum performance.
+        
+        Args:
+            loaders: List of MistralLoader instances
+            
+        Returns:
+            List of document lists, one for each loader
+        """
+        if not loaders:
+            return []
+            
+        log.info(f"Starting concurrent processing of {len(loaders)} files")
+        start_time = time.time()
+        
+        # Process all files concurrently
+        tasks = [loader.load_async() for loader in loaders]
+        results = await asyncio.gather(*tasks, return_exceptions=True)
+        
+        # Handle any exceptions in results
+        processed_results = []
+        for i, result in enumerate(results):
+            if isinstance(result, Exception):
+                log.error(f"File {i} failed: {result}")
+                processed_results.append([Document(
+                    page_content=f"Error processing file: {result}",
+                    metadata={"error": "batch_processing_failed", "file_index": i}
+                )])
+            else:
+                processed_results.append(result)
+        
+        total_time = time.time() - start_time
+        total_docs = sum(len(docs) for docs in processed_results)
+        log.info(f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents")
+        
+        return processed_results