|
@@ -2,16 +2,13 @@ import argparse
|
|
|
import asyncio
|
|
|
import signal
|
|
|
import uuid
|
|
|
-import platform
|
|
|
-import psutil
|
|
|
-import os
|
|
|
from typing import List
|
|
|
from exo.orchestration.standard_node import StandardNode
|
|
|
from exo.networking.grpc.grpc_server import GRPCServer
|
|
|
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
|
|
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
|
|
from exo.api import ChatGPTAPI
|
|
|
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG
|
|
|
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info
|
|
|
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
@@ -27,33 +24,17 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
print_yellow_exo()
|
|
|
-print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
|
|
|
-if args.inference_engine is None:
|
|
|
- if psutil.MACOS:
|
|
|
- from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
- inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
- else:
|
|
|
- from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
- import tinygrad.helpers
|
|
|
- tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
|
|
- inference_engine = TinygradDynamicShardInferenceEngine()
|
|
|
-else:
|
|
|
- if args.inference_engine == "mlx":
|
|
|
- from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
- inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
- elif args.inference_engine == "tinygrad":
|
|
|
- from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
- import tinygrad.helpers
|
|
|
- tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
|
|
- inference_engine = TinygradDynamicShardInferenceEngine()
|
|
|
- else:
|
|
|
- raise ValueError(f"Inference engine {args.inference_engine} not supported")
|
|
|
-print(f"Using inference engine {inference_engine.__class__.__name__}")
|
|
|
|
|
|
+system_info = get_system_info()
|
|
|
+print(f"Detected system: {system_info}")
|
|
|
+
|
|
|
+inference_engine = get_inference_engine()
|
|
|
+print(f"Using inference engine: {inference_engine.__class__.__name__}")
|
|
|
|
|
|
if args.node_port is None:
|
|
|
args.node_port = find_available_port(args.node_host)
|
|
|
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
|
|
|
+
|
|
|
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
|
|
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}")
|
|
|
server = GRPCServer(node, args.node_host, args.node_port)
|