chatgpt_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import uuid
  2. import time
  3. import asyncio
  4. import json
  5. from pathlib import Path
  6. from transformers import AutoTokenizer
  7. from typing import List, Literal, Union, Dict
  8. from aiohttp import web
  9. import aiohttp_cors
  10. from exo import DEBUG, VERSION
  11. from exo.helpers import terminal_link
  12. from exo.inference.shard import Shard
  13. from exo.orchestration import Node
  14. shard_mappings = {
  15. "llama-3-8b": {
  16. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
  17. "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
  18. },
  19. "llama-3-70b": {
  20. "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
  21. "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
  22. },
  23. }
  24. class Message:
  25. def __init__(self, role: str, content: str):
  26. self.role = role
  27. self.content = content
  28. class ChatCompletionRequest:
  29. def __init__(self, model: str, messages: List[Message], temperature: float):
  30. self.model = model
  31. self.messages = messages
  32. self.temperature = temperature
  33. def resolve_tinygrad_tokenizer(model_id: str):
  34. if model_id == "llama3-8b-sfr":
  35. return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
  36. elif model_id == "llama3-70b-sfr":
  37. return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
  38. else:
  39. raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
  40. def resolve_tokenizer(model_id: str):
  41. try:
  42. if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
  43. return AutoTokenizer.from_pretrained(model_id)
  44. except:
  45. import traceback
  46. if DEBUG >= 2: print(traceback.format_exc())
  47. if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
  48. try:
  49. if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
  50. return resolve_tinygrad_tokenizer(model_id)
  51. except:
  52. import traceback
  53. if DEBUG >= 2: print(traceback.format_exc())
  54. if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
  55. if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
  56. from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
  57. return load_tokenizer(get_model_path(model_id))
  58. def generate_completion(
  59. chat_request: ChatCompletionRequest,
  60. tokenizer,
  61. prompt: str,
  62. request_id: str,
  63. tokens: List[int],
  64. stream: bool,
  65. finish_reason: Union[Literal["length", "stop"], None],
  66. object_type: Literal["chat.completion", "text_completion"]
  67. ) -> dict:
  68. completion = {
  69. "id": f"chatcmpl-{request_id}",
  70. "object": object_type,
  71. "created": int(time.time()),
  72. "model": chat_request.model,
  73. "system_fingerprint": f"exo_{VERSION}",
  74. "choices": [
  75. {
  76. "index": 0,
  77. "message": {
  78. "role": "assistant",
  79. "content": tokenizer.decode(tokens)
  80. },
  81. "logprobs": None,
  82. "finish_reason": finish_reason,
  83. }
  84. ]
  85. }
  86. if not stream:
  87. completion["usage"] = {
  88. "prompt_tokens": len(tokenizer.encode(prompt)),
  89. "completion_tokens": len(tokens),
  90. "total_tokens": len(tokenizer.encode(prompt)) + len(tokens)
  91. }
  92. choice = completion["choices"][0]
  93. if object_type.startswith("chat.completion"):
  94. key_name = "delta" if stream else "message"
  95. choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
  96. elif object_type == "text_completion":
  97. choice['text'] = tokenizer.decode(tokens)
  98. else:
  99. ValueError(f"Unsupported response type: {object_type}")
  100. return completion
  101. def build_prompt(tokenizer, messages: List[Message]):
  102. return tokenizer.apply_chat_template(
  103. messages, tokenize=False, add_generation_prompt=True
  104. )
  105. class ChatGPTAPI:
  106. def __init__(self, node: Node, inference_engine_classname: str):
  107. self.node = node
  108. self.inference_engine_classname = inference_engine_classname
  109. self.response_timeout_secs = 90
  110. self.app = web.Application()
  111. self.prev_token_lens: Dict[str, int] = {}
  112. self.stream_tasks: Dict[str, asyncio.Task] = {}
  113. cors = aiohttp_cors.setup(self.app)
  114. cors_options = aiohttp_cors.ResourceOptions(
  115. allow_credentials=True,
  116. expose_headers="*",
  117. allow_headers="*",
  118. allow_methods="*",
  119. )
  120. cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post_chat_completions), {
  121. "*": cors_options
  122. })
  123. cors.add(self.app.router.add_post('/v1/chat/token/encode', self.handle_post_chat_token_encode), {
  124. "*": cors_options
  125. })
  126. self.static_dir = Path(__file__).parent.parent.parent / 'tinychat/examples/tinychat'
  127. self.app.router.add_get('/', self.handle_root)
  128. self.app.router.add_static('/', self.static_dir, name='static')
  129. async def handle_root(self, request):
  130. return web.FileResponse(self.static_dir / 'index.html')
  131. async def handle_post_chat_token_encode(self, request):
  132. data = await request.json()
  133. shard = shard_mappings.get(data.get('model', 'llama-3-8b'), {}).get(self.inference_engine_classname)
  134. messages = data.get('messages', [])
  135. tokenizer = resolve_tokenizer(shard.model_id)
  136. return web.json_response({'length': len(build_prompt(tokenizer, messages))})
  137. async def handle_post_chat_completions(self, request):
  138. data = await request.json()
  139. stream = data.get('stream', False)
  140. messages = [Message(**msg) for msg in data['messages']]
  141. chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
  142. if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
  143. chat_request.model = "llama-3-8b"
  144. shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
  145. if not shard:
  146. return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
  147. request_id = str(uuid.uuid4())
  148. tokenizer = resolve_tokenizer(shard.model_id)
  149. if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
  150. prompt = build_prompt(tokenizer, messages)
  151. callback_id = f"chatgpt-api-wait-response-{request_id}"
  152. callback = self.node.on_token.register(callback_id)
  153. if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
  154. try:
  155. await self.node.process_prompt(shard, prompt, request_id=request_id)
  156. except Exception as e:
  157. if DEBUG >= 2:
  158. import traceback
  159. traceback.print_exc()
  160. return web.json_response({'detail': f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
  161. try:
  162. if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
  163. if stream:
  164. response = web.StreamResponse(
  165. status=200,
  166. reason="OK",
  167. headers={
  168. "Content-Type": "application/json",
  169. "Cache-Control": "no-cache",
  170. }
  171. )
  172. await response.prepare(request)
  173. async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
  174. prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
  175. self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
  176. new_tokens = tokens[prev_last_tokens_len:]
  177. finish_reason = None
  178. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
  179. if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
  180. new_tokens = new_tokens[:-1]
  181. if is_finished:
  182. finish_reason = "stop"
  183. if is_finished and not finish_reason:
  184. finish_reason = "length"
  185. completion = generate_completion(chat_request, tokenizer, prompt, request_id, new_tokens, stream, finish_reason, "chat.completion")
  186. if DEBUG >= 2: print(f"Streaming completion: {completion}")
  187. await response.write(f"data: {json.dumps(completion)}\n\n".encode())
  188. def on_result(_request_id: str, tokens: List[int], is_finished: bool):
  189. self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
  190. return _request_id == request_id and is_finished
  191. _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
  192. if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
  193. if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
  194. try:
  195. await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
  196. except asyncio.TimeoutError:
  197. print("WARNING: Stream task timed out. This should not happen.")
  198. await response.write_eof()
  199. return response
  200. else:
  201. _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
  202. finish_reason = "length"
  203. eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
  204. if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
  205. if tokens[-1] == eos_token_id:
  206. tokens = tokens[:-1]
  207. finish_reason = "stop"
  208. return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
  209. except asyncio.TimeoutError:
  210. return web.json_response({'detail': "Response generation timed out"}, status=408)
  211. finally:
  212. deregistered_callback = self.node.on_token.deregister(callback_id)
  213. if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
  214. async def run(self, host: str = "0.0.0.0", port: int = 8000):
  215. runner = web.AppRunner(self.app)
  216. await runner.setup()
  217. site = web.TCPSite(runner, host, port)
  218. await site.start()
  219. if DEBUG >= 0:
  220. print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
  221. print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")