Browse Source

reformat with yapf format.py

Alex Cheema 8 tháng trước cách đây
mục cha
commit
ea70c9fb76
48 tập tin đã thay đổi với 1873 bổ sung1854 xóa
  1. 38 45
      examples/llama3_distributed.py
  2. 62 59
      exo/api/chatgpt_api.py
  3. 50 64
      exo/download/download_progress.py
  4. 334 297
      exo/download/hf/hf_helpers.py
  5. 58 60
      exo/download/hf/hf_shard_download.py
  6. 10 8
      exo/download/shard_download.py
  7. 85 74
      exo/helpers.py
  8. 5 7
      exo/inference/debug_inference_engine.py
  9. 2 0
      exo/inference/inference_engine.py
  10. 1 0
      exo/inference/mlx/models/base.py
  11. 4 5
      exo/inference/mlx/models/deepseek_v2.py
  12. 5 3
      exo/inference/mlx/models/llama.py
  13. 521 555
      exo/inference/mlx/models/llava.py
  14. 1 0
      exo/inference/mlx/sharded_inference_engine.py
  15. 4 9
      exo/inference/mlx/sharded_model.py
  16. 29 25
      exo/inference/mlx/sharded_utils.py
  17. 6 6
      exo/inference/mlx/test_sharded_llava.py
  18. 2 1
      exo/inference/mlx/test_sharded_model.py
  19. 2 4
      exo/inference/shard.py
  20. 13 11
      exo/inference/test_inference_engine.py
  21. 8 11
      exo/inference/tinygrad/inference.py
  22. 73 34
      exo/inference/tinygrad/models/llama.py
  23. 7 2
      exo/inference/tinygrad/tinygrad_helpers.py
  24. 1 0
      exo/inference/tokenizers.py
  25. 25 32
      exo/models.py
  26. 1 0
      exo/networking/discovery.py
  27. 11 11
      exo/networking/grpc/grpc_discovery.py
  28. 1 0
      exo/networking/grpc/grpc_peer_handle.py
  29. 9 9
      exo/networking/grpc/grpc_server.py
  30. 0 3
      exo/networking/grpc/node_service_pb2.py
  31. 229 271
      exo/networking/grpc/node_service_pb2_grpc.py
  32. 1 0
      exo/networking/grpc/test_grpc_discovery.py
  33. 1 0
      exo/networking/peer_handle.py
  34. 1 0
      exo/networking/server.py
  35. 1 0
      exo/orchestration/node.py
  36. 3 0
      exo/orchestration/standard_node.py
  37. 1 0
      exo/orchestration/test_node.py
  38. 1 0
      exo/test_callbacks.py
  39. 1 0
      exo/topology/partitioning_strategy.py
  40. 1 0
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  41. 1 0
      exo/topology/test_device_capabilities.py
  42. 1 0
      exo/topology/test_map_partitions.py
  43. 1 0
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  44. 1 0
      exo/topology/topology.py
  45. 49 47
      exo/viz/test_topology_viz.py
  46. 64 66
      exo/viz/topology_viz.py
  47. 35 37
      extra/download_hf.py
  48. 113 98
      main.py

+ 38 - 45
examples/llama3_distributed.py

@@ -13,8 +13,8 @@ import argparse
 import uuid
 
 models = {
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
+  "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+  "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
 }
 
 path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
@@ -29,60 +29,53 @@ tokenizer = load_tokenizer(model_path, tokenizer_config)
 #     "localhost:8080",
 #     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 # )
-peer2 = GRPCPeerHandle(
-    "node2",
-    "localhost:8081",
-    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
-)
+peer2 = GRPCPeerHandle("node2", "localhost:8081", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
 shard = models[path_or_hf_repo]
 request_id = str(uuid.uuid4())
 
+
 async def run_prompt(prompt: str):
-    if tokenizer.chat_template is None:
-        tokenizer.chat_template = tokenizer.default_chat_template
-    if (
-        hasattr(tokenizer, "apply_chat_template")
-        and tokenizer.chat_template is not None
-    ):
-        messages = [{"role": "user", "content": prompt}]
-        prompt = tokenizer.apply_chat_template(
-            messages, tokenize=False, add_generation_prompt=True
-        )
+  if tokenizer.chat_template is None:
+    tokenizer.chat_template = tokenizer.default_chat_template
+  if (hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None):
+    messages = [{"role": "user", "content": prompt}]
+    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+  await peer2.connect()
 
-    await peer2.connect()
+  try:
+    await peer2.send_prompt(shard, prompt, request_id)
+  except Exception as e:
+    print(e)
 
+  import time
+  # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
+  previous_length = 0
+  n_tokens = 0
+  start_time = time.perf_counter()
+  while True:
     try:
-        await peer2.send_prompt(shard, prompt, request_id)
+      result, is_finished = await peer2.get_inference_result(request_id)
     except Exception as e:
-        print(e)
+      continue
+    await asyncio.sleep(0.1)
 
-    import time
-    # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
-    previous_length = 0
-    n_tokens = 0
-    start_time = time.perf_counter()
-    while True:
-        try:
-            result, is_finished = await peer2.get_inference_result(request_id)
-        except Exception as e:
-            continue
-        await asyncio.sleep(0.1)
+    # Print the updated string in place
+    updated_string = tokenizer.decode(result)
+    n_tokens = len(result)
+    print(updated_string[previous_length:], end='', flush=True)
+    previous_length = len(updated_string)
 
-        # Print the updated string in place
-        updated_string = tokenizer.decode(result)
-        n_tokens = len(result)
-        print(updated_string[previous_length:], end='', flush=True)
-        previous_length = len(updated_string)
+    if is_finished:
+      print("\nDone")
+      break
+  end_time = time.perf_counter()
+  print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
 
-        if is_finished:
-            print("\nDone")
-            break
-    end_time = time.perf_counter()
-    print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Run prompt")
-    parser.add_argument("--prompt", type=str, help="The prompt to run")
-    args = parser.parse_args()
+  parser = argparse.ArgumentParser(description="Run prompt")
+  parser.add_argument("--prompt", type=str, help="The prompt to run")
+  args = parser.parse_args()
 
-    asyncio.run(run_prompt(args.prompt))
+  asyncio.run(run_prompt(args.prompt))

+ 62 - 59
exo/api/chatgpt_api.py

@@ -16,29 +16,27 @@ from exo.orchestration import Node
 from exo.models import model_base_shards
 from typing import Callable
 
+
 class Message:
-    def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
-        self.role = role
-        self.content = content
 
-    def to_dict(self):
-        return {
-            "role": self.role,
-            "content": self.content
-        }
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+    self.role = role
+    self.content = content
+
+  def to_dict(self):
+    return {"role": self.role, "content": self.content}
+
 
 class ChatCompletionRequest:
-    def __init__(self, model: str, messages: List[Message], temperature: float):
-        self.model = model
-        self.messages = messages
-        self.temperature = temperature
-
-    def to_dict(self):
-        return {
-            "model": self.model,
-            "messages": [message.to_dict() for message in self.messages],
-            "temperature": self.temperature
-        }
+
+  def __init__(self, model: str, messages: List[Message], temperature: float):
+    self.model = model
+    self.messages = messages
+    self.temperature = temperature
+
+  def to_dict(self):
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
+
 
 def generate_completion(
   chat_request: ChatCompletionRequest,
@@ -56,14 +54,12 @@ def generate_completion(
     "created": int(time.time()),
     "model": chat_request.model,
     "system_fingerprint": f"exo_{VERSION}",
-    "choices": [
-      {
-        "index": 0,
-        "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
-        "logprobs": None,
-        "finish_reason": finish_reason,
-      }
-    ],
+    "choices": [{
+      "index": 0,
+      "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
+      "logprobs": None,
+      "finish_reason": finish_reason,
+    }],
   }
 
   if not stream:
@@ -86,37 +82,38 @@ def generate_completion(
 
 
 def remap_messages(messages: List[Message]) -> List[Message]:
-    remapped_messages = []
-    last_image = None
-    for message in messages:
-        if not isinstance(message.content, list):
-           remapped_messages.append(message)
-           continue
-
-        remapped_content = []
-        for content in message.content:
-            if isinstance(content, dict):
-                if content.get("type") in ["image_url", "image"]:
-                    image_url = content.get("image_url", {}).get("url") or content.get("image")
-                    if image_url:
-                        last_image = {"type": "image", "image": image_url}
-                        remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
-                else:
-                    remapped_content.append(content)
-            else:
-                remapped_content.append(content)
-        remapped_messages.append(Message(role=message.role, content=remapped_content))
-
-    if last_image:
-        # Replace the last image placeholder with the actual image content
-        for message in reversed(remapped_messages):
-            for i, content in enumerate(message.content):
-                if isinstance(content, dict):
-                  if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
-                      message.content[i] = last_image
-                      return remapped_messages
-
-    return remapped_messages
+  remapped_messages = []
+  last_image = None
+  for message in messages:
+    if not isinstance(message.content, list):
+      remapped_messages.append(message)
+      continue
+
+    remapped_content = []
+    for content in message.content:
+      if isinstance(content, dict):
+        if content.get("type") in ["image_url", "image"]:
+          image_url = content.get("image_url", {}).get("url") or content.get("image")
+          if image_url:
+            last_image = {"type": "image", "image": image_url}
+            remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+        else:
+          remapped_content.append(content)
+      else:
+        remapped_content.append(content)
+    remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+  if last_image:
+    # Replace the last image placeholder with the actual image content
+    for message in reversed(remapped_messages):
+      for i, content in enumerate(message.content):
+        if isinstance(content, dict):
+          if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+            message.content[i] = last_image
+            return remapped_messages
+
+  return remapped_messages
+
 
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
@@ -149,13 +146,17 @@ def parse_chat_request(data: dict):
     data.get("temperature", 0.0),
   )
 
+
 class PromptSession:
+
   def __init__(self, request_id: str, timestamp: int, prompt: str):
     self.request_id = request_id
     self.timestamp = timestamp
     self.prompt = prompt
 
+
 class ChatGPTAPI:
+
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
@@ -182,6 +183,7 @@ class ChatGPTAPI:
     self.app.middlewares.append(self.log_request)
 
   async def log_request(self, app, handler):
+
     async def middleware(request):
       if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
       return await handler(request)
@@ -268,7 +270,8 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
+                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             if is_finished:

+ 50 - 64
exo/download/download_progress.py

@@ -2,81 +2,67 @@ from typing import Dict, Callable, Coroutine, Any, Literal
 from dataclasses import dataclass
 from datetime import timedelta
 
+
 @dataclass
 class RepoFileProgressEvent:
-    repo_id: str
-    repo_revision: str
-    file_path: str
-    downloaded: int
-    downloaded_this_session: int
-    total: int
-    speed: int
-    eta: timedelta
-    status: Literal["not_started", "in_progress", "complete"]
+  repo_id: str
+  repo_revision: str
+  file_path: str
+  downloaded: int
+  downloaded_this_session: int
+  total: int
+  speed: int
+  eta: timedelta
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+    }
 
-    def to_dict(self):
-        return {
-            "repo_id": self.repo_id,
-            "repo_revision": self.repo_revision,
-            "file_path": self.file_path,
-            "downloaded": self.downloaded,
-            "downloaded_this_session": self.downloaded_this_session,
-            "total": self.total,
-            "speed": self.speed,
-            "eta": self.eta.total_seconds(),
-            "status": self.status
-        }
+  @classmethod
+  def from_dict(cls, data):
+    # Convert eta from seconds back to timedelta
+    if 'eta' in data:
+      data['eta'] = timedelta(seconds=data['eta'])
+    return cls(**data)
 
-    @classmethod
-    def from_dict(cls, data):
-        # Convert eta from seconds back to timedelta
-        if 'eta' in data:
-            data['eta'] = timedelta(seconds=data['eta'])
-        return cls(**data)
 
 @dataclass
 class RepoProgressEvent:
-    repo_id: str
-    repo_revision: str
-    completed_files: int
-    total_files: int
-    downloaded_bytes: int
-    downloaded_bytes_this_session: int
-    total_bytes: int
-    overall_speed: int
-    overall_eta: timedelta
-    file_progress: Dict[str, RepoFileProgressEvent]
-    status: Literal["not_started", "in_progress", "complete"]
+  repo_id: str
+  repo_revision: str
+  completed_files: int
+  total_files: int
+  downloaded_bytes: int
+  downloaded_bytes_this_session: int
+  total_bytes: int
+  overall_speed: int
+  overall_eta: timedelta
+  file_progress: Dict[str, RepoFileProgressEvent]
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
+      "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
+      "file_progress": {k: v.to_dict()
+                        for k, v in self.file_progress.items()}, "status": self.status
+    }
 
-    def to_dict(self):
-        return {
-            "repo_id": self.repo_id,
-            "repo_revision": self.repo_revision,
-            "completed_files": self.completed_files,
-            "total_files": self.total_files,
-            "downloaded_bytes": self.downloaded_bytes,
-            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
-            "total_bytes": self.total_bytes,
-            "overall_speed": self.overall_speed,
-            "overall_eta": self.overall_eta.total_seconds(),
-            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
-            "status": self.status
-        }
+  @classmethod
+  def from_dict(cls, data):
+    # Convert overall_eta from seconds back to timedelta
+    if 'overall_eta' in data:
+      data['overall_eta'] = timedelta(seconds=data['overall_eta'])
 
-    @classmethod
-    def from_dict(cls, data):
-        # Convert overall_eta from seconds back to timedelta
-        if 'overall_eta' in data:
-            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+    # Parse file_progress
+    if 'file_progress' in data:
+      data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
 
-        # Parse file_progress
-        if 'file_progress' in data:
-            data['file_progress'] = {
-                k: RepoFileProgressEvent.from_dict(v)
-                for k, v in data['file_progress'].items()
-            }
+    return cls(**data)
 
-        return cls(**data)
 
 RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
 RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 334 - 297
exo/download/hf/hf_helpers.py

@@ -16,282 +16,322 @@ import aiofiles
 from aiofiles import os as aios
 
 T = TypeVar("T")
+
+
 def filter_repo_objects(
-    items: Iterable[T],
-    *,
-    allow_patterns: Optional[Union[List[str], str]] = None,
-    ignore_patterns: Optional[Union[List[str], str]] = None,
-    key: Optional[Callable[[T], str]] = None,
+  items: Iterable[T],
+  *,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  key: Optional[Callable[[T], str]] = None,
 ) -> Generator[T, None, None]:
-    if isinstance(allow_patterns, str):
-        allow_patterns = [allow_patterns]
-    if isinstance(ignore_patterns, str):
-        ignore_patterns = [ignore_patterns]
-    if allow_patterns is not None:
-        allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
-    if ignore_patterns is not None:
-        ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
-
-    if key is None:
-        def _identity(item: T) -> str:
-            if isinstance(item, str):
-                return item
-            if isinstance(item, Path):
-                return str(item)
-            raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-        key = _identity
-
-    for item in items:
-        path = key(item)
-        if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
-            continue
-        if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
-            continue
-        yield item
+  if isinstance(allow_patterns, str):
+    allow_patterns = [allow_patterns]
+  if isinstance(ignore_patterns, str):
+    ignore_patterns = [ignore_patterns]
+  if allow_patterns is not None:
+    allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
+  if ignore_patterns is not None:
+    ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
+
+  if key is None:
+
+    def _identity(item: T) -> str:
+      if isinstance(item, str):
+        return item
+      if isinstance(item, Path):
+        return str(item)
+      raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
+
+    key = _identity
+
+  for item in items:
+    path = key(item)
+    if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
+      continue
+    if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
+      continue
+    yield item
+
 
 def _add_wildcard_to_directories(pattern: str) -> str:
-    if pattern[-1] == "/":
-        return pattern + "*"
-    return pattern
+  if pattern[-1] == "/":
+    return pattern + "*"
+  return pattern
+
 
 def get_hf_home() -> Path:
-    """Get the Hugging Face home directory."""
-    return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+  """Get the Hugging Face home directory."""
+  return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+
 
 async def get_hf_token():
-    """Retrieve the Hugging Face token from the user's HF_HOME directory."""
-    token_path = get_hf_home() / "token"
-    if await aios.path.exists(token_path):
-        async with aiofiles.open(token_path, 'r') as f:
-            return (await f.read()).strip()
-    return None
+  """Retrieve the Hugging Face token from the user's HF_HOME directory."""
+  token_path = get_hf_home() / "token"
+  if await aios.path.exists(token_path):
+    async with aiofiles.open(token_path, 'r') as f:
+      return (await f.read()).strip()
+  return None
+
 
 async def get_auth_headers():
-    """Get authentication headers if a token is available."""
-    token = await get_hf_token()
-    if token:
-        return {"Authorization": f"Bearer {token}"}
-    return {}
+  """Get authentication headers if a token is available."""
+  token = await get_hf_token()
+  if token:
+    return {"Authorization": f"Bearer {token}"}
+  return {}
+
 
 def get_repo_root(repo_id: str) -> Path:
-    """Get the root directory for a given repo ID in the Hugging Face cache."""
-    sanitized_repo_id = repo_id.replace("/", "--")
-    return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+  """Get the root directory for a given repo ID in the Hugging Face cache."""
+  sanitized_repo_id = repo_id.replace("/", "--")
+  return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+
 
 async def fetch_file_list(session, repo_id, revision, path=""):
-    api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
-    url = f"{api_url}/{path}" if path else api_url
-
-    headers = await get_auth_headers()
-    async with session.get(url, headers=headers) as response:
-        if response.status == 200:
-            data = await response.json()
-            files = []
-            for item in data:
-                if item["type"] == "file":
-                    files.append({"path": item["path"], "size": item["size"]})
-                elif item["type"] == "directory":
-                    subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-                    files.extend(subfiles)
-            return files
-        else:
-            raise Exception(f"Failed to fetch file list: {response.status}")
+  api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
 
 
 @retry(
-    stop=stop_after_attempt(5),
-    wait=wait_exponential(multiplier=1, min=4, max=60),
-    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
-    reraise=True
+  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
 )
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
-    base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, file_path)
-    local_path = os.path.join(save_directory, file_path)
-
-    await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-    # Check if file already exists and get its size
-    local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-    headers = await get_auth_headers()
-    if use_range_request:
-        headers["Range"] = f"bytes={local_file_size}-"
-
-    async with session.get(url, headers=headers) as response:
-        total_size = int(response.headers.get('Content-Length', 0))
-        downloaded_size = local_file_size
-        downloaded_this_session = 0
-        mode = 'ab' if use_range_request else 'wb'
+async def download_file(
+  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
+):
+  base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, file_path)
+  local_path = os.path.join(save_directory, file_path)
+
+  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
+
+  # Check if file already exists and get its size
+  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
+
+  headers = await get_auth_headers()
+  if use_range_request:
+    headers["Range"] = f"bytes={local_file_size}-"
+
+  async with session.get(url, headers=headers) as response:
+    total_size = int(response.headers.get('Content-Length', 0))
+    downloaded_size = local_file_size
+    downloaded_this_session = 0
+    mode = 'ab' if use_range_request else 'wb'
+    if downloaded_size == total_size:
+      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    if response.status == 200:
+      # File doesn't support range requests or we're not using them, start from beginning
+      mode = 'wb'
+      downloaded_size = 0
+    elif response.status == 206:
+      # Partial content, resume download
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+    elif response.status == 416:
+      # Range not satisfiable, get the actual file size
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
         if downloaded_size == total_size:
-            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-            if progress_callback:
-                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-            return
-
-        if response.status == 200:
-            # File doesn't support range requests or we're not using them, start from beginning
-            mode = 'wb'
-            downloaded_size = 0
-        elif response.status == 206:
-            # Partial content, resume download
-            content_range = response.headers.get('Content-Range', '')
-            try:
-                total_size = int(content_range.split('/')[-1])
-            except ValueError:
-                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-        elif response.status == 416:
-            # Range not satisfiable, get the actual file size
-            content_range = response.headers.get('Content-Range', '')
-            try:
-                total_size = int(content_range.split('/')[-1])
-                if downloaded_size == total_size:
-                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-                    if progress_callback:
-                        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-                    return
-            except ValueError:
-                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-        else:
-            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
-        if downloaded_size == total_size:
-            print(f"File already downloaded: {file_path}")
-            if progress_callback:
-                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-            return
-
-        DOWNLOAD_CHUNK_SIZE = 32768
-        start_time = datetime.now()
-        async with aiofiles.open(local_path, mode) as f:
-            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-                await f.write(chunk)
-                downloaded_size += len(chunk)
-                downloaded_this_session += len(chunk)
-                if progress_callback and total_size:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_size = total_size - downloaded_size
-                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
-                    status = "in_progress" if downloaded_size < total_size else "complete"
-                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-                    await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-        if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
-    repo_root = get_repo_root(repo_id)
-    refs_dir = repo_root / "refs"
-    snapshots_dir = repo_root / "snapshots"
-    cachedreqs_dir = repo_root / "cachedreqs"
-
-    # Ensure directories exist
-    await aios.makedirs(refs_dir, exist_ok=True)
-    await aios.makedirs(snapshots_dir, exist_ok=True)
-    await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-    # Check if we have a cached commit hash
-    refs_file = refs_dir / revision
-    if await aios.path.exists(refs_file):
-        async with aiofiles.open(refs_file, 'r') as f:
-            commit_hash = (await f.read()).strip()
-            if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+          if progress_callback:
+            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+          return
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
     else:
-        async with aiohttp.ClientSession() as session:
-            # Fetch the commit hash for the given revision
-            api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
-            headers = await get_auth_headers()
-            async with session.get(api_url, headers=headers) as response:
-                if response.status != 200:
-                    raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-                revision_info = await response.json()
-                commit_hash = revision_info['sha']
-
-            # Cache the commit hash
-            async with aiofiles.open(refs_file, 'w') as f:
-                await f.write(commit_hash)
-
-    # Set up the snapshot directory
-    snapshot_dir = snapshots_dir / commit_hash
-    await aios.makedirs(snapshot_dir, exist_ok=True)
-
-    # Set up the cached file list directory
-    cached_file_list_dir = cachedreqs_dir / commit_hash
-    await aios.makedirs(cached_file_list_dir, exist_ok=True)
-    cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
-
+      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
+
+    if downloaded_size == total_size:
+      print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    DOWNLOAD_CHUNK_SIZE = 32768
+    start_time = datetime.now()
+    async with aiofiles.open(local_path, mode) as f:
+      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
+        await f.write(chunk)
+        downloaded_size += len(chunk)
+        downloaded_this_session += len(chunk)
+        if progress_callback and total_size:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
+          remaining_size = total_size - downloaded_size
+          eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+          status = "in_progress" if downloaded_size < total_size else "complete"
+          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+    if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+
+async def download_repo_files(
+  repo_id: str,
+  revision: str = "main",
+  progress_callback: Optional[RepoProgressCallback] = None,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  max_parallel_downloads: int = 4
+) -> Path:
+  repo_root = get_repo_root(repo_id)
+  refs_dir = repo_root / "refs"
+  snapshots_dir = repo_root / "snapshots"
+  cachedreqs_dir = repo_root / "cachedreqs"
+
+  # Ensure directories exist
+  await aios.makedirs(refs_dir, exist_ok=True)
+  await aios.makedirs(snapshots_dir, exist_ok=True)
+  await aios.makedirs(cachedreqs_dir, exist_ok=True)
+
+  # Check if we have a cached commit hash
+  refs_file = refs_dir / revision
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+  else:
     async with aiohttp.ClientSession() as session:
-        # Check if we have a cached file list
-        if await aios.path.exists(cached_file_list_path):
-            async with aiofiles.open(cached_file_list_path, 'r') as f:
-                file_list = json.loads(await f.read())
-            if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-        else:
-            file_list = await fetch_file_list(session, repo_id, revision)
-            # Cache the file list
-            async with aiofiles.open(cached_file_list_path, 'w') as f:
-                await f.write(json.dumps(file_list))
-            if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
-        filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
-        total_files = len(filtered_file_list)
-        total_bytes = sum(file["size"] for file in filtered_file_list)
-        file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
-        start_time = datetime.now()
-
-        async def download_with_progress(file_info, progress_state):
-            local_path = snapshot_dir / file_info["path"]
-            if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-                if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-                progress_state['completed_files'] += 1
-                progress_state['downloaded_bytes'] += file_info["size"]
-                file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-                if progress_callback:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                    status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-                return
-
-            async def file_progress_callback(event: RepoFileProgressEvent):
-                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-                file_progress[event.file_path] = event
-                if progress_callback:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-
-            await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-            progress_state['completed_files'] += 1
-            file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-            if progress_callback:
-                elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-
-        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-        semaphore = asyncio.Semaphore(max_parallel_downloads)
-        async def download_with_semaphore(file_info):
-            async with semaphore:
-                await download_with_progress(file_info, progress_state)
-        tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-        await asyncio.gather(*tasks)
-
-    return snapshot_dir
+      # Fetch the commit hash for the given revision
+      api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+      headers = await get_auth_headers()
+      async with session.get(api_url, headers=headers) as response:
+        if response.status != 200:
+          raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+        revision_info = await response.json()
+        commit_hash = revision_info['sha']
+
+      # Cache the commit hash
+      async with aiofiles.open(refs_file, 'w') as f:
+        await f.write(commit_hash)
+
+  # Set up the snapshot directory
+  snapshot_dir = snapshots_dir / commit_hash
+  await aios.makedirs(snapshot_dir, exist_ok=True)
+
+  # Set up the cached file list directory
+  cached_file_list_dir = cachedreqs_dir / commit_hash
+  await aios.makedirs(cached_file_list_dir, exist_ok=True)
+  cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
+
+  async with aiohttp.ClientSession() as session:
+    # Check if we have a cached file list
+    if await aios.path.exists(cached_file_list_path):
+      async with aiofiles.open(cached_file_list_path, 'r') as f:
+        file_list = json.loads(await f.read())
+      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
+    else:
+      file_list = await fetch_file_list(session, repo_id, revision)
+      # Cache the file list
+      async with aiofiles.open(cached_file_list_path, 'w') as f:
+        await f.write(json.dumps(file_list))
+      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
+
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+    total_files = len(filtered_file_list)
+    total_bytes = sum(file["size"] for file in filtered_file_list)
+    file_progress: Dict[str, RepoFileProgressEvent] = {
+      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
+      for file in filtered_file_list
+    }
+    start_time = datetime.now()
+
+    async def download_with_progress(file_info, progress_state):
+      local_path = snapshot_dir / file_info["path"]
+      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
+        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
+        progress_state['completed_files'] += 1
+        progress_state['downloaded_bytes'] += file_info["size"]
+        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+        return
+
+      async def file_progress_callback(event: RepoFileProgressEvent):
+        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
+        file_progress[event.file_path] = event
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+
+      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
+      progress_state['completed_files'] += 1
+      file_progress[
+        file_info["path"]
+      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+      if progress_callback:
+        elapsed_time = (datetime.now() - start_time).total_seconds()
+        overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+        overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+        await progress_callback(
+          RepoProgressEvent(
+            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+            overall_eta, file_progress, status
+          )
+        )
+
+    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+
+    async def download_with_semaphore(file_info):
+      async with semaphore:
+        await download_with_progress(file_info, progress_state)
+
+    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
+    await asyncio.gather(*tasks)
+
+  return snapshot_dir
+
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-    """
+  """
     Retrieve the weight map from the model.safetensors.index.json file.
 
     Args:
@@ -302,55 +342,52 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
         Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
     """
 
-    # Download the index file
-    await download_repo_files(
-        repo_id=repo_id,
-        revision=revision,
-        allow_patterns="model.safetensors.index.json"
-    )
+  # Download the index file
+  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
 
-    # Check if the file exists
-    repo_root = get_repo_root(repo_id)
-    snapshot_dir = repo_root / "snapshots"
-    index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
+  # Check if the file exists
+  repo_root = get_repo_root(repo_id)
+  snapshot_dir = repo_root / "snapshots"
+  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
 
-    if index_file:
-        index_file_path = snapshot_dir / index_file
-        if await aios.path.exists(index_file_path):
-            async with aiofiles.open(index_file_path, 'r') as f:
-                index_data = json.loads(await f.read())
-            return index_data.get("weight_map")
+  if index_file:
+    index_file_path = snapshot_dir / index_file
+    if await aios.path.exists(index_file_path):
+      async with aiofiles.open(index_file_path, 'r') as f:
+        index_data = json.loads(await f.read())
+      return index_data.get("weight_map")
+
+  return None
 
-    return None
 
 def extract_layer_num(tensor_name: str) -> Optional[int]:
-    # This is a simple example and might need to be adjusted based on the actual naming convention
-    parts = tensor_name.split('.')
-    for part in parts:
-        if part.isdigit():
-            return int(part)
-    return None
+  # This is a simple example and might need to be adjusted based on the actual naming convention
+  parts = tensor_name.split('.')
+  for part in parts:
+    if part.isdigit():
+      return int(part)
+  return None
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-    default_patterns = [
-        "*.json",
-        "*.py",
-        "tokenizer.model",
-        "*.tiktoken",
-        "*.txt",
-    ]
-    shard_specific_patterns = []
-    if weight_map:
-        for tensor_name, filename in weight_map.items():
-            layer_num = extract_layer_num(tensor_name)
-            if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
-                shard_specific_patterns.append(filename)
-        sorted_file_names = sorted(weight_map.values())
-        if shard.is_first_layer():
-            shard_specific_patterns.append(sorted_file_names[0])
-        elif shard.is_last_layer():
-            shard_specific_patterns.append(sorted_file_names[-1])
-    else:
-        shard_specific_patterns = ["*.safetensors"]
-    return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
+  default_patterns = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+  ]
+  shard_specific_patterns = []
+  if weight_map:
+    for tensor_name, filename in weight_map.items():
+      layer_num = extract_layer_num(tensor_name)
+      if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+        shard_specific_patterns.append(filename)
+    sorted_file_names = sorted(weight_map.values())
+    if shard.is_first_layer():
+      shard_specific_patterns.append(sorted_file_names[0])
+    elif shard.is_last_layer():
+      shard_specific_patterns.append(sorted_file_names[-1])
+  else:
+    shard_specific_patterns = ["*.safetensors"]
+  return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 58 - 60
exo/download/hf/hf_shard_download.py

@@ -8,72 +8,70 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
+
 class HFShardDownloader(ShardDownloader):
-    def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-        self.quick_check = quick_check
-        self.max_parallel_downloads = max_parallel_downloads
-        self.active_downloads: Dict[Shard, asyncio.Task] = {}
-        self.completed_downloads: Dict[Shard, Path] = {}
-        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
-    async def ensure_shard(self, shard: Shard) -> Path:
-        if shard in self.completed_downloads:
-            return self.completed_downloads[shard]
-        if self.quick_check:
-            repo_root = get_repo_root(shard.model_id)
-            snapshots_dir = repo_root / "snapshots"
-            if snapshots_dir.exists():
-                most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
-                return most_recent_dir
+  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
+    self.quick_check = quick_check
+    self.max_parallel_downloads = max_parallel_downloads
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+    self.completed_downloads: Dict[Shard, Path] = {}
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+  async def ensure_shard(self, shard: Shard) -> Path:
+    if shard in self.completed_downloads:
+      return self.completed_downloads[shard]
+    if self.quick_check:
+      repo_root = get_repo_root(shard.model_id)
+      snapshots_dir = repo_root / "snapshots"
+      if snapshots_dir.exists():
+        most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
+        return most_recent_dir
+
+    # If a download on this shard is already in progress, keep that one
+    for active_shard in self.active_downloads:
+      if active_shard == shard:
+        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
+        return await self.active_downloads[shard]
 
-        # If a download on this shard is already in progress, keep that one
-        for active_shard in self.active_downloads:
-            if active_shard == shard:
-                if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-                return await self.active_downloads[shard]
+    # Cancel any downloads for this model_id on a different shard
+    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+    for active_shard in existing_active_shards:
+      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+      task = self.active_downloads[active_shard]
+      task.cancel()
+      try:
+        await task
+      except asyncio.CancelledError:
+        pass  # This is expected when cancelling a task
+      except Exception as e:
+        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+        traceback.print_exc()
+    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
-        # Cancel any downloads for this model_id on a different shard
-        existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-        for active_shard in existing_active_shards:
-            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-            task = self.active_downloads[active_shard]
-            task.cancel()
-            try:
-                await task
-            except asyncio.CancelledError:
-                pass  # This is expected when cancelling a task
-            except Exception as e:
-                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-                traceback.print_exc()
-        self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
+    # Start new download
+    download_task = asyncio.create_task(self._download_shard(shard))
+    self.active_downloads[shard] = download_task
+    try:
+      path = await download_task
+      self.completed_downloads[shard] = path
+      return path
+    finally:
+      # Ensure the task is removed even if an exception occurs
+      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+      if shard in self.active_downloads:
+        self.active_downloads.pop(shard)
 
-        # Start new download
-        download_task = asyncio.create_task(self._download_shard(shard))
-        self.active_downloads[shard] = download_task
-        try:
-            path = await download_task
-            self.completed_downloads[shard] = path
-            return path
-        finally:
-            # Ensure the task is removed even if an exception occurs
-            print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-            if shard in self.active_downloads:
-                self.active_downloads.pop(shard)
+  async def _download_shard(self, shard: Shard) -> Path:
 
-    async def _download_shard(self, shard: Shard) -> Path:
-        async def wrapped_progress_callback(event: RepoProgressEvent):
-            self._on_progress.trigger_all(shard, event)
+    async def wrapped_progress_callback(event: RepoProgressEvent):
+      self._on_progress.trigger_all(shard, event)
 
-        weight_map = await get_weight_map(shard.model_id)
-        allow_patterns = get_allow_patterns(weight_map, shard)
+    weight_map = await get_weight_map(shard.model_id)
+    allow_patterns = get_allow_patterns(weight_map, shard)
 
-        return await download_repo_files(
-            repo_id=shard.model_id,
-            progress_callback=wrapped_progress_callback,
-            allow_patterns=allow_patterns,
-            max_parallel_downloads=self.max_parallel_downloads
-        )
+    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
-    @property
-    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-        return self._on_progress
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress

+ 10 - 8
exo/download/shard_download.py

@@ -5,10 +5,12 @@ from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import AsyncCallbackSystem
 
+
 class ShardDownloader(ABC):
-    @abstractmethod
-    async def ensure_shard(self, shard: Shard) -> Path:
-        """
+
+  @abstractmethod
+  async def ensure_shard(self, shard: Shard) -> Path:
+    """
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
         If you try to download a Shard which overlaps a Shard that is already being downloaded,
@@ -17,9 +19,9 @@ class ShardDownloader(ABC):
         Args:
             shard (Shard): The shard to download.
         """
-        pass
+    pass
 
-    @property
-    @abstractmethod
-    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-        pass
+  @property
+  @abstractmethod
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    pass

+ 85 - 74
exo/helpers.py

@@ -20,6 +20,7 @@ exo_text = r"""
  \___/_/\_\___/ 
     """
 
+
 def get_system_info():
   if psutil.MACOS:
     if platform.machine() == "arm64":
@@ -87,7 +88,10 @@ def terminal_link(uri, label=None):
 
 T = TypeVar("T")
 K = TypeVar("K")
+
+
 class AsyncCallback(Generic[T]):
+
   def __init__(self) -> None:
     self.condition: asyncio.Condition = asyncio.Condition()
     self.result: Optional[Tuple[T, ...]] = None
@@ -95,9 +99,7 @@ class AsyncCallback(Generic[T]):
 
   async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
     async with self.condition:
-      await asyncio.wait_for(
-        self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
-      )
+      await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
       assert self.result is not None  # for type checking
       return self.result
 
@@ -116,6 +118,7 @@ class AsyncCallback(Generic[T]):
 
 
 class AsyncCallbackSystem(Generic[K, T]):
+
   def __init__(self) -> None:
     self.callbacks: Dict[K, AsyncCallback[T]] = {}
 
@@ -139,89 +142,97 @@ class AsyncCallbackSystem(Generic[K, T]):
 
 K = TypeVar('K', bound=str)
 V = TypeVar('V')
+
+
 class PrefixDict(Generic[K, V]):
-    def __init__(self):
-        self.items: Dict[K, V] = {}
 
-    def add(self, key: K, value: V) -> None:
-        self.items[key] = value
+  def __init__(self):
+    self.items: Dict[K, V] = {}
+
+  def add(self, key: K, value: V) -> None:
+    self.items[key] = value
 
-    def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
-        return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
+  def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
+    return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
 
-    def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
-        matches = self.find_prefix(argument)
-        if len(matches) == 0:
-            return None
+  def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
+    matches = self.find_prefix(argument)
+    if len(matches) == 0:
+      return None
+
+    return max(matches, key=lambda x: len(x[0]))
 
-        return max(matches, key=lambda x: len(x[0]))
 
 def is_valid_uuid(val):
-    try:
-        uuid.UUID(str(val))
-        return True
-    except ValueError:
-        return False
+  try:
+    uuid.UUID(str(val))
+    return True
+  except ValueError:
+    return False
+
 
 def get_or_create_node_id():
-    NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
-    try:
-        if NODE_ID_FILE.is_file():
-            with open(NODE_ID_FILE, "r") as f:
-                stored_id = f.read().strip()
-            if is_valid_uuid(stored_id):
-                if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
-                return stored_id
-            else:
-                if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
-
-        new_id = str(uuid.uuid4())
-        with open(NODE_ID_FILE, "w") as f:
-            f.write(new_id)
-
-        if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
-        return new_id
-    except IOError as e:
-        if DEBUG >= 2: print(f"IO error creating node_id: {e}")
-        return str(uuid.uuid4())
-    except Exception as e:
-        if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
-        return str(uuid.uuid4())
+  NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
+  try:
+    if NODE_ID_FILE.is_file():
+      with open(NODE_ID_FILE, "r") as f:
+        stored_id = f.read().strip()
+      if is_valid_uuid(stored_id):
+        if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
+        return stored_id
+      else:
+        if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
+
+    new_id = str(uuid.uuid4())
+    with open(NODE_ID_FILE, "w") as f:
+      f.write(new_id)
+
+    if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
+    return new_id
+  except IOError as e:
+    if DEBUG >= 2: print(f"IO error creating node_id: {e}")
+    return str(uuid.uuid4())
+  except Exception as e:
+    if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
+    return str(uuid.uuid4())
+
 
 def pretty_print_bytes(size_in_bytes: int) -> str:
-    if size_in_bytes < 1024:
-        return f"{size_in_bytes} B"
-    elif size_in_bytes < 1024 ** 2:
-        return f"{size_in_bytes / 1024:.2f} KB"
-    elif size_in_bytes < 1024 ** 3:
-        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
-    elif size_in_bytes < 1024 ** 4:
-        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
-    else:
-        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+  if size_in_bytes < 1024:
+    return f"{size_in_bytes} B"
+  elif size_in_bytes < 1024**2:
+    return f"{size_in_bytes / 1024:.2f} KB"
+  elif size_in_bytes < 1024**3:
+    return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+  elif size_in_bytes < 1024**4:
+    return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+  else:
+    return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
 
 def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
-    if bytes_per_second < 1024:
-        return f"{bytes_per_second} B/s"
-    elif bytes_per_second < 1024 ** 2:
-        return f"{bytes_per_second / 1024:.2f} KB/s"
-    elif bytes_per_second < 1024 ** 3:
-        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
-    elif bytes_per_second < 1024 ** 4:
-        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
-    else:
-        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+  if bytes_per_second < 1024:
+    return f"{bytes_per_second} B/s"
+  elif bytes_per_second < 1024**2:
+    return f"{bytes_per_second / 1024:.2f} KB/s"
+  elif bytes_per_second < 1024**3:
+    return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+  elif bytes_per_second < 1024**4:
+    return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+  else:
+    return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+
 
 def get_all_ip_addresses():
-    try:
-      ip_addresses = []
-      for interface in netifaces.interfaces():
-        ifaddresses = netifaces.ifaddresses(interface)
-        if netifaces.AF_INET in ifaddresses:
-          for link in ifaddresses[netifaces.AF_INET]:
-            ip = link['addr']
-            ip_addresses.append(ip)
-      return list(set(ip_addresses))
-    except:
-      if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-      return ["localhost"]
+  try:
+    ip_addresses = []
+    for interface in netifaces.interfaces():
+      ifaddresses = netifaces.ifaddresses(interface)
+      if netifaces.AF_INET in ifaddresses:
+        for link in ifaddresses[netifaces.AF_INET]:
+          ip = link['addr']
+          ip_addresses.append(ip)
+    return list(set(ip_addresses))
+  except:
+    if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
+    return ["localhost"]

+ 5 - 7
exo/inference/debug_inference_engine.py

@@ -52,10 +52,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(
-  test_inference_engine(
-    TinygradDynamicShardInferenceEngine(),
-    TinygradDynamicShardInferenceEngine(),
-    "llama3-8b-sfr",
-  )
-)
+asyncio.run(test_inference_engine(
+  TinygradDynamicShardInferenceEngine(),
+  TinygradDynamicShardInferenceEngine(),
+  "llama3-8b-sfr",
+))

+ 2 - 0
exo/inference/inference_engine.py

@@ -5,7 +5,9 @@ from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
 
+
 class InferenceEngine(ABC):
+
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     pass

+ 1 - 0
exo/inference/mlx/models/base.py

@@ -5,5 +5,6 @@ from mlx_lm.models.base import KVCache
 
 
 class IdentityBlock(nn.Module):
+
   def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
     return x

+ 4 - 5
exo/inference/mlx/models/deepseek_v2.py

@@ -7,7 +7,7 @@ import mlx.nn as nn
 from mlx_lm.models.base import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
-from ...shard import Shard
+from exo.inference.shard import Shard
 
 
 @dataclass
@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
 
 class DeepseekV2Model(nn.Module):
+
   def __init__(self, config: ModelArgs):
     super().__init__()
     self.args = config
@@ -70,6 +71,7 @@ class DeepseekV2Model(nn.Module):
 
 
 class Model(nn.Module):
+
   def __init__(self, config: ModelArgs):
     super().__init__()
     self.args = config
@@ -107,10 +109,7 @@ class Model(nn.Module):
         for k in ["weight", "scales", "biases"]:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
             to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
-            shard_state_dict[
-              f"{prefix}.mlp.switch_mlp.{
-       m}.{k}"
-            ] = mx.stack(to_join)
+            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
 
     return shard_state_dict
 

+ 5 - 3
exo/inference/mlx/models/llama.py

@@ -24,7 +24,9 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
+
 class LlamaModel(nn.Module):
+
   def __init__(self, args: ModelArgs):
     super().__init__()
     self.args = args
@@ -66,7 +68,9 @@ class LlamaModel(nn.Module):
       h = self.norm(h)
     return h
 
+
 class Model(nn.Module):
+
   def __init__(self, args: ModelArgs):
     super().__init__()
     self.args = args
@@ -116,9 +120,7 @@ class Model(nn.Module):
 
   @property
   def head_dim(self):
-    return (
-      self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
-    )
+    return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
 
   @property
   def n_kv_heads(self):

+ 521 - 555
exo/inference/mlx/models/llava.py

@@ -15,619 +15,585 @@ import numpy as np
 
 @dataclass
 class VisionConfig:
-    model_type: str
-    num_hidden_layers: int = 24
-    hidden_size: int = 1024
-    intermediate_size: int = 4096
-    num_attention_heads: int = 16
-    image_size: int = 336
-    patch_size: int = 14
-    projection_dim: int = 768
-    vocab_size: int = 32000
-    num_channels: int = 3
-    layer_norm_eps: float = 1e-5
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
+  model_type: str
+  num_hidden_layers: int = 24
+  hidden_size: int = 1024
+  intermediate_size: int = 4096
+  num_attention_heads: int = 16
+  image_size: int = 336
+  patch_size: int = 14
+  projection_dim: int = 768
+  vocab_size: int = 32000
+  num_channels: int = 3
+  layer_norm_eps: float = 1e-5
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
 
 
 class VisionAttention(nn.Module):
-    def __init__(
-            self,
-            dims: int,
-            num_heads: int,
-            query_input_dims: Optional[int] = None,
-            key_input_dims: Optional[int] = None,
-            value_input_dims: Optional[int] = None,
-            value_dims: Optional[int] = None,
-            value_output_dims: Optional[int] = None,
-            bias: bool = False,
-    ):
-        super().__init__()
-
-        if (dims % num_heads) != 0:
-            raise ValueError(
-                "The input feature dimensions should be divisible by the "
-                f"number of heads ({dims} % {num_heads}) != 0"
-            )
-
-        query_input_dims = query_input_dims or dims
-        key_input_dims = key_input_dims or dims
-        value_input_dims = value_input_dims or key_input_dims
-        value_dims = value_dims or dims
-        value_output_dims = value_output_dims or dims
-
-        self.num_heads = num_heads
-        self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
-        self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
-        self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
-        self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
-
-    def __call__(self, queries, keys, values, mask=None):
-        queries = self.q_proj(queries)
-        keys = self.k_proj(keys)
-        values = self.v_proj(values)
-
-        num_heads = self.num_heads
-        B, L, D = queries.shape
-        _, S, _ = keys.shape
-        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
-        values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
-
-        scale = math.sqrt(1 / queries.shape[-1])
-        scores = (queries * scale) @ keys
-        if mask is not None:
-            scores = scores + mask.astype(scores.dtype)
-        scores = mx.softmax(scores, axis=-1)
-        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
-
-        return self.out_proj(values_hat)
+
+  def __init__(
+    self,
+    dims: int,
+    num_heads: int,
+    query_input_dims: Optional[int] = None,
+    key_input_dims: Optional[int] = None,
+    value_input_dims: Optional[int] = None,
+    value_dims: Optional[int] = None,
+    value_output_dims: Optional[int] = None,
+    bias: bool = False,
+  ):
+    super().__init__()
+
+    if (dims % num_heads) != 0:
+      raise ValueError("The input feature dimensions should be divisible by the "
+                       f"number of heads ({dims} % {num_heads}) != 0")
+
+    query_input_dims = query_input_dims or dims
+    key_input_dims = key_input_dims or dims
+    value_input_dims = value_input_dims or key_input_dims
+    value_dims = value_dims or dims
+    value_output_dims = value_output_dims or dims
+
+    self.num_heads = num_heads
+    self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
+    self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
+    self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
+    self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
+
+  def __call__(self, queries, keys, values, mask=None):
+    queries = self.q_proj(queries)
+    keys = self.k_proj(keys)
+    values = self.v_proj(values)
+
+    num_heads = self.num_heads
+    B, L, D = queries.shape
+    _, S, _ = keys.shape
+    queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
+    values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+    scale = math.sqrt(1 / queries.shape[-1])
+    scores = (queries * scale) @ keys
+    if mask is not None:
+      scores = scores + mask.astype(scores.dtype)
+    scores = mx.softmax(scores, axis=-1)
+    values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+    return self.out_proj(values_hat)
 
 
 class VisionMLP(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.activation_fn = nn.GELU(approx="fast")
-        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
-        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
 
-    def __call__(self, x: mx.array) -> mx.array:
-        x = self.activation_fn(self.fc1(x))
-        x = self.fc2(x)
-        return x
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.activation_fn = nn.GELU(approx="fast")
+    self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+    self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.activation_fn(self.fc1(x))
+    x = self.fc2(x)
+    return x
 
 
 class VisionEncoderLayer(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.embed_dim = config.hidden_size
-        self.self_attn = VisionAttention(
-            config.hidden_size, config.num_attention_heads, bias=True
-        )
-        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-        self.mlp = VisionMLP(config)
-        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-
-    def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
-        y = self.layer_norm1(x)
-        y = self.self_attn(y, y, y, mask)
-        x = x + y
-        y = self.layer_norm2(x)
-        y = self.mlp(y)
-        return x + y
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embed_dim = config.hidden_size
+    self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True)
+    self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+    self.mlp = VisionMLP(config)
+    self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
+    y = self.layer_norm1(x)
+    y = self.self_attn(y, y, y, mask)
+    x = x + y
+    y = self.layer_norm2(x)
+    y = self.mlp(y)
+    return x + y
 
 
 class VisionEncoder(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
 
 
 class VisionEmbeddings(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.config = config
-        self.embed_dim = config.hidden_size
-        self.image_size = config.image_size
-        self.patch_size = config.patch_size
-
-        self.class_embedding = mx.zeros((config.hidden_size,))
-
-        self.patch_embedding = nn.Conv2d(
-            in_channels=config.num_channels,
-            out_channels=self.embed_dim,
-            kernel_size=self.patch_size,
-            stride=self.patch_size,
-            bias=False,
-        )
-
-        self.num_patches = (self.image_size // self.patch_size) ** 2
-        self.num_positions = self.num_patches + 1
-        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
-
-    def __call__(self, x: mx.array) -> mx.array:
-        batch_size = x.shape[0]
-        patch_embeddings = self.patch_embedding(x)
-        patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
-        embed_dim = patch_embeddings.shape[-1]
-        cls_embeddings = mx.broadcast_to(
-            self.class_embedding, (batch_size, 1, embed_dim)
-        )
-        embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
-        embeddings += self.position_embedding.weight
-        return embeddings
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.config = config
+    self.embed_dim = config.hidden_size
+    self.image_size = config.image_size
+    self.patch_size = config.patch_size
+
+    self.class_embedding = mx.zeros((config.hidden_size, ))
+
+    self.patch_embedding = nn.Conv2d(
+      in_channels=config.num_channels,
+      out_channels=self.embed_dim,
+      kernel_size=self.patch_size,
+      stride=self.patch_size,
+      bias=False,
+    )
+
+    self.num_patches = (self.image_size // self.patch_size)**2
+    self.num_positions = self.num_patches + 1
+    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    batch_size = x.shape[0]
+    patch_embeddings = self.patch_embedding(x)
+    patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
+    embed_dim = patch_embeddings.shape[-1]
+    cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim))
+    embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
+    embeddings += self.position_embedding.weight
+    return embeddings
 
 
 class ClipVisionModel(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.embeddings = VisionEmbeddings(config)
-        self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
-        self.encoder = VisionEncoder(config)
-        self.post_layernorm = nn.LayerNorm(config.hidden_size)
 
-    def __call__(
-        self,
-        x: mx.array,
-        output_hidden_states: Optional[bool] = None,
-    ) -> mx.array:
-        x = self.embeddings(x)
-        x = self.pre_layrnorm(x)
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embeddings = VisionEmbeddings(config)
+    self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
+    self.encoder = VisionEncoder(config)
+    self.post_layernorm = nn.LayerNorm(config.hidden_size)
+
+  def __call__(
+    self,
+    x: mx.array,
+    output_hidden_states: Optional[bool] = None,
+  ) -> mx.array:
+    x = self.embeddings(x)
+    x = self.pre_layrnorm(x)
 
-        encoder_states = (x,) if output_hidden_states else None
+    encoder_states = (x, ) if output_hidden_states else None
 
-        for l in self.encoder.layers:
-            x = l(x, mask=None)
-            if output_hidden_states:
-                encoder_states = encoder_states + (x,)
+    for l in self.encoder.layers:
+      x = l(x, mask=None)
+      if output_hidden_states:
+        encoder_states = encoder_states + (x, )
 
-        pooler_output = self.post_layernorm(x[:, 0, :])
-        return pooler_output, x, encoder_states
+    pooler_output = self.post_layernorm(x[:, 0, :])
+    return pooler_output, x, encoder_states
 
 
 class VisionModel(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-
-        self.model_type = config.model_type
-        if self.model_type != "clip_vision_model":
-            raise ValueError(f"Unsupported model type: {self.model_type}")
-
-        self.vision_model = ClipVisionModel(config)
-
-    def __call__(
-            self, x: mx.array, output_hidden_states: Optional[bool] = None
-    ) -> mx.array:
-        return self.vision_model(x, output_hidden_states)
-
-    def sanitize(self, weights):
-        sanitized_weights = {}
-        for k, v in weights.items():
-            if "position_ids" in k:
-                # Remove unused position_ids
-                continue
-            elif "patch_embedding.weight" in k:
-                # PyTorch conv2d weight tensors have shape:
-                #   [out_channels, in_channels, kH, KW]
-                # MLX conv2d expects the weight be of shape:
-                #   [out_channels, kH, KW, in_channels]
-                sanitized_weights[k] = v.transpose(0, 2, 3, 1)
-            else:
-                sanitized_weights[k] = v
-
-        return sanitized_weights
+
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+
+    self.model_type = config.model_type
+    if self.model_type != "clip_vision_model":
+      raise ValueError(f"Unsupported model type: {self.model_type}")
+
+    self.vision_model = ClipVisionModel(config)
+
+  def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array:
+    return self.vision_model(x, output_hidden_states)
+
+  def sanitize(self, weights):
+    sanitized_weights = {}
+    for k, v in weights.items():
+      if "position_ids" in k:
+        # Remove unused position_ids
+        continue
+      elif "patch_embedding.weight" in k:
+        # PyTorch conv2d weight tensors have shape:
+        #   [out_channels, in_channels, kH, KW]
+        # MLX conv2d expects the weight be of shape:
+        #   [out_channels, kH, KW, in_channels]
+        sanitized_weights[k] = v.transpose(0, 2, 3, 1)
+      else:
+        sanitized_weights[k] = v
+
+    return sanitized_weights
 
 
 @dataclass
 class TextConfig:
-    model_type: str
-    hidden_size: int = 4096
-    num_hidden_layers: int = 32
-    intermediate_size: int = 11008
-    num_attention_heads: int = 32
-    head_dim: int = None
-    rms_norm_eps: float = 1e-6
-    vocab_size: int = 32000
-    num_key_value_heads: int = None
-    rope_theta: float = 10000
-    rope_traditional: bool = False
-    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
-
-    def __post_init__(self):
-        if self.num_key_value_heads is None:
-            self.num_key_value_heads = self.num_attention_heads
-
-        if self.head_dim is None:
-            self.head_dim = self.hidden_size // self.num_attention_heads
-
-        if self.model_type is None:
-            self.model_type = "llama"
-
-        if self.rope_scaling:
-            required_keys = {"factor", "type"}
-            if not all(key in self.rope_scaling for key in required_keys):
-                raise ValueError(f"rope_scaling must contain keys {required_keys}")
-
-            if self.rope_scaling["type"] != "linear":
-                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+  model_type: str
+  hidden_size: int = 4096
+  num_hidden_layers: int = 32
+  intermediate_size: int = 11008
+  num_attention_heads: int = 32
+  head_dim: int = None
+  rms_norm_eps: float = 1e-6
+  vocab_size: int = 32000
+  num_key_value_heads: int = None
+  rope_theta: float = 10000
+  rope_traditional: bool = False
+  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+  def __post_init__(self):
+    if self.num_key_value_heads is None:
+      self.num_key_value_heads = self.num_attention_heads
+
+    if self.head_dim is None:
+      self.head_dim = self.hidden_size // self.num_attention_heads
+
+    if self.model_type is None:
+      self.model_type = "llama"
+
+    if self.rope_scaling:
+      required_keys = {"factor", "type"}
+      if not all(key in self.rope_scaling for key in required_keys):
+        raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+      if self.rope_scaling["type"] != "linear":
+        raise ValueError("rope_scaling 'type' currently only supports 'linear'")
 
 
 class TextAttention(nn.Module):
-    def __init__(self, config: TextConfig):
-        super().__init__()
-
-        dim = config.hidden_size
-        self.n_heads = n_heads = config.num_attention_heads
-        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
-
-        self.repeats = n_heads // n_kv_heads
-
-        head_dim = config.hidden_size // n_heads
-        self.scale = head_dim ** -0.5
-
-        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
-        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
-
-        rope_scale = (
-            1 / config.rope_scaling["factor"]
-            if config.rope_scaling is not None
-               and config.rope_scaling["type"] == "linear"
-            else 1
-        )
-        self.rope = nn.RoPE(
-            head_dim,
-            traditional=config.rope_traditional,
-            base=config.rope_theta,
-            scale=rope_scale,
-        )
-
-    def __call__(
-            self,
-            x: mx.array,
-            mask: Optional[mx.array] = None,
-            cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        B, L, D = x.shape
-
-        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-        # Prepare the queries, keys and values for the attention computation
-        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-        if cache is not None:
-            queries = self.rope(queries, offset=cache.offset)
-            keys = self.rope(keys, offset=cache.offset)
-            keys, values = cache.update_and_fetch(keys, values)
-        else:
-            queries = self.rope(queries)
-            keys = self.rope(keys)
-
-        output = mx.fast.scaled_dot_product_attention(
-            queries, keys, values, scale=self.scale, mask=mask
-        )
-        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-        return self.o_proj(output)
+
+  def __init__(self, config: TextConfig):
+    super().__init__()
+
+    dim = config.hidden_size
+    self.n_heads = n_heads = config.num_attention_heads
+    self.n_kv_heads = n_kv_heads = config.num_key_value_heads
+
+    self.repeats = n_heads // n_kv_heads
+
+    head_dim = config.hidden_size // n_heads
+    self.scale = head_dim**-0.5
+
+    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
+    self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+    self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+
+    rope_scale = (1 / config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
+    self.rope = nn.RoPE(
+      head_dim,
+      traditional=config.rope_traditional,
+      base=config.rope_theta,
+      scale=rope_scale,
+    )
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    B, L, D = x.shape
+
+    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+    # Prepare the queries, keys and values for the attention computation
+    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+    if cache is not None:
+      queries = self.rope(queries, offset=cache.offset)
+      keys = self.rope(keys, offset=cache.offset)
+      keys, values = cache.update_and_fetch(keys, values)
+    else:
+      queries = self.rope(queries)
+      keys = self.rope(keys)
+
+    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
+    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+    return self.o_proj(output)
 
 
 class TextMLP(nn.Module):
-    def __init__(self, dim, hidden_dim):
-        super().__init__()
-        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
-        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
-        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 
-    def __call__(self, x) -> mx.array:
-        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+  def __init__(self, dim, hidden_dim):
+    super().__init__()
+    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
+    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
+
+  def __call__(self, x) -> mx.array:
+    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
 
 
 class TransformerBlock(nn.Module):
-    def __init__(self, config: TextConfig):
-        super().__init__()
-        self.num_attention_heads = config.num_attention_heads
-        self.hidden_size = config.hidden_size
-        self.self_attn = TextAttention(config)
-        self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
-        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-        self.post_attention_layernorm = nn.RMSNorm(
-            config.hidden_size, eps=config.rms_norm_eps
-        )
-        self.config = config
-
-    def __call__(
-            self,
-            x: mx.array,
-            mask: Optional[mx.array] = None,
-            cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        r = self.self_attn(self.input_layernorm(x), mask, cache)
-        h = x + r
-        r = self.mlp(self.post_attention_layernorm(h))
-        out = h + r
-        return out
+
+  def __init__(self, config: TextConfig):
+    super().__init__()
+    self.num_attention_heads = config.num_attention_heads
+    self.hidden_size = config.hidden_size
+    self.self_attn = TextAttention(config)
+    self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
+    self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.config = config
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    r = self.self_attn(self.input_layernorm(x), mask, cache)
+    h = x + r
+    r = self.mlp(self.post_attention_layernorm(h))
+    out = h + r
+    return out
 
 
 class Llama(nn.Module):
-    def __init__(self, config: TextConfig, shard: Shard):
-        super().__init__()
-        self.config = config
-        self.shard = shard
-        self.vocab_size = config.vocab_size
-        self.model_type = config.model_type
-        self.num_hidden_layers = config.num_hidden_layers
-        self.num_key_value_heads = config.num_key_value_heads
-        self.head_dim = config.head_dim
-        assert self.vocab_size > 0
-        if self.shard.is_first_layer():
-            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
-        self.layers = []
-        for i in range(self.num_hidden_layers):
-          if self.shard.start_layer <= i <= self.shard.end_layer:
-            self.layers.append(TransformerBlock(config=config))
-          else:
-            self.layers.append(IdentityBlock())
-        if self.shard.is_last_layer():
-            self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
-    def __call__(
-            self,
-            inputs: mx.array,
-            cache=None,
-            inputs_embeds=None,
-    ):
-        # for passing merged input embeddings
-        if inputs_embeds is None:
-            if self.shard.is_first_layer():
-                h = self.embed_tokens(inputs)
-            else:
-                h = inputs
-        else:
-            h = inputs_embeds
-
-        mask = None
-        if h.shape[1] > 1:
-            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
-            mask = mask.astype(h.dtype)
-
-        if cache is None:
-            cache = [None] * len(self.layers)
-
-
-        for layer, c in zip(self.layers, cache):
-            h = layer(h, mask, c)
-
-        if self.shard.is_last_layer():
-            h = self.norm(h)
-        return h
+
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.config = config
+    self.shard = shard
+    self.vocab_size = config.vocab_size
+    self.model_type = config.model_type
+    self.num_hidden_layers = config.num_hidden_layers
+    self.num_key_value_heads = config.num_key_value_heads
+    self.head_dim = config.head_dim
+    assert self.vocab_size > 0
+    if self.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.shard.start_layer <= i <= self.shard.end_layer:
+        self.layers.append(TransformerBlock(config=config))
+      else:
+        self.layers.append(IdentityBlock())
+    if self.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    # for passing merged input embeddings
+    if inputs_embeds is None:
+      if self.shard.is_first_layer():
+        h = self.embed_tokens(inputs)
+      else:
+        h = inputs
+    else:
+      h = inputs_embeds
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
 
 class LanguageModel(nn.Module):
-    def __init__(self, config: TextConfig, shard: Shard):
-        super().__init__()
-        self.model_type = config.model_type
-        if self.model_type != "llama":
-            raise ValueError(
-                f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
-            )
-        self.shard = shard
-        self.model = Llama(config, shard)
-        if self.shard.is_last_layer():
-            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-        inputs_embeds=None,
-    ):
-        out = self.model(inputs, cache, inputs_embeds)
-        if self.shard.is_last_layer():
-            out = self.lm_head(out)
-        return out
-
-    def sanitize(self, weights):
-        shard_state_dict = {}
-        for key, value in weights.items():
-            if "self_attn.rotary_emb.inv_freq" in key:
-                continue
-
-            if key.startswith('language_model.model.layers.'):
-                layer_num = int(key.split('.')[3])
-                if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
-                    continue
-            if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
-                continue
-            elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
-                continue
-
-            shard_state_dict[key] = value
-
-        return shard_state_dict
+
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.model_type = config.model_type
+    if self.model_type != "llama":
+      raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported")
+    self.shard = shard
+    self.model = Llama(config, shard)
+    if self.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    out = self.model(inputs, cache, inputs_embeds)
+    if self.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+
+      if key.startswith('language_model.model.layers.'):
+        layer_num = int(key.split('.')[3])
+        if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+          continue
+      if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
+        continue
+      elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
+        continue
+
+      shard_state_dict[key] = value
+
+    return shard_state_dict
+
 
 @dataclass
 class LlaVAConfig(BaseModelArgs):
-    text_config: TextConfig
-    vision_config: VisionConfig = None
-    model_type: str = "llava"
-    ignore_index: int = -100
-    image_token_index: int = 32000
-    vision_feature_select_strategy: str = "default"
-    vision_feature_layer: int = -2
-    vocab_size: int = 32000
-
-    @classmethod
-    def from_dict(cls, params):
-        updated_params = {}
-        class_params = inspect.signature(cls).parameters
-        for k, v in params.items():
-            if k in class_params:
-                if k in ["text_config", "vision_config"]:
-                    v = class_params[k].annotation.from_dict(v)
-                updated_params.update({k: v})
-
-        return cls(**updated_params)
+  text_config: TextConfig
+  vision_config: VisionConfig = None
+  model_type: str = "llava"
+  ignore_index: int = -100
+  image_token_index: int = 32000
+  vision_feature_select_strategy: str = "default"
+  vision_feature_layer: int = -2
+  vocab_size: int = 32000
+
+  @classmethod
+  def from_dict(cls, params):
+    updated_params = {}
+    class_params = inspect.signature(cls).parameters
+    for k, v in params.items():
+      if k in class_params:
+        if k in ["text_config", "vision_config"]:
+          v = class_params[k].annotation.from_dict(v)
+        updated_params.update({k: v})
+
+    return cls(**updated_params)
 
 
 @dataclass
 class ModelArgs(LlaVAConfig):
-    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
 
-    def __post_init__(self):
-        if isinstance(self.shard, dict):
-            self.shard = Shard(**self.shard)
+  def __post_init__(self):
+    if isinstance(self.shard, dict):
+      self.shard = Shard(**self.shard)
 
-        if not isinstance(self.shard, Shard):
-            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+    if not isinstance(self.shard, Shard):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
 
-        if not self.shard.is_first_layer():
-            self.vision_config = None
+    if not self.shard.is_first_layer():
+      self.vision_config = None
 
 
 class LlavaMultiModalProjector(nn.Module):
-    def __init__(self, config: LlaVAConfig):
-        super().__init__()
-        self.linear_1 = nn.Linear(
-            config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
-        )
-        self.gelu = nn.GELU()
-        self.linear_2 = nn.Linear(
-            config.text_config.hidden_size, config.text_config.hidden_size, bias=True
-        )
-
-    def __call__(self, x: mx.array) -> mx.array:
-        x = self.linear_1(x)
-        x = self.gelu(x)
-        x = self.linear_2(x)
-        return x
+
+  def __init__(self, config: LlaVAConfig):
+    super().__init__()
+    self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+    self.gelu = nn.GELU()
+    self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.linear_1(x)
+    x = self.gelu(x)
+    x = self.linear_2(x)
+    return x
 
 
 class Model(nn.Module):
-    def __init__(self, config: ModelArgs):
-        super().__init__()
-        self.config = config
-        self.model_type = config.model_type
-        if config.vision_config:
-            self.vision_tower = VisionModel(config.vision_config)
-            self.multi_modal_projector = LlavaMultiModalProjector(config)
-            self.vision_feature_layer = config.vision_feature_layer
-            self.vision_feature_select_strategy = config.vision_feature_select_strategy
-        self.language_model = LanguageModel(config.text_config, config.shard)
-
-    def get_input_embeddings(
-            self,
-            input_ids: Optional[mx.array] = None,
-            pixel_values: Optional[mx.array] = None,
-    ):
-        if pixel_values is None:
-            return self.language_model(input_ids)
-
-        # Get the input embeddings from the language model
-        inputs_embeds = self.language_model.model.embed_tokens(input_ids)
-
-        # Get the ouptut hidden states from the vision model
-        *_, hidden_states = self.vision_tower(
-            pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
-        )
-
-        # Select the hidden states from the desired layer
-        selected_image_feature = hidden_states[self.vision_feature_layer]
-
-        if self.vision_feature_select_strategy == "default":
-            selected_image_feature = selected_image_feature[:, 1:]
-        elif self.vision_feature_select_strategy == "full":
-            selected_image_feature = selected_image_feature
-        else:
-            raise ValueError(
-                "Unexpected feature selection strategy: "
-                f"{self.vision_feature_select_strategy}"
-            )
-
-        # Pass image features through the multi-modal projector
-        image_features = self.multi_modal_projector(selected_image_feature)
-
-        # Insert special image tokens in the input_ids
-        final_inputs_embeds = self._merge_input_ids_with_image_features(
-            image_features, inputs_embeds, input_ids
-        )
-        return final_inputs_embeds
-
-    def _merge_input_ids_with_image_features(
-            self, image_features, inputs_embeds, input_ids
-    ):
-        image_token_index = self.config.image_token_index
-        num_images, num_image_patches, embed_dim = image_features.shape
-
-        # Positions of <image> tokens in input_ids, assuming batch size is 1
-        image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
-
-        if len(image_positions) != num_images:
-            raise ValueError(
-                f"The number of image tokens ({len(image_positions)}) does not "
-                f" match the number of image inputs ({num_images})."
-            )
-
-        text_segments = []
-        start_idx = 0
-
-        for position in image_positions:
-            text_segments.append(inputs_embeds[:, start_idx:position])
-            start_idx = position + 1
-
-        image_embeddings = mx.split(image_features, image_features.shape[0])
-        final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
-        final_embeddings += [inputs_embeds[:, start_idx:]]
-
-        # Create a final embedding of shape
-        # (1, num_image_patches*num_images + sequence_len, embed_dim)
-        return mx.concatenate(final_embeddings, axis=1)
-
-    def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
-        input_embddings = None
-        if pixel_values is not None:
-            input_embddings = self.get_input_embeddings(input_ids, pixel_values)
-        logits = self.language_model(
-            input_ids, cache=cache, inputs_embeds=input_embddings
-        )
-        return logits
-
-    def sanitize(self, weights):
-        if self.config.vision_config:
-          weights = self.vision_tower.sanitize(weights)
-        else:
-          weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
-        weights = self.language_model.sanitize(weights)
-        return weights
-
-    @property
-    def layers(self):
-        return self.language_model.model.layers
-
-    @property
-    def head_dim(self):
-        return (
-                self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
-        )
-
-    @property
-    def n_kv_heads(self):
-        return self.language_model.model.num_key_value_heads
+
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.config = config
+    self.model_type = config.model_type
+    if config.vision_config:
+      self.vision_tower = VisionModel(config.vision_config)
+      self.multi_modal_projector = LlavaMultiModalProjector(config)
+      self.vision_feature_layer = config.vision_feature_layer
+      self.vision_feature_select_strategy = config.vision_feature_select_strategy
+    self.language_model = LanguageModel(config.text_config, config.shard)
+
+  def get_input_embeddings(
+    self,
+    input_ids: Optional[mx.array] = None,
+    pixel_values: Optional[mx.array] = None,
+  ):
+    if pixel_values is None:
+      return self.language_model(input_ids)
+
+    # Get the input embeddings from the language model
+    inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+
+    # Get the ouptut hidden states from the vision model
+    *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)
+
+    # Select the hidden states from the desired layer
+    selected_image_feature = hidden_states[self.vision_feature_layer]
+
+    if self.vision_feature_select_strategy == "default":
+      selected_image_feature = selected_image_feature[:, 1:]
+    elif self.vision_feature_select_strategy == "full":
+      selected_image_feature = selected_image_feature
+    else:
+      raise ValueError("Unexpected feature selection strategy: "
+                       f"{self.vision_feature_select_strategy}")
+
+    # Pass image features through the multi-modal projector
+    image_features = self.multi_modal_projector(selected_image_feature)
+
+    # Insert special image tokens in the input_ids
+    final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
+    return final_inputs_embeds
+
+  def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
+    image_token_index = self.config.image_token_index
+    num_images, num_image_patches, embed_dim = image_features.shape
+
+    # Positions of <image> tokens in input_ids, assuming batch size is 1
+    image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
+
+    if len(image_positions) != num_images:
+      raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
+                       f" match the number of image inputs ({num_images}).")
+
+    text_segments = []
+    start_idx = 0
+
+    for position in image_positions:
+      text_segments.append(inputs_embeds[:, start_idx:position])
+      start_idx = position + 1
+
+    image_embeddings = mx.split(image_features, image_features.shape[0])
+    final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
+    final_embeddings += [inputs_embeds[:, start_idx:]]
+
+    # Create a final embedding of shape
+    # (1, num_image_patches*num_images + sequence_len, embed_dim)
+    return mx.concatenate(final_embeddings, axis=1)
+
+  def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
+    input_embddings = None
+    if pixel_values is not None:
+      input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+    logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
+    return logits
+
+  def sanitize(self, weights):
+    if self.config.vision_config:
+      weights = self.vision_tower.sanitize(weights)
+    else:
+      weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
+    weights = self.language_model.sanitize(weights)
+    return weights
+
+  @property
+  def layers(self):
+    return self.language_model.model.layers
+
+  @property
+  def head_dim(self):
+    return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)
+
+  @property
+  def n_kv_heads(self):
+    return self.language_model.model.num_key_value_heads

+ 1 - 0
exo/inference/mlx/sharded_inference_engine.py

@@ -9,6 +9,7 @@ from exo.download.shard_download import ShardDownloader
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
+
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader

+ 4 - 9
exo/inference/mlx/sharded_model.py

@@ -10,6 +10,7 @@ from ..shard import Shard
 
 
 class StatefulShardedModel:
+
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
     self.shard = shard
     self.model = model
@@ -26,6 +27,7 @@ class StatefulShardedModel:
     top_p: float = 1.0,
     logit_bias: Optional[Dict[int, float]] = None,
   ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+
     def sample(logits: mx.array) -> Tuple[mx.array, float]:
       if logit_bias:
         indices = mx.array(list(logit_bias.keys()))
@@ -74,16 +76,9 @@ class StatefulShardedModel:
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
   def init_cache(self, request_id: str):
-    kv_heads = (
-      [self.model.n_kv_heads] * len(self.model.layers)
-      if isinstance(self.model.n_kv_heads, int)
-      else self.model.n_kv_heads
-    )
+    kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
     if self.max_kv_size is not None:
-      cache = [
-        RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4)
-        for n in kv_heads
-      ]
+      cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
     else:
       cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
 

+ 29 - 25
exo/inference/mlx/sharded_utils.py

@@ -25,6 +25,7 @@ from ..shard import Shard
 
 
 class ModelNotFoundError(Exception):
+
   def __init__(self, message):
     self.message = message
     super().__init__(self.message)
@@ -139,9 +140,10 @@ def load_model_shard(
   if (quantization := config.get("quantization", None)) is not None:
     # Handle legacy models which may not have everything quantized
     def class_predicate(p, m):
-        if not hasattr(m, "to_quantized"):
-            return False
-        return f"{p}.scales" in weights
+      if not hasattr(m, "to_quantized"):
+        return False
+      return f"{p}.scales" in weights
+
     nn.quantize(
       model,
       **quantization,
@@ -156,6 +158,7 @@ def load_model_shard(
   model.eval()
   return model
 
+
 async def load_shard(
   model_path: str,
   shard: Shard,
@@ -179,26 +182,27 @@ async def load_shard(
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     return model, tokenizer
 
+
 async def get_image_from_str(_image_str: str):
-    image_str = _image_str.strip()
-
-    if image_str.startswith("http"):
-        async with aiohttp.ClientSession() as session:
-            async with session.get(image_str, timeout=10) as response:
-                content = await response.read()
-                return Image.open(BytesIO(content)).convert("RGB")
-    elif image_str.startswith("data:image/"):
-        # Extract the image format and base64 data
-        format_prefix, base64_data = image_str.split(";base64,")
-        image_format = format_prefix.split("/")[1].lower()
-        if DEBUG >= 2: print(f"{image_str=} {image_format=}")
-        imgdata = base64.b64decode(base64_data)
-        img = Image.open(BytesIO(imgdata))
-
-        # Convert to RGB if not already
-        if img.mode != "RGB":
-            img = img.convert("RGB")
-
-        return img
-    else:
-        raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
+  image_str = _image_str.strip()
+
+  if image_str.startswith("http"):
+    async with aiohttp.ClientSession() as session:
+      async with session.get(image_str, timeout=10) as response:
+        content = await response.read()
+        return Image.open(BytesIO(content)).convert("RGB")
+  elif image_str.startswith("data:image/"):
+    # Extract the image format and base64 data
+    format_prefix, base64_data = image_str.split(";base64,")
+    image_format = format_prefix.split("/")[1].lower()
+    if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+    imgdata = base64.b64decode(base64_data)
+    img = Image.open(BytesIO(imgdata))
+
+    # Convert to RGB if not already
+    if img.mode != "RGB":
+      img = img.convert("RGB")
+
+    return img
+  else:
+    raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")

+ 6 - 6
exo/inference/mlx/test_sharded_llava.py

@@ -39,8 +39,8 @@ y = full.step("full", input_ids, pixel_values, temp=0)
 full_generated_tokens = [y.item()]
 
 for _ in range(13):
-    y = full.step("full", y, temp=0)
-    full_generated_tokens.append(y.item())
+  y = full.step("full", y, temp=0)
+  full_generated_tokens.append(y.item())
 
 full_response = full_processor.tokenizer.decode(full_generated_tokens)
 print("full response:", full_response)
@@ -54,11 +54,11 @@ y = m2.step("shard", y, temp=0)
 full_generated_tokens = [y.item()]
 
 for _ in range(13):
-    y = m1.step("shard", y, temp=0)
-    y = m2.step("shard", y, temp=0)
-    full_generated_tokens.append(y.item())
+  y = m1.step("shard", y, temp=0)
+  y = m2.step("shard", y, temp=0)
+  full_generated_tokens.append(y.item())
 
 sharded_response = processor2.tokenizer.decode(full_generated_tokens)
 print("sharded response:", sharded_response)
 
-assert full_response == sharded_response
+assert full_response == sharded_response

+ 2 - 1
exo/inference/mlx/test_sharded_model.py

@@ -6,6 +6,7 @@ import numpy as np
 
 
 class DummyModel(nn.Module):
+
   def __init__(self, shard: Optional[Shard] = None):
     self.shard = shard
     self.layers = [
@@ -21,7 +22,7 @@ class DummyModel(nn.Module):
 
   def __call__(self, x, cache=None):
     if self.shard:
-      for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
+      for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
         x = layer(x)
       if self.shard.is_last_layer():
         x = x.reshape((1, 2, 4))

+ 2 - 4
exo/inference/shard.py

@@ -34,8 +34,6 @@ class Shard:
   def overlaps(self, other: 'Shard') -> bool:
     return shards_overlap(self, other)
 
+
 def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
-  return (
-      shard1.model_id == shard2.model_id
-      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
-  )
+  return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))

+ 13 - 11
exo/inference/test_inference_engine.py

@@ -7,6 +7,7 @@ import os
 import asyncio
 import numpy as np
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
@@ -22,7 +23,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
     input_data=resp1,
     inference_state=inference_state_1,
   )
@@ -34,7 +35,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   )
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
     input_data=resp3,
     inference_state=inference_state_3,
   )
@@ -42,21 +43,22 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
 
-asyncio.run(
-  test_inference_engine(
-    MLXDynamicShardInferenceEngine(HFShardDownloader()),
-    MLXDynamicShardInferenceEngine(HFShardDownloader()),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-  )
-)
+
+asyncio.run(test_inference_engine(
+  MLXDynamicShardInferenceEngine(HFShardDownloader()),
+  MLXDynamicShardInferenceEngine(HFShardDownloader()),
+  "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
   import os
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-  asyncio.run(test_inference_engine(
+  asyncio.run(
+    test_inference_engine(
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-  ))
+    )
+  )

+ 8 - 11
exo/inference/tinygrad/inference.py

@@ -20,16 +20,11 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 MODEL_PARAMS = {
-  "8B": {
-    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
-    "files": 1
-  },
-  "70B": {
-    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
-    "files": 8
-  }
+  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 
+
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   linear = nn.Linear
@@ -48,10 +43,12 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
   with Context(BEAM=0):
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=False) # consume=True
+    load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
 
+
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
+
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
@@ -64,7 +61,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     toks = self.tokenizer.encode(prompt)
     h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
 
-    if h.shape == (1,):
+    if h.shape == (1, ):
       start_pos += len(toks)
       start_pos += 1
       n_captured_toks = 0
@@ -80,7 +77,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
     h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
-    if h.shape == (1,):
+    if h.shape == (1, ):
       start_pos += n_captured_toks
       start_pos += 1
       n_captured_toks = 0

+ 73 - 34
exo/inference/tinygrad/models/llama.py

@@ -2,21 +2,24 @@ from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 
+
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
+  freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
+
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
-  a,b = A[..., 0:1], A[..., 1:2]
-  ro = a*c - b*d
-  co = a*d + b*c
+  a, b = A[..., 0:1], A[..., 1:2]
+  ro = a * c - b * d
+  co = a * d + b * c
   return ro.cat(co, dim=-1)
 
-def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -26,16 +29,19 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
 
-def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
+
 class Attention:
+
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
@@ -45,7 +51,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
@@ -69,10 +75,10 @@ class Attention:
 
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -80,26 +86,31 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
 
+
 class FeedForward:
-  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
+
+  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
+    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
+
+  def __call__(self, x: Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
 
-  def __call__(self, x:Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 class TransformerBlock:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
+
+  def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
+
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -127,8 +138,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
-      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
-      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
+      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), ))
+      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), ))
       t = (counter == t_argmax).where(0, t)
 
     # approximate top p
@@ -149,10 +160,28 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
   return output_token
 
+
 from exo.inference.shard import Shard
 
+
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_layers: int,
+    norm_eps: float,
+    vocab_size,
+    shard: Shard = None,
+    linear=nn.Linear,
+    n_kv_heads=None,
+    rope_theta=10000,
+    max_context=1024,
+    jit=True,
+    feed_forward=FeedForward
+  ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
@@ -162,10 +191,10 @@ class Transformer:
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
 
-  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
     seqlen = x.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
-    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
     if self.shard.is_first_layer():
       h = self.tok_embeddings(x)
@@ -182,24 +211,33 @@ class Transformer:
     else:
       return h
 
-  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
+  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
       return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
+
 # *** helpers ***
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+
+def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
   keymap = {
     "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
+       for x in ["q", "k", "v", "o"]
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
+       for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
+       for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
   }
@@ -215,9 +253,10 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     sd[keymap[k]] = v
   return sd
 
-def fix_bf16(weights:Dict[Any, Tensor]):
+
+def fix_bf16(weights: Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
-  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 7 - 2
exo/inference/tinygrad/tinygrad_helpers.py

@@ -8,8 +8,10 @@ from exo.helpers import DEBUG
 from exo.download.hf.hf_helpers import get_allow_patterns
 from fnmatch import fnmatch
 
+
 # **** helper functions ****
 def concat_weights(models, device=None):
+
   def convert(name) -> Tensor:
     disk_tensors: List[Tensor] = [model[name] for model in models]
     if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
@@ -17,11 +19,14 @@ def concat_weights(models, device=None):
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+
   return {name: convert(name) for name in {name: None for model in models for name in model}}
 
-def load(fn:str, shard: Shard):
+
+def load(fn: str, shard: Shard):
   if fn.endswith('.index.json'):
-    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    with open(fn) as fp:
+      weight_map = json.load(fp)['weight_map']
     parts = {}
     filtered_weight_map = {}
     allow_patterns = get_allow_patterns(weight_map, shard)

+ 1 - 0
exo/inference/tokenizers.py

@@ -2,6 +2,7 @@ import traceback
 from transformers import AutoTokenizer, AutoProcessor
 from exo.helpers import DEBUG
 
+
 async def resolve_tokenizer(model_id: str):
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")

+ 25 - 32
exo/models.py

@@ -2,39 +2,32 @@ from exo.inference.shard import Shard
 
 model_base_shards = {
   ### llama
-  "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
-  },
-  "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "llama-3.1-8b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    },
+  "llama-3.1-70b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+    },
+  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), },
+  "llama-3-8b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    },
+  "llama-3-70b":
+    {
+      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+      "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+    },
   ### mistral
-  "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-  },
-  "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-  },
+  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), },
+  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), },
   ### deepseek v2
-  "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
-  },
+  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), },
   ### llava
-  "llava-1.5-7b-hf": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
-  },
+  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), },
 }
-

+ 1 - 0
exo/networking/discovery.py

@@ -4,6 +4,7 @@ from .peer_handle import PeerHandle
 
 
 class Discovery(ABC):
+
   @abstractmethod
   async def start(self) -> None:
     pass

+ 11 - 11
exo/networking/grpc/grpc_discovery.py

@@ -11,6 +11,7 @@ from exo import DEBUG_DISCOVERY
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
+
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
     self.on_message = on_message
@@ -24,6 +25,7 @@ class ListenProtocol(asyncio.DatagramProtocol):
 
 
 class GRPCDiscovery(Discovery):
+
   def __init__(
     self,
     node_id: str,
@@ -97,14 +99,12 @@ class GRPCDiscovery(Discovery):
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
 
-    message = json.dumps(
-      {
-        "type": "discovery",
-        "node_id": self.node_id,
-        "grpc_port": self.node_port,
-        "device_capabilities": self.device_capabilities.to_dict(),
-      }
-    ).encode("utf-8")
+    message = json.dumps({
+      "type": "discovery",
+      "node_id": self.node_id,
+      "grpc_port": self.node_port,
+      "device_capabilities": self.device_capabilities.to_dict(),
+    }).encode("utf-8")
 
     while True:
       try:
@@ -166,14 +166,14 @@ class GRPCDiscovery(Discovery):
       try:
         current_time = time.time()
         peers_to_remove = [
-          peer_handle.id()
-          for peer_handle, connected_at, last_seen in self.known_peers.values()
+          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
         if DEBUG_DISCOVERY >= 2:
           print(
             "Peer statuses:",
-            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
+            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
+             for peer_handle, connected_at, last_seen in self.known_peers.values()},
           )
         if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
           print(f"Cleaning up peers: {peers_to_remove}")

+ 1 - 0
exo/networking/grpc/grpc_peer_handle.py

@@ -13,6 +13,7 @@ from exo.topology.device_capabilities import DeviceCapabilities
 
 
 class GRPCPeerHandle(PeerHandle):
+
   def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self.address = address

+ 9 - 9
exo/networking/grpc/grpc_server.py

@@ -11,6 +11,7 @@ from exo.orchestration import Node
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
+
   def __init__(self, node: Node, host: str, port: int):
     self.node = node
     self.host = host
@@ -81,9 +82,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       node_service_pb2.InferenceResult(
         tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
         is_finished=result[1],
-      )
-      if result[0] is not None
-      else node_service_pb2.InferenceResult(is_finished=result[1])
+      ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
     )
 
   async def CollectTopology(self, request, context):
@@ -91,12 +90,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     visited = set(request.visited)
     topology = await self.node.collect_topology(visited, max_depth)
     nodes = {
-      node_id: node_service_pb2.DeviceCapabilities(
-        model=cap.model,
-        chip=cap.chip,
-        memory=cap.memory,
-        flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
-      )
+      node_id:
+        node_service_pb2.DeviceCapabilities(
+          model=cap.model,
+          chip=cap.chip,
+          memory=cap.memory,
+          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
+        )
       for node_id, cap in topology.nodes.items()
     }
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 0 - 3
exo/networking/grpc/node_service_pb2.py


+ 229 - 271
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,306 +12,264 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-    from grpc._utilities import first_version_is_lower
-    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+  from grpc._utilities import first_version_is_lower
+  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-    _version_not_supported = True
+  _version_not_supported = True
 
 if _version_not_supported:
-    warnings.warn(
-        f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
-        + f' grpcio>={GRPC_GENERATED_VERSION}.'
-        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
-        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
-    )
+  warnings.warn(
+    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
+    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
+    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
+  )
 
 
 class NodeServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
+  """Missing associated documentation comment in .proto file."""
 
-    def __init__(self, channel):
-        """Constructor.
+  def __init__(self, channel):
+    """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-        self.SendPrompt = channel.unary_unary(
-                '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.SendTensor = channel.unary_unary(
-                '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
-        self.CollectTopology = channel.unary_unary(
-                '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
-                _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.SendOpaqueStatus = channel.unary_unary(
-                '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
+    self.SendPrompt = channel.unary_unary(
+      '/node_service.NodeService/SendPrompt',
+      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.SendTensor = channel.unary_unary(
+      '/node_service.NodeService/SendTensor',
+      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.GetInferenceResult = channel.unary_unary(
+      '/node_service.NodeService/GetInferenceResult',
+      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.InferenceResult.FromString,
+      _registered_method=True
+    )
+    self.CollectTopology = channel.unary_unary(
+      '/node_service.NodeService/CollectTopology',
+      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Topology.FromString,
+      _registered_method=True
+    )
+    self.SendResult = channel.unary_unary(
+      '/node_service.NodeService/SendResult',
+      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.SendOpaqueStatus = channel.unary_unary(
+      '/node_service.NodeService/SendOpaqueStatus',
+      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
 
 
 class NodeServiceServicer(object):
-    """Missing associated documentation comment in .proto file."""
+  """Missing associated documentation comment in .proto file."""
 
-    def SendPrompt(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendPrompt(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendTensor(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendTensor(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def GetInferenceResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def CollectTopology(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def CollectTopology(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendOpaqueStatus(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendOpaqueStatus(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-            'SendPrompt': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'SendTensor': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
-            'CollectTopology': grpc.unary_unary_rpc_method_handler(
-                    servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
-            ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-            'node_service.NodeService', rpc_method_handlers)
-    server.add_generic_rpc_handlers((generic_handler,))
-    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+  rpc_method_handlers = {
+    'SendPrompt':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendPrompt,
+        request_deserializer=node__service__pb2.PromptRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'SendTensor':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendTensor,
+        request_deserializer=node__service__pb2.TensorRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'GetInferenceResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.GetInferenceResult,
+        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+      ),
+    'CollectTopology':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.CollectTopology,
+        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+        response_serializer=node__service__pb2.Topology.SerializeToString,
+      ),
+    'SendResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendResult,
+        request_deserializer=node__service__pb2.SendResultRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'SendOpaqueStatus':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendOpaqueStatus,
+        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+  }
+  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
+  server.add_generic_rpc_handlers((generic_handler, ))
+  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-    """Missing associated documentation comment in .proto file."""
+  """Missing associated documentation comment in .proto file."""
 
-    @staticmethod
-    def SendPrompt(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendPrompt',
+      node__service__pb2.PromptRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendTensor(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendTensor',
+      node__service__pb2.TensorRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/GetInferenceResult',
+      node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      node__service__pb2.InferenceResult.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def CollectTopology(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/CollectTopology',
+      node__service__pb2.CollectTopologyRequest.SerializeToString,
+      node__service__pb2.Topology.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendResult',
+      node__service__pb2.SendResultRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendOpaqueStatus(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendOpaqueStatus',
+      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )

+ 1 - 0
exo/networking/grpc/test_grpc_discovery.py

@@ -4,6 +4,7 @@ from .grpc_discovery import GRPCDiscovery
 
 
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
+
   async def asyncSetUp(self):
     self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
     self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)

+ 1 - 0
exo/networking/peer_handle.py

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
 
 
 class PeerHandle(ABC):
+
   @abstractmethod
   def id(self) -> str:
     pass

+ 1 - 0
exo/networking/server.py

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
 
 
 class Server(ABC):
+
   @abstractmethod
   async def start(self) -> None:
     pass

+ 1 - 0
exo/orchestration/node.py

@@ -7,6 +7,7 @@ from exo.topology.topology import Topology
 
 
 class Node(ABC):
+
   @abstractmethod
   async def start(self, wait_for_peers: int = 0) -> None:
     pass

+ 3 - 0
exo/orchestration/standard_node.py

@@ -18,6 +18,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 class StandardNode(Node):
+
   def __init__(
     self,
     _id: str,
@@ -359,6 +360,7 @@ class StandardNode(Node):
     self.on_token.trigger_all(request_id, tokens, is_finished)
 
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+
     async def send_result_to_peer(peer):
       try:
         await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
@@ -372,6 +374,7 @@ class StandardNode(Node):
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
     if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
+
     async def send_status_to_peer(peer):
       try:
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)

+ 1 - 0
exo/orchestration/test_node.py

@@ -7,6 +7,7 @@ from exo.networking.peer_handle import PeerHandle
 
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
+
   def setUp(self):
     self.mock_inference_engine = AsyncMock()
     self.mock_server = AsyncMock()

+ 1 - 0
exo/test_callbacks.py

@@ -12,6 +12,7 @@ async def main() -> None:
   callback2 = callback_system.register("callback2")
 
   def on_next_callback(name: str) -> Callable[..., None]:
+
     def callback(*args: Any) -> None:
       print(f"{name} received values: {args}")
 

+ 1 - 0
exo/topology/partitioning_strategy.py

@@ -14,6 +14,7 @@ class Partition:
 
 
 class PartitioningStrategy(ABC):
+
   @abstractmethod
   def partition(self, topology: Topology) -> List[Partition]:
     pass

+ 1 - 0
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -5,6 +5,7 @@ from .partitioning_strategy import Partition
 
 
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
+
   def partition(self, topology: Topology) -> List[Partition]:
     nodes = list(topology.all_nodes())
     nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)

+ 1 - 0
exo/topology/test_device_capabilities.py

@@ -4,6 +4,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
 
 
 class TestMacDeviceCapabilities(unittest.TestCase):
+
   @patch("subprocess.check_output")
   def test_mac_device_capabilities_pro(self, mock_check_output):
     # Mock the subprocess output

+ 1 - 0
exo/topology/test_map_partitions.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
+
   def test_map_partitions_to_shards(self):
     partitions = [
       Partition("node1", 0.0, 0.42857),

+ 1 - 0
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -6,6 +6,7 @@ from exo.topology.partitioning_strategy import Partition
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
+
   def test_partition(self):
     # triangle
     # node1 -> node2 -> node3 -> node1

+ 1 - 0
exo/topology/topology.py

@@ -3,6 +3,7 @@ from typing import Dict, Set, Optional
 
 
 class Topology:
+
   def __init__(self):
     self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
     self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph

+ 49 - 47
exo/viz/test_topology_viz.py

@@ -9,58 +9,60 @@ from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
 
 
 def create_hf_repo_progress_event(
-    completed_files: int = 5,
-    total_files: int = 10,
-    downloaded_bytes: int = 500000000,
-    downloaded_bytes_this_session: int = 250000000,
-    total_bytes: int = 1000000000,
-    overall_speed: int = 5000000,
-    overall_eta: timedelta = timedelta(seconds=100),
-    file_progress: dict = None,
-    status: str = "in_progress"
+  completed_files: int = 5,
+  total_files: int = 10,
+  downloaded_bytes: int = 500000000,
+  downloaded_bytes_this_session: int = 250000000,
+  total_bytes: int = 1000000000,
+  overall_speed: int = 5000000,
+  overall_eta: timedelta = timedelta(seconds=100),
+  file_progress: dict = None,
+  status: str = "in_progress"
 ) -> RepoProgressEvent:
-    if file_progress is None:
-        file_progress = {
-            "file1.bin": RepoFileProgressEvent(
-                repo_id="repo_id",
-                repo_revision="repo_revision",
-                file_path="file1.bin",
-                downloaded=100000000,
-                downloaded_this_session=50000000,
-                total=200000000,
-                speed=1000000,
-                eta=timedelta(seconds=100),
-                status="in_progress"
-            ),
-            "file2.bin": RepoFileProgressEvent(
-                repo_id="repo_id",
-                repo_revision="repo_revision",
-                file_path="file2.bin",
-                downloaded=200000000,
-                downloaded_this_session=100000000,
-                total=200000000,
-                speed=2000000,
-                eta=timedelta(seconds=0),
-                status="complete"
-            )
-        }
+  if file_progress is None:
+    file_progress = {
+      "file1.bin":
+        RepoFileProgressEvent(
+          repo_id="repo_id",
+          repo_revision="repo_revision",
+          file_path="file1.bin",
+          downloaded=100000000,
+          downloaded_this_session=50000000,
+          total=200000000,
+          speed=1000000,
+          eta=timedelta(seconds=100),
+          status="in_progress"
+        ), "file2.bin":
+          RepoFileProgressEvent(
+            repo_id="repo_id",
+            repo_revision="repo_revision",
+            file_path="file2.bin",
+            downloaded=200000000,
+            downloaded_this_session=100000000,
+            total=200000000,
+            speed=2000000,
+            eta=timedelta(seconds=0),
+            status="complete"
+          )
+    }
 
-    return RepoProgressEvent(
-        repo_id="repo_id",
-        repo_revision="repo_revision",
-        completed_files=completed_files,
-        total_files=total_files,
-        downloaded_bytes=downloaded_bytes,
-        downloaded_bytes_this_session=downloaded_bytes_this_session,
-        total_bytes=total_bytes,
-        overall_speed=overall_speed,
-        overall_eta=overall_eta,
-        file_progress=file_progress,
-        status=status
-    )
+  return RepoProgressEvent(
+    repo_id="repo_id",
+    repo_revision="repo_revision",
+    completed_files=completed_files,
+    total_files=total_files,
+    downloaded_bytes=downloaded_bytes,
+    downloaded_bytes_this_session=downloaded_bytes_this_session,
+    total_bytes=total_bytes,
+    overall_speed=overall_speed,
+    overall_eta=overall_eta,
+    file_progress=file_progress,
+    status=status
+  )
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
+
   async def asyncSetUp(self):
     self.topology = Topology()
     self.topology.update_node(

+ 64 - 66
exo/viz/topology_viz.py

@@ -16,7 +16,9 @@ from rich.syntax import Syntax
 from rich.panel import Panel
 from rich.markdown import Markdown
 
+
 class TopologyViz:
+
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
     self.chatgpt_api_endpoints = chatgpt_api_endpoints
     self.web_chat_urls = web_chat_urls
@@ -28,11 +30,7 @@ class TopologyViz:
 
     self.console = Console()
     self.layout = Layout()
-    self.layout.split(
-      Layout(name="main"),
-      Layout(name="prompt_output", size=15),
-      Layout(name="download", size=25)
-    )
+    self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25))
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
@@ -75,11 +73,11 @@ class TopologyViz:
 
     # Update and show/hide prompt and output panel
     if any(r[0] or r[1] for r in self.requests.values()):
-        self.prompt_output_panel = self._generate_prompt_output_layout()
-        self.layout["prompt_output"].update(self.prompt_output_panel)
-        self.layout["prompt_output"].visible = True
+      self.prompt_output_panel = self._generate_prompt_output_layout()
+      self.layout["prompt_output"].update(self.prompt_output_panel)
+      self.layout["prompt_output"].visible = True
     else:
-        self.layout["prompt_output"].visible = False
+      self.layout["prompt_output"].visible = False
 
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
@@ -97,33 +95,33 @@ class TopologyViz:
     max_lines = 13  # Maximum number of lines for the entire panel content
 
     for (prompt, output) in reversed(requests):
-        prompt_icon, output_icon = "💬️", "🤖"
-
-        # Process prompt
-        prompt_lines = prompt.split('\n')
-        if len(prompt_lines) > max_lines // 2:
-            prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
-        prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
-        prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
-
-        # Process output
-        output_lines = output.split('\n')
-        remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
-        if len(output_lines) > remaining_lines:
-            output_lines = output_lines[:remaining_lines - 1] + ['...']
-        output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
-        output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
-
-        content.append(prompt_text)
-        content.append(output_text)
-        content.append(Text())  # Empty line between entries
+      prompt_icon, output_icon = "💬️", "🤖"
+
+      # Process prompt
+      prompt_lines = prompt.split('\n')
+      if len(prompt_lines) > max_lines // 2:
+        prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+      prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+      prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+      # Process output
+      output_lines = output.split('\n')
+      remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+      if len(output_lines) > remaining_lines:
+        output_lines = output_lines[:remaining_lines - 1] + ['...']
+      output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+      output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+      content.append(prompt_text)
+      content.append(output_text)
+      content.append(Text())  # Empty line between entries
 
     return Panel(
-        Group(*content),
-        title="",
-        border_style="cyan",
-        height=15,  # Increased height to accommodate multiple lines
-        expand=True  # Allow the panel to expand to full width
+      Group(*content),
+      title="",
+      border_style="cyan",
+      height=15,  # Increased height to accommodate multiple lines
+      expand=True  # Allow the panel to expand to full width
     )
 
   def _generate_main_layout(self) -> str:
@@ -185,14 +183,14 @@ class TopologyViz:
       visualization[bar_y][bar_start_x + i] = segment
 
     # Add labels
-    visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
-    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
+    visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
+    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich"
 
     # Add position indicator and FLOPS value
     pos_x = bar_start_x + int(bar_pos * bar_width)
     flops_str = f"{total_flops:.2f} TFLOPS"
     visualization[bar_y - 1][pos_x] = "▼"
-    visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
+    visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
 
     # Add an extra empty line for spacing
@@ -270,41 +268,41 @@ class TopologyViz:
 
     # Current node download progress
     if self.node_id in self.node_download_progress:
-        download_progress = self.node_download_progress[self.node_id]
-        title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
-        summary.add_row(Text(title, style="bold"))
-        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
-        summary.add_row(progress_info)
+      download_progress = self.node_download_progress[self.node_id]
+      title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
+      summary.add_row(Text(title, style="bold"))
+      progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+      summary.add_row(progress_info)
 
-        eta_info = f"{download_progress.overall_eta}"
-        summary.add_row(eta_info)
+      eta_info = f"{download_progress.overall_eta}"
+      summary.add_row(eta_info)
 
-        summary.add_row("")  # Empty row for spacing
+      summary.add_row("")  # Empty row for spacing
 
-        for file_path, file_progress in download_progress.file_progress.items():
-            if file_progress.status != "complete":
-                progress = int(file_progress.downloaded / file_progress.total * 30)
-                bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
-                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
-                summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
+      for file_path, file_progress in download_progress.file_progress.items():
+        if file_progress.status != "complete":
+          progress = int(file_progress.downloaded / file_progress.total * 30)
+          bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
+          percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+          summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
 
     summary.add_row("")  # Empty row for spacing
 
     # Other nodes download progress summary
     summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
     for node_id, progress in self.node_download_progress.items():
-        if node_id != self.node_id:
-            device = self.topology.nodes.get(node_id)
-            partition = next((p for p in self.partitions if p.node_id == node_id), None)
-            partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
-            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
-            speed = pretty_print_bytes_per_second(progress.overall_speed)
-            device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
-            progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
-            progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
-            percentage_str = f"{percentage:.1f}%"
-            eta_str = f"{progress.overall_eta}"
-            summary.add_row(device_info, progress_info, percentage_str)
-            summary.add_row("", progress_bar, eta_str)
-
-    return summary
+      if node_id != self.node_id:
+        device = self.topology.nodes.get(node_id)
+        partition = next((p for p in self.partitions if p.node_id == node_id), None)
+        partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
+        percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+        speed = pretty_print_bytes_per_second(progress.overall_speed)
+        device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
+        progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
+        progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
+        percentage_str = f"{percentage:.1f}%"
+        eta_str = f"{progress.overall_eta}"
+        summary.add_row(device_info, progress_info, percentage_str)
+        summary.add_row("", progress_bar, eta_str)
+
+    return summary

+ 35 - 37
extra/download_hf.py

@@ -3,51 +3,49 @@ import asyncio
 from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 
 DEFAULT_ALLOW_PATTERNS = [
-    "*.json",
-    "*.py",
-    "tokenizer.model",
-    "*.tiktoken",
-    "*.txt",
-    "*.safetensors",
+  "*.json",
+  "*.py",
+  "tokenizer.model",
+  "*.tiktoken",
+  "*.txt",
+  "*.safetensors",
 ]
 # Always ignore `.git` and `.cache/huggingface` folders in commits
 DEFAULT_IGNORE_PATTERNS = [
-    ".git",
-    ".git/*",
-    "*/.git",
-    "**/.git/**",
-    ".cache/huggingface",
-    ".cache/huggingface/*",
-    "*/.cache/huggingface",
-    "**/.cache/huggingface/**",
+  ".git",
+  ".git/*",
+  "*/.git",
+  "**/.git/**",
+  ".cache/huggingface",
+  ".cache/huggingface/*",
+  "*/.cache/huggingface",
+  "**/.cache/huggingface/**",
 ]
 
+
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(event: RepoProgressEvent):
-        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
-        print(f"Estimated time remaining: {event.overall_eta}")
-        print("File Progress:")
-        for file_path, progress in event.file_progress.items():
-            status_icon = {
-                'not_started': '⚪',
-                'in_progress': '🔵',
-                'complete': '✅'
-            }[progress.status]
-            eta_str = str(progress.eta)
-            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
-                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
-        print("\n")
-
-    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+
+  async def progress_callback(event: RepoProgressEvent):
+    print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+    print(f"Estimated time remaining: {event.overall_eta}")
+    print("File Progress:")
+    for file_path, progress in event.file_progress.items():
+      status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
+      eta_str = str(progress.eta)
+      print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+            f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+    print("\n")
+
+  await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
 
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
-    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
-    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
-    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
-    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+  parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+  parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+  parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+  parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+  parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
 
-    args = parser.parse_args()
+  args = parser.parse_args()
 
-    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
+  asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 113 - 98
main.py

@@ -53,131 +53,146 @@ inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 if args.node_port is None:
-    args.node_port = find_available_port(args.node_host)
-    if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+  args.node_port = find_available_port(args.node_host)
+  if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 args.node_id = args.node_id or get_or_create_node_id()
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
-chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
 if DEBUG >= 0:
-    print("Chat interface started:")
-    for web_chat_url in web_chat_urls:
-        print(f" - {terminal_link(web_chat_url)}")
-    print("ChatGPT API endpoint served at:")
-    for chatgpt_api_endpoint in chatgpt_api_endpoints:
-        print(f" - {terminal_link(chatgpt_api_endpoint)}")
+  print("Chat interface started:")
+  for web_chat_url in web_chat_urls:
+    print(f" - {terminal_link(web_chat_url)}")
+  print("ChatGPT API endpoint served at:")
+  for chatgpt_api_endpoint in chatgpt_api_endpoints:
+    print(f" - {terminal_link(chatgpt_api_endpoint)}")
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
-    args.node_id,
-    None,
-    inference_engine,
-    discovery,
-    chatgpt_api_endpoints=chatgpt_api_endpoints,
-    web_chat_urls=web_chat_urls,
-    partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
-    disable_tui=args.disable_tui,
-    max_generate_tokens=args.max_generate_tokens,
-    topology_viz=topology_viz
+  args.node_id,
+  None,
+  inference_engine,
+  discovery,
+  chatgpt_api_endpoints=chatgpt_api_endpoints,
+  web_chat_urls=web_chat_urls,
+  partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+  disable_tui=args.disable_tui,
+  max_generate_tokens=args.max_generate_tokens,
+  topology_viz=topology_viz
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None)
+api = ChatGPTAPI(
+  node,
+  inference_engine.__class__.__name__,
+  response_timeout_secs=args.chatgpt_api_response_timeout_secs,
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+)
 node.on_token.register("update_topology_viz").on_next(
-    lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
+  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id,
+                                                               inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens) if topology_viz else None
 )
+
+
 def preemptively_start_download(request_id: str, opaque_status: str):
-    try:
-        status = json.loads(opaque_status)
-        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
-            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
-            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
-    except Exception as e:
-        if DEBUG >= 2:
-            print(f"Failed to preemptively start download: {e}")
-            traceback.print_exc()
+  try:
+    status = json.loads(opaque_status)
+    if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+      current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+      if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Failed to preemptively start download: {e}")
+      traceback.print_exc()
+
+
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 if args.prometheus_client_port:
-    from exo.stats.metrics import start_metrics_server
-    start_metrics_server(node, args.prometheus_client_port)
+  from exo.stats.metrics import start_metrics_server
+  start_metrics_server(node, args.prometheus_client_port)
 
 last_broadcast_time = 0
+
+
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+
+
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
+
 async def shutdown(signal, loop):
-    """Gracefully shutdown the server and close the asyncio loop."""
-    print(f"Received exit signal {signal.name}...")
-    print("Thank you for using exo.")
-    print_yellow_exo()
-    server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-    [task.cancel() for task in server_tasks]
-    print(f"Cancelling {len(server_tasks)} outstanding tasks")
-    await asyncio.gather(*server_tasks, return_exceptions=True)
-    await server.stop()
-    loop.stop()
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-    shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
-    if not shard:
-        print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
-        return
-    tokenizer = await resolve_tokenizer(shard.model_id)
-    request_id = str(uuid.uuid4())
-    callback_id = f"cli-wait-response-{request_id}"
-    callback = node.on_token.register(callback_id)
-    if topology_viz:
-        topology_viz.update_prompt(request_id, prompt)
-    prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
-
-    try:
-        print(f"Processing prompt: {prompt}")
-        await node.process_prompt(shard, prompt, None, request_id=request_id)
-
-        _, tokens, _ = await callback.wait(
-            lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
-            timeout=300
-        )
-
-        print("\nGenerated response:")
-        print(tokenizer.decode(tokens))
-    except Exception as e:
-        print(f"Error processing prompt: {str(e)}")
-        traceback.print_exc()
-    finally:
-        node.on_token.deregister(callback_id)
+  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  if not shard:
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    return
+  tokenizer = await resolve_tokenizer(shard.model_id)
+  request_id = str(uuid.uuid4())
+  callback_id = f"cli-wait-response-{request_id}"
+  callback = node.on_token.register(callback_id)
+  if topology_viz:
+    topology_viz.update_prompt(request_id, prompt)
+  prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+  try:
+    print(f"Processing prompt: {prompt}")
+    await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+
+    print("\nGenerated response:")
+    print(tokenizer.decode(tokens))
+  except Exception as e:
+    print(f"Error processing prompt: {str(e)}")
+    traceback.print_exc()
+  finally:
+    node.on_token.deregister(callback_id)
+
 
 async def main():
-    loop = asyncio.get_running_loop()
+  loop = asyncio.get_running_loop()
+
+  # Use a more direct approach to handle signals
+  def handle_exit():
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
-    # Use a more direct approach to handle signals
-    def handle_exit():
-        asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+  for s in [signal.SIGINT, signal.SIGTERM]:
+    loop.add_signal_handler(s, handle_exit)
 
-    for s in [signal.SIGINT, signal.SIGTERM]:
-        loop.add_signal_handler(s, handle_exit)
+  await node.start(wait_for_peers=args.wait_for_peers)
 
-    await node.start(wait_for_peers=args.wait_for_peers)
+  if args.run_model:
+    await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+  else:
+    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+    await asyncio.Event().wait()
 
-    if args.run_model:
-        await run_model_cli(node, inference_engine, args.run_model, args.prompt)
-    else:
-        asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
-        await asyncio.Event().wait()
 
 if __name__ == "__main__":
-    loop = asyncio.new_event_loop()
-    asyncio.set_event_loop(loop)
-    try:
-        loop.run_until_complete(main())
-    except KeyboardInterrupt:
-        print("Received keyboard interrupt. Shutting down...")
-    finally:
-        loop.run_until_complete(shutdown(signal.SIGTERM, loop))
-        loop.close()
+  loop = asyncio.new_event_loop()
+  asyncio.set_event_loop(loop)
+  try:
+    loop.run_until_complete(main())
+  except KeyboardInterrupt:
+    print("Received keyboard interrupt. Shutting down...")
+  finally:
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.close()

Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác