comfyui.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import asyncio
  2. import json
  3. import logging
  4. import random
  5. import requests
  6. import aiohttp
  7. import urllib.parse
  8. import urllib.request
  9. from typing import Optional
  10. import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
  11. from open_webui.env import SRC_LOG_LEVELS
  12. from pydantic import BaseModel
  13. log = logging.getLogger(__name__)
  14. log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
  15. default_headers = {"User-Agent": "Mozilla/5.0"}
  16. def queue_prompt(prompt, client_id, base_url, api_key):
  17. log.info("queue_prompt")
  18. p = {"prompt": prompt, "client_id": client_id}
  19. data = json.dumps(p).encode("utf-8")
  20. log.debug(f"queue_prompt data: {data}")
  21. try:
  22. req = urllib.request.Request(
  23. f"{base_url}/prompt",
  24. data=data,
  25. headers={**default_headers, "Authorization": f"Bearer {api_key}"},
  26. )
  27. response = urllib.request.urlopen(req).read()
  28. return json.loads(response)
  29. except Exception as e:
  30. log.exception(f"Error while queuing prompt: {e}")
  31. raise e
  32. def get_image(filename, subfolder, folder_type, base_url, api_key):
  33. log.info("get_image")
  34. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  35. url_values = urllib.parse.urlencode(data)
  36. req = urllib.request.Request(
  37. f"{base_url}/view?{url_values}",
  38. headers={**default_headers, "Authorization": f"Bearer {api_key}"},
  39. )
  40. with urllib.request.urlopen(req) as response:
  41. return response.read()
  42. def get_image_url(filename, subfolder, folder_type, base_url):
  43. log.info("get_image")
  44. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  45. url_values = urllib.parse.urlencode(data)
  46. return f"{base_url}/view?{url_values}"
  47. def get_history(prompt_id, base_url, api_key):
  48. log.info("get_history")
  49. req = urllib.request.Request(
  50. f"{base_url}/history/{prompt_id}",
  51. headers={**default_headers, "Authorization": f"Bearer {api_key}"},
  52. )
  53. with urllib.request.urlopen(req) as response:
  54. return json.loads(response.read())
  55. def get_images(ws, prompt, client_id, base_url, api_key):
  56. prompt_id = queue_prompt(prompt, client_id, base_url, api_key)["prompt_id"]
  57. output_images = []
  58. while True:
  59. out = ws.recv()
  60. if isinstance(out, str):
  61. message = json.loads(out)
  62. if message["type"] == "executing":
  63. data = message["data"]
  64. if data["node"] is None and data["prompt_id"] == prompt_id:
  65. break # Execution is done
  66. else:
  67. continue # previews are binary data
  68. history = get_history(prompt_id, base_url, api_key)[prompt_id]
  69. for o in history["outputs"]:
  70. for node_id in history["outputs"]:
  71. node_output = history["outputs"][node_id]
  72. if "images" in node_output:
  73. for image in node_output["images"]:
  74. url = get_image_url(
  75. image["filename"], image["subfolder"], image["type"], base_url
  76. )
  77. output_images.append({"url": url})
  78. return {"data": output_images}
  79. async def comfyui_upload_image(image_file_item, base_url, api_key):
  80. url = f"{base_url}/api/upload/image"
  81. headers = {}
  82. if api_key:
  83. headers["Authorization"] = f"Bearer {api_key}"
  84. _, (filename, file_bytes, mime_type) = image_file_item
  85. form = aiohttp.FormData()
  86. form.add_field("image", file_bytes, filename=filename, content_type=mime_type)
  87. form.add_field("type", "input") # required by ComfyUI
  88. async with aiohttp.ClientSession() as session:
  89. async with session.post(url, data=form, headers=headers) as resp:
  90. resp.raise_for_status()
  91. return await resp.json()
  92. class ComfyUINodeInput(BaseModel):
  93. type: Optional[str] = None
  94. node_ids: list[str] = []
  95. key: Optional[str] = "text"
  96. value: Optional[str] = None
  97. class ComfyUIWorkflow(BaseModel):
  98. workflow: str
  99. nodes: list[ComfyUINodeInput]
  100. class ComfyUICreateImageForm(BaseModel):
  101. workflow: ComfyUIWorkflow
  102. prompt: str
  103. negative_prompt: Optional[str] = None
  104. width: int
  105. height: int
  106. n: int = 1
  107. steps: Optional[int] = None
  108. seed: Optional[int] = None
  109. async def comfyui_create_image(
  110. model: str, payload: ComfyUICreateImageForm, client_id, base_url, api_key
  111. ):
  112. ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
  113. workflow = json.loads(payload.workflow.workflow)
  114. for node in payload.workflow.nodes:
  115. if node.type:
  116. if node.type == "model":
  117. for node_id in node.node_ids:
  118. workflow[node_id]["inputs"][node.key] = model
  119. elif node.type == "prompt":
  120. for node_id in node.node_ids:
  121. workflow[node_id]["inputs"][
  122. node.key if node.key else "text"
  123. ] = payload.prompt
  124. elif node.type == "negative_prompt":
  125. for node_id in node.node_ids:
  126. workflow[node_id]["inputs"][
  127. node.key if node.key else "text"
  128. ] = payload.negative_prompt
  129. elif node.type == "width":
  130. for node_id in node.node_ids:
  131. workflow[node_id]["inputs"][
  132. node.key if node.key else "width"
  133. ] = payload.width
  134. elif node.type == "height":
  135. for node_id in node.node_ids:
  136. workflow[node_id]["inputs"][
  137. node.key if node.key else "height"
  138. ] = payload.height
  139. elif node.type == "n":
  140. for node_id in node.node_ids:
  141. workflow[node_id]["inputs"][
  142. node.key if node.key else "batch_size"
  143. ] = payload.n
  144. elif node.type == "steps":
  145. for node_id in node.node_ids:
  146. workflow[node_id]["inputs"][
  147. node.key if node.key else "steps"
  148. ] = payload.steps
  149. elif node.type == "seed":
  150. seed = (
  151. payload.seed
  152. if payload.seed
  153. else random.randint(0, 1125899906842624)
  154. )
  155. for node_id in node.node_ids:
  156. workflow[node_id]["inputs"][node.key] = seed
  157. else:
  158. for node_id in node.node_ids:
  159. workflow[node_id]["inputs"][node.key] = node.value
  160. try:
  161. ws = websocket.WebSocket()
  162. headers = {"Authorization": f"Bearer {api_key}"}
  163. ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers)
  164. log.info("WebSocket connection established.")
  165. except Exception as e:
  166. log.exception(f"Failed to connect to WebSocket server: {e}")
  167. return None
  168. try:
  169. log.info("Sending workflow to WebSocket server.")
  170. log.info(f"Workflow: {workflow}")
  171. images = await asyncio.to_thread(
  172. get_images, ws, workflow, client_id, base_url, api_key
  173. )
  174. except Exception as e:
  175. log.exception(f"Error while receiving images: {e}")
  176. images = None
  177. ws.close()
  178. return images
  179. class ComfyUIEditImageForm(BaseModel):
  180. workflow: ComfyUIWorkflow
  181. image: str | list[str]
  182. prompt: str
  183. width: Optional[int] = None
  184. height: Optional[int] = None
  185. n: Optional[int] = None
  186. steps: Optional[int] = None
  187. seed: Optional[int] = None
  188. async def comfyui_edit_image(
  189. model: str, payload: ComfyUIEditImageForm, client_id, base_url, api_key
  190. ):
  191. ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
  192. workflow = json.loads(payload.workflow.workflow)
  193. for node in payload.workflow.nodes:
  194. if node.type:
  195. if node.type == "model":
  196. for node_id in node.node_ids:
  197. workflow[node_id]["inputs"][node.key] = model
  198. elif node.type == "image":
  199. if isinstance(payload.image, list):
  200. # check if multiple images are provided
  201. for idx, node_id in enumerate(node.node_ids):
  202. if idx < len(payload.image):
  203. workflow[node_id]["inputs"][node.key] = payload.image[idx]
  204. else:
  205. for node_id in node.node_ids:
  206. workflow[node_id]["inputs"][node.key] = payload.image
  207. elif node.type == "prompt":
  208. for node_id in node.node_ids:
  209. workflow[node_id]["inputs"][
  210. node.key if node.key else "text"
  211. ] = payload.prompt
  212. elif node.type == "negative_prompt":
  213. for node_id in node.node_ids:
  214. workflow[node_id]["inputs"][
  215. node.key if node.key else "text"
  216. ] = payload.negative_prompt
  217. elif node.type == "width":
  218. for node_id in node.node_ids:
  219. workflow[node_id]["inputs"][
  220. node.key if node.key else "width"
  221. ] = payload.width
  222. elif node.type == "height":
  223. for node_id in node.node_ids:
  224. workflow[node_id]["inputs"][
  225. node.key if node.key else "height"
  226. ] = payload.height
  227. elif node.type == "n":
  228. for node_id in node.node_ids:
  229. workflow[node_id]["inputs"][
  230. node.key if node.key else "batch_size"
  231. ] = payload.n
  232. elif node.type == "steps":
  233. for node_id in node.node_ids:
  234. workflow[node_id]["inputs"][
  235. node.key if node.key else "steps"
  236. ] = payload.steps
  237. elif node.type == "seed":
  238. seed = (
  239. payload.seed
  240. if payload.seed
  241. else random.randint(0, 1125899906842624)
  242. )
  243. for node_id in node.node_ids:
  244. workflow[node_id]["inputs"][node.key] = seed
  245. else:
  246. for node_id in node.node_ids:
  247. workflow[node_id]["inputs"][node.key] = node.value
  248. try:
  249. ws = websocket.WebSocket()
  250. headers = {"Authorization": f"Bearer {api_key}"}
  251. ws.connect(f"{ws_url}/ws?clientId={client_id}", header=headers)
  252. log.info("WebSocket connection established.")
  253. except Exception as e:
  254. log.exception(f"Failed to connect to WebSocket server: {e}")
  255. return None
  256. try:
  257. log.info("Sending workflow to WebSocket server.")
  258. log.info(f"Workflow: {workflow}")
  259. images = await asyncio.to_thread(
  260. get_images, ws, workflow, client_id, base_url, api_key
  261. )
  262. except Exception as e:
  263. log.exception(f"Error while receiving images: {e}")
  264. images = None
  265. ws.close()
  266. return images