Browse Source

Merge pull request #12307 from paddy313/feature/mistral_ocr

feat: Added support for Mistral OCR for Content Extraction
Timothy Jaeryang Baek 3 months ago
parent
commit
0554bbb1cb

+ 5 - 0
backend/open_webui/config.py

@@ -1727,6 +1727,11 @@ DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
     os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
 )
 
+MISTRAL_OCR_API_KEY = PersistentConfig(
+    "MISTRAL_OCR_API_KEY",
+    "rag.mistral_ocr_api_key",
+    os.getenv("MISTRAL_OCR_API_KEY", ""),
+)
 
 BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
     "BYPASS_EMBEDDING_AND_RETRIEVAL",

+ 2 - 0
backend/open_webui/main.py

@@ -191,6 +191,7 @@ from open_webui.config import (
     DOCLING_SERVER_URL,
     DOCUMENT_INTELLIGENCE_ENDPOINT,
     DOCUMENT_INTELLIGENCE_KEY,
+    MISTRAL_OCR_API_KEY,
     RAG_TOP_K,
     RAG_TOP_K_RERANKER,
     RAG_TEXT_SPLITTER,
@@ -582,6 +583,7 @@ app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
 app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
 app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
 app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
+app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
 
 app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
 app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME

+ 12 - 0
backend/open_webui/retrieval/loaders/main.py

@@ -20,6 +20,9 @@ from langchain_community.document_loaders import (
     YoutubeLoader,
 )
 from langchain_core.documents import Document
+
+from open_webui.retrieval.loaders.mistral import MistralLoader
+
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
 
 logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@@ -222,6 +225,15 @@ class Loader:
                 api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
                 api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
             )
+        elif (
+            self.engine == "mistral_ocr"
+            and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
+            and file_ext
+            in ["pdf"]  # Mistral OCR currently only supports PDF and images
+        ):
+            loader = MistralLoader(
+                api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
+            )
         else:
             if file_ext == "pdf":
                 loader = PyPDFLoader(

+ 225 - 0
backend/open_webui/retrieval/loaders/mistral.py

@@ -0,0 +1,225 @@
+import requests
+import logging
+import os
+import sys
+from typing import List, Dict, Any
+
+from langchain_core.documents import Document
+from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+class MistralLoader:
+    """
+    Loads documents by processing them through the Mistral OCR API.
+    """
+
+    BASE_API_URL = "https://api.mistral.ai/v1"
+
+    def __init__(self, api_key: str, file_path: str):
+        """
+        Initializes the loader.
+
+        Args:
+            api_key: Your Mistral API key.
+            file_path: The local path to the PDF file to process.
+        """
+        if not api_key:
+            raise ValueError("API key cannot be empty.")
+        if not os.path.exists(file_path):
+            raise FileNotFoundError(f"File not found at {file_path}")
+
+        self.api_key = api_key
+        self.file_path = file_path
+        self.headers = {"Authorization": f"Bearer {self.api_key}"}
+
+    def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
+        """Checks response status and returns JSON content."""
+        try:
+            response.raise_for_status()  # Raises HTTPError for bad responses (4xx or 5xx)
+            # Handle potential empty responses for certain successful requests (e.g., DELETE)
+            if response.status_code == 204 or not response.content:
+                return {}  # Return empty dict if no content
+            return response.json()
+        except requests.exceptions.HTTPError as http_err:
+            log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
+            raise
+        except requests.exceptions.RequestException as req_err:
+            log.error(f"Request exception occurred: {req_err}")
+            raise
+        except ValueError as json_err:  # Includes JSONDecodeError
+            log.error(f"JSON decode error: {json_err} - Response: {response.text}")
+            raise  # Re-raise after logging
+
+    def _upload_file(self) -> str:
+        """Uploads the file to Mistral for OCR processing."""
+        log.info("Uploading file to Mistral API")
+        url = f"{self.BASE_API_URL}/files"
+        file_name = os.path.basename(self.file_path)
+
+        try:
+            with open(self.file_path, "rb") as f:
+                files = {"file": (file_name, f, "application/pdf")}
+                data = {"purpose": "ocr"}
+
+                upload_headers = self.headers.copy()  # Avoid modifying self.headers
+
+                response = requests.post(
+                    url, headers=upload_headers, files=files, data=data
+                )
+
+            response_data = self._handle_response(response)
+            file_id = response_data.get("id")
+            if not file_id:
+                raise ValueError("File ID not found in upload response.")
+            log.info(f"File uploaded successfully. File ID: {file_id}")
+            return file_id
+        except Exception as e:
+            log.error(f"Failed to upload file: {e}")
+            raise
+
+    def _get_signed_url(self, file_id: str) -> str:
+        """Retrieves a temporary signed URL for the uploaded file."""
+        log.info(f"Getting signed URL for file ID: {file_id}")
+        url = f"{self.BASE_API_URL}/files/{file_id}/url"
+        params = {"expiry": 1}
+        signed_url_headers = {**self.headers, "Accept": "application/json"}
+
+        try:
+            response = requests.get(url, headers=signed_url_headers, params=params)
+            response_data = self._handle_response(response)
+            signed_url = response_data.get("url")
+            if not signed_url:
+                raise ValueError("Signed URL not found in response.")
+            log.info("Signed URL received.")
+            return signed_url
+        except Exception as e:
+            log.error(f"Failed to get signed URL: {e}")
+            raise
+
+    def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
+        """Sends the signed URL to the OCR endpoint for processing."""
+        log.info("Processing OCR via Mistral API")
+        url = f"{self.BASE_API_URL}/ocr"
+        ocr_headers = {
+            **self.headers,
+            "Content-Type": "application/json",
+            "Accept": "application/json",
+        }
+        payload = {
+            "model": "mistral-ocr-latest",
+            "document": {
+                "type": "document_url",
+                "document_url": signed_url,
+            },
+            "include_image_base64": False,
+        }
+
+        try:
+            response = requests.post(url, headers=ocr_headers, json=payload)
+            ocr_response = self._handle_response(response)
+            log.info("OCR processing done.")
+            log.debug("OCR response: %s", ocr_response)
+            return ocr_response
+        except Exception as e:
+            log.error(f"Failed during OCR processing: {e}")
+            raise
+
+    def _delete_file(self, file_id: str) -> None:
+        """Deletes the file from Mistral storage."""
+        log.info(f"Deleting uploaded file ID: {file_id}")
+        url = f"{self.BASE_API_URL}/files/{file_id}"
+        # No specific Accept header needed, default or Authorization is usually sufficient
+
+        try:
+            response = requests.delete(url, headers=self.headers)
+            delete_response = self._handle_response(
+                response
+            )  # Check status, ignore response body unless needed
+            log.info(
+                f"File deleted successfully: {delete_response}"
+            )  # Log the response if available
+        except Exception as e:
+            # Log error but don't necessarily halt execution if deletion fails
+            log.error(f"Failed to delete file ID {file_id}: {e}")
+            # Depending on requirements, you might choose to raise the error here
+
+    def load(self) -> List[Document]:
+        """
+        Executes the full OCR workflow: upload, get URL, process OCR, delete file.
+
+        Returns:
+            A list of Document objects, one for each page processed.
+        """
+        file_id = None
+        try:
+            # 1. Upload file
+            file_id = self._upload_file()
+
+            # 2. Get Signed URL
+            signed_url = self._get_signed_url(file_id)
+
+            # 3. Process OCR
+            ocr_response = self._process_ocr(signed_url)
+
+            # 4. Process results
+            pages_data = ocr_response.get("pages")
+            if not pages_data:
+                log.warning("No pages found in OCR response.")
+                return [Document(page_content="No text content found", metadata={})]
+
+            documents = []
+            total_pages = len(pages_data)
+            for page_data in pages_data:
+                page_content = page_data.get("markdown")
+                page_index = page_data.get("index")  # API uses 0-based index
+
+                if page_content is not None and page_index is not None:
+                    documents.append(
+                        Document(
+                            page_content=page_content,
+                            metadata={
+                                "page": page_index,  # 0-based index from API
+                                "page_label": page_index
+                                + 1,  # 1-based label for convenience
+                                "total_pages": total_pages,
+                                # Add other relevant metadata from page_data if available/needed
+                                # e.g., page_data.get('width'), page_data.get('height')
+                            },
+                        )
+                    )
+                else:
+                    log.warning(
+                        f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
+                    )
+
+            if not documents:
+                # Case where pages existed but none had valid markdown/index
+                log.warning(
+                    "OCR response contained pages, but none had valid content/index."
+                )
+                return [
+                    Document(
+                        page_content="No text content found in valid pages", metadata={}
+                    )
+                ]
+
+            return documents
+
+        except Exception as e:
+            log.error(f"An error occurred during the loading process: {e}")
+            # Return an empty list or a specific error document on failure
+            return [Document(page_content=f"Error during processing: {e}", metadata={})]
+        finally:
+            # 5. Delete file (attempt even if prior steps failed after upload)
+            if file_id:
+                try:
+                    self._delete_file(file_id)
+                except Exception as del_e:
+                    # Log deletion error, but don't overwrite original error if one occurred
+                    log.error(
+                        f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
+                    )

+ 16 - 0
backend/open_webui/routers/retrieval.py

@@ -364,6 +364,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
                 "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
                 "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
             },
+            "mistral_ocr_config": {
+                "api_key": request.app.state.config.MISTRAL_OCR_API_KEY,
+            },
         },
         "chunk": {
             "text_splitter": request.app.state.config.TEXT_SPLITTER,
@@ -427,11 +430,16 @@ class DocumentIntelligenceConfigForm(BaseModel):
     key: str
 
 
+class MistralOCRConfigForm(BaseModel):
+    api_key: str
+
+
 class ContentExtractionConfig(BaseModel):
     engine: str = ""
     tika_server_url: Optional[str] = None
     docling_server_url: Optional[str] = None
     document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
+    mistral_ocr_config: Optional[MistralOCRConfigForm] = None
 
 
 class ChunkParamUpdateForm(BaseModel):
@@ -553,6 +561,10 @@ async def update_rag_config(
             request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
                 form_data.content_extraction.document_intelligence_config.key
             )
+        if form_data.content_extraction.mistral_ocr_config is not None:
+            request.app.state.config.MISTRAL_OCR_API_KEY = (
+                form_data.content_extraction.mistral_ocr_config.api_key
+            )
 
     if form_data.chunk is not None:
         request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
@@ -659,6 +671,9 @@ async def update_rag_config(
                 "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
                 "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
             },
+            "mistral_ocr_config": {
+                "api_key": request.app.state.config.MISTRAL_OCR_API_KEY,
+            },
         },
         "chunk": {
             "text_splitter": request.app.state.config.TEXT_SPLITTER,
@@ -1007,6 +1022,7 @@ def process_file(
                     PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
                     DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
                     DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
+                    MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY,
                 )
                 docs = loader.load(
                     file.filename, file.meta.get("content_type"), file_path

+ 22 - 5
src/lib/components/admin/Settings/Documents.svelte

@@ -54,6 +54,8 @@
 	let documentIntelligenceEndpoint = '';
 	let documentIntelligenceKey = '';
 	let showDocumentIntelligenceConfig = false;
+	let mistralApiKey = '';
+	let showMistralOcrConfig = false;
 
 	let textSplitter = '';
 	let chunkSize = 0;
@@ -189,6 +191,10 @@
 			toast.error($i18n.t('Document Intelligence endpoint and key required.'));
 			return;
 		}
+		if (contentExtractionEngine === 'mistral_ocr' && mistralApiKey === '') {
+			toast.error($i18n.t('Mistral OCR API Key required.'));
+			return;
+		}
 
 		if (!BYPASS_EMBEDDING_AND_RETRIEVAL) {
 			await embeddingModelUpdateHandler();
@@ -220,6 +226,9 @@
 				document_intelligence_config: {
 					key: documentIntelligenceKey,
 					endpoint: documentIntelligenceEndpoint
+				},
+				mistral_ocr_config: {
+					api_key: mistralApiKey
 				}
 			}
 		});
@@ -284,6 +293,8 @@
 			documentIntelligenceEndpoint = res.content_extraction.document_intelligence_config.endpoint;
 			documentIntelligenceKey = res.content_extraction.document_intelligence_config.key;
 			showDocumentIntelligenceConfig = contentExtractionEngine === 'document_intelligence';
+			mistralApiKey = res.content_extraction.mistral_ocr_config.api_key;
+			showMistralOcrConfig = contentExtractionEngine === 'mistral_ocr';
 
 			fileMaxSize = res?.file.max_size ?? '';
 			fileMaxCount = res?.file.max_count ?? '';
@@ -335,21 +346,21 @@
 
 				<hr class=" border-gray-100 dark:border-gray-850 my-2" />
 
-				<div class="  mb-2.5 flex flex-col w-full justify-between">
+				<div class="mb-2.5 flex flex-col w-full justify-between">
 					<div class="flex w-full justify-between">
-						<div class=" self-center text-xs font-medium">
+						<div class="self-center text-xs font-medium">
 							{$i18n.t('Content Extraction Engine')}
 						</div>
-
 						<div class="">
 							<select
 								class="dark:bg-gray-900 w-fit pr-8 rounded-sm px-2 text-xs bg-transparent outline-hidden text-right"
 								bind:value={contentExtractionEngine}
 							>
-								<option value="">{$i18n.t('Default')} </option>
+								<option value="">{$i18n.t('Default')}</option>
 								<option value="tika">{$i18n.t('Tika')}</option>
 								<option value="docling">{$i18n.t('Docling')}</option>
 								<option value="document_intelligence">{$i18n.t('Document Intelligence')}</option>
+								<option value="mistral_ocr">{$i18n.t('Mistral OCR')}</option>
 							</select>
 						</div>
 					</div>
@@ -378,12 +389,18 @@
 								placeholder={$i18n.t('Enter Document Intelligence Endpoint')}
 								bind:value={documentIntelligenceEndpoint}
 							/>
-
 							<SensitiveInput
 								placeholder={$i18n.t('Enter Document Intelligence Key')}
 								bind:value={documentIntelligenceKey}
 							/>
 						</div>
+					{:else if contentExtractionEngine === 'mistral_ocr'}
+						<div class="my-0.5 flex gap-2 pr-2">
+							<SensitiveInput
+								placeholder={$i18n.t('Enter Mistral API Key')}
+								bind:value={mistralApiKey}
+							/>
+						</div>
 					{/if}
 				</div>