main.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. import re
  2. import requests
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. )
  9. from fastapi.middleware.cors import CORSMiddleware
  10. from constants import ERROR_MESSAGES
  11. from utils.utils import (
  12. get_verified_user,
  13. get_admin_user,
  14. )
  15. from apps.images.utils.comfyui import ComfyUIGenerateImageForm, comfyui_generate_image
  16. from typing import Optional
  17. from pydantic import BaseModel
  18. from pathlib import Path
  19. import mimetypes
  20. import uuid
  21. import base64
  22. import json
  23. import logging
  24. from config import (
  25. SRC_LOG_LEVELS,
  26. CACHE_DIR,
  27. IMAGE_GENERATION_ENGINE,
  28. ENABLE_IMAGE_GENERATION,
  29. AUTOMATIC1111_BASE_URL,
  30. AUTOMATIC1111_API_AUTH,
  31. COMFYUI_BASE_URL,
  32. COMFYUI_WORKFLOW,
  33. IMAGES_OPENAI_API_BASE_URL,
  34. IMAGES_OPENAI_API_KEY,
  35. IMAGE_GENERATION_MODEL,
  36. IMAGE_SIZE,
  37. IMAGE_STEPS,
  38. CORS_ALLOW_ORIGIN,
  39. AppConfig,
  40. )
  41. log = logging.getLogger(__name__)
  42. log.setLevel(SRC_LOG_LEVELS["IMAGES"])
  43. IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
  44. IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  45. app = FastAPI()
  46. app.add_middleware(
  47. CORSMiddleware,
  48. allow_origins=CORS_ALLOW_ORIGIN,
  49. allow_credentials=True,
  50. allow_methods=["*"],
  51. allow_headers=["*"],
  52. )
  53. app.state.config = AppConfig()
  54. app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
  55. app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
  56. app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
  57. app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
  58. app.state.config.MODEL = IMAGE_GENERATION_MODEL
  59. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  60. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  61. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  62. app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
  63. app.state.config.IMAGE_SIZE = IMAGE_SIZE
  64. app.state.config.IMAGE_STEPS = IMAGE_STEPS
  65. def get_automatic1111_api_auth():
  66. if app.state.config.AUTOMATIC1111_API_AUTH is None:
  67. return ""
  68. else:
  69. auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
  70. auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
  71. auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
  72. return f"Basic {auth1111_base64_encoded_string}"
  73. @app.get("/config")
  74. async def get_config(request: Request, user=Depends(get_admin_user)):
  75. return {
  76. "engine": app.state.config.ENGINE,
  77. "enabled": app.state.config.ENABLED,
  78. }
  79. class ConfigUpdateForm(BaseModel):
  80. engine: str
  81. enabled: bool
  82. @app.post("/config/update")
  83. async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  84. app.state.config.ENGINE = form_data.engine
  85. app.state.config.ENABLED = form_data.enabled
  86. return {
  87. "engine": app.state.config.ENGINE,
  88. "enabled": app.state.config.ENABLED,
  89. }
  90. class EngineUrlUpdateForm(BaseModel):
  91. AUTOMATIC1111_BASE_URL: Optional[str] = None
  92. AUTOMATIC1111_API_AUTH: Optional[str] = None
  93. COMFYUI_BASE_URL: Optional[str] = None
  94. @app.get("/url")
  95. async def get_engine_url(user=Depends(get_admin_user)):
  96. return {
  97. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  98. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  99. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  100. }
  101. @app.post("/url/update")
  102. async def update_engine_url(
  103. form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
  104. ):
  105. if form_data.AUTOMATIC1111_BASE_URL is None:
  106. app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
  107. else:
  108. url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
  109. try:
  110. r = requests.head(url)
  111. r.raise_for_status()
  112. app.state.config.AUTOMATIC1111_BASE_URL = url
  113. except Exception as e:
  114. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  115. if form_data.COMFYUI_BASE_URL is None:
  116. app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
  117. else:
  118. url = form_data.COMFYUI_BASE_URL.strip("/")
  119. try:
  120. r = requests.head(url)
  121. r.raise_for_status()
  122. app.state.config.COMFYUI_BASE_URL = url
  123. except Exception as e:
  124. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
  125. if form_data.AUTOMATIC1111_API_AUTH is None:
  126. app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
  127. else:
  128. app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
  129. return {
  130. "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
  131. "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
  132. "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
  133. "status": True,
  134. }
  135. class OpenAIConfigUpdateForm(BaseModel):
  136. url: str
  137. key: str
  138. @app.get("/openai/config")
  139. async def get_openai_config(user=Depends(get_admin_user)):
  140. return {
  141. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  142. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  143. }
  144. @app.post("/openai/config/update")
  145. async def update_openai_config(
  146. form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
  147. ):
  148. if form_data.key == "":
  149. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
  150. app.state.config.OPENAI_API_BASE_URL = form_data.url
  151. app.state.config.OPENAI_API_KEY = form_data.key
  152. return {
  153. "status": True,
  154. "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
  155. "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
  156. }
  157. class ImageSizeUpdateForm(BaseModel):
  158. size: str
  159. @app.get("/size")
  160. async def get_image_size(user=Depends(get_admin_user)):
  161. return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
  162. @app.post("/size/update")
  163. async def update_image_size(
  164. form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
  165. ):
  166. pattern = r"^\d+x\d+$" # Regular expression pattern
  167. if re.match(pattern, form_data.size):
  168. app.state.config.IMAGE_SIZE = form_data.size
  169. return {
  170. "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
  171. "status": True,
  172. }
  173. else:
  174. raise HTTPException(
  175. status_code=400,
  176. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
  177. )
  178. class ImageStepsUpdateForm(BaseModel):
  179. steps: int
  180. @app.get("/steps")
  181. async def get_image_size(user=Depends(get_admin_user)):
  182. return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
  183. @app.post("/steps/update")
  184. async def update_image_size(
  185. form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
  186. ):
  187. if form_data.steps >= 0:
  188. app.state.config.IMAGE_STEPS = form_data.steps
  189. return {
  190. "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
  191. "status": True,
  192. }
  193. else:
  194. raise HTTPException(
  195. status_code=400,
  196. detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
  197. )
  198. @app.get("/models")
  199. def get_models(user=Depends(get_verified_user)):
  200. try:
  201. if app.state.config.ENGINE == "openai":
  202. return [
  203. {"id": "dall-e-2", "name": "DALL·E 2"},
  204. {"id": "dall-e-3", "name": "DALL·E 3"},
  205. ]
  206. elif app.state.config.ENGINE == "comfyui":
  207. r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
  208. info = r.json()
  209. return list(
  210. map(
  211. lambda model: {"id": model, "name": model},
  212. info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
  213. )
  214. )
  215. else:
  216. r = requests.get(
  217. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
  218. headers={"authorization": get_automatic1111_api_auth()},
  219. )
  220. models = r.json()
  221. return list(
  222. map(
  223. lambda model: {"id": model["title"], "name": model["model_name"]},
  224. models,
  225. )
  226. )
  227. except Exception as e:
  228. app.state.config.ENABLED = False
  229. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  230. @app.get("/models/default")
  231. async def get_default_model(user=Depends(get_admin_user)):
  232. try:
  233. if app.state.config.ENGINE == "openai":
  234. return {
  235. "model": (
  236. app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
  237. )
  238. }
  239. elif app.state.config.ENGINE == "comfyui":
  240. return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
  241. else:
  242. r = requests.get(
  243. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  244. headers={"authorization": get_automatic1111_api_auth()},
  245. )
  246. options = r.json()
  247. return {"model": options["sd_model_checkpoint"]}
  248. except Exception as e:
  249. app.state.config.ENABLED = False
  250. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
  251. class UpdateModelForm(BaseModel):
  252. model: str
  253. def set_model_handler(model: str):
  254. if app.state.config.ENGINE in ["openai", "comfyui"]:
  255. app.state.config.MODEL = model
  256. return app.state.config.MODEL
  257. else:
  258. api_auth = get_automatic1111_api_auth()
  259. r = requests.get(
  260. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  261. headers={"authorization": api_auth},
  262. )
  263. options = r.json()
  264. if model != options["sd_model_checkpoint"]:
  265. options["sd_model_checkpoint"] = model
  266. r = requests.post(
  267. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
  268. json=options,
  269. headers={"authorization": api_auth},
  270. )
  271. return options
  272. @app.post("/models/default/update")
  273. def update_default_model(
  274. form_data: UpdateModelForm,
  275. user=Depends(get_verified_user),
  276. ):
  277. return set_model_handler(form_data.model)
  278. class GenerateImageForm(BaseModel):
  279. model: Optional[str] = None
  280. prompt: str
  281. n: int = 1
  282. size: Optional[str] = None
  283. negative_prompt: Optional[str] = None
  284. def save_b64_image(b64_str):
  285. try:
  286. image_id = str(uuid.uuid4())
  287. if "," in b64_str:
  288. header, encoded = b64_str.split(",", 1)
  289. mime_type = header.split(";")[0]
  290. img_data = base64.b64decode(encoded)
  291. image_format = mimetypes.guess_extension(mime_type)
  292. image_filename = f"{image_id}{image_format}"
  293. file_path = IMAGE_CACHE_DIR / f"{image_filename}"
  294. with open(file_path, "wb") as f:
  295. f.write(img_data)
  296. return image_filename
  297. else:
  298. image_filename = f"{image_id}.png"
  299. file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
  300. img_data = base64.b64decode(b64_str)
  301. # Write the image data to a file
  302. with open(file_path, "wb") as f:
  303. f.write(img_data)
  304. return image_filename
  305. except Exception as e:
  306. log.exception(f"Error saving image: {e}")
  307. return None
  308. def save_url_image(url):
  309. image_id = str(uuid.uuid4())
  310. try:
  311. r = requests.get(url)
  312. r.raise_for_status()
  313. if r.headers["content-type"].split("/")[0] == "image":
  314. mime_type = r.headers["content-type"]
  315. image_format = mimetypes.guess_extension(mime_type)
  316. if not image_format:
  317. raise ValueError("Could not determine image type from MIME type")
  318. image_filename = f"{image_id}{image_format}"
  319. file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
  320. with open(file_path, "wb") as image_file:
  321. for chunk in r.iter_content(chunk_size=8192):
  322. image_file.write(chunk)
  323. return image_filename
  324. else:
  325. log.error(f"Url does not point to an image.")
  326. return None
  327. except Exception as e:
  328. log.exception(f"Error saving image: {e}")
  329. return None
  330. @app.post("/generations")
  331. async def image_generations(
  332. form_data: GenerateImageForm,
  333. user=Depends(get_verified_user),
  334. ):
  335. width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
  336. r = None
  337. try:
  338. if app.state.config.ENGINE == "openai":
  339. headers = {}
  340. headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
  341. headers["Content-Type"] = "application/json"
  342. data = {
  343. "model": (
  344. app.state.config.MODEL
  345. if app.state.config.MODEL != ""
  346. else "dall-e-2"
  347. ),
  348. "prompt": form_data.prompt,
  349. "n": form_data.n,
  350. "size": (
  351. form_data.size if form_data.size else app.state.config.IMAGE_SIZE
  352. ),
  353. "response_format": "b64_json",
  354. }
  355. r = requests.post(
  356. url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
  357. json=data,
  358. headers=headers,
  359. )
  360. r.raise_for_status()
  361. res = r.json()
  362. images = []
  363. for image in res["data"]:
  364. image_filename = save_b64_image(image["b64_json"])
  365. images.append({"url": f"/cache/image/generations/{image_filename}"})
  366. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  367. with open(file_body_path, "w") as f:
  368. json.dump(data, f)
  369. return images
  370. elif app.state.config.ENGINE == "comfyui":
  371. data = {
  372. "prompt": form_data.prompt,
  373. "width": width,
  374. "height": height,
  375. "n": form_data.n,
  376. }
  377. if app.state.config.IMAGE_STEPS is not None:
  378. data["steps"] = app.state.config.IMAGE_STEPS
  379. if form_data.negative_prompt is not None:
  380. data["negative_prompt"] = form_data.negative_prompt
  381. form_data = ComfyUIGenerateImageForm(**data)
  382. res = await comfyui_generate_image(
  383. app.state.config.MODEL,
  384. form_data,
  385. user.id,
  386. app.state.config.COMFYUI_BASE_URL,
  387. )
  388. log.debug(f"res: {res}")
  389. images = []
  390. for image in res["data"]:
  391. image_filename = save_url_image(image["url"])
  392. images.append({"url": f"/cache/image/generations/{image_filename}"})
  393. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  394. with open(file_body_path, "w") as f:
  395. json.dump(data.model_dump(exclude_none=True), f)
  396. log.debug(f"images: {images}")
  397. return images
  398. else:
  399. if form_data.model:
  400. set_model_handler(form_data.model)
  401. data = {
  402. "prompt": form_data.prompt,
  403. "batch_size": form_data.n,
  404. "width": width,
  405. "height": height,
  406. }
  407. if app.state.config.IMAGE_STEPS is not None:
  408. data["steps"] = app.state.config.IMAGE_STEPS
  409. if form_data.negative_prompt is not None:
  410. data["negative_prompt"] = form_data.negative_prompt
  411. r = requests.post(
  412. url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
  413. json=data,
  414. headers={"authorization": get_automatic1111_api_auth()},
  415. )
  416. res = r.json()
  417. log.debug(f"res: {res}")
  418. images = []
  419. for image in res["images"]:
  420. image_filename = save_b64_image(image)
  421. images.append({"url": f"/cache/image/generations/{image_filename}"})
  422. file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
  423. with open(file_body_path, "w") as f:
  424. json.dump({**data, "info": res["info"]}, f)
  425. return images
  426. except Exception as e:
  427. error = e
  428. if r != None:
  429. data = r.json()
  430. if "error" in data:
  431. error = data["error"]["message"]
  432. raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))