|
@@ -227,16 +227,20 @@ class StandardNode(Node):
|
|
|
):
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
model = shard.model_id
|
|
|
+ sid = shard.__hash__()
|
|
|
+ path = f"{destination}/{model}/{sid}-{iteration}.safetensors"
|
|
|
self.outstanding_requests[f"{sid}::{iteration}"] = "Checking"
|
|
|
if model not in self.checkpoints:
|
|
|
- self.checkpoints[model_id] = {}
|
|
|
- sid = shard.__hash__()
|
|
|
+ self.checkpoints[model] = {}
|
|
|
if sid not in self.checkpoints[model]:
|
|
|
self.checkpoints[model][sid] = []
|
|
|
- if len(self.checkpoints[model][sid]) and self.checkpoints[model][sid][-1] < iteration:
|
|
|
+ if len(self.checkpoints[model][sid]) < 1 or self.checkpoints[model][sid][-1] < iteration:
|
|
|
+ print(f"Saving checkpoint to {path}")
|
|
|
self.outstanding_requests[f"{sid}::{iteration}"] = "Saving"
|
|
|
- await self.inference_engine.save_checkpoint(f"{destination}/{model}/{hash}-{iteration}")
|
|
|
- self.checkpoints[model][sid] = sorted(self.checkpoints.model.sid + [iteration])
|
|
|
+ import os
|
|
|
+ os.makedirs("/".join(path.split("/")[:-1]), exist_ok=True)
|
|
|
+ await self.inference_engine.save_checkpoint(shard, path)
|
|
|
+ self.checkpoints[model][sid] = sorted(self.checkpoints[model][sid] + [iteration])
|
|
|
self.outstanding_requests.pop(f"{sid}::{iteration}")
|
|
|
|
|
|
async def process_example(
|