chatgpt_api.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import uuid
  2. import time
  3. import asyncio
  4. from http.server import BaseHTTPRequestHandler, HTTPServer
  5. from typing import List
  6. from aiohttp import web
  7. from exo import DEBUG
  8. from exo.inference.shard import Shard
  9. from exo.orchestration import Node
  10. from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
  11. shard_mappings = {
  12. "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
  13. "llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
  14. }
  15. class Message:
  16. def __init__(self, role: str, content: str):
  17. self.role = role
  18. self.content = content
  19. class ChatCompletionRequest:
  20. def __init__(self, model: str, messages: List[Message], temperature: float):
  21. self.model = model
  22. self.messages = messages
  23. self.temperature = temperature
  24. class ChatGPTAPI:
  25. def __init__(self, node: Node):
  26. self.node = node
  27. self.app = web.Application()
  28. self.app.router.add_post('/v1/chat/completions', self.handle_post)
  29. async def handle_post(self, request):
  30. data = await request.json()
  31. messages = [Message(**msg) for msg in data['messages']]
  32. chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
  33. prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
  34. shard = shard_mappings.get(chat_request.model)
  35. if not shard:
  36. return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
  37. request_id = str(uuid.uuid4())
  38. tokenizer = load_tokenizer(get_model_path(shard.model_id))
  39. prompt = tokenizer.apply_chat_template(
  40. chat_request.messages, tokenize=False, add_generation_prompt=True
  41. )
  42. if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
  43. try:
  44. result = await self.node.process_prompt(shard, prompt, request_id=request_id)
  45. except Exception as e:
  46. pass # TODO
  47. # return web.json_response({'detail': str(e)}, status=500)
  48. # poll for the response. TODO: implement callback for specific request id
  49. timeout = 90
  50. start_time = time.time()
  51. while time.time() - start_time < timeout:
  52. try:
  53. result, is_finished = await self.node.get_inference_result(request_id)
  54. except Exception as e:
  55. continue
  56. await asyncio.sleep(0.1)
  57. if is_finished:
  58. if result[-1] == tokenizer._tokenizer.eos_token_id:
  59. result = result[:-1]
  60. return web.json_response({
  61. "id": f"chatcmpl-{request_id}",
  62. "object": "chat.completion",
  63. "created": int(time.time()),
  64. "model": chat_request.model,
  65. "usage": {
  66. "prompt_tokens": len(tokenizer.encode(prompt)),
  67. "completion_tokens": len(result),
  68. "total_tokens": len(tokenizer.encode(prompt)) + len(result)
  69. },
  70. "choices": [
  71. {
  72. "message": {
  73. "role": "assistant",
  74. "content": tokenizer.decode(result)
  75. },
  76. "logprobs": None,
  77. "finish_reason": "stop",
  78. "index": 0
  79. }
  80. ]
  81. })
  82. return web.json_response({'detail': "Response generation timed out"}, status=408)
  83. async def run(self, host: str = "0.0.0.0", port: int = 8000):
  84. runner = web.AppRunner(self.app)
  85. await runner.setup()
  86. site = web.TCPSite(runner, host, port)
  87. await site.start()
  88. if DEBUG >= 1: print(f"Starting ChatGPT API server at {host}:{port}")
  89. # Usage example
  90. if __name__ == "__main__":
  91. loop = asyncio.get_event_loop()
  92. node = Node() # Assuming Node is properly defined elsewhere
  93. api = ChatGPTAPI(node)
  94. loop.run_until_complete(api.run())