1
0
Timothy Jaeryang Baek 1 сар өмнө
parent
commit
5e1f4fa0ff

+ 130 - 53
backend/open_webui/routers/files.py

@@ -6,8 +6,10 @@ from fnmatch import fnmatch
 from pathlib import Path
 from typing import Optional
 from urllib.parse import quote
+import asyncio
 
 from fastapi import (
+    BackgroundTasks,
     APIRouter,
     Depends,
     File,
@@ -18,6 +20,7 @@ from fastapi import (
     status,
     Query,
 )
+
 from fastapi.responses import FileResponse, StreamingResponse
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import SRC_LOG_LEVELS
@@ -42,7 +45,6 @@ from pydantic import BaseModel
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
-
 router = APIRouter()
 
 
@@ -83,13 +85,64 @@ def has_access_to_file(
 ############################
 
 
+def process_uploaded_file(request, file, file_item, file_metadata, user):
+    try:
+        if file.content_type:
+            stt_supported_content_types = getattr(
+                request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
+            )
+
+            if any(
+                fnmatch(file.content_type, content_type)
+                for content_type in (
+                    stt_supported_content_types
+                    if stt_supported_content_types
+                    and any(t.strip() for t in stt_supported_content_types)
+                    else ["audio/*", "video/webm"]
+                )
+            ):
+                file_path = Storage.get_file(file_path)
+                result = transcribe(request, file_path, file_metadata)
+
+                process_file(
+                    request,
+                    ProcessFileForm(
+                        file_id=file_item.id, content=result.get("text", "")
+                    ),
+                    user=user,
+                )
+            elif (not file.content_type.startswith(("image/", "video/"))) or (
+                request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
+            ):
+                process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
+        else:
+            log.info(
+                f"File type {file.content_type} is not provided, but trying to process anyway"
+            )
+            process_file(request, ProcessFileForm(file_id=file_item.id), user=user)
+
+        Files.update_file_data_by_id(
+            file_item.id,
+            {"status": "completed"},
+        )
+    except Exception as e:
+        log.error(f"Error processing file: {file_item.id}")
+        Files.update_file_data_by_id(
+            file_item.id,
+            {
+                "status": "failed",
+                "error": str(e.detail) if hasattr(e, "detail") else str(e),
+            },
+        )
+
+
 @router.post("/", response_model=FileModelResponse)
 def upload_file(
     request: Request,
+    background_tasks: BackgroundTasks,
     file: UploadFile = File(...),
     metadata: Optional[dict | str] = Form(None),
     process: bool = Query(True),
-    internal: bool = False,
     user=Depends(get_verified_user),
 ):
     log.info(f"file.content_type: {file.content_type}")
@@ -112,7 +165,7 @@ def upload_file(
         # Remove the leading dot from the file extension
         file_extension = file_extension[1:] if file_extension else ""
 
-        if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
+        if process and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
             request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
                 ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
             ]
@@ -147,6 +200,9 @@ def upload_file(
                     "id": id,
                     "filename": name,
                     "path": file_path,
+                    "data": {
+                        **({"status": "pending"} if process else {}),
+                    },
                     "meta": {
                         "name": name,
                         "content_type": file.content_type,
@@ -156,58 +212,25 @@ def upload_file(
                 }
             ),
         )
-        if process:
-            try:
-                if file.content_type:
-                    stt_supported_content_types = getattr(
-                        request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", []
-                    )
 
-                    if any(
-                        fnmatch(file.content_type, content_type)
-                        for content_type in (
-                            stt_supported_content_types
-                            if stt_supported_content_types
-                            and any(t.strip() for t in stt_supported_content_types)
-                            else ["audio/*", "video/webm"]
-                        )
-                    ):
-                        file_path = Storage.get_file(file_path)
-                        result = transcribe(request, file_path, file_metadata)
-
-                        process_file(
-                            request,
-                            ProcessFileForm(file_id=id, content=result.get("text", "")),
-                            user=user,
-                        )
-                    elif (not file.content_type.startswith(("image/", "video/"))) or (
-                        request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
-                    ):
-                        process_file(request, ProcessFileForm(file_id=id), user=user)
-                else:
-                    log.info(
-                        f"File type {file.content_type} is not provided, but trying to process anyway"
-                    )
-                    process_file(request, ProcessFileForm(file_id=id), user=user)
-
-                file_item = Files.get_file_by_id(id=id)
-            except Exception as e:
-                log.exception(e)
-                log.error(f"Error processing file: {file_item.id}")
-                file_item = FileModelResponse(
-                    **{
-                        **file_item.model_dump(),
-                        "error": str(e.detail) if hasattr(e, "detail") else str(e),
-                    }
-                )
-
-        if file_item:
-            return file_item
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_400_BAD_REQUEST,
-                detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
+        if process:
+            background_tasks.add_task(
+                process_uploaded_file,
+                request,
+                file,
+                file_item,
+                file_metadata,
+                user,
             )
+            return {"status": True, **file_item.model_dump()}
+        else:
+            if file_item:
+                return file_item
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
+                )
 
     except Exception as e:
         log.exception(e)
@@ -334,6 +357,60 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
         )
 
 
+@router.get("/{id}/process/status")
+async def get_file_process_status(
+    id: str, stream: bool = Query(False), user=Depends(get_verified_user)
+):
+    file = Files.get_file_by_id(id)
+
+    if not file:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if (
+        file.user_id == user.id
+        or user.role == "admin"
+        or has_access_to_file(id, "read", user)
+    ):
+        if stream:
+            MAX_FILE_PROCESSING_DURATION = 3600 * 2
+
+            async def event_stream(file_item):
+                for _ in range(MAX_FILE_PROCESSING_DURATION):
+                    file_item = Files.get_file_by_id(file_item.id)
+                    if file_item:
+                        data = file_item.model_dump().get("data", {})
+                        status = data.get("status")
+
+                        if status:
+                            event = {"status": status}
+                            if status == "failed":
+                                event["error"] = data.get("error")
+
+                            yield f"data: {json.dumps(event)}\n\n"
+                            if status in ("completed", "failed"):
+                                break
+                        else:
+                            # Legacy
+                            break
+
+                    await asyncio.sleep(0.5)
+
+            return StreamingResponse(
+                event_stream(file),
+                media_type="text/event-stream",
+            )
+        else:
+            return {"status": file.data.get("status", "pending")}
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
 ############################
 # Get File Data Content By Id
 ############################

+ 3 - 1
backend/open_webui/routers/images.py

@@ -469,7 +469,9 @@ def upload_image(request, image_data, content_type, metadata, user):
             "content-type": content_type,
         },
     )
-    file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
+    file_item = upload_file(
+        request, file=file, metadata=metadata, process=False, user=user
+    )
     url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
     return url
 

+ 1 - 1
backend/open_webui/routers/retrieval.py

@@ -1476,7 +1476,7 @@ def process_file(
         log.debug(f"text_content: {text_content}")
         Files.update_file_data_by_id(
             file.id,
-            {"content": text_content},
+            {"status": "completed", "content": text_content},
         )
 
         hash = calculate_sha256_string(text_content)

+ 70 - 0
src/lib/apis/files/index.ts

@@ -1,4 +1,5 @@
 import { WEBUI_API_BASE_URL } from '$lib/constants';
+import { splitStream } from '$lib/utils';
 
 export const uploadFile = async (token: string, file: File, metadata?: object | null) => {
 	const data = new FormData();
@@ -31,6 +32,75 @@ export const uploadFile = async (token: string, file: File, metadata?: object |
 		throw error;
 	}
 
+	if (res) {
+		const status = await getFileProcessStatus(token, res.id);
+
+		if (status && status.ok) {
+			const reader = status.body
+				.pipeThrough(new TextDecoderStream())
+				.pipeThrough(splitStream('\n'))
+				.getReader();
+
+			while (true) {
+				const { value, done } = await reader.read();
+				if (done) {
+					break;
+				}
+
+				try {
+					let lines = value.split('\n');
+
+					for (const line of lines) {
+						if (line !== '') {
+							console.log(line);
+							if (line === 'data: [DONE]') {
+								console.log(line);
+							} else {
+								let data = JSON.parse(line.replace(/^data: /, ''));
+								console.log(data);
+
+								if (data?.error) {
+									console.error(data.error);
+									res.error = data.error;
+								}
+							}
+						}
+					}
+				} catch (error) {
+					console.log(error);
+				}
+			}
+		}
+	}
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFileProcessStatus = async (token: string, id: string) => {
+	const queryParams = new URLSearchParams();
+	queryParams.append('stream', 'true');
+
+	let error = null;
+	const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}/process/status?${queryParams}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	}).catch((err) => {
+		error = err.detail;
+		console.error(err);
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
 	return res;
 };
 

+ 1 - 0
src/lib/components/chat/Chat.svelte

@@ -90,6 +90,7 @@
 	import { fade } from 'svelte/transition';
 	import Tooltip from '../common/Tooltip.svelte';
 	import Sidebar from '../icons/Sidebar.svelte';
+	import { uploadFile } from '$lib/apis/files';
 
 	export let chatIdProp = '';
 

+ 6 - 0
src/lib/components/workspace/Knowledge/KnowledgeBase.svelte

@@ -182,6 +182,12 @@
 
 			if (uploadedFile) {
 				console.log(uploadedFile);
+
+				if (uploadedFile.error) {
+					console.warn('File upload warning:', uploadedFile.error);
+					toast.warning(uploadedFile.error);
+				}
+
 				knowledge.files = knowledge.files.map((item) => {
 					if (item.itemId === tempItemId) {
 						item.id = uploadedFile.id;