Bladeren bron

Fixed MLX import blocking native Windows execution of exo. (Not Final)

Sandesh Bharadwaj 3 maanden geleden
bovenliggende
commit
6737e36e23
3 gewijzigde bestanden met toevoegingen van 20 en 3 verwijderingen
  1. 7 1
      exo/api/chatgpt_api.py
  2. 6 1
      exo/networking/grpc/grpc_peer_handle.py
  3. 7 1
      exo/networking/grpc/grpc_server.py

+ 7 - 1
exo/api/chatgpt_api.py

@@ -21,7 +21,13 @@ from PIL import Image
 import numpy as np
 import base64
 from io import BytesIO
-import mlx.core as mx
+import platform
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
+
 import tempfile
 from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil

+ 6 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -12,7 +12,12 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 import json
-import mlx.core as mx
+import platform
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
 
 class GRPCPeerHandle(PeerHandle):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):

+ 7 - 1
exo/networking/grpc/grpc_server.py

@@ -3,13 +3,19 @@ from concurrent import futures
 import numpy as np
 from asyncio import CancelledError
 
+import platform
+
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 import json
-import mlx.core as mx
+
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
+  import mlx.core as mx
+else:
+  import numpy as mx
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):