Explorar o código

feat: image edit support

Timothy Jaeryang Baek hai 3 meses
pai
achega
72f8539fd2

+ 58 - 12
backend/open_webui/config.py

@@ -3074,16 +3074,30 @@ EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig(
 # Images
 # Images
 ####################################
 ####################################
 
 
+ENABLE_IMAGE_GENERATION = PersistentConfig(
+    "ENABLE_IMAGE_GENERATION",
+    "image_generation.enable",
+    os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
+)
+
 IMAGE_GENERATION_ENGINE = PersistentConfig(
 IMAGE_GENERATION_ENGINE = PersistentConfig(
     "IMAGE_GENERATION_ENGINE",
     "IMAGE_GENERATION_ENGINE",
     "image_generation.engine",
     "image_generation.engine",
     os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
     os.getenv("IMAGE_GENERATION_ENGINE", "openai"),
 )
 )
 
 
-ENABLE_IMAGE_GENERATION = PersistentConfig(
-    "ENABLE_IMAGE_GENERATION",
-    "image_generation.enable",
-    os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
+IMAGE_GENERATION_MODEL = PersistentConfig(
+    "IMAGE_GENERATION_MODEL",
+    "image_generation.model",
+    os.getenv("IMAGE_GENERATION_MODEL", ""),
+)
+
+IMAGE_SIZE = PersistentConfig(
+    "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
+)
+
+IMAGE_STEPS = PersistentConfig(
+    "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
 )
 )
 
 
 ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
 ENABLE_IMAGE_PROMPT_GENERATION = PersistentConfig(
@@ -3285,19 +3299,51 @@ IMAGES_GEMINI_ENDPOINT_METHOD = PersistentConfig(
     os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
     os.getenv("IMAGES_GEMINI_ENDPOINT_METHOD", ""),
 )
 )
 
 
-IMAGE_SIZE = PersistentConfig(
-    "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
+
+IMAGE_EDIT_ENGINE = PersistentConfig(
+    "IMAGE_EDIT_ENGINE",
+    "images.edit.engine",
+    os.getenv("IMAGE_EDIT_ENGINE", "openai"),
 )
 )
 
 
-IMAGE_STEPS = PersistentConfig(
-    "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
+IMAGE_EDIT_MODEL = PersistentConfig(
+    "IMAGE_EDIT_MODEL",
+    "images.edit.model",
+    os.getenv("IMAGE_EDIT_MODEL", ""),
 )
 )
 
 
-IMAGE_GENERATION_MODEL = PersistentConfig(
-    "IMAGE_GENERATION_MODEL",
-    "image_generation.model",
-    os.getenv("IMAGE_GENERATION_MODEL", ""),
+IMAGE_EDIT_SIZE = PersistentConfig(
+    "IMAGE_EDIT_SIZE", "images.edit.size", os.getenv("IMAGE_EDIT_SIZE", "")
+)
+
+IMAGES_EDIT_OPENAI_API_BASE_URL = PersistentConfig(
+    "IMAGES_EDIT_OPENAI_API_BASE_URL",
+    "images.edit.openai.api_base_url",
+    os.getenv("IMAGES_EDIT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
+)
+IMAGES_EDIT_OPENAI_API_VERSION = PersistentConfig(
+    "IMAGES_EDIT_OPENAI_API_VERSION",
+    "images.edit.openai.api_version",
+    os.getenv("IMAGES_EDIT_OPENAI_API_VERSION", ""),
+)
+
+IMAGES_EDIT_OPENAI_API_KEY = PersistentConfig(
+    "IMAGES_EDIT_OPENAI_API_KEY",
+    "images.edit.openai.api_key",
+    os.getenv("IMAGES_EDIT_OPENAI_API_KEY", OPENAI_API_KEY),
+)
+
+IMAGES_EDIT_GEMINI_API_BASE_URL = PersistentConfig(
+    "IMAGES_EDIT_GEMINI_API_BASE_URL",
+    "images.edit.gemini.api_base_url",
+    os.getenv("IMAGES_EDIT_GEMINI_API_BASE_URL", GEMINI_API_BASE_URL),
 )
 )
+IMAGES_EDIT_GEMINI_API_KEY = PersistentConfig(
+    "IMAGES_EDIT_GEMINI_API_KEY",
+    "images.edit.gemini.api_key",
+    os.getenv("IMAGES_EDIT_GEMINI_API_KEY", GEMINI_API_KEY),
+)
+
 
 
 ####################################
 ####################################
 # Audio
 # Audio

+ 18 - 1
backend/open_webui/main.py

@@ -163,6 +163,14 @@ from open_webui.config import (
     IMAGES_GEMINI_API_BASE_URL,
     IMAGES_GEMINI_API_BASE_URL,
     IMAGES_GEMINI_API_KEY,
     IMAGES_GEMINI_API_KEY,
     IMAGES_GEMINI_ENDPOINT_METHOD,
     IMAGES_GEMINI_ENDPOINT_METHOD,
+    IMAGE_EDIT_ENGINE,
+    IMAGE_EDIT_MODEL,
+    IMAGE_EDIT_SIZE,
+    IMAGES_EDIT_OPENAI_API_BASE_URL,
+    IMAGES_EDIT_OPENAI_API_KEY,
+    IMAGES_EDIT_OPENAI_API_VERSION,
+    IMAGES_EDIT_GEMINI_API_BASE_URL,
+    IMAGES_EDIT_GEMINI_API_KEY,
     # Audio
     # Audio
     AUDIO_STT_ENGINE,
     AUDIO_STT_ENGINE,
     AUDIO_STT_MODEL,
     AUDIO_STT_MODEL,
@@ -1078,7 +1086,6 @@ app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL
 app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
 app.state.config.IMAGES_GEMINI_API_KEY = IMAGES_GEMINI_API_KEY
 app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD
 app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = IMAGES_GEMINI_ENDPOINT_METHOD
 
 
-
 app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
 app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
 app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
 app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS
 app.state.config.AUTOMATIC1111_PARAMS = AUTOMATIC1111_PARAMS
@@ -1089,6 +1096,16 @@ app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
 app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
 app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
 
 
 
 
+app.state.config.IMAGE_EDIT_ENGINE = IMAGE_EDIT_ENGINE
+app.state.config.IMAGE_EDIT_MODEL = IMAGE_EDIT_MODEL
+app.state.config.IMAGE_EDIT_SIZE = IMAGE_EDIT_SIZE
+app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = IMAGES_EDIT_OPENAI_API_BASE_URL
+app.state.config.IMAGES_EDIT_OPENAI_API_KEY = IMAGES_EDIT_OPENAI_API_KEY
+app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = IMAGES_EDIT_OPENAI_API_VERSION
+app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = IMAGES_EDIT_GEMINI_API_BASE_URL
+app.state.config.IMAGES_EDIT_GEMINI_API_KEY = IMAGES_EDIT_GEMINI_API_KEY
+
+
 ########################################
 ########################################
 #
 #
 # AUDIO
 # AUDIO

+ 306 - 8
backend/open_webui/routers/images.py

@@ -1,5 +1,6 @@
 import asyncio
 import asyncio
 import base64
 import base64
+import uuid
 import io
 import io
 import json
 import json
 import logging
 import logging
@@ -10,18 +11,13 @@ from typing import Optional
 
 
 from urllib.parse import quote
 from urllib.parse import quote
 import requests
 import requests
-from fastapi import (
-    APIRouter,
-    Depends,
-    HTTPException,
-    Request,
-    UploadFile,
-)
+from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
+from fastapi.responses import FileResponse
 
 
 from open_webui.config import CACHE_DIR
 from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
 from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
-from open_webui.routers.files import upload_file_handler
+from open_webui.routers.files import upload_file_handler, get_file_content_by_id
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.headers import include_user_info_headers
 from open_webui.utils.headers import include_user_info_headers
 from open_webui.utils.images.comfyui import (
 from open_webui.utils.images.comfyui import (
@@ -121,6 +117,16 @@ class ImagesConfig(BaseModel):
     IMAGES_GEMINI_API_KEY: str
     IMAGES_GEMINI_API_KEY: str
     IMAGES_GEMINI_ENDPOINT_METHOD: str
     IMAGES_GEMINI_ENDPOINT_METHOD: str
 
 
+    IMAGE_EDIT_ENGINE: str
+    IMAGE_EDIT_MODEL: str
+    IMAGE_EDIT_SIZE: Optional[str]
+
+    IMAGES_EDIT_OPENAI_API_BASE_URL: str
+    IMAGES_EDIT_OPENAI_API_KEY: str
+    IMAGES_EDIT_OPENAI_API_VERSION: str
+    IMAGES_EDIT_GEMINI_API_BASE_URL: str
+    IMAGES_EDIT_GEMINI_API_KEY: str
+
 
 
 @router.get("/config", response_model=ImagesConfig)
 @router.get("/config", response_model=ImagesConfig)
 async def get_config(request: Request, user=Depends(get_admin_user)):
 async def get_config(request: Request, user=Depends(get_admin_user)):
@@ -144,6 +150,14 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
         "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
         "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
         "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
         "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
         "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
         "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
+        "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
+        "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
+        "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
+        "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
+        "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
+        "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
+        "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
+        "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
     }
     }
 
 
 
 
@@ -152,6 +166,8 @@ async def update_config(
     request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
     request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
 ):
 ):
     request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
     request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
+
+    # Create Image
     request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
     request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
         form_data.ENABLE_IMAGE_PROMPT_GENERATION
         form_data.ENABLE_IMAGE_PROMPT_GENERATION
     )
     )
@@ -215,6 +231,28 @@ async def update_config(
         form_data.IMAGES_GEMINI_ENDPOINT_METHOD
         form_data.IMAGES_GEMINI_ENDPOINT_METHOD
     )
     )
 
 
+    # Edit Image
+    request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
+    request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
+    request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
+
+    request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
+        form_data.IMAGES_OPENAI_API_BASE_URL
+    )
+    request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
+        form_data.IMAGES_OPENAI_API_KEY
+    )
+    request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
+        form_data.IMAGES_EDIT_OPENAI_API_VERSION
+    )
+
+    request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
+        form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
+    )
+    request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
+        form_data.IMAGES_EDIT_GEMINI_API_KEY
+    )
+
     return {
     return {
         "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
         "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
         "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
@@ -235,6 +273,14 @@ async def update_config(
         "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
         "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
         "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
         "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
         "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
         "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
+        "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
+        "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
+        "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
+        "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
+        "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
+        "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
+        "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
+        "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
     }
     }
 
 
 
 
@@ -674,3 +720,255 @@ async def image_generations(
             if "error" in data:
             if "error" in data:
                 error = data["error"]["message"]
                 error = data["error"]["message"]
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
         raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
+
+
+class EditImageForm(BaseModel):
+    image: str | list[str]  # base64-encoded image(s) or URL(s)
+    prompt: str
+    model: Optional[str] = None
+    size: Optional[str] = None
+    n: Optional[int] = None
+    negative_prompt: Optional[str] = None
+
+
+@router.post("/edit")
+async def image_edits(
+    request: Request,
+    form_data: EditImageForm,
+    user=Depends(get_verified_user),
+):
+
+    size = None
+    width, height = None, None
+    if (
+        request.app.state.config.IMAGE_EDIT_SIZE
+        and "x" in request.app.state.config.IMAGE_EDIT_SIZE
+    ) or (form_data.size and "x" in form_data.size):
+        size = (
+            form_data.size
+            if form_data.size
+            else request.app.state.config.IMAGE_EDIT_SIZE
+        )
+        width, height = tuple(map(int, size.split("x")))
+
+    model = (
+        request.app.state.config.IMAGE_EDIT_MODEL
+        if form_data.model is None
+        else form_data.model
+    )
+
+    def load_url_image(string):
+        if string.startswith("http://") or string.startswith("https://"):
+            r = requests.get(string)
+            r.raise_for_status()
+            image_data = base64.b64encode(r.content).decode("utf-8")
+            return f"data:{r.headers['content-type']};base64,{image_data}"
+
+        elif string.startswith("/api/v1/files"):
+            file_id = string.split("/api/v1/files/")[1].split("/content")[0]
+            file_response = get_file_content_by_id(file_id, user)
+
+            if isinstance(file_response, FileResponse):
+                file_bytes = file_response.body
+                mime_type = file_response.headers.get("content-type", "image/png")
+                image_data = base64.b64encode(file_bytes).decode("utf-8")
+                return f"data:{mime_type};base64,{image_data}"
+        return string
+
+    # Load image(s) from URL(s) if necessary
+    if isinstance(form_data.image, str):
+        form_data.image = load_url_image(form_data.image)
+    elif isinstance(form_data.image, list):
+        form_data.image = [load_url_image(img) for img in form_data.image]
+
+    r = None
+    try:
+        if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
+            headers = {
+                "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
+            }
+
+            if ENABLE_FORWARD_USER_INFO_HEADERS:
+                headers = include_user_info_headers(headers, user)
+
+            data = {
+                "model": model,
+                "prompt": form_data.prompt,
+                **({"n": form_data.n} if form_data.n else {}),
+                **({"size": size} if size else {}),
+                **(
+                    {}
+                    if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
+                    else {"response_format": "b64_json"}
+                ),
+            }
+
+            def get_image_file_item(base64_string):
+                data = base64_string
+                header, encoded = data.split(",", 1)
+                mime_type = header.split(";")[0].lstrip("data:")
+                image_data = base64.b64decode(encoded)
+                return (
+                    "image",
+                    (
+                        f"{uuid.uuid4()}.png",
+                        io.BytesIO(image_data),
+                        mime_type if mime_type else "image/png",
+                    ),
+                )
+
+            files = []
+            if isinstance(form_data.image, str):
+                files = [get_image_file_item(form_data.image)]
+            elif isinstance(form_data.image, list):
+                for img in form_data.image:
+                    files.append(get_image_file_item(img))
+
+            url_search_params = ""
+            if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
+                url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
+
+            # Use asyncio.to_thread for the requests.post call
+            r = await asyncio.to_thread(
+                requests.post,
+                url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
+                headers=headers,
+                files=files,
+                data=data,
+            )
+
+            r.raise_for_status()
+            res = r.json()
+
+            images = []
+            for image in res["data"]:
+                if image_url := image.get("url", None):
+                    image_data, content_type = get_image_data(image_url, headers)
+                else:
+                    image_data, content_type = get_image_data(image["b64_json"])
+
+                url = upload_image(request, image_data, content_type, data, user)
+                images.append({"url": url})
+            return images
+
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
+            headers = {
+                "Content-Type": "application/json",
+                "x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
+            }
+
+            model = f"{model}:generateContent"
+            data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
+
+            if isinstance(form_data.image, str):
+                data["contents"][0]["parts"].append(
+                    {
+                        "inline_data": {
+                            "mime_type": "image/png",
+                            "data": form_data.image.split(",", 1)[1],
+                        }
+                    }
+                )
+            elif isinstance(form_data.image, list):
+                data["contents"][0]["parts"].extend(
+                    [
+                        {
+                            "inline_data": {
+                                "mime_type": "image/png",
+                                "data": image.split(",", 1)[1],
+                            }
+                        }
+                        for image in form_data.image
+                    ]
+                )
+
+            # Use asyncio.to_thread for the requests.post call
+            r = await asyncio.to_thread(
+                requests.post,
+                url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
+                json=data,
+                headers=headers,
+            )
+
+            r.raise_for_status()
+            res = r.json()
+
+            images = []
+            for image in res["candidates"]:
+                for part in image["content"]["parts"]:
+                    if part.get("inlineData", {}).get("data"):
+                        image_data, content_type = get_image_data(
+                            part["inlineData"]["data"]
+                        )
+                        url = upload_image(
+                            request, image_data, content_type, data, user
+                        )
+                        images.append({"url": url})
+
+            return images
+
+        elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
+            data = {
+                "prompt": form_data.prompt,
+                "width": width,
+                "height": height,
+                "n": form_data.n,
+            }
+
+            if request.app.state.config.IMAGE_EDIT_STEPS is not None:
+                data["steps"] = request.app.state.config.IMAGE_EDIT_STEPS
+
+            if form_data.negative_prompt is not None:
+                data["negative_prompt"] = form_data.negative_prompt
+
+            form_data = ComfyUICreateImageForm(
+                **{
+                    "workflow": ComfyUIWorkflow(
+                        **{
+                            "workflow": request.app.state.config.COMFYUI_WORKFLOW,
+                            "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
+                        }
+                    ),
+                    **data,
+                }
+            )
+            res = await comfyui_create_image(
+                model,
+                form_data,
+                user.id,
+                request.app.state.config.COMFYUI_BASE_URL,
+                request.app.state.config.COMFYUI_API_KEY,
+            )
+            log.debug(f"res: {res}")
+
+            images = []
+
+            for image in res["data"]:
+                headers = None
+                if request.app.state.config.COMFYUI_API_KEY:
+                    headers = {
+                        "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
+                    }
+
+                image_data, content_type = get_image_data(image["url"], headers)
+                url = upload_image(
+                    request,
+                    image_data,
+                    content_type,
+                    form_data.model_dump(exclude_none=True),
+                    user,
+                )
+                images.append({"url": url})
+            return images
+    except Exception as e:
+        error = e
+        if r != None:
+            data = r.text
+            try:
+                data = json.loads(data)
+                if "error" in data:
+                    error = data["error"]["message"]
+            except Exception:
+                error = data
+
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))

+ 152 - 65
backend/open_webui/utils/middleware.py

@@ -47,7 +47,8 @@ from open_webui.routers.retrieval import (
 from open_webui.routers.images import (
 from open_webui.routers.images import (
     image_generations,
     image_generations,
     CreateImageForm,
     CreateImageForm,
-    upload_image,
+    image_edits,
+    EditImageForm,
 )
 )
 from open_webui.routers.pipelines import (
 from open_webui.routers.pipelines import (
     process_pipeline_inlet_filter,
     process_pipeline_inlet_filter,
@@ -717,9 +718,31 @@ async def chat_web_search_handler(
     return form_data
     return form_data
 
 
 
 
+def get_last_images(message_list):
+    images = []
+    for message in reversed(message_list):
+        images_flag = False
+        for file in message.get("files", []):
+            if file.get("type") == "image":
+                images.append(file.get("url"))
+                images_flag = True
+
+        if images_flag:
+            break
+
+    return images
+
+
 async def chat_image_generation_handler(
 async def chat_image_generation_handler(
     request: Request, form_data: dict, extra_params: dict, user
     request: Request, form_data: dict, extra_params: dict, user
 ):
 ):
+    metadata = extra_params.get("__metadata__", {})
+    chat_id = metadata.get("chat_id", None)
+    if not chat_id:
+        return form_data
+
+    chat = Chats.get_chat_by_id_and_user_id(chat_id, user.id)
+
     __event_emitter__ = extra_params["__event_emitter__"]
     __event_emitter__ = extra_params["__event_emitter__"]
     await __event_emitter__(
     await __event_emitter__(
         {
         {
@@ -728,87 +751,151 @@ async def chat_image_generation_handler(
         }
         }
     )
     )
 
 
-    messages = form_data["messages"]
-    user_message = get_last_user_message(messages)
+    messages_map = chat.chat.get("history", {}).get("messages", {})
+    message_id = chat.chat.get("history", {}).get("currentId")
+    message_list = get_message_list(messages_map, message_id)
+    user_message = get_last_user_message(message_list)
 
 
     prompt = user_message
     prompt = user_message
-    negative_prompt = ""
+    input_images = get_last_images(message_list)
 
 
-    if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
-        try:
-            res = await generate_image_prompt(
-                request,
-                {
-                    "model": form_data["model"],
-                    "messages": messages,
-                },
-                user,
-            )
+    system_message_content = ""
+    if len(input_images) == 0:
+        # Create image(s)
+        if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION:
+            try:
+                res = await generate_image_prompt(
+                    request,
+                    {
+                        "model": form_data["model"],
+                        "messages": form_data["messages"],
+                    },
+                    user,
+                )
 
 
-            response = res["choices"][0]["message"]["content"]
+                response = res["choices"][0]["message"]["content"]
 
 
-            try:
-                bracket_start = response.find("{")
-                bracket_end = response.rfind("}") + 1
+                try:
+                    bracket_start = response.find("{")
+                    bracket_end = response.rfind("}") + 1
 
 
-                if bracket_start == -1 or bracket_end == -1:
-                    raise Exception("No JSON object found in the response")
+                    if bracket_start == -1 or bracket_end == -1:
+                        raise Exception("No JSON object found in the response")
+
+                    response = response[bracket_start:bracket_end]
+                    response = json.loads(response)
+                    prompt = response.get("prompt", [])
+                except Exception as e:
+                    prompt = user_message
 
 
-                response = response[bracket_start:bracket_end]
-                response = json.loads(response)
-                prompt = response.get("prompt", [])
             except Exception as e:
             except Exception as e:
+                log.exception(e)
                 prompt = user_message
                 prompt = user_message
 
 
+        try:
+            images = await image_generations(
+                request=request,
+                form_data=CreateImageForm(**{"prompt": prompt}),
+                user=user,
+            )
+
+            await __event_emitter__(
+                {
+                    "type": "status",
+                    "data": {"description": "Image created", "done": True},
+                }
+            )
+
+            await __event_emitter__(
+                {
+                    "type": "files",
+                    "data": {
+                        "files": [
+                            {
+                                "type": "image",
+                                "url": image["url"],
+                            }
+                            for image in images
+                        ]
+                    },
+                }
+            )
+
+            system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
         except Exception as e:
         except Exception as e:
-            log.exception(e)
-            prompt = user_message
+            log.debug(e)
 
 
-    system_message_content = ""
+            error_message = ""
+            if isinstance(e, HTTPException):
+                if e.detail and isinstance(e.detail, dict):
+                    error_message = e.detail.get("message", str(e.detail))
+                else:
+                    error_message = str(e.detail)
 
 
-    try:
-        images = await image_generations(
-            request=request,
-            form_data=CreateImageForm(**{"prompt": prompt}),
-            user=user,
-        )
+            await __event_emitter__(
+                {
+                    "type": "status",
+                    "data": {
+                        "description": f"An error occurred while generating an image",
+                        "done": True,
+                    },
+                }
+            )
 
 
-        await __event_emitter__(
-            {
-                "type": "status",
-                "data": {"description": "Image created", "done": True},
-            }
-        )
+            system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
+    else:
+        # Edit image(s)
+        try:
+            images = await image_edits(
+                request=request,
+                form_data=EditImageForm(**{"prompt": prompt, "image": input_images}),
+                user=user,
+            )
 
 
-        await __event_emitter__(
-            {
-                "type": "files",
-                "data": {
-                    "files": [
-                        {
-                            "type": "image",
-                            "url": image["url"],
-                        }
-                        for image in images
-                    ]
-                },
-            }
-        )
+            await __event_emitter__(
+                {
+                    "type": "status",
+                    "data": {"description": "Image created", "done": True},
+                }
+            )
 
 
-        system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>"
-    except Exception as e:
-        log.exception(e)
-        await __event_emitter__(
-            {
-                "type": "status",
-                "data": {
-                    "description": f"An error occurred while generating an image",
-                    "done": True,
-                },
-            }
-        )
+            await __event_emitter__(
+                {
+                    "type": "files",
+                    "data": {
+                        "files": [
+                            {
+                                "type": "image",
+                                "url": image["url"],
+                            }
+                            for image in images
+                        ]
+                    },
+                }
+            )
+
+            system_message_content = "<context>The requested image has been created and is now being shown to the user. Let them know that it has been generated.</context>"
+        except Exception as e:
+            log.debug(e)
+
+            error_message = ""
+            if isinstance(e, HTTPException):
+                if e.detail and isinstance(e.detail, dict):
+                    error_message = e.detail.get("message", str(e.detail))
+                else:
+                    error_message = str(e.detail)
+
+            await __event_emitter__(
+                {
+                    "type": "status",
+                    "data": {
+                        "description": f"An error occurred while generating an image",
+                        "done": True,
+                    },
+                }
+            )
 
 
-        system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>"
+            system_message_content = f"<context>Image generation was attempted but failed. The system is currently unable to generate the image. Tell the user that an error occurred: {error_message}</context>"
 
 
     if system_message_content:
     if system_message_content:
         form_data["messages"] = add_or_update_system_message(
         form_data["messages"] = add_or_update_system_message(

+ 81 - 25
src/lib/components/admin/Settings/Images.svelte

@@ -29,7 +29,7 @@
 	let config = null;
 	let config = null;
 
 
 	let showComfyUIWorkflowEditor = false;
 	let showComfyUIWorkflowEditor = false;
-	let requiredWorkflowNodes = [
+	let REQUIRED_WORKFLOW_NODES = [
 		{
 		{
 			type: 'prompt',
 			type: 'prompt',
 			key: 'text',
 			key: 'text',
@@ -62,6 +62,29 @@
 		}
 		}
 	];
 	];
 
 
+	let REQUIRED_EDIT_WORKFLOW_NODES = [
+		{
+			type: 'prompt',
+			key: 'text',
+			node_ids: ''
+		},
+		{
+			type: 'model',
+			key: 'ckpt_name',
+			node_ids: ''
+		},
+		{
+			type: 'width',
+			key: 'width',
+			node_ids: ''
+		},
+		{
+			type: 'height',
+			key: 'height',
+			node_ids: ''
+		}
+	];
+
 	const getModels = async () => {
 	const getModels = async () => {
 		models = await getImageGenerationModels(localStorage.token).catch((error) => {
 		models = await getImageGenerationModels(localStorage.token).catch((error) => {
 			toast.error(`${error}`);
 			toast.error(`${error}`);
@@ -137,7 +160,7 @@
 		}
 		}
 
 
 		if (config?.COMFYUI_WORKFLOW) {
 		if (config?.COMFYUI_WORKFLOW) {
-			config.COMFYUI_WORKFLOW_NODES = requiredWorkflowNodes.map((node) => {
+			config.COMFYUI_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => {
 				return {
 				return {
 					type: node.type,
 					type: node.type,
 					key: node.key,
 					key: node.key,
@@ -178,7 +201,7 @@
 				}
 				}
 			}
 			}
 
 
-			requiredWorkflowNodes = requiredWorkflowNodes.map((node) => {
+			REQUIRED_WORKFLOW_NODES = REQUIRED_WORKFLOW_NODES.map((node) => {
 				const n = config.COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node;
 				const n = config.COMFYUI_WORKFLOW_NODES.find((n) => n.type === node.type) ?? node;
 				console.debug(n);
 				console.debug(n);
 
 
@@ -665,7 +688,7 @@
 								</div>
 								</div>
 
 
 								<div class="mt-1 text-xs flex flex-col gap-1.5">
 								<div class="mt-1 text-xs flex flex-col gap-1.5">
-									{#each requiredWorkflowNodes as node}
+									{#each REQUIRED_WORKFLOW_NODES as node}
 										<div class="flex w-full flex-col">
 										<div class="flex w-full flex-col">
 											<div class="shrink-0">
 											<div class="shrink-0">
 												<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
 												<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
@@ -791,13 +814,13 @@
 								placeholder={$i18n.t('Select Engine')}
 								placeholder={$i18n.t('Select Engine')}
 							>
 							>
 								<option value="openai">{$i18n.t('Default (Open AI)')}</option>
 								<option value="openai">{$i18n.t('Default (Open AI)')}</option>
-								<option value="comfyui">{$i18n.t('ComfyUI')}</option>
-								<option value="comfyui">{$i18n.t('Gemini')}</option>
+								<!-- <option value="comfyui">{$i18n.t('ComfyUI')}</option> -->
+								<option value="gemini">{$i18n.t('Gemini')}</option>
 							</select>
 							</select>
 						</div>
 						</div>
 					</div>
 					</div>
 
 
-					{#if config.ENABLE_IMAGE_EDIT}
+					{#if config.ENABLE_IMAGE_GENERATION}
 						<div class="mb-2.5">
 						<div class="mb-2.5">
 							<div class="flex w-full justify-between items-center">
 							<div class="flex w-full justify-between items-center">
 								<div class="text-xs pr-2">
 								<div class="text-xs pr-2">
@@ -918,7 +941,7 @@
 										<input
 										<input
 											class="w-full text-sm bg-transparent outline-hidden text-right"
 											class="w-full text-sm bg-transparent outline-hidden text-right"
 											placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
 											placeholder={$i18n.t('Enter URL (e.g. http://127.0.0.1:7860/)')}
-											bind:value={config.COMFYUI_BASE_URL}
+											bind:value={config.IMAGES_EDIT_COMFYUI_BASE_URL}
 										/>
 										/>
 									</div>
 									</div>
 									<button
 									<button
@@ -967,7 +990,7 @@
 										<SensitiveInput
 										<SensitiveInput
 											inputClassName="text-right w-full"
 											inputClassName="text-right w-full"
 											placeholder={$i18n.t('sk-1234')}
 											placeholder={$i18n.t('sk-1234')}
-											bind:value={config.COMFYUI_API_KEY}
+											bind:value={config.IMAGES_EDIT_COMFYUI_API_KEY}
 											required={false}
 											required={false}
 										/>
 										/>
 									</div>
 									</div>
@@ -977,7 +1000,7 @@
 
 
 						<div class="mb-2.5">
 						<div class="mb-2.5">
 							<input
 							<input
-								id="upload-comfyui-workflow-input"
+								id="upload-comfyui-edit-workflow-input"
 								hidden
 								hidden
 								type="file"
 								type="file"
 								accept=".json"
 								accept=".json"
@@ -986,7 +1009,7 @@
 									const reader = new FileReader();
 									const reader = new FileReader();
 
 
 									reader.onload = (e) => {
 									reader.onload = (e) => {
-										config.COMFYUI_WORKFLOW = e.target.result;
+										config.IMAGES_EDIT_COMFYUI_WORKFLOW = e.target.result;
 										e.target.value = null;
 										e.target.value = null;
 									};
 									};
 
 
@@ -1002,7 +1025,7 @@
 
 
 								<div class="flex w-full">
 								<div class="flex w-full">
 									<div class="flex-1 mr-2 justify-end flex gap-1">
 									<div class="flex-1 mr-2 justify-end flex gap-1">
-										{#if config.COMFYUI_WORKFLOW}
+										{#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
 											<button
 											<button
 												class="text-xs text-gray-700 dark:text-gray-400 underline"
 												class="text-xs text-gray-700 dark:text-gray-400 underline"
 												type="button"
 												type="button"
@@ -1022,7 +1045,7 @@
 												type="button"
 												type="button"
 												aria-label={$i18n.t('Click here to upload a workflow.json file.')}
 												aria-label={$i18n.t('Click here to upload a workflow.json file.')}
 												on:click={() => {
 												on:click={() => {
-													document.getElementById('upload-comfyui-workflow-input')?.click();
+													document.getElementById('upload-comfyui-edit-workflow-input')?.click();
 												}}
 												}}
 											>
 											>
 												{$i18n.t('Upload')}
 												{$i18n.t('Upload')}
@@ -1035,28 +1058,20 @@
 							<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
 							<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
 								<CodeEditorModal
 								<CodeEditorModal
 									bind:show={showComfyUIWorkflowEditor}
 									bind:show={showComfyUIWorkflowEditor}
-									value={config.COMFYUI_WORKFLOW}
+									value={config.IMAGES_EDIT_COMFYUI_WORKFLOW}
 									lang="json"
 									lang="json"
 									onChange={(e) => {
 									onChange={(e) => {
-										config.COMFYUI_WORKFLOW = e;
+										config.IMAGES_EDIT_COMFYUI_WORKFLOW = e;
 									}}
 									}}
 									onSave={() => {
 									onSave={() => {
 										console.log('Saved');
 										console.log('Saved');
 									}}
 									}}
 								/>
 								/>
-								<!-- {#if config.COMFYUI_WORKFLOW}
-									<Textarea
-										class="w-full rounded-lg my-1 py-2 px-3 text-xs bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-hidden disabled:text-gray-600 resize-none"
-										rows="10"
-										bind:value={config.COMFYUI_WORKFLOW}
-										required
-									/>
-								{/if} -->
 								{$i18n.t('Make sure to export a workflow.json file as API format from ComfyUI.')}
 								{$i18n.t('Make sure to export a workflow.json file as API format from ComfyUI.')}
 							</div>
 							</div>
 						</div>
 						</div>
 
 
-						{#if config.COMFYUI_WORKFLOW}
+						{#if config.IMAGES_EDIT_COMFYUI_WORKFLOW}
 							<div class="mb-2.5">
 							<div class="mb-2.5">
 								<div class="flex w-full justify-between items-center">
 								<div class="flex w-full justify-between items-center">
 									<div class="text-xs pr-2 shrink-0">
 									<div class="text-xs pr-2 shrink-0">
@@ -1067,7 +1082,7 @@
 								</div>
 								</div>
 
 
 								<div class="mt-1 text-xs flex flex-col gap-1.5">
 								<div class="mt-1 text-xs flex flex-col gap-1.5">
-									{#each requiredWorkflowNodes as node}
+									{#each REQUIRED_EDIT_WORKFLOW_NODES as node}
 										<div class="flex w-full flex-col">
 										<div class="flex w-full flex-col">
 											<div class="shrink-0">
 											<div class="shrink-0">
 												<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
 												<div class=" capitalize line-clamp-1 w-20 text-gray-400 dark:text-gray-500">
@@ -1111,6 +1126,47 @@
 								</div>
 								</div>
 							</div>
 							</div>
 						{/if}
 						{/if}
+					{:else if config?.IMAGE_GENERATION_ENGINE === 'gemini'}
+						<div class="mb-2.5">
+							<div class="flex w-full justify-between items-center">
+								<div class="text-xs pr-2 shrink-0">
+									<div class="">
+										{$i18n.t('Gemini Base URL')}
+									</div>
+								</div>
+
+								<div class="flex w-full">
+									<div class="flex-1">
+										<input
+											class="w-full text-sm bg-transparent outline-hidden text-right"
+											placeholder={$i18n.t('API Base URL')}
+											bind:value={config.IMAGES_EDIT_GEMINI_API_BASE_URL}
+										/>
+									</div>
+								</div>
+							</div>
+						</div>
+
+						<div class="mb-2.5">
+							<div class="flex w-full justify-between items-center">
+								<div class="text-xs pr-2 shrink-0">
+									<div class="">
+										{$i18n.t('Gemini API Key')}
+									</div>
+								</div>
+
+								<div class="flex w-full">
+									<div class="flex-1">
+										<SensitiveInput
+											inputClassName="text-right w-full"
+											placeholder={$i18n.t('API Key')}
+											bind:value={config.IMAGES_EDIT_GEMINI_API_KEY}
+											required={true}
+										/>
+									</div>
+								</div>
+							</div>
+						</div>
 					{/if}
 					{/if}
 				</div>
 				</div>
 			</div>
 			</div>