瀏覽代碼

Correct loss propagation so we can see the actual loss instead of just the requestor shard's loss

Nel Nibcord 8 月之前
父節點
當前提交
9283f6d7bd

+ 0 - 1
exo/api/chatgpt_api.py

@@ -153,7 +153,6 @@ class ChatGPTAPI:
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
-    print(self.inference_engine_classname)
     self.response_timeout = response_timeout
     self.on_chat_completion_request = on_chat_completion_request
     self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload

+ 2 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -99,7 +99,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     l = mx.array(lengths)
     score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
     #print(f"evaluate out -> {score}")
-    return np.array(score)
+    return score
 
   async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
     await self.ensure_shard(shard)
@@ -134,7 +134,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
     #print(layers[0])
 
-    return np.array(score).reshape(1, -1), np.array(layers[0]['input_layernorm'])
+    return score, np.array(layers[0]['input_layernorm'])
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 7 - 7
exo/networking/grpc/grpc_peer_handle.py

@@ -118,16 +118,16 @@ class GRPCPeerHandle(PeerHandle):
       example=node_service_pb2.Tensor(tensor_data=example.tobytes(), shape=example.shape, dtype=str(example.dtype)),
       target=node_service_pb2.Tensor(tensor_data=target.tobytes(), shape=target.shape, dtype=str(target.dtype)),
       length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
-      train = train,
+      train=train,
       request_id=request_id,
     )
     response = await self.stub.SendExample(request)
-
-    if not response.tensor_data or not response.shape or not response.dtype:
-      return None
-
-    out = np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-    return out
+    loss = response.loss
+    if train and not shard.is_first_layer():
+      grads = np.frombuffer(response.grads.tensor_data, dtype=np.dtype(response.grads.dtype)).reshape(response.grads.shape)
+      return loss, grads
+    else:
+      return loss
   
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(

+ 8 - 4
exo/networking/grpc/grpc_server.py

@@ -83,10 +83,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     train = request.train
     request_id = request.request_id
 
-    result = await self.node.process_example(shard, example, target, length, train, request_id)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
-    tensor_data = result.tobytes()
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+    if train and not shard.is_first_layer():
+      loss, grad = await self.node.process_example(shard, example, target, length, train, request_id)
+      tensor_data = grad.tobytes()
+      grad_tensor = node_service_pb2.Tensor(tensor_data=tensor_data, shape=grad.shape, dtype=str(grad.dtype))
+      return node_service_pb2.Loss(loss=loss, grads=grad_tensor)
+    else:
+      loss = await self.node.process_example(shard, example, target, length, train, request_id)
+      return node_service_pb2.Loss(loss=loss, grads=None)
     
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth

+ 4 - 4
exo/networking/grpc/node_service.proto

@@ -5,7 +5,7 @@ package node_service;
 service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
-  rpc SendExample (ExampleRequest) returns (Tensor) {}
+  rpc SendExample (ExampleRequest) returns (Loss) {}
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc SendResult (SendResultRequest) returns (Empty) {}
@@ -41,9 +41,9 @@ message ExampleRequest {
   optional string request_id = 6;
 }
 
-message LossRequest {
-  Tensor loss = 1;
-  Tensor grads = 2;
+message Loss {
+  float loss = 1;
+  optional Tensor grads = 2;
 }
   
 message GetInferenceResultRequest {

文件差異過大導致無法顯示
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 3 - 3
exo/networking/grpc/node_service_pb2_grpc.py

@@ -47,7 +47,7 @@ class NodeServiceStub(object):
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
                 request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                response_deserializer=node__service__pb2.Loss.FromString,
                 _registered_method=True)
         self.GetInferenceResult = channel.unary_unary(
                 '/node_service.NodeService/GetInferenceResult',
@@ -143,7 +143,7 @@ def add_NodeServiceServicer_to_server(servicer, server):
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
                     request_deserializer=node__service__pb2.ExampleRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    response_serializer=node__service__pb2.Loss.SerializeToString,
             ),
             'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
                     servicer.GetInferenceResult,
@@ -251,7 +251,7 @@ class NodeService(object):
             target,
             '/node_service.NodeService/SendExample',
             node__service__pb2.ExampleRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            node__service__pb2.Loss.FromString,
             options,
             channel_credentials,
             insecure,

+ 13 - 28
exo/orchestration/standard_node.py

@@ -122,7 +122,6 @@ class StandardNode(Node):
       await self.inference_engine.ensure_shard(shard)
       self.buffered_token_output[request_id][0].append(token.item())
       is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
       asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
       forward = token.reshape(1, -1)
@@ -211,13 +210,14 @@ class StandardNode(Node):
   ):
     shard = self.get_current_shard(base_shard)
     if shard.is_first_layer():
-      resp = await self.process_example(shard, example, target, length, train, request_id)
+      loss = await self.process_example(shard, example, target, length, train, request_id)
+      return loss
     else:
       if request_id is None:
         request_id = str(uuid.uuid4())
       self.outstanding_requests[request_id] = "waiting"
-      resp = await self.forward_example(shard, example, target, length, train, request_id, 0) 
-    return resp
+      loss = await self.forward_example(shard, example, target, length, train, request_id, 0) 
+    return loss
 
   async def coordinate_save(
     self,
@@ -279,7 +279,6 @@ class StandardNode(Node):
           "shard": shard.to_dict(),
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
         }),
       )
     )
@@ -308,11 +307,14 @@ class StandardNode(Node):
           self.outstanding_requests[request_id] = "preprocessing"
           step = await self.inference_engine.infer_tensor(request_id, shard, example)
           self.outstanding_requests[request_id] = "waiting"
-          backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+          loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
           self.outstanding_requests[request_id] = "training"
-          loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
+          partial_loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
         self.outstanding_requests.pop(request_id)
-        return loss.reshape(1, -1) if shard.is_first_layer() else grad
+        if shard.is_first_layer():
+          return loss
+        else:
+          return loss, grad
       else:
         if shard.is_last_layer():
           self.outstanding_requests[request_id] = "evaluating"
@@ -323,7 +325,7 @@ class StandardNode(Node):
           self.outstanding_requests[request_id] = "waiting"
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
         self.outstanding_requests.pop(request_id)
-        return loss.reshape(1, -1)
+        return loss
     except Exception as e:
       self.outstanding_requests.pop(request_id)
       print(f"Error processing example for shard {shard}: {e}")
@@ -413,25 +415,8 @@ class StandardNode(Node):
     if not target_peer:
       raise ValueError(f"peer for {target_index} not found")
     if DEBUG >= 1: print(f"sending example to {target_peer.id()}: {step} => {target} ({length})")
-    ret = await target_peer.send_example(target_shard, step, target, length, request_id=request_id, train=train)
-    return ret
-
-  async def forward_loss(
-    self,
-    base_shard: Shard,
-    loss: np.ndarray,
-    request_id: str,
-    target_index: int,
-  ) -> None:
-    if DEBUG >= 1: print(f"target partition index: {target_index}")
-    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
-    target_shard = self.get_current_shard(base_shard, target_index)
-    if DEBUG >= 2: print(f"computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
-    target_peer = next((p for p in self.peers if p.id() == target_id), None)
-    if not target_peer:
-      raise ValueError(f"peer for {target_index} not found")
-    if DEBUG >= 1: print(f"sending tensor to {target_peer.id()}: {loss}")
-    await target_peer.send_loss(target_shard, step, target, length, request_id=request_id)
+    resp = await target_peer.send_example(target_shard, step, target, length, request_id=request_id, train=train)
+    return resp
 
   async def forward_prompt(
     self,

+ 2 - 2
setup.py

@@ -35,8 +35,8 @@ extras_require = {
     "yapf==0.40.2",
   ],
   "apple_silicon": [
-    "mlx",
-    "mlx-lm",
+    "mlx==0.20.0",
+    "mlx-lm==0.19.3",
   ],
 }
 

部分文件因文件數量過多而無法顯示