images.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. import asyncio
  2. import base64
  3. import io
  4. import json
  5. import logging
  6. import mimetypes
  7. import re
  8. from pathlib import Path
  9. from typing import Optional
  10. from urllib.parse import quote
  11. import requests
  12. from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
  13. from open_webui.config import CACHE_DIR
  14. from open_webui.constants import ERROR_MESSAGES
  15. from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
  16. from open_webui.routers.files import upload_file
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.utils.images.comfyui import (
  19. ComfyUIGenerateImageForm,
  20. ComfyUIWorkflow,
  21. comfyui_generate_image,
  22. )
  23. from pydantic import BaseModel
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  26. IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
  27. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  28. router = APIRouter()
  29. @router.get("/config")
  30. async def get_config(request: Request, user=Depends(get_admin_user)):
  31. return {
  32. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  33. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  34. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  35. "openai": {
  36. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  37. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  38. },
  39. "automatic1111": {
  40. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  41. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  42. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  43. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  44. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  45. },
  46. "comfyui": {
  47. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  48. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  49. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  50. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  51. },
  52. "gemini": {
  53. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  54. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  55. },
  56. }
  57. class OpenAIConfigForm(BaseModel):
  58. OPENAI_API_BASE_URL: str
  59. OPENAI_API_KEY: str
  60. class Automatic1111ConfigForm(BaseModel):
  61. AUTOMATIC1111_BASE_URL: str
  62. AUTOMATIC1111_API_AUTH: str
  63. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  64. AUTOMATIC1111_SAMPLER: Optional[str]
  65. AUTOMATIC1111_SCHEDULER: Optional[str]
  66. class ComfyUIConfigForm(BaseModel):
  67. COMFYUI_BASE_URL: str
  68. COMFYUI_API_KEY: str
  69. COMFYUI_WORKFLOW: str
  70. COMFYUI_WORKFLOW_NODES: list[dict]
  71. class GeminiConfigForm(BaseModel):
  72. GEMINI_API_BASE_URL: str
  73. GEMINI_API_KEY: str
  74. class ConfigForm(BaseModel):
  75. enabled: bool
  76. engine: str
  77. prompt_generation: bool
  78. openai: OpenAIConfigForm
  79. automatic1111: Automatic1111ConfigForm
  80. comfyui: ComfyUIConfigForm
  81. gemini: GeminiConfigForm
  82. @router.post("/config/update")
  83. async def update_config(
  84. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  85. ):
  86. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  87. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  88. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  89. form_data.prompt_generation
  90. )
  91. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  92. form_data.openai.OPENAI_API_BASE_URL
  93. )
  94. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  95. request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
  96. form_data.gemini.GEMINI_API_BASE_URL
  97. )
  98. request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
  99. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  100. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  101. )
  102. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  103. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  104. )
  105. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  106. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  107. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  108. else None
  109. )
  110. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  111. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  112. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  113. else None
  114. )
  115. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  116. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  117. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  118. else None
  119. )
  120. request.app.state.config.COMFYUI_BASE_URL = (
  121. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  122. )
  123. request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
  124. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  125. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  126. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  127. )
  128. return {
  129. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  130. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  131. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  132. "openai": {
  133. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  134. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  135. },
  136. "automatic1111": {
  137. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  138. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  139. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  140. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  141. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  142. },
  143. "comfyui": {
  144. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  145. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  146. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  147. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  148. },
  149. "gemini": {
  150. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  151. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  152. },
  153. }
  154. def get_automatic1111_api_auth(request: Request):
  155. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  156. return ""
  157. else:
  158. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  159. "utf-8"
  160. )
  161. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  162. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  163. return f"Basic {auth1111_base64_encoded_string}"
  164. @router.get("/config/url/verify")
  165. async def verify_url(request: Request, user=Depends(get_admin_user)):
  166. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  167. try:
  168. r = requests.get(
  169. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  170. headers={"authorization": get_automatic1111_api_auth(request)},
  171. )
  172. r.raise_for_status()
  173. return True
  174. except Exception:
  175. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  176. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  177. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  178. headers = None
  179. if request.app.state.config.COMFYUI_API_KEY:
  180. headers = {
  181. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  182. }
  183. try:
  184. r = requests.get(
  185. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  186. headers=headers,
  187. )
  188. r.raise_for_status()
  189. return True
  190. except Exception:
  191. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  192. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  193. else:
  194. return True
  195. def set_image_model(request: Request, model: str):
  196. log.info(f"Setting image model to {model}")
  197. request.app.state.config.IMAGE_GENERATION_MODEL = model
  198. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  199. api_auth = get_automatic1111_api_auth(request)
  200. r = requests.get(
  201. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  202. headers={"authorization": api_auth},
  203. )
  204. options = r.json()
  205. if model != options["sd_model_checkpoint"]:
  206. options["sd_model_checkpoint"] = model
  207. r = requests.post(
  208. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  209. json=options,
  210. headers={"authorization": api_auth},
  211. )
  212. return request.app.state.config.IMAGE_GENERATION_MODEL
  213. def get_image_model(request):
  214. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  215. return (
  216. request.app.state.config.IMAGE_GENERATION_MODEL
  217. if request.app.state.config.IMAGE_GENERATION_MODEL
  218. else "dall-e-2"
  219. )
  220. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  221. return (
  222. request.app.state.config.IMAGE_GENERATION_MODEL
  223. if request.app.state.config.IMAGE_GENERATION_MODEL
  224. else "imagen-3.0-generate-002"
  225. )
  226. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  227. return (
  228. request.app.state.config.IMAGE_GENERATION_MODEL
  229. if request.app.state.config.IMAGE_GENERATION_MODEL
  230. else ""
  231. )
  232. elif (
  233. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  234. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  235. ):
  236. try:
  237. r = requests.get(
  238. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  239. headers={"authorization": get_automatic1111_api_auth(request)},
  240. )
  241. options = r.json()
  242. return options["sd_model_checkpoint"]
  243. except Exception as e:
  244. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  245. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  246. class ImageConfigForm(BaseModel):
  247. MODEL: str
  248. IMAGE_SIZE: str
  249. IMAGE_STEPS: int
  250. @router.get("/image/config")
  251. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  252. return {
  253. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  254. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  255. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  256. }
  257. @router.post("/image/config/update")
  258. async def update_image_config(
  259. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  260. ):
  261. set_image_model(request, form_data.MODEL)
  262. if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
  263. raise HTTPException(
  264. status_code=400,
  265. detail=ERROR_MESSAGES.INCORRECT_FORMAT(
  266. " (auto is only allowed with gpt-image-1)."
  267. ),
  268. )
  269. pattern = r"^\d+x\d+$"
  270. if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
  271. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  272. else:
  273. raise HTTPException(
  274. status_code=400,
  275. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  276. )
  277. if form_data.IMAGE_STEPS >= 0:
  278. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  279. else:
  280. raise HTTPException(
  281. status_code=400,
  282. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  283. )
  284. return {
  285. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  286. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  287. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  288. }
  289. @router.get("/models")
  290. def get_models(request: Request, user=Depends(get_verified_user)):
  291. try:
  292. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  293. return [
  294. {"id": "dall-e-2", "name": "DALL·E 2"},
  295. {"id": "dall-e-3", "name": "DALL·E 3"},
  296. {"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
  297. ]
  298. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  299. return [
  300. {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
  301. ]
  302. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  303. # TODO - get models from comfyui
  304. headers = {
  305. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  306. }
  307. r = requests.get(
  308. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  309. headers=headers,
  310. )
  311. info = r.json()
  312. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  313. model_node_id = None
  314. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  315. if node["type"] == "model":
  316. if node["node_ids"]:
  317. model_node_id = node["node_ids"][0]
  318. break
  319. if model_node_id:
  320. model_list_key = None
  321. log.info(workflow[model_node_id]["class_type"])
  322. for key in info[workflow[model_node_id]["class_type"]]["input"][
  323. "required"
  324. ]:
  325. if "_name" in key:
  326. model_list_key = key
  327. break
  328. if model_list_key:
  329. return list(
  330. map(
  331. lambda model: {"id": model, "name": model},
  332. info[workflow[model_node_id]["class_type"]]["input"][
  333. "required"
  334. ][model_list_key][0],
  335. )
  336. )
  337. else:
  338. return list(
  339. map(
  340. lambda model: {"id": model, "name": model},
  341. info["CheckpointLoaderSimple"]["input"]["required"][
  342. "ckpt_name"
  343. ][0],
  344. )
  345. )
  346. elif (
  347. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  348. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  349. ):
  350. r = requests.get(
  351. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  352. headers={"authorization": get_automatic1111_api_auth(request)},
  353. )
  354. models = r.json()
  355. return list(
  356. map(
  357. lambda model: {"id": model["title"], "name": model["model_name"]},
  358. models,
  359. )
  360. )
  361. except Exception as e:
  362. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  363. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  364. class GenerateImageForm(BaseModel):
  365. model: Optional[str] = None
  366. prompt: str
  367. size: Optional[str] = None
  368. n: int = 1
  369. negative_prompt: Optional[str] = None
  370. def load_b64_image_data(b64_str):
  371. try:
  372. if "," in b64_str:
  373. header, encoded = b64_str.split(",", 1)
  374. mime_type = header.split(";")[0].lstrip("data:")
  375. img_data = base64.b64decode(encoded)
  376. else:
  377. mime_type = "image/png"
  378. img_data = base64.b64decode(b64_str)
  379. return img_data, mime_type
  380. except Exception as e:
  381. log.exception(f"Error loading image data: {e}")
  382. return None, None
  383. def load_url_image_data(url, headers=None):
  384. try:
  385. if headers:
  386. r = requests.get(url, headers=headers)
  387. else:
  388. r = requests.get(url)
  389. r.raise_for_status()
  390. if r.headers["content-type"].split("/")[0] == "image":
  391. mime_type = r.headers["content-type"]
  392. return r.content, mime_type
  393. else:
  394. log.error("Url does not point to an image.")
  395. return None
  396. except Exception as e:
  397. log.exception(f"Error saving image: {e}")
  398. return None
  399. def upload_image(request, image_data, content_type, metadata, user):
  400. image_format = mimetypes.guess_extension(content_type)
  401. file = UploadFile(
  402. file=io.BytesIO(image_data),
  403. filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
  404. headers={
  405. "content-type": content_type,
  406. },
  407. )
  408. file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
  409. url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
  410. return url
  411. @router.post("/generations")
  412. async def image_generations(
  413. request: Request,
  414. form_data: GenerateImageForm,
  415. user=Depends(get_verified_user),
  416. ):
  417. # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
  418. # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
  419. # image model other than gpt-image-1, which is warned about on settings save
  420. width, height = (
  421. tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  422. if "x" in request.app.state.config.IMAGE_SIZE
  423. else (512, 512)
  424. )
  425. r = None
  426. try:
  427. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  428. headers = {}
  429. headers["Authorization"] = (
  430. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  431. )
  432. headers["Content-Type"] = "application/json"
  433. if ENABLE_FORWARD_USER_INFO_HEADERS:
  434. headers["X-OpenWebUI-User-Name"] = quote(user.name)
  435. headers["X-OpenWebUI-User-Id"] = quote(user.id)
  436. headers["X-OpenWebUI-User-Email"] = quote(user.email)
  437. headers["X-OpenWebUI-User-Role"] = quote(user.role)
  438. data = {
  439. "model": (
  440. request.app.state.config.IMAGE_GENERATION_MODEL
  441. if request.app.state.config.IMAGE_GENERATION_MODEL != ""
  442. else "dall-e-2"
  443. ),
  444. "prompt": form_data.prompt,
  445. "n": form_data.n,
  446. "size": (
  447. form_data.size
  448. if form_data.size
  449. else request.app.state.config.IMAGE_SIZE
  450. ),
  451. **(
  452. {}
  453. if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
  454. else {"response_format": "b64_json"}
  455. ),
  456. }
  457. # Use asyncio.to_thread for the requests.post call
  458. r = await asyncio.to_thread(
  459. requests.post,
  460. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
  461. json=data,
  462. headers=headers,
  463. )
  464. r.raise_for_status()
  465. res = r.json()
  466. images = []
  467. for image in res["data"]:
  468. if image_url := image.get("url", None):
  469. image_data, content_type = load_url_image_data(image_url, headers)
  470. else:
  471. image_data, content_type = load_b64_image_data(image["b64_json"])
  472. url = upload_image(request, image_data, content_type, data, user)
  473. images.append({"url": url})
  474. return images
  475. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  476. headers = {}
  477. headers["Content-Type"] = "application/json"
  478. headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
  479. model = get_image_model(request)
  480. data = {
  481. "instances": {"prompt": form_data.prompt},
  482. "parameters": {
  483. "sampleCount": form_data.n,
  484. "outputOptions": {"mimeType": "image/png"},
  485. },
  486. }
  487. # Use asyncio.to_thread for the requests.post call
  488. r = await asyncio.to_thread(
  489. requests.post,
  490. url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
  491. json=data,
  492. headers=headers,
  493. )
  494. r.raise_for_status()
  495. res = r.json()
  496. images = []
  497. for image in res["predictions"]:
  498. image_data, content_type = load_b64_image_data(
  499. image["bytesBase64Encoded"]
  500. )
  501. url = upload_image(request, image_data, content_type, data, user)
  502. images.append({"url": url})
  503. return images
  504. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  505. data = {
  506. "prompt": form_data.prompt,
  507. "width": width,
  508. "height": height,
  509. "n": form_data.n,
  510. }
  511. if request.app.state.config.IMAGE_STEPS is not None:
  512. data["steps"] = request.app.state.config.IMAGE_STEPS
  513. if form_data.negative_prompt is not None:
  514. data["negative_prompt"] = form_data.negative_prompt
  515. form_data = ComfyUIGenerateImageForm(
  516. **{
  517. "workflow": ComfyUIWorkflow(
  518. **{
  519. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  520. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  521. }
  522. ),
  523. **data,
  524. }
  525. )
  526. res = await comfyui_generate_image(
  527. request.app.state.config.IMAGE_GENERATION_MODEL,
  528. form_data,
  529. user.id,
  530. request.app.state.config.COMFYUI_BASE_URL,
  531. request.app.state.config.COMFYUI_API_KEY,
  532. )
  533. log.debug(f"res: {res}")
  534. images = []
  535. for image in res["data"]:
  536. headers = None
  537. if request.app.state.config.COMFYUI_API_KEY:
  538. headers = {
  539. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  540. }
  541. image_data, content_type = load_url_image_data(image["url"], headers)
  542. url = upload_image(
  543. request,
  544. image_data,
  545. content_type,
  546. form_data.model_dump(exclude_none=True),
  547. user,
  548. )
  549. images.append({"url": url})
  550. return images
  551. elif (
  552. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  553. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  554. ):
  555. if form_data.model:
  556. set_image_model(request, form_data.model)
  557. data = {
  558. "prompt": form_data.prompt,
  559. "batch_size": form_data.n,
  560. "width": width,
  561. "height": height,
  562. }
  563. if request.app.state.config.IMAGE_STEPS is not None:
  564. data["steps"] = request.app.state.config.IMAGE_STEPS
  565. if form_data.negative_prompt is not None:
  566. data["negative_prompt"] = form_data.negative_prompt
  567. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  568. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  569. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  570. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  571. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  572. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  573. # Use asyncio.to_thread for the requests.post call
  574. r = await asyncio.to_thread(
  575. requests.post,
  576. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  577. json=data,
  578. headers={"authorization": get_automatic1111_api_auth(request)},
  579. )
  580. res = r.json()
  581. log.debug(f"res: {res}")
  582. images = []
  583. for image in res["images"]:
  584. image_data, content_type = load_b64_image_data(image)
  585. url = upload_image(
  586. request,
  587. image_data,
  588. content_type,
  589. {**data, "info": res["info"]},
  590. user,
  591. )
  592. images.append({"url": url})
  593. return images
  594. except Exception as e:
  595. error = e
  596. if r != None:
  597. data = r.json()
  598. if "error" in data:
  599. error = data["error"]["message"]
  600. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))