Bläddra i källkod

reformat with yapf format.py

Alex Cheema 8 månader sedan
förälder
incheckning
ea70c9fb76
48 ändrade filer med 1873 tillägg och 1854 borttagningar
  1. 38 45
      examples/llama3_distributed.py
  2. 62 59
      exo/api/chatgpt_api.py
  3. 50 64
      exo/download/download_progress.py
  4. 334 297
      exo/download/hf/hf_helpers.py
  5. 58 60
      exo/download/hf/hf_shard_download.py
  6. 10 8
      exo/download/shard_download.py
  7. 85 74
      exo/helpers.py
  8. 5 7
      exo/inference/debug_inference_engine.py
  9. 2 0
      exo/inference/inference_engine.py
  10. 1 0
      exo/inference/mlx/models/base.py
  11. 4 5
      exo/inference/mlx/models/deepseek_v2.py
  12. 5 3
      exo/inference/mlx/models/llama.py
  13. 521 555
      exo/inference/mlx/models/llava.py
  14. 1 0
      exo/inference/mlx/sharded_inference_engine.py
  15. 4 9
      exo/inference/mlx/sharded_model.py
  16. 29 25
      exo/inference/mlx/sharded_utils.py
  17. 6 6
      exo/inference/mlx/test_sharded_llava.py
  18. 2 1
      exo/inference/mlx/test_sharded_model.py
  19. 2 4
      exo/inference/shard.py
  20. 13 11
      exo/inference/test_inference_engine.py
  21. 8 11
      exo/inference/tinygrad/inference.py
  22. 73 34
      exo/inference/tinygrad/models/llama.py
  23. 7 2
      exo/inference/tinygrad/tinygrad_helpers.py
  24. 1 0
      exo/inference/tokenizers.py
  25. 25 32
      exo/models.py
  26. 1 0
      exo/networking/discovery.py
  27. 11 11
      exo/networking/grpc/grpc_discovery.py
  28. 1 0
      exo/networking/grpc/grpc_peer_handle.py
  29. 9 9
      exo/networking/grpc/grpc_server.py
  30. 0 3
      exo/networking/grpc/node_service_pb2.py
  31. 229 271
      exo/networking/grpc/node_service_pb2_grpc.py
  32. 1 0
      exo/networking/grpc/test_grpc_discovery.py
  33. 1 0
      exo/networking/peer_handle.py
  34. 1 0
      exo/networking/server.py
  35. 1 0
      exo/orchestration/node.py
  36. 3 0
      exo/orchestration/standard_node.py
  37. 1 0
      exo/orchestration/test_node.py
  38. 1 0
      exo/test_callbacks.py
  39. 1 0
      exo/topology/partitioning_strategy.py
  40. 1 0
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  41. 1 0
      exo/topology/test_device_capabilities.py
  42. 1 0
      exo/topology/test_map_partitions.py
  43. 1 0
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  44. 1 0
      exo/topology/topology.py
  45. 49 47
      exo/viz/test_topology_viz.py
  46. 64 66
      exo/viz/topology_viz.py
  47. 35 37
      extra/download_hf.py
  48. 113 98
      main.py

+ 38 - 45
examples/llama3_distributed.py

@@ -13,8 +13,8 @@ import argparse
 import uuid
 import uuid
 
 
 models = {
 models = {
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
+  "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+  "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
 }
 }
 
 
 path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
 path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
@@ -29,60 +29,53 @@ tokenizer = load_tokenizer(model_path, tokenizer_config)
 #     "localhost:8080",
 #     "localhost:8080",
 #     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 #     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 # )
 # )
-peer2 = GRPCPeerHandle(
-    "node2",
-    "localhost:8081",
-    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
-)
+peer2 = GRPCPeerHandle("node2", "localhost:8081", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
 shard = models[path_or_hf_repo]
 shard = models[path_or_hf_repo]
 request_id = str(uuid.uuid4())
 request_id = str(uuid.uuid4())
 
 
+
 async def run_prompt(prompt: str):
 async def run_prompt(prompt: str):
-    if tokenizer.chat_template is None:
-        tokenizer.chat_template = tokenizer.default_chat_template
-    if (
-        hasattr(tokenizer, "apply_chat_template")
-        and tokenizer.chat_template is not None
-    ):
-        messages = [{"role": "user", "content": prompt}]
-        prompt = tokenizer.apply_chat_template(
-            messages, tokenize=False, add_generation_prompt=True
-        )
+  if tokenizer.chat_template is None:
+    tokenizer.chat_template = tokenizer.default_chat_template
+  if (hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None):
+    messages = [{"role": "user", "content": prompt}]
+    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+  await peer2.connect()
 
 
-    await peer2.connect()
+  try:
+    await peer2.send_prompt(shard, prompt, request_id)
+  except Exception as e:
+    print(e)
 
 
+  import time
+  # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
+  previous_length = 0
+  n_tokens = 0
+  start_time = time.perf_counter()
+  while True:
     try:
     try:
-        await peer2.send_prompt(shard, prompt, request_id)
+      result, is_finished = await peer2.get_inference_result(request_id)
     except Exception as e:
     except Exception as e:
-        print(e)
+      continue
+    await asyncio.sleep(0.1)
 
 
-    import time
-    # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
-    previous_length = 0
-    n_tokens = 0
-    start_time = time.perf_counter()
-    while True:
-        try:
-            result, is_finished = await peer2.get_inference_result(request_id)
-        except Exception as e:
-            continue
-        await asyncio.sleep(0.1)
+    # Print the updated string in place
+    updated_string = tokenizer.decode(result)
+    n_tokens = len(result)
+    print(updated_string[previous_length:], end='', flush=True)
+    previous_length = len(updated_string)
 
 
-        # Print the updated string in place
-        updated_string = tokenizer.decode(result)
-        n_tokens = len(result)
-        print(updated_string[previous_length:], end='', flush=True)
-        previous_length = len(updated_string)
+    if is_finished:
+      print("\nDone")
+      break
+  end_time = time.perf_counter()
+  print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
 
 
-        if is_finished:
-            print("\nDone")
-            break
-    end_time = time.perf_counter()
-    print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Run prompt")
-    parser.add_argument("--prompt", type=str, help="The prompt to run")
-    args = parser.parse_args()
+  parser = argparse.ArgumentParser(description="Run prompt")
+  parser.add_argument("--prompt", type=str, help="The prompt to run")
+  args = parser.parse_args()
 
 
-    asyncio.run(run_prompt(args.prompt))
+  asyncio.run(run_prompt(args.prompt))

+ 62 - 59
exo/api/chatgpt_api.py

@@ -16,29 +16,27 @@ from exo.orchestration import Node
 from exo.models import model_base_shards
 from exo.models import model_base_shards
 from typing import Callable
 from typing import Callable
 
 
+
 class Message:
 class Message:
-    def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
-        self.role = role
-        self.content = content
 
 
-    def to_dict(self):
-        return {
-            "role": self.role,
-            "content": self.content
-        }
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+    self.role = role
+    self.content = content
+
+  def to_dict(self):
+    return {"role": self.role, "content": self.content}
+
 
 
 class ChatCompletionRequest:
 class ChatCompletionRequest:
-    def __init__(self, model: str, messages: List[Message], temperature: float):
-        self.model = model
-        self.messages = messages
-        self.temperature = temperature
-
-    def to_dict(self):
-        return {
-            "model": self.model,
-            "messages": [message.to_dict() for message in self.messages],
-            "temperature": self.temperature
-        }
+
+  def __init__(self, model: str, messages: List[Message], temperature: float):
+    self.model = model
+    self.messages = messages
+    self.temperature = temperature
+
+  def to_dict(self):
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
+
 
 
 def generate_completion(
 def generate_completion(
   chat_request: ChatCompletionRequest,
   chat_request: ChatCompletionRequest,
@@ -56,14 +54,12 @@ def generate_completion(
     "created": int(time.time()),
     "created": int(time.time()),
     "model": chat_request.model,
     "model": chat_request.model,
     "system_fingerprint": f"exo_{VERSION}",
     "system_fingerprint": f"exo_{VERSION}",
-    "choices": [
-      {
-        "index": 0,
-        "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
-        "logprobs": None,
-        "finish_reason": finish_reason,
-      }
-    ],
+    "choices": [{
+      "index": 0,
+      "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
+      "logprobs": None,
+      "finish_reason": finish_reason,
+    }],
   }
   }
 
 
   if not stream:
   if not stream:
@@ -86,37 +82,38 @@ def generate_completion(
 
 
 
 
 def remap_messages(messages: List[Message]) -> List[Message]:
 def remap_messages(messages: List[Message]) -> List[Message]:
-    remapped_messages = []
-    last_image = None
-    for message in messages:
-        if not isinstance(message.content, list):
-           remapped_messages.append(message)
-           continue
-
-        remapped_content = []
-        for content in message.content:
-            if isinstance(content, dict):
-                if content.get("type") in ["image_url", "image"]:
-                    image_url = content.get("image_url", {}).get("url") or content.get("image")
-                    if image_url:
-                        last_image = {"type": "image", "image": image_url}
-                        remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
-                else:
-                    remapped_content.append(content)
-            else:
-                remapped_content.append(content)
-        remapped_messages.append(Message(role=message.role, content=remapped_content))
-
-    if last_image:
-        # Replace the last image placeholder with the actual image content
-        for message in reversed(remapped_messages):
-            for i, content in enumerate(message.content):
-                if isinstance(content, dict):
-                  if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
-                      message.content[i] = last_image
-                      return remapped_messages
-
-    return remapped_messages
+  remapped_messages = []
+  last_image = None
+  for message in messages:
+    if not isinstance(message.content, list):
+      remapped_messages.append(message)
+      continue
+
+    remapped_content = []
+    for content in message.content:
+      if isinstance(content, dict):
+        if content.get("type") in ["image_url", "image"]:
+          image_url = content.get("image_url", {}).get("url") or content.get("image")
+          if image_url:
+            last_image = {"type": "image", "image": image_url}
+            remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+        else:
+          remapped_content.append(content)
+      else:
+        remapped_content.append(content)
+    remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+  if last_image:
+    # Replace the last image placeholder with the actual image content
+    for message in reversed(remapped_messages):
+      for i, content in enumerate(message.content):
+        if isinstance(content, dict):
+          if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+            message.content[i] = last_image
+            return remapped_messages
+
+  return remapped_messages
+
 
 
 def build_prompt(tokenizer, _messages: List[Message]):
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   messages = remap_messages(_messages)
@@ -149,13 +146,17 @@ def parse_chat_request(data: dict):
     data.get("temperature", 0.0),
     data.get("temperature", 0.0),
   )
   )
 
 
+
 class PromptSession:
 class PromptSession:
+
   def __init__(self, request_id: str, timestamp: int, prompt: str):
   def __init__(self, request_id: str, timestamp: int, prompt: str):
     self.request_id = request_id
     self.request_id = request_id
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
+
 class ChatGPTAPI:
 class ChatGPTAPI:
+
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
@@ -182,6 +183,7 @@ class ChatGPTAPI:
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
 
 
   async def log_request(self, app, handler):
   async def log_request(self, app, handler):
+
     async def middleware(request):
     async def middleware(request):
       if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
       if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
       return await handler(request)
       return await handler(request)
@@ -268,7 +270,8 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
+                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             new_tokens = new_tokens[:-1]
             if is_finished:
             if is_finished:

+ 50 - 64
exo/download/download_progress.py

@@ -2,81 +2,67 @@ from typing import Dict, Callable, Coroutine, Any, Literal
 from dataclasses import dataclass
 from dataclasses import dataclass
 from datetime import timedelta
 from datetime import timedelta
 
 
+
 @dataclass
 @dataclass
 class RepoFileProgressEvent:
 class RepoFileProgressEvent:
-    repo_id: str
-    repo_revision: str
-    file_path: str
-    downloaded: int
-    downloaded_this_session: int
-    total: int
-    speed: int
-    eta: timedelta
-    status: Literal["not_started", "in_progress", "complete"]
+  repo_id: str
+  repo_revision: str
+  file_path: str
+  downloaded: int
+  downloaded_this_session: int
+  total: int
+  speed: int
+  eta: timedelta
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+    }
 
 
-    def to_dict(self):
-        return {
-            "repo_id": self.repo_id,
-            "repo_revision": self.repo_revision,
-            "file_path": self.file_path,
-            "downloaded": self.downloaded,
-            "downloaded_this_session": self.downloaded_this_session,
-            "total": self.total,
-            "speed": self.speed,
-            "eta": self.eta.total_seconds(),
-            "status": self.status
-        }
+  @classmethod
+  def from_dict(cls, data):
+    # Convert eta from seconds back to timedelta
+    if 'eta' in data:
+      data['eta'] = timedelta(seconds=data['eta'])
+    return cls(**data)
 
 
-    @classmethod
-    def from_dict(cls, data):
-        # Convert eta from seconds back to timedelta
-        if 'eta' in data:
-            data['eta'] = timedelta(seconds=data['eta'])
-        return cls(**data)
 
 
 @dataclass
 @dataclass
 class RepoProgressEvent:
 class RepoProgressEvent:
-    repo_id: str
-    repo_revision: str
-    completed_files: int
-    total_files: int
-    downloaded_bytes: int
-    downloaded_bytes_this_session: int
-    total_bytes: int
-    overall_speed: int
-    overall_eta: timedelta
-    file_progress: Dict[str, RepoFileProgressEvent]
-    status: Literal["not_started", "in_progress", "complete"]
+  repo_id: str
+  repo_revision: str
+  completed_files: int
+  total_files: int
+  downloaded_bytes: int
+  downloaded_bytes_this_session: int
+  total_bytes: int
+  overall_speed: int
+  overall_eta: timedelta
+  file_progress: Dict[str, RepoFileProgressEvent]
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
+      "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
+      "file_progress": {k: v.to_dict()
+                        for k, v in self.file_progress.items()}, "status": self.status
+    }
 
 
-    def to_dict(self):
-        return {
-            "repo_id": self.repo_id,
-            "repo_revision": self.repo_revision,
-            "completed_files": self.completed_files,
-            "total_files": self.total_files,
-            "downloaded_bytes": self.downloaded_bytes,
-            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
-            "total_bytes": self.total_bytes,
-            "overall_speed": self.overall_speed,
-            "overall_eta": self.overall_eta.total_seconds(),
-            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
-            "status": self.status
-        }
+  @classmethod
+  def from_dict(cls, data):
+    # Convert overall_eta from seconds back to timedelta
+    if 'overall_eta' in data:
+      data['overall_eta'] = timedelta(seconds=data['overall_eta'])
 
 
-    @classmethod
-    def from_dict(cls, data):
-        # Convert overall_eta from seconds back to timedelta
-        if 'overall_eta' in data:
-            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+    # Parse file_progress
+    if 'file_progress' in data:
+      data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
 
 
-        # Parse file_progress
-        if 'file_progress' in data:
-            data['file_progress'] = {
-                k: RepoFileProgressEvent.from_dict(v)
-                for k, v in data['file_progress'].items()
-            }
+    return cls(**data)
 
 
-        return cls(**data)
 
 
 RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
 RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
 RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]
 RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 334 - 297
exo/download/hf/hf_helpers.py

@@ -16,282 +16,322 @@ import aiofiles
 from aiofiles import os as aios
 from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
+
+
 def filter_repo_objects(
 def filter_repo_objects(
-    items: Iterable[T],
-    *,
-    allow_patterns: Optional[Union[List[str], str]] = None,
-    ignore_patterns: Optional[Union[List[str], str]] = None,
-    key: Optional[Callable[[T], str]] = None,
+  items: Iterable[T],
+  *,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  key: Optional[Callable[[T], str]] = None,
 ) -> Generator[T, None, None]:
 ) -> Generator[T, None, None]:
-    if isinstance(allow_patterns, str):
-        allow_patterns = [allow_patterns]
-    if isinstance(ignore_patterns, str):
-        ignore_patterns = [ignore_patterns]
-    if allow_patterns is not None:
-        allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
-    if ignore_patterns is not None:
-        ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
-
-    if key is None:
-        def _identity(item: T) -> str:
-            if isinstance(item, str):
-                return item
-            if isinstance(item, Path):
-                return str(item)
-            raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-        key = _identity
-
-    for item in items:
-        path = key(item)
-        if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
-            continue
-        if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
-            continue
-        yield item
+  if isinstance(allow_patterns, str):
+    allow_patterns = [allow_patterns]
+  if isinstance(ignore_patterns, str):
+    ignore_patterns = [ignore_patterns]
+  if allow_patterns is not None:
+    allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
+  if ignore_patterns is not None:
+    ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
+
+  if key is None:
+
+    def _identity(item: T) -> str:
+      if isinstance(item, str):
+        return item
+      if isinstance(item, Path):
+        return str(item)
+      raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
+
+    key = _identity
+
+  for item in items:
+    path = key(item)
+    if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
+      continue
+    if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
+      continue
+    yield item
+
 
 
 def _add_wildcard_to_directories(pattern: str) -> str:
 def _add_wildcard_to_directories(pattern: str) -> str:
-    if pattern[-1] == "/":
-        return pattern + "*"
-    return pattern
+  if pattern[-1] == "/":
+    return pattern + "*"
+  return pattern
+
 
 
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
-    """Get the Hugging Face home directory."""
-    return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+  """Get the Hugging Face home directory."""
+  return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+
 
 
 async def get_hf_token():
 async def get_hf_token():
-    """Retrieve the Hugging Face token from the user's HF_HOME directory."""
-    token_path = get_hf_home() / "token"
-    if await aios.path.exists(token_path):
-        async with aiofiles.open(token_path, 'r') as f:
-            return (await f.read()).strip()
-    return None
+  """Retrieve the Hugging Face token from the user's HF_HOME directory."""
+  token_path = get_hf_home() / "token"
+  if await aios.path.exists(token_path):
+    async with aiofiles.open(token_path, 'r') as f:
+      return (await f.read()).strip()
+  return None
+
 
 
 async def get_auth_headers():
 async def get_auth_headers():
-    """Get authentication headers if a token is available."""
-    token = await get_hf_token()
-    if token:
-        return {"Authorization": f"Bearer {token}"}
-    return {}
+  """Get authentication headers if a token is available."""
+  token = await get_hf_token()
+  if token:
+    return {"Authorization": f"Bearer {token}"}
+  return {}
+
 
 
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
-    """Get the root directory for a given repo ID in the Hugging Face cache."""
-    sanitized_repo_id = repo_id.replace("/", "--")
-    return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+  """Get the root directory for a given repo ID in the Hugging Face cache."""
+  sanitized_repo_id = repo_id.replace("/", "--")
+  return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
-    api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
-    url = f"{api_url}/{path}" if path else api_url
-
-    headers = await get_auth_headers()
-    async with session.get(url, headers=headers) as response:
-        if response.status == 200:
-            data = await response.json()
-            files = []
-            for item in data:
-                if item["type"] == "file":
-                    files.append({"path": item["path"], "size": item["size"]})
-                elif item["type"] == "directory":
-                    subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-                    files.extend(subfiles)
-            return files
-        else:
-            raise Exception(f"Failed to fetch file list: {response.status}")
+  api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
 
 
 
 
 @retry(
 @retry(
-    stop=stop_after_attempt(5),
-    wait=wait_exponential(multiplier=1, min=4, max=60),
-    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
-    reraise=True
+  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
 )
 )
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
-    base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, file_path)
-    local_path = os.path.join(save_directory, file_path)
-
-    await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-    # Check if file already exists and get its size
-    local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-    headers = await get_auth_headers()
-    if use_range_request:
-        headers["Range"] = f"bytes={local_file_size}-"
-
-    async with session.get(url, headers=headers) as response:
-        total_size = int(response.headers.get('Content-Length', 0))
-        downloaded_size = local_file_size
-        downloaded_this_session = 0
-        mode = 'ab' if use_range_request else 'wb'
+async def download_file(
+  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
+):
+  base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, file_path)
+  local_path = os.path.join(save_directory, file_path)
+
+  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
+
+  # Check if file already exists and get its size
+  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
+
+  headers = await get_auth_headers()
+  if use_range_request:
+    headers["Range"] = f"bytes={local_file_size}-"
+
+  async with session.get(url, headers=headers) as response:
+    total_size = int(response.headers.get('Content-Length', 0))
+    downloaded_size = local_file_size
+    downloaded_this_session = 0
+    mode = 'ab' if use_range_request else 'wb'
+    if downloaded_size == total_size:
+      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    if response.status == 200:
+      # File doesn't support range requests or we're not using them, start from beginning
+      mode = 'wb'
+      downloaded_size = 0
+    elif response.status == 206:
+      # Partial content, resume download
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+    elif response.status == 416:
+      # Range not satisfiable, get the actual file size
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
         if downloaded_size == total_size:
         if downloaded_size == total_size:
-            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-            if progress_callback:
-                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-            return
-
-        if response.status == 200:
-            # File doesn't support range requests or we're not using them, start from beginning
-            mode = 'wb'
-            downloaded_size = 0
-        elif response.status == 206:
-            # Partial content, resume download
-            content_range = response.headers.get('Content-Range', '')
-            try:
-                total_size = int(content_range.split('/')[-1])
-            except ValueError:
-                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-        elif response.status == 416:
-            # Range not satisfiable, get the actual file size
-            content_range = response.headers.get('Content-Range', '')
-            try:
-                total_size = int(content_range.split('/')[-1])
-                if downloaded_size == total_size:
-                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-                    if progress_callback:
-                        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-                    return
-            except ValueError:
-                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-        else:
-            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
-        if downloaded_size == total_size:
-            print(f"File already downloaded: {file_path}")
-            if progress_callback:
-                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-            return
-
-        DOWNLOAD_CHUNK_SIZE = 32768
-        start_time = datetime.now()
-        async with aiofiles.open(local_path, mode) as f:
-            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-                await f.write(chunk)
-                downloaded_size += len(chunk)
-                downloaded_this_session += len(chunk)
-                if progress_callback and total_size:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_size = total_size - downloaded_size
-                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
-                    status = "in_progress" if downloaded_size < total_size else "complete"
-                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-                    await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-        if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
-    repo_root = get_repo_root(repo_id)
-    refs_dir = repo_root / "refs"
-    snapshots_dir = repo_root / "snapshots"
-    cachedreqs_dir = repo_root / "cachedreqs"
-
-    # Ensure directories exist
-    await aios.makedirs(refs_dir, exist_ok=True)
-    await aios.makedirs(snapshots_dir, exist_ok=True)
-    await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-    # Check if we have a cached commit hash
-    refs_file = refs_dir / revision
-    if await aios.path.exists(refs_file):
-        async with aiofiles.open(refs_file, 'r') as f:
-            commit_hash = (await f.read()).strip()
-            if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+          if progress_callback:
+            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+          return
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
     else:
     else:
-        async with aiohttp.ClientSession() as session:
-            # Fetch the commit hash for the given revision
-            api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
-            headers = await get_auth_headers()
-            async with session.get(api_url, headers=headers) as response:
-                if response.status != 200:
-                    raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-                revision_info = await response.json()
-                commit_hash = revision_info['sha']
-
-            # Cache the commit hash
-            async with aiofiles.open(refs_file, 'w') as f:
-                await f.write(commit_hash)
-
-    # Set up the snapshot directory
-    snapshot_dir = snapshots_dir / commit_hash
-    await aios.makedirs(snapshot_dir, exist_ok=True)
-
-    # Set up the cached file list directory
-    cached_file_list_dir = cachedreqs_dir / commit_hash
-    await aios.makedirs(cached_file_list_dir, exist_ok=True)
-    cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
-
+      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
+
+    if downloaded_size == total_size:
+      print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    DOWNLOAD_CHUNK_SIZE = 32768
+    start_time = datetime.now()
+    async with aiofiles.open(local_path, mode) as f:
+      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
+        await f.write(chunk)
+        downloaded_size += len(chunk)
+        downloaded_this_session += len(chunk)
+        if progress_callback and total_size:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
+          remaining_size = total_size - downloaded_size
+          eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+          status = "in_progress" if downloaded_size < total_size else "complete"
+          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+    if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+
+async def download_repo_files(
+  repo_id: str,
+  revision: str = "main",
+  progress_callback: Optional[RepoProgressCallback] = None,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  max_parallel_downloads: int = 4
+) -> Path:
+  repo_root = get_repo_root(repo_id)
+  refs_dir = repo_root / "refs"
+  snapshots_dir = repo_root / "snapshots"
+  cachedreqs_dir = repo_root / "cachedreqs"
+
+  # Ensure directories exist
+  await aios.makedirs(refs_dir, exist_ok=True)
+  await aios.makedirs(snapshots_dir, exist_ok=True)
+  await aios.makedirs(cachedreqs_dir, exist_ok=True)
+
+  # Check if we have a cached commit hash
+  refs_file = refs_dir / revision
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+  else:
     async with aiohttp.ClientSession() as session:
     async with aiohttp.ClientSession() as session:
-        # Check if we have a cached file list
-        if await aios.path.exists(cached_file_list_path):
-            async with aiofiles.open(cached_file_list_path, 'r') as f:
-                file_list = json.loads(await f.read())
-            if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-        else:
-            file_list = await fetch_file_list(session, repo_id, revision)
-            # Cache the file list
-            async with aiofiles.open(cached_file_list_path, 'w') as f:
-                await f.write(json.dumps(file_list))
-            if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
-        filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
-        total_files = len(filtered_file_list)
-        total_bytes = sum(file["size"] for file in filtered_file_list)
-        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
-        start_time = datetime.now()
-
-        async def download_with_progress(file_info, progress_state):
-            local_path = snapshot_dir / file_info["path"]
-            if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-                if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-                progress_state['completed_files'] += 1
-                progress_state['downloaded_bytes'] += file_info["size"]
-                file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-                if progress_callback:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                    status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-                return
-
-            async def file_progress_callback(event: RepoFileProgressEvent):
-                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-                file_progress[event.file_path] = event
-                if progress_callback:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-
-            await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-            progress_state['completed_files'] += 1
-            file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-            if progress_callback:
-                elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-
-        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-        semaphore = asyncio.Semaphore(max_parallel_downloads)
-        async def download_with_semaphore(file_info):
-            async with semaphore:
-                await download_with_progress(file_info, progress_state)
-        tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-        await asyncio.gather(*tasks)
-
-    return snapshot_dir
+      # Fetch the commit hash for the given revision
+      api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+      headers = await get_auth_headers()
+      async with session.get(api_url, headers=headers) as response:
+        if response.status != 200:
+          raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+        revision_info = await response.json()
+        commit_hash = revision_info['sha']
+
+      # Cache the commit hash
+      async with aiofiles.open(refs_file, 'w') as f:
+        await f.write(commit_hash)
+
+  # Set up the snapshot directory
+  snapshot_dir = snapshots_dir / commit_hash
+  await aios.makedirs(snapshot_dir, exist_ok=True)
+
+  # Set up the cached file list directory
+  cached_file_list_dir = cachedreqs_dir / commit_hash
+  await aios.makedirs(cached_file_list_dir, exist_ok=True)
+  cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
+
+  async with aiohttp.ClientSession() as session:
+    # Check if we have a cached file list
+    if await aios.path.exists(cached_file_list_path):
+      async with aiofiles.open(cached_file_list_path, 'r') as f:
+        file_list = json.loads(await f.read())
+      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
+    else:
+      file_list = await fetch_file_list(session, repo_id, revision)
+      # Cache the file list
+      async with aiofiles.open(cached_file_list_path, 'w') as f:
+        await f.write(json.dumps(file_list))
+      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
+
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+    total_files = len(filtered_file_list)
+    total_bytes = sum(file["size"] for file in filtered_file_list)
+    file_progress: Dict[str, RepoFileProgressEvent] = {
+      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
+      for file in filtered_file_list
+    }
+    start_time = datetime.now()
+
+    async def download_with_progress(file_info, progress_state):
+      local_path = snapshot_dir / file_info["path"]
+      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
+        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
+        progress_state['completed_files'] += 1
+        progress_state['downloaded_bytes'] += file_info["size"]
+        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+        return
+
+      async def file_progress_callback(event: RepoFileProgressEvent):
+        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
+        file_progress[event.file_path] = event
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+
+      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
+      progress_state['completed_files'] += 1
+      file_progress[
+        file_info["path"]
+      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+      if progress_callback:
+        elapsed_time = (datetime.now() - start_time).total_seconds()
+        overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+        overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+        await progress_callback(
+          RepoProgressEvent(
+            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+            overall_eta, file_progress, status
+          )
+        )
+
+    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+
+    async def download_with_semaphore(file_info):
+      async with semaphore:
+        await download_with_progress(file_info, progress_state)
+
+    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
+    await asyncio.gather(*tasks)
+
+  return snapshot_dir
+
 
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
 async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-    """
+  """
     Retrieve the weight map from the model.safetensors.index.json file.
     Retrieve the weight map from the model.safetensors.index.json file.
 
 
     Args:
     Args:
@@ -302,55 +342,52 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
         Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
         Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
     """
     """
 
 
-    # Download the index file
-    await download_repo_files(
-        repo_id=repo_id,
-        revision=revision,
-        allow_patterns="model.safetensors.index.json"
-    )
+  # Download the index file
+  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
 
 
-    # Check if the file exists
-    repo_root = get_repo_root(repo_id)
-    snapshot_dir = repo_root / "snapshots"
-    index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
+  # Check if the file exists
+  repo_root = get_repo_root(repo_id)
+  snapshot_dir = repo_root / "snapshots"
+  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
 
 
-    if index_file:
-        index_file_path = snapshot_dir / index_file
-        if await aios.path.exists(index_file_path):
-            async with aiofiles.open(index_file_path, 'r') as f:
-                index_data = json.loads(await f.read())
-            return index_data.get("weight_map")
+  if index_file:
+    index_file_path = snapshot_dir / index_file
+    if await aios.path.exists(index_file_path):
+      async with aiofiles.open(index_file_path, 'r') as f:
+        index_data = json.loads(await f.read())
+      return index_data.get("weight_map")
+
+  return None
 
 
-    return None
 
 
 def extract_layer_num(tensor_name: str) -> Optional[int]:
 def extract_layer_num(tensor_name: str) -> Optional[int]:
-    # This is a simple example and might need to be adjusted based on the actual naming convention
-    parts = tensor_name.split('.')
-    for part in parts:
-        if part.isdigit():
-            return int(part)
-    return None
+  # This is a simple example and might need to be adjusted based on the actual naming convention
+  parts = tensor_name.split('.')
+  for part in parts:
+    if part.isdigit():
+      return int(part)
+  return None
 
 
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-    default_patterns = [
-        "*.json",
-        "*.py",
-        "tokenizer.model",
-        "*.tiktoken",
-        "*.txt",
-    ]
-    shard_specific_patterns = []
-    if weight_map:
-        for tensor_name, filename in weight_map.items():
-            layer_num = extract_layer_num(tensor_name)
-            if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
-                shard_specific_patterns.append(filename)
-        sorted_file_names = sorted(weight_map.values())
-        if shard.is_first_layer():
-            shard_specific_patterns.append(sorted_file_names[0])
-        elif shard.is_last_layer():
-            shard_specific_patterns.append(sorted_file_names[-1])
-    else:
-        shard_specific_patterns = ["*.safetensors"]
-    return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
+  default_patterns = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+  ]
+  shard_specific_patterns = []
+  if weight_map:
+    for tensor_name, filename in weight_map.items():
+      layer_num = extract_layer_num(tensor_name)
+      if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+        shard_specific_patterns.append(filename)
+    sorted_file_names = sorted(weight_map.values())
+    if shard.is_first_layer():
+      shard_specific_patterns.append(sorted_file_names[0])
+    elif shard.is_last_layer():
+      shard_specific_patterns.append(sorted_file_names[-1])
+  else:
+    shard_specific_patterns = ["*.safetensors"]
+  return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 58 - 60
exo/download/hf/hf_shard_download.py

@@ -8,72 +8,70 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
 
+
 class HFShardDownloader(ShardDownloader):
 class HFShardDownloader(ShardDownloader):
-    def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-        self.quick_check = quick_check
-        self.max_parallel_downloads = max_parallel_downloads
-        self.active_downloads: Dict[Shard, asyncio.Task] = {}
-        self.completed_downloads: Dict[Shard, Path] = {}
-        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
 
-    async def ensure_shard(self, shard: Shard) -> Path:
-        if shard in self.completed_downloads:
-            return self.completed_downloads[shard]
-        if self.quick_check:
-            repo_root = get_repo_root(shard.model_id)
-            snapshots_dir = repo_root / "snapshots"
-            if snapshots_dir.exists():
-                most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
-                return most_recent_dir
+  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
+    self.quick_check = quick_check
+    self.max_parallel_downloads = max_parallel_downloads
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+    self.completed_downloads: Dict[Shard, Path] = {}
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+  async def ensure_shard(self, shard: Shard) -> Path:
+    if shard in self.completed_downloads:
+      return self.completed_downloads[shard]
+    if self.quick_check:
+      repo_root = get_repo_root(shard.model_id)
+      snapshots_dir = repo_root / "snapshots"
+      if snapshots_dir.exists():
+        most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
+        return most_recent_dir
+
+    # If a download on this shard is already in progress, keep that one
+    for active_shard in self.active_downloads:
+      if active_shard == shard:
+        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
+        return await self.active_downloads[shard]
 
 
-        # If a download on this shard is already in progress, keep that one
-        for active_shard in self.active_downloads:
-            if active_shard == shard:
-                if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-                return await self.active_downloads[shard]
+    # Cancel any downloads for this model_id on a different shard
+    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+    for active_shard in existing_active_shards:
+      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+      task = self.active_downloads[active_shard]
+      task.cancel()
+      try:
+        await task
+      except asyncio.CancelledError:
+        pass  # This is expected when cancelling a task
+      except Exception as e:
+        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+        traceback.print_exc()
+    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
 
-        # Cancel any downloads for this model_id on a different shard
-        existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-        for active_shard in existing_active_shards:
-            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-            task = self.active_downloads[active_shard]
-            task.cancel()
-            try:
-                await task
-            except asyncio.CancelledError:
-                pass  # This is expected when cancelling a task
-            except Exception as e:
-                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-                traceback.print_exc()
-        self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
+    # Start new download
+    download_task = asyncio.create_task(self._download_shard(shard))
+    self.active_downloads[shard] = download_task
+    try:
+      path = await download_task
+      self.completed_downloads[shard] = path
+      return path
+    finally:
+      # Ensure the task is removed even if an exception occurs
+      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+      if shard in self.active_downloads:
+        self.active_downloads.pop(shard)
 
 
-        # Start new download
-        download_task = asyncio.create_task(self._download_shard(shard))
-        self.active_downloads[shard] = download_task
-        try:
-            path = await download_task
-            self.completed_downloads[shard] = path
-            return path
-        finally:
-            # Ensure the task is removed even if an exception occurs
-            print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-            if shard in self.active_downloads:
-                self.active_downloads.pop(shard)
+  async def _download_shard(self, shard: Shard) -> Path:
 
 
-    async def _download_shard(self, shard: Shard) -> Path:
-        async def wrapped_progress_callback(event: RepoProgressEvent):
-            self._on_progress.trigger_all(shard, event)
+    async def wrapped_progress_callback(event: RepoProgressEvent):
+      self._on_progress.trigger_all(shard, event)
 
 
-        weight_map = await get_weight_map(shard.model_id)
-        allow_patterns = get_allow_patterns(weight_map, shard)
+    weight_map = await get_weight_map(shard.model_id)
+    allow_patterns = get_allow_patterns(weight_map, shard)
 
 
-        return await download_repo_files(
-            repo_id=shard.model_id,
-            progress_callback=wrapped_progress_callback,
-            allow_patterns=allow_patterns,
-            max_parallel_downloads=self.max_parallel_downloads
-        )
+    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
 
-    @property
-    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-        return self._on_progress
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress

+ 10 - 8
exo/download/shard_download.py

@@ -5,10 +5,12 @@ from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 
 
+
 class ShardDownloader(ABC):
 class ShardDownloader(ABC):
-    @abstractmethod
-    async def ensure_shard(self, shard: Shard) -> Path:
-        """
+
+  @abstractmethod
+  async def ensure_shard(self, shard: Shard) -> Path:
+    """
         Ensures that the shard is downloaded.
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
         Does not allow multiple overlapping downloads at once.
         If you try to download a Shard which overlaps a Shard that is already being downloaded,
         If you try to download a Shard which overlaps a Shard that is already being downloaded,
@@ -17,9 +19,9 @@ class ShardDownloader(ABC):
         Args:
         Args:
             shard (Shard): The shard to download.
             shard (Shard): The shard to download.
         """
         """
-        pass
+    pass
 
 
-    @property
-    @abstractmethod
-    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-        pass
+  @property
+  @abstractmethod
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    pass

+ 85 - 74
exo/helpers.py

@@ -20,6 +20,7 @@ exo_text = r"""
  \___/_/\_\___/ 
  \___/_/\_\___/ 
     """
     """
 
 
+
 def get_system_info():
 def get_system_info():
   if psutil.MACOS:
   if psutil.MACOS:
     if platform.machine() == "arm64":
     if platform.machine() == "arm64":
@@ -87,7 +88,10 @@ def terminal_link(uri, label=None):
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 K = TypeVar("K")
 K = TypeVar("K")
+
+
 class AsyncCallback(Generic[T]):
 class AsyncCallback(Generic[T]):
+
   def __init__(self) -> None:
   def __init__(self) -> None:
     self.condition: asyncio.Condition = asyncio.Condition()
     self.condition: asyncio.Condition = asyncio.Condition()
     self.result: Optional[Tuple[T, ...]] = None
     self.result: Optional[Tuple[T, ...]] = None
@@ -95,9 +99,7 @@ class AsyncCallback(Generic[T]):
 
 
   async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
   async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
     async with self.condition:
     async with self.condition:
-      await asyncio.wait_for(
-        self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
-      )
+      await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
       assert self.result is not None  # for type checking
       assert self.result is not None  # for type checking
       return self.result
       return self.result
 
 
@@ -116,6 +118,7 @@ class AsyncCallback(Generic[T]):
 
 
 
 
 class AsyncCallbackSystem(Generic[K, T]):
 class AsyncCallbackSystem(Generic[K, T]):
+
   def __init__(self) -> None:
   def __init__(self) -> None:
     self.callbacks: Dict[K, AsyncCallback[T]] = {}
     self.callbacks: Dict[K, AsyncCallback[T]] = {}
 
 
@@ -139,89 +142,97 @@ class AsyncCallbackSystem(Generic[K, T]):
 
 
 K = TypeVar('K', bound=str)
 K = TypeVar('K', bound=str)
 V = TypeVar('V')
 V = TypeVar('V')
+
+
 class PrefixDict(Generic[K, V]):
 class PrefixDict(Generic[K, V]):
-    def __init__(self):
-        self.items: Dict[K, V] = {}
 
 
-    def add(self, key: K, value: V) -> None:
-        self.items[key] = value
+  def __init__(self):
+    self.items: Dict[K, V] = {}
+
+  def add(self, key: K, value: V) -> None:
+    self.items[key] = value
 
 
-    def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
-        return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
+  def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
+    return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
 
 
-    def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
-        matches = self.find_prefix(argument)
-        if len(matches) == 0:
-            return None
+  def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
+    matches = self.find_prefix(argument)
+    if len(matches) == 0:
+      return None
+
+    return max(matches, key=lambda x: len(x[0]))
 
 
-        return max(matches, key=lambda x: len(x[0]))
 
 
 def is_valid_uuid(val):
 def is_valid_uuid(val):
-    try:
-        uuid.UUID(str(val))
-        return True
-    except ValueError:
-        return False
+  try:
+    uuid.UUID(str(val))
+    return True
+  except ValueError:
+    return False
+
 
 
 def get_or_create_node_id():
 def get_or_create_node_id():
-    NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
-    try:
-        if NODE_ID_FILE.is_file():
-            with open(NODE_ID_FILE, "r") as f:
-                stored_id = f.read().strip()
-            if is_valid_uuid(stored_id):
-                if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
-                return stored_id
-            else:
-                if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
-
-        new_id = str(uuid.uuid4())
-        with open(NODE_ID_FILE, "w") as f:
-            f.write(new_id)
-
-        if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
-        return new_id
-    except IOError as e:
-        if DEBUG >= 2: print(f"IO error creating node_id: {e}")
-        return str(uuid.uuid4())
-    except Exception as e:
-        if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
-        return str(uuid.uuid4())
+  NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
+  try:
+    if NODE_ID_FILE.is_file():
+      with open(NODE_ID_FILE, "r") as f:
+        stored_id = f.read().strip()
+      if is_valid_uuid(stored_id):
+        if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
+        return stored_id
+      else:
+        if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
+
+    new_id = str(uuid.uuid4())
+    with open(NODE_ID_FILE, "w") as f:
+      f.write(new_id)
+
+    if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
+    return new_id
+  except IOError as e:
+    if DEBUG >= 2: print(f"IO error creating node_id: {e}")
+    return str(uuid.uuid4())
+  except Exception as e:
+    if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
+    return str(uuid.uuid4())
+
 
 
 def pretty_print_bytes(size_in_bytes: int) -> str:
 def pretty_print_bytes(size_in_bytes: int) -> str:
-    if size_in_bytes < 1024:
-        return f"{size_in_bytes} B"
-    elif size_in_bytes < 1024 ** 2:
-        return f"{size_in_bytes / 1024:.2f} KB"
-    elif size_in_bytes < 1024 ** 3:
-        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
-    elif size_in_bytes < 1024 ** 4:
-        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
-    else:
-        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+  if size_in_bytes < 1024:
+    return f"{size_in_bytes} B"
+  elif size_in_bytes < 1024**2:
+    return f"{size_in_bytes / 1024:.2f} KB"
+  elif size_in_bytes < 1024**3:
+    return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+  elif size_in_bytes < 1024**4:
+    return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+  else:
+    return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
 
 
 def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
 def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
-    if bytes_per_second < 1024:
-        return f"{bytes_per_second} B/s"
-    elif bytes_per_second < 1024 ** 2:
-        return f"{bytes_per_second / 1024:.2f} KB/s"
-    elif bytes_per_second < 1024 ** 3:
-        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
-    elif bytes_per_second < 1024 ** 4:
-        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
-    else:
-        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+  if bytes_per_second < 1024:
+    return f"{bytes_per_second} B/s"
+  elif bytes_per_second < 1024**2:
+    return f"{bytes_per_second / 1024:.2f} KB/s"
+  elif bytes_per_second < 1024**3:
+    return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+  elif bytes_per_second < 1024**4:
+    return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+  else:
+    return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+
 
 
 def get_all_ip_addresses():
 def get_all_ip_addresses():
-    try:
-      ip_addresses = []
-      for interface in netifaces.interfaces():
-        ifaddresses = netifaces.ifaddresses(interface)
-        if netifaces.AF_INET in ifaddresses:
-          for link in ifaddresses[netifaces.AF_INET]:
-            ip = link['addr']
-            ip_addresses.append(ip)
-      return list(set(ip_addresses))
-    except:
-      if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-      return ["localhost"]
+  try:
+    ip_addresses = []
+    for interface in netifaces.interfaces():
+      ifaddresses = netifaces.ifaddresses(interface)
+      if netifaces.AF_INET in ifaddresses:
+        for link in ifaddresses[netifaces.AF_INET]:
+          ip = link['addr']
+          ip_addresses.append(ip)
+    return list(set(ip_addresses))
+  except:
+    if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
+    return ["localhost"]

+ 5 - 7
exo/inference/debug_inference_engine.py

@@ -52,10 +52,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(
-  test_inference_engine(
-    TinygradDynamicShardInferenceEngine(),
-    TinygradDynamicShardInferenceEngine(),
-    "llama3-8b-sfr",
-  )
-)
+asyncio.run(test_inference_engine(
+  TinygradDynamicShardInferenceEngine(),
+  TinygradDynamicShardInferenceEngine(),
+  "llama3-8b-sfr",
+))

+ 2 - 0
exo/inference/inference_engine.py

@@ -5,7 +5,9 @@ from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
 
 
+
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
+
   @abstractmethod
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     pass
     pass

+ 1 - 0
exo/inference/mlx/models/base.py

@@ -5,5 +5,6 @@ from mlx_lm.models.base import KVCache
 
 
 
 
 class IdentityBlock(nn.Module):
 class IdentityBlock(nn.Module):
+
   def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
   def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
     return x
     return x

+ 4 - 5
exo/inference/mlx/models/deepseek_v2.py

@@ -7,7 +7,7 @@ import mlx.nn as nn
 from mlx_lm.models.base import KVCache
 from mlx_lm.models.base import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
 from .base import IdentityBlock
-from ...shard import Shard
+from exo.inference.shard import Shard
 
 
 
 
 @dataclass
 @dataclass
@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
 
 
 
 class DeepseekV2Model(nn.Module):
 class DeepseekV2Model(nn.Module):
+
   def __init__(self, config: ModelArgs):
   def __init__(self, config: ModelArgs):
     super().__init__()
     super().__init__()
     self.args = config
     self.args = config
@@ -70,6 +71,7 @@ class DeepseekV2Model(nn.Module):
 
 
 
 
 class Model(nn.Module):
 class Model(nn.Module):
+
   def __init__(self, config: ModelArgs):
   def __init__(self, config: ModelArgs):
     super().__init__()
     super().__init__()
     self.args = config
     self.args = config
@@ -107,10 +109,7 @@ class Model(nn.Module):
         for k in ["weight", "scales", "biases"]:
         for k in ["weight", "scales", "biases"]:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
             to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
             to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
-            shard_state_dict[
-              f"{prefix}.mlp.switch_mlp.{
-       m}.{k}"
-            ] = mx.stack(to_join)
+            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
 
 
     return shard_state_dict
     return shard_state_dict
 
 

+ 5 - 3
exo/inference/mlx/models/llama.py

@@ -24,7 +24,9 @@ class ModelArgs(ModelArgs):
 
 
     self.shard = Shard(**self.shard)
     self.shard = Shard(**self.shard)
 
 
+
 class LlamaModel(nn.Module):
 class LlamaModel(nn.Module):
+
   def __init__(self, args: ModelArgs):
   def __init__(self, args: ModelArgs):
     super().__init__()
     super().__init__()
     self.args = args
     self.args = args
@@ -66,7 +68,9 @@ class LlamaModel(nn.Module):
       h = self.norm(h)
       h = self.norm(h)
     return h
     return h
 
 
+
 class Model(nn.Module):
 class Model(nn.Module):
+
   def __init__(self, args: ModelArgs):
   def __init__(self, args: ModelArgs):
     super().__init__()
     super().__init__()
     self.args = args
     self.args = args
@@ -116,9 +120,7 @@ class Model(nn.Module):
 
 
   @property
   @property
   def head_dim(self):
   def head_dim(self):
-    return (
-      self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
-    )
+    return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
 
 
   @property
   @property
   def n_kv_heads(self):
   def n_kv_heads(self):

+ 521 - 555
exo/inference/mlx/models/llava.py

@@ -15,619 +15,585 @@ import numpy as np
 
 
 @dataclass
 @dataclass
 class VisionConfig:
 class VisionConfig:
-    model_type: str
-    num_hidden_layers: int = 24
-    hidden_size: int = 1024
-    intermediate_size: int = 4096
-    num_attention_heads: int = 16
-    image_size: int = 336
-    patch_size: int = 14
-    projection_dim: int = 768
-    vocab_size: int = 32000
-    num_channels: int = 3
-    layer_norm_eps: float = 1e-5
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
+  model_type: str
+  num_hidden_layers: int = 24
+  hidden_size: int = 1024
+  intermediate_size: int = 4096
+  num_attention_heads: int = 16
+  image_size: int = 336
+  patch_size: int = 14
+  projection_dim: int = 768
+  vocab_size: int = 32000
+  num_channels: int = 3
+  layer_norm_eps: float = 1e-5
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
 
 
 
 
 class VisionAttention(nn.Module):
 class VisionAttention(nn.Module):
-    def __init__(
-            self,
-            dims: int,
-            num_heads: int,
-            query_input_dims: Optional[int] = None,
-            key_input_dims: Optional[int] = None,
-            value_input_dims: Optional[int] = None,
-            value_dims: Optional[int] = None,
-            value_output_dims: Optional[int] = None,
-            bias: bool = False,
-    ):
-        super().__init__()
-
-        if (dims % num_heads) != 0:
-            raise ValueError(
-                "The input feature dimensions should be divisible by the "
-                f"number of heads ({dims} % {num_heads}) != 0"
-            )
-
-        query_input_dims = query_input_dims or dims
-        key_input_dims = key_input_dims or dims
-        value_input_dims = value_input_dims or key_input_dims
-        value_dims = value_dims or dims
-        value_output_dims = value_output_dims or dims
-
-        self.num_heads = num_heads
-        self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
-        self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
-        self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
-        self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
-
-    def __call__(self, queries, keys, values, mask=None):
-        queries = self.q_proj(queries)
-        keys = self.k_proj(keys)
-        values = self.v_proj(values)
-
-        num_heads = self.num_heads
-        B, L, D = queries.shape
-        _, S, _ = keys.shape
-        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
-        values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
-
-        scale = math.sqrt(1 / queries.shape[-1])
-        scores = (queries * scale) @ keys
-        if mask is not None:
-            scores = scores + mask.astype(scores.dtype)
-        scores = mx.softmax(scores, axis=-1)
-        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
-
-        return self.out_proj(values_hat)
+
+  def __init__(
+    self,
+    dims: int,
+    num_heads: int,
+    query_input_dims: Optional[int] = None,
+    key_input_dims: Optional[int] = None,
+    value_input_dims: Optional[int] = None,
+    value_dims: Optional[int] = None,
+    value_output_dims: Optional[int] = None,
+    bias: bool = False,
+  ):
+    super().__init__()
+
+    if (dims % num_heads) != 0:
+      raise ValueError("The input feature dimensions should be divisible by the "
+                       f"number of heads ({dims} % {num_heads}) != 0")
+
+    query_input_dims = query_input_dims or dims
+    key_input_dims = key_input_dims or dims
+    value_input_dims = value_input_dims or key_input_dims
+    value_dims = value_dims or dims
+    value_output_dims = value_output_dims or dims
+
+    self.num_heads = num_heads
+    self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
+    self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
+    self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
+    self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
+
+  def __call__(self, queries, keys, values, mask=None):
+    queries = self.q_proj(queries)
+    keys = self.k_proj(keys)
+    values = self.v_proj(values)
+
+    num_heads = self.num_heads
+    B, L, D = queries.shape
+    _, S, _ = keys.shape
+    queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
+    values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+    scale = math.sqrt(1 / queries.shape[-1])
+    scores = (queries * scale) @ keys
+    if mask is not None:
+      scores = scores + mask.astype(scores.dtype)
+    scores = mx.softmax(scores, axis=-1)
+    values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+    return self.out_proj(values_hat)
 
 
 
 
 class VisionMLP(nn.Module):
 class VisionMLP(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.activation_fn = nn.GELU(approx="fast")
-        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
-        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
 
 
-    def __call__(self, x: mx.array) -> mx.array:
-        x = self.activation_fn(self.fc1(x))
-        x = self.fc2(x)
-        return x
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.activation_fn = nn.GELU(approx="fast")
+    self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+    self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.activation_fn(self.fc1(x))
+    x = self.fc2(x)
+    return x
 
 
 
 
 class VisionEncoderLayer(nn.Module):
 class VisionEncoderLayer(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.embed_dim = config.hidden_size
-        self.self_attn = VisionAttention(
-            config.hidden_size, config.num_attention_heads, bias=True
-        )
-        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-        self.mlp = VisionMLP(config)
-        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-
-    def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
-        y = self.layer_norm1(x)
-        y = self.self_attn(y, y, y, mask)
-        x = x + y
-        y = self.layer_norm2(x)
-        y = self.mlp(y)
-        return x + y
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embed_dim = config.hidden_size
+    self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True)
+    self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+    self.mlp = VisionMLP(config)
+    self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
+    y = self.layer_norm1(x)
+    y = self.self_attn(y, y, y, mask)
+    x = x + y
+    y = self.layer_norm2(x)
+    y = self.mlp(y)
+    return x + y
 
 
 
 
 class VisionEncoder(nn.Module):
 class VisionEncoder(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
 
 
 
 
 class VisionEmbeddings(nn.Module):
 class VisionEmbeddings(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.config = config
-        self.embed_dim = config.hidden_size
-        self.image_size = config.image_size
-        self.patch_size = config.patch_size
-
-        self.class_embedding = mx.zeros((config.hidden_size,))
-
-        self.patch_embedding = nn.Conv2d(
-            in_channels=config.num_channels,
-            out_channels=self.embed_dim,
-            kernel_size=self.patch_size,
-            stride=self.patch_size,
-            bias=False,
-        )
-
-        self.num_patches = (self.image_size // self.patch_size) ** 2
-        self.num_positions = self.num_patches + 1
-        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
-
-    def __call__(self, x: mx.array) -> mx.array:
-        batch_size = x.shape[0]
-        patch_embeddings = self.patch_embedding(x)
-        patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
-        embed_dim = patch_embeddings.shape[-1]
-        cls_embeddings = mx.broadcast_to(
-            self.class_embedding, (batch_size, 1, embed_dim)
-        )
-        embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
-        embeddings += self.position_embedding.weight
-        return embeddings
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.config = config
+    self.embed_dim = config.hidden_size
+    self.image_size = config.image_size
+    self.patch_size = config.patch_size
+
+    self.class_embedding = mx.zeros((config.hidden_size, ))
+
+    self.patch_embedding = nn.Conv2d(
+      in_channels=config.num_channels,
+      out_channels=self.embed_dim,
+      kernel_size=self.patch_size,
+      stride=self.patch_size,
+      bias=False,
+    )
+
+    self.num_patches = (self.image_size // self.patch_size)**2
+    self.num_positions = self.num_patches + 1
+    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    batch_size = x.shape[0]
+    patch_embeddings = self.patch_embedding(x)
+    patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
+    embed_dim = patch_embeddings.shape[-1]
+    cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim))
+    embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
+    embeddings += self.position_embedding.weight
+    return embeddings
 
 
 
 
 class ClipVisionModel(nn.Module):
 class ClipVisionModel(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.embeddings = VisionEmbeddings(config)
-        self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
-        self.encoder = VisionEncoder(config)
-        self.post_layernorm = nn.LayerNorm(config.hidden_size)
 
 
-    def __call__(
-        self,
-        x: mx.array,
-        output_hidden_states: Optional[bool] = None,
-    ) -> mx.array:
-        x = self.embeddings(x)
-        x = self.pre_layrnorm(x)
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embeddings = VisionEmbeddings(config)
+    self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
+    self.encoder = VisionEncoder(config)
+    self.post_layernorm = nn.LayerNorm(config.hidden_size)
+
+  def __call__(
+    self,
+    x: mx.array,
+    output_hidden_states: Optional[bool] = None,
+  ) -> mx.array:
+    x = self.embeddings(x)
+    x = self.pre_layrnorm(x)
 
 
-        encoder_states = (x,) if output_hidden_states else None
+    encoder_states = (x, ) if output_hidden_states else None
 
 
-        for l in self.encoder.layers:
-            x = l(x, mask=None)
-            if output_hidden_states:
-                encoder_states = encoder_states + (x,)
+    for l in self.encoder.layers:
+      x = l(x, mask=None)
+      if output_hidden_states:
+        encoder_states = encoder_states + (x, )
 
 
-        pooler_output = self.post_layernorm(x[:, 0, :])
-        return pooler_output, x, encoder_states
+    pooler_output = self.post_layernorm(x[:, 0, :])
+    return pooler_output, x, encoder_states
 
 
 
 
 class VisionModel(nn.Module):
 class VisionModel(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-
-        self.model_type = config.model_type
-        if self.model_type != "clip_vision_model":
-            raise ValueError(f"Unsupported model type: {self.model_type}")
-
-        self.vision_model = ClipVisionModel(config)
-
-    def __call__(
-            self, x: mx.array, output_hidden_states: Optional[bool] = None
-    ) -> mx.array:
-        return self.vision_model(x, output_hidden_states)
-
-    def sanitize(self, weights):
-        sanitized_weights = {}
-        for k, v in weights.items():
-            if "position_ids" in k:
-                # Remove unused position_ids
-                continue
-            elif "patch_embedding.weight" in k:
-                # PyTorch conv2d weight tensors have shape:
-                #   [out_channels, in_channels, kH, KW]
-                # MLX conv2d expects the weight be of shape:
-                #   [out_channels, kH, KW, in_channels]
-                sanitized_weights[k] = v.transpose(0, 2, 3, 1)
-            else:
-                sanitized_weights[k] = v
-
-        return sanitized_weights
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+
+    self.model_type = config.model_type
+    if self.model_type != "clip_vision_model":
+      raise ValueError(f"Unsupported model type: {self.model_type}")
+
+    self.vision_model = ClipVisionModel(config)
+
+  def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array:
+    return self.vision_model(x, output_hidden_states)
+
+  def sanitize(self, weights):
+    sanitized_weights = {}
+    for k, v in weights.items():
+      if "position_ids" in k:
+        # Remove unused position_ids
+        continue
+      elif "patch_embedding.weight" in k:
+        # PyTorch conv2d weight tensors have shape:
+        #   [out_channels, in_channels, kH, KW]
+        # MLX conv2d expects the weight be of shape:
+        #   [out_channels, kH, KW, in_channels]
+        sanitized_weights[k] = v.transpose(0, 2, 3, 1)
+      else:
+        sanitized_weights[k] = v
+
+    return sanitized_weights
 
 
 
 
 @dataclass
 @dataclass
 class TextConfig:
 class TextConfig:
-    model_type: str
-    hidden_size: int = 4096
-    num_hidden_layers: int = 32
-    intermediate_size: int = 11008
-    num_attention_heads: int = 32
-    head_dim: int = None
-    rms_norm_eps: float = 1e-6
-    vocab_size: int = 32000
-    num_key_value_heads: int = None
-    rope_theta: float = 10000
-    rope_traditional: bool = False
-    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
-
-    def __post_init__(self):
-        if self.num_key_value_heads is None:
-            self.num_key_value_heads = self.num_attention_heads
-
-        if self.head_dim is None:
-            self.head_dim = self.hidden_size // self.num_attention_heads
-
-        if self.model_type is None:
-            self.model_type = "llama"
-
-        if self.rope_scaling:
-            required_keys = {"factor", "type"}
-            if not all(key in self.rope_scaling for key in required_keys):
-                raise ValueError(f"rope_scaling must contain keys {required_keys}")
-
-            if self.rope_scaling["type"] != "linear":
-                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+  model_type: str
+  hidden_size: int = 4096
+  num_hidden_layers: int = 32
+  intermediate_size: int = 11008
+  num_attention_heads: int = 32
+  head_dim: int = None
+  rms_norm_eps: float = 1e-6
+  vocab_size: int = 32000
+  num_key_value_heads: int = None
+  rope_theta: float = 10000
+  rope_traditional: bool = False
+  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+  def __post_init__(self):
+    if self.num_key_value_heads is None:
+      self.num_key_value_heads = self.num_attention_heads
+
+    if self.head_dim is None:
+      self.head_dim = self.hidden_size // self.num_attention_heads
+
+    if self.model_type is None:
+      self.model_type = "llama"
+
+    if self.rope_scaling:
+      required_keys = {"factor", "type"}
+      if not all(key in self.rope_scaling for key in required_keys):
+        raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+      if self.rope_scaling["type"] != "linear":
+        raise ValueError("rope_scaling 'type' currently only supports 'linear'")
 
 
 
 
 class TextAttention(nn.Module):
 class TextAttention(nn.Module):
-    def __init__(self, config: TextConfig):
-        super().__init__()
-
-        dim = config.hidden_size
-        self.n_heads = n_heads = config.num_attention_heads
-        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
-
-        self.repeats = n_heads // n_kv_heads
-
-        head_dim = config.hidden_size // n_heads
-        self.scale = head_dim ** -0.5
-
-        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
-        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
-
-        rope_scale = (
-            1 / config.rope_scaling["factor"]
-            if config.rope_scaling is not None
-               and config.rope_scaling["type"] == "linear"
-            else 1
-        )
-        self.rope = nn.RoPE(
-            head_dim,
-            traditional=config.rope_traditional,
-            base=config.rope_theta,
-            scale=rope_scale,
-        )
-
-    def __call__(
-            self,
-            x: mx.array,
-            mask: Optional[mx.array] = None,
-            cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        B, L, D = x.shape
-
-        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-        # Prepare the queries, keys and values for the attention computation
-        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-        if cache is not None:
-            queries = self.rope(queries, offset=cache.offset)
-            keys = self.rope(keys, offset=cache.offset)
-            keys, values = cache.update_and_fetch(keys, values)
-        else:
-            queries = self.rope(queries)
-            keys = self.rope(keys)
-
-        output = mx.fast.scaled_dot_product_attention(
-            queries, keys, values, scale=self.scale, mask=mask
-        )
-        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-        return self.o_proj(output)
+
+  def __init__(self, config: TextConfig):
+    super().__init__()
+
+    dim = config.hidden_size
+    self.n_heads = n_heads = config.num_attention_heads
+    self.n_kv_heads = n_kv_heads = config.num_key_value_heads
+
+    self.repeats = n_heads // n_kv_heads
+
+    head_dim = config.hidden_size // n_heads
+    self.scale = head_dim**-0.5
+
+    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
+    self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+    self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+
+    rope_scale = (1 / config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
+    self.rope = nn.RoPE(
+      head_dim,
+      traditional=config.rope_traditional,
+      base=config.rope_theta,
+      scale=rope_scale,
+    )
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    B, L, D = x.shape
+
+    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+    # Prepare the queries, keys and values for the attention computation
+    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+    if cache is not None:
+      queries = self.rope(queries, offset=cache.offset)
+      keys = self.rope(keys, offset=cache.offset)
+      keys, values = cache.update_and_fetch(keys, values)
+    else:
+      queries = self.rope(queries)
+      keys = self.rope(keys)
+
+    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
+    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+    return self.o_proj(output)
 
 
 
 
 class TextMLP(nn.Module):
 class TextMLP(nn.Module):
-    def __init__(self, dim, hidden_dim):
-        super().__init__()
-        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
-        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
-        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 
 
-    def __call__(self, x) -> mx.array:
-        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+  def __init__(self, dim, hidden_dim):
+    super().__init__()
+    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
+    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
+
+  def __call__(self, x) -> mx.array:
+    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 class TransformerBlock(nn.Module):
 class TransformerBlock(nn.Module):
-    def __init__(self, config: TextConfig):
-        super().__init__()
-        self.num_attention_heads = config.num_attention_heads
-        self.hidden_size = config.hidden_size
-        self.self_attn = TextAttention(config)
-        self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
-        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-        self.post_attention_layernorm = nn.RMSNorm(
-            config.hidden_size, eps=config.rms_norm_eps
-        )
-        self.config = config
-
-    def __call__(
-            self,
-            x: mx.array,
-            mask: Optional[mx.array] = None,
-            cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        r = self.self_attn(self.input_layernorm(x), mask, cache)
-        h = x + r
-        r = self.mlp(self.post_attention_layernorm(h))
-        out = h + r
-        return out
+
+  def __init__(self, config: TextConfig):
+    super().__init__()
+    self.num_attention_heads = config.num_attention_heads
+    self.hidden_size = config.hidden_size
+    self.self_attn = TextAttention(config)
+    self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
+    self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.config = config
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    r = self.self_attn(self.input_layernorm(x), mask, cache)
+    h = x + r
+    r = self.mlp(self.post_attention_layernorm(h))
+    out = h + r
+    return out
 
 
 
 
 class Llama(nn.Module):
 class Llama(nn.Module):
-    def __init__(self, config: TextConfig, shard: Shard):
-        super().__init__()
-        self.config = config
-        self.shard = shard
-        self.vocab_size = config.vocab_size
-        self.model_type = config.model_type
-        self.num_hidden_layers = config.num_hidden_layers
-        self.num_key_value_heads = config.num_key_value_heads
-        self.head_dim = config.head_dim
-        assert self.vocab_size > 0
-        if self.shard.is_first_layer():
-            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
-        self.layers = []
-        for i in range(self.num_hidden_layers):
-          if self.shard.start_layer <= i <= self.shard.end_layer:
-            self.layers.append(TransformerBlock(config=config))
-          else:
-            self.layers.append(IdentityBlock())
-        if self.shard.is_last_layer():
-            self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
-    def __call__(
-            self,
-            inputs: mx.array,
-            cache=None,
-            inputs_embeds=None,
-    ):
-        # for passing merged input embeddings
-        if inputs_embeds is None:
-            if self.shard.is_first_layer():
-                h = self.embed_tokens(inputs)
-            else:
-                h = inputs
-        else:
-            h = inputs_embeds
-
-        mask = None
-        if h.shape[1] > 1:
-            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
-            mask = mask.astype(h.dtype)
-
-        if cache is None:
-            cache = [None] * len(self.layers)
-
-
-        for layer, c in zip(self.layers, cache):
-            h = layer(h, mask, c)
-
-        if self.shard.is_last_layer():
-            h = self.norm(h)
-        return h
+
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.config = config
+    self.shard = shard
+    self.vocab_size = config.vocab_size
+    self.model_type = config.model_type
+    self.num_hidden_layers = config.num_hidden_layers
+    self.num_key_value_heads = config.num_key_value_heads
+    self.head_dim = config.head_dim
+    assert self.vocab_size > 0
+    if self.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.shard.start_layer <= i <= self.shard.end_layer:
+        self.layers.append(TransformerBlock(config=config))
+      else:
+        self.layers.append(IdentityBlock())
+    if self.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    # for passing merged input embeddings
+    if inputs_embeds is None:
+      if self.shard.is_first_layer():
+        h = self.embed_tokens(inputs)
+      else:
+        h = inputs
+    else:
+      h = inputs_embeds
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
 
 
 class LanguageModel(nn.Module):
 class LanguageModel(nn.Module):
-    def __init__(self, config: TextConfig, shard: Shard):
-        super().__init__()
-        self.model_type = config.model_type
-        if self.model_type != "llama":
-            raise ValueError(
-                f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
-            )
-        self.shard = shard
-        self.model = Llama(config, shard)
-        if self.shard.is_last_layer():
-            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-        inputs_embeds=None,
-    ):
-        out = self.model(inputs, cache, inputs_embeds)
-        if self.shard.is_last_layer():
-            out = self.lm_head(out)
-        return out
-
-    def sanitize(self, weights):
-        shard_state_dict = {}
-        for key, value in weights.items():
-            if "self_attn.rotary_emb.inv_freq" in key:
-                continue
-
-            if key.startswith('language_model.model.layers.'):
-                layer_num = int(key.split('.')[3])
-                if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
-                    continue
-            if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
-                continue
-            elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
-                continue
-
-            shard_state_dict[key] = value
-
-        return shard_state_dict
+
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.model_type = config.model_type
+    if self.model_type != "llama":
+      raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported")
+    self.shard = shard
+    self.model = Llama(config, shard)
+    if self.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    out = self.model(inputs, cache, inputs_embeds)
+    if self.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+
+      if key.startswith('language_model.model.layers.'):
+        layer_num = int(key.split('.')[3])
+        if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+          continue
+      if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
+        continue
+      elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
+        continue
+
+      shard_state_dict[key] = value
+
+    return shard_state_dict
+
 
 
 @dataclass
 @dataclass
 class LlaVAConfig(BaseModelArgs):
 class LlaVAConfig(BaseModelArgs):
-    text_config: TextConfig
-    vision_config: VisionConfig = None
-    model_type: str = "llava"
-    ignore_index: int = -100
-    image_token_index: int = 32000
-    vision_feature_select_strategy: str = "default"
-    vision_feature_layer: int = -2
-    vocab_size: int = 32000
-
-    @classmethod
-    def from_dict(cls, params):
-        updated_params = {}
-        class_params = inspect.signature(cls).parameters
-        for k, v in params.items():
-            if k in class_params:
-                if k in ["text_config", "vision_config"]:
-                    v = class_params[k].annotation.from_dict(v)
-                updated_params.update({k: v})
-
-        return cls(**updated_params)
+  text_config: TextConfig
+  vision_config: VisionConfig = None
+  model_type: str = "llava"
+  ignore_index: int = -100
+  image_token_index: int = 32000
+  vision_feature_select_strategy: str = "default"
+  vision_feature_layer: int = -2
+  vocab_size: int = 32000
+
+  @classmethod
+  def from_dict(cls, params):
+    updated_params = {}
+    class_params = inspect.signature(cls).parameters
+    for k, v in params.items():
+      if k in class_params:
+        if k in ["text_config", "vision_config"]:
+          v = class_params[k].annotation.from_dict(v)
+        updated_params.update({k: v})
+
+    return cls(**updated_params)
 
 
 
 
 @dataclass
 @dataclass
 class ModelArgs(LlaVAConfig):
 class ModelArgs(LlaVAConfig):
-    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
 
 
-    def __post_init__(self):
-        if isinstance(self.shard, dict):
-            self.shard = Shard(**self.shard)
+  def __post_init__(self):
+    if isinstance(self.shard, dict):
+      self.shard = Shard(**self.shard)
 
 
-        if not isinstance(self.shard, Shard):
-            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+    if not isinstance(self.shard, Shard):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
 
 
-        if not self.shard.is_first_layer():
-            self.vision_config = None
+    if not self.shard.is_first_layer():
+      self.vision_config = None
 
 
 
 
 class LlavaMultiModalProjector(nn.Module):
 class LlavaMultiModalProjector(nn.Module):
-    def __init__(self, config: LlaVAConfig):
-        super().__init__()
-        self.linear_1 = nn.Linear(
-            config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
-        )
-        self.gelu = nn.GELU()
-        self.linear_2 = nn.Linear(
-            config.text_config.hidden_size, config.text_config.hidden_size, bias=True
-        )
-
-    def __call__(self, x: mx.array) -> mx.array:
-        x = self.linear_1(x)
-        x = self.gelu(x)
-        x = self.linear_2(x)
-        return x
+
+  def __init__(self, config: LlaVAConfig):
+    super().__init__()
+    self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+    self.gelu = nn.GELU()
+    self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.linear_1(x)
+    x = self.gelu(x)
+    x = self.linear_2(x)
+    return x
 
 
 
 
 class Model(nn.Module):
 class Model(nn.Module):
-    def __init__(self, config: ModelArgs):
-        super().__init__()
-        self.config = config
-        self.model_type = config.model_type
-        if config.vision_config:
-            self.vision_tower = VisionModel(config.vision_config)
-            self.multi_modal_projector = LlavaMultiModalProjector(config)
-            self.vision_feature_layer = config.vision_feature_layer
-            self.vision_feature_select_strategy = config.vision_feature_select_strategy
-        self.language_model = LanguageModel(config.text_config, config.shard)
-
-    def get_input_embeddings(
-            self,
-            input_ids: Optional[mx.array] = None,
-            pixel_values: Optional[mx.array] = None,
-    ):
-        if pixel_values is None:
-            return self.language_model(input_ids)
-
-        # Get the input embeddings from the language model
-        inputs_embeds = self.language_model.model.embed_tokens(input_ids)
-
-        # Get the ouptut hidden states from the vision model
-        *_, hidden_states = self.vision_tower(
-            pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
-        )
-
-        # Select the hidden states from the desired layer
-        selected_image_feature = hidden_states[self.vision_feature_layer]
-
-        if self.vision_feature_select_strategy == "default":
-            selected_image_feature = selected_image_feature[:, 1:]
-        elif self.vision_feature_select_strategy == "full":
-            selected_image_feature = selected_image_feature
-        else:
-            raise ValueError(
-                "Unexpected feature selection strategy: "
-                f"{self.vision_feature_select_strategy}"
-            )
-
-        # Pass image features through the multi-modal projector
-        image_features = self.multi_modal_projector(selected_image_feature)
-
-        # Insert special image tokens in the input_ids
-        final_inputs_embeds = self._merge_input_ids_with_image_features(
-            image_features, inputs_embeds, input_ids
-        )
-        return final_inputs_embeds
-
-    def _merge_input_ids_with_image_features(
-            self, image_features, inputs_embeds, input_ids
-    ):
-        image_token_index = self.config.image_token_index
-        num_images, num_image_patches, embed_dim = image_features.shape
-
-        # Positions of <image> tokens in input_ids, assuming batch size is 1
-        image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
-
-        if len(image_positions) != num_images:
-            raise ValueError(
-                f"The number of image tokens ({len(image_positions)}) does not "
-                f" match the number of image inputs ({num_images})."
-            )
-
-        text_segments = []
-        start_idx = 0
-
-        for position in image_positions:
-            text_segments.append(inputs_embeds[:, start_idx:position])
-            start_idx = position + 1
-
-        image_embeddings = mx.split(image_features, image_features.shape[0])
-        final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
-        final_embeddings += [inputs_embeds[:, start_idx:]]
-
-        # Create a final embedding of shape
-        # (1, num_image_patches*num_images + sequence_len, embed_dim)
-        return mx.concatenate(final_embeddings, axis=1)
-
-    def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
-        input_embddings = None
-        if pixel_values is not None:
-            input_embddings = self.get_input_embeddings(input_ids, pixel_values)
-        logits = self.language_model(
-            input_ids, cache=cache, inputs_embeds=input_embddings
-        )
-        return logits
-
-    def sanitize(self, weights):
-        if self.config.vision_config:
-          weights = self.vision_tower.sanitize(weights)
-        else:
-          weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
-        weights = self.language_model.sanitize(weights)
-        return weights
-
-    @property
-    def layers(self):
-        return self.language_model.model.layers
-
-    @property
-    def head_dim(self):
-        return (
-                self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
-        )
-
-    @property
-    def n_kv_heads(self):
-        return self.language_model.model.num_key_value_heads
+
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.config = config
+    self.model_type = config.model_type
+    if config.vision_config:
+      self.vision_tower = VisionModel(config.vision_config)
+      self.multi_modal_projector = LlavaMultiModalProjector(config)
+      self.vision_feature_layer = config.vision_feature_layer
+      self.vision_feature_select_strategy = config.vision_feature_select_strategy
+    self.language_model = LanguageModel(config.text_config, config.shard)
+
+  def get_input_embeddings(
+    self,
+    input_ids: Optional[mx.array] = None,
+    pixel_values: Optional[mx.array] = None,
+  ):
+    if pixel_values is None:
+      return self.language_model(input_ids)
+
+    # Get the input embeddings from the language model
+    inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+
+    # Get the ouptut hidden states from the vision model
+    *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)
+
+    # Select the hidden states from the desired layer
+    selected_image_feature = hidden_states[self.vision_feature_layer]
+
+    if self.vision_feature_select_strategy == "default":
+      selected_image_feature = selected_image_feature[:, 1:]
+    elif self.vision_feature_select_strategy == "full":
+      selected_image_feature = selected_image_feature
+    else:
+      raise ValueError("Unexpected feature selection strategy: "
+                       f"{self.vision_feature_select_strategy}")
+
+    # Pass image features through the multi-modal projector
+    image_features = self.multi_modal_projector(selected_image_feature)
+
+    # Insert special image tokens in the input_ids
+    final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
+    return final_inputs_embeds
+
+  def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
+    image_token_index = self.config.image_token_index
+    num_images, num_image_patches, embed_dim = image_features.shape
+
+    # Positions of <image> tokens in input_ids, assuming batch size is 1
+    image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
+
+    if len(image_positions) != num_images:
+      raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
+                       f" match the number of image inputs ({num_images}).")
+
+    text_segments = []
+    start_idx = 0
+
+    for position in image_positions:
+      text_segments.append(inputs_embeds[:, start_idx:position])
+      start_idx = position + 1
+
+    image_embeddings = mx.split(image_features, image_features.shape[0])
+    final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
+    final_embeddings += [inputs_embeds[:, start_idx:]]
+
+    # Create a final embedding of shape
+    # (1, num_image_patches*num_images + sequence_len, embed_dim)
+    return mx.concatenate(final_embeddings, axis=1)
+
+  def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
+    input_embddings = None
+    if pixel_values is not None:
+      input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+    logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
+    return logits
+
+  def sanitize(self, weights):
+    if self.config.vision_config:
+      weights = self.vision_tower.sanitize(weights)
+    else:
+      weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
+    weights = self.language_model.sanitize(weights)
+    return weights
+
+  @property
+  def layers(self):
+    return self.language_model.model.layers
+
+  @property
+  def head_dim(self):
+    return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)
+
+  @property
+  def n_kv_heads(self):
+    return self.language_model.model.num_key_value_heads

+ 1 - 0
exo/inference/mlx/sharded_inference_engine.py

@@ -9,6 +9,7 @@ from exo.download.shard_download import ShardDownloader
 
 
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
+
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader

+ 4 - 9
exo/inference/mlx/sharded_model.py

@@ -10,6 +10,7 @@ from ..shard import Shard
 
 
 
 
 class StatefulShardedModel:
 class StatefulShardedModel:
+
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
     self.shard = shard
     self.shard = shard
     self.model = model
     self.model = model
@@ -26,6 +27,7 @@ class StatefulShardedModel:
     top_p: float = 1.0,
     top_p: float = 1.0,
     logit_bias: Optional[Dict[int, float]] = None,
     logit_bias: Optional[Dict[int, float]] = None,
   ) -> Generator[Tuple[mx.array, mx.array], None, None]:
   ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+
     def sample(logits: mx.array) -> Tuple[mx.array, float]:
     def sample(logits: mx.array) -> Tuple[mx.array, float]:
       if logit_bias:
       if logit_bias:
         indices = mx.array(list(logit_bias.keys()))
         indices = mx.array(list(logit_bias.keys()))
@@ -74,16 +76,9 @@ class StatefulShardedModel:
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
 
   def init_cache(self, request_id: str):
   def init_cache(self, request_id: str):
-    kv_heads = (
-      [self.model.n_kv_heads] * len(self.model.layers)
-      if isinstance(self.model.n_kv_heads, int)
-      else self.model.n_kv_heads
-    )
+    kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
     if self.max_kv_size is not None:
     if self.max_kv_size is not None:
-      cache = [
-        RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4)
-        for n in kv_heads
-      ]
+      cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
     else:
     else:
       cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
       cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
 
 

+ 29 - 25
exo/inference/mlx/sharded_utils.py

@@ -25,6 +25,7 @@ from ..shard import Shard
 
 
 
 
 class ModelNotFoundError(Exception):
 class ModelNotFoundError(Exception):
+
   def __init__(self, message):
   def __init__(self, message):
     self.message = message
     self.message = message
     super().__init__(self.message)
     super().__init__(self.message)
@@ -139,9 +140,10 @@ def load_model_shard(
   if (quantization := config.get("quantization", None)) is not None:
   if (quantization := config.get("quantization", None)) is not None:
     # Handle legacy models which may not have everything quantized
     # Handle legacy models which may not have everything quantized
     def class_predicate(p, m):
     def class_predicate(p, m):
-        if not hasattr(m, "to_quantized"):
-            return False
-        return f"{p}.scales" in weights
+      if not hasattr(m, "to_quantized"):
+        return False
+      return f"{p}.scales" in weights
+
     nn.quantize(
     nn.quantize(
       model,
       model,
       **quantization,
       **quantization,
@@ -156,6 +158,7 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
+
 async def load_shard(
 async def load_shard(
   model_path: str,
   model_path: str,
   shard: Shard,
   shard: Shard,
@@ -179,26 +182,27 @@ async def load_shard(
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     return model, tokenizer
     return model, tokenizer
 
 
+
 async def get_image_from_str(_image_str: str):
 async def get_image_from_str(_image_str: str):
-    image_str = _image_str.strip()
-
-    if image_str.startswith("http"):
-        async with aiohttp.ClientSession() as session:
-            async with session.get(image_str, timeout=10) as response:
-                content = await response.read()
-                return Image.open(BytesIO(content)).convert("RGB")
-    elif image_str.startswith("data:image/"):
-        # Extract the image format and base64 data
-        format_prefix, base64_data = image_str.split(";base64,")
-        image_format = format_prefix.split("/")[1].lower()
-        if DEBUG >= 2: print(f"{image_str=} {image_format=}")
-        imgdata = base64.b64decode(base64_data)
-        img = Image.open(BytesIO(imgdata))
-
-        # Convert to RGB if not already
-        if img.mode != "RGB":
-            img = img.convert("RGB")
-
-        return img
-    else:
-        raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
+  image_str = _image_str.strip()
+
+  if image_str.startswith("http"):
+    async with aiohttp.ClientSession() as session:
+      async with session.get(image_str, timeout=10) as response:
+        content = await response.read()
+        return Image.open(BytesIO(content)).convert("RGB")
+  elif image_str.startswith("data:image/"):
+    # Extract the image format and base64 data
+    format_prefix, base64_data = image_str.split(";base64,")
+    image_format = format_prefix.split("/")[1].lower()
+    if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+    imgdata = base64.b64decode(base64_data)
+    img = Image.open(BytesIO(imgdata))
+
+    # Convert to RGB if not already
+    if img.mode != "RGB":
+      img = img.convert("RGB")
+
+    return img
+  else:
+    raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")

+ 6 - 6
exo/inference/mlx/test_sharded_llava.py

@@ -39,8 +39,8 @@ y = full.step("full", input_ids, pixel_values, temp=0)
 full_generated_tokens = [y.item()]
 full_generated_tokens = [y.item()]
 
 
 for _ in range(13):
 for _ in range(13):
-    y = full.step("full", y, temp=0)
-    full_generated_tokens.append(y.item())
+  y = full.step("full", y, temp=0)
+  full_generated_tokens.append(y.item())
 
 
 full_response = full_processor.tokenizer.decode(full_generated_tokens)
 full_response = full_processor.tokenizer.decode(full_generated_tokens)
 print("full response:", full_response)
 print("full response:", full_response)
@@ -54,11 +54,11 @@ y = m2.step("shard", y, temp=0)
 full_generated_tokens = [y.item()]
 full_generated_tokens = [y.item()]
 
 
 for _ in range(13):
 for _ in range(13):
-    y = m1.step("shard", y, temp=0)
-    y = m2.step("shard", y, temp=0)
-    full_generated_tokens.append(y.item())
+  y = m1.step("shard", y, temp=0)
+  y = m2.step("shard", y, temp=0)
+  full_generated_tokens.append(y.item())
 
 
 sharded_response = processor2.tokenizer.decode(full_generated_tokens)
 sharded_response = processor2.tokenizer.decode(full_generated_tokens)
 print("sharded response:", sharded_response)
 print("sharded response:", sharded_response)
 
 
-assert full_response == sharded_response
+assert full_response == sharded_response

+ 2 - 1
exo/inference/mlx/test_sharded_model.py

@@ -6,6 +6,7 @@ import numpy as np
 
 
 
 
 class DummyModel(nn.Module):
 class DummyModel(nn.Module):
+
   def __init__(self, shard: Optional[Shard] = None):
   def __init__(self, shard: Optional[Shard] = None):
     self.shard = shard
     self.shard = shard
     self.layers = [
     self.layers = [
@@ -21,7 +22,7 @@ class DummyModel(nn.Module):
 
 
   def __call__(self, x, cache=None):
   def __call__(self, x, cache=None):
     if self.shard:
     if self.shard:
-      for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
+      for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
         x = layer(x)
         x = layer(x)
       if self.shard.is_last_layer():
       if self.shard.is_last_layer():
         x = x.reshape((1, 2, 4))
         x = x.reshape((1, 2, 4))

+ 2 - 4
exo/inference/shard.py

@@ -34,8 +34,6 @@ class Shard:
   def overlaps(self, other: 'Shard') -> bool:
   def overlaps(self, other: 'Shard') -> bool:
     return shards_overlap(self, other)
     return shards_overlap(self, other)
 
 
+
 def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
 def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
-  return (
-      shard1.model_id == shard2.model_id
-      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
-  )
+  return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))

+ 13 - 11
exo/inference/test_inference_engine.py

@@ -7,6 +7,7 @@ import os
 import asyncio
 import asyncio
 import numpy as np
 import numpy as np
 
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
@@ -22,7 +23,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
   resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
     input_data=resp1,
     input_data=resp1,
     inference_state=inference_state_1,
     inference_state=inference_state_1,
   )
   )
@@ -34,7 +35,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   )
   )
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,
     inference_state=inference_state_3,
     inference_state=inference_state_3,
   )
   )
@@ -42,21 +43,22 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(
-  test_inference_engine(
-    MLXDynamicShardInferenceEngine(HFShardDownloader()),
-    MLXDynamicShardInferenceEngine(HFShardDownloader()),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-  )
-)
+
+asyncio.run(test_inference_engine(
+  MLXDynamicShardInferenceEngine(HFShardDownloader()),
+  MLXDynamicShardInferenceEngine(HFShardDownloader()),
+  "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+))
 
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
   import tinygrad
   import os
   import os
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-  asyncio.run(test_inference_engine(
+  asyncio.run(
+    test_inference_engine(
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-  ))
+    )
+  )

+ 8 - 11
exo/inference/tinygrad/inference.py

@@ -20,16 +20,11 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 ALPHA_P = 0.0
 MODEL_PARAMS = {
 MODEL_PARAMS = {
-  "8B": {
-    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
-    "files": 1
-  },
-  "70B": {
-    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
-    "files": 8
-  }
+  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 }
 
 
+
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
@@ -48,10 +43,12 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
 
   with Context(BEAM=0):
   with Context(BEAM=0):
     # replace weights in model
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=False) # consume=True
+    load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
   return model
 
 
+
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
+
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
@@ -64,7 +61,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     toks = self.tokenizer.encode(prompt)
     toks = self.tokenizer.encode(prompt)
     h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
     h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
 
 
-    if h.shape == (1,):
+    if h.shape == (1, ):
       start_pos += len(toks)
       start_pos += len(toks)
       start_pos += 1
       start_pos += 1
       n_captured_toks = 0
       n_captured_toks = 0
@@ -80,7 +77,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
 
     h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
     h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
 
-    if h.shape == (1,):
+    if h.shape == (1, ):
       start_pos += n_captured_toks
       start_pos += n_captured_toks
       start_pos += 1
       start_pos += 1
       n_captured_toks = 0
       n_captured_toks = 0

+ 73 - 34
exo/inference/tinygrad/models/llama.py

@@ -2,21 +2,24 @@ from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
 
 
+
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
+  freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
+
 
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
 def complex_mult(A, c, d):
-  a,b = A[..., 0:1], A[..., 1:2]
-  ro = a*c - b*d
-  co = a*d + b*c
+  a, b = A[..., 0:1], A[..., 1:2]
+  ro = a * c - b * d
+  co = a * d + b * c
   return ro.cat(co, dim=-1)
   return ro.cat(co, dim=-1)
 
 
-def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -26,16 +29,19 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
   xk_out = complex_mult(xk, c, d)
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
   return xq_out.flatten(3), xk_out.flatten(3)
 
 
-def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   bs, seqlen, n_kv_heads, head_dim = x.shape
   if n_rep == 1: return x
   if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
 
+
 class Attention:
 class Attention:
+
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
     self.max_context = max_context
@@ -45,7 +51,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
       xqkv = x @ self.wqkv.T
@@ -69,10 +75,10 @@ class Attention:
 
 
     # update the cache
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -80,26 +86,31 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
     return self.wo(attn)
 
 
+
 class FeedForward:
 class FeedForward:
-  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
+
+  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
+    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
+
+  def __call__(self, x: Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
-  def __call__(self, x:Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
 class TransformerBlock:
 class TransformerBlock:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
+
+  def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
+
 # standard openai sampling
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -127,8 +138,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
-      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
-      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
+      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), ))
+      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), ))
       t = (counter == t_argmax).where(0, t)
       t = (counter == t_argmax).where(0, t)
 
 
     # approximate top p
     # approximate top p
@@ -149,10 +160,28 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
 
   return output_token
   return output_token
 
 
+
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+
 class Transformer:
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_layers: int,
+    norm_eps: float,
+    vocab_size,
+    shard: Shard = None,
+    linear=nn.Linear,
+    n_kv_heads=None,
+    rope_theta=10000,
+    max_context=1024,
+    jit=True,
+    feed_forward=FeedForward
+  ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
@@ -162,10 +191,10 @@ class Transformer:
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
     seqlen = x.shape[1]
     seqlen = x.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
-    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
     if self.shard.is_first_layer():
     if self.shard.is_first_layer():
       h = self.tok_embeddings(x)
       h = self.tok_embeddings(x)
@@ -182,24 +211,33 @@ class Transformer:
     else:
     else:
       return h
       return h
 
 
-  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
+  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
       return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
       return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
 
+
 # *** helpers ***
 # *** helpers ***
 
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+
+def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+
   def permute(v: Tensor, n_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
 
   keymap = {
   keymap = {
     "model.embed_tokens.weight": "tok_embeddings.weight",
     "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
+       for x in ["q", "k", "v", "o"]
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
+       for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
+       for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
     "lm_head.weight": "output.weight",
   }
   }
@@ -215,9 +253,10 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     sd[keymap[k]] = v
     sd[keymap[k]] = v
   return sd
   return sd
 
 
-def fix_bf16(weights:Dict[Any, Tensor]):
+
+def fix_bf16(weights: Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
   # TODO: check if device supports bf16
-  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 7 - 2
exo/inference/tinygrad/tinygrad_helpers.py

@@ -8,8 +8,10 @@ from exo.helpers import DEBUG
 from exo.download.hf.hf_helpers import get_allow_patterns
 from exo.download.hf.hf_helpers import get_allow_patterns
 from fnmatch import fnmatch
 from fnmatch import fnmatch
 
 
+
 # **** helper functions ****
 # **** helper functions ****
 def concat_weights(models, device=None):
 def concat_weights(models, device=None):
+
   def convert(name) -> Tensor:
   def convert(name) -> Tensor:
     disk_tensors: List[Tensor] = [model[name] for model in models]
     disk_tensors: List[Tensor] = [model[name] for model in models]
     if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
     if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
@@ -17,11 +19,14 @@ def concat_weights(models, device=None):
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+
   return {name: convert(name) for name in {name: None for model in models for name in model}}
   return {name: convert(name) for name in {name: None for model in models for name in model}}
 
 
-def load(fn:str, shard: Shard):
+
+def load(fn: str, shard: Shard):
   if fn.endswith('.index.json'):
   if fn.endswith('.index.json'):
-    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    with open(fn) as fp:
+      weight_map = json.load(fp)['weight_map']
     parts = {}
     parts = {}
     filtered_weight_map = {}
     filtered_weight_map = {}
     allow_patterns = get_allow_patterns(weight_map, shard)
     allow_patterns = get_allow_patterns(weight_map, shard)

+ 1 - 0
exo/inference/tokenizers.py

@@ -2,6 +2,7 @@ import traceback
 from transformers import AutoTokenizer, AutoProcessor
 from transformers import AutoTokenizer, AutoProcessor
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 
 
+
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
   try:
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")

+ 25 - 32
exo/models.py

@@ -2,39 +2,32 @@ from exo.inference.shard import Shard
 
 
 model_base_shards = {
 model_base_shards = {
   ### llama
   ### llama
-  "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
-  },
-  "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "llama-3.1-8b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    },
+  "llama-3.1-70b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+    },
+  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), },
+  "llama-3-8b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    },
+  "llama-3-70b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+    },
   ### mistral
   ### mistral
-  "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-  },
-  "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-  },
+  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), },
+  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), },
   ### deepseek v2
   ### deepseek v2
-  "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
-  },
+  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), },
   ### llava
   ### llava
-  "llava-1.5-7b-hf": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
-  },
+  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), },
 }
 }
-

+ 1 - 0
exo/networking/discovery.py

@@ -4,6 +4,7 @@ from .peer_handle import PeerHandle
 
 
 
 
 class Discovery(ABC):
 class Discovery(ABC):
+
   @abstractmethod
   @abstractmethod
   async def start(self) -> None:
   async def start(self) -> None:
     pass
     pass

+ 11 - 11
exo/networking/grpc/grpc_discovery.py

@@ -11,6 +11,7 @@ from exo import DEBUG_DISCOVERY
 
 
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
 class ListenProtocol(asyncio.DatagramProtocol):
+
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
     super().__init__()
     self.on_message = on_message
     self.on_message = on_message
@@ -24,6 +25,7 @@ class ListenProtocol(asyncio.DatagramProtocol):
 
 
 
 
 class GRPCDiscovery(Discovery):
 class GRPCDiscovery(Discovery):
+
   def __init__(
   def __init__(
     self,
     self,
     node_id: str,
     node_id: str,
@@ -97,14 +99,12 @@ class GRPCDiscovery(Discovery):
     sock = transport.get_extra_info("socket")
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
 
 
-    message = json.dumps(
-      {
-        "type": "discovery",
-        "node_id": self.node_id,
-        "grpc_port": self.node_port,
-        "device_capabilities": self.device_capabilities.to_dict(),
-      }
-    ).encode("utf-8")
+    message = json.dumps({
+      "type": "discovery",
+      "node_id": self.node_id,
+      "grpc_port": self.node_port,
+      "device_capabilities": self.device_capabilities.to_dict(),
+    }).encode("utf-8")
 
 
     while True:
     while True:
       try:
       try:
@@ -166,14 +166,14 @@ class GRPCDiscovery(Discovery):
       try:
       try:
         current_time = time.time()
         current_time = time.time()
         peers_to_remove = [
         peers_to_remove = [
-          peer_handle.id()
-          for peer_handle, connected_at, last_seen in self.known_peers.values()
+          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
         ]
         if DEBUG_DISCOVERY >= 2:
         if DEBUG_DISCOVERY >= 2:
           print(
           print(
             "Peer statuses:",
             "Peer statuses:",
-            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
+            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
+             for peer_handle, connected_at, last_seen in self.known_peers.values()},
           )
           )
         if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
         if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
           print(f"Cleaning up peers: {peers_to_remove}")
           print(f"Cleaning up peers: {peers_to_remove}")

+ 1 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -13,6 +13,7 @@ from exo.topology.device_capabilities import DeviceCapabilities
 
 
 
 
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
+
   def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
   def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self._id = _id
     self.address = address
     self.address = address

+ 9 - 9
exo/networking/grpc/grpc_server.py

@@ -11,6 +11,7 @@ from exo.orchestration import Node
 
 
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
+
   def __init__(self, node: Node, host: str, port: int):
   def __init__(self, node: Node, host: str, port: int):
     self.node = node
     self.node = node
     self.host = host
     self.host = host
@@ -81,9 +82,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       node_service_pb2.InferenceResult(
       node_service_pb2.InferenceResult(
         tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
         tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
         is_finished=result[1],
         is_finished=result[1],
-      )
-      if result[0] is not None
-      else node_service_pb2.InferenceResult(is_finished=result[1])
+      ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
     )
     )
 
 
   async def CollectTopology(self, request, context):
   async def CollectTopology(self, request, context):
@@ -91,12 +90,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     visited = set(request.visited)
     visited = set(request.visited)
     topology = await self.node.collect_topology(visited, max_depth)
     topology = await self.node.collect_topology(visited, max_depth)
     nodes = {
     nodes = {
-      node_id: node_service_pb2.DeviceCapabilities(
-        model=cap.model,
-        chip=cap.chip,
-        memory=cap.memory,
-        flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
-      )
+      node_id:
+        node_service_pb2.DeviceCapabilities(
+          model=cap.model,
+          chip=cap.chip,
+          memory=cap.memory,
+          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
+        )
       for node_id, cap in topology.nodes.items()
       for node_id, cap in topology.nodes.items()
     }
     }
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 0 - 3
exo/networking/grpc/node_service_pb2.py


+ 229 - 271
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,306 +12,264 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 _version_not_supported = False
 
 
 try:
 try:
-    from grpc._utilities import first_version_is_lower
-    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+  from grpc._utilities import first_version_is_lower
+  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
 except ImportError:
-    _version_not_supported = True
+  _version_not_supported = True
 
 
 if _version_not_supported:
 if _version_not_supported:
-    warnings.warn(
-        f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
-        + f' grpcio>={GRPC_GENERATED_VERSION}.'
-        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
-        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
-    )
+  warnings.warn(
+    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
+    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
+    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
+  )
 
 
 
 
 class NodeServiceStub(object):
 class NodeServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
+  """Missing associated documentation comment in .proto file."""
 
 
-    def __init__(self, channel):
-        """Constructor.
+  def __init__(self, channel):
+    """Constructor.
 
 
         Args:
         Args:
             channel: A grpc.Channel.
             channel: A grpc.Channel.
         """
         """
-        self.SendPrompt = channel.unary_unary(
-                '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.SendTensor = channel.unary_unary(
-                '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
-        self.CollectTopology = channel.unary_unary(
-                '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
-                _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.SendOpaqueStatus = channel.unary_unary(
-                '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
+    self.SendPrompt = channel.unary_unary(
+      '/node_service.NodeService/SendPrompt',
+      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.SendTensor = channel.unary_unary(
+      '/node_service.NodeService/SendTensor',
+      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.GetInferenceResult = channel.unary_unary(
+      '/node_service.NodeService/GetInferenceResult',
+      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.InferenceResult.FromString,
+      _registered_method=True
+    )
+    self.CollectTopology = channel.unary_unary(
+      '/node_service.NodeService/CollectTopology',
+      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Topology.FromString,
+      _registered_method=True
+    )
+    self.SendResult = channel.unary_unary(
+      '/node_service.NodeService/SendResult',
+      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.SendOpaqueStatus = channel.unary_unary(
+      '/node_service.NodeService/SendOpaqueStatus',
+      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
 
 
 
 
 class NodeServiceServicer(object):
 class NodeServiceServicer(object):
-    """Missing associated documentation comment in .proto file."""
+  """Missing associated documentation comment in .proto file."""
 
 
-    def SendPrompt(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendPrompt(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def SendTensor(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendTensor(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def GetInferenceResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def CollectTopology(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def CollectTopology(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def SendResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def SendOpaqueStatus(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendOpaqueStatus(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
 def add_NodeServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-            'SendPrompt': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'SendTensor': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
-            'CollectTopology': grpc.unary_unary_rpc_method_handler(
-                    servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
-            ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-            'node_service.NodeService', rpc_method_handlers)
-    server.add_generic_rpc_handlers((generic_handler,))
-    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+  rpc_method_handlers = {
+    'SendPrompt':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendPrompt,
+        request_deserializer=node__service__pb2.PromptRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'SendTensor':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendTensor,
+        request_deserializer=node__service__pb2.TensorRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'GetInferenceResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.GetInferenceResult,
+        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+      ),
+    'CollectTopology':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.CollectTopology,
+        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+        response_serializer=node__service__pb2.Topology.SerializeToString,
+      ),
+    'SendResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendResult,
+        request_deserializer=node__service__pb2.SendResultRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'SendOpaqueStatus':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendOpaqueStatus,
+        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+  }
+  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
+  server.add_generic_rpc_handlers((generic_handler, ))
+  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
 
 
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
 class NodeService(object):
 class NodeService(object):
-    """Missing associated documentation comment in .proto file."""
+  """Missing associated documentation comment in .proto file."""
 
 
-    @staticmethod
-    def SendPrompt(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendPrompt',
+      node__service__pb2.PromptRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def SendTensor(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendTensor',
+      node__service__pb2.TensorRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/GetInferenceResult',
+      node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      node__service__pb2.InferenceResult.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def CollectTopology(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/CollectTopology',
+      node__service__pb2.CollectTopologyRequest.SerializeToString,
+      node__service__pb2.Topology.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def SendResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendResult',
+      node__service__pb2.SendResultRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def SendOpaqueStatus(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendOpaqueStatus',
+      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )

+ 1 - 0
exo/networking/grpc/test_grpc_discovery.py

@@ -4,6 +4,7 @@ from .grpc_discovery import GRPCDiscovery
 
 
 
 
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
+
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
     self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
     self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
     self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)

+ 1 - 0
exo/networking/peer_handle.py

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
 
 
 
 
 class PeerHandle(ABC):
 class PeerHandle(ABC):
+
   @abstractmethod
   @abstractmethod
   def id(self) -> str:
   def id(self) -> str:
     pass
     pass

+ 1 - 0
exo/networking/server.py

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
 
 
 
 
 class Server(ABC):
 class Server(ABC):
+
   @abstractmethod
   @abstractmethod
   async def start(self) -> None:
   async def start(self) -> None:
     pass
     pass

+ 1 - 0
exo/orchestration/node.py

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
 
 
 
 
 class Node(ABC):
 class Node(ABC):
+
   @abstractmethod
   @abstractmethod
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:
     pass
     pass

+ 3 - 0
exo/orchestration/standard_node.py

@@ -18,6 +18,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 
 
 class StandardNode(Node):
 class StandardNode(Node):
+
   def __init__(
   def __init__(
     self,
     self,
     _id: str,
     _id: str,
@@ -359,6 +360,7 @@ class StandardNode(Node):
     self.on_token.trigger_all(request_id, tokens, is_finished)
     self.on_token.trigger_all(request_id, tokens, is_finished)
 
 
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+
     async def send_result_to_peer(peer):
     async def send_result_to_peer(peer):
       try:
       try:
         await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
         await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
@@ -372,6 +374,7 @@ class StandardNode(Node):
 
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
     if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
     if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
+
     async def send_status_to_peer(peer):
     async def send_status_to_peer(peer):
       try:
       try:
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)

+ 1 - 0
exo/orchestration/test_node.py

@@ -7,6 +7,7 @@ from exo.networking.peer_handle import PeerHandle
 
 
 
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
 class TestNode(unittest.IsolatedAsyncioTestCase):
+
   def setUp(self):
   def setUp(self):
     self.mock_inference_engine = AsyncMock()
     self.mock_inference_engine = AsyncMock()
     self.mock_server = AsyncMock()
     self.mock_server = AsyncMock()

+ 1 - 0
exo/test_callbacks.py

@@ -12,6 +12,7 @@ async def main() -> None:
   callback2 = callback_system.register("callback2")
   callback2 = callback_system.register("callback2")
 
 
   def on_next_callback(name: str) -> Callable[..., None]:
   def on_next_callback(name: str) -> Callable[..., None]:
+
     def callback(*args: Any) -> None:
     def callback(*args: Any) -> None:
       print(f"{name} received values: {args}")
       print(f"{name} received values: {args}")
 
 

+ 1 - 0
exo/topology/partitioning_strategy.py

@@ -14,6 +14,7 @@ class Partition:
 
 
 
 
 class PartitioningStrategy(ABC):
 class PartitioningStrategy(ABC):
+
   @abstractmethod
   @abstractmethod
   def partition(self, topology: Topology) -> List[Partition]:
   def partition(self, topology: Topology) -> List[Partition]:
     pass
     pass

+ 1 - 0
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -5,6 +5,7 @@ from .partitioning_strategy import Partition
 
 
 
 
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
+
   def partition(self, topology: Topology) -> List[Partition]:
   def partition(self, topology: Topology) -> List[Partition]:
     nodes = list(topology.all_nodes())
     nodes = list(topology.all_nodes())
     nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
     nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)

+ 1 - 0
exo/topology/test_device_capabilities.py

@@ -4,6 +4,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
 
 
 
 
 class TestMacDeviceCapabilities(unittest.TestCase):
 class TestMacDeviceCapabilities(unittest.TestCase):
+
   @patch("subprocess.check_output")
   @patch("subprocess.check_output")
   def test_mac_device_capabilities_pro(self, mock_check_output):
   def test_mac_device_capabilities_pro(self, mock_check_output):
     # Mock the subprocess output
     # Mock the subprocess output

+ 1 - 0
exo/topology/test_map_partitions.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 
 
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
+
   def test_map_partitions_to_shards(self):
   def test_map_partitions_to_shards(self):
     partitions = [
     partitions = [
       Partition("node1", 0.0, 0.42857),
       Partition("node1", 0.0, 0.42857),

+ 1 - 0
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -6,6 +6,7 @@ from exo.topology.partitioning_strategy import Partition
 
 
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
+
   def test_partition(self):
   def test_partition(self):
     # triangle
     # triangle
     # node1 -> node2 -> node3 -> node1
     # node1 -> node2 -> node3 -> node1

+ 1 - 0
exo/topology/topology.py

@@ -3,6 +3,7 @@ from typing import Dict, Set, Optional
 
 
 
 
 class Topology:
 class Topology:
+
   def __init__(self):
   def __init__(self):
     self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
     self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
     self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
     self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph

+ 49 - 47
exo/viz/test_topology_viz.py

@@ -9,58 +9,60 @@ from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
 
 
 
 
 def create_hf_repo_progress_event(
 def create_hf_repo_progress_event(
-    completed_files: int = 5,
-    total_files: int = 10,
-    downloaded_bytes: int = 500000000,
-    downloaded_bytes_this_session: int = 250000000,
-    total_bytes: int = 1000000000,
-    overall_speed: int = 5000000,
-    overall_eta: timedelta = timedelta(seconds=100),
-    file_progress: dict = None,
-    status: str = "in_progress"
+  completed_files: int = 5,
+  total_files: int = 10,
+  downloaded_bytes: int = 500000000,
+  downloaded_bytes_this_session: int = 250000000,
+  total_bytes: int = 1000000000,
+  overall_speed: int = 5000000,
+  overall_eta: timedelta = timedelta(seconds=100),
+  file_progress: dict = None,
+  status: str = "in_progress"
 ) -> RepoProgressEvent:
 ) -> RepoProgressEvent:
-    if file_progress is None:
-        file_progress = {
-            "file1.bin": RepoFileProgressEvent(
-                repo_id="repo_id",
-                repo_revision="repo_revision",
-                file_path="file1.bin",
-                downloaded=100000000,
-                downloaded_this_session=50000000,
-                total=200000000,
-                speed=1000000,
-                eta=timedelta(seconds=100),
-                status="in_progress"
-            ),
-            "file2.bin": RepoFileProgressEvent(
-                repo_id="repo_id",
-                repo_revision="repo_revision",
-                file_path="file2.bin",
-                downloaded=200000000,
-                downloaded_this_session=100000000,
-                total=200000000,
-                speed=2000000,
-                eta=timedelta(seconds=0),
-                status="complete"
-            )
-        }
+  if file_progress is None:
+    file_progress = {
+      "file1.bin":
+        RepoFileProgressEvent(
+          repo_id="repo_id",
+          repo_revision="repo_revision",
+          file_path="file1.bin",
+          downloaded=100000000,
+          downloaded_this_session=50000000,
+          total=200000000,
+          speed=1000000,
+          eta=timedelta(seconds=100),
+          status="in_progress"
+        ), "file2.bin":
+          RepoFileProgressEvent(
+            repo_id="repo_id",
+            repo_revision="repo_revision",
+            file_path="file2.bin",
+            downloaded=200000000,
+            downloaded_this_session=100000000,
+            total=200000000,
+            speed=2000000,
+            eta=timedelta(seconds=0),
+            status="complete"
+          )
+    }
 
 
-    return RepoProgressEvent(
-        repo_id="repo_id",
-        repo_revision="repo_revision",
-        completed_files=completed_files,
-        total_files=total_files,
-        downloaded_bytes=downloaded_bytes,
-        downloaded_bytes_this_session=downloaded_bytes_this_session,
-        total_bytes=total_bytes,
-        overall_speed=overall_speed,
-        overall_eta=overall_eta,
-        file_progress=file_progress,
-        status=status
-    )
+  return RepoProgressEvent(
+    repo_id="repo_id",
+    repo_revision="repo_revision",
+    completed_files=completed_files,
+    total_files=total_files,
+    downloaded_bytes=downloaded_bytes,
+    downloaded_bytes_this_session=downloaded_bytes_this_session,
+    total_bytes=total_bytes,
+    overall_speed=overall_speed,
+    overall_eta=overall_eta,
+    file_progress=file_progress,
+    status=status
+  )
 
 
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
+
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     self.topology = Topology()
     self.topology = Topology()
     self.topology.update_node(
     self.topology.update_node(

+ 64 - 66
exo/viz/topology_viz.py

@@ -16,7 +16,9 @@ from rich.syntax import Syntax
 from rich.panel import Panel
 from rich.panel import Panel
 from rich.markdown import Markdown
 from rich.markdown import Markdown
 
 
+
 class TopologyViz:
 class TopologyViz:
+
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
     self.chatgpt_api_endpoints = chatgpt_api_endpoints
     self.chatgpt_api_endpoints = chatgpt_api_endpoints
     self.web_chat_urls = web_chat_urls
     self.web_chat_urls = web_chat_urls
@@ -28,11 +30,7 @@ class TopologyViz:
 
 
     self.console = Console()
     self.console = Console()
     self.layout = Layout()
     self.layout = Layout()
-    self.layout.split(
-      Layout(name="main"),
-      Layout(name="prompt_output", size=15),
-      Layout(name="download", size=25)
-    )
+    self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25))
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
@@ -75,11 +73,11 @@ class TopologyViz:
 
 
     # Update and show/hide prompt and output panel
     # Update and show/hide prompt and output panel
     if any(r[0] or r[1] for r in self.requests.values()):
     if any(r[0] or r[1] for r in self.requests.values()):
-        self.prompt_output_panel = self._generate_prompt_output_layout()
-        self.layout["prompt_output"].update(self.prompt_output_panel)
-        self.layout["prompt_output"].visible = True
+      self.prompt_output_panel = self._generate_prompt_output_layout()
+      self.layout["prompt_output"].update(self.prompt_output_panel)
+      self.layout["prompt_output"].visible = True
     else:
     else:
-        self.layout["prompt_output"].visible = False
+      self.layout["prompt_output"].visible = False
 
 
     # Only show download_panel if there are in-progress downloads
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
@@ -97,33 +95,33 @@ class TopologyViz:
     max_lines = 13  # Maximum number of lines for the entire panel content
     max_lines = 13  # Maximum number of lines for the entire panel content
 
 
     for (prompt, output) in reversed(requests):
     for (prompt, output) in reversed(requests):
-        prompt_icon, output_icon = "💬️", "🤖"
-
-        # Process prompt
-        prompt_lines = prompt.split('\n')
-        if len(prompt_lines) > max_lines // 2:
-            prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
-        prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
-        prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
-
-        # Process output
-        output_lines = output.split('\n')
-        remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
-        if len(output_lines) > remaining_lines:
-            output_lines = output_lines[:remaining_lines - 1] + ['...']
-        output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
-        output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
-
-        content.append(prompt_text)
-        content.append(output_text)
-        content.append(Text())  # Empty line between entries
+      prompt_icon, output_icon = "💬️", "🤖"
+
+      # Process prompt
+      prompt_lines = prompt.split('\n')
+      if len(prompt_lines) > max_lines // 2:
+        prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+      prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+      prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+      # Process output
+      output_lines = output.split('\n')
+      remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+      if len(output_lines) > remaining_lines:
+        output_lines = output_lines[:remaining_lines - 1] + ['...']
+      output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+      output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+      content.append(prompt_text)
+      content.append(output_text)
+      content.append(Text())  # Empty line between entries
 
 
     return Panel(
     return Panel(
-        Group(*content),
-        title="",
-        border_style="cyan",
-        height=15,  # Increased height to accommodate multiple lines
-        expand=True  # Allow the panel to expand to full width
+      Group(*content),
+      title="",
+      border_style="cyan",
+      height=15,  # Increased height to accommodate multiple lines
+      expand=True  # Allow the panel to expand to full width
     )
     )
 
 
   def _generate_main_layout(self) -> str:
   def _generate_main_layout(self) -> str:
@@ -185,14 +183,14 @@ class TopologyViz:
       visualization[bar_y][bar_start_x + i] = segment
       visualization[bar_y][bar_start_x + i] = segment
 
 
     # Add labels
     # Add labels
-    visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
-    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
+    visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
+    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich"
 
 
     # Add position indicator and FLOPS value
     # Add position indicator and FLOPS value
     pos_x = bar_start_x + int(bar_pos * bar_width)
     pos_x = bar_start_x + int(bar_pos * bar_width)
     flops_str = f"{total_flops:.2f} TFLOPS"
     flops_str = f"{total_flops:.2f} TFLOPS"
     visualization[bar_y - 1][pos_x] = "▼"
     visualization[bar_y - 1][pos_x] = "▼"
-    visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
+    visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
     visualization[bar_y + 2][pos_x] = "▲"
 
 
     # Add an extra empty line for spacing
     # Add an extra empty line for spacing
@@ -270,41 +268,41 @@ class TopologyViz:
 
 
     # Current node download progress
     # Current node download progress
     if self.node_id in self.node_download_progress:
     if self.node_id in self.node_download_progress:
-        download_progress = self.node_download_progress[self.node_id]
-        title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
-        summary.add_row(Text(title, style="bold"))
-        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
-        summary.add_row(progress_info)
+      download_progress = self.node_download_progress[self.node_id]
+      title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
+      summary.add_row(Text(title, style="bold"))
+      progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+      summary.add_row(progress_info)
 
 
-        eta_info = f"{download_progress.overall_eta}"
-        summary.add_row(eta_info)
+      eta_info = f"{download_progress.overall_eta}"
+      summary.add_row(eta_info)
 
 
-        summary.add_row("")  # Empty row for spacing
+      summary.add_row("")  # Empty row for spacing
 
 
-        for file_path, file_progress in download_progress.file_progress.items():
-            if file_progress.status != "complete":
-                progress = int(file_progress.downloaded / file_progress.total * 30)
-                bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
-                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
-                summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
+      for file_path, file_progress in download_progress.file_progress.items():
+        if file_progress.status != "complete":
+          progress = int(file_progress.downloaded / file_progress.total * 30)
+          bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
+          percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+          summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
 
 
     summary.add_row("")  # Empty row for spacing
     summary.add_row("")  # Empty row for spacing
 
 
     # Other nodes download progress summary
     # Other nodes download progress summary
     summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
     summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
     for node_id, progress in self.node_download_progress.items():
     for node_id, progress in self.node_download_progress.items():
-        if node_id != self.node_id:
-            device = self.topology.nodes.get(node_id)
-            partition = next((p for p in self.partitions if p.node_id == node_id), None)
-            partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
-            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
-            speed = pretty_print_bytes_per_second(progress.overall_speed)
-            device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
-            progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
-            progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
-            percentage_str = f"{percentage:.1f}%"
-            eta_str = f"{progress.overall_eta}"
-            summary.add_row(device_info, progress_info, percentage_str)
-            summary.add_row("", progress_bar, eta_str)
-
-    return summary
+      if node_id != self.node_id:
+        device = self.topology.nodes.get(node_id)
+        partition = next((p for p in self.partitions if p.node_id == node_id), None)
+        partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
+        percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+        speed = pretty_print_bytes_per_second(progress.overall_speed)
+        device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
+        progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
+        progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
+        percentage_str = f"{percentage:.1f}%"
+        eta_str = f"{progress.overall_eta}"
+        summary.add_row(device_info, progress_info, percentage_str)
+        summary.add_row("", progress_bar, eta_str)
+
+    return summary

+ 35 - 37
extra/download_hf.py

@@ -3,51 +3,49 @@ import asyncio
 from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 
 
 DEFAULT_ALLOW_PATTERNS = [
 DEFAULT_ALLOW_PATTERNS = [
-    "*.json",
-    "*.py",
-    "tokenizer.model",
-    "*.tiktoken",
-    "*.txt",
-    "*.safetensors",
+  "*.json",
+  "*.py",
+  "tokenizer.model",
+  "*.tiktoken",
+  "*.txt",
+  "*.safetensors",
 ]
 ]
 # Always ignore `.git` and `.cache/huggingface` folders in commits
 # Always ignore `.git` and `.cache/huggingface` folders in commits
 DEFAULT_IGNORE_PATTERNS = [
 DEFAULT_IGNORE_PATTERNS = [
-    ".git",
-    ".git/*",
-    "*/.git",
-    "**/.git/**",
-    ".cache/huggingface",
-    ".cache/huggingface/*",
-    "*/.cache/huggingface",
-    "**/.cache/huggingface/**",
+  ".git",
+  ".git/*",
+  "*/.git",
+  "**/.git/**",
+  ".cache/huggingface",
+  ".cache/huggingface/*",
+  "*/.cache/huggingface",
+  "**/.cache/huggingface/**",
 ]
 ]
 
 
+
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(event: RepoProgressEvent):
-        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
-        print(f"Estimated time remaining: {event.overall_eta}")
-        print("File Progress:")
-        for file_path, progress in event.file_progress.items():
-            status_icon = {
-                'not_started': '⚪',
-                'in_progress': '🔵',
-                'complete': '✅'
-            }[progress.status]
-            eta_str = str(progress.eta)
-            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
-                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
-        print("\n")
-
-    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+
+  async def progress_callback(event: RepoProgressEvent):
+    print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+    print(f"Estimated time remaining: {event.overall_eta}")
+    print("File Progress:")
+    for file_path, progress in event.file_progress.items():
+      status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
+      eta_str = str(progress.eta)
+      print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+            f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+    print("\n")
+
+  await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
-    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
-    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
-    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
-    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+  parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+  parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+  parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+  parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+  parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
 
 
-    args = parser.parse_args()
+  args = parser.parse_args()
 
 
-    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
+  asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 113 - 98
main.py

@@ -53,131 +53,146 @@ inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 
 if args.node_port is None:
 if args.node_port is None:
-    args.node_port = find_available_port(args.node_host)
-    if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+  args.node_port = find_available_port(args.node_host)
+  if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 
 args.node_id = args.node_id or get_or_create_node_id()
 args.node_id = args.node_id or get_or_create_node_id()
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
-chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
 if DEBUG >= 0:
 if DEBUG >= 0:
-    print("Chat interface started:")
-    for web_chat_url in web_chat_urls:
-        print(f" - {terminal_link(web_chat_url)}")
-    print("ChatGPT API endpoint served at:")
-    for chatgpt_api_endpoint in chatgpt_api_endpoints:
-        print(f" - {terminal_link(chatgpt_api_endpoint)}")
+  print("Chat interface started:")
+  for web_chat_url in web_chat_urls:
+    print(f" - {terminal_link(web_chat_url)}")
+  print("ChatGPT API endpoint served at:")
+  for chatgpt_api_endpoint in chatgpt_api_endpoints:
+    print(f" - {terminal_link(chatgpt_api_endpoint)}")
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
-    args.node_id,
-    None,
-    inference_engine,
-    discovery,
-    chatgpt_api_endpoints=chatgpt_api_endpoints,
-    web_chat_urls=web_chat_urls,
-    partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
-    disable_tui=args.disable_tui,
-    max_generate_tokens=args.max_generate_tokens,
-    topology_viz=topology_viz
+  args.node_id,
+  None,
+  inference_engine,
+  discovery,
+  chatgpt_api_endpoints=chatgpt_api_endpoints,
+  web_chat_urls=web_chat_urls,
+  partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+  disable_tui=args.disable_tui,
+  max_generate_tokens=args.max_generate_tokens,
+  topology_viz=topology_viz
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None)
+api = ChatGPTAPI(
+  node,
+  inference_engine.__class__.__name__,
+  response_timeout_secs=args.chatgpt_api_response_timeout_secs,
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+)
 node.on_token.register("update_topology_viz").on_next(
 node.on_token.register("update_topology_viz").on_next(
-    lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
+  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id,
+                                                               inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
 )
 )
+
+
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
-    try:
-        status = json.loads(opaque_status)
-        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
-            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
-            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
-    except Exception as e:
-        if DEBUG >= 2:
-            print(f"Failed to preemptively start download: {e}")
-            traceback.print_exc()
+  try:
+    status = json.loads(opaque_status)
+    if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+      current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+      if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Failed to preemptively start download: {e}")
+      traceback.print_exc()
+
+
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 if args.prometheus_client_port:
 if args.prometheus_client_port:
-    from exo.stats.metrics import start_metrics_server
-    start_metrics_server(node, args.prometheus_client_port)
+  from exo.stats.metrics import start_metrics_server
+  start_metrics_server(node, args.prometheus_client_port)
 
 
 last_broadcast_time = 0
 last_broadcast_time = 0
+
+
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+
+
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
+
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
-    """Gracefully shutdown the server and close the asyncio loop."""
-    print(f"Received exit signal {signal.name}...")
-    print("Thank you for using exo.")
-    print_yellow_exo()
-    server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-    [task.cancel() for task in server_tasks]
-    print(f"Cancelling {len(server_tasks)} outstanding tasks")
-    await asyncio.gather(*server_tasks, return_exceptions=True)
-    await server.stop()
-    loop.stop()
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-    shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
-    if not shard:
-        print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
-        return
-    tokenizer = await resolve_tokenizer(shard.model_id)
-    request_id = str(uuid.uuid4())
-    callback_id = f"cli-wait-response-{request_id}"
-    callback = node.on_token.register(callback_id)
-    if topology_viz:
-        topology_viz.update_prompt(request_id, prompt)
-    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
-
-    try:
-        print(f"Processing prompt: {prompt}")
-        await node.process_prompt(shard, prompt, None, request_id=request_id)
-
-        _, tokens, _ = await callback.wait(
-            lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
-            timeout=300
-        )
-
-        print("\nGenerated response:")
-        print(tokenizer.decode(tokens))
-    except Exception as e:
-        print(f"Error processing prompt: {str(e)}")
-        traceback.print_exc()
-    finally:
-        node.on_token.deregister(callback_id)
+  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  if not shard:
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    return
+  tokenizer = await resolve_tokenizer(shard.model_id)
+  request_id = str(uuid.uuid4())
+  callback_id = f"cli-wait-response-{request_id}"
+  callback = node.on_token.register(callback_id)
+  if topology_viz:
+    topology_viz.update_prompt(request_id, prompt)
+  prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+  try:
+    print(f"Processing prompt: {prompt}")
+    await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+
+    print("\nGenerated response:")
+    print(tokenizer.decode(tokens))
+  except Exception as e:
+    print(f"Error processing prompt: {str(e)}")
+    traceback.print_exc()
+  finally:
+    node.on_token.deregister(callback_id)
+
 
 
 async def main():
 async def main():
-    loop = asyncio.get_running_loop()
+  loop = asyncio.get_running_loop()
+
+  # Use a more direct approach to handle signals
+  def handle_exit():
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
 
-    # Use a more direct approach to handle signals
-    def handle_exit():
-        asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+  for s in [signal.SIGINT, signal.SIGTERM]:
+    loop.add_signal_handler(s, handle_exit)
 
 
-    for s in [signal.SIGINT, signal.SIGTERM]:
-        loop.add_signal_handler(s, handle_exit)
+  await node.start(wait_for_peers=args.wait_for_peers)
 
 
-    await node.start(wait_for_peers=args.wait_for_peers)
+  if args.run_model:
+    await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+  else:
+    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+    await asyncio.Event().wait()
 
 
-    if args.run_model:
-        await run_model_cli(node, inference_engine, args.run_model, args.prompt)
-    else:
-        asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
-        await asyncio.Event().wait()
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    loop = asyncio.new_event_loop()
-    asyncio.set_event_loop(loop)
-    try:
-        loop.run_until_complete(main())
-    except KeyboardInterrupt:
-        print("Received keyboard interrupt. Shutting down...")
-    finally:
-        loop.run_until_complete(shutdown(signal.SIGTERM, loop))
-        loop.close()
+  loop = asyncio.new_event_loop()
+  asyncio.set_event_loop(loop)
+  try:
+    loop.run_until_complete(main())
+  except KeyboardInterrupt:
+    print("Received keyboard interrupt. Shutting down...")
+  finally:
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.close()

Vissa filer visades inte eftersom för många filer har ändrats