Browse Source

fix: image generation

Timothy Jaeryang Baek 1 tháng trước cách đây
mục cha
commit
72b25ab78b
1 tập tin đã thay đổi với 25 bổ sung5 xóa
  1. 25 5
      backend/open_webui/routers/images.py

+ 25 - 5
backend/open_webui/routers/images.py

@@ -10,7 +10,15 @@ from typing import Optional
 
 from urllib.parse import quote
 import requests
-from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
+from fastapi import (
+    APIRouter,
+    Depends,
+    HTTPException,
+    Request,
+    UploadFile,
+    BackgroundTasks,
+)
+
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
@@ -460,7 +468,7 @@ def load_url_image_data(url, headers=None):
         return None
 
 
-def upload_image(request, image_data, content_type, metadata, user):
+def upload_image(request, background_tasks, image_data, content_type, metadata, user):
     image_format = mimetypes.guess_extension(content_type)
     file = UploadFile(
         file=io.BytesIO(image_data),
@@ -470,7 +478,12 @@ def upload_image(request, image_data, content_type, metadata, user):
         },
     )
     file_item = upload_file(
-        request, file=file, metadata=metadata, process=False, user=user
+        request,
+        background_tasks,
+        file=file,
+        metadata=metadata,
+        process=False,
+        user=user,
     )
     url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
     return url
@@ -479,6 +492,7 @@ def upload_image(request, image_data, content_type, metadata, user):
 @router.post("/generations")
 async def image_generations(
     request: Request,
+    background_tasks: BackgroundTasks,
     form_data: GenerateImageForm,
     user=Depends(get_verified_user),
 ):
@@ -552,7 +566,9 @@ async def image_generations(
                 else:
                     image_data, content_type = load_b64_image_data(image["b64_json"])
 
-                url = upload_image(request, image_data, content_type, data, user)
+                url = upload_image(
+                    request, background_tasks, image_data, content_type, data, user
+                )
                 images.append({"url": url})
             return images
 
@@ -586,7 +602,9 @@ async def image_generations(
                 image_data, content_type = load_b64_image_data(
                     image["bytesBase64Encoded"]
                 )
-                url = upload_image(request, image_data, content_type, data, user)
+                url = upload_image(
+                    request, background_tasks, image_data, content_type, data, user
+                )
                 images.append({"url": url})
 
             return images
@@ -637,6 +655,7 @@ async def image_generations(
                 image_data, content_type = load_url_image_data(image["url"], headers)
                 url = upload_image(
                     request,
+                    background_tasks,
                     image_data,
                     content_type,
                     form_data.model_dump(exclude_none=True),
@@ -690,6 +709,7 @@ async def image_generations(
                 image_data, content_type = load_b64_image_data(image)
                 url = upload_image(
                     request,
+                    background_tasks,
                     image_data,
                     content_type,
                     {**data, "info": res["info"]},