main.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import argparse
  2. import asyncio
  3. import signal
  4. import mlx.core as mx
  5. import mlx.nn as nn
  6. import uuid
  7. from typing import List
  8. from exo.orchestration.standard_node import StandardNode
  9. from exo.networking.grpc.grpc_server import GRPCServer
  10. from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  11. from exo.inference.shard import Shard
  12. from exo.networking.grpc.grpc_discovery import GRPCDiscovery
  13. from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
  14. # parse args
  15. parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
  16. parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
  17. parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
  18. parser.add_argument("--node-port", type=int, default=8080, help="Node port")
  19. parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
  20. parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
  21. parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
  22. args = parser.parse_args()
  23. inference_engine = MLXDynamicShardInferenceEngine()
  24. def on_token(tokens: List[int]):
  25. if inference_engine.tokenizer:
  26. print(inference_engine.tokenizer.decode(tokens))
  27. discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
  28. node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), on_token=on_token)
  29. server = GRPCServer(node, args.node_host, args.node_port)
  30. node.server = server
  31. async def shutdown(signal, loop):
  32. """Gracefully shutdown the server and close the asyncio loop."""
  33. print(f"Received exit signal {signal.name}...")
  34. server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
  35. [task.cancel() for task in server_tasks]
  36. print(f"Cancelling {len(server_tasks)} outstanding tasks")
  37. await asyncio.gather(*server_tasks, return_exceptions=True)
  38. await server.stop()
  39. loop.stop()
  40. async def main():
  41. loop = asyncio.get_running_loop()
  42. # Use a more direct approach to handle signals
  43. def handle_exit():
  44. asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
  45. for s in [signal.SIGINT, signal.SIGTERM]:
  46. loop.add_signal_handler(s, handle_exit)
  47. await node.start(wait_for_peers=args.wait_for_peers)
  48. await asyncio.Event().wait()
  49. if __name__ == "__main__":
  50. loop = asyncio.new_event_loop()
  51. asyncio.set_event_loop(loop)
  52. try:
  53. loop.run_until_complete(main())
  54. except KeyboardInterrupt:
  55. print("Received keyboard interrupt. Shutting down...")
  56. finally:
  57. loop.run_until_complete(shutdown(signal.SIGTERM, loop))
  58. loop.close()