Timothy Jaeryang Baek 1 hónapja
szülő
commit
37a3de0703
2 módosított fájl, 37 hozzáadás és 24 törlés
  1. 32 10
      backend/open_webui/routers/files.py
  2. 5 14
      backend/open_webui/routers/images.py

+ 32 - 10
backend/open_webui/routers/files.py

@@ -144,6 +144,17 @@ def upload_file(
     metadata: Optional[dict | str] = Form(None),
     process: bool = Query(True),
     user=Depends(get_verified_user),
+):
+    return upload_file_handler(request, file, metadata, process, user, background_tasks)
+
+
+def upload_file_handler(
+    request: Request,
+    file: UploadFile = File(...),
+    metadata: Optional[dict | str] = Form(None),
+    process: bool = Query(True),
+    user=Depends(get_verified_user),
+    background_tasks: Optional[BackgroundTasks] = None,
 ):
     log.info(f"file.content_type: {file.content_type}")
 
@@ -214,16 +225,27 @@ def upload_file(
         )
 
         if process:
-            background_tasks.add_task(
-                process_uploaded_file,
-                request,
-                file,
-                file_path,
-                file_item,
-                file_metadata,
-                user,
-            )
-            return {"status": True, **file_item.model_dump()}
+            if background_tasks:
+                background_tasks.add_task(
+                    process_uploaded_file,
+                    request,
+                    file,
+                    file_path,
+                    file_item,
+                    file_metadata,
+                    user,
+                )
+                return {"status": True, **file_item.model_dump()}
+            else:
+                process_uploaded_file(
+                    request,
+                    file,
+                    file_path,
+                    file_item,
+                    file_metadata,
+                    user,
+                )
+                return {"status": True, **file_item.model_dump()}
         else:
             if file_item:
                 return file_item

+ 5 - 14
backend/open_webui/routers/images.py

@@ -16,13 +16,12 @@ from fastapi import (
     HTTPException,
     Request,
     UploadFile,
-    BackgroundTasks,
 )
 
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
-from open_webui.routers.files import upload_file
+from open_webui.routers.files import upload_file_handler
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.images.comfyui import (
     ComfyUIGenerateImageForm,
@@ -468,7 +467,7 @@ def load_url_image_data(url, headers=None):
         return None
 
 
-def upload_image(request, background_tasks, image_data, content_type, metadata, user):
+def upload_image(request, image_data, content_type, metadata, user):
     image_format = mimetypes.guess_extension(content_type)
     file = UploadFile(
         file=io.BytesIO(image_data),
@@ -477,9 +476,8 @@ def upload_image(request, background_tasks, image_data, content_type, metadata,
             "content-type": content_type,
         },
     )
-    file_item = upload_file(
+    file_item = upload_file_handler(
         request,
-        background_tasks,
         file=file,
         metadata=metadata,
         process=False,
@@ -492,7 +490,6 @@ def upload_image(request, background_tasks, image_data, content_type, metadata,
 @router.post("/generations")
 async def image_generations(
     request: Request,
-    background_tasks: BackgroundTasks,
     form_data: GenerateImageForm,
     user=Depends(get_verified_user),
 ):
@@ -566,9 +563,7 @@ async def image_generations(
                 else:
                     image_data, content_type = load_b64_image_data(image["b64_json"])
 
-                url = upload_image(
-                    request, background_tasks, image_data, content_type, data, user
-                )
+                url = upload_image(request, image_data, content_type, data, user)
                 images.append({"url": url})
             return images
 
@@ -602,9 +597,7 @@ async def image_generations(
                 image_data, content_type = load_b64_image_data(
                     image["bytesBase64Encoded"]
                 )
-                url = upload_image(
-                    request, background_tasks, image_data, content_type, data, user
-                )
+                url = upload_image(request, image_data, content_type, data, user)
                 images.append({"url": url})
 
             return images
@@ -655,7 +648,6 @@ async def image_generations(
                 image_data, content_type = load_url_image_data(image["url"], headers)
                 url = upload_image(
                     request,
-                    background_tasks,
                     image_data,
                     content_type,
                     form_data.model_dump(exclude_none=True),
@@ -709,7 +701,6 @@ async def image_generations(
                 image_data, content_type = load_b64_image_data(image)
                 url = upload_image(
                     request,
-                    background_tasks,
                     image_data,
                     content_type,
                     {**data, "info": res["info"]},