helpers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import os
  2. import asyncio
  3. from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
  4. import socket
  5. import random
  6. import platform
  7. import psutil
  8. import uuid
  9. import netifaces
  10. from pathlib import Path
  11. import tempfile
  12. DEBUG = int(os.getenv("DEBUG", default="0"))
  13. DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
  14. VERSION = "0.0.1"
  15. exo_text = r"""
  16. _____ _____
  17. / _ \ \/ / _ \
  18. | __/> < (_) |
  19. \___/_/\_\___/
  20. """
  21. def get_system_info():
  22. if psutil.MACOS:
  23. if platform.machine() == "arm64":
  24. return "Apple Silicon Mac"
  25. if platform.machine() in ["x86_64", "i386"]:
  26. return "Intel Mac"
  27. return "Unknown Mac architecture"
  28. if psutil.LINUX:
  29. return "Linux"
  30. return "Non-Mac, non-Linux system"
  31. def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
  32. used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")
  33. def read_used_ports():
  34. if os.path.exists(used_ports_file):
  35. with open(used_ports_file, "r") as f:
  36. return [int(line.strip()) for line in f if line.strip().isdigit()]
  37. return []
  38. def write_used_port(port, used_ports):
  39. with open(used_ports_file, "w") as f:
  40. print(used_ports[-19:])
  41. for p in used_ports[-19:] + [port]:
  42. f.write(f"{p}\n")
  43. used_ports = read_used_ports()
  44. available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
  45. while available_ports:
  46. port = random.choice(list(available_ports))
  47. if DEBUG >= 2: print(f"Trying to find available port {port=}")
  48. try:
  49. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  50. s.bind((host, port))
  51. write_used_port(port, used_ports)
  52. return port
  53. except socket.error:
  54. available_ports.remove(port)
  55. raise RuntimeError("No available ports in the specified range")
  56. def print_exo():
  57. print(exo_text)
  58. def print_yellow_exo():
  59. yellow = "\033[93m" # ANSI escape code for yellow
  60. reset = "\033[0m" # ANSI escape code to reset color
  61. print(f"{yellow}{exo_text}{reset}")
  62. def terminal_link(uri, label=None):
  63. if label is None:
  64. label = uri
  65. parameters = ""
  66. # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST
  67. escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\"
  68. return escape_mask.format(parameters, uri, label)
  69. T = TypeVar("T")
  70. K = TypeVar("K")
  71. class AsyncCallback(Generic[T]):
  72. def __init__(self) -> None:
  73. self.condition: asyncio.Condition = asyncio.Condition()
  74. self.result: Optional[Tuple[T, ...]] = None
  75. self.observers: list[Callable[..., None]] = []
  76. async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
  77. async with self.condition:
  78. await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
  79. assert self.result is not None # for type checking
  80. return self.result
  81. def on_next(self, callback: Callable[..., None]) -> None:
  82. self.observers.append(callback)
  83. def set(self, *args: T) -> None:
  84. self.result = args
  85. for observer in self.observers:
  86. observer(*args)
  87. asyncio.create_task(self.notify())
  88. async def notify(self) -> None:
  89. async with self.condition:
  90. self.condition.notify_all()
  91. class AsyncCallbackSystem(Generic[K, T]):
  92. def __init__(self) -> None:
  93. self.callbacks: Dict[K, AsyncCallback[T]] = {}
  94. def register(self, name: K) -> AsyncCallback[T]:
  95. if name not in self.callbacks:
  96. self.callbacks[name] = AsyncCallback[T]()
  97. return self.callbacks[name]
  98. def deregister(self, name: K) -> None:
  99. if name in self.callbacks:
  100. del self.callbacks[name]
  101. def trigger(self, name: K, *args: T) -> None:
  102. if name in self.callbacks:
  103. self.callbacks[name].set(*args)
  104. def trigger_all(self, *args: T) -> None:
  105. for callback in self.callbacks.values():
  106. callback.set(*args)
  107. K = TypeVar('K', bound=str)
  108. V = TypeVar('V')
  109. class PrefixDict(Generic[K, V]):
  110. def __init__(self):
  111. self.items: Dict[K, V] = {}
  112. def add(self, key: K, value: V) -> None:
  113. self.items[key] = value
  114. def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
  115. return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
  116. def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
  117. matches = self.find_prefix(argument)
  118. if len(matches) == 0:
  119. return None
  120. return max(matches, key=lambda x: len(x[0]))
  121. def is_valid_uuid(val):
  122. try:
  123. uuid.UUID(str(val))
  124. return True
  125. except ValueError:
  126. return False
  127. def get_or_create_node_id():
  128. NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
  129. try:
  130. if NODE_ID_FILE.is_file():
  131. with open(NODE_ID_FILE, "r") as f:
  132. stored_id = f.read().strip()
  133. if is_valid_uuid(stored_id):
  134. if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
  135. return stored_id
  136. else:
  137. if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
  138. new_id = str(uuid.uuid4())
  139. with open(NODE_ID_FILE, "w") as f:
  140. f.write(new_id)
  141. if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
  142. return new_id
  143. except IOError as e:
  144. if DEBUG >= 2: print(f"IO error creating node_id: {e}")
  145. return str(uuid.uuid4())
  146. except Exception as e:
  147. if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
  148. return str(uuid.uuid4())
  149. def pretty_print_bytes(size_in_bytes: int) -> str:
  150. if size_in_bytes < 1024:
  151. return f"{size_in_bytes} B"
  152. elif size_in_bytes < 1024**2:
  153. return f"{size_in_bytes / 1024:.2f} KB"
  154. elif size_in_bytes < 1024**3:
  155. return f"{size_in_bytes / (1024 ** 2):.2f} MB"
  156. elif size_in_bytes < 1024**4:
  157. return f"{size_in_bytes / (1024 ** 3):.2f} GB"
  158. else:
  159. return f"{size_in_bytes / (1024 ** 4):.2f} TB"
  160. def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
  161. if bytes_per_second < 1024:
  162. return f"{bytes_per_second} B/s"
  163. elif bytes_per_second < 1024**2:
  164. return f"{bytes_per_second / 1024:.2f} KB/s"
  165. elif bytes_per_second < 1024**3:
  166. return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
  167. elif bytes_per_second < 1024**4:
  168. return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
  169. else:
  170. return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
  171. def get_all_ip_addresses():
  172. try:
  173. ip_addresses = []
  174. for interface in netifaces.interfaces():
  175. ifaddresses = netifaces.ifaddresses(interface)
  176. if netifaces.AF_INET in ifaddresses:
  177. for link in ifaddresses[netifaces.AF_INET]:
  178. ip = link['addr']
  179. ip_addresses.append(ip)
  180. return list(set(ip_addresses))
  181. except:
  182. if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
  183. return ["localhost"]