|
@@ -1,18 +1,15 @@
|
|
|
import argparse
|
|
|
import asyncio
|
|
|
import signal
|
|
|
-import mlx.core as mx
|
|
|
-import mlx.nn as nn
|
|
|
import uuid
|
|
|
+import platform
|
|
|
from typing import List
|
|
|
from exo.orchestration.standard_node import StandardNode
|
|
|
from exo.networking.grpc.grpc_server import GRPCServer
|
|
|
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
|
|
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
|
|
from exo.api import ChatGPTAPI
|
|
|
|
|
|
-
|
|
|
# parse args
|
|
|
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
|
|
parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
|
|
@@ -24,8 +21,14 @@ parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of pee
|
|
|
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
+print(f"Starting {platform.system()=}")
|
|
|
+if platform.system() == "Darwin":
|
|
|
+ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
|
+ inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
+else:
|
|
|
+ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
|
+ inference_engine = TinygradDynamicShardInferenceEngine()
|
|
|
|
|
|
-inference_engine = MLXDynamicShardInferenceEngine()
|
|
|
def on_token(tokens: List[int]):
|
|
|
if inference_engine.tokenizer:
|
|
|
print(inference_engine.tokenizer.decode(tokens))
|