Răsfoiți Sursa

Coordination biz

Nel Nibcord 8 luni în urmă
părinte
comite
175ebc1c42

+ 8 - 0
exo/inference/mlx/sharded_inference_engine.py

@@ -56,6 +56,14 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
     return tokens
+
+  async def save_checkpoint(self, path: Path):
+    await self.ensure_shard(shard)
+    await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
+
+  async def load_checkpoint(self, path: Path):
+    await self.ensure_shard(shard)
+    await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
     
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)

+ 14 - 3
exo/main.py

@@ -41,8 +41,10 @@ parser.add_argument("command", nargs="?", choices=["run", "eval", "train"], help
 parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--iters", type=int, default=100, help="Training iterations")
+parser.add_argument("--save-every", type=int, default=5, help="Save the model every N iterations.")
 parser.add_argument("--data", type=str, default="exo/train/data/lora", help="Directory where training data lives")
 parser.add_argument("--batch-size", type=int, default=1, help="Minibatch size.")
+parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="Directory from which to load and save checkpoints")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
@@ -235,7 +237,7 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
   total_loss = np.sum(losses) / np.sum(tokens)
   print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
 
-async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters):
+async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
   inference_class = inference_engine.__class__.__name__
   shard = build_base_shard(model_name, inference_class)
   if not shard:
@@ -251,8 +253,15 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
       _, _, lengths = batch
       losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=True)))
       tokens.append(np.sum(lengths))
-  total_loss = np.sum(losses) / np.sum(tokens)
-  print(f"total | loss: {total_loss}, tokens: {np.sum(tokens)}")
+    total_loss = np.sum(losses) / np.sum(tokens)
+    print(f"epoch {iters}\t| loss: {total_loss}, tokens: {np.sum(tokens)}")
+
+async def hold_outstanding(node: Node):
+  while True:
+    if node.outstanding_requests:
+      await asyncio.sleep(.1)
+    else:
+      return      
 
 async def main():
   loop = asyncio.get_running_loop()
@@ -317,6 +326,8 @@ async def main():
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()
+  
+  await hold_outstanding(node)
   if args.wait_for_peers > 0:
     print("Cooldown to allow peers to exit gracefully")
     for i in tqdm(range(50)):

+ 22 - 2
exo/orchestration/standard_node.py

@@ -53,6 +53,7 @@ class StandardNode(Node):
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.topology_inference_engines_pool: List[List[str]] = []
     self.shard_downloader = shard_downloader
+    self.outstanding_requests = {}
 
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
@@ -119,8 +120,10 @@ class StandardNode(Node):
       token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
       await self.inference_engine.ensure_shard(shard)
       self.buffered_token_output[request_id][0].append(token.item())
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
       forward = token.reshape(1, -1)
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
       asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
@@ -129,7 +132,9 @@ class StandardNode(Node):
 
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      self.outstanding_requests.pop(request_id)
     else:
+      self.outstanding_requests[request_id] = "waiting"
       asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
 
     return np.array(self.buffered_token_output[request_id][0])
@@ -185,11 +190,13 @@ class StandardNode(Node):
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      self.outstanding_requests[request_id] = "waiting"
       resp = await self.forward_prompt(shard, prompt, request_id, 0)
       return None
     else:
+      self.outstanding_requests[request_id] = "processing"
       result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_inference_result(shard, result, request_id) 
+      ret = await self.process_inference_result(shard, result, request_id)
       return result
 
   async def enqueue_example(
@@ -207,6 +214,7 @@ class StandardNode(Node):
     else:
       if request_id is None:
         request_id = str(uuid.uuid4())
+      self.outstanding_requests[request_id] = "waiting"
       resp = await self.forward_example(shard, example, target, length, train, request_id, 0) 
     return resp
     
@@ -274,20 +282,30 @@ class StandardNode(Node):
       target = target.astype(int)
       if train:
         if shard.is_last_layer():
+          self.outstanding_requests[request_id] = "training"
           loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
         else:
+          self.outstanding_requests[request_id] = "preprocessing"
           step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          self.outstanding_requests[request_id] = "waiting"
           backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+          self.outstanding_requests[request_id] = "training"
           loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
+        self.outstanding_requests.pop(request_id)
         return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
       else:
         if shard.is_last_layer():
+          self.outstanding_requests[request_id] = "evaluating"
           loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
         else:
+          self.outstanding_requests[request_id] = "preprocessing"
           step = await self.inference_engine.infer_tensor(request_id, shard, example)
+          self.outstanding_requests[request_id] = "waiting"
           loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+        self.outstanding_requests.pop(request_id)
         return loss.reshape(example.shape[0], -1)
     except Exception as e:
+      self.outstanding_requests.pop(request_id)
       print(f"Error processing example for shard {shard}: {e}")
       traceback.print_exc()
       return None
@@ -347,10 +365,12 @@ class StandardNode(Node):
 
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
+      self.outstanding_requests[request_id] = "processing"
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
       ret = await self.process_inference_result(shard, result, request_id) 
       return ret
     except Exception as e:
+      self.outstanding_requests.pop(request_id)
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       return None