main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import argparse
  2. import asyncio
  3. import atexit
  4. import signal
  5. import json
  6. import logging
  7. import platform
  8. import os
  9. import sys
  10. import time
  11. import traceback
  12. import uuid
  13. from exo.networking.manual.manual_discovery import ManualDiscovery
  14. from exo.networking.manual.network_topology_config import NetworkTopology
  15. from exo.orchestration.standard_node import StandardNode
  16. from exo.networking.grpc.grpc_server import GRPCServer
  17. from exo.networking.udp.udp_discovery import UDPDiscovery
  18. from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
  19. from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
  20. from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
  21. from exo.api import ChatGPTAPI
  22. from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
  23. from exo.download.hf.hf_shard_download import HFShardDownloader
  24. from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
  25. from exo.inference.shard import Shard
  26. from exo.inference.inference_engine import get_inference_engine, InferenceEngine
  27. from exo.inference.tokenizers import resolve_tokenizer
  28. from exo.orchestration.node import Node
  29. from exo.models import build_base_shard, get_repo
  30. from exo.viz.topology_viz import TopologyViz
  31. from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
  32. # parse args
  33. parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
  34. parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
  35. parser.add_argument("model_name", nargs="?", help="Model name to run")
  36. parser.add_argument("--default-model", type=str, default=None, help="Default model")
  37. parser.add_argument("--node-id", type=str, default=None, help="Node ID")
  38. parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
  39. parser.add_argument("--node-port", type=int, default=None, help="Node port")
  40. parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
  41. parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
  42. parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
  43. parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
  44. parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
  45. parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
  46. parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
  47. parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
  48. parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
  49. parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
  50. parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
  51. parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
  52. parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
  53. parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
  54. parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
  55. parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
  56. parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
  57. parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
  58. parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
  59. parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
  60. parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
  61. parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
  62. args = parser.parse_args()
  63. print(f"Selected inference engine: {args.inference_engine}")
  64. print_yellow_exo()
  65. system_info = get_system_info()
  66. print(f"Detected system: {system_info}")
  67. shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
  68. max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
  69. inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
  70. print(f"Inference engine name after selection: {inference_engine_name}")
  71. inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
  72. print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
  73. if args.node_port is None:
  74. args.node_port = find_available_port(args.node_host)
  75. if DEBUG >= 1: print(f"Using available port: {args.node_port}")
  76. args.node_id = args.node_id or get_or_create_node_id()
  77. chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()]
  78. web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()]
  79. if DEBUG >= 0:
  80. print("Chat interface started:")
  81. for web_chat_url in web_chat_urls:
  82. print(f" - {terminal_link(web_chat_url)}")
  83. print("ChatGPT API endpoint served at:")
  84. for chatgpt_api_endpoint in chatgpt_api_endpoints:
  85. print(f" - {terminal_link(chatgpt_api_endpoint)}")
  86. # Convert node-id-filter and interface-type-filter to lists if provided
  87. allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
  88. allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
  89. if args.discovery_module == "udp":
  90. discovery = UDPDiscovery(
  91. args.node_id,
  92. args.node_port,
  93. args.listen_port,
  94. args.broadcast_port,
  95. lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
  96. discovery_timeout=args.discovery_timeout,
  97. allowed_node_ids=allowed_node_ids,
  98. allowed_interface_types=allowed_interface_types
  99. )
  100. elif args.discovery_module == "tailscale":
  101. discovery = TailscaleDiscovery(
  102. args.node_id,
  103. args.node_port,
  104. lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
  105. discovery_timeout=args.discovery_timeout,
  106. tailscale_api_key=args.tailscale_api_key,
  107. tailnet=args.tailnet_name,
  108. allowed_node_ids=allowed_node_ids
  109. )
  110. elif args.discovery_module == "manual":
  111. if not args.discovery_config_path:
  112. raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
  113. discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
  114. topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
  115. node = StandardNode(
  116. args.node_id,
  117. None,
  118. inference_engine,
  119. discovery,
  120. partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
  121. max_generate_tokens=args.max_generate_tokens,
  122. topology_viz=topology_viz,
  123. shard_downloader=shard_downloader,
  124. default_sample_temperature=args.default_temp
  125. )
  126. server = GRPCServer(node, args.node_host, args.node_port)
  127. node.server = server
  128. api = ChatGPTAPI(
  129. node,
  130. inference_engine.__class__.__name__,
  131. response_timeout=args.chatgpt_api_response_timeout,
  132. on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
  133. default_model=args.default_model
  134. )
  135. node.on_token.register("update_topology_viz").on_next(
  136. lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
  137. )
  138. def preemptively_start_download(request_id: str, opaque_status: str):
  139. try:
  140. status = json.loads(opaque_status)
  141. if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
  142. current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
  143. if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
  144. asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
  145. except Exception as e:
  146. if DEBUG >= 2:
  147. print(f"Failed to preemptively start download: {e}")
  148. traceback.print_exc()
  149. node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
  150. if args.prometheus_client_port:
  151. from exo.stats.metrics import start_metrics_server
  152. start_metrics_server(node, args.prometheus_client_port)
  153. last_broadcast_time = 0
  154. def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
  155. global last_broadcast_time
  156. current_time = time.time()
  157. if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
  158. last_broadcast_time = current_time
  159. asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
  160. shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
  161. async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
  162. inference_class = inference_engine.__class__.__name__
  163. shard = build_base_shard(model_name, inference_class)
  164. if not shard:
  165. print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
  166. return
  167. tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
  168. request_id = str(uuid.uuid4())
  169. callback_id = f"cli-wait-response-{request_id}"
  170. callback = node.on_token.register(callback_id)
  171. if topology_viz:
  172. topology_viz.update_prompt(request_id, prompt)
  173. prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
  174. try:
  175. print(f"Processing prompt: {prompt}")
  176. await node.process_prompt(shard, prompt, request_id=request_id)
  177. _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
  178. print("\nGenerated response:")
  179. print(tokenizer.decode(tokens))
  180. except Exception as e:
  181. print(f"Error processing prompt: {str(e)}")
  182. traceback.print_exc()
  183. finally:
  184. node.on_token.deregister(callback_id)
  185. def clean_path(path):
  186. """Clean and resolve path"""
  187. if path.startswith("Optional("):
  188. path = path.strip('Optional("').rstrip('")')
  189. return os.path.expanduser(path)
  190. async def main():
  191. loop = asyncio.get_running_loop()
  192. # Check HuggingFace directory permissions
  193. hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
  194. if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
  195. print(f"{has_read=}, {has_write=}")
  196. if not has_read or not has_write:
  197. print(f"""
  198. WARNING: Limited permissions for model storage directory: {hf_home}.
  199. This may prevent model downloads from working correctly.
  200. {"❌ No read access" if not has_read else ""}
  201. {"❌ No write access" if not has_write else ""}
  202. """)
  203. if not args.models_seed_dir is None:
  204. try:
  205. models_seed_dir = clean_path(args.models_seed_dir)
  206. await move_models_to_hf(models_seed_dir)
  207. except Exception as e:
  208. print(f"Error moving models to .cache/huggingface: {e}")
  209. def restore_cursor():
  210. if platform.system() != "Windows":
  211. os.system("tput cnorm") # Show cursor
  212. # Restore the cursor when the program exits
  213. atexit.register(restore_cursor)
  214. # Use a more direct approach to handle signals
  215. def handle_exit():
  216. asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
  217. if platform.system() != "Windows":
  218. for s in [signal.SIGINT, signal.SIGTERM]:
  219. loop.add_signal_handler(s, handle_exit)
  220. await node.start(wait_for_peers=args.wait_for_peers)
  221. if args.command == "run" or args.run_model:
  222. model_name = args.model_name or args.run_model
  223. if not model_name:
  224. print("Error: Model name is required when using 'run' command or --run-model")
  225. return
  226. await run_model_cli(node, inference_engine, model_name, args.prompt)
  227. else:
  228. asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
  229. await asyncio.Event().wait()
  230. def run():
  231. loop = asyncio.new_event_loop()
  232. asyncio.set_event_loop(loop)
  233. try:
  234. loop.run_until_complete(main())
  235. except KeyboardInterrupt:
  236. print("Received keyboard interrupt. Shutting down...")
  237. finally:
  238. loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
  239. loop.close()
  240. if __name__ == "__main__":
  241. run()