images.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. import asyncio
  2. import base64
  3. import io
  4. import json
  5. import logging
  6. import mimetypes
  7. import re
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
  12. from open_webui.config import CACHE_DIR
  13. from open_webui.constants import ERROR_MESSAGES
  14. from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
  15. from open_webui.routers.files import upload_file
  16. from open_webui.utils.auth import get_admin_user, get_verified_user
  17. from open_webui.utils.images.comfyui import (
  18. ComfyUIGenerateImageForm,
  19. ComfyUIWorkflow,
  20. comfyui_generate_image,
  21. )
  22. from pydantic import BaseModel
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  25. IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
  26. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  27. router = APIRouter()
  28. @router.get("/config")
  29. async def get_config(request: Request, user=Depends(get_admin_user)):
  30. return {
  31. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  32. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  33. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  34. "openai": {
  35. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  36. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  37. },
  38. "automatic1111": {
  39. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  40. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  41. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  42. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  43. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  44. },
  45. "comfyui": {
  46. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  47. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  48. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  49. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  50. },
  51. "gemini": {
  52. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  53. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  54. },
  55. }
  56. class OpenAIConfigForm(BaseModel):
  57. OPENAI_API_BASE_URL: str
  58. OPENAI_API_KEY: str
  59. class Automatic1111ConfigForm(BaseModel):
  60. AUTOMATIC1111_BASE_URL: str
  61. AUTOMATIC1111_API_AUTH: str
  62. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  63. AUTOMATIC1111_SAMPLER: Optional[str]
  64. AUTOMATIC1111_SCHEDULER: Optional[str]
  65. class ComfyUIConfigForm(BaseModel):
  66. COMFYUI_BASE_URL: str
  67. COMFYUI_API_KEY: str
  68. COMFYUI_WORKFLOW: str
  69. COMFYUI_WORKFLOW_NODES: list[dict]
  70. class GeminiConfigForm(BaseModel):
  71. GEMINI_API_BASE_URL: str
  72. GEMINI_API_KEY: str
  73. class ConfigForm(BaseModel):
  74. enabled: bool
  75. engine: str
  76. prompt_generation: bool
  77. openai: OpenAIConfigForm
  78. automatic1111: Automatic1111ConfigForm
  79. comfyui: ComfyUIConfigForm
  80. gemini: GeminiConfigForm
  81. @router.post("/config/update")
  82. async def update_config(
  83. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  84. ):
  85. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  86. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  87. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  88. form_data.prompt_generation
  89. )
  90. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  91. form_data.openai.OPENAI_API_BASE_URL
  92. )
  93. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  94. request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
  95. form_data.gemini.GEMINI_API_BASE_URL
  96. )
  97. request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
  98. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  99. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  100. )
  101. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  102. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  103. )
  104. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  105. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  106. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  107. else None
  108. )
  109. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  110. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  111. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  112. else None
  113. )
  114. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  115. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  116. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  117. else None
  118. )
  119. request.app.state.config.COMFYUI_BASE_URL = (
  120. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  121. )
  122. request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
  123. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  124. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  125. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  126. )
  127. return {
  128. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  129. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  130. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  131. "openai": {
  132. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  133. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  134. },
  135. "automatic1111": {
  136. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  137. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  138. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  139. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  140. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  141. },
  142. "comfyui": {
  143. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  144. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  145. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  146. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  147. },
  148. "gemini": {
  149. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  150. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  151. },
  152. }
  153. def get_automatic1111_api_auth(request: Request):
  154. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  155. return ""
  156. else:
  157. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  158. "utf-8"
  159. )
  160. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  161. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  162. return f"Basic {auth1111_base64_encoded_string}"
  163. @router.get("/config/url/verify")
  164. async def verify_url(request: Request, user=Depends(get_admin_user)):
  165. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  166. try:
  167. r = requests.get(
  168. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  169. headers={"authorization": get_automatic1111_api_auth(request)},
  170. )
  171. r.raise_for_status()
  172. return True
  173. except Exception:
  174. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  175. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  176. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  177. headers = None
  178. if request.app.state.config.COMFYUI_API_KEY:
  179. headers = {
  180. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  181. }
  182. try:
  183. r = requests.get(
  184. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  185. headers=headers,
  186. )
  187. r.raise_for_status()
  188. return True
  189. except Exception:
  190. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  191. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  192. else:
  193. return True
  194. def set_image_model(request: Request, model: str):
  195. log.info(f"Setting image model to {model}")
  196. request.app.state.config.IMAGE_GENERATION_MODEL = model
  197. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  198. api_auth = get_automatic1111_api_auth(request)
  199. r = requests.get(
  200. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  201. headers={"authorization": api_auth},
  202. )
  203. options = r.json()
  204. if model != options["sd_model_checkpoint"]:
  205. options["sd_model_checkpoint"] = model
  206. r = requests.post(
  207. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  208. json=options,
  209. headers={"authorization": api_auth},
  210. )
  211. return request.app.state.config.IMAGE_GENERATION_MODEL
  212. def get_image_model(request):
  213. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  214. return (
  215. request.app.state.config.IMAGE_GENERATION_MODEL
  216. if request.app.state.config.IMAGE_GENERATION_MODEL
  217. else "dall-e-2"
  218. )
  219. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  220. return (
  221. request.app.state.config.IMAGE_GENERATION_MODEL
  222. if request.app.state.config.IMAGE_GENERATION_MODEL
  223. else "imagen-3.0-generate-002"
  224. )
  225. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  226. return (
  227. request.app.state.config.IMAGE_GENERATION_MODEL
  228. if request.app.state.config.IMAGE_GENERATION_MODEL
  229. else ""
  230. )
  231. elif (
  232. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  233. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  234. ):
  235. try:
  236. r = requests.get(
  237. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  238. headers={"authorization": get_automatic1111_api_auth(request)},
  239. )
  240. options = r.json()
  241. return options["sd_model_checkpoint"]
  242. except Exception as e:
  243. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  244. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  245. class ImageConfigForm(BaseModel):
  246. MODEL: str
  247. IMAGE_SIZE: str
  248. IMAGE_STEPS: int
  249. @router.get("/image/config")
  250. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  251. return {
  252. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  253. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  254. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  255. }
  256. @router.post("/image/config/update")
  257. async def update_image_config(
  258. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  259. ):
  260. set_image_model(request, form_data.MODEL)
  261. pattern = r"^\d+x\d+$"
  262. if re.match(pattern, form_data.IMAGE_SIZE):
  263. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  264. else:
  265. raise HTTPException(
  266. status_code=400,
  267. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  268. )
  269. if form_data.IMAGE_STEPS >= 0:
  270. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  271. else:
  272. raise HTTPException(
  273. status_code=400,
  274. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  275. )
  276. return {
  277. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  278. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  279. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  280. }
  281. @router.get("/models")
  282. def get_models(request: Request, user=Depends(get_verified_user)):
  283. try:
  284. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  285. return [
  286. {"id": "dall-e-2", "name": "DALL·E 2"},
  287. {"id": "dall-e-3", "name": "DALL·E 3"},
  288. {"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
  289. ]
  290. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  291. return [
  292. {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
  293. ]
  294. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  295. # TODO - get models from comfyui
  296. headers = {
  297. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  298. }
  299. r = requests.get(
  300. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  301. headers=headers,
  302. )
  303. info = r.json()
  304. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  305. model_node_id = None
  306. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  307. if node["type"] == "model":
  308. if node["node_ids"]:
  309. model_node_id = node["node_ids"][0]
  310. break
  311. if model_node_id:
  312. model_list_key = None
  313. log.info(workflow[model_node_id]["class_type"])
  314. for key in info[workflow[model_node_id]["class_type"]]["input"][
  315. "required"
  316. ]:
  317. if "_name" in key:
  318. model_list_key = key
  319. break
  320. if model_list_key:
  321. return list(
  322. map(
  323. lambda model: {"id": model, "name": model},
  324. info[workflow[model_node_id]["class_type"]]["input"][
  325. "required"
  326. ][model_list_key][0],
  327. )
  328. )
  329. else:
  330. return list(
  331. map(
  332. lambda model: {"id": model, "name": model},
  333. info["CheckpointLoaderSimple"]["input"]["required"][
  334. "ckpt_name"
  335. ][0],
  336. )
  337. )
  338. elif (
  339. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  340. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  341. ):
  342. r = requests.get(
  343. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  344. headers={"authorization": get_automatic1111_api_auth(request)},
  345. )
  346. models = r.json()
  347. return list(
  348. map(
  349. lambda model: {"id": model["title"], "name": model["model_name"]},
  350. models,
  351. )
  352. )
  353. except Exception as e:
  354. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  355. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  356. class GenerateImageForm(BaseModel):
  357. model: Optional[str] = None
  358. prompt: str
  359. size: Optional[str] = None
  360. n: int = 1
  361. negative_prompt: Optional[str] = None
  362. def load_b64_image_data(b64_str):
  363. try:
  364. if "," in b64_str:
  365. header, encoded = b64_str.split(",", 1)
  366. mime_type = header.split(";")[0].lstrip("data:")
  367. img_data = base64.b64decode(encoded)
  368. else:
  369. mime_type = "image/png"
  370. img_data = base64.b64decode(b64_str)
  371. return img_data, mime_type
  372. except Exception as e:
  373. log.exception(f"Error loading image data: {e}")
  374. return None, None
  375. def load_url_image_data(url, headers=None):
  376. try:
  377. if headers:
  378. r = requests.get(url, headers=headers)
  379. else:
  380. r = requests.get(url)
  381. r.raise_for_status()
  382. if r.headers["content-type"].split("/")[0] == "image":
  383. mime_type = r.headers["content-type"]
  384. return r.content, mime_type
  385. else:
  386. log.error("Url does not point to an image.")
  387. return None
  388. except Exception as e:
  389. log.exception(f"Error saving image: {e}")
  390. return None
  391. def upload_image(request, image_data, content_type, metadata, user):
  392. image_format = mimetypes.guess_extension(content_type)
  393. file = UploadFile(
  394. file=io.BytesIO(image_data),
  395. filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
  396. headers={
  397. "content-type": content_type,
  398. },
  399. )
  400. file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
  401. url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
  402. return url
  403. @router.post("/generations")
  404. async def image_generations(
  405. request: Request,
  406. form_data: GenerateImageForm,
  407. user=Depends(get_verified_user),
  408. ):
  409. width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  410. r = None
  411. try:
  412. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  413. headers = {}
  414. headers["Authorization"] = (
  415. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  416. )
  417. headers["Content-Type"] = "application/json"
  418. if ENABLE_FORWARD_USER_INFO_HEADERS:
  419. headers["X-OpenWebUI-User-Name"] = user.name
  420. headers["X-OpenWebUI-User-Id"] = user.id
  421. headers["X-OpenWebUI-User-Email"] = user.email
  422. headers["X-OpenWebUI-User-Role"] = user.role
  423. data = {
  424. "model": (
  425. request.app.state.config.IMAGE_GENERATION_MODEL
  426. if request.app.state.config.IMAGE_GENERATION_MODEL != ""
  427. else "dall-e-2"
  428. ),
  429. "prompt": form_data.prompt,
  430. "n": form_data.n,
  431. "size": (
  432. form_data.size
  433. if form_data.size
  434. else request.app.state.config.IMAGE_SIZE
  435. ),
  436. **(
  437. {}
  438. if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
  439. else {"response_format": "b64_json"}
  440. ),
  441. }
  442. # Use asyncio.to_thread for the requests.post call
  443. r = await asyncio.to_thread(
  444. requests.post,
  445. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
  446. json=data,
  447. headers=headers,
  448. )
  449. r.raise_for_status()
  450. res = r.json()
  451. images = []
  452. for image in res["data"]:
  453. if image_url := image.get("url", None):
  454. image_data, content_type = load_url_image_data(image_url, headers)
  455. else:
  456. image_data, content_type = load_b64_image_data(image["b64_json"])
  457. url = upload_image(request, image_data, content_type, data, user)
  458. images.append({"url": url})
  459. return images
  460. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  461. headers = {}
  462. headers["Content-Type"] = "application/json"
  463. headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
  464. model = get_image_model(request)
  465. data = {
  466. "instances": {"prompt": form_data.prompt},
  467. "parameters": {
  468. "sampleCount": form_data.n,
  469. "outputOptions": {"mimeType": "image/png"},
  470. },
  471. }
  472. # Use asyncio.to_thread for the requests.post call
  473. r = await asyncio.to_thread(
  474. requests.post,
  475. url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
  476. json=data,
  477. headers=headers,
  478. )
  479. r.raise_for_status()
  480. res = r.json()
  481. images = []
  482. for image in res["predictions"]:
  483. image_data, content_type = load_b64_image_data(
  484. image["bytesBase64Encoded"]
  485. )
  486. url = upload_image(request, image_data, content_type, data, user)
  487. images.append({"url": url})
  488. return images
  489. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  490. data = {
  491. "prompt": form_data.prompt,
  492. "width": width,
  493. "height": height,
  494. "n": form_data.n,
  495. }
  496. if request.app.state.config.IMAGE_STEPS is not None:
  497. data["steps"] = request.app.state.config.IMAGE_STEPS
  498. if form_data.negative_prompt is not None:
  499. data["negative_prompt"] = form_data.negative_prompt
  500. form_data = ComfyUIGenerateImageForm(
  501. **{
  502. "workflow": ComfyUIWorkflow(
  503. **{
  504. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  505. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  506. }
  507. ),
  508. **data,
  509. }
  510. )
  511. res = await comfyui_generate_image(
  512. request.app.state.config.IMAGE_GENERATION_MODEL,
  513. form_data,
  514. user.id,
  515. request.app.state.config.COMFYUI_BASE_URL,
  516. request.app.state.config.COMFYUI_API_KEY,
  517. )
  518. log.debug(f"res: {res}")
  519. images = []
  520. for image in res["data"]:
  521. headers = None
  522. if request.app.state.config.COMFYUI_API_KEY:
  523. headers = {
  524. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  525. }
  526. image_data, content_type = load_url_image_data(image["url"], headers)
  527. url = upload_image(
  528. request,
  529. image_data,
  530. content_type,
  531. form_data.model_dump(exclude_none=True),
  532. user,
  533. )
  534. images.append({"url": url})
  535. return images
  536. elif (
  537. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  538. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  539. ):
  540. if form_data.model:
  541. set_image_model(request, form_data.model)
  542. data = {
  543. "prompt": form_data.prompt,
  544. "batch_size": form_data.n,
  545. "width": width,
  546. "height": height,
  547. }
  548. if request.app.state.config.IMAGE_STEPS is not None:
  549. data["steps"] = request.app.state.config.IMAGE_STEPS
  550. if form_data.negative_prompt is not None:
  551. data["negative_prompt"] = form_data.negative_prompt
  552. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  553. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  554. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  555. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  556. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  557. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  558. # Use asyncio.to_thread for the requests.post call
  559. r = await asyncio.to_thread(
  560. requests.post,
  561. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  562. json=data,
  563. headers={"authorization": get_automatic1111_api_auth(request)},
  564. )
  565. res = r.json()
  566. log.debug(f"res: {res}")
  567. images = []
  568. for image in res["images"]:
  569. image_data, content_type = load_b64_image_data(image)
  570. url = upload_image(
  571. request,
  572. image_data,
  573. content_type,
  574. {**data, "info": res["info"]},
  575. user,
  576. )
  577. images.append({"url": url})
  578. return images
  579. except Exception as e:
  580. error = e
  581. if r != None:
  582. data = r.json()
  583. if "error" in data:
  584. error = data["error"]["message"]
  585. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))