瀏覽代碼

fix to broadcast

josh 8 月之前
父節點
當前提交
593d810db8
共有 3 個文件被更改,包括 9 次插入8 次删除
  1. 1 2
      exo/main.py
  2. 1 6
      exo/networking/udp/udp_discovery.py
  3. 7 0
      exo/orchestration/standard_node.py

+ 1 - 2
exo/main.py

@@ -210,13 +210,12 @@ async def main():
     loop.add_signal_handler(s, handle_exit)
     loop.add_signal_handler(s, handle_exit)
 
 
   await node.start(wait_for_peers=args.wait_for_peers)
   await node.start(wait_for_peers=args.wait_for_peers)
-
+  await select_best_inference_engine(node)
   if args.command == "run" or args.run_model:
   if args.command == "run" or args.run_model:
     model_name = args.model_name or args.run_model
     model_name = args.model_name or args.run_model
     if not model_name:
     if not model_name:
       print("Error: Model name is required when using 'run' command or --run-model")
       print("Error: Model name is required when using 'run' command or --run-model")
       return
       return
-    await select_best_inference_engine(node)
     await run_model_cli(node, inference_engine, model_name, args.prompt)
     await run_model_cli(node, inference_engine, model_name, args.prompt)
   else:
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task

+ 1 - 6
exo/networking/udp/udp_discovery.py

@@ -159,12 +159,7 @@ class UDPDiscovery(Discovery):
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           return
           return
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
-    elif message["type"] == "supported_inference_engines":
-      logger.error(f'supported_inference_engines: {message}')
-      peer_id = message["node_id"]
-      engines = message["engines"]
-      if peer_id in self.known_peers: self.known_peers[peer_id][0].topology_inference_engines_pool.append(engines)
-
+  
   async def task_listen_for_peers(self):
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
                                                             local_addr=("0.0.0.0", self.listen_port))
                                                             local_addr=("0.0.0.0", self.listen_port))

+ 7 - 0
exo/orchestration/standard_node.py

@@ -15,7 +15,9 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
+import logging
 
 
+logger = logging.getLogger(__name__)
 
 
 class StandardNode(Node):
 class StandardNode(Node):
   def __init__(
   def __init__(
@@ -91,7 +93,9 @@ class StandardNode(Node):
       "node_id": self.id,
       "node_id": self.id,
       "engines": supported_engines
       "engines": supported_engines
     })
     })
+    logger.error(f'broadcast_supported_engines: {status_message}')
     await self.broadcast_opaque_status("", status_message)
     await self.broadcast_opaque_status("", status_message)
+    logger.error(f'broadcast_supported_engines: done')
 
 
   def get_topology_inference_engines(self) -> List[str]:
   def get_topology_inference_engines(self) -> List[str]:
     return self.topology_inference_engines_pool
     return self.topology_inference_engines_pool
@@ -438,6 +442,9 @@ class StandardNode(Node):
 
 
     async def send_status_to_peer(peer):
     async def send_status_to_peer(peer):
       try:
       try:
+        status_dict = json.loads(status)
+        if status_dict.get("type") == "supported_inference_engines":
+          logger.error(f'broadcasting_inference_engines: {status_dict}')
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
       except asyncio.TimeoutError:
       except asyncio.TimeoutError:
         print(f"Timeout sending opaque status to {peer.id()}")
         print(f"Timeout sending opaque status to {peer.id()}")