chatgpt_api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import uuid
  2. import time
  3. import asyncio
  4. import json
  5. from pathlib import Path
  6. from transformers import AutoTokenizer, AutoProcessor
  7. from typing import List, Literal, Union, Dict
  8. from aiohttp import web
  9. import aiohttp_cors
  10. from exo import DEBUG, VERSION
  11. from exo.helpers import terminal_link
  12. from exo.inference.shard import Shard
  13. from exo.orchestration import Node
  14. shard_mappings = {
  15. ### llama
  16. "llama-3.1-8b": {
  17. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
  18. },
  19. "llama-3.1-70b": {
  20. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
  21. },
  22. "llama-3.1-405b": {
  23. "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
  24. },
  25. "llama-3-8b": {
  26. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
  27. "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
  28. },
  29. "llama-3-70b": {
  30. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
  31. "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
  32. },
  33. ### mistral
  34. "mistral-nemo": {
  35. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
  36. },
  37. "mistral-large": {
  38. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
  39. },
  40. ### deepseek v2
  41. "deepseek-coder-v2-lite": {
  42. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
  43. },
  44. ### llava
  45. "llava-1.5-7b-hf": {
  46. "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
  47. },
  48. }
  49. class Message:
  50. def __init__(self, role: str, content: Union[str, list]):
  51. self.role = role
  52. self.content = content
  53. class ChatCompletionRequest:
  54. def __init__(self, model: str, messages: List[Message], temperature: float):
  55. self.model = model
  56. self.messages = messages
  57. self.temperature = temperature
  58. def resolve_tinygrad_tokenizer(model_id: str):
  59. if model_id == "llama3-8b-sfr":
  60. return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
  61. elif model_id == "llama3-70b-sfr":
  62. return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
  63. else:
  64. raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
  65. async def resolve_tokenizer(model_id: str):
  66. try:
  67. if DEBUG >= 2: print(f"Trying to AutoProcessor for {model_id}")
  68. processor = AutoProcessor.from_pretrained(model_id)
  69. processor.eos_token_id = processor.tokenizer.eos_token_id
  70. processor.encode = processor.tokenizer.encode
  71. return processor
  72. except Exception as e:
  73. if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}")
  74. import traceback
  75. if DEBUG >= 2: print(traceback.format_exc())
  76. try:
  77. if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
  78. return AutoTokenizer.from_pretrained(model_id)
  79. except Exception as e:
  80. if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
  81. import traceback
  82. if DEBUG >= 2: print(traceback.format_exc())
  83. try:
  84. if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
  85. return resolve_tinygrad_tokenizer(model_id)
  86. except Exception as e:
  87. if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
  88. import traceback
  89. if DEBUG >= 2: print(traceback.format_exc())
  90. if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
  91. from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
  92. return load_tokenizer(await get_model_path(model_id))
  93. def generate_completion(
  94. chat_request: ChatCompletionRequest,
  95. tokenizer,
  96. prompt: str,
  97. request_id: str,
  98. tokens: List[int],
  99. stream: bool,
  100. finish_reason: Union[Literal["length", "stop"], None],
  101. object_type: Literal["chat.completion", "text_completion"],
  102. ) -> dict:
  103. completion = {
  104. "id": f"chatcmpl-{request_id}",
  105. "object": object_type,
  106. "created": int(time.time()),
  107. "model": chat_request.model,
  108. "system_fingerprint": f"exo_{VERSION}",
  109. "choices": [
  110. {
  111. "index": 0,
  112. "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
  113. "logprobs": None,
  114. "finish_reason": finish_reason,
  115. }
  116. ],
  117. }
  118. if not stream:
  119. completion["usage"] = {
  120. "prompt_tokens": len(tokenizer.encode(prompt)),
  121. "completion_tokens": len(tokens),
  122. "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
  123. }
  124. choice = completion["choices"][0]
  125. if object_type.startswith("chat.completion"):
  126. key_name = "delta" if stream else "message"
  127. choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
  128. elif object_type == "text_completion":
  129. choice["text"] = tokenizer.decode(tokens)
  130. else:
  131. ValueError(f"Unsupported response type: {object_type}")
  132. return completion
  133. def build_prompt(tokenizer, messages: List[Message]):
  134. prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  135. image_str = None
  136. for message in messages:
  137. if not isinstance(message.content, list):
  138. continue
  139. for content in message.content:
  140. if content.get("type", None) == "image":
  141. image_str = content.get("image", None)
  142. break
  143. return prompt, image_str
  144. def parse_message(data: dict):
  145. if "role" not in data or "content" not in data:
  146. raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
  147. return Message(data["role"], data["content"])
  148. def parse_chat_request(data: dict):
  149. return ChatCompletionRequest(
  150. data.get("model", "llama-3.1-8b"),
  151. [parse_message(msg) for msg in data["messages"]],
  152. data.get("temperature", 0.0),
  153. )
  154. class ChatGPTAPI:
  155. def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
  156. self.node = node
  157. self.inference_engine_classname = inference_engine_classname
  158. self.response_timeout_secs = response_timeout_secs
  159. self.app = web.Application()
  160. self.prev_token_lens: Dict[str, int] = {}
  161. self.stream_tasks: Dict[str, asyncio.Task] = {}
  162. cors = aiohttp_cors.setup(self.app)
  163. cors_options = aiohttp_cors.ResourceOptions(
  164. allow_credentials=True,
  165. expose_headers="*",
  166. allow_headers="*",
  167. allow_methods="*",
  168. )
  169. cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
  170. cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
  171. self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
  172. self.app.router.add_get("/", self.handle_root)
  173. self.app.router.add_static("/", self.static_dir, name="static")
  174. # Add middleware to log every request
  175. self.app.middlewares.append(self.log_request)
  176. async def log_request(self, app, handler):
  177. async def middleware(request):
  178. if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
  179. return await handler(request)
  180. return middleware
  181. async def handle_root(self, request):
  182. print(f"Handling root request from {request.remote}")
  183. return web.FileResponse(self.static_dir / "index.html")
  184. async def handle_post_chat_token_encode(self, request):
  185. data = await request.json()
  186. shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
  187. messages = [parse_message(msg) for msg in data.get("messages", [])]
  188. tokenizer = await resolve_tokenizer(shard.model_id)
  189. return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
  190. async def handle_post_chat_completions(self, request):
  191. data = await request.json()
  192. if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
  193. stream = data.get("stream", False)
  194. chat_request = parse_chat_request(data)
  195. if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
  196. chat_request.model = "llama-3.1-8b"
  197. if not chat_request.model or chat_request.model not in shard_mappings:
  198. if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
  199. chat_request.model = "llama-3.1-8b"
  200. shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
  201. if not shard:
  202. supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
  203. return web.json_response(
  204. {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
  205. status=400,
  206. )
  207. request_id = str(uuid.uuid4())
  208. tokenizer = await resolve_tokenizer(shard.model_id)
  209. if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
  210. prompt, image_str = build_prompt(tokenizer, chat_request.messages)
  211. callback_id = f"chatgpt-api-wait-response-{request_id}"
  212. callback = self.node.on_token.register(callback_id)
  213. if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
  214. try:
  215. await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
  216. except Exception as e:
  217. if DEBUG >= 2:
  218. import traceback
  219. traceback.print_exc()
  220. return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  221. try:
  222. if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
  223. if stream:
  224. response = web.StreamResponse(
  225. status=200,
  226. reason="OK",
  227. headers={
  228. "Content-Type": "application/json",
  229. "Cache-Control": "no-cache",
  230. },
  231. )
  232. await response.prepare(request)
  233. async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
  234. prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
  235. self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
  236. new_tokens = tokens[prev_last_tokens_len:]
  237. finish_reason = None
  238. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
  239. if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
  240. new_tokens = new_tokens[:-1]
  241. if is_finished:
  242. finish_reason = "stop"
  243. if is_finished and not finish_reason:
  244. finish_reason = "length"
  245. completion = generate_completion(
  246. chat_request,
  247. tokenizer,
  248. prompt,
  249. request_id,
  250. new_tokens,
  251. stream,
  252. finish_reason,
  253. "chat.completion",
  254. )
  255. if DEBUG >= 2: print(f"Streaming completion: {completion}")
  256. await response.write(f"data: {json.dumps(completion)}\n\n".encode())
  257. def on_result(_request_id: str, tokens: List[int], is_finished: bool):
  258. self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
  259. return _request_id == request_id and is_finished
  260. _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
  261. if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
  262. if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
  263. try:
  264. await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
  265. except asyncio.TimeoutError:
  266. print("WARNING: Stream task timed out. This should not happen.")
  267. await response.write_eof()
  268. return response
  269. else:
  270. _, tokens, _ = await callback.wait(
  271. lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
  272. timeout=self.response_timeout_secs,
  273. )
  274. finish_reason = "length"
  275. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
  276. if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
  277. if tokens[-1] == eos_token_id:
  278. tokens = tokens[:-1]
  279. finish_reason = "stop"
  280. return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
  281. except asyncio.TimeoutError:
  282. return web.json_response({"detail": "Response generation timed out"}, status=408)
  283. finally:
  284. deregistered_callback = self.node.on_token.deregister(callback_id)
  285. if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
  286. async def run(self, host: str = "0.0.0.0", port: int = 8000):
  287. runner = web.AppRunner(self.app)
  288. await runner.setup()
  289. site = web.TCPSite(runner, host, port)
  290. await site.start()
  291. if DEBUG >= 0:
  292. print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
  293. print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")