mistral.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. import requests
  2. import aiohttp
  3. import asyncio
  4. import logging
  5. import os
  6. import sys
  7. import time
  8. from typing import List, Dict, Any
  9. from contextlib import asynccontextmanager
  10. from langchain_core.documents import Document
  11. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  12. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  13. log = logging.getLogger(__name__)
  14. log.setLevel(SRC_LOG_LEVELS["RAG"])
  15. class MistralLoader:
  16. """
  17. Enhanced Mistral OCR loader with both sync and async support.
  18. Loads documents by processing them through the Mistral OCR API.
  19. """
  20. BASE_API_URL = "https://api.mistral.ai/v1"
  21. def __init__(
  22. self,
  23. api_key: str,
  24. file_path: str,
  25. timeout: int = 300, # 5 minutes default
  26. max_retries: int = 3,
  27. enable_debug_logging: bool = False
  28. ):
  29. """
  30. Initializes the loader with enhanced features.
  31. Args:
  32. api_key: Your Mistral API key.
  33. file_path: The local path to the PDF file to process.
  34. timeout: Request timeout in seconds.
  35. max_retries: Maximum number of retry attempts.
  36. enable_debug_logging: Enable detailed debug logs.
  37. """
  38. if not api_key:
  39. raise ValueError("API key cannot be empty.")
  40. if not os.path.exists(file_path):
  41. raise FileNotFoundError(f"File not found at {file_path}")
  42. self.api_key = api_key
  43. self.file_path = file_path
  44. self.timeout = timeout
  45. self.max_retries = max_retries
  46. self.debug = enable_debug_logging
  47. # Pre-compute file info for performance
  48. self.file_name = os.path.basename(file_path)
  49. self.file_size = os.path.getsize(file_path)
  50. self.headers = {
  51. "Authorization": f"Bearer {self.api_key}",
  52. "User-Agent": "OpenWebUI-MistralLoader/2.0"
  53. }
  54. def _debug_log(self, message: str, *args) -> None:
  55. """Conditional debug logging for performance."""
  56. if self.debug:
  57. log.debug(message, *args)
  58. def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
  59. """Checks response status and returns JSON content."""
  60. try:
  61. response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
  62. # Handle potential empty responses for certain successful requests (e.g., DELETE)
  63. if response.status_code == 204 or not response.content:
  64. return {} # Return empty dict if no content
  65. return response.json()
  66. except requests.exceptions.HTTPError as http_err:
  67. log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
  68. raise
  69. except requests.exceptions.RequestException as req_err:
  70. log.error(f"Request exception occurred: {req_err}")
  71. raise
  72. except ValueError as json_err: # Includes JSONDecodeError
  73. log.error(f"JSON decode error: {json_err} - Response: {response.text}")
  74. raise # Re-raise after logging
  75. async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
  76. """Async version of response handling with better error info."""
  77. try:
  78. response.raise_for_status()
  79. # Check content type
  80. content_type = response.headers.get('content-type', '')
  81. if 'application/json' not in content_type:
  82. if response.status == 204:
  83. return {}
  84. text = await response.text()
  85. raise ValueError(f"Unexpected content type: {content_type}, body: {text[:200]}...")
  86. return await response.json()
  87. except aiohttp.ClientResponseError as e:
  88. error_text = await response.text() if response else "No response"
  89. log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
  90. raise
  91. except aiohttp.ClientError as e:
  92. log.error(f"Client error: {e}")
  93. raise
  94. except Exception as e:
  95. log.error(f"Unexpected error processing response: {e}")
  96. raise
  97. def _retry_request_sync(self, request_func, *args, **kwargs):
  98. """Synchronous retry logic with exponential backoff."""
  99. for attempt in range(self.max_retries):
  100. try:
  101. return request_func(*args, **kwargs)
  102. except (requests.exceptions.RequestException, Exception) as e:
  103. if attempt == self.max_retries - 1:
  104. raise
  105. wait_time = (2 ** attempt) + 0.5
  106. log.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...")
  107. time.sleep(wait_time)
  108. async def _retry_request_async(self, request_func, *args, **kwargs):
  109. """Async retry logic with exponential backoff."""
  110. for attempt in range(self.max_retries):
  111. try:
  112. return await request_func(*args, **kwargs)
  113. except (aiohttp.ClientError, asyncio.TimeoutError) as e:
  114. if attempt == self.max_retries - 1:
  115. raise
  116. wait_time = (2 ** attempt) + 0.5
  117. log.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...")
  118. await asyncio.sleep(wait_time)
  119. def _upload_file(self) -> str:
  120. """Uploads the file to Mistral for OCR processing (sync version)."""
  121. log.info("Uploading file to Mistral API")
  122. url = f"{self.BASE_API_URL}/files"
  123. file_name = os.path.basename(self.file_path)
  124. def upload_request():
  125. with open(self.file_path, "rb") as f:
  126. files = {"file": (file_name, f, "application/pdf")}
  127. data = {"purpose": "ocr"}
  128. response = requests.post(
  129. url,
  130. headers=self.headers,
  131. files=files,
  132. data=data,
  133. timeout=self.timeout
  134. )
  135. return self._handle_response(response)
  136. try:
  137. response_data = self._retry_request_sync(upload_request)
  138. file_id = response_data.get("id")
  139. if not file_id:
  140. raise ValueError("File ID not found in upload response.")
  141. log.info(f"File uploaded successfully. File ID: {file_id}")
  142. return file_id
  143. except Exception as e:
  144. log.error(f"Failed to upload file: {e}")
  145. raise
  146. async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
  147. """Async file upload with streaming for better memory efficiency."""
  148. url = f"{self.BASE_API_URL}/files"
  149. async def upload_request():
  150. # Create multipart writer for streaming upload
  151. writer = aiohttp.MultipartWriter('form-data')
  152. # Add purpose field
  153. purpose_part = writer.append('ocr')
  154. purpose_part.set_content_disposition('form-data', name='purpose')
  155. # Add file part with streaming
  156. file_part = writer.append_payload(aiohttp.streams.FilePayload(
  157. self.file_path,
  158. filename=self.file_name,
  159. content_type='application/pdf'
  160. ))
  161. file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
  162. self._debug_log(f"Uploading file: {self.file_name} ({self.file_size:,} bytes)")
  163. async with session.post(
  164. url,
  165. data=writer,
  166. headers=self.headers,
  167. timeout=aiohttp.ClientTimeout(total=self.timeout)
  168. ) as response:
  169. return await self._handle_response_async(response)
  170. response_data = await self._retry_request_async(upload_request)
  171. file_id = response_data.get("id")
  172. if not file_id:
  173. raise ValueError("File ID not found in upload response.")
  174. log.info(f"File uploaded successfully. File ID: {file_id}")
  175. return file_id
  176. def _get_signed_url(self, file_id: str) -> str:
  177. """Retrieves a temporary signed URL for the uploaded file (sync version)."""
  178. log.info(f"Getting signed URL for file ID: {file_id}")
  179. url = f"{self.BASE_API_URL}/files/{file_id}/url"
  180. params = {"expiry": 1}
  181. signed_url_headers = {**self.headers, "Accept": "application/json"}
  182. def url_request():
  183. response = requests.get(
  184. url,
  185. headers=signed_url_headers,
  186. params=params,
  187. timeout=self.timeout
  188. )
  189. return self._handle_response(response)
  190. try:
  191. response_data = self._retry_request_sync(url_request)
  192. signed_url = response_data.get("url")
  193. if not signed_url:
  194. raise ValueError("Signed URL not found in response.")
  195. log.info("Signed URL received.")
  196. return signed_url
  197. except Exception as e:
  198. log.error(f"Failed to get signed URL: {e}")
  199. raise
  200. async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
  201. """Async signed URL retrieval."""
  202. url = f"{self.BASE_API_URL}/files/{file_id}/url"
  203. params = {"expiry": 1}
  204. headers = {
  205. **self.headers,
  206. "Accept": "application/json"
  207. }
  208. async def url_request():
  209. self._debug_log(f"Getting signed URL for file ID: {file_id}")
  210. async with session.get(
  211. url,
  212. headers=headers,
  213. params=params,
  214. timeout=aiohttp.ClientTimeout(total=self.timeout)
  215. ) as response:
  216. return await self._handle_response_async(response)
  217. response_data = await self._retry_request_async(url_request)
  218. signed_url = response_data.get("url")
  219. if not signed_url:
  220. raise ValueError("Signed URL not found in response.")
  221. self._debug_log("Signed URL received successfully")
  222. return signed_url
  223. def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
  224. """Sends the signed URL to the OCR endpoint for processing (sync version)."""
  225. log.info("Processing OCR via Mistral API")
  226. url = f"{self.BASE_API_URL}/ocr"
  227. ocr_headers = {
  228. **self.headers,
  229. "Content-Type": "application/json",
  230. "Accept": "application/json",
  231. }
  232. payload = {
  233. "model": "mistral-ocr-latest",
  234. "document": {
  235. "type": "document_url",
  236. "document_url": signed_url,
  237. },
  238. "include_image_base64": False,
  239. }
  240. def ocr_request():
  241. response = requests.post(
  242. url,
  243. headers=ocr_headers,
  244. json=payload,
  245. timeout=self.timeout
  246. )
  247. return self._handle_response(response)
  248. try:
  249. ocr_response = self._retry_request_sync(ocr_request)
  250. log.info("OCR processing done.")
  251. self._debug_log("OCR response: %s", ocr_response)
  252. return ocr_response
  253. except Exception as e:
  254. log.error(f"Failed during OCR processing: {e}")
  255. raise
  256. async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
  257. """Async OCR processing with timing metrics."""
  258. url = f"{self.BASE_API_URL}/ocr"
  259. headers = {
  260. **self.headers,
  261. "Content-Type": "application/json",
  262. "Accept": "application/json",
  263. }
  264. payload = {
  265. "model": "mistral-ocr-latest",
  266. "document": {
  267. "type": "document_url",
  268. "document_url": signed_url,
  269. },
  270. "include_image_base64": False,
  271. }
  272. async def ocr_request():
  273. log.info("Starting OCR processing via Mistral API")
  274. start_time = time.time()
  275. async with session.post(
  276. url,
  277. json=payload,
  278. headers=headers,
  279. timeout=aiohttp.ClientTimeout(total=self.timeout)
  280. ) as response:
  281. ocr_response = await self._handle_response_async(response)
  282. processing_time = time.time() - start_time
  283. log.info(f"OCR processing completed in {processing_time:.2f}s")
  284. return ocr_response
  285. return await self._retry_request_async(ocr_request)
  286. def _delete_file(self, file_id: str) -> None:
  287. """Deletes the file from Mistral storage (sync version)."""
  288. log.info(f"Deleting uploaded file ID: {file_id}")
  289. url = f"{self.BASE_API_URL}/files/{file_id}"
  290. try:
  291. response = requests.delete(url, headers=self.headers, timeout=30)
  292. delete_response = self._handle_response(response)
  293. log.info(f"File deleted successfully: {delete_response}")
  294. except Exception as e:
  295. # Log error but don't necessarily halt execution if deletion fails
  296. log.error(f"Failed to delete file ID {file_id}: {e}")
  297. async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
  298. """Async file deletion with error tolerance."""
  299. try:
  300. async def delete_request():
  301. self._debug_log(f"Deleting file ID: {file_id}")
  302. async with session.delete(
  303. url=f"{self.BASE_API_URL}/files/{file_id}",
  304. headers=self.headers,
  305. timeout=aiohttp.ClientTimeout(total=30) # Shorter timeout for cleanup
  306. ) as response:
  307. return await self._handle_response_async(response)
  308. await self._retry_request_async(delete_request)
  309. self._debug_log(f"File {file_id} deleted successfully")
  310. except Exception as e:
  311. # Don't fail the entire process if cleanup fails
  312. log.warning(f"Failed to delete file ID {file_id}: {e}")
  313. @asynccontextmanager
  314. async def _get_session(self):
  315. """Context manager for HTTP session with optimized settings."""
  316. connector = aiohttp.TCPConnector(
  317. limit=10, # Total connection limit
  318. limit_per_host=5, # Per-host connection limit
  319. ttl_dns_cache=300, # DNS cache TTL
  320. use_dns_cache=True,
  321. keepalive_timeout=30,
  322. enable_cleanup_closed=True
  323. )
  324. async with aiohttp.ClientSession(
  325. connector=connector,
  326. timeout=aiohttp.ClientTimeout(total=self.timeout),
  327. headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}
  328. ) as session:
  329. yield session
  330. def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
  331. """Process OCR results into Document objects with enhanced metadata."""
  332. pages_data = ocr_response.get("pages")
  333. if not pages_data:
  334. log.warning("No pages found in OCR response.")
  335. return [Document(page_content="No text content found", metadata={"error": "no_pages"})]
  336. documents = []
  337. total_pages = len(pages_data)
  338. skipped_pages = 0
  339. for page_data in pages_data:
  340. page_content = page_data.get("markdown")
  341. page_index = page_data.get("index") # API uses 0-based index
  342. if page_content is not None and page_index is not None:
  343. # Clean up content efficiently
  344. cleaned_content = page_content.strip() if isinstance(page_content, str) else str(page_content)
  345. if cleaned_content: # Only add non-empty pages
  346. documents.append(
  347. Document(
  348. page_content=cleaned_content,
  349. metadata={
  350. "page": page_index, # 0-based index from API
  351. "page_label": page_index + 1, # 1-based label for convenience
  352. "total_pages": total_pages,
  353. "file_name": self.file_name,
  354. "file_size": self.file_size,
  355. "processing_engine": "mistral-ocr"
  356. },
  357. )
  358. )
  359. else:
  360. skipped_pages += 1
  361. self._debug_log(f"Skipping empty page {page_index}")
  362. else:
  363. skipped_pages += 1
  364. self._debug_log(f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}")
  365. if skipped_pages > 0:
  366. log.info(f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages")
  367. if not documents:
  368. # Case where pages existed but none had valid markdown/index
  369. log.warning("OCR response contained pages, but none had valid content/index.")
  370. return [
  371. Document(
  372. page_content="No valid text content found in document",
  373. metadata={"error": "no_valid_pages", "total_pages": total_pages}
  374. )
  375. ]
  376. return documents
  377. def load(self) -> List[Document]:
  378. """
  379. Executes the full OCR workflow: upload, get URL, process OCR, delete file.
  380. Synchronous version for backward compatibility.
  381. Returns:
  382. A list of Document objects, one for each page processed.
  383. """
  384. file_id = None
  385. start_time = time.time()
  386. try:
  387. # 1. Upload file
  388. file_id = self._upload_file()
  389. # 2. Get Signed URL
  390. signed_url = self._get_signed_url(file_id)
  391. # 3. Process OCR
  392. ocr_response = self._process_ocr(signed_url)
  393. # 4. Process results
  394. documents = self._process_results(ocr_response)
  395. total_time = time.time() - start_time
  396. log.info(f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents")
  397. return documents
  398. except Exception as e:
  399. total_time = time.time() - start_time
  400. log.error(f"An error occurred during the loading process after {total_time:.2f}s: {e}")
  401. # Return an error document on failure
  402. return [Document(
  403. page_content=f"Error during processing: {e}",
  404. metadata={"error": "processing_failed", "file_name": self.file_name}
  405. )]
  406. finally:
  407. # 5. Delete file (attempt even if prior steps failed after upload)
  408. if file_id:
  409. try:
  410. self._delete_file(file_id)
  411. except Exception as del_e:
  412. # Log deletion error, but don't overwrite original error if one occurred
  413. log.error(f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}")
  414. async def load_async(self) -> List[Document]:
  415. """
  416. Asynchronous OCR workflow execution with optimized performance.
  417. Returns:
  418. A list of Document objects, one for each page processed.
  419. """
  420. file_id = None
  421. start_time = time.time()
  422. try:
  423. async with self._get_session() as session:
  424. # 1. Upload file with streaming
  425. file_id = await self._upload_file_async(session)
  426. # 2. Get signed URL
  427. signed_url = await self._get_signed_url_async(session, file_id)
  428. # 3. Process OCR
  429. ocr_response = await self._process_ocr_async(session, signed_url)
  430. # 4. Process results
  431. documents = self._process_results(ocr_response)
  432. total_time = time.time() - start_time
  433. log.info(f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents")
  434. return documents
  435. except Exception as e:
  436. total_time = time.time() - start_time
  437. log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
  438. return [Document(
  439. page_content=f"Error during OCR processing: {e}",
  440. metadata={"error": "processing_failed", "file_name": self.file_name}
  441. )]
  442. finally:
  443. # 5. Cleanup - always attempt file deletion
  444. if file_id:
  445. try:
  446. async with self._get_session() as session:
  447. await self._delete_file_async(session, file_id)
  448. except Exception as cleanup_error:
  449. log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
  450. @staticmethod
  451. async def load_multiple_async(loaders: List['MistralLoader']) -> List[List[Document]]:
  452. """
  453. Process multiple files concurrently for maximum performance.
  454. Args:
  455. loaders: List of MistralLoader instances
  456. Returns:
  457. List of document lists, one for each loader
  458. """
  459. if not loaders:
  460. return []
  461. log.info(f"Starting concurrent processing of {len(loaders)} files")
  462. start_time = time.time()
  463. # Process all files concurrently
  464. tasks = [loader.load_async() for loader in loaders]
  465. results = await asyncio.gather(*tasks, return_exceptions=True)
  466. # Handle any exceptions in results
  467. processed_results = []
  468. for i, result in enumerate(results):
  469. if isinstance(result, Exception):
  470. log.error(f"File {i} failed: {result}")
  471. processed_results.append([Document(
  472. page_content=f"Error processing file: {result}",
  473. metadata={"error": "batch_processing_failed", "file_index": i}
  474. )])
  475. else:
  476. processed_results.append(result)
  477. total_time = time.time() - start_time
  478. total_docs = sum(len(docs) for docs in processed_results)
  479. log.info(f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents")
  480. return processed_results