1
0
Эх сурвалжийг харах

Merge pull request #48 from itsknk/intel-mac

Implement dynamic inference engine selection #45
Alex Cheema 1 жил өмнө
parent
commit
2e419ba211
2 өөрчлөгдсөн 32 нэмэгдсэн , 27 устгасан
  1. 25 1
      exo/helpers.py
  2. 7 26
      main.py

+ 25 - 1
exo/helpers.py

@@ -3,6 +3,8 @@ import asyncio
 from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
 import socket
 import random
+import platform
+import psutil
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -15,6 +17,28 @@ exo_text = r"""
  \___/_/\_\___/ 
     """
 
+def get_system_info():
+    if psutil.MACOS:
+        if platform.machine() == 'arm64':
+            return "Apple Silicon Mac"
+        elif platform.machine() in ['x86_64', 'i386']:
+            return "Intel Mac"
+        else:
+            return "Unknown Mac architecture"
+    elif psutil.LINUX:
+        return "Linux"
+    else:
+        return "Non-Mac, non-Linux system"
+
+def get_inference_engine():
+    system_info = get_system_info()
+    if system_info == "Apple Silicon Mac":
+        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+        return MLXDynamicShardInferenceEngine()
+    else:
+        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+        return TinygradDynamicShardInferenceEngine()
+
 def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
     used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
 
@@ -113,4 +137,4 @@ class AsyncCallbackSystem(Generic[K, T]):
 
     def trigger_all(self, *args: T) -> None:
         for callback in self.callbacks.values():
-            callback.set(*args)
+            callback.set(*args)

+ 7 - 26
main.py

@@ -2,16 +2,13 @@ import argparse
 import asyncio
 import signal
 import uuid
-import platform
-import psutil
-import os
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -28,33 +25,17 @@ parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help
 args = parser.parse_args()
 
 print_yellow_exo()
-print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
-if args.inference_engine is None:
-    if psutil.MACOS:
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        inference_engine = MLXDynamicShardInferenceEngine()
-    else:
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        import tinygrad.helpers
-        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-        inference_engine = TinygradDynamicShardInferenceEngine()
-else:
-    if args.inference_engine == "mlx":
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        inference_engine = MLXDynamicShardInferenceEngine()
-    elif args.inference_engine == "tinygrad":
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        import tinygrad.helpers
-        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-        inference_engine = TinygradDynamicShardInferenceEngine()
-    else:
-        raise ValueError(f"Inference engine {args.inference_engine} not supported")
-print(f"Using inference engine {inference_engine.__class__.__name__}")
 
+system_info = get_system_info()
+print(f"Detected system: {system_info}")
+
+inference_engine = get_inference_engine()
+print(f"Using inference engine: {inference_engine.__class__.__name__}")
 
 if args.node_port is None:
     args.node_port = find_available_port(args.node_host)
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui)
 server = GRPCServer(node, args.node_host, args.node_port)