Selaa lähdekoodia

display all interfaces web chat and chatgpt api are available on fixes #134

Alex Cheema 8 kuukautta sitten
vanhempi
commit
71591d2ebc
6 muutettua tiedostoa jossa 35 lisäystä ja 17 poistoa
  1. 0 3
      exo/api/chatgpt_api.py
  2. 15 0
      exo/helpers.py
  3. 3 3
      exo/orchestration/standard_node.py
  4. 8 8
      exo/viz/topology_viz.py
  5. 8 3
      main.py
  6. 1 0
      setup.py

+ 0 - 3
exo/api/chatgpt_api.py

@@ -391,6 +391,3 @@ class ChatGPTAPI:
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     await site.start()
-    if DEBUG >= 0:
-      print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
-      print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")

+ 15 - 0
exo/helpers.py

@@ -6,6 +6,7 @@ import random
 import platform
 import psutil
 import uuid
+import netifaces
 from pathlib import Path
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
@@ -225,3 +226,17 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
         return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
     else:
         return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+
+def get_all_ip_addresses():
+    try:
+      ip_addresses = []
+      for interface in netifaces.interfaces():
+        ifaddresses = netifaces.ifaddresses(interface)
+        if netifaces.AF_INET in ifaddresses:
+          for link in ifaddresses[netifaces.AF_INET]:
+            ip = link['addr']
+            ip_addresses.append(ip)
+      return list(set(ip_addresses))
+    except:
+      if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
+      return ["localhost"]

+ 3 - 3
exo/orchestration/standard_node.py

@@ -26,8 +26,8 @@ class StandardNode(Node):
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
-    chatgpt_api_endpoint: Optional[str] = None,
-    web_chat_url: Optional[str] = None,
+    chatgpt_api_endpoints: List[str] = [],
+    web_chat_urls: List[str] = [],
     disable_tui: Optional[bool] = False,
   ):
     self.id = _id
@@ -39,7 +39,7 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
+    self.topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()

+ 8 - 8
exo/viz/topology_viz.py

@@ -14,9 +14,9 @@ from rich.layout import Layout
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
 class TopologyViz:
-  def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
-    self.chatgpt_api_endpoint = chatgpt_api_endpoint
-    self.web_chat_url = web_chat_url
+  def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
+    self.chatgpt_api_endpoints = chatgpt_api_endpoints
+    self.web_chat_urls = web_chat_urls
     self.topology = Topology()
     self.partitions: List[Partition] = []
     self.node_id = None
@@ -80,12 +80,12 @@ class TopologyViz:
         if 0 <= start_x + j < 100 and i < len(visualization):
           visualization[i][start_x + j] = char
 
-    # Display chatgpt_api_endpoint and web_chat_url if set
+    # Display chatgpt_api_endpoints and web_chat_urls
     info_lines = []
-    if self.web_chat_url:
-      info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
-    if self.chatgpt_api_endpoint:
-      info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
+    if len(self.web_chat_urls) > 0:
+      info_lines.append(f"Web Chat URL (tinychat): {' '.join(self.web_chat_urls[:1])}")
+    if len(self.chatgpt_api_endpoints) > 0:
+      info_lines.append(f"ChatGPT API endpoint: {' '.join(self.chatgpt_api_endpoints[:1])}")
 
     info_start_y = len(exo_lines) + 1
     for i, line in enumerate(info_lines):

+ 8 - 3
main.py

@@ -11,7 +11,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
 
 # parse args
@@ -47,14 +47,19 @@ if args.node_port is None:
 
 args.node_id = args.node_id or get_or_create_node_id()
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
+chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
+web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+if DEBUG >= 0:
+    print(f"Chat interface started:\n{'\n'.join([' - ' + terminal_link(web_chat_url) for web_chat_url in web_chat_urls])}")
+    print(f"ChatGPT API endpoint served at:\n{'\n'.join([' - ' + terminal_link(chatgpt_api_endpoint) for chatgpt_api_endpoint in chatgpt_api_endpoints])}")
 node = StandardNode(
     args.node_id,
     None,
     inference_engine,
     discovery,
+    chatgpt_api_endpoints=chatgpt_api_endpoints,
+    web_chat_urls=web_chat_urls,
     partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
-    chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions",
-    web_chat_url=f"http://localhost:{args.chatgpt_api_port}",
     disable_tui=args.disable_tui,
     max_generate_tokens=args.max_generate_tokens,
 )

+ 1 - 0
setup.py

@@ -13,6 +13,7 @@ install_requires = [
     "hf-transfer==0.1.8",
     "huggingface-hub==0.24.5",
     "Jinja2==3.1.4",
+    "netifaces==0.11.0",
     "numpy==2.0.0",
     "pillow==10.4.0",
     "prometheus-client==0.20.0",