Browse Source

allow overriding inference_engine and separate flag for TINYGRAD_DEBUG

Alex Cheema 9 months ago
parent
commit
945f90f676
1 changed files with 22 additions and 5 deletions
  1. 22 5
      main.py

+ 22 - 5
main.py

@@ -4,6 +4,7 @@ import signal
 import uuid
 import platform
 import psutil
+import os
 from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
@@ -21,16 +22,32 @@ parser.add_argument("--listen-port", type=int, default=5678, help="Listening por
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 args = parser.parse_args()
 
 print_yellow_exo()
 print(f"Starting exo {platform.system()=} {psutil.virtual_memory()=}")
-if psutil.MACOS:
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-    inference_engine = MLXDynamicShardInferenceEngine()
+if args.inference_engine is None:
+    if psutil.MACOS:
+        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+        inference_engine = MLXDynamicShardInferenceEngine()
+    else:
+        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+        import tinygrad.helpers
+        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+        inference_engine = TinygradDynamicShardInferenceEngine()
 else:
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    inference_engine = TinygradDynamicShardInferenceEngine()
+    if args.inference_engine == "mlx":
+        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+        inference_engine = MLXDynamicShardInferenceEngine()
+    elif args.inference_engine == "tinygrad":
+        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+        import tinygrad.helpers
+        tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+        inference_engine = TinygradDynamicShardInferenceEngine()
+    else:
+        raise ValueError(f"Inference engine {args.inference_engine} not supported")
+print(f"Using inference engine {inference_engine.__class__.__name__}")
 
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
 node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy())