|
@@ -44,7 +44,8 @@ parser.add_argument("--iters", type=int, default=100, help="Training iterations"
|
|
|
parser.add_argument("--save-every", type=int, default=5, help="Save the model every N iterations.")
|
|
|
parser.add_argument("--data", type=str, default="exo/train/data/lora", help="Directory where training data lives")
|
|
|
parser.add_argument("--batch-size", type=int, default=1, help="Minibatch size.")
|
|
|
-parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="Directory from which to load and save checkpoints")
|
|
|
+parser.add_argument("--resume-checkpoint", type=str, default=None, help="Path to a custom checkpoint to load")
|
|
|
+parser.add_argument("--save-checkpoint-dir", type=str, default="checkpoints", help="Path to a folder where checkpoints are stored")
|
|
|
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
|
|
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
|
|
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
|
@@ -264,8 +265,9 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
|
|
|
await asyncio.sleep(1)
|
|
|
for epoch in range(iters):
|
|
|
loss, tokens = await run_iter(node, shard, True, train, batch_size)
|
|
|
- print(f"epoch {epoch + 1}/{iters}\t| {loss=}, {tokens=}")
|
|
|
- if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0:
|
|
|
+ print(f"epoch {epoch + 1}/{iters}\t| loss: {loss}, tokens: {tokens}")
|
|
|
+ if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0 and checkpoint_dir is not None:
|
|
|
+ node.coordinate_save(checkpoint_dir, shard, epoch)
|
|
|
print("Hold up let's save a checkpoint")
|
|
|
await hold_outstanding(node)
|
|
|
await hold_outstanding(node)
|
|
@@ -329,7 +331,7 @@ async def main():
|
|
|
if not model_name:
|
|
|
print("Error: This train ain't leaving the station without a model")
|
|
|
return
|
|
|
- await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every)
|
|
|
+ await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
|
|
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|