Browse Source

implement dynamic inference engine selection

implement the system detection and inference engine selection

implement dynamic inference engine selection

implement dynamic inference engine selection

implement dynamic inference engine selection

remove inconsistency

implement dynamic inference engine selection
itsknk 11 months ago
parent
commit
e934664168
2 changed files with 32 additions and 27 deletions
  1. 25 1
      exo/helpers.py
  2. 7 26
      main.py

+ 25 - 1
exo/helpers.py

@@ -3,6 +3,8 @@ import asyncio
 from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
 import socket
 import random
+import platform
+import psutil
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -15,6 +17,28 @@ exo_text = r"""
  \___/_/\_\___/ 
     """
 
+def get_system_info():
+    if psutil.MACOS:
+        if platform.machine() == 'arm64':
+            return "Apple Silicon Mac"
+        elif platform.machine() in ['x86_64', 'i386']:
+            return "Intel Mac"
+        else:
+            return "Unknown Mac architecture"
+    elif psutil.LINUX:
+        return "Linux"
+    else:
+        return "Non-Mac, non-Linux system"
+
+def get_inference_engine():
+    system_info = get_system_info()
+    if system_info == "Apple Silicon Mac":
+        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+        return MLXDynamicShardInferenceEngine()
+    else:
+        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+        return TinygradDynamicShardInferenceEngine()
+
 def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
     used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
 
@@ -113,4 +137,4 @@ class AsyncCallbackSystem(Generic[K, T]):
 
     def trigger_all(self, *args: T) -> None:
         for callback in self.callbacks.values():
-            callback.set(*args)
+            callback.set(*args)

+ 7 - 26
main.py

@@ -2,16 +2,13 @@ import argparse
 import asyncio
 import signal
 import uuid
-import platform
-import psutil
-import os
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -27,33 +24,17 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 args = parser.parse_args()
 
 print_yellow_exo()
-print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
-if args.inference_engine is None:
-    if psutil.MACOS:
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        inference_engine = MLXDynamicShardInferenceEngine()
-    else:
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        import tinygrad.helpers
-        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-        inference_engine = TinygradDynamicShardInferenceEngine()
-else:
-    if args.inference_engine == "mlx":
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        inference_engine = MLXDynamicShardInferenceEngine()
-    elif args.inference_engine == "tinygrad":
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        import tinygrad.helpers
-        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-        inference_engine = TinygradDynamicShardInferenceEngine()
-    else:
-        raise ValueError(f"Inference engine {args.inference_engine} not supported")
-print(f"Using inference engine {inference_engine.__class__.__name__}")
 
+system_info = get_system_info()
+print(f"Detected system: {system_info}")
+
+inference_engine = get_inference_engine()
+print(f"Using inference engine: {inference_engine.__class__.__name__}")
 
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
 server = GRPCServer(node, args.node_host, args.node_port)