main.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import argparse
  2. import asyncio
  3. import aiofiles
  4. import signal
  5. import json
  6. import time
  7. import traceback
  8. import uuid
  9. from typing import Optional
  10. from pathlib import Path
  11. from exo.orchestration.standard_node import StandardNode
  12. from exo.networking.grpc.grpc_server import GRPCServer
  13. from exo.networking.grpc.grpc_discovery import GRPCDiscovery
  14. from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
  15. from exo.api import ChatGPTAPI
  16. from exo.download.shard_download import ShardDownloader, RepoProgressEvent
  17. from exo.download.hf.hf_shard_download import HFShardDownloader
  18. from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
  19. from exo.inference.shard import Shard
  20. from exo.inference.inference_engine import get_inference_engine, InferenceEngine
  21. from exo.inference.tokenizers import resolve_tokenizer
  22. from exo.orchestration.node import Node
  23. from exo.models import model_base_shards
  24. from exo.viz.topology_viz import TopologyViz
  25. # parse args
  26. parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
  27. parser.add_argument("--node-id", type=str, default=None, help="Node ID")
  28. parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
  29. parser.add_argument("--node-port", type=int, default=None, help="Node port")
  30. parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
  31. parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
  32. parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
  33. parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
  34. parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
  35. parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
  36. parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
  37. parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
  38. parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds")
  39. parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
  40. parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
  41. parser.add_argument("--max-kv-size", type=int, default=1024, help="Max KV size for inference engine")
  42. parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
  43. parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
  44. parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
  45. parser.add_argument("--file", type=str, help="File to use for the model when using --run-model", default=None)
  46. args = parser.parse_args()
  47. print_yellow_exo()
  48. system_info = get_system_info()
  49. print(f"Detected system: {system_info}")
  50. shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
  51. inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
  52. inference_engine = get_inference_engine(inference_engine_name, shard_downloader, max_kv_size=args.max_kv_size)
  53. print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
  54. if args.node_port is None:
  55. args.node_port = find_available_port(args.node_host)
  56. if DEBUG >= 1: print(f"Using available port: {args.node_port}")
  57. args.node_id = args.node_id or get_or_create_node_id()
  58. chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
  59. web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
  60. if DEBUG >= 0:
  61. print("Chat interface started:")
  62. for web_chat_url in web_chat_urls:
  63. print(f" - {terminal_link(web_chat_url)}")
  64. print("ChatGPT API endpoint served at:")
  65. for chatgpt_api_endpoint in chatgpt_api_endpoints:
  66. print(f" - {terminal_link(chatgpt_api_endpoint)}")
  67. discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
  68. topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
  69. node = StandardNode(
  70. args.node_id,
  71. None,
  72. inference_engine,
  73. discovery,
  74. partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
  75. max_generate_tokens=args.max_generate_tokens,
  76. topology_viz=topology_viz
  77. )
  78. server = GRPCServer(node, args.node_host, args.node_port)
  79. node.server = server
  80. api = ChatGPTAPI(
  81. node,
  82. inference_engine.__class__.__name__,
  83. response_timeout_secs=args.chatgpt_api_response_timeout_secs,
  84. on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
  85. )
  86. node.on_token.register("update_topology_viz").on_next(
  87. lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
  88. )
  89. def preemptively_start_download(request_id: str, opaque_status: str):
  90. try:
  91. status = json.loads(opaque_status)
  92. if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
  93. current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
  94. if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
  95. asyncio.create_task(shard_downloader.ensure_shard(current_shard))
  96. except Exception as e:
  97. if DEBUG >= 2:
  98. print(f"Failed to preemptively start download: {e}")
  99. traceback.print_exc()
  100. node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
  101. if args.prometheus_client_port:
  102. from exo.stats.metrics import start_metrics_server
  103. start_metrics_server(node, args.prometheus_client_port)
  104. last_broadcast_time = 0
  105. def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
  106. global last_broadcast_time
  107. current_time = time.time()
  108. if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
  109. last_broadcast_time = current_time
  110. asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
  111. shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
  112. async def shutdown(signal, loop):
  113. """Gracefully shutdown the server and close the asyncio loop."""
  114. print(f"Received exit signal {signal.name}...")
  115. print("Thank you for using exo.")
  116. print_yellow_exo()
  117. server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
  118. [task.cancel() for task in server_tasks]
  119. print(f"Cancelling {len(server_tasks)} outstanding tasks")
  120. await asyncio.gather(*server_tasks, return_exceptions=True)
  121. await server.stop()
  122. loop.stop()
  123. async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str, file_path: Optional[str] = None):
  124. if file_path:
  125. try:
  126. import textract
  127. prompt = "Input file: " + textract.process(file_path).decode('utf-8') + "\n\n---\n\n" + prompt
  128. except Exception as e:
  129. print(f"Error reading file {file_path}: {str(e)}")
  130. return
  131. shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
  132. if not shard:
  133. print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
  134. return
  135. tokenizer = await resolve_tokenizer(shard.model_id)
  136. request_id = str(uuid.uuid4())
  137. callback_id = f"cli-wait-response-{request_id}"
  138. callback = node.on_token.register(callback_id)
  139. if topology_viz:
  140. topology_viz.update_prompt(request_id, prompt)
  141. prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
  142. try:
  143. print(f"Processing prompt (len={len(prompt)}): {prompt}")
  144. await node.process_prompt(shard, prompt, None, request_id=request_id)
  145. _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
  146. print("\nGenerated response:")
  147. print(tokenizer.decode(tokens))
  148. except Exception as e:
  149. print(f"Error processing prompt: {str(e)}")
  150. traceback.print_exc()
  151. finally:
  152. node.on_token.deregister(callback_id)
  153. async def main():
  154. loop = asyncio.get_running_loop()
  155. # Use a more direct approach to handle signals
  156. def handle_exit():
  157. asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
  158. for s in [signal.SIGINT, signal.SIGTERM]:
  159. loop.add_signal_handler(s, handle_exit)
  160. await node.start(wait_for_peers=args.wait_for_peers)
  161. if args.run_model:
  162. await run_model_cli(node, inference_engine, args.run_model, args.prompt, args.file)
  163. else:
  164. asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
  165. await asyncio.Event().wait()
  166. if __name__ == "__main__":
  167. loop = asyncio.new_event_loop()
  168. asyncio.set_event_loop(loop)
  169. try:
  170. loop.run_until_complete(main())
  171. except KeyboardInterrupt:
  172. print("Received keyboard interrupt. Shutting down...")
  173. finally:
  174. loop.run_until_complete(shutdown(signal.SIGTERM, loop))
  175. loop.close()