chatgpt_api.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. import uuid
  2. import time
  3. import asyncio
  4. import json
  5. import os
  6. from pathlib import Path
  7. from transformers import AutoTokenizer
  8. from typing import List, Literal, Union, Dict, Optional
  9. from aiohttp import web
  10. import aiohttp_cors
  11. import traceback
  12. import signal
  13. from exo import DEBUG, VERSION
  14. from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
  15. from exo.inference.tokenizers import resolve_tokenizer
  16. from exo.orchestration import Node
  17. from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
  18. from typing import Callable, Optional
  19. from PIL import Image
  20. import numpy as np
  21. import base64
  22. from io import BytesIO
  23. import platform
  24. from exo.download.download_progress import RepoProgressEvent
  25. from exo.download.new_shard_download import delete_model
  26. import tempfile
  27. from exo.apputil import create_animation_mp4
  28. from collections import defaultdict
  29. if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
  30. import mlx.core as mx
  31. else:
  32. import numpy as mx
  33. class Message:
  34. def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
  35. self.role = role
  36. self.content = content
  37. self.tools = tools
  38. def to_dict(self):
  39. data = {"role": self.role, "content": self.content}
  40. if self.tools:
  41. data["tools"] = self.tools
  42. return data
  43. class ChatCompletionRequest:
  44. def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
  45. self.model = model
  46. self.messages = messages
  47. self.temperature = temperature
  48. self.tools = tools
  49. def to_dict(self):
  50. return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
  51. def generate_completion(
  52. chat_request: ChatCompletionRequest,
  53. tokenizer,
  54. prompt: str,
  55. request_id: str,
  56. tokens: List[int],
  57. stream: bool,
  58. finish_reason: Union[Literal["length", "stop"], None],
  59. object_type: Literal["chat.completion", "text_completion"],
  60. ) -> dict:
  61. completion = {
  62. "id": f"chatcmpl-{request_id}",
  63. "object": object_type,
  64. "created": int(time.time()),
  65. "model": chat_request.model,
  66. "system_fingerprint": f"exo_{VERSION}",
  67. "choices": [{
  68. "index": 0,
  69. "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
  70. "logprobs": None,
  71. "finish_reason": finish_reason,
  72. }],
  73. }
  74. if not stream:
  75. completion["usage"] = {
  76. "prompt_tokens": len(tokenizer.encode(prompt)),
  77. "completion_tokens": len(tokens),
  78. "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
  79. }
  80. choice = completion["choices"][0]
  81. if object_type.startswith("chat.completion"):
  82. key_name = "delta" if stream else "message"
  83. choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
  84. elif object_type == "text_completion":
  85. choice["text"] = tokenizer.decode(tokens)
  86. else:
  87. ValueError(f"Unsupported response type: {object_type}")
  88. return completion
  89. def remap_messages(messages: List[Message]) -> List[Message]:
  90. remapped_messages = []
  91. last_image = None
  92. for message in messages:
  93. if not isinstance(message.content, list):
  94. remapped_messages.append(message)
  95. continue
  96. remapped_content = []
  97. for content in message.content:
  98. if isinstance(content, dict):
  99. if content.get("type") in ["image_url", "image"]:
  100. image_url = content.get("image_url", {}).get("url") or content.get("image")
  101. if image_url:
  102. last_image = {"type": "image", "image": image_url}
  103. remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
  104. else:
  105. remapped_content.append(content)
  106. else:
  107. remapped_content.append(content)
  108. remapped_messages.append(Message(role=message.role, content=remapped_content))
  109. if last_image:
  110. # Replace the last image placeholder with the actual image content
  111. for message in reversed(remapped_messages):
  112. for i, content in enumerate(message.content):
  113. if isinstance(content, dict):
  114. if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
  115. message.content[i] = last_image
  116. return remapped_messages
  117. return remapped_messages
  118. def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
  119. messages = remap_messages(_messages)
  120. chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
  121. if tools:
  122. chat_template_args["tools"] = tools
  123. try:
  124. prompt = tokenizer.apply_chat_template(**chat_template_args)
  125. if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
  126. return prompt
  127. except UnicodeEncodeError:
  128. # Handle Unicode encoding by ensuring everything is UTF-8
  129. chat_template_args["conversation"] = [
  130. {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
  131. for k, v in m.to_dict().items()}
  132. for m in messages
  133. ]
  134. prompt = tokenizer.apply_chat_template(**chat_template_args)
  135. if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
  136. return prompt
  137. def parse_message(data: dict):
  138. if "role" not in data or "content" not in data:
  139. raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
  140. return Message(data["role"], data["content"], data.get("tools"))
  141. def parse_chat_request(data: dict, default_model: str):
  142. return ChatCompletionRequest(
  143. data.get("model", default_model),
  144. [parse_message(msg) for msg in data["messages"]],
  145. data.get("temperature", 0.0),
  146. data.get("tools", None),
  147. )
  148. class PromptSession:
  149. def __init__(self, request_id: str, timestamp: int, prompt: str):
  150. self.request_id = request_id
  151. self.timestamp = timestamp
  152. self.prompt = prompt
  153. class ChatGPTAPI:
  154. def __init__(
  155. self,
  156. node: Node,
  157. inference_engine_classname: str,
  158. response_timeout: int = 90,
  159. on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
  160. default_model: Optional[str] = None,
  161. system_prompt: Optional[str] = None
  162. ):
  163. self.node = node
  164. self.inference_engine_classname = inference_engine_classname
  165. self.response_timeout = response_timeout
  166. self.on_chat_completion_request = on_chat_completion_request
  167. self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
  168. self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
  169. self.prev_token_lens: Dict[str, int] = {}
  170. self.stream_tasks: Dict[str, asyncio.Task] = {}
  171. self.default_model = default_model or "llama-3.2-1b"
  172. self.token_queues = defaultdict(asyncio.Queue)
  173. # Get the callback system and register our handler
  174. self.token_callback = node.on_token.register("chatgpt-api-token-handler")
  175. self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
  176. self.system_prompt = system_prompt
  177. cors = aiohttp_cors.setup(self.app)
  178. cors_options = aiohttp_cors.ResourceOptions(
  179. allow_credentials=True,
  180. expose_headers="*",
  181. allow_headers="*",
  182. allow_methods="*",
  183. )
  184. cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
  185. cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
  186. cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  187. cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  188. cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  189. cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  190. cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
  191. cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
  192. cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
  193. cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
  194. cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
  195. cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
  196. cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
  197. cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
  198. cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
  199. cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
  200. cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
  201. # Add static routes
  202. if "__compiled__" not in globals():
  203. self.static_dir = Path(__file__).parent.parent/"tinychat"
  204. self.app.router.add_get("/", self.handle_root)
  205. self.app.router.add_static("/", self.static_dir, name="static")
  206. # Always add images route, regardless of compilation status
  207. self.images_dir = get_exo_images_dir()
  208. self.images_dir.mkdir(parents=True, exist_ok=True)
  209. self.app.router.add_static('/images/', self.images_dir, name='static_images')
  210. self.app.middlewares.append(self.timeout_middleware)
  211. self.app.middlewares.append(self.log_request)
  212. async def handle_quit(self, request):
  213. if DEBUG >= 1: print("Received quit signal")
  214. response = web.json_response({"detail": "Quit signal received"}, status=200)
  215. await response.prepare(request)
  216. await response.write_eof()
  217. await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
  218. async def timeout_middleware(self, app, handler):
  219. async def middleware(request):
  220. try:
  221. return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
  222. except asyncio.TimeoutError:
  223. return web.json_response({"detail": "Request timed out"}, status=408)
  224. return middleware
  225. async def log_request(self, app, handler):
  226. async def middleware(request):
  227. if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
  228. return await handler(request)
  229. return middleware
  230. async def handle_root(self, request):
  231. return web.FileResponse(self.static_dir/"index.html")
  232. async def handle_healthcheck(self, request):
  233. return web.json_response({"status": "ok"})
  234. async def handle_model_support(self, request):
  235. try:
  236. response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
  237. await response.prepare(request)
  238. downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
  239. for (path, d) in downloads:
  240. model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
  241. await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
  242. await response.write(b"data: [DONE]\n\n")
  243. return response
  244. except Exception as e:
  245. print(f"Error in handle_model_support: {str(e)}")
  246. traceback.print_exc()
  247. return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
  248. async def handle_get_models(self, request):
  249. models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
  250. return web.json_response({"object": "list", "data": models_list})
  251. async def handle_post_chat_token_encode(self, request):
  252. data = await request.json()
  253. model = data.get("model", self.default_model)
  254. if model and model.startswith("gpt-"): # Handle gpt- model requests
  255. model = self.default_model
  256. if not model or model not in model_cards:
  257. if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
  258. model = self.default_model
  259. shard = build_base_shard(model, self.inference_engine_classname)
  260. messages = [parse_message(msg) for msg in data.get("messages", [])]
  261. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
  262. prompt = build_prompt(tokenizer, messages, data.get("tools", None))
  263. tokens = tokenizer.encode(prompt)
  264. return web.json_response({
  265. "length": len(prompt),
  266. "num_tokens": len(tokens),
  267. "encoded_tokens": tokens,
  268. "encoded_prompt": prompt,
  269. })
  270. async def handle_get_download_progress(self, request):
  271. progress_data = {}
  272. for node_id, progress_event in self.node.node_download_progress.items():
  273. if isinstance(progress_event, RepoProgressEvent):
  274. if progress_event.status != "in_progress": continue
  275. progress_data[node_id] = progress_event.to_dict()
  276. else:
  277. print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
  278. return web.json_response(progress_data)
  279. async def handle_post_chat_completions(self, request):
  280. data = await request.json()
  281. if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
  282. stream = data.get("stream", False)
  283. chat_request = parse_chat_request(data, self.default_model)
  284. if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
  285. chat_request.model = self.default_model
  286. if not chat_request.model or chat_request.model not in model_cards:
  287. if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
  288. chat_request.model = self.default_model
  289. shard = build_base_shard(chat_request.model, self.inference_engine_classname)
  290. if not shard:
  291. supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
  292. return web.json_response(
  293. {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
  294. status=400,
  295. )
  296. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
  297. if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
  298. # Add system prompt if set
  299. if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
  300. chat_request.messages.insert(0, Message("system", self.system_prompt))
  301. prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
  302. request_id = str(uuid.uuid4())
  303. if self.on_chat_completion_request:
  304. try:
  305. self.on_chat_completion_request(request_id, chat_request, prompt)
  306. except Exception as e:
  307. if DEBUG >= 2: traceback.print_exc()
  308. if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
  309. try:
  310. await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
  311. if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
  312. if stream:
  313. response = web.StreamResponse(
  314. status=200,
  315. reason="OK",
  316. headers={
  317. "Content-Type": "text/event-stream",
  318. "Cache-Control": "no-cache",
  319. },
  320. )
  321. await response.prepare(request)
  322. try:
  323. # Stream tokens while waiting for inference to complete
  324. while True:
  325. if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
  326. tokens, is_finished = await asyncio.wait_for(
  327. self.token_queues[request_id].get(),
  328. timeout=self.response_timeout
  329. )
  330. if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
  331. eos_token_id = None
  332. if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
  333. if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
  334. finish_reason = None
  335. if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
  336. if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
  337. completion = generate_completion(
  338. chat_request,
  339. tokenizer,
  340. prompt,
  341. request_id,
  342. tokens,
  343. stream,
  344. finish_reason,
  345. "chat.completion",
  346. )
  347. await response.write(f"data: {json.dumps(completion)}\n\n".encode())
  348. if is_finished:
  349. break
  350. await response.write_eof()
  351. return response
  352. except asyncio.TimeoutError:
  353. if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
  354. return web.json_response({"detail": "Response generation timed out"}, status=408)
  355. except Exception as e:
  356. if DEBUG >= 2:
  357. print(f"[ChatGPTAPI] Error processing prompt: {e}")
  358. traceback.print_exc()
  359. return web.json_response(
  360. {"detail": f"Error processing prompt: {str(e)}"},
  361. status=500
  362. )
  363. finally:
  364. # Clean up the queue for this request
  365. if request_id in self.token_queues:
  366. if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
  367. del self.token_queues[request_id]
  368. else:
  369. tokens = []
  370. while True:
  371. _tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
  372. tokens.extend(_tokens)
  373. if is_finished:
  374. break
  375. finish_reason = "length"
  376. eos_token_id = None
  377. if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
  378. if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
  379. if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
  380. if tokens[-1] == eos_token_id:
  381. finish_reason = "stop"
  382. return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
  383. except asyncio.TimeoutError:
  384. return web.json_response({"detail": "Response generation timed out"}, status=408)
  385. except Exception as e:
  386. if DEBUG >= 2: traceback.print_exc()
  387. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  388. async def handle_post_image_generations(self, request):
  389. data = await request.json()
  390. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  391. stream = data.get("stream", False)
  392. model = data.get("model", "")
  393. prompt = data.get("prompt", "")
  394. image_url = data.get("image_url", "")
  395. if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
  396. shard = build_base_shard(model, self.inference_engine_classname)
  397. if DEBUG >= 2: print(f"shard: {shard}")
  398. if not shard:
  399. return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
  400. request_id = str(uuid.uuid4())
  401. callback_id = f"chatgpt-api-wait-response-{request_id}"
  402. callback = self.node.on_token.register(callback_id)
  403. try:
  404. if image_url != "" and image_url != None:
  405. img = self.base64_decode(image_url)
  406. else:
  407. img = None
  408. await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
  409. response = web.StreamResponse(status=200, reason='OK', headers={
  410. 'Content-Type': 'application/octet-stream',
  411. "Cache-Control": "no-cache",
  412. })
  413. await response.prepare(request)
  414. def get_progress_bar(current_step, total_steps, bar_length=50):
  415. # Calculate the percentage of completion
  416. percent = float(current_step)/total_steps
  417. # Calculate the number of hashes to display
  418. arrow = '-'*int(round(percent*bar_length) - 1) + '>'
  419. spaces = ' '*(bar_length - len(arrow))
  420. # Create the progress bar string
  421. progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
  422. return progress_bar
  423. async def stream_image(_request_id: str, result, is_finished: bool):
  424. if isinstance(result, list):
  425. await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
  426. elif isinstance(result, np.ndarray):
  427. try:
  428. im = Image.fromarray(np.array(result))
  429. # Save the image to a file
  430. image_filename = f"{_request_id}.png"
  431. image_path = self.images_dir/image_filename
  432. im.save(image_path)
  433. # Get URL for the saved image
  434. try:
  435. image_url = request.app.router['static_images'].url_for(filename=image_filename)
  436. base_url = f"{request.scheme}://{request.host}"
  437. full_image_url = base_url + str(image_url)
  438. await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
  439. except KeyError as e:
  440. if DEBUG >= 2: print(f"Error getting image URL: {e}")
  441. # Fallback to direct file path if URL generation fails
  442. await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
  443. if is_finished:
  444. await response.write_eof()
  445. except Exception as e:
  446. if DEBUG >= 2: print(f"Error processing image: {e}")
  447. if DEBUG >= 2: traceback.print_exc()
  448. await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
  449. stream_task = None
  450. def on_result(_request_id: str, result, is_finished: bool):
  451. nonlocal stream_task
  452. stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
  453. return _request_id == request_id and is_finished
  454. await callback.wait(on_result, timeout=self.response_timeout*10)
  455. if stream_task:
  456. # Wait for the stream task to complete before returning
  457. await stream_task
  458. return response
  459. except Exception as e:
  460. if DEBUG >= 2: traceback.print_exc()
  461. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  462. async def handle_delete_model(self, request):
  463. model_id = request.match_info.get('model_name')
  464. try:
  465. if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
  466. else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
  467. except Exception as e:
  468. if DEBUG >= 2: traceback.print_exc()
  469. return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
  470. async def handle_get_initial_models(self, request):
  471. model_data = {}
  472. for model_id in get_supported_models([[self.inference_engine_classname]]):
  473. model_data[model_id] = {
  474. "name": get_pretty_name(model_id),
  475. "downloaded": None, # Initially unknown
  476. "download_percentage": None, # Change from 0 to null
  477. "total_size": None,
  478. "total_downloaded": None,
  479. "loading": True # Add loading state
  480. }
  481. return web.json_response(model_data)
  482. async def handle_create_animation(self, request):
  483. try:
  484. data = await request.json()
  485. replacement_image_path = data.get("replacement_image_path")
  486. device_name = data.get("device_name", "Local Device")
  487. prompt_text = data.get("prompt", "")
  488. if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
  489. if not replacement_image_path:
  490. return web.json_response({"error": "replacement_image_path is required"}, status=400)
  491. # Create temp directory if it doesn't exist
  492. tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
  493. tmp_dir.mkdir(parents=True, exist_ok=True)
  494. # Generate unique output filename in temp directory
  495. output_filename = f"animation_{uuid.uuid4()}.mp4"
  496. output_path = str(tmp_dir/output_filename)
  497. if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
  498. # Create the animation
  499. create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
  500. return web.json_response({"status": "success", "output_path": output_path})
  501. except Exception as e:
  502. if DEBUG >= 2: traceback.print_exc()
  503. return web.json_response({"error": str(e)}, status=500)
  504. async def handle_post_download(self, request):
  505. try:
  506. data = await request.json()
  507. model_name = data.get("model")
  508. if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
  509. if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
  510. shard = build_full_shard(model_name, self.inference_engine_classname)
  511. if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
  512. asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
  513. return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
  514. except Exception as e:
  515. if DEBUG >= 2: traceback.print_exc()
  516. return web.json_response({"error": str(e)}, status=500)
  517. async def handle_get_topology(self, request):
  518. try:
  519. topology = self.node.current_topology
  520. if topology:
  521. return web.json_response(topology.to_json())
  522. else:
  523. return web.json_response({})
  524. except Exception as e:
  525. if DEBUG >= 2: traceback.print_exc()
  526. return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
  527. async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
  528. await self.token_queues[request_id].put((tokens, is_finished))
  529. async def run(self, host: str = "0.0.0.0", port: int = 52415):
  530. runner = web.AppRunner(self.app)
  531. await runner.setup()
  532. site = web.TCPSite(runner, host, port)
  533. await site.start()
  534. def base64_decode(self, base64_string):
  535. #decode and reshape image
  536. if base64_string.startswith('data:image'):
  537. base64_string = base64_string.split(',')[1]
  538. image_data = base64.b64decode(base64_string)
  539. img = Image.open(BytesIO(image_data))
  540. W, H = (dim - dim%64 for dim in (img.width, img.height))
  541. if W != img.width or H != img.height:
  542. if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
  543. img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
  544. img = mx.array(np.array(img))
  545. img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
  546. img = img[None]
  547. return img