Browse Source

Fixed up the ops so that batches work

Nel Nibcord 8 months ago
parent
commit
38e368f00b

+ 2 - 2
exo/inference/mlx/losses.py

@@ -19,9 +19,9 @@ def back_gradient_loss(model, inputs, gradients, lengths):
   grad = gradients.astype(mx.float32)
 
   # Mask padding tokens
-  length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
+  length_mask = mx.repeat(mx.arange(inputs.shape[1])[None, :] < lengths[:, None], out.shape[-1]).reshape(out.shape)
 
-  masked_sum = (out * length_mask.T).sum(axis=1)
+  masked_sum = (out * length_mask).sum(axis=1)
   gradient_lens = mx.abs(grad * masked_sum)
   loss = gradient_lens.sum() / length_mask.sum()
 #  print(f"|    {inputs=}\n"

+ 1 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -134,8 +134,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
     #print(layers[0])
 
-    return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
-    return 0, 0
+    return np.array(score).reshape(1, -1), np.array(layers[0]['input_layernorm'])
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 0 - 1
exo/main.py

@@ -225,7 +225,6 @@ async def hold_outstanding(node: Node):
     else:
       return      
 
-
 async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
   losses = []
   tokens = []

+ 2 - 2
exo/orchestration/standard_node.py

@@ -292,7 +292,7 @@ class StandardNode(Node):
           self.outstanding_requests[request_id] = "training"
           loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
         self.outstanding_requests.pop(request_id)
-        return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
+        return loss.reshape(1, -1) if shard.is_first_layer() else grad
       else:
         if shard.is_last_layer():
           self.outstanding_requests[request_id] = "evaluating"
@@ -303,7 +303,7 @@ class StandardNode(Node):
           self.outstanding_requests[request_id] = "waiting"
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
         self.outstanding_requests.pop(request_id)
-        return loss.reshape(example.shape[0], -1)
+        return loss.reshape(1, -1)
     except Exception as e:
       self.outstanding_requests.pop(request_id)
       print(f"Error processing example for shard {shard}: {e}")

+ 2 - 2
setup.py

@@ -35,8 +35,8 @@ extras_require = {
     "yapf==0.40.2",
   ],
   "apple_silicon": [
-    "mlx==0.20.0",
-    "mlx-lm==0.19.3",
+    "mlx",
+    "mlx-lm",
   ],
 }