浏览代码

Naive network-propagated loss implementation on MLX

Nel Nibcord 7 月之前
父节点
当前提交
75c8650f1f

+ 11 - 0
exo/inference/mlx/losses.py

@@ -12,3 +12,14 @@ def length_masked_ce_loss(model, inputs, targets, lengths):
   loss = ce.sum() / length_mask.sum()
   loss = ce.sum() / length_mask.sum()
   return loss
   return loss
 
 
+#Naive intermediate layer loss, where we replace the targets with gradients and just multiply the output by the gradients to derive the loss. This is naive and may warrant some further iteration, but will do the job for now
+def back_gradient_loss(model, inputs, gradients, shard_proportion):
+  out = model(inputs)
+  logits = out[:, -1, :]
+  loss = (logits * gradients).mean()
+  return loss
+
+loss_fns = {
+  "back_gradient": back_gradient_loss,
+  "length_masked_ce": length_masked_ce_loss,
+}

+ 16 - 10
exo/inference/mlx/sharded_inference_engine.py

@@ -6,7 +6,7 @@ import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
 from .stateful_model import StatefulModel
 from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
-from .losses import length_masked_ce_loss
+from .losses import loss_fns 
 from ..shard import Shard
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
@@ -64,33 +64,39 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     #print(f"infer_tensor out -> {output_data}")
     #print(f"infer_tensor out -> {output_data}")
     return output_data
     return output_data
   
   
-  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
+  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('loss', lambda: loss_fns[loss])
     await self.ensure_session('task', lambda: ('eval', self.model.eval()))
     await self.ensure_session('task', lambda: ('eval', self.model.eval()))
     #print(f"evaluate in <- {inputs}")
     #print(f"evaluate in <- {inputs}")
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
-    y = mx.array(targets).astype(mx.int64)
+    y = mx.array(targets)
     l = mx.array(lengths)
     l = mx.array(lengths)
     score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
     score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
     #print(f"evaluate out -> {score}")
     #print(f"evaluate out -> {score}")
     return np.array(score)
     return np.array(score)
+
+  async def update_model(self, grad, lval):
+    await self.ensure_shard(shard)
+    self.session['opt'].update(self.model, grad)
+    mx.eval(self.model.parameters(), self.session['opt'].state, lval)
   
   
-  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=optim.Adam, lr=1e-5):
+  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.Adam, lr=1e-5):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('loss', lambda: loss_fns[loss])
     await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
     await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
     await self.ensure_session('opt', lambda: opt(lr))
     await self.ensure_session('opt', lambda: opt(lr))
     await self.ensure_session('task', lambda: ('train', self.model.train()))
     await self.ensure_session('task', lambda: ('train', self.model.train()))
 
 
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
-    y = mx.array(targets).astype(mx.int64)
+    y = mx.array(targets)
     l = mx.array(lengths)
     l = mx.array(lengths)
     loop = asyncio.get_running_loop()
     loop = asyncio.get_running_loop()
-    loss, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
-    await loop.run_in_executor(self.executor, lambda: self.session['opt'].update(self.model, grad))
+    score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
+    loop.run_in_executor(self.executor, self.update_model, grad, score)
+    layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
 
 
-    return np.array(loss), np.array(grad)
+    return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:

+ 2 - 2
exo/main.py

@@ -225,7 +225,7 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   train, val, test = dataloader(lambda i: tokenizer.encode(i))
   train, val, test = dataloader(lambda i: tokenizer.encode(i))
   dataset = test
   dataset = test
-  print(f"Evaluating {len(dataset)} examples with batch_size {batch_size}")
+  print(f"Evaluating {len(test)} examples with batch_size {batch_size}")
   losses = []
   losses = []
   tokens = []
   tokens = []
   for batch in tqdm(iterate_batches(test, batch_size), total=len(dataset) // batch_size):
   for batch in tqdm(iterate_batches(test, batch_size), total=len(dataset) // batch_size):
@@ -243,7 +243,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
     return
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   train, val, test = dataloader(lambda i: tokenizer.encode(i))
   train, val, test = dataloader(lambda i: tokenizer.encode(i))
-  print(f"Training on {len(val)} examples with batch_size {batch_size}")
+  print(f"Training on {len(train)} examples with batch_size {batch_size}")
   for epoch in range(iters):
   for epoch in range(iters):
     losses = []
     losses = []
     tokens = []
     tokens = []

+ 0 - 15
exo/networking/grpc/grpc_server.py

@@ -87,21 +87,6 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
     tensor_data = result.tobytes()
     tensor_data = result.tobytes()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
-
-  async def SendLoss(self, request, context):
-    shard = Shard(
-      model_id=request.shard.model_id,
-      start_layer=request.shard.start_layer,
-      end_layer=request.shard.end_layer,
-      n_layers=request.shard.n_layers,
-    )
-    loss = np.frombuffer(request.loss.tensor_data, dtype=np.dtype(request.loss.dtype)).reshape(request.loss.shape)
-    request_id = request.request_id
-
-    if shard.is_first_layer():
-      asyncself.node.backward_loss(shard, loss, request_id)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     
     
   async def CollectTopology(self, request, context):
   async def CollectTopology(self, request, context):
     max_depth = request.max_depth
     max_depth = request.max_depth

+ 5 - 1
exo/networking/grpc/node_service.proto

@@ -5,7 +5,6 @@ package node_service;
 service NodeService {
 service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
-  rpc SendLoss (TensorRequest) returns (Empty) {}
   rpc SendExample (ExampleRequest) returns (Tensor) {}
   rpc SendExample (ExampleRequest) returns (Tensor) {}
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
@@ -41,6 +40,11 @@ message ExampleRequest {
   bool train = 5;
   bool train = 5;
   optional string request_id = 6;
   optional string request_id = 6;
 }
 }
+
+message LossRequest {
+  Tensor loss = 1;
+  Tensor grads = 2;
+}
   
   
 message GetInferenceResultRequest {
 message GetInferenceResultRequest {
   string request_id = 1;
   string request_id = 1;

文件差异内容过多而无法显示
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 0 - 43
exo/networking/grpc/node_service_pb2_grpc.py

@@ -44,11 +44,6 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Tensor.FromString,
                 response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
                 _registered_method=True)
-        self.SendLoss = channel.unary_unary(
-                '/node_service.NodeService/SendLoss',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
         self.SendExample = channel.unary_unary(
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
                 '/node_service.NodeService/SendExample',
                 request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
                 request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
@@ -96,12 +91,6 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
 
-    def SendLoss(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
     def SendExample(self, request, context):
     def SendExample(self, request, context):
         """Missing associated documentation comment in .proto file."""
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -151,11 +140,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
                     response_serializer=node__service__pb2.Tensor.SerializeToString,
                     response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             ),
-            'SendLoss': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendLoss,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
                     servicer.SendExample,
                     request_deserializer=node__service__pb2.ExampleRequest.FromString,
                     request_deserializer=node__service__pb2.ExampleRequest.FromString,
@@ -251,33 +235,6 @@ class NodeService(object):
             metadata,
             metadata,
             _registered_method=True)
             _registered_method=True)
 
 
-    @staticmethod
-    def SendLoss(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendLoss',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
     @staticmethod
     @staticmethod
     def SendExample(request,
     def SendExample(request,
             target,
             target,

+ 12 - 12
exo/orchestration/standard_node.py

@@ -269,24 +269,24 @@ class StandardNode(Node):
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
-
     if DEBUG >= 1: print(f"[{request_id}] process_example: {example.shape=}")
     if DEBUG >= 1: print(f"[{request_id}] process_example: {example.shape=}")
     try:
     try:
-      if shard.is_last_layer():
-        if train:
+      target = target.astype(int)
+      if train:
+        if shard.is_last_layer():
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
-          return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
         else:
         else:
-          loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
-          return loss.reshape(example.shape[0], -1)
+          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+          loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
+        return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
       else:
       else:
-        step = await self.inference_engine.infer_tensor(request_id, shard, example)
-        result = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
-        if train:
-          forward = self.get_current_shard(self.get_partition_index(offset = 1))
-          return result
+        if shard.is_last_layer():
+          loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
         else:
         else:
-          return result.reshape(example.shape[0], -1)
+          step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+        return loss.reshape(example.shape[0], -1)
     except Exception as e:
     except Exception as e:
       print(f"Error processing example for shard {shard}: {e}")
       print(f"Error processing example for shard {shard}: {e}")
       traceback.print_exc()
       traceback.print_exc()

部分文件因为文件数量过多而无法显示