chatgpt_api.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. import uuid
  2. import time
  3. import asyncio
  4. import json
  5. import os
  6. from pathlib import Path
  7. from transformers import AutoTokenizer
  8. from typing import List, Literal, Union, Dict, Optional
  9. from aiohttp import web
  10. import aiohttp_cors
  11. import traceback
  12. import signal
  13. from exo import DEBUG, VERSION
  14. from exo.download.download_progress import RepoProgressEvent
  15. from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
  16. from exo.inference.tokenizers import resolve_tokenizer
  17. from exo.orchestration import Node
  18. from exo.models import build_base_shard, model_cards, get_repo, pretty_name
  19. from typing import Callable, Optional
  20. from PIL import Image
  21. import numpy as np
  22. import base64
  23. from io import BytesIO
  24. import platform
  25. if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
  26. import mlx.core as mx
  27. else:
  28. import numpy as mx
  29. import tempfile
  30. from exo.download.hf.hf_shard_download import HFShardDownloader
  31. import shutil
  32. from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
  33. from exo.apputil import create_animation_mp4
  34. class Message:
  35. def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
  36. self.role = role
  37. self.content = content
  38. self.tools = tools
  39. def to_dict(self):
  40. data = {"role": self.role, "content": self.content}
  41. if self.tools:
  42. data["tools"] = self.tools
  43. return data
  44. class ChatCompletionRequest:
  45. def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
  46. self.model = model
  47. self.messages = messages
  48. self.temperature = temperature
  49. self.tools = tools
  50. def to_dict(self):
  51. return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
  52. def generate_completion(
  53. chat_request: ChatCompletionRequest,
  54. tokenizer,
  55. prompt: str,
  56. request_id: str,
  57. tokens: List[int],
  58. stream: bool,
  59. finish_reason: Union[Literal["length", "stop"], None],
  60. object_type: Literal["chat.completion", "text_completion"],
  61. ) -> dict:
  62. completion = {
  63. "id": f"chatcmpl-{request_id}",
  64. "object": object_type,
  65. "created": int(time.time()),
  66. "model": chat_request.model,
  67. "system_fingerprint": f"exo_{VERSION}",
  68. "choices": [{
  69. "index": 0,
  70. "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
  71. "logprobs": None,
  72. "finish_reason": finish_reason,
  73. }],
  74. }
  75. if not stream:
  76. completion["usage"] = {
  77. "prompt_tokens": len(tokenizer.encode(prompt)),
  78. "completion_tokens": len(tokens),
  79. "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
  80. }
  81. choice = completion["choices"][0]
  82. if object_type.startswith("chat.completion"):
  83. key_name = "delta" if stream else "message"
  84. choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
  85. elif object_type == "text_completion":
  86. choice["text"] = tokenizer.decode(tokens)
  87. else:
  88. ValueError(f"Unsupported response type: {object_type}")
  89. return completion
  90. def remap_messages(messages: List[Message]) -> List[Message]:
  91. remapped_messages = []
  92. last_image = None
  93. for message in messages:
  94. if not isinstance(message.content, list):
  95. remapped_messages.append(message)
  96. continue
  97. remapped_content = []
  98. for content in message.content:
  99. if isinstance(content, dict):
  100. if content.get("type") in ["image_url", "image"]:
  101. image_url = content.get("image_url", {}).get("url") or content.get("image")
  102. if image_url:
  103. last_image = {"type": "image", "image": image_url}
  104. remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
  105. else:
  106. remapped_content.append(content)
  107. else:
  108. remapped_content.append(content)
  109. remapped_messages.append(Message(role=message.role, content=remapped_content))
  110. if last_image:
  111. # Replace the last image placeholder with the actual image content
  112. for message in reversed(remapped_messages):
  113. for i, content in enumerate(message.content):
  114. if isinstance(content, dict):
  115. if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
  116. message.content[i] = last_image
  117. return remapped_messages
  118. return remapped_messages
  119. def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
  120. messages = remap_messages(_messages)
  121. chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
  122. if tools:
  123. chat_template_args["tools"] = tools
  124. try:
  125. prompt = tokenizer.apply_chat_template(**chat_template_args)
  126. if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
  127. return prompt
  128. except UnicodeEncodeError:
  129. # Handle Unicode encoding by ensuring everything is UTF-8
  130. chat_template_args["conversation"] = [
  131. {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
  132. for k, v in m.to_dict().items()}
  133. for m in messages
  134. ]
  135. prompt = tokenizer.apply_chat_template(**chat_template_args)
  136. if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
  137. return prompt
  138. def parse_message(data: dict):
  139. if "role" not in data or "content" not in data:
  140. raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
  141. return Message(data["role"], data["content"], data.get("tools"))
  142. def parse_chat_request(data: dict, default_model: str):
  143. return ChatCompletionRequest(
  144. data.get("model", default_model),
  145. [parse_message(msg) for msg in data["messages"]],
  146. data.get("temperature", 0.0),
  147. data.get("tools", None),
  148. )
  149. class PromptSession:
  150. def __init__(self, request_id: str, timestamp: int, prompt: str):
  151. self.request_id = request_id
  152. self.timestamp = timestamp
  153. self.prompt = prompt
  154. class ChatGPTAPI:
  155. def __init__(
  156. self,
  157. node: Node,
  158. inference_engine_classname: str,
  159. response_timeout: int = 90,
  160. on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
  161. default_model: Optional[str] = None,
  162. system_prompt: Optional[str] = None
  163. ):
  164. self.node = node
  165. self.inference_engine_classname = inference_engine_classname
  166. self.response_timeout = response_timeout
  167. self.on_chat_completion_request = on_chat_completion_request
  168. self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
  169. self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
  170. self.prev_token_lens: Dict[str, int] = {}
  171. self.stream_tasks: Dict[str, asyncio.Task] = {}
  172. self.default_model = default_model or "llama-3.2-1b"
  173. self.system_prompt = system_prompt
  174. cors = aiohttp_cors.setup(self.app)
  175. cors_options = aiohttp_cors.ResourceOptions(
  176. allow_credentials=True,
  177. expose_headers="*",
  178. allow_headers="*",
  179. allow_methods="*",
  180. )
  181. cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
  182. cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
  183. cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  184. cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  185. cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  186. cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  187. cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
  188. cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
  189. cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
  190. cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
  191. cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
  192. cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
  193. cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
  194. cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
  195. cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
  196. cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
  197. # Add static routes
  198. if "__compiled__" not in globals():
  199. self.static_dir = Path(__file__).parent.parent/"tinychat"
  200. self.app.router.add_get("/", self.handle_root)
  201. self.app.router.add_static("/", self.static_dir, name="static")
  202. # Always add images route, regardless of compilation status
  203. self.images_dir = get_exo_images_dir()
  204. self.images_dir.mkdir(parents=True, exist_ok=True)
  205. self.app.router.add_static('/images/', self.images_dir, name='static_images')
  206. self.app.middlewares.append(self.timeout_middleware)
  207. self.app.middlewares.append(self.log_request)
  208. async def handle_quit(self, request):
  209. if DEBUG >= 1: print("Received quit signal")
  210. response = web.json_response({"detail": "Quit signal received"}, status=200)
  211. await response.prepare(request)
  212. await response.write_eof()
  213. await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
  214. async def timeout_middleware(self, app, handler):
  215. async def middleware(request):
  216. try:
  217. return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
  218. except asyncio.TimeoutError:
  219. return web.json_response({"detail": "Request timed out"}, status=408)
  220. return middleware
  221. async def log_request(self, app, handler):
  222. async def middleware(request):
  223. if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
  224. return await handler(request)
  225. return middleware
  226. async def handle_root(self, request):
  227. return web.FileResponse(self.static_dir/"index.html")
  228. async def handle_healthcheck(self, request):
  229. return web.json_response({"status": "ok"})
  230. async def handle_model_support(self, request):
  231. try:
  232. response = web.StreamResponse(status=200, reason='OK', headers={
  233. 'Content-Type': 'text/event-stream',
  234. 'Cache-Control': 'no-cache',
  235. 'Connection': 'keep-alive',
  236. })
  237. await response.prepare(request)
  238. async def process_model(model_name, pretty):
  239. if model_name in model_cards:
  240. model_info = model_cards[model_name]
  241. if self.inference_engine_classname in model_info.get("repo", {}):
  242. shard = build_base_shard(model_name, self.inference_engine_classname)
  243. if shard:
  244. downloader = HFShardDownloader(quick_check=True)
  245. downloader.current_shard = shard
  246. downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
  247. status = await downloader.get_shard_download_status()
  248. download_percentage = status.get("overall") if status else None
  249. total_size = status.get("total_size") if status else None
  250. total_downloaded = status.get("total_downloaded") if status else False
  251. model_data = {
  252. model_name: {
  253. "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
  254. "total_downloaded": total_downloaded
  255. }
  256. }
  257. await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
  258. # Process all models in parallel
  259. await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
  260. await response.write(b"data: [DONE]\n\n")
  261. return response
  262. except Exception as e:
  263. print(f"Error in handle_model_support: {str(e)}")
  264. traceback.print_exc()
  265. return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
  266. async def handle_get_models(self, request):
  267. models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
  268. return web.json_response({"object": "list", "data": models_list})
  269. async def handle_post_chat_token_encode(self, request):
  270. data = await request.json()
  271. model = data.get("model", self.default_model)
  272. if model and model.startswith("gpt-"): # Handle gpt- model requests
  273. model = self.default_model
  274. if not model or model not in model_cards:
  275. if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
  276. model = self.default_model
  277. shard = build_base_shard(model, self.inference_engine_classname)
  278. messages = [parse_message(msg) for msg in data.get("messages", [])]
  279. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
  280. prompt = build_prompt(tokenizer, messages, data.get("tools", None))
  281. tokens = tokenizer.encode(prompt)
  282. return web.json_response({
  283. "length": len(prompt),
  284. "num_tokens": len(tokens),
  285. "encoded_tokens": tokens,
  286. "encoded_prompt": prompt,
  287. })
  288. async def handle_get_download_progress(self, request):
  289. progress_data = {}
  290. for node_id, progress_event in self.node.node_download_progress.items():
  291. if isinstance(progress_event, RepoProgressEvent):
  292. progress_data[node_id] = progress_event.to_dict()
  293. else:
  294. print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
  295. return web.json_response(progress_data)
  296. async def handle_post_chat_completions(self, request):
  297. data = await request.json()
  298. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  299. stream = data.get("stream", False)
  300. chat_request = parse_chat_request(data, self.default_model)
  301. if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
  302. chat_request.model = self.default_model
  303. if not chat_request.model or chat_request.model not in model_cards:
  304. if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
  305. chat_request.model = self.default_model
  306. shard = build_base_shard(chat_request.model, self.inference_engine_classname)
  307. if not shard:
  308. supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
  309. return web.json_response(
  310. {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
  311. status=400,
  312. )
  313. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
  314. if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
  315. # Add system prompt if set
  316. if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
  317. chat_request.messages.insert(0, Message("system", self.system_prompt))
  318. prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
  319. request_id = str(uuid.uuid4())
  320. if self.on_chat_completion_request:
  321. try:
  322. self.on_chat_completion_request(request_id, chat_request, prompt)
  323. except Exception as e:
  324. if DEBUG >= 2: traceback.print_exc()
  325. # request_id = None
  326. # match = self.prompts.find_longest_prefix(prompt)
  327. # if match and len(prompt) > len(match[1].prompt):
  328. # if DEBUG >= 2:
  329. # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
  330. # request_id = match[1].request_id
  331. # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
  332. # # remove the matching prefix from the prompt
  333. # prompt = prompt[len(match[1].prompt):]
  334. # else:
  335. # request_id = str(uuid.uuid4())
  336. # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
  337. callback_id = f"chatgpt-api-wait-response-{request_id}"
  338. callback = self.node.on_token.register(callback_id)
  339. if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
  340. try:
  341. await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
  342. if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
  343. if stream:
  344. response = web.StreamResponse(
  345. status=200,
  346. reason="OK",
  347. headers={
  348. "Content-Type": "text/event-stream",
  349. "Cache-Control": "no-cache",
  350. },
  351. )
  352. await response.prepare(request)
  353. async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
  354. prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
  355. self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
  356. new_tokens = tokens[prev_last_tokens_len:]
  357. finish_reason = None
  358. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
  359. AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
  360. if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
  361. new_tokens = new_tokens[:-1]
  362. if is_finished:
  363. finish_reason = "stop"
  364. if is_finished and not finish_reason:
  365. finish_reason = "length"
  366. completion = generate_completion(
  367. chat_request,
  368. tokenizer,
  369. prompt,
  370. request_id,
  371. new_tokens,
  372. stream,
  373. finish_reason,
  374. "chat.completion",
  375. )
  376. if DEBUG >= 2: print(f"Streaming completion: {completion}")
  377. try:
  378. await response.write(f"data: {json.dumps(completion)}\n\n".encode())
  379. except Exception as e:
  380. if DEBUG >= 2: print(f"Error streaming completion: {e}")
  381. if DEBUG >= 2: traceback.print_exc()
  382. def on_result(_request_id: str, tokens: List[int], is_finished: bool):
  383. if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
  384. return _request_id == request_id and is_finished
  385. _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
  386. if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
  387. if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
  388. try:
  389. await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
  390. except asyncio.TimeoutError:
  391. print("WARNING: Stream task timed out. This should not happen.")
  392. await response.write_eof()
  393. return response
  394. else:
  395. _, tokens, _ = await callback.wait(
  396. lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
  397. timeout=self.response_timeout,
  398. )
  399. finish_reason = "length"
  400. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
  401. if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
  402. if tokens[-1] == eos_token_id:
  403. tokens = tokens[:-1]
  404. finish_reason = "stop"
  405. return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
  406. except asyncio.TimeoutError:
  407. return web.json_response({"detail": "Response generation timed out"}, status=408)
  408. except Exception as e:
  409. if DEBUG >= 2: traceback.print_exc()
  410. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  411. finally:
  412. deregistered_callback = self.node.on_token.deregister(callback_id)
  413. if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
  414. async def handle_post_image_generations(self, request):
  415. data = await request.json()
  416. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  417. stream = data.get("stream", False)
  418. model = data.get("model", "")
  419. prompt = data.get("prompt", "")
  420. image_url = data.get("image_url", "")
  421. if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
  422. shard = build_base_shard(model, self.inference_engine_classname)
  423. if DEBUG >= 2: print(f"shard: {shard}")
  424. if not shard:
  425. return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
  426. request_id = str(uuid.uuid4())
  427. callback_id = f"chatgpt-api-wait-response-{request_id}"
  428. callback = self.node.on_token.register(callback_id)
  429. try:
  430. if image_url != "" and image_url != None:
  431. img = self.base64_decode(image_url)
  432. else:
  433. img = None
  434. await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
  435. response = web.StreamResponse(status=200, reason='OK', headers={
  436. 'Content-Type': 'application/octet-stream',
  437. "Cache-Control": "no-cache",
  438. })
  439. await response.prepare(request)
  440. def get_progress_bar(current_step, total_steps, bar_length=50):
  441. # Calculate the percentage of completion
  442. percent = float(current_step)/total_steps
  443. # Calculate the number of hashes to display
  444. arrow = '-'*int(round(percent*bar_length) - 1) + '>'
  445. spaces = ' '*(bar_length - len(arrow))
  446. # Create the progress bar string
  447. progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
  448. return progress_bar
  449. async def stream_image(_request_id: str, result, is_finished: bool):
  450. if isinstance(result, list):
  451. await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
  452. elif isinstance(result, np.ndarray):
  453. try:
  454. im = Image.fromarray(np.array(result))
  455. # Save the image to a file
  456. image_filename = f"{_request_id}.png"
  457. image_path = self.images_dir/image_filename
  458. im.save(image_path)
  459. # Get URL for the saved image
  460. try:
  461. image_url = request.app.router['static_images'].url_for(filename=image_filename)
  462. base_url = f"{request.scheme}://{request.host}"
  463. full_image_url = base_url + str(image_url)
  464. await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
  465. except KeyError as e:
  466. if DEBUG >= 2: print(f"Error getting image URL: {e}")
  467. # Fallback to direct file path if URL generation fails
  468. await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
  469. if is_finished:
  470. await response.write_eof()
  471. except Exception as e:
  472. if DEBUG >= 2: print(f"Error processing image: {e}")
  473. if DEBUG >= 2: traceback.print_exc()
  474. await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
  475. stream_task = None
  476. def on_result(_request_id: str, result, is_finished: bool):
  477. nonlocal stream_task
  478. stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
  479. return _request_id == request_id and is_finished
  480. await callback.wait(on_result, timeout=self.response_timeout*10)
  481. if stream_task:
  482. # Wait for the stream task to complete before returning
  483. await stream_task
  484. return response
  485. except Exception as e:
  486. if DEBUG >= 2: traceback.print_exc()
  487. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  488. async def handle_delete_model(self, request):
  489. try:
  490. model_name = request.match_info.get('model_name')
  491. if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
  492. if not model_name or model_name not in model_cards:
  493. return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
  494. shard = build_base_shard(model_name, self.inference_engine_classname)
  495. if not shard:
  496. return web.json_response({"detail": "Could not build shard for model"}, status=400)
  497. repo_id = get_repo(shard.model_id, self.inference_engine_classname)
  498. if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
  499. # Get the HF cache directory using the helper function
  500. hf_home = get_hf_home()
  501. cache_dir = get_repo_root(repo_id)
  502. if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
  503. if os.path.exists(cache_dir):
  504. if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
  505. try:
  506. shutil.rmtree(cache_dir)
  507. return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
  508. except Exception as e:
  509. return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
  510. else:
  511. return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
  512. except Exception as e:
  513. print(f"Error in handle_delete_model: {str(e)}")
  514. traceback.print_exc()
  515. return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
  516. async def handle_get_initial_models(self, request):
  517. model_data = {}
  518. for model_name, pretty in pretty_name.items():
  519. model_data[model_name] = {
  520. "name": pretty,
  521. "downloaded": None, # Initially unknown
  522. "download_percentage": None, # Change from 0 to null
  523. "total_size": None,
  524. "total_downloaded": None,
  525. "loading": True # Add loading state
  526. }
  527. return web.json_response(model_data)
  528. async def handle_create_animation(self, request):
  529. try:
  530. data = await request.json()
  531. replacement_image_path = data.get("replacement_image_path")
  532. device_name = data.get("device_name", "Local Device")
  533. prompt_text = data.get("prompt", "")
  534. if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
  535. if not replacement_image_path:
  536. return web.json_response({"error": "replacement_image_path is required"}, status=400)
  537. # Create temp directory if it doesn't exist
  538. tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
  539. tmp_dir.mkdir(parents=True, exist_ok=True)
  540. # Generate unique output filename in temp directory
  541. output_filename = f"animation_{uuid.uuid4()}.mp4"
  542. output_path = str(tmp_dir/output_filename)
  543. if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
  544. # Create the animation
  545. create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
  546. return web.json_response({"status": "success", "output_path": output_path})
  547. except Exception as e:
  548. if DEBUG >= 2: traceback.print_exc()
  549. return web.json_response({"error": str(e)}, status=500)
  550. async def handle_post_download(self, request):
  551. try:
  552. data = await request.json()
  553. model_name = data.get("model")
  554. if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
  555. if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
  556. shard = build_base_shard(model_name, self.inference_engine_classname)
  557. if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
  558. asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
  559. return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
  560. except Exception as e:
  561. if DEBUG >= 2: traceback.print_exc()
  562. return web.json_response({"error": str(e)}, status=500)
  563. async def handle_get_topology(self, request):
  564. try:
  565. topology = self.node.current_topology
  566. if topology:
  567. return web.json_response(topology.to_json())
  568. else:
  569. return web.json_response({})
  570. except Exception as e:
  571. if DEBUG >= 2: traceback.print_exc()
  572. return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
  573. async def run(self, host: str = "0.0.0.0", port: int = 52415):
  574. runner = web.AppRunner(self.app)
  575. await runner.setup()
  576. site = web.TCPSite(runner, host, port)
  577. await site.start()
  578. def base64_decode(self, base64_string):
  579. #decode and reshape image
  580. if base64_string.startswith('data:image'):
  581. base64_string = base64_string.split(',')[1]
  582. image_data = base64.b64decode(base64_string)
  583. img = Image.open(BytesIO(image_data))
  584. W, H = (dim - dim%64 for dim in (img.width, img.height))
  585. if W != img.width or H != img.height:
  586. if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
  587. img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
  588. img = mx.array(np.array(img))
  589. img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
  590. img = img[None]
  591. return img