Timothy Jaeryang Baek 3 luni în urmă
părinte
comite
5c1ba23026
1 a modificat fișierele cu 33 adăugiri și 24 ștergeri
  1. 33 24
      backend/open_webui/routers/images.py

+ 33 - 24
backend/open_webui/routers/images.py

@@ -737,7 +737,6 @@ async def image_edits(
     form_data: EditImageForm,
     user=Depends(get_verified_user),
 ):
-
     size = None
     width, height = None, None
     if (
@@ -757,29 +756,39 @@ async def image_edits(
         else form_data.model
     )
 
-    def load_url_image(string):
-        if string.startswith("http://") or string.startswith("https://"):
-            r = requests.get(string)
-            r.raise_for_status()
-            image_data = base64.b64encode(r.content).decode("utf-8")
-            return f"data:{r.headers['content-type']};base64,{image_data}"
-
-        elif string.startswith("/api/v1/files"):
-            file_id = string.split("/api/v1/files/")[1].split("/content")[0]
-            file_response = get_file_content_by_id(file_id, user)
-
-            if isinstance(file_response, FileResponse):
-                file_bytes = file_response.body
-                mime_type = file_response.headers.get("content-type", "image/png")
-                image_data = base64.b64encode(file_bytes).decode("utf-8")
-                return f"data:{mime_type};base64,{image_data}"
-        return string
-
-    # Load image(s) from URL(s) if necessary
-    if isinstance(form_data.image, str):
-        form_data.image = load_url_image(form_data.image)
-    elif isinstance(form_data.image, list):
-        form_data.image = [load_url_image(img) for img in form_data.image]
+    try:
+
+        async def load_url_image(data):
+            if data.startswith("http://") or data.startswith("https://"):
+                r = await asyncio.to_thread(requests.get, data)
+                r.raise_for_status()
+
+                image_data = base64.b64encode(r.content).decode("utf-8")
+                return f"data:{r.headers['content-type']};base64,{image_data}"
+
+            elif data.startswith("/api/v1/files"):
+                file_id = data.split("/api/v1/files/")[1].split("/content")[0]
+                file_response = await get_file_content_by_id(file_id, user)
+
+                if isinstance(file_response, FileResponse):
+                    file_path = file_response.path
+
+                    with open(file_path, "rb") as f:
+                        file_bytes = f.read()
+                        image_data = base64.b64encode(file_bytes).decode("utf-8")
+                        mime_type, _ = mimetypes.guess_type(file_path)
+
+                    return f"data:{mime_type};base64,{image_data}"
+
+            return data
+
+        # Load image(s) from URL(s) if necessary
+        if isinstance(form_data.image, str):
+            form_data.image = await load_url_image(form_data.image)
+        elif isinstance(form_data.image, list):
+            form_data.image = [await load_url_image(img) for img in form_data.image]
+    except Exception as e:
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
 
     r = None
     try: