Parcourir la source

run format.py on ./exo

Alex Cheema il y a 6 mois
Parent
commit
98ea71edda

+ 4 - 7
exo/api/chatgpt_api.py

@@ -178,7 +178,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent / "tinychat"
+    self.static_dir = Path(__file__).parent.parent/"tinychat"
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_static("/", self.static_dir, name="static")
 
@@ -191,6 +191,7 @@ class ChatGPTAPI:
         return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
       except asyncio.TimeoutError:
         return web.json_response({"detail": "Request timed out"}, status=408)
+
     return middleware
 
   async def log_request(self, app, handler):
@@ -204,7 +205,7 @@ class ChatGPTAPI:
     return web.FileResponse(self.static_dir/"index.html")
 
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_base_shards.items()])
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
@@ -222,7 +223,6 @@ class ChatGPTAPI:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
     return web.json_response(progress_data)
 
-
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
@@ -270,10 +270,7 @@ class ChatGPTAPI:
     if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
 
     try:
-      await asyncio.wait_for(
-        asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))),
-        timeout=self.response_timeout
-      )
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))), timeout=self.response_timeout)
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 

+ 4 - 2
exo/download/hf/hf_helpers.py

@@ -70,8 +70,10 @@ def _add_wildcard_to_directories(pattern: str) -> str:
     return pattern + "*"
   return pattern
 
+
 def get_hf_endpoint() -> str:
-    return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+
 
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
@@ -394,7 +396,7 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
+  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   if weight_map:
     for tensor_name, filename in weight_map.items():

+ 1 - 0
exo/download/shard_download.py

@@ -25,6 +25,7 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
 
+
 class NoopShardDownloader(ShardDownloader):
   async def ensure_shard(self, shard: Shard) -> Path:
     return Path("/tmp/noop_shard")

+ 1 - 1
exo/helpers.py

@@ -170,7 +170,7 @@ def is_valid_uuid(val):
 
 
 def get_or_create_node_id():
-  NODE_ID_FILE = Path(tempfile.gettempdir()) / ".exo_node_id"
+  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
   try:
     if NODE_ID_FILE.is_file():
       with open(NODE_ID_FILE, "r") as f:

+ 1 - 0
exo/inference/dummy_inference_engine.py

@@ -5,6 +5,7 @@ import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 
+
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None

+ 2 - 1
exo/inference/mlx/models/qwen2.py

@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
+
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -57,7 +58,7 @@ class Qwen2Model(nn.Module):
       mask = create_attention_mask(h, cache)
 
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)

+ 5 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -10,6 +10,7 @@ import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 
+
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
@@ -44,7 +45,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()
-      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
       self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard

+ 1 - 0
exo/inference/mlx/sharded_model.py

@@ -8,6 +8,7 @@ from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
+
 # TODO: support a speculative model so we can parallelise compute across devices
 class StatefulShardedModel:
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):

+ 47 - 42
exo/inference/test_dummy_inference_engine.py

@@ -4,53 +4,58 @@ import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 
+
 class MockShardDownloader:
-    async def ensure_shard(self, shard):
-        pass
+  async def ensure_shard(self, shard):
+    pass
+
+
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
-    engine = DummyInferenceEngine(MockShardDownloader())
-    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    test_prompt = "This is a test prompt"
-    
-    result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
-    
-    print(f"Inference result shape: {result.shape}")
-    print(f"Inference state: {state}")
-    print(f"Is finished: {is_finished}")
-    
-    assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
-    assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+  engine = DummyInferenceEngine(MockShardDownloader())
+  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+  test_prompt = "This is a test prompt"
+
+  result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
+
+  print(f"Inference result shape: {result.shape}")
+  print(f"Inference state: {state}")
+  print(f"Is finished: {is_finished}")
+
+  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+  assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
+  assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
 
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
-    # Initialize the DummyInferenceEngine
-    engine = DummyInferenceEngine(MockShardDownloader())
-    
-    # Create a test shard
-    shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    
-    # Test infer_prompt
-    output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    # Test infer_tensor
-    input_tensor = np.array([[1, 2, 3]])
-    output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    print("All tests passed!")
+  # Initialize the DummyInferenceEngine
+  engine = DummyInferenceEngine(MockShardDownloader())
+
+  # Create a test shard
+  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+
+  # Test infer_prompt
+  output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+  assert isinstance(state, str), "State should be a string"
+  assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+  # Test infer_tensor
+  input_tensor = np.array([[1, 2, 3]])
+  output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+  assert isinstance(state, str), "State should be a string"
+  assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+  print("All tests passed!")
+
 
 if __name__ == "__main__":
-    import asyncio
-    asyncio.run(test_dummy_inference_engine())
-    asyncio.run(test_dummy_inference_specific())
+  import asyncio
+  asyncio.run(test_dummy_inference_engine())
+  asyncio.run(test_dummy_inference_specific())

+ 2 - 12
exo/inference/test_inference_engine.py

@@ -44,12 +44,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Llama-3.2-1B-Instruct-4bit",
-  16
-))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
@@ -57,10 +52,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
-    test_inference_engine(
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-      32
-    )
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", 32)
   )

+ 6 - 1
exo/inference/tokenizers.py

@@ -7,14 +7,18 @@ from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 
+
 class DummyTokenizer:
   def __init__(self):
     self.eos_token_id = 0
+
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
-    return [1,2,3]
+    return [1, 2, 3]
+
   def decode(self, tokens):
     return "dummy"
 
+
 async def resolve_tokenizer(model_id: str):
   if model_id == "dummy":
     return DummyTokenizer()
@@ -29,6 +33,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
 
+
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")

+ 31 - 14
exo/main.py

@@ -54,14 +54,13 @@ parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
-
 print_yellow_exo()
 
-
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
+                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
@@ -84,9 +83,23 @@ if DEBUG >= 0:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
 if args.discovery_module == "udp":
-  discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+  discovery = UDPDiscovery(
+    args.node_id,
+    args.node_port,
+    args.listen_port,
+    args.broadcast_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout
+  )
 elif args.discovery_module == "tailscale":
-  discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
+  discovery = TailscaleDiscovery(
+    args.node_id,
+    args.node_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout,
+    tailscale_api_key=args.tailscale_api_key,
+    tailnet=args.tailnet_name
+  )
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
@@ -113,6 +126,8 @@ api = ChatGPTAPI(
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
+
+
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
@@ -124,6 +139,8 @@ def preemptively_start_download(request_id: str, opaque_status: str):
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
+
+
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
 if args.prometheus_client_port:
@@ -132,16 +149,14 @@ if args.prometheus_client_port:
 
 last_broadcast_time = 0
 
+
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
-            "type": "download_progress",
-            "node_id": node.id,
-            "progress": event.to_dict()
-        })))
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
@@ -158,6 +173,7 @@ async def shutdown(signal, loop):
   await server.stop()
   loop.stop()
 
+
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
   shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
   if not shard:
@@ -220,5 +236,6 @@ def run():
     loop.run_until_complete(shutdown(signal.SIGTERM, loop))
     loop.close()
 
+
 if __name__ == "__main__":
   run()

+ 12 - 36
exo/models.py

@@ -2,12 +2,8 @@ from exo.inference.shard import Shard
 
 model_base_shards = {
   ### llama
-  "llama-3.2-1b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
-  },
-  "llama-3.2-3b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
+  "llama-3.2-1b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),},
+  "llama-3.2-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
   "llama-3.1-8b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
@@ -38,36 +34,16 @@ model_base_shards = {
   ### llava
   "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
   ### qwen
-  "qwen-2.5-coder-1.5b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-coder-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-math-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-14b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),
-  },
-  "qwen-2.5-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "qwen-2.5-math-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "qwen-2.5-coder-1.5b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-coder-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-math-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
+  "qwen-2.5-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
+  "qwen-2.5-math-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
   ### nemotron
-  "nemotron-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "nemotron-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
+  "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
   # dummy
-  "dummy": {
-    "DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),
-  },
+  "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
 }

+ 3 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -117,7 +117,9 @@ class GRPCPeerHandle(PeerHandle):
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8))
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+      )
       topology.update_node(node_id, device_capabilities)
     for node_id, peers in response.peer_graph.items():
       for peer_id in peers.peer_ids:

Fichier diff supprimé car celui-ci est trop grand
+ 0 - 3
exo/networking/grpc/node_service_pb2.py


+ 263 - 314
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,349 +12,298 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-    from grpc._utilities import first_version_is_lower
-    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+  from grpc._utilities import first_version_is_lower
+  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-    _version_not_supported = True
+  _version_not_supported = True
 
 if _version_not_supported:
-    warnings.warn(
-        f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
-        + f' grpcio>={GRPC_GENERATED_VERSION}.'
-        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
-        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
-    )
+  warnings.warn(
+    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
+    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
+    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
+  )
 
 
 class NodeServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def __init__(self, channel):
-        """Constructor.
+  """Missing associated documentation comment in .proto file."""
+  def __init__(self, channel):
+    """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-        self.SendPrompt = channel.unary_unary(
-                '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.SendTensor = channel.unary_unary(
-                '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
-        self.CollectTopology = channel.unary_unary(
-                '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
-                _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.SendOpaqueStatus = channel.unary_unary(
-                '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.HealthCheck = channel.unary_unary(
-                '/node_service.NodeService/HealthCheck',
-                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
-                _registered_method=True)
+    self.SendPrompt = channel.unary_unary(
+      '/node_service.NodeService/SendPrompt',
+      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.SendTensor = channel.unary_unary(
+      '/node_service.NodeService/SendTensor',
+      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.GetInferenceResult = channel.unary_unary(
+      '/node_service.NodeService/GetInferenceResult',
+      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.InferenceResult.FromString,
+      _registered_method=True
+    )
+    self.CollectTopology = channel.unary_unary(
+      '/node_service.NodeService/CollectTopology',
+      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Topology.FromString,
+      _registered_method=True
+    )
+    self.SendResult = channel.unary_unary(
+      '/node_service.NodeService/SendResult',
+      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.SendOpaqueStatus = channel.unary_unary(
+      '/node_service.NodeService/SendOpaqueStatus',
+      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.HealthCheck = channel.unary_unary(
+      '/node_service.NodeService/HealthCheck',
+      request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+      response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+      _registered_method=True
+    )
 
 
 class NodeServiceServicer(object):
+  """Missing associated documentation comment in .proto file."""
+  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendPrompt(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def SendTensor(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendTensor(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def GetInferenceResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def CollectTopology(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def CollectTopology(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendOpaqueStatus(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendOpaqueStatus(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def HealthCheck(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def HealthCheck(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-            'SendPrompt': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'SendTensor': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
-            'CollectTopology': grpc.unary_unary_rpc_method_handler(
-                    servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
-            ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'HealthCheck': grpc.unary_unary_rpc_method_handler(
-                    servicer.HealthCheck,
-                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
-            ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-            'node_service.NodeService', rpc_method_handlers)
-    server.add_generic_rpc_handlers((generic_handler,))
-    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+  rpc_method_handlers = {
+    'SendPrompt':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendPrompt,
+        request_deserializer=node__service__pb2.PromptRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'SendTensor':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendTensor,
+        request_deserializer=node__service__pb2.TensorRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'GetInferenceResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.GetInferenceResult,
+        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+      ),
+    'CollectTopology':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.CollectTopology,
+        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+        response_serializer=node__service__pb2.Topology.SerializeToString,
+      ),
+    'SendResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendResult,
+        request_deserializer=node__service__pb2.SendResultRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'SendOpaqueStatus':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendOpaqueStatus,
+        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'HealthCheck':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.HealthCheck,
+        request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+        response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+      ),
+  }
+  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
+  server.add_generic_rpc_handlers((generic_handler,))
+  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-    """Missing associated documentation comment in .proto file."""
-
-    @staticmethod
-    def SendPrompt(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  """Missing associated documentation comment in .proto file."""
+  @staticmethod
+  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendPrompt',
+      node__service__pb2.PromptRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendTensor(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendTensor',
+      node__service__pb2.TensorRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/GetInferenceResult',
+      node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      node__service__pb2.InferenceResult.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def CollectTopology(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/CollectTopology',
+      node__service__pb2.CollectTopologyRequest.SerializeToString,
+      node__service__pb2.Topology.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendResult',
+      node__service__pb2.SendResultRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendOpaqueStatus(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendOpaqueStatus',
+      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def HealthCheck(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/HealthCheck',
-            node__service__pb2.HealthCheckRequest.SerializeToString,
-            node__service__pb2.HealthCheckResponse.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def HealthCheck(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/HealthCheck',
+      node__service__pb2.HealthCheckRequest.SerializeToString,
+      node__service__pb2.HealthCheckResponse.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )

+ 9 - 7
exo/networking/manual/manual_discovery.py

@@ -19,7 +19,9 @@ class ManualDiscovery(Discovery):
     self.create_peer_handle = create_peer_handle
 
     if node_id not in self.topology.peers:
-      raise ValueError(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
+      raise ValueError(
+        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+      )
 
     self.listen_task = None
 
@@ -42,7 +44,6 @@ class ManualDiscovery(Discovery):
     if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
     return list(self.known_peers.values())
 
-
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
@@ -52,18 +53,19 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)  
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             self.known_peers[peer_id] = peer
           else:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try: del self.known_peers[peer_id]
-            except KeyError: pass
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass
         except Exception as e:
-            if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
       await asyncio.sleep(1.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
-

+ 0 - 1
exo/networking/manual/network_topology_config.py

@@ -17,7 +17,6 @@ class NetworkTopology(BaseModel):
   """
   node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
   """
-
   @classmethod
   def from_path(cls, path: str) -> "NetworkTopology":
     try:

+ 1 - 0
exo/networking/peer_handle.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 
+
 class PeerHandle(ABC):
   @abstractmethod
   def id(self) -> str:

+ 11 - 11
exo/networking/tailscale/tailscale_discovery.py

@@ -8,6 +8,7 @@ from exo.topology.device_capabilities import DeviceCapabilities, device_capabili
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
 
+
 class TailscaleDiscovery(Discovery):
   def __init__(
     self,
@@ -69,14 +70,11 @@ class TailscaleDiscovery(Discovery):
         devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         current_time = time.time()
 
-        active_devices = {
-          name: device for name, device in devices.items()
-          if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
-        }
+        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
 
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
-        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time  - device.last_seen.timestamp()) for device in devices.values()])
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
 
         for device in active_devices.values():
           if device.name == self.node_id: continue
@@ -141,7 +139,13 @@ class TailscaleDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}" for peer_handle, connected_at, last_seen in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            }
+          )
 
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
@@ -164,9 +168,5 @@ class TailscaleDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove

+ 15 - 26
exo/networking/tailscale/tailscale_helpers.py

@@ -7,6 +7,7 @@ from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from datetime import datetime, timezone
 
+
 class Device:
   def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
     self.device_id = device_id
@@ -16,12 +17,7 @@ class Device:
 
   @classmethod
   def from_dict(cls, data: Dict[str, Any]) -> 'Device':
-    return cls(
-      device_id=data.get('id', ''),
-      name=data.get('name', ''),
-      addresses=data.get('addresses', []),
-      last_seen=cls.parse_datetime(data.get('lastSeen'))
-    )
+    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
 
   @staticmethod
   def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
@@ -29,13 +25,10 @@ class Device:
       return None
     return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
 
+
 async def get_device_id() -> str:
   try:
-    process = await asyncio.create_subprocess_exec(
-      'tailscale', 'status', '--json',
-      stdout=asyncio.subprocess.PIPE,
-      stderr=asyncio.subprocess.PIPE
-    )
+    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
     stdout, stderr = await process.communicate()
     if process.returncode != 0:
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
@@ -45,22 +38,16 @@ async def get_device_id() -> str:
   except Exception as e:
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
 
+
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
   async with aiohttp.ClientSession() as session:
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}',
-      'Content-Type': 'application/json'
-    }
+    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
 
     attributes = {
-      "custom:exo_node_id": node_id.replace('-', '_'),
-      "custom:exo_node_port": node_port,
-      "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
-      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model),
-      "custom:exo_device_capability_memory": str(device_capabilities.memory),
-      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
-      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
     }
 
@@ -73,12 +60,11 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
         else:
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
 
+
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}'
-    }
+    headers = {'Authorization': f'Bearer {api_key}'}
     async with session.get(url, headers=headers) as response:
       if response.status == 200:
         data = await response.json()
@@ -100,6 +86,7 @@ async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int,
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
 
+
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
   result = {}
   prefix = "custom:exo_"
@@ -112,12 +99,14 @@ def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
         result[attr_name] = float(value)
   return result
 
+
 def sanitize_attribute(value: str) -> str:
   # Replace invalid characters with underscores
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   # Truncate to 50 characters
   return sanitized_value[:50]
 
+
 async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
@@ -133,4 +122,4 @@ async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]
         device = Device.from_dict(device_data)
         devices[device.name] = device
 
-      return devices
+      return devices

+ 2 - 0
exo/networking/tailscale/test_tailscale_discovery.py

@@ -5,6 +5,7 @@ from unittest import mock
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.peer_handle import PeerHandle
 
+
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
@@ -37,5 +38,6 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     # Check if discovered peers are instances of GRPCPeerHandle
     print(peers)
 
+
 if __name__ == '__main__':
   unittest.main()

+ 14 - 15
exo/networking/udp/udp_discovery.py

@@ -9,6 +9,7 @@ from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
@@ -90,17 +91,13 @@ class UDPDiscovery(Discovery):
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1, # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
         })
         if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
 
         transport = None
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-            lambda: BroadcastProtocol(message, self.broadcast_port),
-            local_addr=(addr, 0),
-            family=socket.AF_INET
-          )
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
           if DEBUG_DISCOVERY >= 3:
             print(f"Broadcasting presence at ({addr})")
         except Exception as e:
@@ -145,7 +142,8 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers:
           existing_peer_prio = self.known_peers[peer_id][3]
           if existing_peer_prio >= peer_prio:
-            if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
+            if DEBUG >= 1:
+              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         if not await new_peer_handle.health_check():
@@ -161,8 +159,7 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
 
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
-                                                            local_addr=("0.0.0.0", self.listen_port))
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
   async def task_cleanup_peers(self):
@@ -177,7 +174,13 @@ class UDPDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
+              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
+            }
+          )
 
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
@@ -200,9 +203,5 @@ class UDPDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove

+ 14 - 26
exo/orchestration/standard_node.py

@@ -18,6 +18,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
+
 class StandardNode(Node):
   def __init__(
     self,
@@ -87,18 +88,14 @@ class StandardNode(Node):
   def get_supported_inference_engines(self):
     supported_engine_names = []
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-        supported_engine_names.append('mlx')
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('mlx')
+      supported_engine_names.append('tinygrad')
     else:
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('tinygrad')
     return supported_engine_names
 
   async def broadcast_supported_engines(self, supported_engines_names: List[str]):
-    status_message = json.dumps({
-        "type": "supported_inference_engines",
-        "node_id": self.id,
-        "engines": supported_engines_names
-    })
+    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
     await self.broadcast_opaque_status("", status_message)
 
   def get_topology_inference_engines(self) -> List[List[str]]:
@@ -311,20 +308,16 @@ class StandardNode(Node):
     next_peer_ids = {peer.id() for peer in next_peers}
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
-    peers_updated = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
-    peers_unchanged = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
+    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
 
     def _pretty(peers: List[PeerHandle]) -> List[str]:
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
-    if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    if DEBUG >= 2:
+      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
 
     async def disconnect_with_timeout(peer, timeout=5):
       try:
@@ -344,14 +337,8 @@ class StandardNode(Node):
         traceback.print_exc()
         return False
 
-    disconnect_results = await asyncio.gather(
-      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
-      return_exceptions=True
-    )
-    connect_results = await asyncio.gather(
-      *(connect_with_timeout(peer) for peer in peers_to_connect),
-      return_exceptions=True
-    )
+    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
 
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
@@ -375,7 +362,7 @@ class StandardNode(Node):
         self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
       else:
         if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader) 
+        self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
 
   async def periodic_topology_collection(self, interval: int):
     while True:
@@ -465,6 +452,7 @@ class StandardNode(Node):
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         traceback.print_exc()
+
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
     self.on_opaque_status.trigger_all(request_id, status)

+ 44 - 41
exo/tinychat/update_deps.py

@@ -4,49 +4,52 @@ from bs4 import BeautifulSoup
 from urllib.parse import urljoin, urlparse
 import re
 
+
 def download_file(url, local_path):
-    response = requests.get(url)
-    if response.status_code == 200:
-        os.makedirs(os.path.dirname(local_path), exist_ok=True)
-        with open(local_path, 'wb') as f:
-            f.write(response.content)
-        print(f"Downloaded: {local_path}")
-    else:
-        print(response.status_code)
-        print(f"Failed to download: {url}")
+  response = requests.get(url)
+  if response.status_code == 200:
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    with open(local_path, 'wb') as f:
+      f.write(response.content)
+    print(f"Downloaded: {local_path}")
+  else:
+    print(response.status_code)
+    print(f"Failed to download: {url}")
+
 
 def update_html(html_content, base_url):
-    soup = BeautifulSoup(html_content, 'html.parser')
+  soup = BeautifulSoup(html_content, 'html.parser')
 
-    for tag in soup.find_all(['script', 'link']):
-        if tag.has_attr('src'):
-            url = tag['src']
-        elif tag.has_attr('href'):
-            url = tag['href']
-        else:
-            continue
+  for tag in soup.find_all(['script', 'link']):
+    if tag.has_attr('src'):
+      url = tag['src']
+    elif tag.has_attr('href'):
+      url = tag['href']
+    else:
+      continue
+
+    if url.startswith(('http://', 'https://')):
+      full_url = url
+    else:
+      full_url = urljoin(base_url, url)
 
-        if url.startswith(('http://', 'https://')):
-            full_url = url
-        else:
-            full_url = urljoin(base_url, url)
+    parsed_url = urlparse(full_url)
+    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
 
-        parsed_url = urlparse(full_url)
-        local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+    download_file(full_url, local_path)
 
-        download_file(full_url, local_path)
+    relative_path = os.path.relpath(local_path, '.')
+    if tag.name == 'script':
+      tag['src'] = "/" + relative_path
+    elif tag.name == 'link':
+      tag['href'] = "/" + relative_path
 
-        relative_path = os.path.relpath(local_path, '.')
-        if tag.name == 'script':
-            tag['src'] = "/" + relative_path
-        elif tag.name == 'link':
-            tag['href'] = "/" + relative_path
+  return str(soup)
 
-    return str(soup)
 
 # Read the HTML file
 with open('./index.html', 'r') as f:
-    html_content = f.read()
+  html_content = f.read()
 
 # Update HTML and download files
 # updated_html = update_html(html_content, 'https://example.com')
@@ -68,7 +71,7 @@ download_file(css_url, css_output_path)
 
 # Parse CSS file for font URLs
 with open(css_output_path, 'r', encoding='utf-8') as f:
-    css_content = f.read()
+  css_content = f.read()
 
 # Extract font URLs from the CSS content
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
@@ -77,14 +80,14 @@ print(f"Found {len(font_urls)} font URLs")
 
 # Download font files
 for font_url in font_urls:
-    font_url = font_url.strip('"\'')
-    if font_url.startswith('../'):
-        font_url = font_url[3:]
+  font_url = font_url.strip('"\'')
+  if font_url.startswith('../'):
+    font_url = font_url[3:]
 
-    # Use base_url instead of urljoin to keep the version number
-    full_url = base_url + font_url
-    relative_path = font_url
-    output_path = os.path.join(output_dir, relative_path)
-    download_file(full_url, output_path)
+  # Use base_url instead of urljoin to keep the version number
+  full_url = base_url + font_url
+  relative_path = font_url
+  output_path = os.path.join(output_dir, relative_path)
+  download_file(full_url, output_path)
 
-print("Download complete!")
+print("Download complete!")

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff