浏览代码

Basic model saving

Nel Nibcord 8 月之前
父节点
当前提交
9eadee310b
共有 3 个文件被更改,包括 29 次插入7 次删除
  1. 2 2
      exo/inference/mlx/sharded_inference_engine.py
  2. 6 4
      exo/main.py
  3. 21 1
      exo/orchestration/standard_node.py

+ 2 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -70,11 +70,11 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
 
-  async def save_checkpoint(self, path: str):
+  async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
 
-  async def load_checkpoint(self, path: str):
+  async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
     

+ 6 - 4
exo/main.py

@@ -44,7 +44,8 @@ parser.add_argument("--iters", type=int, default=100, help="Training iterations"
 parser.add_argument("--save-every", type=int, default=5, help="Save the model every N iterations.")
 parser.add_argument("--data", type=str, default="exo/train/data/lora", help="Directory where training data lives")
 parser.add_argument("--batch-size", type=int, default=1, help="Minibatch size.")
-parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="Directory from which to load and save checkpoints")
+parser.add_argument("--resume-checkpoint", type=str, default=None, help="Path to a custom checkpoint to load")
+parser.add_argument("--save-checkpoint-dir", type=str, default="checkpoints", help="Path to a folder where checkpoints are stored")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
@@ -264,8 +265,9 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
     await asyncio.sleep(1)
   for epoch in range(iters):
     loss, tokens = await run_iter(node, shard, True, train, batch_size)
-    print(f"epoch {epoch + 1}/{iters}\t| {loss=}, {tokens=}")
-    if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0:
+    print(f"epoch {epoch + 1}/{iters}\t| loss: {loss}, tokens: {tokens}")
+    if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0 and checkpoint_dir is not None:
+      node.coordinate_save(checkpoint_dir, shard, epoch)
       print("Hold up let's save a checkpoint")
       await hold_outstanding(node)
   await hold_outstanding(node)
@@ -329,7 +331,7 @@ async def main():
       if not model_name:
         print("Error: This train ain't leaving the station without a model")
         return
-      await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every)
+      await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
     
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task

+ 21 - 1
exo/orchestration/standard_node.py

@@ -43,6 +43,7 @@ class StandardNode(Node):
     self.buffered_logits: Dict[str, List[np.ndarray]] = {}
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.buffered_partials: Dict[str, List[np.ndarray]] = {}
+    self.checkpoints: Dict[str, Dict[str, int]] = {}
     
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
@@ -217,7 +218,26 @@ class StandardNode(Node):
       self.outstanding_requests[request_id] = "waiting"
       resp = await self.forward_example(shard, example, target, length, train, request_id, 0) 
     return resp
-    
+
+  async def coordinate_save(
+    self,
+    base_shard: Shard,
+    iteration: int,
+    destination: str,
+  ):
+    shard = self.get_current_shard(base_shard)
+    model = shard.model_id
+    self.outstanding_requests[f"{sid}::{iteration}"] = "Checking"
+    if model not in self.checkpoints:
+      self.checkpoints[model_id] = {}
+    sid = shard.__hash__()
+    if sid not in self.checkpoints[model]:
+      self.checkpoints[model][sid] = []
+    if len(self.checkpoints[model][sid]) and self.checkpoints[model][sid][-1] < iteration:
+      self.outstanding_requests[f"{sid}::{iteration}"] = "Saving"
+      await self.inference_engine.save_checkpoint(f"{destination}/{model}/{hash}-{iteration}")
+      self.checkpoints[model][sid] = sorted(self.checkpoints.model.sid + [iteration])
+    self.outstanding_requests.pop(f"{sid}::{iteration}")
 
   async def process_example(
     self,