Преглед изворни кода

fix treating token as a list

Alex Cheema пре 3 месеци
родитељ
комит
9954ce8e4d

+ 4 - 4
exo/api/chatgpt_api.py

@@ -408,16 +408,16 @@ class ChatGPTAPI:
           # Stream tokens while waiting for inference to complete
           while True:
             if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
-            token, is_finished = await asyncio.wait_for(
+            tokens, is_finished = await asyncio.wait_for(
               self.token_queues[request_id].get(),
               timeout=self.response_timeout
             )
-            if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
 
             finish_reason = None
             eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
 
-            if token == eos_token_id:
+            if tokens[-1] == eos_token_id:
               if is_finished:
                 finish_reason = "stop"
             if is_finished and not finish_reason:
@@ -428,7 +428,7 @@ class ChatGPTAPI:
               tokenizer,
               prompt,
               request_id,
-              [token],
+              tokens,
               stream,
               finish_reason,
               "chat.completion",

+ 1 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -123,7 +123,7 @@ class GRPCPeerHandle(PeerHandle):
       request_id=request_id,
       inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
-    response =await self.stub.SendTensor(request)
+    response = await self.stub.SendTensor(request)
 
     if not response.tensor_data or not response.shape or not response.dtype:
       return None

+ 4 - 4
exo/networking/grpc/node_service.proto

@@ -3,11 +3,11 @@ syntax = "proto3";
 package node_service;
 
 service NodeService {
-  rpc SendPrompt (PromptRequest) returns (Empty) {}
-  rpc SendTensor (TensorRequest) returns (Empty) {}
+  rpc SendPrompt (PromptRequest) returns (Tensor) {}
+  rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc SendExample (ExampleRequest) returns (Loss) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
-  rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
+  rpc SendResult (SendResultRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
   rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
 }
@@ -95,7 +95,7 @@ message DeviceCapabilities {
   DeviceFlops flops = 4;
 }
 
-message SendNewTokenRequest {
+message SendResultRequest {
   string request_id = 1;
   repeated int32 result = 2;
   optional Tensor tensor = 3;

Разлика између датотеке није приказан због своје велике величине
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 16 - 16
exo/networking/grpc/node_service_pb2_grpc.py

@@ -37,12 +37,12 @@ class NodeServiceStub(object):
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
                 request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
@@ -54,9 +54,9 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Topology.FromString,
                 _registered_method=True)
-        self.SendNewToken = channel.unary_unary(
-                '/node_service.NodeService/SendNewToken',
-                request_serializer=node__service__pb2.SendNewTokenRequest.SerializeToString,
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendOpaqueStatus = channel.unary_unary(
@@ -98,7 +98,7 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def SendNewToken(self, request, context):
+    def SendResult(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
         context.set_details('Method not implemented!')
@@ -122,12 +122,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
                     request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
@@ -139,9 +139,9 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                     response_serializer=node__service__pb2.Topology.SerializeToString,
             ),
-            'SendNewToken': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendNewToken,
-                    request_deserializer=node__service__pb2.SendNewTokenRequest.FromString,
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
@@ -181,7 +181,7 @@ class NodeService(object):
             target,
             '/node_service.NodeService/SendPrompt',
             node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -208,7 +208,7 @@ class NodeService(object):
             target,
             '/node_service.NodeService/SendTensor',
             node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -274,7 +274,7 @@ class NodeService(object):
             _registered_method=True)
 
     @staticmethod
-    def SendNewToken(request,
+    def SendResult(request,
             target,
             options=(),
             channel_credentials=None,
@@ -287,8 +287,8 @@ class NodeService(object):
         return grpc.experimental.unary_unary(
             request,
             target,
-            '/node_service.NodeService/SendNewToken',
-            node__service__pb2.SendNewTokenRequest.SerializeToString,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
             node__service__pb2.Empty.FromString,
             options,
             channel_credentials,

Неке датотеке нису приказане због велике количине промена