Browse Source

bench wip

Alex Cheema 5 months ago
parent
commit
efb5279975

+ 6 - 3
exo/inference/mlx/sharded_inference_engine.py

@@ -26,15 +26,18 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       inputs = await loop.run_in_executor(self.executor, tokenize)
       pixel_values = mx.array(inputs["pixel_values"])
       input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
+      o = await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values)
+      output_data: np.ndarray = np.array(o, copy=False)
     else:
       input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
+      o = await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids)
+      output_data: np.ndarray = np.array(o, copy=False)
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
+    o = await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data))
+    output_data: np.ndarray = np.array(o, copy=False)
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
   async def ensure_shard(self, shard: Shard):

+ 8 - 1
exo/inference/tokenizers.py

@@ -19,7 +19,15 @@ class DummyTokenizer:
     return "dummy"
 
 
+cached_tokenizers = {}
 async def resolve_tokenizer(model_id: str):
+  if model_id in cached_tokenizers:
+    return cached_tokenizers[model_id]
+  tokenizer = await resolve_tokenizer_no_cache(model_id)
+  cached_tokenizers[model_id] = tokenizer
+  return tokenizer
+
+async def resolve_tokenizer_no_cache(model_id: str):
   if model_id == "dummy":
     return DummyTokenizer()
   local_path = await get_local_snapshot_dir(model_id)
@@ -33,7 +41,6 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
 
-
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")

+ 70 - 20
exo/main.py

@@ -25,10 +25,11 @@ from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.models import model_base_shards
 from exo.viz.topology_viz import TopologyViz
+import yappi
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
-parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
+parser.add_argument("command", nargs="?", choices=["run", "bench"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
@@ -48,6 +49,7 @@ parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
+parser.add_argument("--bench", action=argparse.BooleanOptionalAction, help="Benchmark the model")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
@@ -123,7 +125,7 @@ api = ChatGPTAPI(
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
-)
+) if not args.bench else None
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
@@ -143,24 +145,7 @@ def preemptively_start_download(request_id: str, opaque_status: str):
 
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
-stats = {}
-def collect_tps_stats(request_id: str, tokens: List[int], is_finished: bool):
-  if not request_id in stats:
-    stats[request_id] = { "first_token_time": time.perf_counter_ns() }
-  stats[request_id]["n_tokens"] = len(tokens)
-  stats[request_id]["tps"] = 1e9 * (stats[request_id]["n_tokens"] - 1) / (time.perf_counter_ns() - stats[request_id]["first_token_time"])
-  if is_finished:
-    stats[request_id]["end_time"] = time.perf_counter_ns()
-  print(stats)
-
-node.on_token.register("collect_tps_stats").on_next(collect_tps_stats)
-if args.prometheus_client_port:
-  from exo.stats.metrics import start_metrics_server
-  start_metrics_server(node, args.prometheus_client_port)
-
 last_broadcast_time = 0
-
-
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
   global last_broadcast_time
   current_time = time.time()
@@ -212,6 +197,59 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
   finally:
     node.on_token.deregister(callback_id)
 
+async def bench(node: Node, inference_engine: InferenceEngine):
+  prompt = "write an essay about war"
+  model_name = "llama-3.2-3b"
+  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  if not shard:
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    return
+  tokenizer = await resolve_tokenizer(shard.model_id)
+  request_id = str(uuid.uuid4())
+  callback_id = f"cli-wait-response-{request_id}"
+  callback = node.on_token.register(callback_id)
+  if topology_viz:
+    topology_viz.update_prompt(request_id, prompt)
+  prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+  stats = { "start_time": time.perf_counter_ns() }
+  def collect_tps_stats(request_id: str, tokens: List[int], is_finished: bool):
+    if not request_id in stats:
+      stats[request_id] = { "first_token_time": time.perf_counter_ns() }
+    stats[request_id]["n_tokens"] = len(tokens)
+    stats[request_id]["tps"] = 1e9 * (stats[request_id]["n_tokens"] - 1) / (time.perf_counter_ns() - stats[request_id]["first_token_time"])
+    if is_finished:
+      stats[request_id]["end_time"] = time.perf_counter_ns()
+    print(stats)
+  node.on_token.register("collect_tps_stats").on_next(collect_tps_stats)
+
+  try:
+    print(f"Processing prompt: {prompt}")
+    await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+
+    print("\nGenerated response:")
+    print(tokenizer.decode(tokens))
+
+    # Write benchmark results to file
+    prompt_tokens = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)
+    benchmark_results = {
+      "total_time_ns": stats[request_id]["end_time"] - stats["start_time"],
+      "first_token_latency_ns": stats[request_id]["first_token_time"] - stats["start_time"],
+      "tokens_generated": stats[request_id]["n_tokens"],
+      "prompt_num_tokens": len(prompt_tokens),
+      "prompt_tokens_per_second": 1e9 * len(prompt_tokens) / (stats[request_id]["first_token_time"] - stats["start_time"]),
+      "generation_tokens_per_second": 1e9 * (stats[request_id]["n_tokens"] - 1) / (stats[request_id]["end_time"] - stats[request_id]["first_token_time"])
+    }
+    with open("benchmark_results.json", "w") as f:
+      json.dump(benchmark_results, f, indent=2)
+  except Exception as e:
+    print(f"Error processing prompt: {str(e)}")
+    traceback.print_exc()
+  finally:
+    node.on_token.deregister(callback_id)
+
 
 async def main():
   loop = asyncio.get_running_loop()
@@ -225,7 +263,9 @@ async def main():
 
   await node.start(wait_for_peers=args.wait_for_peers)
 
-  if args.command == "run" or args.run_model:
+  if args.command == "bench" or args.bench:
+    await bench(node, inference_engine)
+  elif args.command == "run" or args.run_model:
     model_name = args.model_name or args.run_model
     if not model_name:
       print("Error: Model name is required when using 'run' command or --run-model")
@@ -240,7 +280,17 @@ def run():
   loop = asyncio.new_event_loop()
   asyncio.set_event_loop(loop)
   try:
+    yappi.set_clock_type('wall')
+    yappi.start()
+
     loop.run_until_complete(main())
+
+    yappi.stop()
+    # Print profiling results
+    func_stats = yappi.get_func_stats()
+    func_stats.sort('ttot').print_all()
+    func_stats.save('callgrind.out', type='callgrind')
+
   except KeyboardInterrupt:
     print("Received keyboard interrupt. Shutting down...")
   finally:

+ 3 - 3
exo/orchestration/standard_node.py

@@ -171,7 +171,7 @@ class StandardNode(Node):
     if not is_finished:
       asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
 
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    return np.asarray(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 
   async def process_tensor(
     self,
@@ -247,7 +247,7 @@ class StandardNode(Node):
       if not is_finished:
         asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
 
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      return np.asarray(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
@@ -380,7 +380,7 @@ class StandardNode(Node):
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     if request_id not in self.buffered_token_output:
       return None, False
-    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+    return np.array(self.buffered_token_output[request_id][0], copy=False), self.buffered_token_output[request_id][1]
 
   async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
     next_topology = Topology()