Browse Source

Working distributed training

Only works on unquantized models on MLX so far. Also for some weird reason any opt but SGD seems to NaN everything
Nel Nibcord 8 months ago
parent
commit
dd3d99043b

+ 4 - 2
exo/inference/inference_engine.py

@@ -27,9 +27,11 @@ class InferenceEngine(ABC):
   async def save_session(self, key, value):
     self.session[key] = value
   
-  async def ensure_session(self, key, value_gen):
-    if key not in self.session:
+  async def ensure_session(self, key, check, value_gen, hook=None):
+    if key not in self.session or not check(self.session[key]):
       await self.save_session(key, value_gen())
+      if hook is not None:
+        hook()
   
   async def ensure_session_match(self, key, check, value_gen):
     if key not in self.session or not check(self.session[key]):

+ 17 - 5
exo/inference/mlx/losses.py

@@ -3,20 +3,32 @@ import mlx.nn as nn
 def length_masked_ce_loss(model, inputs, targets, lengths):
   # Run model on inputs
   logits = model(inputs).astype(mx.float32)
-
+  
   # Mask padding tokens
   length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
 
   # Calculate the loss
   ce = nn.losses.cross_entropy(logits, targets) * length_mask
   loss = ce.sum() / length_mask.sum()
+#  print(f"|    {inputs=}\n| ==>{logits=}\n| ~^~{ce=}\n| == {loss=}")
   return loss
 
 #Naive intermediate layer loss, where we replace the targets with gradients and just multiply the output by the gradients to derive the loss. This is naive and may warrant some further iteration, but will do the job for now
-def back_gradient_loss(model, inputs, gradients, shard_proportion):
-  out = model(inputs)
-  logits = out[:, -1, :]
-  loss = (logits * gradients).mean()
+def back_gradient_loss(model, inputs, gradients, lengths):
+  out = model(inputs).astype(mx.float32)
+  grad = gradients.astype(mx.float32)
+
+  # Mask padding tokens
+  length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
+
+  masked_sum = (out * length_mask.T).sum(axis=1)
+  gradient_lens = mx.abs(grad * masked_sum)
+  loss = gradient_lens.sum() / length_mask.sum()
+#  print(f"|    {inputs=}\n"
+#      + f"| ==>{out=}\n"
+#      + f"| ~^~{masked_sum=}\n"
+#      + f"| <~>{gradient_lens=}\n"
+#      + f"| == {loss=}")
   return loss
 
 loss_fns = {

+ 57 - 25
exo/inference/mlx/sharded_inference_engine.py

@@ -4,7 +4,6 @@ import mlx.nn as nn
 from mlx_lm.sample_utils import top_p_sampling
 import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
-from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from .losses import loss_fns 
 from ..shard import Shard
@@ -12,6 +11,9 @@ from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from collections import OrderedDict
+from mlx_lm.models.cache import make_prompt_cache
 
 def sample_logits(
   logits: mx.array,
@@ -39,8 +41,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard = None
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
+    self.caches = OrderedDict()
     self.session = {}
 
+  async def poll_cache(self, request_id: str, max_caches=2):
+    if request_id in self.caches:
+      self.caches.move_to_end(request_id)
+    else:
+      newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
+      if len(self.caches) > max_caches:
+        self.caches.popitem(last=False)
+      self.caches[request_id] = newcache
+    return self.caches[request_id]
+
   async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     y = mx.array(x)
     logits = y[:, -1, :]
@@ -57,54 +70,72 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
 
-  async def save_checkpoint(self, path: Path):
+  async def save_checkpoint(self, path: str):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
 
-  async def load_checkpoint(self, path: Path):
+  async def load_checkpoint(self, path: str):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
     
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     #print(f"infer_tensor in <- {input_data}")
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    loop = asyncio.get_running_loop()
+    cache = await self.poll_cache(request_id)
+    x = mx.array(input_data).astype(mx.int64) if self.shard.is_first_layer() else mx.array(input_data)
+    #print(f"Infer Tensor: {x=}")
+    output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, cache=cache)))
     #print(f"infer_tensor out -> {output_data}")
     return output_data
-  
+
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss_fns[loss])
-    await self.ensure_session('task', lambda: ('eval', self.model.eval()))
+    await self.save_session('loss', loss_fns[loss])
+    loop = asyncio.get_running_loop()
     #print(f"evaluate in <- {inputs}")
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
-    score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
+    score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
     #print(f"evaluate out -> {score}")
     return np.array(score)
 
-  async def update_model(self, grad, lval):
+  async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
     await self.ensure_shard(shard)
-    self.session['opt'].update(self.model, grad)
-    mx.eval(self.model.parameters(), self.session['opt'].state, lval)
-  
-  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.Adam, lr=1e-5):
-    await self.ensure_shard(shard)
-    await self.ensure_session('loss', lambda: loss_fns[loss])
-    await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
-    await self.ensure_session('opt', lambda: opt(lr))
-    await self.ensure_session('task', lambda: ('train', self.model.train()))
+    if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
+      await self.save_session('train_layers', trainable_layers)
+      self.model.freeze()
+      self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
+    if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
+      await self.save_session('lossname', loss)
+      await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
+    if 'opt' not in self.session:
+      await self.save_session('opt', opt(lr))
+    return True
+
+  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
+    loop = asyncio.get_running_loop()
+    nothin = await self.ensure_train(shard, loss, opt, lr)
+    def train_step(inp, tar, lng):
+      lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
+      gradlayers = grad['model']['layers']
+      self.session['opt'].update(self.model, grad)
+      mx.eval(self.model.parameters(), self.session['opt'].state, lval)
+      return lval, gradlayers
 
     x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
-    loop = asyncio.get_running_loop()
-    score, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
-    layers = [{k: v["weight"].shape for k,v in l.items() if 'weight' in v} for l in grad['model']['model']['layers'] if l]
-    await loop.run_in_executor(self.executor, self.update_model, grad, score)
+
+    score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
+    #print(f"{score=}")
+      
+    layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
+    #print(layers[0])
 
     return np.array(score).reshape(inputs.shape[0], -1), np.array(layers[0]['input_layernorm']).reshape(inputs.shape[0], -1)
+    return 0, 0
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -113,12 +144,13 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
-      loop = asyncio.get_running_loop()
 
       def load_shard_wrapper():
         return asyncio.run(load_shard(model_path, shard))
 
-      model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
+      model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
       self.shard = shard
-      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 
+      self.model = model_shard 
+      self.caches = OrderedDict()
+      self.session = {}
 

+ 3 - 0
exo/inference/mlx/sharded_utils.py

@@ -145,6 +145,8 @@ def load_model_shard(
 
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
+  if DEBUG >= 8:
+    print(f"\n|| {config=} ||\n")
 
   if (quantization := config.get("quantization", None)) is not None:
     # Handle legacy models which may not have everything quantized
@@ -153,6 +155,7 @@ def load_model_shard(
         return False
       return f"{p}.scales" in weights
 
+
     nn.quantize(
       model,
       **quantization,

+ 2 - 16
exo/inference/mlx/stateful_model.py

@@ -4,9 +4,9 @@ from collections import OrderedDict
 import mlx.core as mx
 import mlx.nn as nn
 from mlx_lm.models.cache import make_prompt_cache
+import numpy as np
 
 from ..shard import Shard
-
 class StatefulModel(nn.Module):
   def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
     super().__init__()
@@ -15,20 +15,6 @@ class StatefulModel(nn.Module):
     self.max_caches = max_caches
     self.caches = OrderedDict()
   
-  def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    # if self.max_kv_size is not None:
-      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    # else:
-      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    cache = make_prompt_cache(self.model)
-
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache
-
   def __call__(self, x, request_id: Optional[str] = None, use_cache: bool = True):
     #print(f"StatefulModel in <- {x}")
     if use_cache and request_id is not None:
@@ -37,7 +23,7 @@ class StatefulModel(nn.Module):
       else:
         self.caches.move_to_end(request_id)
 
-      cache = self.caches[request_id]
+      cache = mx.array(self.caches[request_id])
       y = self.model(x, cache=cache)
     else:
       y = self.model(x)

+ 1 - 1
exo/inference/tinygrad/inference.py

@@ -128,4 +128,4 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
-      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)
+      self.model = model_shard

+ 37 - 29
exo/main.py

@@ -218,6 +218,26 @@ def clean_path(path):
         path = path.strip('Optional("').rstrip('")')
     return os.path.expanduser(path)
 
+async def hold_outstanding(node: Node):
+  while True:
+    if node.outstanding_requests:
+      await asyncio.sleep(1)
+    else:
+      return      
+
+
+async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
+  losses = []
+  tokens = []
+  for batch in tqdm(iterate_batches(data, batch_size), total=len(data) // batch_size):
+    _, _, lengths = batch
+    losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=train)))
+    tokens.append(np.sum(lengths))
+  total_tokens = np.sum(tokens)
+  total_loss = np.sum(losses) / total_tokens
+  
+  return total_loss, total_tokens
+
 async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
@@ -225,17 +245,12 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  train, val, test = dataloader(lambda i: tokenizer.encode(i))
-  dataset = test
+  train, val, test = dataloader(tokenizer.encode)
   print(f"Evaluating {len(test)} examples with batch_size {batch_size}")
-  losses = []
-  tokens = []
-  for batch in tqdm(iterate_batches(test, batch_size), total=len(dataset) // batch_size):
-    _, _, lengths = batch
-    losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
-    tokens.append(np.sum(lengths))
-  total_loss = np.sum(losses) / np.sum(tokens)
-  print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
+  loss, tokens = await run_iter(node, shard, False, test, batch_size)
+  print(f"total | {loss=}, {tokens=}")
+  print("Waiting for outstanding tasks")
+  await hold_outstanding(node)
 
 async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
   inference_class = inference_engine.__class__.__name__
@@ -244,25 +259,19 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  train, val, test = dataloader(lambda i: tokenizer.encode(i))
-  print(f"Training on {len(train)} examples with batch_size {batch_size}")
+  train, val, test = dataloader(tokenizer.encode)
+  print(f"Training on {len(train)} examples with batch_size {batch_size} for {iters} epochs")
+  for i in tqdm(range(3)):
+    await asyncio.sleep(1)
   for epoch in range(iters):
-    losses = []
-    tokens = []
-    for batch in tqdm(iterate_batches(train, batch_size), total=len(train) // batch_size):
-      _, _, lengths = batch
-      losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=True)))
-      tokens.append(np.sum(lengths))
-    total_loss = np.sum(losses) / np.sum(tokens)
-    print(f"epoch {iters}\t| loss: {total_loss}, tokens: {np.sum(tokens)}")
-
-async def hold_outstanding(node: Node):
-  while True:
-    if node.outstanding_requests:
-      await asyncio.sleep(.1)
-    else:
-      return      
+    loss, tokens = await run_iter(node, shard, True, train, batch_size)
+    print(f"epoch {epoch + 1}/{iters}\t| {loss=}, {tokens=}")
+    if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0:
+      print("Hold up let's save a checkpoint")
+      await hold_outstanding(node)
+  await hold_outstanding(node)
 
+  
 async def main():
   loop = asyncio.get_running_loop()
 
@@ -321,13 +330,12 @@ async def main():
       if not model_name:
         print("Error: This train ain't leaving the station without a model")
         return
-      await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters)
+      await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every)
     
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()
   
-  await hold_outstanding(node)
   if args.wait_for_peers > 0:
     print("Cooldown to allow peers to exit gracefully")
     for i in tqdm(range(50)):

+ 1 - 1
exo/models.py

@@ -20,7 +20,7 @@ model_cards = {
   "llama-3.2-3b": {
     "layers": 28,
     "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
        "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
     },
   },