浏览代码

revive the chatgpt api endpoint on :8000

Alex Cheema 9 月之前
父节点
当前提交
f2895cbcee
共有 6 个文件被更改,包括 128 次插入6 次删除
  1. 13 1
      README.md
  2. 1 0
      exo/api/__init__.py
  3. 104 0
      exo/api/chatgpt_api.py
  4. 2 2
      exo/orchestration/node.py
  5. 3 3
      exo/orchestration/standard_node.py
  6. 5 0
      main.py

+ 13 - 1
README.md

@@ -74,7 +74,19 @@ python3 main.py
 
 That's it! No configuration required - exo will automatically discover the other device(s).
 
-Until the below is fixed, the only way to access inference is via peer handles. See how it's done in [this example for Llama 3](examples/llama3_distributed.py).
+The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py).
+
+exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000. Note: this is currently only supported by tail nodes (i.e. nodes selected to be at the end of the ring topology). Example request:
+
+```
+curl http://localhost:8000/v1/chat/completions \
+  -H "Content-Type: application/json" \
+  -d '{
+     "model": "llama-3-70b",
+     "messages": [{"role": "user", "content": "What is the meaning of exo?"}],
+     "temperature": 0.7
+   }'
+```
 
 // A ChatGPT-like web interface will be available on each device on port 8000 http://localhost:8000 and Chat-GPT-compatible API on port 8001 (currently doesn't work see https://github.com/exo-explore/exo/issues/6).
 

+ 1 - 0
exo/api/__init__.py

@@ -0,0 +1 @@
+from exo.api.chatgpt_api import ChatGPTAPI

+ 104 - 0
exo/api/chatgpt_api.py

@@ -0,0 +1,104 @@
+import uuid
+import time
+import asyncio
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from typing import List
+from aiohttp import web
+from exo import DEBUG
+from exo.inference.shard import Shard
+from exo.orchestration import Node
+from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
+
+shard_mappings = {
+    "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+}
+
+class Message:
+    def __init__(self, role: str, content: str):
+        self.role = role
+        self.content = content
+
+class ChatCompletionRequest:
+    def __init__(self, model: str, messages: List[Message], temperature: float):
+        self.model = model
+        self.messages = messages
+        self.temperature = temperature
+
+class ChatGPTAPI:
+    def __init__(self, node: Node):
+        self.node = node
+        self.app = web.Application()
+        self.app.router.add_post('/v1/chat/completions', self.handle_post)
+
+    async def handle_post(self, request):
+        data = await request.json()
+        messages = [Message(**msg) for msg in data['messages']]
+        chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
+        prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
+        shard = shard_mappings.get(chat_request.model)
+        if not shard:
+            return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
+        request_id = str(uuid.uuid4())
+
+        tokenizer = load_tokenizer(get_model_path(shard.model_id))
+        prompt = tokenizer.apply_chat_template(
+            chat_request.messages, tokenize=False, add_generation_prompt=True
+        )
+
+        if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+        try:
+            result = await self.node.process_prompt(shard, prompt, request_id=request_id)
+        except Exception as e:
+            pass # TODO
+            # return web.json_response({'detail': str(e)}, status=500)
+        
+        # poll for the response. TODO: implement callback for specific request id
+        timeout = 90
+        start_time = time.time()
+        while time.time() - start_time < timeout:
+            print("poll")
+            try:
+                result, is_finished = await self.node.get_inference_result(request_id)
+            except Exception as e:
+                continue
+            await asyncio.sleep(0.1)
+            if is_finished:
+                return web.json_response({
+                    "id": f"chatcmpl-{request_id}",
+                    "object": "chat.completion",
+                    "created": int(time.time()),
+                    "model": chat_request.model,
+                    "usage": {
+                        "prompt_tokens": len(tokenizer.encode(prompt)),
+                        "completion_tokens": len(result),
+                        "total_tokens": len(tokenizer.encode(prompt)) + len(result)
+                    },
+                    "choices": [
+                        {
+                            "message": {
+                                "role": "assistant",
+                                "content": tokenizer.decode(result)
+                            },
+                            "logprobs": None,
+                            "finish_reason": "stop",
+                            "index": 0
+                        }
+                    ]
+                })
+
+        return web.json_response({'detail': "Response generation timed out"}, status=408)
+
+    async def run(self, host: str = "0.0.0.0", port: int = 8000):
+        runner = web.AppRunner(self.app)
+        await runner.setup()
+        site = web.TCPSite(runner, host, port)
+        await site.start()
+        print(f"Starting ChatGPT API server at {host}:{port}")
+
+# Usage example
+if __name__ == "__main__":
+    loop = asyncio.get_event_loop()
+    node = Node()  # Assuming Node is properly defined elsewhere
+    api = ChatGPTAPI(node)
+    loop.run_until_complete(api.run())

+ 2 - 2
exo/orchestration/node.py

@@ -14,11 +14,11 @@ class Node(ABC):
         pass
 
     @abstractmethod
-    async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
+    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
         pass
 
     @abstractmethod
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
+    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
         pass
 
     @abstractmethod

+ 3 - 3
exo/orchestration/standard_node.py

@@ -12,7 +12,7 @@ import asyncio
 import uuid
 
 class StandardNode(Node):
-    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
+    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 256):
         self.id = id
         self.inference_engine = inference_engine
         self.server = server
@@ -50,7 +50,7 @@ class StandardNode(Node):
             return
 
         result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
-        is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
+        is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if is_finished:
             self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
 
@@ -74,7 +74,7 @@ class StandardNode(Node):
         try:
             if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
             result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
-            is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
+            is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
             if is_finished:
                 self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
 

+ 5 - 0
main.py

@@ -11,6 +11,8 @@ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceE
 from exo.inference.shard import Shard
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
+from exo.api import ChatGPTAPI
+
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -20,6 +22,7 @@ parser.add_argument("--node-port", type=int, default=8080, help="Node port")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
+parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
 args = parser.parse_args()
 
 
@@ -32,6 +35,7 @@ node = StandardNode(args.node_id, None, inference_engine, discovery, partitionin
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 
+api = ChatGPTAPI(node)
 
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
@@ -54,6 +58,7 @@ async def main():
         loop.add_signal_handler(s, handle_exit)
 
     await node.start(wait_for_peers=args.wait_for_peers)
+    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
 
     await asyncio.Event().wait()