Răsfoiți Sursa

WIP: Training works on mlx

Still debugging some tinygrad stuff, and fixing comms
Nel Nibcord 8 luni în urmă
părinte
comite
836856824e

+ 14 - 0
exo/inference/inference_engine.py

@@ -24,6 +24,20 @@ class InferenceEngine(ABC):
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     pass
   
+  async def save_session(self, key, value):
+    self.session[key] = value
+  
+  async def ensure_session(self, key, value_gen):
+    if key not in self.session:
+      await self.save_session(key, value_gen())
+  
+  async def ensure_session_match(self, key, check, value_gen):
+    if key not in self.session or not check(self.session[key]):
+      await self.save_session(key, value_gen())
+  
+  async def clear_session(self):
+    session.empty()
+  
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
     x = tokens.reshape(1, -1)

+ 3 - 5
exo/inference/mlx/losses.py

@@ -2,15 +2,13 @@ import mlx.core as mx
 import mlx.nn as nn
 def length_masked_ce_loss(model, inputs, targets, lengths):
   # Run model on inputs
-  logits = model(inputs)
-  logits = logits.astype(mx.float32)
+  logits = model(inputs).astype(mx.float32)
 
   # Mask padding tokens
   length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
 
   # Calculate the loss
   ce = nn.losses.cross_entropy(logits, targets) * length_mask
-  ntoks = length_mask.sum()
-  ce = ce.sum() / ntoks
-  return ce
+  loss = ce.sum() / length_mask.sum()
+  return loss
 

+ 30 - 11
exo/inference/mlx/sharded_inference_engine.py

@@ -2,6 +2,7 @@ import numpy as np
 import mlx.core as mx
 import mlx.nn as nn
 from mlx_lm.sample_utils import top_p_sampling
+import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
 from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
@@ -38,6 +39,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard = None
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
+    self.session = {}
 
   async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     y = mx.array(x)
@@ -61,6 +63,34 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
     #print(f"infer_tensor out -> {output_data}")
     return output_data
+  
+  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
+    await self.ensure_shard(shard)
+    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('task', lambda: ('eval', self.model.eval()))
+    #print(f"evaluate in <- {inputs}")
+    x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
+    y = mx.array(targets).astype(mx.int64)
+    l = mx.array(lengths)
+    score = await asyncio.get_running_loop().run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
+    #print(f"evaluate out -> {score}")
+    return np.array(score)
+  
+  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=optim.Adam, lr=1e-5):
+    await self.ensure_shard(shard)
+    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('LVaG', lambda: nn.value_and_grad(self.model, self.session['loss']))
+    await self.ensure_session('opt', lambda: opt(lr))
+    await self.ensure_session('task', lambda: ('train', self.model.train()))
+
+    x = mx.array(inputs).astype(mx.int64) if self.shard.is_first_layer() else mx.array(inputs)
+    y = mx.array(targets).astype(mx.int64)
+    l = mx.array(lengths)
+    loop = asyncio.get_running_loop()
+    loss, grad = await loop.run_in_executor(self.executor, self.session['LVaG'], self.model, x, y, l)
+    await loop.run_in_executor(self.executor, lambda: self.session['opt'].update(self.model, grad))
+
+    return np.array(loss), np.array(grad)
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -78,14 +108,3 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       self.shard = shard
       self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 
 
-  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
-    await self.ensure_shard(shard)
-    #print(f"evaluate in <- {inputs}")
-    x = mx.array(inputs).astype(mx.int64)
-    y = mx.array(targets).astype(mx.int64)
-    l = mx.array(lengths)
-    def model_wrapper(e):
-      return self.model(e, request_id)
-    score = await asyncio.get_running_loop().run_in_executor(self.executor, loss, model_wrapper, x, y, l)
-    #print(f"evaluate out -> {score}")
-    return np.array(score)

+ 3 - 3
exo/inference/mlx/stateful_model.py

@@ -1,4 +1,4 @@
-from typing import Dict, Tuple
+from typing import Dict, Tuple, Optional
 from collections import OrderedDict
 
 import mlx.core as mx
@@ -29,9 +29,9 @@ class StatefulModel(nn.Module):
 
     self.caches[request_id] = cache
 
-  def __call__(self, x, request_id: str, use_cache: bool = True):
+  def __call__(self, x, request_id: Optional[str] = None, use_cache: bool = True):
     #print(f"StatefulModel in <- {x}")
-    if use_cache:
+    if use_cache and request_id is not None:
       if request_id not in self.caches:
         self.init_cache(request_id)
       else:

+ 32 - 14
exo/inference/tinygrad/inference.py

@@ -5,7 +5,7 @@ from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggin
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
-from tinygrad import Tensor, nn, Context
+from tinygrad import Tensor, nn, Context, TinyJit
 from exo.inference.inference_engine import InferenceEngine
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
@@ -15,7 +15,7 @@ from .stateful_model import StatefulModel
 from .losses import length_masked_ce_loss
 import asyncio
 
-Tensor.no_grad = True
+Tensor.no_grad = False
 # default settings
 TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
 TOP_K = 25
@@ -63,6 +63,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     self.shard = None
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
+    self.session = {}
 
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
@@ -82,11 +83,37 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    #print(f"infer_tensor in <- {input_data}")
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
-    #print(f"infer_tensor out -> {output_data}")
     return output_data.numpy()
 
+  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
+    def step(x, y, l):
+      Tensor.training = False
+      return self.session['loss'](self.model, x, y, l)
+    await self.ensure_shard(shard)
+    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('jit', lambda: TinyJit(step)) 
+    score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
+    out = score.numpy()
+    return out
+  
+  async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=nn.optim.Adam, lr=1e-5):
+    def step(x, y, l):
+      Tensor.training = True
+      score = self.session['loss'](self.model, x, y, l)
+      self.session['opt'].zero_grad()
+      score.backward()
+      self.session['opt'].step()
+      return score
+    await self.ensure_shard(shard)
+    await self.ensure_session('loss', lambda: loss)
+    await self.ensure_session('opt', lambda: opt(nn.state.get_parameters(self.model.model), lr=lr))
+    await self.ensure_session('jit', lambda: TinyJit(step)) 
+      
+    score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
+    
+    return loss.numpy(), loss.numpy()
+
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
@@ -101,13 +128,4 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
-      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 
-
-  async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
-    await self.ensure_shard(shard)
-    def model_wrapper(x):
-      return self.model(x, request_id)
-    score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: loss(model_wrapper, Tensor(inputs), Tensor(targets), Tensor(lengths)).realize())
-    out = score.numpy()
-    return out
-
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)

+ 6 - 7
exo/inference/tinygrad/losses.py

@@ -1,15 +1,14 @@
 from tinygrad import Tensor, dtypes
+import numpy as np
 def length_masked_ce_loss(model, inputs, targets, lengths):
   # Run model on inputs
-  logits = model(inputs)
-  logits = logits.cast(dtypes.float32)
+  logits = model(inputs).cast(dtypes.float32).contiguous()
 
   # Mask padding tokens
-  length_mask = Tensor.arange(inputs.shape[1])[None, :] < lengths[:, None]
+  length_mask = Tensor(np.arange(inputs.shape[1])[None, :] < lengths[:, None], requires_grad=False)
 
   # Calculate the loss
-  ce = logits.sparse_categorical_crossentropy(targets) * length_mask
-  ntoks = length_mask.sum()
-  ce = ce.sum() / ntoks
-  return ce
+  ce = logits.sparse_categorical_crossentropy(Tensor(targets, requires_grad=False)).mul(length_mask)
+  loss = ce.sum() / length_mask.sum()
+  return loss
 

+ 3 - 3
exo/inference/tinygrad/stateful_model.py

@@ -1,6 +1,6 @@
 from tinygrad import Tensor, Variable 
 from collections import OrderedDict
-from typing import List
+from typing import List, Optional
 
 def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
   cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
@@ -30,10 +30,10 @@ class StatefulModel:
 
     self.states[request_id] = ModelState(cache)
 
-  def __call__(self, x: Tensor, request_id: str, use_cache: bool = True): 
+  def __call__(self, x: Tensor, request_id: Optional[str] = None, use_cache: bool = True): 
     h = self.model.embed(x)
     #print(f"StatefulModel in <- {h}")
-    if use_cache:
+    if use_cache and request_id is not None:
       if request_id not in self.states:
         self.init_cache(h, request_id)
       else:

+ 9 - 8
exo/main.py

@@ -14,7 +14,7 @@ import numpy as np
 from functools import partial
 from tqdm import tqdm
 from tqdm.asyncio import tqdm_asyncio
-from exo.train.dataset import load_dataset, iterate_batches
+from exo.train.dataset import load_dataset, iterate_batches, compose
 from exo.networking.manual.manual_discovery import ManualDiscovery
 from exo.networking.manual.network_topology_config import NetworkTopology
 from exo.orchestration.standard_node import StandardNode
@@ -40,7 +40,7 @@ parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run", "eval", "train"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("--default-model", type=str, default=None, help="Default model")
-parser.add_argument("--iters", type=int, default=600, help="Training iterations")
+parser.add_argument("--iters", type=int, default=100, help="Training iterations")
 parser.add_argument("--data", type=str, default="exo/train/data/lora", help="Directory where training data lives")
 parser.add_argument("--batch-size", type=int, default=1, help="Minibatch size.")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
@@ -223,7 +223,7 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  train, val, test = dataloader(tokenizer)
+  train, val, test = dataloader(lambda i: tokenizer.encode(i))
   dataset = test
   print(f"Evaluating {len(dataset)} examples with batch_size {batch_size}")
   losses = []
@@ -242,14 +242,14 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
   tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  train, val, test = dataloader(tokenizer)
-  print(f"Training on {len(train)} examples with batch_size {batch_size}")
+  train, val, test = dataloader(lambda i: tokenizer.encode(i))
+  print(f"Training on {len(val)} examples with batch_size {batch_size}")
   for epoch in range(iters):
     losses = []
     tokens = []
-    for batch in tqdm(iterate_batches(train, batch_size), total=len(dataset) // batch_size):
+    for batch in tqdm(iterate_batches(train, batch_size), total=len(train) // batch_size):
       _, _, lengths = batch
-      losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch)))
+      losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=True)))
       tokens.append(np.sum(lengths))
   total_loss = np.sum(losses) / np.sum(tokens)
   print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
@@ -301,7 +301,8 @@ async def main():
     await run_model_cli(node, inference_engine, model_name, args.prompt)
   elif args.command == "eval" or args.command == 'train':
     model_name = args.model_name
-    dataloader = lambda tok: load_dataset(args.data, preprocess=lambda i: tok.encode(i["text"]))
+    dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
+                                                   , loadline=lambda line: json.loads(line).get("text",""))
     if args.command == 'eval':
       if not model_name:
         print("Error: Much like a human, I can't evaluate anything without a model")

+ 2 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -107,7 +107,7 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
   
-  async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
+  async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.ExampleRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -118,6 +118,7 @@ class GRPCPeerHandle(PeerHandle):
       example=node_service_pb2.Tensor(tensor_data=example.tobytes(), shape=example.shape, dtype=str(example.dtype)),
       target=node_service_pb2.Tensor(tensor_data=target.tobytes(), shape=target.shape, dtype=str(target.dtype)),
       length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
+      train = train,
       request_id=request_id,
     )
     response = await self.stub.SendExample(request)

+ 2 - 1
exo/networking/grpc/grpc_server.py

@@ -80,9 +80,10 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     example = np.frombuffer(request.example.tensor_data, dtype=np.dtype(request.example.dtype)).reshape(request.example.shape)
     target = np.frombuffer(request.target.tensor_data, dtype=np.dtype(request.target.dtype)).reshape(request.target.shape)
     length = np.frombuffer(request.length.tensor_data, dtype=np.dtype(request.length.dtype)).reshape(request.length.shape)
+    train = request.train
     request_id = request.request_id
 
-    result = await self.node.process_example(shard, example, target, length, request_id)
+    result = await self.node.process_example(shard, example, target, length, train, request_id)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {example=} {target=} {length=} {request_id=} result: {result}")
     tensor_data = result.tobytes()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))

+ 2 - 1
exo/networking/grpc/node_service.proto

@@ -38,7 +38,8 @@ message ExampleRequest {
   Tensor example = 2;
   Tensor target = 3;
   Tensor length = 4;
-  optional string request_id = 5;
+  bool train = 5;
+  optional string request_id = 6;
 }
   
 message GetInferenceResultRequest {

Fișier diff suprimat deoarece este prea mare
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 20 - 12
exo/orchestration/standard_node.py

@@ -203,11 +203,11 @@ class StandardNode(Node):
   ):
     shard = self.get_current_shard(base_shard)
     if shard.is_first_layer():
-      resp = await self.process_example(shard, example, target, length, request_id)
+      resp = await self.process_example(shard, example, target, length, train, request_id)
     else:
       if request_id is None:
         request_id = str(uuid.uuid4())
-      resp = await self.forward_example(shard, example, target, length, request_id, 0) 
+      resp = await self.forward_example(shard, example, target, length, train, request_id, 0) 
     return resp
     
 
@@ -217,8 +217,8 @@ class StandardNode(Node):
     example: np.ndarray,
     target: np.ndarray, 
     length: np.ndarray,
-    request_id: Optional[str] = None,
     train: bool = False,
+    request_id: Optional[str] = None,
   ):
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -237,7 +237,7 @@ class StandardNode(Node):
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_example(shard, example, target, length, request_id, train=train)
+    resp = await self._process_example(shard, example, target, length, train, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -256,15 +256,15 @@ class StandardNode(Node):
       )
     )
     return resp
-  
+
   async def _process_example(
     self,
     base_shard: Shard,
     example: np.ndarray,
     target: np.ndarray, 
     length: np.ndarray,
-    request_id: Optional[str] = None,
     train: bool = False,
+    request_id: Optional[str] = None,
   ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
@@ -273,13 +273,20 @@ class StandardNode(Node):
     if DEBUG >= 1: print(f"[{request_id}] process_example: {example.shape=}")
     try:
       if shard.is_last_layer():
-        loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
-        loss_tensor = loss.reshape(1, -1)
-        return loss_tensor
+        if train:
+          loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
+          return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
+        else:
+          loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
+          return loss.reshape(example.shape[0], -1)
       else:
         step = await self.inference_engine.infer_tensor(request_id, shard, example)
-        loss = await self.forward_example(shard, step, target, length, request_id, self.get_partition_index(offset = 1))
-        return loss
+        result = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+        if train:
+          forward = self.get_current_shard(self.get_partition_index(offset = 1))
+          return result
+        else:
+          return result.reshape(example.shape[0], -1)
     except Exception as e:
       print(f"Error processing example for shard {shard}: {e}")
       traceback.print_exc()
@@ -354,6 +361,7 @@ class StandardNode(Node):
     step: np.ndarray,
     target: np.ndarray,
     length: np.ndarray,
+    train: bool,
     request_id: str,
     target_index: int,
   ) -> None:
@@ -365,7 +373,7 @@ class StandardNode(Node):
     if not target_peer:
       raise ValueError(f"peer for {target_index} not found")
     if DEBUG >= 1: print(f"sending example to {target_peer.id()}: {step} => {target} ({length})")
-    ret = await target_peer.send_example(target_shard, step, target, length, request_id=request_id)
+    ret = await target_peer.send_example(target_shard, step, target, length, request_id=request_id, train=train)
     return ret
 
   async def forward_loss(

+ 6 - 11
exo/train/dataset.py

@@ -44,19 +44,14 @@ def iterate_batches(dset, batch_size, train=False, uniform_length=True):
       break
 
 class Dataset:
-  preprocess = lambda item: item
-  load = lambda line: line
-  def __init__(self, path: Path, preprocess=None, load=None, metrics={}):
+  def __init__(self, path: Path, preprocess=lambda item: item, loadline=json.loads, metrics={}):
     if not path.exists():
       self._data = None
     else:
-      if preprocess is not None:
-        self.preprocess = preprocess
-      if load is not None:
-        self.load = load
+      self.preprocess = preprocess
       with open(path, "r") as fid:
-        self._data = [load(l) for l in fid]
-        self._maxlen = max([len(self.preprocess(x)) for x in self._data])
+        self._data = [loadline(l) for l in fid]
+        self._maxlen = max([len(preprocess(x)) for x in self._data])
         # Check if any sequence is longer than 2048 tokens
         if self._maxlen > 2048:
           print("You've got sequences with over 2048 tokens in here! Split your data fool!")
@@ -69,11 +64,11 @@ class Dataset:
     return len(self._data)
 
 
-def load_dataset(data_path: str, preprocess=None):
+def load_dataset(data_path: str, preprocess=lambda i: i, loadline=json.loads):
   def load_and_check(name):
     dataset_path = Path(data_path) / f"{name}.jsonl"
     try:
-      return Dataset(dataset_path, preprocess=preprocess, load=json.loads)
+      return Dataset(dataset_path, preprocess=preprocess, loadline=loadline)
     except Exception as e:
       print(f"Unable to build dataset {dataset_path} ({e})")
       raise

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff