Sfoglia il codice sorgente

fix: image generation with allowed file extensions

Timothy Jaeryang Baek 4 mesi fa
parent
commit
7a593b63b2

+ 3 - 3
backend/open_webui/routers/files.py

@@ -85,12 +85,12 @@ def upload_file(
     request: Request,
     file: UploadFile = File(...),
     user=Depends(get_verified_user),
-    file_metadata: dict = None,
+    metadata: dict = None,
     process: bool = Query(True),
 ):
     log.info(f"file.content_type: {file.content_type}")
 
-    file_metadata = file_metadata if file_metadata else {}
+    file_metadata = metadata if metadata else {}
     try:
         unsanitized_filename = file.filename
         filename = os.path.basename(unsanitized_filename)
@@ -99,7 +99,7 @@ def upload_file(
         # Remove the leading dot from the file extension
         file_extension = file_extension[1:] if file_extension else ""
 
-        if request.app.state.config.ALLOWED_FILE_EXTENSIONS:
+        if not file_metadata and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
             request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
                 ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
             ]

+ 3 - 3
backend/open_webui/routers/images.py

@@ -451,7 +451,7 @@ def load_url_image_data(url, headers=None):
         return None
 
 
-def upload_image(request, image_metadata, image_data, content_type, user):
+def upload_image(request, image_data, content_type, metadata, user):
     image_format = mimetypes.guess_extension(content_type)
     file = UploadFile(
         file=io.BytesIO(image_data),
@@ -460,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
             "content-type": content_type,
         },
     )
-    file_item = upload_file(request, file, user, file_metadata=image_metadata)
+    file_item = upload_file(request, file, user, metadata=metadata)
     url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
     return url
 
@@ -527,7 +527,7 @@ async def image_generations(
                 else:
                     image_data, content_type = load_b64_image_data(image["b64_json"])
 
-                url = upload_image(request, data, image_data, content_type, user)
+                url = upload_image(request, image_data, content_type, data, user)
                 images.append({"url": url})
             return images