Browse Source

support chatgpt api endpoint fron any node #24

Alex Cheema 10 tháng trước cách đây
mục cha
commit
8a35fd83f6

+ 5 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -1,6 +1,6 @@
 import grpc
 import numpy as np
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List
 
 # These would be generated from the .proto file
 from . import node_service_pb2
@@ -92,3 +92,7 @@ class GRPCPeerHandle(PeerHandle):
     async def global_reset(self, base_shard: Shard, visited: set[str], max_depth: int) -> None:
         request = node_service_pb2.GlobalResetRequest(base_shard=node_service_pb2.Shard(model_id=base_shard.model_id, start_layer=base_shard.start_layer, end_layer=base_shard.end_layer, n_layers=base_shard.n_layers), visited=visited, max_depth=max_depth)
         await self.stub.GlobalReset(request)
+
+    async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+        request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
+        await self.stub.SendResult(request)

+ 10 - 0
exo/networking/grpc/grpc_server.py

@@ -70,11 +70,21 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         topology = await self.node.collect_topology(visited, max_depth)
         nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()}
         peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+        if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
         return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
     async def GlobalReset(self, request, context):
         base_shard = Shard(model_id=request.base_shard.model_id, start_layer=request.base_shard.start_layer, end_layer=request.base_shard.end_layer, n_layers=request.base_shard.n_layers)
         visited = set(request.visited)
         max_depth = request.max_depth
+        if DEBUG >= 2: print(f"Received GlobalReset request: {base_shard=} {visited=} {max_depth=}")
         await self.node.global_reset(base_shard, visited, max_depth)
         return node_service_pb2.Empty()
+
+    async def SendResult(self, request, context):
+        request_id = request.request_id
+        result = request.result
+        is_finished = request.is_finished
+        if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+        self.node.on_token.trigger_all(request_id, result, is_finished)
+        return node_service_pb2.Empty()

+ 7 - 0
exo/networking/grpc/node_service.proto

@@ -9,6 +9,7 @@ service NodeService {
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc GlobalReset (GlobalResetRequest) returns (Empty) {}
+  rpc SendResult (SendResultRequest) returns (Empty) {}
 }
 
 message Shard {
@@ -77,4 +78,10 @@ message DeviceCapabilities {
   int32 memory = 3;
 }
 
+message SendResultRequest {
+  string request_id = 1;
+  repeated int32 result = 2;
+  bool is_finished = 3;
+}
+
 message Empty {}

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 43 - 0
exo/networking/grpc/node_service_pb2_grpc.py

@@ -69,6 +69,11 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.GlobalResetRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
 
 
 class NodeServiceServicer(object):
@@ -110,6 +115,12 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
 
 def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
@@ -143,6 +154,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.GlobalResetRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
             'node_service.NodeService', rpc_method_handlers)
@@ -315,3 +331,30 @@ class NodeService(object):
             timeout,
             metadata,
             _registered_method=True)
+
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 5 - 1
exo/networking/peer_handle.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Optional, Tuple
+from typing import Optional, Tuple, List
 import numpy as np
 from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
@@ -49,3 +49,7 @@ class PeerHandle(ABC):
     @abstractmethod
     async def global_reset(self, base_shard: Shard, visited: set[str], max_depth: int) -> None:
         pass
+
+    @abstractmethod
+    async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+        pass

+ 11 - 0
exo/orchestration/standard_node.py

@@ -54,6 +54,7 @@ class StandardNode(Node):
         is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if is_finished:
             self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+            asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
 
         if result.size == 1:
             self.buffered_token_output[request_id][0].append(result.item())
@@ -78,6 +79,7 @@ class StandardNode(Node):
             is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
             if is_finished:
                 self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+                asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
 
             if result.size == 1:  # we got a new token out
                 self.buffered_token_output[request_id][0].append(result.item())
@@ -236,3 +238,12 @@ class StandardNode(Node):
     def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
         if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
         self.on_token.trigger_all(request_id, tokens, is_finished)
+
+    async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+        for peer in self.peers:
+            try:
+                await peer.send_result(request_id, result, is_finished)
+            except Exception as e:
+                import traceback
+                traceback.print_exc()
+                print(f"Error broadcasting result to {peer.id()}: {e}")

Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác