chatgpt_api.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. import uuid
  2. import time
  3. import asyncio
  4. import json
  5. from pathlib import Path
  6. from transformers import AutoTokenizer, AutoProcessor
  7. from typing import List, Literal, Union, Dict
  8. from aiohttp import web
  9. import aiohttp_cors
  10. import traceback
  11. from exo import DEBUG, VERSION
  12. from exo.helpers import terminal_link, PrefixDict
  13. from exo.inference.shard import Shard
  14. from exo.orchestration import Node
  15. shard_mappings = {
  16. ### llama
  17. "llama-3.1-8b": {
  18. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
  19. "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
  20. },
  21. "llama-3.1-70b": {
  22. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
  23. "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
  24. },
  25. "llama-3.1-405b": {
  26. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
  27. },
  28. "llama-3-8b": {
  29. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
  30. "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
  31. },
  32. "llama-3-70b": {
  33. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
  34. "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
  35. },
  36. ### mistral
  37. "mistral-nemo": {
  38. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
  39. },
  40. "mistral-large": {
  41. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
  42. },
  43. ### deepseek v2
  44. "deepseek-coder-v2-lite": {
  45. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
  46. },
  47. ### llava
  48. "llava-1.5-7b-hf": {
  49. "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
  50. },
  51. }
  52. class Message:
  53. def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
  54. self.role = role
  55. self.content = content
  56. def to_dict(self):
  57. return {
  58. "role": self.role,
  59. "content": self.content
  60. }
  61. class ChatCompletionRequest:
  62. def __init__(self, model: str, messages: List[Message], temperature: float):
  63. self.model = model
  64. self.messages = messages
  65. self.temperature = temperature
  66. def to_dict(self):
  67. return {
  68. "model": self.model,
  69. "messages": [message.to_dict() for message in self.messages],
  70. "temperature": self.temperature
  71. }
  72. async def resolve_tokenizer(model_id: str):
  73. try:
  74. if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
  75. processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
  76. if not hasattr(processor, 'eos_token_id'):
  77. processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
  78. if not hasattr(processor, 'encode'):
  79. processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
  80. if not hasattr(processor, 'decode'):
  81. processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
  82. return processor
  83. except Exception as e:
  84. if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
  85. if DEBUG >= 4: print(traceback.format_exc())
  86. try:
  87. if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
  88. return AutoTokenizer.from_pretrained(model_id)
  89. except Exception as e:
  90. if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
  91. if DEBUG >= 4: print(traceback.format_exc())
  92. raise ValueError(f"[TODO] Unsupported model: {model_id}")
  93. def generate_completion(
  94. chat_request: ChatCompletionRequest,
  95. tokenizer,
  96. prompt: str,
  97. request_id: str,
  98. tokens: List[int],
  99. stream: bool,
  100. finish_reason: Union[Literal["length", "stop"], None],
  101. object_type: Literal["chat.completion", "text_completion"],
  102. ) -> dict:
  103. completion = {
  104. "id": f"chatcmpl-{request_id}",
  105. "object": object_type,
  106. "created": int(time.time()),
  107. "model": chat_request.model,
  108. "system_fingerprint": f"exo_{VERSION}",
  109. "choices": [
  110. {
  111. "index": 0,
  112. "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
  113. "logprobs": None,
  114. "finish_reason": finish_reason,
  115. }
  116. ],
  117. }
  118. if not stream:
  119. completion["usage"] = {
  120. "prompt_tokens": len(tokenizer.encode(prompt)),
  121. "completion_tokens": len(tokens),
  122. "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
  123. }
  124. choice = completion["choices"][0]
  125. if object_type.startswith("chat.completion"):
  126. key_name = "delta" if stream else "message"
  127. choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
  128. elif object_type == "text_completion":
  129. choice["text"] = tokenizer.decode(tokens)
  130. else:
  131. ValueError(f"Unsupported response type: {object_type}")
  132. return completion
  133. def remap_messages(messages: List[Message]) -> List[Message]:
  134. remapped_messages = []
  135. last_image = None
  136. for message in messages:
  137. if not isinstance(message.content, list):
  138. remapped_messages.append(message)
  139. continue
  140. remapped_content = []
  141. for content in message.content:
  142. if isinstance(content, dict):
  143. if content.get("type") in ["image_url", "image"]:
  144. image_url = content.get("image_url", {}).get("url") or content.get("image")
  145. if image_url:
  146. last_image = {"type": "image", "image": image_url}
  147. remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
  148. else:
  149. remapped_content.append(content)
  150. else:
  151. remapped_content.append(content)
  152. remapped_messages.append(Message(role=message.role, content=remapped_content))
  153. if last_image:
  154. # Replace the last image placeholder with the actual image content
  155. for message in reversed(remapped_messages):
  156. for i, content in enumerate(message.content):
  157. if isinstance(content, dict):
  158. if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
  159. message.content[i] = last_image
  160. return remapped_messages
  161. return remapped_messages
  162. def build_prompt(tokenizer, _messages: List[Message]):
  163. messages = remap_messages(_messages)
  164. prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  165. image_str = None
  166. for message in messages:
  167. if not isinstance(message.content, list):
  168. continue
  169. for content in message.content:
  170. # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
  171. # follows the convention in https://platform.openai.com/docs/guides/vision
  172. if isinstance(content, dict) and content.get("type", None) == "image":
  173. image_str = content.get("image", None)
  174. break
  175. return prompt, image_str
  176. def parse_message(data: dict):
  177. if "role" not in data or "content" not in data:
  178. raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
  179. return Message(data["role"], data["content"])
  180. def parse_chat_request(data: dict):
  181. return ChatCompletionRequest(
  182. data.get("model", "llama-3.1-8b"),
  183. [parse_message(msg) for msg in data["messages"]],
  184. data.get("temperature", 0.0),
  185. )
  186. class PromptSession:
  187. def __init__(self, request_id: str, timestamp: int, prompt: str):
  188. self.request_id = request_id
  189. self.timestamp = timestamp
  190. self.prompt = prompt
  191. class ChatGPTAPI:
  192. def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
  193. self.node = node
  194. self.inference_engine_classname = inference_engine_classname
  195. self.response_timeout_secs = response_timeout_secs
  196. self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
  197. self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
  198. self.prev_token_lens: Dict[str, int] = {}
  199. self.stream_tasks: Dict[str, asyncio.Task] = {}
  200. cors = aiohttp_cors.setup(self.app)
  201. cors_options = aiohttp_cors.ResourceOptions(
  202. allow_credentials=True,
  203. expose_headers="*",
  204. allow_headers="*",
  205. allow_methods="*",
  206. )
  207. cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  208. cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  209. self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
  210. self.app.router.add_get("/", self.handle_root)
  211. self.app.router.add_static("/", self.static_dir, name="static")
  212. # Add middleware to log every request
  213. self.app.middlewares.append(self.log_request)
  214. async def log_request(self, app, handler):
  215. async def middleware(request):
  216. if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
  217. return await handler(request)
  218. return middleware
  219. async def handle_root(self, request):
  220. return web.FileResponse(self.static_dir / "index.html")
  221. async def handle_post_chat_token_encode(self, request):
  222. data = await request.json()
  223. shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
  224. messages = [parse_message(msg) for msg in data.get("messages", [])]
  225. tokenizer = await resolve_tokenizer(shard.model_id)
  226. return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
  227. async def handle_post_chat_completions(self, request):
  228. data = await request.json()
  229. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  230. stream = data.get("stream", False)
  231. chat_request = parse_chat_request(data)
  232. if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
  233. chat_request.model = "llama-3.1-8b"
  234. if not chat_request.model or chat_request.model not in shard_mappings:
  235. if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
  236. chat_request.model = "llama-3.1-8b"
  237. shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
  238. if not shard:
  239. supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
  240. return web.json_response(
  241. {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
  242. status=400,
  243. )
  244. tokenizer = await resolve_tokenizer(shard.model_id)
  245. if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
  246. prompt, image_str = build_prompt(tokenizer, chat_request.messages)
  247. request_id = str(uuid.uuid4())
  248. # request_id = None
  249. # match = self.prompts.find_longest_prefix(prompt)
  250. # if match and len(prompt) > len(match[1].prompt):
  251. # if DEBUG >= 2:
  252. # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
  253. # request_id = match[1].request_id
  254. # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
  255. # # remove the matching prefix from the prompt
  256. # prompt = prompt[len(match[1].prompt):]
  257. # else:
  258. # request_id = str(uuid.uuid4())
  259. # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
  260. callback_id = f"chatgpt-api-wait-response-{request_id}"
  261. callback = self.node.on_token.register(callback_id)
  262. if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
  263. try:
  264. await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
  265. except Exception as e:
  266. if DEBUG >= 2: traceback.print_exc()
  267. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  268. try:
  269. if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
  270. if stream:
  271. response = web.StreamResponse(
  272. status=200,
  273. reason="OK",
  274. headers={
  275. "Content-Type": "application/json",
  276. "Cache-Control": "no-cache",
  277. },
  278. )
  279. await response.prepare(request)
  280. async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
  281. prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
  282. self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
  283. new_tokens = tokens[prev_last_tokens_len:]
  284. finish_reason = None
  285. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
  286. if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
  287. new_tokens = new_tokens[:-1]
  288. if is_finished:
  289. finish_reason = "stop"
  290. if is_finished and not finish_reason:
  291. finish_reason = "length"
  292. completion = generate_completion(
  293. chat_request,
  294. tokenizer,
  295. prompt,
  296. request_id,
  297. new_tokens,
  298. stream,
  299. finish_reason,
  300. "chat.completion",
  301. )
  302. if DEBUG >= 2: print(f"Streaming completion: {completion}")
  303. try:
  304. await response.write(f"data: {json.dumps(completion)}\n\n".encode())
  305. except Exception as e:
  306. if DEBUG >= 2: print(f"Error streaming completion: {e}")
  307. if DEBUG >= 2: traceback.print_exc()
  308. def on_result(_request_id: str, tokens: List[int], is_finished: bool):
  309. self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
  310. return _request_id == request_id and is_finished
  311. _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
  312. if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
  313. if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
  314. try:
  315. await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
  316. except asyncio.TimeoutError:
  317. print("WARNING: Stream task timed out. This should not happen.")
  318. await response.write_eof()
  319. return response
  320. else:
  321. _, tokens, _ = await callback.wait(
  322. lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
  323. timeout=self.response_timeout_secs,
  324. )
  325. finish_reason = "length"
  326. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
  327. if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
  328. if tokens[-1] == eos_token_id:
  329. tokens = tokens[:-1]
  330. finish_reason = "stop"
  331. return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
  332. except asyncio.TimeoutError:
  333. return web.json_response({"detail": "Response generation timed out"}, status=408)
  334. finally:
  335. deregistered_callback = self.node.on_token.deregister(callback_id)
  336. if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
  337. async def run(self, host: str = "0.0.0.0", port: int = 8000):
  338. runner = web.AppRunner(self.app)
  339. await runner.setup()
  340. site = web.TCPSite(runner, host, port)
  341. await site.start()