浏览代码

Naive network-propagated loss implementation on MLX

Nel Nibcord 7 月之前
父节点
当前提交
75c8650f1f

+ 11 - 0
exo/inference/mlx/losses.py

@@ -12,3 +12,14 @@ def length_masked_ce_loss(model, inputs, targets, lengths):
   loss = ce.sum() / length_mask.sum()
   return loss
 
+#Naive intermediate layer loss, where we replace the targets with gradients and just multiply the output by the gradients to derive the loss. This is naive and may warrant some further iteration, but will do the job for now
+def back_gradient_loss(model, inputs, gradients, shard_proportion):
+  out = model(inputs)
+  logits = out[:, -1, :]
+  loss = (logits * gradients).mean()
+  return loss
+
+loss_fns = {
+  "back_gradient": back_gradient_loss,
+  "length_masked_ce": length_masked_ce_loss,
+}

+ 16 - 10
exo/inference/mlx/sharded_inference_engine.py

@@ -6,7 +6,7 @@ import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
 from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
-from .losses import length_masked_ce_loss
+from .losses import loss_fns 
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
@@ -64,33 +64,39 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     #print(f"infer_tensor out -> {output_data}")
     return output_data
   
-  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
+  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('loss', lambda: loss_fns[loss])
     await self.ensure_session('task', lambda: ('eval', self.model.eval()))
     #print(f"evaluate in <- {inputs}")
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
-    y = mx.array(targets).astype(mx.int64)
+    y = mx.array(targets)
     l = mx.array(lengths)
     score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
     #print(f"evaluate out -> {score}")
     return np.array(score)
+
+  async def update_model(self, grad, lval):
+    await self.ensure_shard(shard)
+    self.session['opt'].update(self.model, grad)
+    mx.eval(self.model.parameters(), self.session['opt'].state, lval)
   
-  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=optim.Adam, lr=1e-5):
+  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.Adam, lr=1e-5):
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('loss', lambda: loss_fns[loss])
     await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
     await self.ensure_session('opt', lambda: opt(lr))
     await self.ensure_session('task', lambda: ('train', self.model.train()))
 
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
-    y = mx.array(targets).astype(mx.int64)
+    y = mx.array(targets)
     l = mx.array(lengths)
     loop = asyncio.get_running_loop()
-    loss, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
-    await loop.run_in_executor(self.executor, lambda: self.session['opt'].update(self.model, grad))
+    score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
+    loop.run_in_executor(self.executor, self.update_model, grad, score)
+    layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
 
-    return np.array(loss), np.array(grad)
+    return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 2 - 2
exo/main.py

@@ -225,7 +225,7 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   train, val, test = dataloader(lambda i: tokenizer.encode(i))
   dataset = test
-  print(f"Evaluating {len(dataset)} examples with batch_size {batch_size}")
+  print(f"Evaluating {len(test)} examples with batch_size {batch_size}")
   losses = []
   tokens = []
   for batch in tqdm(iterate_batches(test, batch_size), total=len(dataset) // batch_size):
@@ -243,7 +243,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   train, val, test = dataloader(lambda i: tokenizer.encode(i))
-  print(f"Training on {len(val)} examples with batch_size {batch_size}")
+  print(f"Training on {len(train)} examples with batch_size {batch_size}")
   for epoch in range(iters):
     losses = []
     tokens = []

+ 0 - 15
exo/networking/grpc/grpc_server.py

@@ -87,21 +87,6 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
     tensor_data = result.tobytes()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
-
-  async def SendLoss(self, request, context):
-    shard = Shard(
-      model_id=request.shard.model_id,
-      start_layer=request.shard.start_layer,
-      end_layer=request.shard.end_layer,
-      n_layers=request.shard.n_layers,
-    )
-    loss = np.frombuffer(request.loss.tensor_data, dtype=np.dtype(request.loss.dtype)).reshape(request.loss.shape)
-    request_id = request.request_id
-
-    if shard.is_first_layer():
-      asyncself.node.backward_loss(shard, loss, request_id)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth

+ 5 - 1
exo/networking/grpc/node_service.proto

@@ -5,7 +5,6 @@ package node_service;
 service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
-  rpc SendLoss (TensorRequest) returns (Empty) {}
   rpc SendExample (ExampleRequest) returns (Tensor) {}
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
@@ -41,6 +40,11 @@ message ExampleRequest {
   bool train = 5;
   optional string request_id = 6;
 }
+
+message LossRequest {
+  Tensor loss = 1;
+  Tensor grads = 2;
+}
   
 message GetInferenceResultRequest {
   string request_id = 1;

文件差异内容过多而无法显示
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 0 - 43
exo/networking/grpc/node_service_pb2_grpc.py

@@ -44,11 +44,6 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
-        self.SendLoss = channel.unary_unary(
-                '/node_service.NodeService/SendLoss',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
                 request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
@@ -96,12 +91,6 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def SendLoss(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
     def SendExample(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -151,11 +140,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
                     response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
-            'SendLoss': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendLoss,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
                     request_deserializer=node__service__pb2.ExampleRequest.FromString,
@@ -251,33 +235,6 @@ class NodeService(object):
             metadata,
             _registered_method=True)
 
-    @staticmethod
-    def SendLoss(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendLoss',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
     @staticmethod
     def SendExample(request,
             target,

+ 12 - 12
exo/orchestration/standard_node.py

@@ -269,24 +269,24 @@ class StandardNode(Node):
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
-
     if DEBUG >= 1: print(f"[{request_id}] process_example: {example.shape=}")
     try:
-      if shard.is_last_layer():
-        if train:
+      target = target.astype(int)
+      if train:
+        if shard.is_last_layer():
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
-          return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
         else:
-          loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
-          return loss.reshape(example.shape[0], -1)
+          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+          loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
+        return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
       else:
-        step = await self.inference_engine.infer_tensor(request_id, shard, example)
-        result = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
-        if train:
-          forward = self.get_current_shard(self.get_partition_index(offset = 1))
-          return result
+        if shard.is_last_layer():
+          loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
         else:
-          return result.reshape(example.shape[0], -1)
+          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+        return loss.reshape(example.shape[0], -1)
     except Exception as e:
       print(f"Error processing example for shard {shard}: {e}")
       traceback.print_exc()

部分文件因为文件数量过多而无法显示