|
@@ -4,6 +4,7 @@ import signal
|
|
|
import uuid
|
|
|
import platform
|
|
|
import psutil
|
|
|
+import os
|
|
|
from typing import List
|
|
|
from exo.orchestration.standard_node import StandardNode
|
|
|
from exo.networking.grpc.grpc_server import GRPCServer
|
|
@@ -21,16 +22,32 @@ parser.add_argument("--listen-port", type=int, default=5678, help="Listening por
|
|
|
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
|
|
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
|
|
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
|
|
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
print_yellow_exo()
|
|
|
print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
|
|
|
-if psutil.MACOS:
|
|
|
- from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
- inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
+if args.inference_engine is None:
|
|
|
+ if psutil.MACOS:
|
|
|
+ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
+ inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
+ else:
|
|
|
+ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
+ import tinygrad.helpers
|
|
|
+ tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
|
|
+ inference_engine = TinygradDynamicShardInferenceEngine()
|
|
|
else:
|
|
|
- from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
- inference_engine = TinygradDynamicShardInferenceEngine()
|
|
|
+ if args.inference_engine == "mlx":
|
|
|
+ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
+ inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
+ elif args.inference_engine == "tinygrad":
|
|
|
+ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
+ import tinygrad.helpers
|
|
|
+ tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
|
|
+ inference_engine = TinygradDynamicShardInferenceEngine()
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Inference engine {args.inference_engine} not supported")
|
|
|
+print(f"Using inference engine {inference_engine.__class__.__name__}")
|
|
|
|
|
|
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
|
|
|
node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())
|