images.py 40 KB


  1. import asyncio
  2. import base64
  3. import uuid
  4. import io
  5. import json
  6. import logging
  7. import mimetypes
  8. import re
  9. from pathlib import Path
  10. from typing import Optional
  11. from urllib.parse import quote
  12. import requests
  13. from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
  14. from fastapi.responses import FileResponse
  15. from open_webui.config import CACHE_DIR
  16. from open_webui.constants import ERROR_MESSAGES
  17. from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
  18. from open_webui.routers.files import upload_file_handler, get_file_content_by_id
  19. from open_webui.utils.auth import get_admin_user, get_verified_user
  20. from open_webui.utils.headers import include_user_info_headers
  21. from open_webui.utils.images.comfyui import (
  22. ComfyUICreateImageForm,
  23. ComfyUIEditImageForm,
  24. ComfyUIWorkflow,
  25. comfyui_upload_image,
  26. comfyui_create_image,
  27. comfyui_edit_image,
  28. )
  29. from pydantic import BaseModel
  30. log = logging.getLogger(__name__)
  31. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  32. IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
  33. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  34. router = APIRouter()
  35. def set_image_model(request: Request, model: str):
  36. log.info(f"Setting image model to {model}")
  37. request.app.state.config.IMAGE_GENERATION_MODEL = model
  38. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  39. api_auth = get_automatic1111_api_auth(request)
  40. r = requests.get(
  41. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  42. headers={"authorization": api_auth},
  43. )
  44. options = r.json()
  45. if model != options["sd_model_checkpoint"]:
  46. options["sd_model_checkpoint"] = model
  47. r = requests.post(
  48. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  49. json=options,
  50. headers={"authorization": api_auth},
  51. )
  52. return request.app.state.config.IMAGE_GENERATION_MODEL
  53. def get_image_model(request):
  54. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  55. return (
  56. request.app.state.config.IMAGE_GENERATION_MODEL
  57. if request.app.state.config.IMAGE_GENERATION_MODEL
  58. else "dall-e-2"
  59. )
  60. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  61. return (
  62. request.app.state.config.IMAGE_GENERATION_MODEL
  63. if request.app.state.config.IMAGE_GENERATION_MODEL
  64. else "imagen-3.0-generate-002"
  65. )
  66. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  67. return (
  68. request.app.state.config.IMAGE_GENERATION_MODEL
  69. if request.app.state.config.IMAGE_GENERATION_MODEL
  70. else ""
  71. )
  72. elif (
  73. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  74. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  75. ):
  76. try:
  77. r = requests.get(
  78. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  79. headers={"authorization": get_automatic1111_api_auth(request)},
  80. )
  81. options = r.json()
  82. return options["sd_model_checkpoint"]
  83. except Exception as e:
  84. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  85. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  86. class ImagesConfig(BaseModel):
  87. ENABLE_IMAGE_GENERATION: bool
  88. ENABLE_IMAGE_PROMPT_GENERATION: bool
  89. IMAGE_GENERATION_ENGINE: str
  90. IMAGE_GENERATION_MODEL: str
  91. IMAGE_SIZE: Optional[str]
  92. IMAGE_STEPS: Optional[int]
  93. IMAGES_OPENAI_API_BASE_URL: str
  94. IMAGES_OPENAI_API_KEY: str
  95. IMAGES_OPENAI_API_VERSION: str
  96. AUTOMATIC1111_BASE_URL: str
  97. AUTOMATIC1111_API_AUTH: str
  98. AUTOMATIC1111_PARAMS: Optional[dict | str]
  99. COMFYUI_BASE_URL: str
  100. COMFYUI_API_KEY: str
  101. COMFYUI_WORKFLOW: str
  102. COMFYUI_WORKFLOW_NODES: list[dict]
  103. IMAGES_GEMINI_API_BASE_URL: str
  104. IMAGES_GEMINI_API_KEY: str
  105. IMAGES_GEMINI_ENDPOINT_METHOD: str
  106. IMAGE_EDIT_ENGINE: str
  107. IMAGE_EDIT_MODEL: str
  108. IMAGE_EDIT_SIZE: Optional[str]
  109. IMAGES_EDIT_OPENAI_API_BASE_URL: str
  110. IMAGES_EDIT_OPENAI_API_KEY: str
  111. IMAGES_EDIT_OPENAI_API_VERSION: str
  112. IMAGES_EDIT_GEMINI_API_BASE_URL: str
  113. IMAGES_EDIT_GEMINI_API_KEY: str
  114. IMAGES_EDIT_COMFYUI_BASE_URL: str
  115. IMAGES_EDIT_COMFYUI_API_KEY: str
  116. IMAGES_EDIT_COMFYUI_WORKFLOW: str
  117. IMAGES_EDIT_COMFYUI_WORKFLOW_NODES: list[dict]
  118. @router.get("/config", response_model=ImagesConfig)
  119. async def get_config(request: Request, user=Depends(get_admin_user)):
  120. return {
  121. "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
  122. "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  123. "IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
  124. "IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  125. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  126. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  127. "IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  128. "IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  129. "IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
  130. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  131. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  132. "AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
  133. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  134. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  135. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  136. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  137. "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  138. "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  139. "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
  140. "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
  141. "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
  142. "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
  143. "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
  144. "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
  145. "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
  146. "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
  147. "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
  148. "IMAGES_EDIT_COMFYUI_BASE_URL": request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
  149. "IMAGES_EDIT_COMFYUI_API_KEY": request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
  150. "IMAGES_EDIT_COMFYUI_WORKFLOW": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW,
  151. "IMAGES_EDIT_COMFYUI_WORKFLOW_NODES": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
  152. }
  153. @router.post("/config/update")
  154. async def update_config(
  155. request: Request, form_data: ImagesConfig, user=Depends(get_admin_user)
  156. ):
  157. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.ENABLE_IMAGE_GENERATION
  158. # Create Image
  159. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  160. form_data.ENABLE_IMAGE_PROMPT_GENERATION
  161. )
  162. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.IMAGE_GENERATION_ENGINE
  163. set_image_model(request, form_data.IMAGE_GENERATION_MODEL)
  164. if (
  165. form_data.IMAGE_SIZE == "auto"
  166. and form_data.IMAGE_GENERATION_MODEL != "gpt-image-1"
  167. ):
  168. raise HTTPException(
  169. status_code=400,
  170. detail=ERROR_MESSAGES.INCORRECT_FORMAT(
  171. " (auto is only allowed with gpt-image-1)."
  172. ),
  173. )
  174. pattern = r"^\d+x\d+$"
  175. if (
  176. form_data.IMAGE_SIZE == "auto"
  177. or form_data.IMAGE_SIZE == ""
  178. or re.match(pattern, form_data.IMAGE_SIZE)
  179. ):
  180. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  181. else:
  182. raise HTTPException(
  183. status_code=400,
  184. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  185. )
  186. if form_data.IMAGE_STEPS >= 0:
  187. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  188. else:
  189. raise HTTPException(
  190. status_code=400,
  191. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  192. )
  193. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  194. form_data.IMAGES_OPENAI_API_BASE_URL
  195. )
  196. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.IMAGES_OPENAI_API_KEY
  197. request.app.state.config.IMAGES_OPENAI_API_VERSION = (
  198. form_data.IMAGES_OPENAI_API_VERSION
  199. )
  200. request.app.state.config.AUTOMATIC1111_BASE_URL = form_data.AUTOMATIC1111_BASE_URL
  201. request.app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
  202. request.app.state.config.AUTOMATIC1111_PARAMS = form_data.AUTOMATIC1111_PARAMS
  203. request.app.state.config.COMFYUI_BASE_URL = form_data.COMFYUI_BASE_URL.strip("/")
  204. request.app.state.config.COMFYUI_API_KEY = form_data.COMFYUI_API_KEY
  205. request.app.state.config.COMFYUI_WORKFLOW = form_data.COMFYUI_WORKFLOW
  206. request.app.state.config.COMFYUI_WORKFLOW_NODES = form_data.COMFYUI_WORKFLOW_NODES
  207. request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
  208. form_data.IMAGES_GEMINI_API_BASE_URL
  209. )
  210. request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.IMAGES_GEMINI_API_KEY
  211. request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD = (
  212. form_data.IMAGES_GEMINI_ENDPOINT_METHOD
  213. )
  214. # Edit Image
  215. request.app.state.config.IMAGE_EDIT_ENGINE = form_data.IMAGE_EDIT_ENGINE
  216. request.app.state.config.IMAGE_EDIT_MODEL = form_data.IMAGE_EDIT_MODEL
  217. request.app.state.config.IMAGE_EDIT_SIZE = form_data.IMAGE_EDIT_SIZE
  218. request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL = (
  219. form_data.IMAGES_OPENAI_API_BASE_URL
  220. )
  221. request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY = (
  222. form_data.IMAGES_OPENAI_API_KEY
  223. )
  224. request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION = (
  225. form_data.IMAGES_EDIT_OPENAI_API_VERSION
  226. )
  227. request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL = (
  228. form_data.IMAGES_EDIT_GEMINI_API_BASE_URL
  229. )
  230. request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY = (
  231. form_data.IMAGES_EDIT_GEMINI_API_KEY
  232. )
  233. request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL = (
  234. form_data.IMAGES_EDIT_COMFYUI_BASE_URL.strip("/")
  235. )
  236. request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY = (
  237. form_data.IMAGES_EDIT_COMFYUI_API_KEY
  238. )
  239. request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW = (
  240. form_data.IMAGES_EDIT_COMFYUI_WORKFLOW
  241. )
  242. request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES = (
  243. form_data.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES
  244. )
  245. return {
  246. "ENABLE_IMAGE_GENERATION": request.app.state.config.ENABLE_IMAGE_GENERATION,
  247. "ENABLE_IMAGE_PROMPT_GENERATION": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  248. "IMAGE_GENERATION_ENGINE": request.app.state.config.IMAGE_GENERATION_ENGINE,
  249. "IMAGE_GENERATION_MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  250. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  251. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  252. "IMAGES_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  253. "IMAGES_OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  254. "IMAGES_OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
  255. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  256. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  257. "AUTOMATIC1111_PARAMS": request.app.state.config.AUTOMATIC1111_PARAMS,
  258. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  259. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  260. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  261. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  262. "IMAGES_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  263. "IMAGES_GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  264. "IMAGES_GEMINI_ENDPOINT_METHOD": request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD,
  265. "IMAGE_EDIT_ENGINE": request.app.state.config.IMAGE_EDIT_ENGINE,
  266. "IMAGE_EDIT_MODEL": request.app.state.config.IMAGE_EDIT_MODEL,
  267. "IMAGE_EDIT_SIZE": request.app.state.config.IMAGE_EDIT_SIZE,
  268. "IMAGES_EDIT_OPENAI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL,
  269. "IMAGES_EDIT_OPENAI_API_KEY": request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY,
  270. "IMAGES_EDIT_OPENAI_API_VERSION": request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION,
  271. "IMAGES_EDIT_GEMINI_API_BASE_URL": request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL,
  272. "IMAGES_EDIT_GEMINI_API_KEY": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
  273. "IMAGES_EDIT_COMFYUI_BASE_URL": request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
  274. "IMAGES_EDIT_COMFYUI_API_KEY": request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
  275. "IMAGES_EDIT_COMFYUI_WORKFLOW": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW,
  276. "IMAGES_EDIT_COMFYUI_WORKFLOW_NODES": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
  277. }
  278. def get_automatic1111_api_auth(request: Request):
  279. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  280. return ""
  281. else:
  282. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  283. "utf-8"
  284. )
  285. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  286. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  287. return f"Basic {auth1111_base64_encoded_string}"
  288. @router.get("/config/url/verify")
  289. async def verify_url(request: Request, user=Depends(get_admin_user)):
  290. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  291. try:
  292. r = requests.get(
  293. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  294. headers={"authorization": get_automatic1111_api_auth(request)},
  295. )
  296. r.raise_for_status()
  297. return True
  298. except Exception:
  299. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  300. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  301. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  302. headers = None
  303. if request.app.state.config.COMFYUI_API_KEY:
  304. headers = {
  305. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  306. }
  307. try:
  308. r = requests.get(
  309. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  310. headers=headers,
  311. )
  312. r.raise_for_status()
  313. return True
  314. except Exception:
  315. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  316. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  317. else:
  318. return True
  319. @router.get("/models")
  320. def get_models(request: Request, user=Depends(get_verified_user)):
  321. try:
  322. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  323. return [
  324. {"id": "dall-e-2", "name": "DALL·E 2"},
  325. {"id": "dall-e-3", "name": "DALL·E 3"},
  326. {"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
  327. ]
  328. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  329. return [
  330. {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
  331. ]
  332. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  333. # TODO - get models from comfyui
  334. headers = {
  335. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  336. }
  337. r = requests.get(
  338. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  339. headers=headers,
  340. )
  341. info = r.json()
  342. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  343. model_node_id = None
  344. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  345. if node["type"] == "model":
  346. if node["node_ids"]:
  347. model_node_id = node["node_ids"][0]
  348. break
  349. if model_node_id:
  350. model_list_key = None
  351. log.info(workflow[model_node_id]["class_type"])
  352. for key in info[workflow[model_node_id]["class_type"]]["input"][
  353. "required"
  354. ]:
  355. if "_name" in key:
  356. model_list_key = key
  357. break
  358. if model_list_key:
  359. return list(
  360. map(
  361. lambda model: {"id": model, "name": model},
  362. info[workflow[model_node_id]["class_type"]]["input"][
  363. "required"
  364. ][model_list_key][0],
  365. )
  366. )
  367. else:
  368. return list(
  369. map(
  370. lambda model: {"id": model, "name": model},
  371. info["CheckpointLoaderSimple"]["input"]["required"][
  372. "ckpt_name"
  373. ][0],
  374. )
  375. )
  376. elif (
  377. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  378. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  379. ):
  380. r = requests.get(
  381. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  382. headers={"authorization": get_automatic1111_api_auth(request)},
  383. )
  384. models = r.json()
  385. return list(
  386. map(
  387. lambda model: {"id": model["title"], "name": model["model_name"]},
  388. models,
  389. )
  390. )
  391. except Exception as e:
  392. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  393. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  394. class CreateImageForm(BaseModel):
  395. model: Optional[str] = None
  396. prompt: str
  397. size: Optional[str] = None
  398. n: int = 1
  399. negative_prompt: Optional[str] = None
  400. GenerateImageForm = CreateImageForm # Alias for backward compatibility
  401. def get_image_data(data: str, headers=None):
  402. try:
  403. if data.startswith("http://") or data.startswith("https://"):
  404. if headers:
  405. r = requests.get(data, headers=headers)
  406. else:
  407. r = requests.get(data)
  408. r.raise_for_status()
  409. if r.headers["content-type"].split("/")[0] == "image":
  410. mime_type = r.headers["content-type"]
  411. return r.content, mime_type
  412. else:
  413. log.error("Url does not point to an image.")
  414. return None
  415. else:
  416. if "," in data:
  417. header, encoded = data.split(",", 1)
  418. mime_type = header.split(";")[0].lstrip("data:")
  419. img_data = base64.b64decode(encoded)
  420. else:
  421. mime_type = "image/png"
  422. img_data = base64.b64decode(data)
  423. return img_data, mime_type
  424. except Exception as e:
  425. log.exception(f"Error loading image data: {e}")
  426. return None, None
  427. def upload_image(request, image_data, content_type, metadata, user):
  428. image_format = mimetypes.guess_extension(content_type)
  429. file = UploadFile(
  430. file=io.BytesIO(image_data),
  431. filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
  432. headers={
  433. "content-type": content_type,
  434. },
  435. )
  436. file_item = upload_file_handler(
  437. request,
  438. file=file,
  439. metadata=metadata,
  440. process=False,
  441. user=user,
  442. )
  443. url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
  444. return url
  445. @router.post("/generations")
  446. async def image_generations(
  447. request: Request,
  448. form_data: CreateImageForm,
  449. user=Depends(get_verified_user),
  450. ):
  451. # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
  452. # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
  453. # image model other than gpt-image-1, which is warned about on settings save
  454. size = "512x512"
  455. if (
  456. request.app.state.config.IMAGE_SIZE
  457. and "x" in request.app.state.config.IMAGE_SIZE
  458. ):
  459. size = request.app.state.config.IMAGE_SIZE
  460. if form_data.size and "x" in form_data.size:
  461. size = form_data.size
  462. width, height = tuple(map(int, size.split("x")))
  463. model = get_image_model(request)
  464. r = None
  465. try:
  466. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  467. headers = {
  468. "Authorization": f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}",
  469. "Content-Type": "application/json",
  470. }
  471. if ENABLE_FORWARD_USER_INFO_HEADERS:
  472. headers = include_user_info_headers(headers, user)
  473. data = {
  474. "model": model,
  475. "prompt": form_data.prompt,
  476. "n": form_data.n,
  477. "size": (
  478. form_data.size
  479. if form_data.size
  480. else request.app.state.config.IMAGE_SIZE
  481. ),
  482. **(
  483. {}
  484. if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
  485. else {"response_format": "b64_json"}
  486. ),
  487. }
  488. api_version_query_param = ""
  489. if request.app.state.config.IMAGES_OPENAI_API_VERSION:
  490. api_version_query_param = (
  491. f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
  492. )
  493. # Use asyncio.to_thread for the requests.post call
  494. r = await asyncio.to_thread(
  495. requests.post,
  496. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}",
  497. json=data,
  498. headers=headers,
  499. )
  500. r.raise_for_status()
  501. res = r.json()
  502. images = []
  503. for image in res["data"]:
  504. if image_url := image.get("url", None):
  505. image_data, content_type = get_image_data(image_url, headers)
  506. else:
  507. image_data, content_type = get_image_data(image["b64_json"])
  508. url = upload_image(request, image_data, content_type, data, user)
  509. images.append({"url": url})
  510. return images
  511. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  512. headers = {
  513. "Content-Type": "application/json",
  514. "x-goog-api-key": request.app.state.config.IMAGES_GEMINI_API_KEY,
  515. }
  516. data = {}
  517. if (
  518. request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == ""
  519. or request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD == "predict"
  520. ):
  521. model = f"{model}:predict"
  522. data = {
  523. "instances": {"prompt": form_data.prompt},
  524. "parameters": {
  525. "sampleCount": form_data.n,
  526. "outputOptions": {"mimeType": "image/png"},
  527. },
  528. }
  529. elif (
  530. request.app.state.config.IMAGES_GEMINI_ENDPOINT_METHOD
  531. == "generateContent"
  532. ):
  533. model = f"{model}:generateContent"
  534. data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
  535. # Use asyncio.to_thread for the requests.post call
  536. r = await asyncio.to_thread(
  537. requests.post,
  538. url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}",
  539. json=data,
  540. headers=headers,
  541. )
  542. r.raise_for_status()
  543. res = r.json()
  544. images = []
  545. if model.endswith(":predict"):
  546. for image in res["predictions"]:
  547. image_data, content_type = get_image_data(
  548. image["bytesBase64Encoded"]
  549. )
  550. url = upload_image(request, image_data, content_type, data, user)
  551. images.append({"url": url})
  552. elif model.endswith(":generateContent"):
  553. for image in res["candidates"]:
  554. for part in image["content"]["parts"]:
  555. if part.get("inlineData", {}).get("data"):
  556. image_data, content_type = get_image_data(
  557. part["inlineData"]["data"]
  558. )
  559. url = upload_image(
  560. request, image_data, content_type, data, user
  561. )
  562. images.append({"url": url})
  563. return images
  564. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  565. data = {
  566. "prompt": form_data.prompt,
  567. "width": width,
  568. "height": height,
  569. "n": form_data.n,
  570. }
  571. if request.app.state.config.IMAGE_STEPS is not None:
  572. data["steps"] = request.app.state.config.IMAGE_STEPS
  573. if form_data.negative_prompt is not None:
  574. data["negative_prompt"] = form_data.negative_prompt
  575. form_data = ComfyUICreateImageForm(
  576. **{
  577. "workflow": ComfyUIWorkflow(
  578. **{
  579. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  580. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  581. }
  582. ),
  583. **data,
  584. }
  585. )
  586. res = await comfyui_create_image(
  587. model,
  588. form_data,
  589. user.id,
  590. request.app.state.config.COMFYUI_BASE_URL,
  591. request.app.state.config.COMFYUI_API_KEY,
  592. )
  593. log.debug(f"res: {res}")
  594. images = []
  595. for image in res["data"]:
  596. headers = None
  597. if request.app.state.config.COMFYUI_API_KEY:
  598. headers = {
  599. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  600. }
  601. image_data, content_type = get_image_data(image["url"], headers)
  602. url = upload_image(
  603. request,
  604. image_data,
  605. content_type,
  606. form_data.model_dump(exclude_none=True),
  607. user,
  608. )
  609. images.append({"url": url})
  610. return images
  611. elif (
  612. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  613. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  614. ):
  615. if form_data.model:
  616. set_image_model(request, form_data.model)
  617. data = {
  618. "prompt": form_data.prompt,
  619. "batch_size": form_data.n,
  620. "width": width,
  621. "height": height,
  622. }
  623. if request.app.state.config.IMAGE_STEPS is not None:
  624. data["steps"] = request.app.state.config.IMAGE_STEPS
  625. if form_data.negative_prompt is not None:
  626. data["negative_prompt"] = form_data.negative_prompt
  627. if request.app.state.config.AUTOMATIC1111_PARAMS:
  628. data = {**data, **request.app.state.config.AUTOMATIC1111_PARAMS}
  629. # Use asyncio.to_thread for the requests.post call
  630. r = await asyncio.to_thread(
  631. requests.post,
  632. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  633. json=data,
  634. headers={"authorization": get_automatic1111_api_auth(request)},
  635. )
  636. res = r.json()
  637. log.debug(f"res: {res}")
  638. images = []
  639. for image in res["images"]:
  640. image_data, content_type = get_image_data(image)
  641. url = upload_image(
  642. request,
  643. image_data,
  644. content_type,
  645. {**data, "info": res["info"]},
  646. user,
  647. )
  648. images.append({"url": url})
  649. return images
  650. except Exception as e:
  651. error = e
  652. if r != None:
  653. data = r.json()
  654. if "error" in data:
  655. error = data["error"]["message"]
  656. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
  657. class EditImageForm(BaseModel):
  658. image: str | list[str] # base64-encoded image(s) or URL(s)
  659. prompt: str
  660. model: Optional[str] = None
  661. size: Optional[str] = None
  662. n: Optional[int] = None
  663. negative_prompt: Optional[str] = None
  664. @router.post("/edit")
  665. async def image_edits(
  666. request: Request,
  667. form_data: EditImageForm,
  668. user=Depends(get_verified_user),
  669. ):
  670. size = None
  671. width, height = None, None
  672. if (
  673. request.app.state.config.IMAGE_EDIT_SIZE
  674. and "x" in request.app.state.config.IMAGE_EDIT_SIZE
  675. ) or (form_data.size and "x" in form_data.size):
  676. size = (
  677. form_data.size
  678. if form_data.size
  679. else request.app.state.config.IMAGE_EDIT_SIZE
  680. )
  681. width, height = tuple(map(int, size.split("x")))
  682. model = (
  683. request.app.state.config.IMAGE_EDIT_MODEL
  684. if form_data.model is None
  685. else form_data.model
  686. )
  687. try:
  688. async def load_url_image(data):
  689. if data.startswith("http://") or data.startswith("https://"):
  690. r = await asyncio.to_thread(requests.get, data)
  691. r.raise_for_status()
  692. image_data = base64.b64encode(r.content).decode("utf-8")
  693. return f"data:{r.headers['content-type']};base64,{image_data}"
  694. elif data.startswith("/api/v1/files"):
  695. file_id = data.split("/api/v1/files/")[1].split("/content")[0]
  696. file_response = await get_file_content_by_id(file_id, user)
  697. if isinstance(file_response, FileResponse):
  698. file_path = file_response.path
  699. with open(file_path, "rb") as f:
  700. file_bytes = f.read()
  701. image_data = base64.b64encode(file_bytes).decode("utf-8")
  702. mime_type, _ = mimetypes.guess_type(file_path)
  703. return f"data:{mime_type};base64,{image_data}"
  704. return data
  705. # Load image(s) from URL(s) if necessary
  706. if isinstance(form_data.image, str):
  707. form_data.image = await load_url_image(form_data.image)
  708. elif isinstance(form_data.image, list):
  709. form_data.image = [await load_url_image(img) for img in form_data.image]
  710. except Exception as e:
  711. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  712. def get_image_file_item(base64_string):
  713. data = base64_string
  714. header, encoded = data.split(",", 1)
  715. mime_type = header.split(";")[0].lstrip("data:")
  716. image_data = base64.b64decode(encoded)
  717. return (
  718. "image",
  719. (
  720. f"{uuid.uuid4()}.png",
  721. io.BytesIO(image_data),
  722. mime_type if mime_type else "image/png",
  723. ),
  724. )
  725. r = None
  726. try:
  727. if request.app.state.config.IMAGE_EDIT_ENGINE == "openai":
  728. headers = {
  729. "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_OPENAI_API_KEY}",
  730. }
  731. if ENABLE_FORWARD_USER_INFO_HEADERS:
  732. headers = include_user_info_headers(headers, user)
  733. data = {
  734. "model": model,
  735. "prompt": form_data.prompt,
  736. **({"n": form_data.n} if form_data.n else {}),
  737. **({"size": size} if size else {}),
  738. **(
  739. {}
  740. if "gpt-image-1" in request.app.state.config.IMAGE_EDIT_MODEL
  741. else {"response_format": "b64_json"}
  742. ),
  743. }
  744. files = []
  745. if isinstance(form_data.image, str):
  746. files = [get_image_file_item(form_data.image)]
  747. elif isinstance(form_data.image, list):
  748. for img in form_data.image:
  749. files.append(get_image_file_item(img))
  750. url_search_params = ""
  751. if request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION:
  752. url_search_params += f"?api-version={request.app.state.config.IMAGES_EDIT_OPENAI_API_VERSION}"
  753. # Use asyncio.to_thread for the requests.post call
  754. r = await asyncio.to_thread(
  755. requests.post,
  756. url=f"{request.app.state.config.IMAGES_EDIT_OPENAI_API_BASE_URL}/images/edits{url_search_params}",
  757. headers=headers,
  758. files=files,
  759. data=data,
  760. )
  761. r.raise_for_status()
  762. res = r.json()
  763. images = []
  764. for image in res["data"]:
  765. if image_url := image.get("url", None):
  766. image_data, content_type = get_image_data(image_url, headers)
  767. else:
  768. image_data, content_type = get_image_data(image["b64_json"])
  769. url = upload_image(request, image_data, content_type, data, user)
  770. images.append({"url": url})
  771. return images
  772. elif request.app.state.config.IMAGE_EDIT_ENGINE == "gemini":
  773. headers = {
  774. "Content-Type": "application/json",
  775. "x-goog-api-key": request.app.state.config.IMAGES_EDIT_GEMINI_API_KEY,
  776. }
  777. model = f"{model}:generateContent"
  778. data = {"contents": [{"parts": [{"text": form_data.prompt}]}]}
  779. if isinstance(form_data.image, str):
  780. data["contents"][0]["parts"].append(
  781. {
  782. "inline_data": {
  783. "mime_type": "image/png",
  784. "data": form_data.image.split(",", 1)[1],
  785. }
  786. }
  787. )
  788. elif isinstance(form_data.image, list):
  789. data["contents"][0]["parts"].extend(
  790. [
  791. {
  792. "inline_data": {
  793. "mime_type": "image/png",
  794. "data": image.split(",", 1)[1],
  795. }
  796. }
  797. for image in form_data.image
  798. ]
  799. )
  800. # Use asyncio.to_thread for the requests.post call
  801. r = await asyncio.to_thread(
  802. requests.post,
  803. url=f"{request.app.state.config.IMAGES_EDIT_GEMINI_API_BASE_URL}/models/{model}",
  804. json=data,
  805. headers=headers,
  806. )
  807. r.raise_for_status()
  808. res = r.json()
  809. images = []
  810. for image in res["candidates"]:
  811. for part in image["content"]["parts"]:
  812. if part.get("inlineData", {}).get("data"):
  813. image_data, content_type = get_image_data(
  814. part["inlineData"]["data"]
  815. )
  816. url = upload_image(
  817. request, image_data, content_type, data, user
  818. )
  819. images.append({"url": url})
  820. return images
  821. elif request.app.state.config.IMAGE_EDIT_ENGINE == "comfyui":
  822. try:
  823. files = []
  824. if isinstance(form_data.image, str):
  825. files = [get_image_file_item(form_data.image)]
  826. elif isinstance(form_data.image, list):
  827. for img in form_data.image:
  828. files.append(get_image_file_item(img))
  829. # Upload images to ComfyUI and get their names
  830. comfyui_images = []
  831. for file_item in files:
  832. res = await comfyui_upload_image(
  833. file_item,
  834. request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
  835. request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
  836. )
  837. comfyui_images.append(res.get("name", file_item[1][0]))
  838. except Exception as e:
  839. log.debug(f"Error uploading images to ComfyUI: {e}")
  840. raise Exception("Failed to upload images to ComfyUI.")
  841. data = {
  842. "image": comfyui_images,
  843. "prompt": form_data.prompt,
  844. **({"width": width} if width is not None else {}),
  845. **({"height": height} if height is not None else {}),
  846. **({"n": form_data.n} if form_data.n else {}),
  847. }
  848. form_data = ComfyUIEditImageForm(
  849. **{
  850. "workflow": ComfyUIWorkflow(
  851. **{
  852. "workflow": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW,
  853. "nodes": request.app.state.config.IMAGES_EDIT_COMFYUI_WORKFLOW_NODES,
  854. }
  855. ),
  856. **data,
  857. }
  858. )
  859. res = await comfyui_edit_image(
  860. model,
  861. form_data,
  862. user.id,
  863. request.app.state.config.IMAGES_EDIT_COMFYUI_BASE_URL,
  864. request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY,
  865. )
  866. log.debug(f"res: {res}")
  867. image_urls = set()
  868. for image in res["data"]:
  869. image_urls.add(image["url"])
  870. image_urls = list(image_urls)
  871. # Prioritize output type URLs if available
  872. output_type_urls = [url for url in image_urls if "type=output" in url]
  873. if output_type_urls:
  874. image_urls = output_type_urls
  875. log.debug(f"Image URLs: {image_urls}")
  876. images = []
  877. for image_url in image_urls:
  878. headers = None
  879. if request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY:
  880. headers = {
  881. "Authorization": f"Bearer {request.app.state.config.IMAGES_EDIT_COMFYUI_API_KEY}"
  882. }
  883. image_data, content_type = get_image_data(image_url, headers)
  884. url = upload_image(
  885. request,
  886. image_data,
  887. content_type,
  888. form_data.model_dump(exclude_none=True),
  889. user,
  890. )
  891. images.append({"url": url})
  892. return images
  893. except Exception as e:
  894. error = e
  895. if r != None:
  896. data = r.text
  897. try:
  898. data = json.loads(data)
  899. if "error" in data:
  900. error = data["error"]["message"]
  901. except Exception:
  902. error = data
  903. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))