chatgpt_api.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. import uuid
  2. import time
  3. import asyncio
  4. import json
  5. from pathlib import Path
  6. from transformers import AutoTokenizer
  7. from typing import List, Literal, Union, Dict
  8. from aiohttp import web
  9. import aiohttp_cors
  10. import traceback
  11. import os
  12. import signal
  13. import sys
  14. from exo import DEBUG, VERSION
  15. from exo.download.download_progress import RepoProgressEvent
  16. from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
  17. from exo.inference.tokenizers import resolve_tokenizer
  18. from exo.orchestration import Node
  19. from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
  20. from typing import Callable, Optional
  21. from PIL import Image
  22. import numpy as np
  23. import base64
  24. from io import BytesIO
  25. import mlx.core as mx
  26. class Message:
  27. def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
  28. self.role = role
  29. self.content = content
  30. def to_dict(self):
  31. return {"role": self.role, "content": self.content}
  32. class ChatCompletionRequest:
  33. def __init__(self, model: str, messages: List[Message], temperature: float):
  34. self.model = model
  35. self.messages = messages
  36. self.temperature = temperature
  37. def to_dict(self):
  38. return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
  39. def generate_completion(
  40. chat_request: ChatCompletionRequest,
  41. tokenizer,
  42. prompt: str,
  43. request_id: str,
  44. tokens: List[int],
  45. stream: bool,
  46. finish_reason: Union[Literal["length", "stop"], None],
  47. object_type: Literal["chat.completion", "text_completion"],
  48. ) -> dict:
  49. completion = {
  50. "id": f"chatcmpl-{request_id}",
  51. "object": object_type,
  52. "created": int(time.time()),
  53. "model": chat_request.model,
  54. "system_fingerprint": f"exo_{VERSION}",
  55. "choices": [{
  56. "index": 0,
  57. "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
  58. "logprobs": None,
  59. "finish_reason": finish_reason,
  60. }],
  61. }
  62. if not stream:
  63. completion["usage"] = {
  64. "prompt_tokens": len(tokenizer.encode(prompt)),
  65. "completion_tokens": len(tokens),
  66. "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
  67. }
  68. choice = completion["choices"][0]
  69. if object_type.startswith("chat.completion"):
  70. key_name = "delta" if stream else "message"
  71. choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
  72. elif object_type == "text_completion":
  73. choice["text"] = tokenizer.decode(tokens)
  74. else:
  75. ValueError(f"Unsupported response type: {object_type}")
  76. return completion
  77. def remap_messages(messages: List[Message]) -> List[Message]:
  78. remapped_messages = []
  79. last_image = None
  80. for message in messages:
  81. if not isinstance(message.content, list):
  82. remapped_messages.append(message)
  83. continue
  84. remapped_content = []
  85. for content in message.content:
  86. if isinstance(content, dict):
  87. if content.get("type") in ["image_url", "image"]:
  88. image_url = content.get("image_url", {}).get("url") or content.get("image")
  89. if image_url:
  90. last_image = {"type": "image", "image": image_url}
  91. remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
  92. else:
  93. remapped_content.append(content)
  94. else:
  95. remapped_content.append(content)
  96. remapped_messages.append(Message(role=message.role, content=remapped_content))
  97. if last_image:
  98. # Replace the last image placeholder with the actual image content
  99. for message in reversed(remapped_messages):
  100. for i, content in enumerate(message.content):
  101. if isinstance(content, dict):
  102. if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
  103. message.content[i] = last_image
  104. return remapped_messages
  105. return remapped_messages
  106. def build_prompt(tokenizer, _messages: List[Message]):
  107. messages = remap_messages(_messages)
  108. prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
  109. for message in messages:
  110. if not isinstance(message.content, list):
  111. continue
  112. return prompt
  113. def parse_message(data: dict):
  114. if "role" not in data or "content" not in data:
  115. raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
  116. return Message(data["role"], data["content"])
  117. def parse_chat_request(data: dict, default_model: str):
  118. return ChatCompletionRequest(
  119. data.get("model", default_model),
  120. [parse_message(msg) for msg in data["messages"]],
  121. data.get("temperature", 0.0),
  122. )
  123. class PromptSession:
  124. def __init__(self, request_id: str, timestamp: int, prompt: str):
  125. self.request_id = request_id
  126. self.timestamp = timestamp
  127. self.prompt = prompt
  128. class ChatGPTAPI:
  129. def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
  130. self.node = node
  131. self.inference_engine_classname = inference_engine_classname
  132. self.response_timeout = response_timeout
  133. self.on_chat_completion_request = on_chat_completion_request
  134. self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
  135. self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
  136. self.prev_token_lens: Dict[str, int] = {}
  137. self.stream_tasks: Dict[str, asyncio.Task] = {}
  138. self.default_model = default_model or "llama-3.2-1b"
  139. cors = aiohttp_cors.setup(self.app)
  140. cors_options = aiohttp_cors.ResourceOptions(
  141. allow_credentials=True,
  142. expose_headers="*",
  143. allow_headers="*",
  144. allow_methods="*",
  145. )
  146. cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
  147. cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
  148. cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  149. cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  150. cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  151. cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  152. cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
  153. cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
  154. cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
  155. cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
  156. cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
  157. if "__compiled__" not in globals():
  158. self.static_dir = Path(__file__).parent.parent/"tinychat"
  159. self.app.router.add_get("/", self.handle_root)
  160. self.app.router.add_static("/", self.static_dir, name="static")
  161. self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
  162. self.app.middlewares.append(self.timeout_middleware)
  163. self.app.middlewares.append(self.log_request)
  164. async def handle_quit(self, request):
  165. if DEBUG>=1: print("Received quit signal")
  166. response = web.json_response({"detail": "Quit signal received"}, status=200)
  167. await response.prepare(request)
  168. await response.write_eof()
  169. await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
  170. async def timeout_middleware(self, app, handler):
  171. async def middleware(request):
  172. try:
  173. return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
  174. except asyncio.TimeoutError:
  175. return web.json_response({"detail": "Request timed out"}, status=408)
  176. return middleware
  177. async def log_request(self, app, handler):
  178. async def middleware(request):
  179. if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
  180. return await handler(request)
  181. return middleware
  182. async def handle_root(self, request):
  183. return web.FileResponse(self.static_dir/"index.html")
  184. async def handle_healthcheck(self, request):
  185. return web.json_response({"status": "ok"})
  186. async def handle_model_support(self, request):
  187. return web.json_response({
  188. "model pool": {
  189. model_name: pretty_name.get(model_name, model_name)
  190. for model_name in get_supported_models(self.node.topology_inference_engines_pool)
  191. }
  192. })
  193. async def handle_get_models(self, request):
  194. return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
  195. async def handle_post_chat_token_encode(self, request):
  196. data = await request.json()
  197. shard = build_base_shard(self.default_model, self.inference_engine_classname)
  198. messages = [parse_message(msg) for msg in data.get("messages", [])]
  199. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
  200. return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
  201. async def handle_get_download_progress(self, request):
  202. progress_data = {}
  203. for node_id, progress_event in self.node.node_download_progress.items():
  204. if isinstance(progress_event, RepoProgressEvent):
  205. progress_data[node_id] = progress_event.to_dict()
  206. else:
  207. print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
  208. return web.json_response(progress_data)
  209. async def handle_post_chat_completions(self, request):
  210. data = await request.json()
  211. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  212. stream = data.get("stream", False)
  213. chat_request = parse_chat_request(data, self.default_model)
  214. if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
  215. chat_request.model = self.default_model
  216. if not chat_request.model or chat_request.model not in model_cards:
  217. if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
  218. chat_request.model = self.default_model
  219. shard = build_base_shard(chat_request.model, self.inference_engine_classname)
  220. if not shard:
  221. supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
  222. return web.json_response(
  223. {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
  224. status=400,
  225. )
  226. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
  227. if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
  228. prompt = build_prompt(tokenizer, chat_request.messages)
  229. request_id = str(uuid.uuid4())
  230. if self.on_chat_completion_request:
  231. try:
  232. self.on_chat_completion_request(request_id, chat_request, prompt)
  233. except Exception as e:
  234. if DEBUG >= 2: traceback.print_exc()
  235. # request_id = None
  236. # match = self.prompts.find_longest_prefix(prompt)
  237. # if match and len(prompt) > len(match[1].prompt):
  238. # if DEBUG >= 2:
  239. # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
  240. # request_id = match[1].request_id
  241. # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
  242. # # remove the matching prefix from the prompt
  243. # prompt = prompt[len(match[1].prompt):]
  244. # else:
  245. # request_id = str(uuid.uuid4())
  246. # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
  247. callback_id = f"chatgpt-api-wait-response-{request_id}"
  248. callback = self.node.on_token.register(callback_id)
  249. if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
  250. try:
  251. await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
  252. if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
  253. if stream:
  254. response = web.StreamResponse(
  255. status=200,
  256. reason="OK",
  257. headers={
  258. "Content-Type": "text/event-stream",
  259. "Cache-Control": "no-cache",
  260. },
  261. )
  262. await response.prepare(request)
  263. async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
  264. prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
  265. self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
  266. new_tokens = tokens[prev_last_tokens_len:]
  267. finish_reason = None
  268. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
  269. AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
  270. if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
  271. new_tokens = new_tokens[:-1]
  272. if is_finished:
  273. finish_reason = "stop"
  274. if is_finished and not finish_reason:
  275. finish_reason = "length"
  276. completion = generate_completion(
  277. chat_request,
  278. tokenizer,
  279. prompt,
  280. request_id,
  281. new_tokens,
  282. stream,
  283. finish_reason,
  284. "chat.completion",
  285. )
  286. if DEBUG >= 2: print(f"Streaming completion: {completion}")
  287. try:
  288. await response.write(f"data: {json.dumps(completion)}\n\n".encode())
  289. except Exception as e:
  290. if DEBUG >= 2: print(f"Error streaming completion: {e}")
  291. if DEBUG >= 2: traceback.print_exc()
  292. def on_result(_request_id: str, tokens: List[int], is_finished: bool):
  293. if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
  294. return _request_id == request_id and is_finished
  295. _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
  296. if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
  297. if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
  298. try:
  299. await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
  300. except asyncio.TimeoutError:
  301. print("WARNING: Stream task timed out. This should not happen.")
  302. await response.write_eof()
  303. return response
  304. else:
  305. _, tokens, _ = await callback.wait(
  306. lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
  307. timeout=self.response_timeout,
  308. )
  309. finish_reason = "length"
  310. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
  311. if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
  312. if tokens[-1] == eos_token_id:
  313. tokens = tokens[:-1]
  314. finish_reason = "stop"
  315. return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
  316. except asyncio.TimeoutError:
  317. return web.json_response({"detail": "Response generation timed out"}, status=408)
  318. except Exception as e:
  319. if DEBUG >= 2: traceback.print_exc()
  320. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  321. finally:
  322. deregistered_callback = self.node.on_token.deregister(callback_id)
  323. if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
  324. async def handle_post_image_generations(self, request):
  325. data = await request.json()
  326. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  327. stream = data.get("stream", False)
  328. model = data.get("model", "")
  329. prompt = data.get("prompt", "")
  330. image_url = data.get("image_url", "")
  331. print(f"model: {model}, prompt: {prompt}, stream: {stream}")
  332. shard = build_base_shard(model, self.inference_engine_classname)
  333. print(f"shard: {shard}")
  334. if not shard:
  335. return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
  336. request_id = str(uuid.uuid4())
  337. callback_id = f"chatgpt-api-wait-response-{request_id}"
  338. callback = self.node.on_token.register(callback_id)
  339. try:
  340. if image_url != "" and image_url != None:
  341. img = self.base64_decode(image_url)
  342. else:
  343. img = None
  344. await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
  345. response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
  346. await response.prepare(request)
  347. def get_progress_bar(current_step, total_steps, bar_length=50):
  348. # Calculate the percentage of completion
  349. percent = float(current_step) / total_steps
  350. # Calculate the number of hashes to display
  351. arrow = '-' * int(round(percent * bar_length) - 1) + '>'
  352. spaces = ' ' * (bar_length - len(arrow))
  353. # Create the progress bar string
  354. progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
  355. return progress_bar
  356. async def stream_image(_request_id: str, result, is_finished: bool):
  357. if isinstance(result, list):
  358. await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
  359. elif isinstance(result, np.ndarray):
  360. im = Image.fromarray(np.array(result))
  361. images_folder = get_exo_images_dir()
  362. # Save the image to a file
  363. image_filename = f"{_request_id}.png"
  364. image_path = images_folder / image_filename
  365. im.save(image_path)
  366. image_url = request.app.router['static_images'].url_for(filename=image_filename)
  367. base_url = f"{request.scheme}://{request.host}"
  368. # Construct the full URL correctly
  369. full_image_url = base_url + str(image_url)
  370. await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
  371. if is_finished:
  372. await response.write_eof()
  373. stream_task = None
  374. def on_result(_request_id: str, result, is_finished: bool):
  375. nonlocal stream_task
  376. stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
  377. return _request_id == request_id and is_finished
  378. await callback.wait(on_result, timeout=self.response_timeout*10)
  379. if stream_task:
  380. # Wait for the stream task to complete before returning
  381. await stream_task
  382. return response
  383. except Exception as e:
  384. if DEBUG >= 2: traceback.print_exc()
  385. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  386. async def run(self, host: str = "0.0.0.0", port: int = 52415):
  387. runner = web.AppRunner(self.app)
  388. await runner.setup()
  389. site = web.TCPSite(runner, host, port)
  390. await site.start()
  391. def base64_decode(self, base64_string):
  392. #decode and reshape image
  393. if base64_string.startswith('data:image'):
  394. base64_string = base64_string.split(',')[1]
  395. image_data = base64.b64decode(base64_string)
  396. img = Image.open(BytesIO(image_data))
  397. W, H = (dim - dim % 64 for dim in (img.width, img.height))
  398. if W != img.width or H != img.height:
  399. print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
  400. img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
  401. img = mx.array(np.array(img))
  402. img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
  403. img = img[None]
  404. return img