images.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import asyncio
  2. import base64
  3. import json
  4. import logging
  5. import mimetypes
  6. import re
  7. import uuid
  8. from pathlib import Path
  9. from typing import Optional
  10. import requests
  11. from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from pydantic import BaseModel
  14. from open_webui.config import CACHE_DIR
  15. from open_webui.constants import ERROR_MESSAGES
  16. from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
  17. from open_webui.utils.auth import get_admin_user, get_verified_user
  18. from open_webui.utils.images.comfyui import (
  19. ComfyUIGenerateImageForm,
  20. ComfyUIWorkflow,
  21. comfyui_generate_image,
  22. )
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  25. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  26. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  27. router = APIRouter()
  28. @router.get("/config")
  29. async def get_config(request: Request, user=Depends(get_admin_user)):
  30. return {
  31. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  32. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  33. "openai": {
  34. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  35. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  36. },
  37. "automatic1111": {
  38. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  39. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  40. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  41. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  42. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  43. },
  44. "comfyui": {
  45. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  46. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  47. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  48. },
  49. }
  50. class OpenAIConfigForm(BaseModel):
  51. OPENAI_API_BASE_URL: str
  52. OPENAI_API_KEY: str
  53. class Automatic1111ConfigForm(BaseModel):
  54. AUTOMATIC1111_BASE_URL: str
  55. AUTOMATIC1111_API_AUTH: str
  56. AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
  57. AUTOMATIC1111_SAMPLER: Optional[str]
  58. AUTOMATIC1111_SCHEDULER: Optional[str]
  59. class ComfyUIConfigForm(BaseModel):
  60. COMFYUI_BASE_URL: str
  61. COMFYUI_WORKFLOW: str
  62. COMFYUI_WORKFLOW_NODES: list[dict]
  63. class ConfigForm(BaseModel):
  64. enabled: bool
  65. engine: str
  66. openai: OpenAIConfigForm
  67. automatic1111: Automatic1111ConfigForm
  68. comfyui: ComfyUIConfigForm
  69. @router.post("/config/update")
  70. async def update_config(
  71. request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
  72. ):
  73. request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
  74. request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
  75. request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
  76. form_data.openai.OPENAI_API_BASE_URL
  77. )
  78. request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
  79. request.app.state.config.AUTOMATIC1111_BASE_URL = (
  80. form_data.automatic1111.AUTOMATIC1111_BASE_URL
  81. )
  82. request.app.state.config.AUTOMATIC1111_API_AUTH = (
  83. form_data.automatic1111.AUTOMATIC1111_API_AUTH
  84. )
  85. request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
  86. float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
  87. if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
  88. else None
  89. )
  90. request.app.state.config.AUTOMATIC1111_SAMPLER = (
  91. form_data.automatic1111.AUTOMATIC1111_SAMPLER
  92. if form_data.automatic1111.AUTOMATIC1111_SAMPLER
  93. else None
  94. )
  95. request.app.state.config.AUTOMATIC1111_SCHEDULER = (
  96. form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  97. if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
  98. else None
  99. )
  100. request.app.state.config.COMFYUI_BASE_URL = (
  101. form_data.comfyui.COMFYUI_BASE_URL.strip("/")
  102. )
  103. request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
  104. request.app.state.config.COMFYUI_WORKFLOW_NODES = (
  105. form_data.comfyui.COMFYUI_WORKFLOW_NODES
  106. )
  107. return {
  108. "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
  109. "engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
  110. "openai": {
  111. "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
  112. "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
  113. },
  114. "automatic1111": {
  115. "AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
  116. "AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
  117. "AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
  118. "AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
  119. "AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
  120. },
  121. "comfyui": {
  122. "COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
  123. "COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
  124. "COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  125. },
  126. }
  127. def get_automatic1111_api_auth(request: Request):
  128. if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
  129. return ""
  130. else:
  131. auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
  132. "utf-8"
  133. )
  134. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  135. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  136. return f"Basic {auth1111_base64_encoded_string}"
  137. @router.get("/config/url/verify")
  138. async def verify_url(request: Request, user=Depends(get_admin_user)):
  139. if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
  140. try:
  141. r = requests.get(
  142. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  143. headers={"authorization": get_automatic1111_api_auth(request)},
  144. )
  145. r.raise_for_status()
  146. return True
  147. except Exception:
  148. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  149. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  150. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  151. try:
  152. r = requests.get(
  153. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
  154. )
  155. r.raise_for_status()
  156. return True
  157. except Exception:
  158. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  159. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  160. else:
  161. return True
  162. def set_image_model(request: Request, model: str):
  163. log.info(f"Setting image model to {model}")
  164. request.app.state.config.MODEL = model
  165. if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
  166. api_auth = get_automatic1111_api_auth()
  167. r = requests.get(
  168. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  169. headers={"authorization": api_auth},
  170. )
  171. options = r.json()
  172. if model != options["sd_model_checkpoint"]:
  173. options["sd_model_checkpoint"] = model
  174. r = requests.post(
  175. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  176. json=options,
  177. headers={"authorization": api_auth},
  178. )
  179. return request.app.state.config.MODEL
  180. def get_image_model():
  181. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  182. return (
  183. request.app.state.config.MODEL
  184. if request.app.state.config.MODEL
  185. else "dall-e-2"
  186. )
  187. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  188. return request.app.state.config.MODEL if request.app.state.config.MODEL else ""
  189. elif (
  190. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  191. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  192. ):
  193. try:
  194. r = requests.get(
  195. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  196. headers={"authorization": get_automatic1111_api_auth()},
  197. )
  198. options = r.json()
  199. return options["sd_model_checkpoint"]
  200. except Exception as e:
  201. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  202. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  203. class ImageConfigForm(BaseModel):
  204. MODEL: str
  205. IMAGE_SIZE: str
  206. IMAGE_STEPS: int
  207. @router.get("/image/config")
  208. async def get_image_config(request: Request, user=Depends(get_admin_user)):
  209. return {
  210. "MODEL": request.app.state.config.MODEL,
  211. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  212. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  213. }
  214. @router.post("/image/config/update")
  215. async def update_image_config(
  216. request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
  217. ):
  218. set_image_model(request, form_data.MODEL)
  219. pattern = r"^\d+x\d+$"
  220. if re.match(pattern, form_data.IMAGE_SIZE):
  221. request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
  222. else:
  223. raise HTTPException(
  224. status_code=400,
  225. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  226. )
  227. if form_data.IMAGE_STEPS >= 0:
  228. request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
  229. else:
  230. raise HTTPException(
  231. status_code=400,
  232. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  233. )
  234. return {
  235. "MODEL": request.app.state.config.MODEL,
  236. "IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
  237. "IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
  238. }
  239. @router.get("/models")
  240. def get_models(request: Request, user=Depends(get_verified_user)):
  241. try:
  242. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  243. return [
  244. {"id": "dall-e-2", "name": "DALL·E 2"},
  245. {"id": "dall-e-3", "name": "DALL·E 3"},
  246. ]
  247. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  248. # TODO - get models from comfyui
  249. r = requests.get(
  250. url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
  251. )
  252. info = r.json()
  253. workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
  254. model_node_id = None
  255. for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
  256. if node["type"] == "model":
  257. if node["node_ids"]:
  258. model_node_id = node["node_ids"][0]
  259. break
  260. if model_node_id:
  261. model_list_key = None
  262. print(workflow[model_node_id]["class_type"])
  263. for key in info[workflow[model_node_id]["class_type"]]["input"][
  264. "required"
  265. ]:
  266. if "_name" in key:
  267. model_list_key = key
  268. break
  269. if model_list_key:
  270. return list(
  271. map(
  272. lambda model: {"id": model, "name": model},
  273. info[workflow[model_node_id]["class_type"]]["input"][
  274. "required"
  275. ][model_list_key][0],
  276. )
  277. )
  278. else:
  279. return list(
  280. map(
  281. lambda model: {"id": model, "name": model},
  282. info["CheckpointLoaderSimple"]["input"]["required"][
  283. "ckpt_name"
  284. ][0],
  285. )
  286. )
  287. elif (
  288. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  289. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  290. ):
  291. r = requests.get(
  292. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  293. headers={"authorization": get_automatic1111_api_auth()},
  294. )
  295. models = r.json()
  296. return list(
  297. map(
  298. lambda model: {"id": model["title"], "name": model["model_name"]},
  299. models,
  300. )
  301. )
  302. except Exception as e:
  303. request.app.state.config.ENABLE_IMAGE_GENERATION = False
  304. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  305. class GenerateImageForm(BaseModel):
  306. model: Optional[str] = None
  307. prompt: str
  308. size: Optional[str] = None
  309. n: int = 1
  310. negative_prompt: Optional[str] = None
  311. def save_b64_image(b64_str):
  312. try:
  313. image_id = str(uuid.uuid4())
  314. if "," in b64_str:
  315. header, encoded = b64_str.split(",", 1)
  316. mime_type = header.split(";")[0]
  317. img_data = base64.b64decode(encoded)
  318. image_format = mimetypes.guess_extension(mime_type)
  319. image_filename = f"{image_id}{image_format}"
  320. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  321. with open(file_path, "wb") as f:
  322. f.write(img_data)
  323. return image_filename
  324. else:
  325. image_filename = f"{image_id}.png"
  326. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  327. img_data = base64.b64decode(b64_str)
  328. # Write the image data to a file
  329. with open(file_path, "wb") as f:
  330. f.write(img_data)
  331. return image_filename
  332. except Exception as e:
  333. log.exception(f"Error saving image: {e}")
  334. return None
  335. def save_url_image(url):
  336. image_id = str(uuid.uuid4())
  337. try:
  338. r = requests.get(url)
  339. r.raise_for_status()
  340. if r.headers["content-type"].split("/")[0] == "image":
  341. mime_type = r.headers["content-type"]
  342. image_format = mimetypes.guess_extension(mime_type)
  343. if not image_format:
  344. raise ValueError("Could not determine image type from MIME type")
  345. image_filename = f"{image_id}{image_format}"
  346. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  347. with open(file_path, "wb") as image_file:
  348. for chunk in r.iter_content(chunk_size=8192):
  349. image_file.write(chunk)
  350. return image_filename
  351. else:
  352. log.error("Url does not point to an image.")
  353. return None
  354. except Exception as e:
  355. log.exception(f"Error saving image: {e}")
  356. return None
  357. @router.post("/generations")
  358. async def image_generations(
  359. request: Request,
  360. form_data: GenerateImageForm,
  361. user=Depends(get_verified_user),
  362. ):
  363. width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
  364. r = None
  365. try:
  366. if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
  367. headers = {}
  368. headers["Authorization"] = (
  369. f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
  370. )
  371. headers["Content-Type"] = "application/json"
  372. if ENABLE_FORWARD_USER_INFO_HEADERS:
  373. headers["X-OpenWebUI-User-Name"] = user.name
  374. headers["X-OpenWebUI-User-Id"] = user.id
  375. headers["X-OpenWebUI-User-Email"] = user.email
  376. headers["X-OpenWebUI-User-Role"] = user.role
  377. data = {
  378. "model": (
  379. request.app.state.config.MODEL
  380. if request.app.state.config.MODEL != ""
  381. else "dall-e-2"
  382. ),
  383. "prompt": form_data.prompt,
  384. "n": form_data.n,
  385. "size": (
  386. form_data.size
  387. if form_data.size
  388. else request.app.state.config.IMAGE_SIZE
  389. ),
  390. "response_format": "b64_json",
  391. }
  392. # Use asyncio.to_thread for the requests.post call
  393. r = await asyncio.to_thread(
  394. requests.post,
  395. url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
  396. json=data,
  397. headers=headers,
  398. )
  399. r.raise_for_status()
  400. res = r.json()
  401. images = []
  402. for image in res["data"]:
  403. image_filename = save_b64_image(image["b64_json"])
  404. images.append({"url": f"/cache/image/generations/{image_filename}"})
  405. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  406. with open(file_body_path, "w") as f:
  407. json.dump(data, f)
  408. return images
  409. elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
  410. data = {
  411. "prompt": form_data.prompt,
  412. "width": width,
  413. "height": height,
  414. "n": form_data.n,
  415. }
  416. if request.app.state.config.IMAGE_STEPS is not None:
  417. data["steps"] = request.app.state.config.IMAGE_STEPS
  418. if form_data.negative_prompt is not None:
  419. data["negative_prompt"] = form_data.negative_prompt
  420. form_data = ComfyUIGenerateImageForm(
  421. **{
  422. "workflow": ComfyUIWorkflow(
  423. **{
  424. "workflow": request.app.state.config.COMFYUI_WORKFLOW,
  425. "nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
  426. }
  427. ),
  428. **data,
  429. }
  430. )
  431. res = await comfyui_generate_image(
  432. request.app.state.config.MODEL,
  433. form_data,
  434. user.id,
  435. request.app.state.config.COMFYUI_BASE_URL,
  436. )
  437. log.debug(f"res: {res}")
  438. images = []
  439. for image in res["data"]:
  440. image_filename = save_url_image(image["url"])
  441. images.append({"url": f"/cache/image/generations/{image_filename}"})
  442. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  443. with open(file_body_path, "w") as f:
  444. json.dump(form_data.model_dump(exclude_none=True), f)
  445. log.debug(f"images: {images}")
  446. return images
  447. elif (
  448. request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
  449. or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
  450. ):
  451. if form_data.model:
  452. set_image_model(form_data.model)
  453. data = {
  454. "prompt": form_data.prompt,
  455. "batch_size": form_data.n,
  456. "width": width,
  457. "height": height,
  458. }
  459. if request.app.state.config.IMAGE_STEPS is not None:
  460. data["steps"] = request.app.state.config.IMAGE_STEPS
  461. if form_data.negative_prompt is not None:
  462. data["negative_prompt"] = form_data.negative_prompt
  463. if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
  464. data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
  465. if request.app.state.config.AUTOMATIC1111_SAMPLER:
  466. data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
  467. if request.app.state.config.AUTOMATIC1111_SCHEDULER:
  468. data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
  469. # Use asyncio.to_thread for the requests.post call
  470. r = await asyncio.to_thread(
  471. requests.post,
  472. url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  473. json=data,
  474. headers={"authorization": get_automatic1111_api_auth()},
  475. )
  476. res = r.json()
  477. log.debug(f"res: {res}")
  478. images = []
  479. for image in res["images"]:
  480. image_filename = save_b64_image(image)
  481. images.append({"url": f"/cache/image/generations/{image_filename}"})
  482. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  483. with open(file_body_path, "w") as f:
  484. json.dump({**data, "info": res["info"]}, f)
  485. return images
  486. except Exception as e:
  487. error = e
  488. if r != None:
  489. data = r.json()
  490. if "error" in data:
  491. error = data["error"]["message"]
  492. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))