|
@@ -11,6 +11,7 @@ import time
|
|
|
import traceback
|
|
|
import uuid
|
|
|
import numpy as np
|
|
|
+from functools import partial
|
|
|
from tqdm import tqdm
|
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
from exo.train.dataset import load_dataset, iterate_batches
|
|
@@ -215,43 +216,41 @@ def clean_path(path):
|
|
|
path = path.strip('Optional("').rstrip('")')
|
|
|
return os.path.expanduser(path)
|
|
|
|
|
|
-async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataset, batch_size, num_batches=-1):
|
|
|
+async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
|
|
|
inference_class = inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
|
if not shard:
|
|
|
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
|
return
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
|
|
- all_losses = []
|
|
|
- ntokens = 0
|
|
|
-
|
|
|
+ train, val, test = dataloader(tokenizer)
|
|
|
+ dataset = test
|
|
|
print(f"Evaluating {len(dataset)} examples with batch_size {batch_size}")
|
|
|
losses = []
|
|
|
tokens = []
|
|
|
- for batch in tqdm(iterate_batches(dataset, tokenizer, batch_size), total=len(dataset) // batch_size):
|
|
|
+ for batch in tqdm(iterate_batches(test, batch_size), total=len(dataset) // batch_size):
|
|
|
_, _, lengths = batch
|
|
|
losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
|
|
|
tokens.append(np.sum(lengths))
|
|
|
total_loss = np.sum(losses) / np.sum(tokens)
|
|
|
print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
|
|
|
|
|
|
-async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataset, batch_size, iters):
|
|
|
+async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters):
|
|
|
inference_class = inference_engine.__class__.__name__
|
|
|
shard = build_base_shard(model_name, inference_class)
|
|
|
if not shard:
|
|
|
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
|
|
return
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
|
|
- all_losses = []
|
|
|
- ntokens = 0
|
|
|
-
|
|
|
- print(f"Training on {len(dataset)} examples with batch_size {batch_size}")
|
|
|
- losses = []
|
|
|
- tokens = []
|
|
|
- for batch in tqdm(iterate_batches(dataset, tokenizer, batch_size), total=len(dataset) // batch_size):
|
|
|
- _, _, lengths = batch
|
|
|
- losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
|
|
|
- tokens.append(np.sum(lengths))
|
|
|
+ train, val, test = dataloader(tokenizer)
|
|
|
+ print(f"Training on {len(train)} examples with batch_size {batch_size}")
|
|
|
+ for epoch in range(iters):
|
|
|
+ losses = []
|
|
|
+ tokens = []
|
|
|
+ for batch in tqdm(iterate_batches(train, batch_size), total=len(dataset) // batch_size):
|
|
|
+ _, _, lengths = batch
|
|
|
+ losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
|
|
|
+ tokens.append(np.sum(lengths))
|
|
|
total_loss = np.sum(losses) / np.sum(tokens)
|
|
|
print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
|
|
|
|
|
@@ -301,19 +300,18 @@ async def main():
|
|
|
return
|
|
|
await run_model_cli(node, inference_engine, model_name, args.prompt)
|
|
|
elif args.command == "eval" or args.command == 'train':
|
|
|
- data_path = args.data
|
|
|
- train, val, test = load_dataset(data_path)
|
|
|
model_name = args.model_name
|
|
|
+ dataloader = lambda tok: load_dataset(args.data, preprocess=lambda i: tok.encode(i["text"]))
|
|
|
if args.command == 'eval':
|
|
|
if not model_name:
|
|
|
print("Error: Much like a human, I can't evaluate anything without a model")
|
|
|
return
|
|
|
- await eval_model_cli(node, inference_engine, model_name, test, args.batch_size)
|
|
|
+ await eval_model_cli(node, inference_engine, model_name, dataloader, args.batch_size)
|
|
|
else:
|
|
|
if not model_name:
|
|
|
print("Error: This train ain't leaving the station without a model")
|
|
|
return
|
|
|
- await train_model_cli(node, inference_engine, model_name, test, args.batch_size, args.iters)
|
|
|
+ await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters)
|
|
|
|
|
|
else:
|
|
|
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|