1
0
Эх сурвалжийг харах

inference engine selection improvements

JakobDylanC 9 сар өмнө
parent
commit
f2f61ccee6
2 өөрчлөгдсөн 7 нэмэгдсэн , 5 устгасан
  1. 5 4
      exo/helpers.py
  2. 2 1
      main.py

+ 5 - 4
exo/helpers.py

@@ -30,14 +30,15 @@ def get_system_info():
     else:
         return "Non-Mac, non-Linux system"
 
-def get_inference_engine():
-    system_info = get_system_info()
-    if system_info == "Apple Silicon Mac":
+def get_inference_engine(inference_engine_name):
+    if inference_engine_name == "mlx":
         from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
         return MLXDynamicShardInferenceEngine()
-    else:
+    elif inference_engine_name == "tinygrad":
         from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
         return TinygradDynamicShardInferenceEngine()
+    else:
+        raise ValueError(f"Inference engine {inference_engine_name} not supported")
 
 def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
     used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')

+ 2 - 1
main.py

@@ -30,7 +30,8 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-inference_engine = get_inference_engine()
+inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
+inference_engine = get_inference_engine(inference_engine_name)
 print(f"Using inference engine: {inference_engine.__class__.__name__}")
 
 if args.node_port is None: