Browse Source

Merge pull request #14774 from rragundez/images-from-db

fix: Store and load code interpreter generated images from a central location (DB and/or cloud storage)
Tim Jaeryang Baek 8 months ago
parent
commit
6cb519ca0e
2 changed files with 31 additions and 43 deletions
  1. 2 2
      backend/open_webui/routers/images.py
  2. 29 41
      backend/open_webui/utils/middleware.py

+ 2 - 2
backend/open_webui/routers/images.py

@@ -420,7 +420,7 @@ def load_b64_image_data(b64_str):
     try:
         if "," in b64_str:
             header, encoded = b64_str.split(",", 1)
-            mime_type = header.split(";")[0]
+            mime_type = header.split(";")[0].lstrip("data:")
             img_data = base64.b64decode(encoded)
         else:
             mime_type = "image/png"
@@ -428,7 +428,7 @@ def load_b64_image_data(b64_str):
         return img_data, mime_type
     except Exception as e:
         log.exception(f"Error loading image data: {e}")
-        return None
+        return None, None
 
 
 def load_url_image_data(url, headers=None):

+ 29 - 41
backend/open_webui/utils/middleware.py

@@ -37,7 +37,12 @@ from open_webui.routers.tasks import (
     generate_chat_tags,
 )
 from open_webui.routers.retrieval import process_web_search, SearchForm
-from open_webui.routers.images import image_generations, GenerateImageForm
+from open_webui.routers.images import (
+    load_b64_image_data,
+    image_generations,
+    GenerateImageForm,
+    upload_image,
+)
 from open_webui.routers.pipelines import (
     process_pipeline_inlet_filter,
     process_pipeline_outlet_filter,
@@ -2259,28 +2264,19 @@ async def process_chat_response(
                                         stdoutLines = stdout.split("\n")
                                         for idx, line in enumerate(stdoutLines):
                                             if "data:image/png;base64" in line:
-                                                id = str(uuid4())
-
-                                                # ensure the path exists
-                                                os.makedirs(
-                                                    os.path.join(CACHE_DIR, "images"),
-                                                    exist_ok=True,
-                                                )
-
-                                                image_path = os.path.join(
-                                                    CACHE_DIR,
-                                                    f"images/{id}.png",
-                                                )
-
-                                                with open(image_path, "wb") as f:
-                                                    f.write(
-                                                        base64.b64decode(
-                                                            line.split(",")[1]
-                                                        )
+                                                image_url = ""
+                                                # Extract base64 image data from the line
+                                                image_data, content_type = load_b64_image_data(line)
+                                                if image_data is not None:
+                                                    image_url = upload_image(
+                                                        request,
+                                                        image_data,
+                                                        content_type,
+                                                        metadata,
+                                                        user,
                                                     )
-
                                                 stdoutLines[idx] = (
-                                                    f"![Output Image {idx}](/cache/images/{id}.png)"
+                                                    f"![Output Image]({image_url})"
                                                 )
 
                                         output["stdout"] = "\n".join(stdoutLines)
@@ -2291,30 +2287,22 @@ async def process_chat_response(
                                         resultLines = result.split("\n")
                                         for idx, line in enumerate(resultLines):
                                             if "data:image/png;base64" in line:
-                                                id = str(uuid4())
-
-                                                # ensure the path exists
-                                                os.makedirs(
-                                                    os.path.join(CACHE_DIR, "images"),
-                                                    exist_ok=True,
-                                                )
-
-                                                image_path = os.path.join(
-                                                    CACHE_DIR,
-                                                    f"images/{id}.png",
+                                                image_url = ""
+                                                # Extract base64 image data from the line
+                                                image_data, content_type = (
+                                                    load_b64_image_data(line)
                                                 )
-
-                                                with open(image_path, "wb") as f:
-                                                    f.write(
-                                                        base64.b64decode(
-                                                            line.split(",")[1]
-                                                        )
+                                                if image_data is not None:
+                                                    image_url = upload_image(
+                                                        request,
+                                                        image_data,
+                                                        content_type,
+                                                        metadata,
+                                                        user,
                                                     )
-
                                                 resultLines[idx] = (
-                                                    f"![Output Image {idx}](/cache/images/{id}.png)"
+                                                    f"![Output Image]({image_url})"
                                                 )
-
                                         output["result"] = "\n".join(resultLines)
                         except Exception as e:
                             output = str(e)