main.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import argparse
  2. import asyncio
  3. import signal
  4. import mlx.core as mx
  5. import mlx.nn as nn
  6. import uuid
  7. from typing import List
  8. from exo.orchestration.standard_node import StandardNode
  9. from exo.networking.grpc.grpc_server import GRPCServer
  10. from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
  11. from exo.networking.grpc.grpc_discovery import GRPCDiscovery
  12. from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
  13. from exo.api import ChatGPTAPI
  14. # parse args
  15. parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
  16. parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
  17. parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
  18. parser.add_argument("--node-port", type=int, default=8080, help="Node port")
  19. parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
  20. parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
  21. parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
  22. parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
  23. args = parser.parse_args()
  24. inference_engine = MLXDynamicShardInferenceEngine()
  25. def on_token(tokens: List[int]):
  26. if inference_engine.tokenizer:
  27. print(inference_engine.tokenizer.decode(tokens))
  28. discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port)
  29. node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), on_token=on_token)
  30. server = GRPCServer(node, args.node_host, args.node_port)
  31. node.server = server
  32. api = ChatGPTAPI(node)
  33. async def shutdown(signal, loop):
  34. """Gracefully shutdown the server and close the asyncio loop."""
  35. print(f"Received exit signal {signal.name}...")
  36. server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
  37. [task.cancel() for task in server_tasks]
  38. print(f"Cancelling {len(server_tasks)} outstanding tasks")
  39. await asyncio.gather(*server_tasks, return_exceptions=True)
  40. await server.stop()
  41. loop.stop()
  42. async def main():
  43. loop = asyncio.get_running_loop()
  44. # Use a more direct approach to handle signals
  45. def handle_exit():
  46. asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
  47. for s in [signal.SIGINT, signal.SIGTERM]:
  48. loop.add_signal_handler(s, handle_exit)
  49. await node.start(wait_for_peers=args.wait_for_peers)
  50. asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
  51. await asyncio.Event().wait()
  52. if __name__ == "__main__":
  53. loop = asyncio.new_event_loop()
  54. asyncio.set_event_loop(loop)
  55. try:
  56. loop.run_until_complete(main())
  57. except KeyboardInterrupt:
  58. print("Received keyboard interrupt. Shutting down...")
  59. finally:
  60. loop.run_until_complete(shutdown(signal.SIGTERM, loop))
  61. loop.close()