|
@@ -12,7 +12,12 @@ from exo.topology.topology import Topology
|
|
|
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
|
|
from exo.helpers import DEBUG
|
|
|
import json
|
|
|
-import mlx.core as mx
|
|
|
+import platform
|
|
|
+
|
|
|
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
|
|
+ import mlx.core as mx
|
|
|
+else:
|
|
|
+ import numpy as mx
|
|
|
|
|
|
class GRPCPeerHandle(PeerHandle):
|
|
|
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
|