Ver Fonte

fix to creating engines

josh há 8 meses atrás
pai
commit
cd4d324a5e
3 ficheiros alterados com 33 adições e 28 exclusões
  1. 2 1
      exo/main.py
  2. 1 5
      exo/networking/udp/udp_discovery.py
  3. 30 22
      exo/orchestration/standard_node.py

+ 2 - 1
exo/main.py

@@ -95,7 +95,8 @@ node = StandardNode(
   discovery,
   discovery,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   max_generate_tokens=args.max_generate_tokens,
-  topology_viz=topology_viz
+  topology_viz=topology_viz,
+  shard_downloader=shard_downloader
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server

+ 1 - 5
exo/networking/udp/udp_discovery.py

@@ -159,11 +159,7 @@ class UDPDiscovery(Discovery):
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           return
           return
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
-    if message["type"] == "supported_inference_engines":
-      logger.error(f'supported_inference_engines: {message}')
-      peer_id = message["node_id"]
-      engines = message["engines"]
-      if peer_id in self.known_peers: self.known_peers[peer_id][0].topology_inference_engines_pool.append(engines)
+
   async def task_listen_for_peers(self):
   async def task_listen_for_peers(self):
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
     await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
                                                             local_addr=("0.0.0.0", self.listen_port))
                                                             local_addr=("0.0.0.0", self.listen_port))

+ 30 - 22
exo/orchestration/standard_node.py

@@ -4,7 +4,7 @@ import asyncio
 import uuid
 import uuid
 import time
 import time
 import traceback
 import traceback
-from typing import List, Dict, Optional, Tuple, Union
+from typing import List, Dict, Optional, Tuple, Union, Set
 from exo.networking import Discovery, PeerHandle, Server
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.inference.inference_engine import InferenceEngine, Shard
 from .node import Node
 from .node import Node
@@ -15,6 +15,8 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
 import logging
 import logging
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -29,6 +31,7 @@ class StandardNode(Node):
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     max_generate_tokens: int = 1024,
     topology_viz: Optional[TopologyViz] = None,
     topology_viz: Optional[TopologyViz] = None,
+    shard_downloader: Optional[HFShardDownloader] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
@@ -45,7 +48,8 @@ class StandardNode(Node):
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
-    self.topology_inference_engines_pool: List[str] = []
+    self.topology_inference_engines_pool: List[List[str]] = []
+    self.shard_downloader = shard_downloader
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
     await self.server.start()
@@ -62,6 +66,14 @@ class StandardNode(Node):
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
       status_data = json.loads(opaque_status)
       status_data = json.loads(opaque_status)
+      if status_data.get("type", "") == "supported_inference_engines":
+        logger.error(f'supported_inference_engines: {status_data}')
+        logger.error('inside on_status data')
+        node_id = status_data.get("node_id")
+        engines = status_data.get("engines", [])
+        logger.error(f'engines: {engines}')
+        self.topology_inference_engines_pool.append(engines)
+        logger.error(f'topology_inference_engines_pool: {self.topology_inference_engines_pool}')
       if status_data.get("type", "") == "node_status":
       if status_data.get("type", "") == "node_status":
         if status_data.get("status", "").startswith("start_"):
         if status_data.get("status", "").startswith("start_"):
           self.current_topology.active_node_id = status_data.get("node_id")
           self.current_topology.active_node_id = status_data.get("node_id")
@@ -82,22 +94,24 @@ class StandardNode(Node):
   def get_supported_inference_engines(self):
   def get_supported_inference_engines(self):
     supported_engine_names = []
     supported_engine_names = []
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-      supported_engine_names.extend(['mlx', 'tinygrad'])
+        supported_engine_names.append('mlx')
+        supported_engine_names.append('tinygrad')
     else:
     else:
-      supported_engine_names.append('tinygrad')
+        supported_engine_names.append('tinygrad')
     return supported_engine_names
     return supported_engine_names
 
 
-  async def broadcast_supported_engines(self, supported_engines: List):
+  async def broadcast_supported_engines(self, supported_engines_names: List[str]):
     status_message = json.dumps({
     status_message = json.dumps({
-      "type": "supported_inference_engines",
-      "node_id": self.id,
-      "engines": supported_engines
+        "type": "supported_inference_engines",
+        "node_id": self.id,
+        "engines": supported_engines_names
     })
     })
     logger.error(f'broadcast_supported_engines: {status_message}')
     logger.error(f'broadcast_supported_engines: {status_message}')
     await self.broadcast_opaque_status("", status_message)
     await self.broadcast_opaque_status("", status_message)
     logger.error(f'broadcast_supported_engines: done')
     logger.error(f'broadcast_supported_engines: done')
 
 
-  def get_topology_inference_engines(self) -> List[str]:
+  def get_topology_inference_engines(self) -> List[List[str]]:
+    logger.error(f'topology_inference_engines_pool: {self.topology_inference_engines_pool}')
     return self.topology_inference_engines_pool
     return self.topology_inference_engines_pool
 
 
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
@@ -365,20 +379,15 @@ class StandardNode(Node):
   async def select_best_inference_engine(self):
   async def select_best_inference_engine(self):
     supported_engines = self.get_supported_inference_engines()
     supported_engines = self.get_supported_inference_engines()
     await self.broadcast_supported_engines(supported_engines)
     await self.broadcast_supported_engines(supported_engines)
-    logger.error('ABOVE and ALL')
-    logger.error("Topology inference engines pool: %s", self.get_topology_inference_engines())
+    logger.error("Topology inference engines pool: %s", len(self.get_topology_inference_engines()))
     logger.error(f'result:{self.get_topology_inference_engines()}')
     logger.error(f'result:{self.get_topology_inference_engines()}')
-    if self.get_topology_inference_engines():
-      logger.info("Topology inference engines pool: %s", self.get_topology_inference_engines())
-      topology_inference_engines_pool = self.get_topology_inference_engines()
-      if any("tinygrad" in engines and len(engines) == 1 for engines in topology_inference_engines_pool):
-          return "tinygrad"
-      common_engine_across_peers = set.intersection(*topology_inference_engines_pool)
-      if "mlx" in common_engine_across_peers:
-          return "mlx"
+    if len(self.get_topology_inference_engines()):
+      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
+        logger.info("Found node with only tinygrad, using tinygrad on all nodes")
+        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
       else:
       else:
-          raise ValueError("No compatible inference engine found across all nodes")
-
+        logger.info("all nodes can use mlx, using mlx for inference")
+        self.inference_engine = get_inference_engine("mlx", self.shard_downloader) 
 
 
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
@@ -462,7 +471,6 @@ class StandardNode(Node):
 
 
     async def send_status_to_peer(peer):
     async def send_status_to_peer(peer):
       try:
       try:
-        status_dict = json.loads(status)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
       except asyncio.TimeoutError:
       except asyncio.TimeoutError:
         print(f"Timeout sending opaque status to {peer.id()}")
         print(f"Timeout sending opaque status to {peer.id()}")