浏览代码

implement dynamic inference engine selection

implement the system detection and inference engine selection

implement dynamic inference engine selection

implement dynamic inference engine selection

implement dynamic inference engine selection

remove inconsistency

implement dynamic inference engine selection
itsknk 1 年之前
父节点
当前提交
e934664168
共有 2 个文件被更改,包括 32 次插入27 次删除
  1. 25 1
      exo/helpers.py
  2. 7 26
      main.py

+ 25 - 1
exo/helpers.py

@@ -3,6 +3,8 @@ import asyncio
 from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
 import socket
 import random
+import platform
+import psutil
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -15,6 +17,28 @@ exo_text = r"""
  \___/_/\_\___/ 
     """
 
+def get_system_info():
+    if psutil.MACOS:
+        if platform.machine() == 'arm64':
+            return "Apple Silicon Mac"
+        elif platform.machine() in ['x86_64', 'i386']:
+            return "Intel Mac"
+        else:
+            return "Unknown Mac architecture"
+    elif psutil.LINUX:
+        return "Linux"
+    else:
+        return "Non-Mac, non-Linux system"
+
+def get_inference_engine():
+    system_info = get_system_info()
+    if system_info == "Apple Silicon Mac":
+        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+        return MLXDynamicShardInferenceEngine()
+    else:
+        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+        return TinygradDynamicShardInferenceEngine()
+
 def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
     used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
 
@@ -113,4 +137,4 @@ class AsyncCallbackSystem(Generic[K, T]):
 
     def trigger_all(self, *args: T) -> None:
         for callback in self.callbacks.values():
-            callback.set(*args)
+            callback.set(*args)

+ 7 - 26
main.py

@@ -2,16 +2,13 @@ import argparse
 import asyncio
 import signal
 import uuid
-import platform
-import psutil
-import os
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -27,33 +24,17 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 args = parser.parse_args()
 
 print_yellow_exo()
-print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
-if args.inference_engine is None:
-    if psutil.MACOS:
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        inference_engine = MLXDynamicShardInferenceEngine()
-    else:
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        import tinygrad.helpers
-        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-        inference_engine = TinygradDynamicShardInferenceEngine()
-else:
-    if args.inference_engine == "mlx":
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        inference_engine = MLXDynamicShardInferenceEngine()
-    elif args.inference_engine == "tinygrad":
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        import tinygrad.helpers
-        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-        inference_engine = TinygradDynamicShardInferenceEngine()
-    else:
-        raise ValueError(f"Inference engine {args.inference_engine} not supported")
-print(f"Using inference engine {inference_engine.__class__.__name__}")
 
+system_info = get_system_info()
+print(f"Detected system: {system_info}")
+
+inference_engine = get_inference_engine()
+print(f"Using inference engine: {inference_engine.__class__.__name__}")
 
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
 server = GRPCServer(node, args.node_host, args.node_port)