Browse Source

one token at a time

Alex Cheema 7 months ago
parent
commit
f55a53ae7e

+ 11 - 16
exo/api/chatgpt_api.py

@@ -367,15 +367,11 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
-          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
-          new_tokens = tokens[prev_last_tokens_len:]
+        async def stream_result(_request_id: str, token: int, is_finished: bool):
           finish_reason = None
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
                                                                                                                              AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
-          if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
-            new_tokens = new_tokens[:-1]
+          if token == eos_token_id:
             if is_finished:
               finish_reason = "stop"
           if is_finished and not finish_reason:
@@ -386,7 +382,7 @@ class ChatGPTAPI:
             tokenizer,
             prompt,
             request_id,
-            new_tokens,
+            [token],
             stream,
             finish_reason,
             "chat.completion",
@@ -398,12 +394,12 @@ class ChatGPTAPI:
             if DEBUG >= 2: print(f"Error streaming completion: {e}")
             if DEBUG >= 2: traceback.print_exc()
 
-        def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
+        def on_result(_request_id: str, token: int, is_finished: bool):
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, token, is_finished))
 
           return _request_id == request_id and is_finished
 
-        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
+        _, token, _ = await callback.wait(on_result, timeout=self.response_timeout)
         if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
           if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
           try:
@@ -413,19 +409,18 @@ class ChatGPTAPI:
         await response.write_eof()
         return response
       else:
-        _, tokens, _ = await callback.wait(
-          lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+        _, token, _ = await callback.wait(
+          lambda _request_id, token, is_finished: _request_id == request_id and is_finished,
           timeout=self.response_timeout,
         )
 
         finish_reason = "length"
         eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
-        if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
-        if tokens[-1] == eos_token_id:
-          tokens = tokens[:-1]
+        if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
+        if token == eos_token_id:
           finish_reason = "stop"
 
-        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
+        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, [token], stream, finish_reason, "chat.completion"))
     except asyncio.TimeoutError:
       return web.json_response({"detail": "Response generation timed out"}, status=408)
     except Exception as e:

+ 1 - 1
exo/main.py

@@ -152,7 +152,7 @@ api = ChatGPTAPI(
   default_model=args.default_model
 )
 node.on_token.register("update_topology_viz").on_next(
-  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
+  lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 
 def preemptively_start_download(request_id: str, opaque_status: str):

+ 3 - 13
exo/networking/grpc/grpc_peer_handle.py

@@ -147,16 +147,6 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
-    response = await self.stub.GetInferenceResult(request)
-    if response.tensor is None:
-      return None, response.is_finished
-    return (
-      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
-      response.is_finished,
-    )
-
   async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
     response = await self.stub.CollectTopology(request)
@@ -174,9 +164,9 @@ class GRPCPeerHandle(PeerHandle):
         topology.add_edge(node_id, conn.to_id, conn.description)
     return topology
 
-  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
-    await self.stub.SendResult(request)
+  async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
+    request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
+    await self.stub.SendNewToken(request)
 
   async def send_opaque_status(self, request_id: str, status: str) -> None:
     request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)

+ 3 - 13
exo/networking/grpc/node_service.proto

@@ -6,9 +6,8 @@ service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc SendExample (ExampleRequest) returns (Loss) {}
-  rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
-  rpc SendResult (SendResultRequest) returns (Empty) {}
+  rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
   rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
 }
@@ -45,15 +44,6 @@ message Loss {
   float loss = 1;
   optional Tensor grads = 2;
 }
-  
-message GetInferenceResultRequest {
-  string request_id = 1;
-}
-
-message InferenceResult {
-  optional Tensor tensor = 1;
-  bool is_finished = 2;
-}
 
 message Tensor {
   bytes tensor_data = 1;
@@ -93,9 +83,9 @@ message DeviceCapabilities {
   DeviceFlops flops = 4;
 }
 
-message SendResultRequest {
+message SendNewTokenRequest {
   string request_id = 1;
-  repeated int32 result = 2;
+  int32 token = 2;
   bool is_finished = 3;
 }
 

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 10 - 53
exo/networking/grpc/node_service_pb2_grpc.py

@@ -49,19 +49,14 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Loss.FromString,
                 _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Topology.FromString,
                 _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+        self.SendNewToken = channel.unary_unary(
+                '/node_service.NodeService/SendNewToken',
+                request_serializer=node__service__pb2.SendNewTokenRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendOpaqueStatus = channel.unary_unary(
@@ -97,19 +92,13 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
     def CollectTopology(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def SendResult(self, request, context):
+    def SendNewToken(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
         context.set_details('Method not implemented!')
@@ -145,19 +134,14 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.ExampleRequest.FromString,
                     response_serializer=node__service__pb2.Loss.SerializeToString,
             ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                     response_serializer=node__service__pb2.Topology.SerializeToString,
             ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+            'SendNewToken': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendNewToken,
+                    request_deserializer=node__service__pb2.SendNewTokenRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
@@ -262,33 +246,6 @@ class NodeService(object):
             metadata,
             _registered_method=True)
 
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
     @staticmethod
     def CollectTopology(request,
             target,
@@ -317,7 +274,7 @@ class NodeService(object):
             _registered_method=True)
 
     @staticmethod
-    def SendResult(request,
+    def SendNewToken(request,
             target,
             options=(),
             channel_credentials=None,
@@ -330,8 +287,8 @@ class NodeService(object):
         return grpc.experimental.unary_unary(
             request,
             target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
+            '/node_service.NodeService/SendNewToken',
+            node__service__pb2.SendNewTokenRequest.SerializeToString,
             node__service__pb2.Empty.FromString,
             options,
             channel_credentials,

+ 1 - 5
exo/networking/peer_handle.py

@@ -48,11 +48,7 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    pass
-
-  @abstractmethod
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+  async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
     pass
 
   @abstractmethod

+ 13 - 19
exo/orchestration/node.py

@@ -47,7 +47,7 @@ class Node:
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self.default_sample_temperature = default_sample_temperature
-    self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+    self._on_token = AsyncCallbackSystem[str, Tuple[str, int, bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
@@ -122,10 +122,9 @@ class Node:
       self.buffered_token_output[request_id][0].append(token.item())
       is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-      asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
       forward = token.reshape(1, -1)
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+      self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
+      asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
     else:
       forward = result
 
@@ -549,11 +548,6 @@ class Node:
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
 
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    if request_id not in self.buffered_token_output:
-      return None, False
-    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
-
   async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
     next_topology = Topology()
     next_topology.update_node(self.id, self.device_capabilities)
@@ -590,28 +584,28 @@ class Node:
     return self.topology
 
   @property
-  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, int, bool]]:
     return self._on_token
 
   @property
   def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
     return self._on_opaque_status
 
-  def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
-    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
-    self.on_token.trigger_all(request_id, tokens, is_finished)
+  def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
+    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
+    self.on_token.trigger_all(request_id, token, is_finished)
   
-  async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    async def send_result_to_peer(peer):
+  async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
+    async def send_new_token_to_peer(peer):
       try:
-        await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
+        await asyncio.wait_for(peer.send_new_token(request_id, token, is_finished), timeout=15.0)
       except asyncio.TimeoutError:
-        print(f"Timeout broadcasting result to {peer.id()}")
+        print(f"Timeout broadcasting new token to {peer.id()}")
       except Exception as e:
-        print(f"Error broadcasting result to {peer.id()}: {e}")
+        print(f"Error broadcasting new token to {peer.id()}: {e}")
         traceback.print_exc()
 
-    await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+    await asyncio.gather(*[send_new_token_to_peer(peer) for peer in self.peers], return_exceptions=True)
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
     if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")

Some files were not shown because too many files changed in this diff