Sfoglia il codice sorgente

cleaner chatgpt api impl with async callbacks

Alex Cheema 9 mesi fa
parent
commit
eb92da2c3e

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, DEBUG_DISCOVERY

+ 46 - 42
exo/api/chatgpt_api.py

@@ -65,6 +65,7 @@ class ChatGPTAPI:
         self.app = web.Application()
         self.app.router.add_post('/v1/chat/completions', self.handle_post)
         self.inference_engine_classname = inference_engine_classname
+        self.response_timeout_secs = 90
 
     async def handle_post(self, request):
         data = await request.json()
@@ -84,49 +85,52 @@ class ChatGPTAPI:
 
         if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
         try:
-            result = await self.node.process_prompt(shard, prompt, request_id=request_id)
+            await self.node.process_prompt(shard, prompt, request_id=request_id)
         except Exception as e:
-            pass # TODO
-            # return web.json_response({'detail': str(e)}, status=500)
-
-        # poll for the response. TODO: implement callback for specific request id
-        timeout = 90
-        start_time = time.time()
-        while time.time() - start_time < timeout:
-            try:
-                result, is_finished = await self.node.get_inference_result(request_id)
-            except Exception as e:
-                continue
-            await asyncio.sleep(0.1)
-            if is_finished:
-                eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
-                if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
-                if result[-1] == eos_token_id:
-                    result = result[:-1]
-                return web.json_response({
-                    "id": f"chatcmpl-{request_id}",
-                    "object": "chat.completion",
-                    "created": int(time.time()),
-                    "model": chat_request.model,
-                    "usage": {
-                        "prompt_tokens": len(tokenizer.encode(prompt)),
-                        "completion_tokens": len(result),
-                        "total_tokens": len(tokenizer.encode(prompt)) + len(result)
-                    },
-                    "choices": [
-                        {
-                            "message": {
-                                "role": "assistant",
-                                "content": tokenizer.decode(result)
-                            },
-                            "logprobs": None,
-                            "finish_reason": "stop",
-                            "index": 0
-                        }
-                    ]
-                })
-
-        return web.json_response({'detail': "Response generation timed out"}, status=408)
+            if DEBUG >= 2:
+                import traceback
+                traceback.print_exc()
+            return web.json_response({'detail': f"Error processing prompt (see logs): {str(e)}"}, status=500)
+
+        callback_id = f"chatgpt-api-wait-response-{request_id}"
+        callback = self.node.on_token.register(callback_id)
+
+        try:
+            if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
+            _, result, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
+
+            eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+            if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
+            if result[-1] == eos_token_id:
+                result = result[:-1]
+
+            return web.json_response({
+                "id": f"chatcmpl-{request_id}",
+                "object": "chat.completion",
+                "created": int(time.time()),
+                "model": chat_request.model,
+                "usage": {
+                    "prompt_tokens": len(tokenizer.encode(prompt)),
+                    "completion_tokens": len(result),
+                    "total_tokens": len(tokenizer.encode(prompt)) + len(result)
+                },
+                "choices": [
+                    {
+                        "message": {
+                            "role": "assistant",
+                            "content": tokenizer.decode(result)
+                        },
+                        "logprobs": None,
+                        "finish_reason": "stop",
+                        "index": 0
+                    }
+                ]
+            })
+        except asyncio.TimeoutError:
+            return web.json_response({'detail': "Response generation timed out"}, status=408)
+        finally:
+            deregistered_callback = self.node.on_token.deregister(callback_id)
+            if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
     async def run(self, host: str = "0.0.0.0", port: int = 8000):
         runner = web.AppRunner(self.app)

+ 54 - 0
exo/helpers.py

@@ -1,3 +1,57 @@
 import os
+import asyncio
+from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
+DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
+
+T = TypeVar('T')
+K = TypeVar('K')
+
+class AsyncCallback(Generic[T]):
+    def __init__(self) -> None:
+        self.condition: asyncio.Condition = asyncio.Condition()
+        self.result: Optional[Tuple[T, ...]] = None
+        self.observers: list[Callable[..., None]] = []
+
+    async def wait(self,
+                   check_condition: Callable[..., bool],
+                   timeout: Optional[float] = None) -> Tuple[T, ...]:
+        async with self.condition:
+            await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
+            assert self.result is not None  # for type checking
+            return self.result
+
+    def on_next(self, callback: Callable[..., None]) -> None:
+        self.observers.append(callback)
+
+    def set(self, *args: T) -> None:
+        self.result = args
+        for observer in self.observers:
+            observer(*args)
+        asyncio.create_task(self.notify())
+
+    async def notify(self) -> None:
+        async with self.condition:
+            self.condition.notify_all()
+
+class AsyncCallbackSystem(Generic[K, T]):
+    def __init__(self) -> None:
+        self.callbacks: Dict[K, AsyncCallback[T]] = {}
+
+    def register(self, name: K) -> AsyncCallback[T]:
+        if name not in self.callbacks:
+            self.callbacks[name] = AsyncCallback[T]()
+        return self.callbacks[name]
+
+    def deregister(self, name: K) -> None:
+        if name in self.callbacks:
+            del self.callbacks[name]
+
+    def trigger(self, name: K, *args: T) -> None:
+        if name in self.callbacks:
+            self.callbacks[name].set(*args)
+
+    def trigger_all(self, *args: T) -> None:
+        for callback in self.callbacks.values():
+            callback.set(*args)

+ 10 - 10
exo/networking/grpc/grpc_discovery.py

@@ -7,7 +7,7 @@ from ..discovery import Discovery
 from ..peer_handle import PeerHandle
 from .grpc_peer_handle import GRPCPeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
-from exo import DEBUG
+from exo import DEBUG_DISCOVERY
 
 class GRPCDiscovery(Discovery):
     def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
@@ -39,26 +39,26 @@ class GRPCDiscovery(Discovery):
             await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
 
     async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-        if DEBUG >= 2: print("Starting peer discovery process...")
+        if DEBUG_DISCOVERY >= 2: print("Starting peer discovery process...")
 
         if wait_for_peers > 0:
             while len(self.known_peers) == 0:
-                if DEBUG >= 2: print("No peers discovered yet, retrying in 1 second...")
+                if DEBUG_DISCOVERY >= 2: print("No peers discovered yet, retrying in 1 second...")
                 await asyncio.sleep(1)  # Keep trying to find peers
-            if DEBUG >= 2: print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
+            if DEBUG_DISCOVERY >= 2: print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
 
         grace_period = 5  # seconds
         while True:
             initial_peer_count = len(self.known_peers)
-            if DEBUG >= 2: print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
+            if DEBUG_DISCOVERY >= 2: print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
             await asyncio.sleep(grace_period)
             if len(self.known_peers) == initial_peer_count:
                 if wait_for_peers > 0:
-                    if DEBUG >= 2: print(f"Waiting additional {wait_for_peers} seconds for more peers.")
+                    if DEBUG_DISCOVERY >= 2: print(f"Waiting additional {wait_for_peers} seconds for more peers.")
                     await asyncio.sleep(wait_for_peers)
                     wait_for_peers = 0
                 else:
-                    if DEBUG >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
+                    if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
                     break  # No new peers found in the grace period, we are done
 
         return list(self.known_peers.values())
@@ -94,7 +94,7 @@ class GRPCDiscovery(Discovery):
             try:
                 data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024)
                 message = json.loads(data.decode('utf-8'))
-                if DEBUG >= 2: print(f"received from peer {addr}: {message}")
+                if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
                 if message['type'] == 'discovery' and message['node_id'] != self.node_id:
                     peer_id = message['node_id']
                     peer_host = addr[0]
@@ -102,7 +102,7 @@ class GRPCDiscovery(Discovery):
                     device_capabilities = DeviceCapabilities(**message['device_capabilities'])
                     if peer_id not in self.known_peers:
                         self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
-                        if DEBUG >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
+                        if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
                     self.peer_last_seen[peer_id] = time.time()
             except Exception as e:
                 print(f"Error in peer discovery: {e}")
@@ -118,5 +118,5 @@ class GRPCDiscovery(Discovery):
             for peer_id in peers_to_remove:
                 del self.known_peers[peer_id]
                 del self.peer_last_seen[peer_id]
-                if DEBUG >= 2: print(f"Removed peer {peer_id} due to inactivity.")
+                if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
             await asyncio.sleep(self.broadcast_interval)

+ 7 - 1
exo/orchestration/node.py

@@ -1,6 +1,7 @@
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List, Callable
 import numpy as np
 from abc import ABC, abstractmethod
+from exo.helpers import AsyncCallbackSystem
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 
@@ -36,3 +37,8 @@ class Node(ABC):
     @abstractmethod
     async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
         pass
+
+    @property
+    @abstractmethod
+    def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+        pass

+ 15 - 6
exo/orchestration/standard_node.py

@@ -8,11 +8,12 @@ from exo.topology.device_capabilities import device_capabilities
 from exo.topology.partitioning_strategy import PartitioningStrategy
 from exo.topology.partitioning_strategy import Partition
 from exo import DEBUG
+from exo.helpers import AsyncCallback, AsyncCallbackSystem
 import asyncio
 import uuid
 
 class StandardNode(Node):
-    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 256):
+    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256):
         self.id = id
         self.inference_engine = inference_engine
         self.server = server
@@ -22,7 +23,7 @@ class StandardNode(Node):
         self.topology: Topology = Topology()
         self.device_capabilities = device_capabilities()
         self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-        self.on_token = on_token
+        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
         self.max_generate_tokens = max_generate_tokens
 
     async def start(self, wait_for_peers: int = 0) -> None:
@@ -56,14 +57,14 @@ class StandardNode(Node):
 
         if result.size == 1:
             self.buffered_token_output[request_id][0].append(result.item())
-            self.on_token(self.buffered_token_output[request_id][0])
+            self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
 
-        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
+        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
 
         if not is_finished:
             asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
 
-        return np.array(self.buffered_token_output[request_id]) if len(self.buffered_token_output[request_id]) > 0 else None
+        return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 
     async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         if request_id is None:
@@ -80,7 +81,7 @@ class StandardNode(Node):
 
             if result.size == 1:  # we got a new token out
                 self.buffered_token_output[request_id][0].append(result.item())
-                self.on_token(self.buffered_token_output[request_id][0])
+                self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
             if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id])}")
 
             if not is_finished:
@@ -227,3 +228,11 @@ class StandardNode(Node):
                 await peer.global_reset(base_shard, visited, max_depth = max_depth - 1)
             except Exception as e:
                 print(f"Error collecting topology from {peer.id()}: {e}")
+
+    @property
+    def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+        return self._on_token
+
+    def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
+        if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
+        self.on_token.trigger_all(request_id, tokens, is_finished)

+ 47 - 0
exo/test_callbacks.py

@@ -0,0 +1,47 @@
+import asyncio
+from typing import Any, Callable
+from exo.helpers import AsyncCallbackSystem, AsyncCallback
+
+# Usage example
+async def main() -> None:
+    callback_system = AsyncCallbackSystem[str, Any]()
+
+    # Register callbacks
+    callback1 = callback_system.register("callback1")
+    callback2 = callback_system.register("callback2")
+
+    def on_next_callback(name: str) -> Callable[..., None]:
+        def callback(*args: Any) -> None:
+            print(f"{name} received values: {args}")
+        return callback
+
+    callback1.on_next(on_next_callback("Callback1"))
+    callback2.on_next(on_next_callback("Callback2"))
+
+    async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
+        try:
+            result = await callback.wait(condition, timeout=2)
+            print(f"{name} wait completed with result: {result}")
+        except asyncio.TimeoutError:
+            print(f"{name} wait timed out")
+
+    # Trigger all callbacks at once
+    callback_system.trigger_all("Hello", 42, True)
+
+    # Wait for all callbacks with different conditions
+    await asyncio.gather(
+        wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
+        wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True)
+    )
+
+    # Trigger individual callback
+    callback_system.trigger("callback2", "World", -10, False)
+
+    # Demonstrate timeout
+    new_callback = callback_system.register("new_callback")
+    new_callback.on_next(on_next_callback("NewCallback"))
+    await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
+
+    callback_system.trigger("callback2", "World", 200, False)
+
+asyncio.run(main())

+ 3 - 5
main.py

@@ -30,16 +30,14 @@ else:
     from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
     inference_engine = TinygradDynamicShardInferenceEngine()
 
-def on_token(tokens: List[int]):
-    if inference_engine.tokenizer:
-        print(inference_engine.tokenizer.decode(tokens))
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
-node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), on_token=on_token)
+node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
-
 api = ChatGPTAPI(node, inference_engine.__class__.__name__)
 
+node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if inference_engine.tokenizer else tokens))
+
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     print(f"Received exit signal {signal.name}...")