Browse Source

Made models save properly

Nel Nibcord 8 months ago
parent
commit
0d3abfca95
2 changed files with 10 additions and 7 deletions
  1. 1 2
      exo/main.py
  2. 9 5
      exo/orchestration/standard_node.py

+ 1 - 2
exo/main.py

@@ -267,8 +267,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
     loss, tokens = await run_iter(node, shard, True, train, batch_size)
     print(f"epoch {epoch + 1}/{iters}\t| loss: {loss}, tokens: {tokens}")
     if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0 and checkpoint_dir is not None:
-      node.coordinate_save(checkpoint_dir, shard, epoch)
-      print("Hold up let's save a checkpoint")
+      await node.coordinate_save(shard, epoch, checkpoint_dir)
       await hold_outstanding(node)
   await hold_outstanding(node)
 

+ 9 - 5
exo/orchestration/standard_node.py

@@ -227,16 +227,20 @@ class StandardNode(Node):
   ):
     shard = self.get_current_shard(base_shard)
     model = shard.model_id
+    sid = shard.__hash__()
+    path = f"{destination}/{model}/{sid}-{iteration}.safetensors"
     self.outstanding_requests[f"{sid}::{iteration}"] = "Checking"
     if model not in self.checkpoints:
-      self.checkpoints[model_id] = {}
-    sid = shard.__hash__()
+      self.checkpoints[model] = {}
     if sid not in self.checkpoints[model]:
       self.checkpoints[model][sid] = []
-    if len(self.checkpoints[model][sid]) and self.checkpoints[model][sid][-1] < iteration:
+    if len(self.checkpoints[model][sid]) < 1 or self.checkpoints[model][sid][-1] < iteration:
+      print(f"Saving checkpoint to {path}")
       self.outstanding_requests[f"{sid}::{iteration}"] = "Saving"
-      await self.inference_engine.save_checkpoint(f"{destination}/{model}/{hash}-{iteration}")
-      self.checkpoints[model][sid] = sorted(self.checkpoints.model.sid + [iteration])
+      import os
+      os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
+      await self.inference_engine.save_checkpoint(shard, path)
+      self.checkpoints[model][sid] = sorted(self.checkpoints[model][sid] + [iteration])
     self.outstanding_requests.pop(f"{sid}::{iteration}")
 
   async def process_example(