|
@@ -30,14 +30,15 @@ def get_system_info():
|
|
|
else:
|
|
|
return "Non-Mac, non-Linux system"
|
|
|
|
|
|
-def get_inference_engine():
|
|
|
- system_info = get_system_info()
|
|
|
- if system_info == "Apple Silicon Mac":
|
|
|
+def get_inference_engine(inference_engine_name):
|
|
|
+ if inference_engine_name == "mlx":
|
|
|
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
return MLXDynamicShardInferenceEngine()
|
|
|
- else:
|
|
|
+ elif inference_engine_name == "tinygrad":
|
|
|
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
return TinygradDynamicShardInferenceEngine()
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Inference engine {inference_engine_name} not supported")
|
|
|
|
|
|
def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
|
|
|
used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
|