Browse Source

Generalizing some of the dataset biz while also creating uniform batches

Nel Nibcord 8 months ago
parent
commit
a6fd7a3430
2 changed files with 52 additions and 39 deletions
  1. 18 20
      exo/main.py
  2. 34 19
      exo/train/dataset.py

+ 18 - 20
exo/main.py

@@ -11,6 +11,7 @@ import time
 import traceback
 import uuid
 import numpy as np
+from functools import partial
 from tqdm import tqdm
 from tqdm.asyncio import tqdm_asyncio
 from exo.train.dataset import load_dataset, iterate_batches
@@ -215,43 +216,41 @@ def clean_path(path):
         path = path.strip('Optional("').rstrip('")')
     return os.path.expanduser(path)
 
-async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataset, batch_size, num_batches=-1):
+async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  all_losses = []
-  ntokens = 0
-
+  train, val, test = dataloader(tokenizer)
+  dataset = test
   print(f"Evaluating {len(dataset)} examples with batch_size {batch_size}")
   losses = []
   tokens = []
-  for batch in tqdm(iterate_batches(dataset, tokenizer, batch_size), total=len(dataset) // batch_size):
+  for batch in tqdm(iterate_batches(test, batch_size), total=len(dataset) // batch_size):
     _, _, lengths = batch
     losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
     tokens.append(np.sum(lengths))
   total_loss = np.sum(losses) / np.sum(tokens)
   print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
 
-async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataset, batch_size, iters):
+async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  all_losses = []
-  ntokens = 0
-
-  print(f"Training on {len(dataset)} examples with batch_size {batch_size}")
-  losses = []
-  tokens = []
-  for batch in tqdm(iterate_batches(dataset, tokenizer, batch_size), total=len(dataset) // batch_size):
-    _, _, lengths = batch
-    losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
-    tokens.append(np.sum(lengths))
+  train, val, test = dataloader(tokenizer)
+  print(f"Training on {len(train)} examples with batch_size {batch_size}")
+  for epoch in range(iters):
+    losses = []
+    tokens = []
+    for batch in tqdm(iterate_batches(train, batch_size), total=len(dataset) // batch_size):
+      _, _, lengths = batch
+      losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
+      tokens.append(np.sum(lengths))
   total_loss = np.sum(losses) / np.sum(tokens)
   print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
 
@@ -301,19 +300,18 @@ async def main():
       return
     await run_model_cli(node, inference_engine, model_name, args.prompt)
   elif args.command == "eval" or args.command == 'train':
-    data_path = args.data
-    train, val, test = load_dataset(data_path)
     model_name = args.model_name
+    dataloader = lambda tok: load_dataset(args.data, preprocess=lambda i: tok.encode(i["text"]))
     if args.command == 'eval':
       if not model_name:
         print("Error: Much like a human, I can't evaluate anything without a model")
         return
-      await eval_model_cli(node, inference_engine, model_name, test, args.batch_size)
+      await eval_model_cli(node, inference_engine, model_name, dataloader, args.batch_size)
     else:
       if not model_name:
         print("Error: This train ain't leaving the station without a model")
         return
-      await train_model_cli(node, inference_engine, model_name, test, args.batch_size, args.iters)
+      await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters)
     
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task

+ 34 - 19
exo/train/dataset.py

@@ -2,63 +2,78 @@
 from pathlib import Path
 import numpy as np
 import json
+from functools import partial, reduce
+def compose(*funcs):    
+  return reduce(lambda f, g: lambda x: f(g(x)), funcs, lambda x : x)
 
-def make_batch(tokens):
+def batch_with_lengths(tokens, maxlen = None):
   lengths = [len(x) for x in tokens]
   batch_size = len(lengths)
-
-  # Check if any sequence is longer than 2048 tokens
-  if max(lengths) > 2048:
-    print("You've got sequences with over 2048 tokens in here! Split your data fool!")
+  if maxlen is None:
+    maxlen = max(lengths)
+  else:
+    lengths = [min(maxlen, l) for l in lengths]
 
   # Pad to the max length
-  batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
+  batch_arr = np.zeros((batch_size, maxlen), np.int32)
 
   for j in range(batch_size):
     batch_arr[j, : lengths[j]] = tokens[j]
   batch = np.array(batch_arr)
   return batch[:, :-1], batch[:, 1:], np.array(lengths)
 
-def iterate_batches(dset, tokenizer, batch_size, train=False):
+def batch_chunk(batch_size):
+  return lambda d, i: d[i:i + batch_size]
+  
+
+def iterate_batches(dset, batch_size, train=False, uniform_length=True):
 # Shuffle indices
+  make_batch = lambda b: batch_with_lengths(b, maxlen=dset._maxlen if uniform_length else None)
+  chunk = batch_chunk(batch_size)
   while True:
     indices = np.arange(len(dset))
     if train:
       indices = np.random.permutation(indices)
+    batch = compose(make_batch, lambda i: [dset[k] for k in i], partial(chunk, indices))
 
     # Collect batches from dataset
     for i in range(0, len(indices) - batch_size + 1, batch_size):
-      # Encode batch
-      yield make_batch([tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)])
+      yield batch(i)
 
     if not train:
       break
 
 class Dataset:
-  """
-  Light-weight wrapper to hold lines from a jsonl file
-  """
-
-  def __init__(self, path: Path, key: str = "text"):
+  preprocess = lambda item: item
+  load = lambda line: line
+  def __init__(self, path: Path, preprocess=None, load=None, metrics={}):
     if not path.exists():
       self._data = None
     else:
+      if preprocess is not None:
+        self.preprocess = preprocess
+      if load is not None:
+        self.load = load
       with open(path, "r") as fid:
-        self._data = [json.loads(l) for l in fid]
-    self._key = key
+        self._data = [load(l) for l in fid]
+        self._maxlen = max([len(self.preprocess(x)) for x in self._data])
+        # Check if any sequence is longer than 2048 tokens
+        if self._maxlen > 2048:
+          print("You've got sequences with over 2048 tokens in here! Split your data fool!")
+
 
   def __getitem__(self, idx: int):
-    return self._data[idx][self._key]
+    return self.preprocess(self._data[idx])
 
   def __len__(self):
     return len(self._data)
 
 
-def load_dataset(data_path: str):
+def load_dataset(data_path: str, preprocess=None):
   def load_and_check(name):
     dataset_path = Path(data_path) / f"{name}.jsonl"
     try:
-      return Dataset(dataset_path)
+      return Dataset(dataset_path, preprocess=preprocess, load=json.loads)
     except Exception as e:
       print(f"Unable to build dataset {dataset_path} ({e})")
       raise