1
0

images.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. import asyncio
  2. import base64
  3. import io
  4. import json
  5. import logging
  6. import mimetypes
  7. import re
  8. from pathlib import Path
  9. from typing import Optional
  10. from urllib.parse import quote
  11. import requests
  12. from fastapi import (
  13. APIRouter,
  14. Depends,
  15. HTTPException,
  16. Request,
  17. UploadFile,
  18. )
  19. from open_webui.config import CACHE_DIR
  20. from open_webui.constants import ERROR_MESSAGES
  21. from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
  22. from open_webui.routers.files import upload_file_handler
  23. from open_webui.utils.auth import get_admin_user, get_verified_user
  24. from open_webui.utils.images.comfyui import (
  25. ComfyUIGenerateImageForm,
  26. ComfyUIWorkflow,
  27. comfyui_generate_image,
  28. )
  29. from pydantic import BaseModel
  30. log = logging.getLogger(__name__)
  31. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  32. IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
  33. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  34. router = APIRouter()
  35. @router.get("/config")
  36. async def get_config(request: Request, user=Depends(get_admin_user)):
  37. return {
  38. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  39. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  40. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  41. "openai": {
  42. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  43. "OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
  44. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  45. },
  46. "automatic1111": {
  47. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  48. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  49. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  50. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  51. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  52. },
  53. "comfyui": {
  54. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  55. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  56. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  57. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  58. },
  59. "gemini": {
  60. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  61. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  62. },
  63. }
  64. class OpenAIConfigForm(BaseModel):
  65. OPENAI_API_BASE_URL: str
  66. OPENAI_API_VERSION: str
  67. OPENAI_API_KEY: str
  68. class Automatic1111ConfigForm(BaseModel):
  69. AUTOMATIC1111_BASE_URL: str
  70. AUTOMATIC1111_API_AUTH: str
  71. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  72. AUTOMATIC1111_SAMPLER: Optional[str]
  73. AUTOMATIC1111_SCHEDULER: Optional[str]
  74. class ComfyUIConfigForm(BaseModel):
  75. COMFYUI_BASE_URL: str
  76. COMFYUI_API_KEY: str
  77. COMFYUI_WORKFLOW: str
  78. COMFYUI_WORKFLOW_NODES: list[dict]
  79. class GeminiConfigForm(BaseModel):
  80. GEMINI_API_BASE_URL: str
  81. GEMINI_API_KEY: str
  82. class ConfigForm(BaseModel):
  83. enabled: bool
  84. engine: str
  85. prompt_generation: bool
  86. openai: OpenAIConfigForm
  87. automatic1111: Automatic1111ConfigForm
  88. comfyui: ComfyUIConfigForm
  89. gemini: GeminiConfigForm
  90. @router.post("/config/update")
  91. async def update_config(
  92. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  93. ):
  94. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  95. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  96. request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = (
  97. form_data.prompt_generation
  98. )
  99. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  100. form_data.openai.OPENAI_API_BASE_URL
  101. )
  102. request.app.state.config.IMAGES_OPENAI_API_VERSION = (
  103. form_data.openai.OPENAI_API_VERSION
  104. )
  105. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  106. request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
  107. form_data.gemini.GEMINI_API_BASE_URL
  108. )
  109. request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
  110. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  111. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  112. )
  113. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  114. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  115. )
  116. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  117. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  118. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  119. else None
  120. )
  121. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  122. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  123. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  124. else None
  125. )
  126. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  127. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  128. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  129. else None
  130. )
  131. request.app.state.config.COMFYUI_BASE_URL = (
  132. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  133. )
  134. request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
  135. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  136. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  137. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  138. )
  139. return {
  140. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  141. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  142. "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION,
  143. "openai": {
  144. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  145. "OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION,
  146. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  147. },
  148. "automatic1111": {
  149. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  150. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  151. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  152. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  153. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  154. },
  155. "comfyui": {
  156. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  157. "COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
  158. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  159. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  160. },
  161. "gemini": {
  162. "GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
  163. "GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
  164. },
  165. }
  166. def get_automatic1111_api_auth(request: Request):
  167. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  168. return ""
  169. else:
  170. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  171. "utf-8"
  172. )
  173. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  174. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  175. return f"Basic {auth1111_base64_encoded_string}"
  176. @router.get("/config/url/verify")
  177. async def verify_url(request: Request, user=Depends(get_admin_user)):
  178. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  179. try:
  180. r = requests.get(
  181. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  182. headers={"authorization": get_automatic1111_api_auth(request)},
  183. )
  184. r.raise_for_status()
  185. return True
  186. except Exception:
  187. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  188. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  189. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  190. headers = None
  191. if request.app.state.config.COMFYUI_API_KEY:
  192. headers = {
  193. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  194. }
  195. try:
  196. r = requests.get(
  197. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  198. headers=headers,
  199. )
  200. r.raise_for_status()
  201. return True
  202. except Exception:
  203. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  204. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  205. else:
  206. return True
  207. def set_image_model(request: Request, model: str):
  208. log.info(f"Setting image model to {model}")
  209. request.app.state.config.IMAGE_GENERATION_MODEL = model
  210. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  211. api_auth = get_automatic1111_api_auth(request)
  212. r = requests.get(
  213. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  214. headers={"authorization": api_auth},
  215. )
  216. options = r.json()
  217. if model != options["sd_model_checkpoint"]:
  218. options["sd_model_checkpoint"] = model
  219. r = requests.post(
  220. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  221. json=options,
  222. headers={"authorization": api_auth},
  223. )
  224. return request.app.state.config.IMAGE_GENERATION_MODEL
  225. def get_image_model(request):
  226. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  227. return (
  228. request.app.state.config.IMAGE_GENERATION_MODEL
  229. if request.app.state.config.IMAGE_GENERATION_MODEL
  230. else "dall-e-2"
  231. )
  232. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  233. return (
  234. request.app.state.config.IMAGE_GENERATION_MODEL
  235. if request.app.state.config.IMAGE_GENERATION_MODEL
  236. else "imagen-3.0-generate-002"
  237. )
  238. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  239. return (
  240. request.app.state.config.IMAGE_GENERATION_MODEL
  241. if request.app.state.config.IMAGE_GENERATION_MODEL
  242. else ""
  243. )
  244. elif (
  245. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  246. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  247. ):
  248. try:
  249. r = requests.get(
  250. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  251. headers={"authorization": get_automatic1111_api_auth(request)},
  252. )
  253. options = r.json()
  254. return options["sd_model_checkpoint"]
  255. except Exception as e:
  256. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  257. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  258. class ImageConfigForm(BaseModel):
  259. MODEL: str
  260. IMAGE_SIZE: str
  261. IMAGE_STEPS: int
  262. @router.get("/image/config")
  263. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  264. return {
  265. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  266. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  267. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  268. }
  269. @router.post("/image/config/update")
  270. async def update_image_config(
  271. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  272. ):
  273. set_image_model(request, form_data.MODEL)
  274. if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1":
  275. raise HTTPException(
  276. status_code=400,
  277. detail=ERROR_MESSAGES.INCORRECT_FORMAT(
  278. " (auto is only allowed with gpt-image-1)."
  279. ),
  280. )
  281. pattern = r"^\d+x\d+$"
  282. if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE):
  283. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  284. else:
  285. raise HTTPException(
  286. status_code=400,
  287. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  288. )
  289. if form_data.IMAGE_STEPS >= 0:
  290. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  291. else:
  292. raise HTTPException(
  293. status_code=400,
  294. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  295. )
  296. return {
  297. "MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
  298. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  299. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  300. }
  301. @router.get("/models")
  302. def get_models(request: Request, user=Depends(get_verified_user)):
  303. try:
  304. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  305. return [
  306. {"id": "dall-e-2", "name": "DALL·E 2"},
  307. {"id": "dall-e-3", "name": "DALL·E 3"},
  308. {"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
  309. ]
  310. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  311. return [
  312. {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
  313. ]
  314. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  315. # TODO - get models from comfyui
  316. headers = {
  317. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  318. }
  319. r = requests.get(
  320. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
  321. headers=headers,
  322. )
  323. info = r.json()
  324. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  325. model_node_id = None
  326. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  327. if node["type"] == "model":
  328. if node["node_ids"]:
  329. model_node_id = node["node_ids"][0]
  330. break
  331. if model_node_id:
  332. model_list_key = None
  333. log.info(workflow[model_node_id]["class_type"])
  334. for key in info[workflow[model_node_id]["class_type"]]["input"][
  335. "required"
  336. ]:
  337. if "_name" in key:
  338. model_list_key = key
  339. break
  340. if model_list_key:
  341. return list(
  342. map(
  343. lambda model: {"id": model, "name": model},
  344. info[workflow[model_node_id]["class_type"]]["input"][
  345. "required"
  346. ][model_list_key][0],
  347. )
  348. )
  349. else:
  350. return list(
  351. map(
  352. lambda model: {"id": model, "name": model},
  353. info["CheckpointLoaderSimple"]["input"]["required"][
  354. "ckpt_name"
  355. ][0],
  356. )
  357. )
  358. elif (
  359. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  360. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  361. ):
  362. r = requests.get(
  363. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  364. headers={"authorization": get_automatic1111_api_auth(request)},
  365. )
  366. models = r.json()
  367. return list(
  368. map(
  369. lambda model: {"id": model["title"], "name": model["model_name"]},
  370. models,
  371. )
  372. )
  373. except Exception as e:
  374. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  375. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  376. class GenerateImageForm(BaseModel):
  377. model: Optional[str] = None
  378. prompt: str
  379. size: Optional[str] = None
  380. n: int = 1
  381. negative_prompt: Optional[str] = None
  382. def load_b64_image_data(b64_str):
  383. try:
  384. if "," in b64_str:
  385. header, encoded = b64_str.split(",", 1)
  386. mime_type = header.split(";")[0].lstrip("data:")
  387. img_data = base64.b64decode(encoded)
  388. else:
  389. mime_type = "image/png"
  390. img_data = base64.b64decode(b64_str)
  391. return img_data, mime_type
  392. except Exception as e:
  393. log.exception(f"Error loading image data: {e}")
  394. return None, None
  395. def load_url_image_data(url, headers=None):
  396. try:
  397. if headers:
  398. r = requests.get(url, headers=headers)
  399. else:
  400. r = requests.get(url)
  401. r.raise_for_status()
  402. if r.headers["content-type"].split("/")[0] == "image":
  403. mime_type = r.headers["content-type"]
  404. return r.content, mime_type
  405. else:
  406. log.error("Url does not point to an image.")
  407. return None
  408. except Exception as e:
  409. log.exception(f"Error saving image: {e}")
  410. return None
  411. def upload_image(request, image_data, content_type, metadata, user):
  412. image_format = mimetypes.guess_extension(content_type)
  413. file = UploadFile(
  414. file=io.BytesIO(image_data),
  415. filename=f"generated-image{image_format}", # will be converted to a unique ID on upload_file
  416. headers={
  417. "content-type": content_type,
  418. },
  419. )
  420. file_item = upload_file_handler(
  421. request,
  422. file=file,
  423. metadata=metadata,
  424. process=False,
  425. user=user,
  426. )
  427. url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
  428. return url
  429. @router.post("/generations")
  430. async def image_generations(
  431. request: Request,
  432. form_data: GenerateImageForm,
  433. user=Depends(get_verified_user),
  434. ):
  435. # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default
  436. # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an
  437. # image model other than gpt-image-1, which is warned about on settings save
  438. size = "512x512"
  439. if (
  440. request.app.state.config.IMAGE_SIZE
  441. and "x" in request.app.state.config.IMAGE_SIZE
  442. ):
  443. size = request.app.state.config.IMAGE_SIZE
  444. if form_data.size and "x" in form_data.size:
  445. size = form_data.size
  446. width, height = tuple(map(int, size.split("x")))
  447. model = get_image_model(request)
  448. r = None
  449. try:
  450. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  451. headers = {}
  452. headers["Authorization"] = (
  453. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  454. )
  455. headers["Content-Type"] = "application/json"
  456. if ENABLE_FORWARD_USER_INFO_HEADERS:
  457. headers["X-OpenWebUI-User-Name"] = quote(user.name, safe=" ")
  458. headers["X-OpenWebUI-User-Id"] = user.id
  459. headers["X-OpenWebUI-User-Email"] = user.email
  460. headers["X-OpenWebUI-User-Role"] = user.role
  461. data = {
  462. "model": model,
  463. "prompt": form_data.prompt,
  464. "n": form_data.n,
  465. "size": (
  466. form_data.size
  467. if form_data.size
  468. else request.app.state.config.IMAGE_SIZE
  469. ),
  470. **(
  471. {}
  472. if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
  473. else {"response_format": "b64_json"}
  474. ),
  475. }
  476. api_version_query_param = ""
  477. if request.app.state.config.IMAGES_OPENAI_API_VERSION:
  478. api_version_query_param = (
  479. f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}"
  480. )
  481. # Use asyncio.to_thread for the requests.post call
  482. r = await asyncio.to_thread(
  483. requests.post,
  484. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}",
  485. json=data,
  486. headers=headers,
  487. )
  488. r.raise_for_status()
  489. res = r.json()
  490. images = []
  491. for image in res["data"]:
  492. if image_url := image.get("url", None):
  493. image_data, content_type = load_url_image_data(image_url, headers)
  494. else:
  495. image_data, content_type = load_b64_image_data(image["b64_json"])
  496. url = upload_image(request, image_data, content_type, data, user)
  497. images.append({"url": url})
  498. return images
  499. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
  500. headers = {}
  501. headers["Content-Type"] = "application/json"
  502. headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
  503. data = {
  504. "instances": {"prompt": form_data.prompt},
  505. "parameters": {
  506. "sampleCount": form_data.n,
  507. "outputOptions": {"mimeType": "image/png"},
  508. },
  509. }
  510. # Use asyncio.to_thread for the requests.post call
  511. r = await asyncio.to_thread(
  512. requests.post,
  513. url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
  514. json=data,
  515. headers=headers,
  516. )
  517. r.raise_for_status()
  518. res = r.json()
  519. images = []
  520. for image in res["predictions"]:
  521. image_data, content_type = load_b64_image_data(
  522. image["bytesBase64Encoded"]
  523. )
  524. url = upload_image(request, image_data, content_type, data, user)
  525. images.append({"url": url})
  526. return images
  527. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  528. data = {
  529. "prompt": form_data.prompt,
  530. "width": width,
  531. "height": height,
  532. "n": form_data.n,
  533. }
  534. if request.app.state.config.IMAGE_STEPS is not None:
  535. data["steps"] = request.app.state.config.IMAGE_STEPS
  536. if form_data.negative_prompt is not None:
  537. data["negative_prompt"] = form_data.negative_prompt
  538. form_data = ComfyUIGenerateImageForm(
  539. **{
  540. "workflow": ComfyUIWorkflow(
  541. **{
  542. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  543. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  544. }
  545. ),
  546. **data,
  547. }
  548. )
  549. res = await comfyui_generate_image(
  550. model,
  551. form_data,
  552. user.id,
  553. request.app.state.config.COMFYUI_BASE_URL,
  554. request.app.state.config.COMFYUI_API_KEY,
  555. )
  556. log.debug(f"res: {res}")
  557. images = []
  558. for image in res["data"]:
  559. headers = None
  560. if request.app.state.config.COMFYUI_API_KEY:
  561. headers = {
  562. "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
  563. }
  564. image_data, content_type = load_url_image_data(image["url"], headers)
  565. url = upload_image(
  566. request,
  567. image_data,
  568. content_type,
  569. form_data.model_dump(exclude_none=True),
  570. user,
  571. )
  572. images.append({"url": url})
  573. return images
  574. elif (
  575. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  576. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  577. ):
  578. if form_data.model:
  579. set_image_model(request, form_data.model)
  580. data = {
  581. "prompt": form_data.prompt,
  582. "batch_size": form_data.n,
  583. "width": width,
  584. "height": height,
  585. }
  586. if request.app.state.config.IMAGE_STEPS is not None:
  587. data["steps"] = request.app.state.config.IMAGE_STEPS
  588. if form_data.negative_prompt is not None:
  589. data["negative_prompt"] = form_data.negative_prompt
  590. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  591. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  592. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  593. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  594. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  595. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  596. # Use asyncio.to_thread for the requests.post call
  597. r = await asyncio.to_thread(
  598. requests.post,
  599. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  600. json=data,
  601. headers={"authorization": get_automatic1111_api_auth(request)},
  602. )
  603. res = r.json()
  604. log.debug(f"res: {res}")
  605. images = []
  606. for image in res["images"]:
  607. image_data, content_type = load_b64_image_data(image)
  608. url = upload_image(
  609. request,
  610. image_data,
  611. content_type,
  612. {**data, "info": res["info"]},
  613. user,
  614. )
  615. images.append({"url": url})
  616. return images
  617. except Exception as e:
  618. error = e
  619. if r != None:
  620. data = r.json()
  621. if "error" in data:
  622. error = data["error"]["message"]
  623. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))