Browse Source

added error

josh 6 tháng trước cách đây
mục cha
commit
a03f3a2a5c
3 tập tin đã thay đổi với 39 bổ sung41 xóa
  1. 22 33
      exo/main.py
  2. 4 0
      exo/networking/udp/udp_discovery.py
  3. 13 8
      exo/orchestration/standard_node.py

+ 22 - 33
exo/main.py

@@ -2,6 +2,8 @@ import argparse
 import asyncio
 import signal
 import json
+import logging
+import seqlog
 import time
 import traceback
 import uuid
@@ -51,6 +53,13 @@ args = parser.parse_args()
 
 print_yellow_exo()
 
+seqlog.log_to_seq(
+  server_url = 'http://localhost:5341',
+  level = logging.DEBUG
+)
+logger = logging.getLogger('exo_logger')
+logger.setLevel(logging.DEBUG)
+
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
@@ -145,37 +154,18 @@ async def shutdown(signal, loop):
   loop.stop()
 
 async def select_best_inference_engine(node: StandardNode):
-  supported_engines = set(node.get_supported_inference_engines())
+  supported_engines = node.get_supported_inference_engines()
   await node.broadcast_supported_engines(supported_engines)
-  num_peers = len(node.peers)
-  all_peers_responded = asyncio.Event()
-  def check_all_responses():
-      if len(node.received_opaque_statuses) >= num_peers:
-          all_peers_responded.set()
-  node.on_opaque_status.register("engine_selection").on_next(lambda *args: check_all_responses())
-  try:
-      await asyncio.wait_for(all_peers_responded.wait(), timeout=10.0)
-  except asyncio.TimeoutError:
-      print("Timed out waiting for peer nodes to respond.")
-  node.on_opaque_status.unregister("engine_selection")
-  all_supported_engines = [supported_engines]
-  for peer_id, status in node.received_opaque_statuses:
-      try:
-          status_data = json.loads(status)
-          if status_data.get("type") == "supported_inference_engines":
-              all_supported_engines.append(set(status_data.get("engines", [])))
-      except json.JSONDecodeError:
-          continue
-  if any("tinygrad" in engines and len(engines) == 1 for engines in all_supported_engines):
-      return "tinygrad"
-  common_engine_across_peers = set.intersection(*all_supported_engines)
-  if "mlx" in common_engine_across_peers:
-      print('mlx')
-      return "mlx"
-  elif "tinygrad" in common_engine_across_peers:
-      return "tinygrad"
-  else:
-      raise ValueError("No compatible inference engine found across all nodes")
+
+  if node.get_topology_inference_engines():
+    topology_inference_engines_pool = node.get_topology_inference_engines()
+    if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
+        return "tinygrad"
+    common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
+    if "mlx" in common_engine_across_peers:
+        return "mlx"
+    else:
+        raise ValueError("No compatible inference engine found across all nodes")
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
@@ -217,9 +207,8 @@ async def main():
     loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
-  if len(node.peers) > 1:
-    compatible_engine = await select_best_inference_engine(node)
-    node.inference_engine = get_inference_engine(compatible_engine, shard_downloader)
+  await select_best_inference_engine(node)
+  
 
   if args.command == "run" or args.run_model:
     model_name = args.model_name or args.run_model

+ 4 - 0
exo/networking/udp/udp_discovery.py

@@ -159,6 +159,10 @@ class UDPDiscovery(Discovery):
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           return
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
+    elif message["type"] == "supported_inference_engines":
+      peer_id = message["node_id"]
+      engines = message["engines"]
+      if peer_id in self.known_peers: self.known_peers[peer_id][0].topology_inference_engines_pool.append(engines)
 
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),

+ 13 - 8
exo/orchestration/standard_node.py

@@ -43,6 +43,7 @@ class StandardNode(Node):
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.topology_inference_engines_pool: List[str] = []
 
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
@@ -77,19 +78,23 @@ class StandardNode(Node):
       if DEBUG >= 1: traceback.print_exc()
 
   def get_supported_inference_engines(self):
-    supported_engines = []
-    if self.inferenceEngine == 'mlx':
-      supported_engines.extend('mlx', 'tinygrad')
+    supported_engine_names = []
+    if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
+      supported_engine_names.extend(['mlx', 'tinygrad'])
     else:
-      supported_engines.append('tinygrad')
-    return supported_engines
+      supported_engine_names.append('tinygrad')
+    return supported_engine_names
 
   async def broadcast_supported_engines(self, supported_engines: List):
-    await self.broadcast_opaque_status("", json.dumps({
+    status_message = json.dumps({
       "type": "supported_inference_engines",
-      "node_id": self.id, 
+      "node_id": self.id,
       "engines": supported_engines
-    }))
+    })
+    await self.broadcast_opaque_status("", status_message)
+
+  def get_topology_inference_engines(self) -> List[str]:
+    return self.topology_inference_engines_pool
 
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)