Przeglądaj źródła

Merge branch 'main' into better_networking

Alex Cheema 11 miesięcy temu
rodzic
commit
2341aa1acf
72 zmienionych plików z 2941 dodań i 2379 usunięć
  1. 5 0
      .circleci/config.yml
  2. 19 19
      .gitignore
  3. 19 0
      .style.yapf
  4. 39 0
      examples/chatgpt_api.sh
  5. 0 88
      examples/llama3_distributed.py
  6. 84 137
      exo/api/chatgpt_api.py
  7. 47 68
      exo/download/download_progress.py
  8. 347 292
      exo/download/hf/hf_helpers.py
  9. 58 60
      exo/download/hf/hf_shard_download.py
  10. 9 8
      exo/download/shard_download.py
  11. 82 89
      exo/helpers.py
  12. 6 8
      exo/inference/debug_inference_engine.py
  13. 17 0
      exo/inference/inference_engine.py
  14. 3 6
      exo/inference/mlx/models/deepseek_v2.py
  15. 4 7
      exo/inference/mlx/models/llama.py
  16. 507 555
      exo/inference/mlx/models/llava.py
  17. 8 5
      exo/inference/mlx/sharded_model.py
  18. 31 28
      exo/inference/mlx/sharded_utils.py
  19. 6 6
      exo/inference/mlx/test_sharded_llava.py
  20. 2 2
      exo/inference/mlx/test_sharded_model.py
  21. 2 4
      exo/inference/shard.py
  22. 13 11
      exo/inference/test_inference_engine.py
  23. 10 13
      exo/inference/tinygrad/inference.py
  24. 76 42
      exo/inference/tinygrad/models/llama.py
  25. 7 3
      exo/inference/tinygrad/tinygrad_helpers.py
  26. 41 0
      exo/inference/tokenizers.py
  27. 29 0
      exo/models.py
  28. 1 1
      exo/networking/grpc/grpc_peer_handle.py
  29. 17 14
      exo/networking/grpc/grpc_server.py
  30. 0 3
      exo/networking/grpc/node_service_pb2.py
  31. 228 273
      exo/networking/grpc/node_service_pb2_grpc.py
  32. 9 11
      exo/networking/udp_discovery.py
  33. 63 70
      exo/orchestration/standard_node.py
  34. 1 1
      exo/stats/metrics.py
  35. 73 64
      exo/topology/device_capabilities.py
  36. 2 2
      exo/topology/partitioning_strategy.py
  37. 1 1
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  38. 1 1
      exo/topology/test_device_capabilities.py
  39. 2 2
      exo/topology/test_map_partitions.py
  40. 3 3
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  41. 52 51
      exo/viz/test_topology_viz.py
  42. 122 57
      exo/viz/topology_viz.py
  43. 34 37
      extra/download_hf.py
  44. 3 0
      extra/start_openwebui.sh
  45. 10 87
      format.py
  46. 120 67
      main.py
  47. 0 10
      pyproject.toml
  48. 41 42
      setup.py
  49. 34 0
      test/test_tokenizers.py
  50. 119 131
      tinychat/examples/tinychat/index.html
  51. 0 0
      tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js
  52. 0 0
      tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js
  53. 1 0
      tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js
  54. 11 0
      tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css
  55. 5 0
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css
  56. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf
  57. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2
  58. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf
  59. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2
  60. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf
  61. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2
  62. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf
  63. BIN
      tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2
  64. 7 0
      tinychat/examples/tinychat/static/fonts.googleapis.com/css2
  65. 316 0
      tinychat/examples/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js
  66. 1 0
      tinychat/examples/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css
  67. 0 0
      tinychat/examples/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js
  68. 0 0
      tinychat/examples/tinychat/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js
  69. 1 0
      tinychat/examples/tinychat/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js
  70. 97 0
      tinychat/examples/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js
  71. 5 0
      tinychat/examples/tinychat/static/unpkg.com/marked@13.0.0/marked.min.js
  72. 90 0
      tinychat/examples/tinychat/update_deps.py

+ 5 - 0
.circleci/config.yml

@@ -44,6 +44,7 @@ commands:
             # Check processes before proceeding
             # Check processes before proceeding
             check_processes
             check_processes
 
 
+            echo "Sending request to first instance..."
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
               -H "Content-Type: application/json" \
               -H "Content-Type: application/json" \
               -d '{
               -d '{
@@ -56,6 +57,7 @@ commands:
             # Check processes after first response
             # Check processes after first response
             check_processes
             check_processes
 
 
+            echo "Sending request to second instance..."
             response_2=$(curl -s http://localhost:8001/v1/chat/completions \
             response_2=$(curl -s http://localhost:8001/v1/chat/completions \
               -H "Content-Type: application/json" \
               -H "Content-Type: application/json" \
               -d '{
               -d '{
@@ -110,7 +112,10 @@ jobs:
           command: |
           command: |
             source env/bin/activate
             source env/bin/activate
             # set TEMPERATURE to 0 for deterministic sampling
             # set TEMPERATURE to 0 for deterministic sampling
+            echo "Running inference engine tests..."
             METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
+            echo "Running tokenizer tests..."
+            python3 ./test/test_tokenizers.py
 
 
   discovery_integration_test:
   discovery_integration_test:
     macos:
     macos:

+ 19 - 19
.gitignore

@@ -14,24 +14,24 @@ __pycache__/
 *.so
 *.so
 
 
 # Distribution / packaging
 # Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
+/.Python
+/build/
+/develop-eggs/
+/dist/
+/downloads/
+/eggs/
+/.eggs/
+/lib/
+/lib64/
+/parts/
+/sdist/
+/var/
+/wheels/
+/share/python-wheels/
+/*.egg-info/
+/.installed.cfg
+/*.egg
+/MANIFEST
 
 
 # PyInstaller
 # PyInstaller
 #  Usually these files are written by a python script from a template
 #  Usually these files are written by a python script from a template
@@ -169,4 +169,4 @@ cython_debug/
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
 #.idea/
 #.idea/
 
 
-**/*.xcodeproj/*
+**/*.xcodeproj/*

+ 19 - 0
.style.yapf

@@ -0,0 +1,19 @@
+[style]
+based_on_style = pep8
+indent_width = 2
+column_limit = 200
+allow_split_before_dict_value = False
+dedent_closing_brackets = True
+split_before_first_argument = False
+split_complex_comprehension = False
+continuation_indent_width = 2
+indent_dictionary_value = True
+allow_multiline_dictionary_keys = True
+each_dict_entry_on_separate_line = False
+allow_multiline_lambdas = True
+blank_line_before_nested_class_or_def = False
+arithmetic_precedence_indication = True
+no_spaces_around_selected_binary_operators = "*,/"
+coalesce_brackets = True
+space_between_ending_comma_and_closing_bracket = False
+split_before_expression_after_opening_paren = False

+ 39 - 0
examples/chatgpt_api.sh

@@ -0,0 +1,39 @@
+# exo provides an API that aims to be a drop-in replacements for the ChatGPT-API.
+# This example shows how you can use the API first without streaming and second with streaming.
+# This works the same in a single-node set up and in a multi-node setup.
+# You need to start exo before running this by running `python3 main.py`.
+
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+MODEL="llama-3.1-8b"
+PROMPT="What is the meaning of exo?"
+TEMPERATURE=0.7
+
+echo ""
+echo ""
+echo "--- Output without streaming:"
+echo ""
+curl "${API_ENDPOINT}/v1/chat/completions" --silent \
+  -H "Content-Type: application/json" \
+  -d '{
+     "model": "'"${MODEL}"'",
+     "messages": [{"role": "user", "content": "'"${PROMPT}"'"}],
+     "temperature": '"${TEMPERATURE}"'
+   }'
+
+echo ""
+echo ""
+echo "--- Output with streaming:"
+echo ""
+curl "${API_ENDPOINT}/v1/chat/completions" --silent \
+  -H "Content-Type: application/json" \
+  -d '{
+     "model": "'"${MODEL}"'",
+     "messages": [{"role": "user", "content": "'"${PROMPT}"'"}],
+     "temperature": '"${TEMPERATURE}"',
+     "stream": true
+   }' | while read -r line; do
+       if [[ $line == data:* ]]; then
+           content=$(echo "$line" | sed 's/^data: //')
+           echo "$content" | jq -r '.choices[].delta.content' --unbuffered | tr -d '\n'
+       fi
+   done

+ 0 - 88
examples/llama3_distributed.py

@@ -1,88 +0,0 @@
-# In this example, a user is running a home cluster with 3 shards.
-# They are prompting the cluster to generate a response to a question.
-# The cluster is given the question, and the user is given the response.
-
-from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-from exo.inference.shard import Shard
-from exo.networking.peer_handle import PeerHandle
-from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
-from typing import List
-import asyncio
-import argparse
-import uuid
-
-models = {
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
-}
-
-path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
-model_path = get_model_path(path_or_hf_repo)
-tokenizer_config = {}
-tokenizer = load_tokenizer(model_path, tokenizer_config)
-
-# we intentionally leave out peer1 to demonstrate equality of nodes in exo.
-# there is no "master" node in exo, all nodes are equal and can take on any role.
-# peer1 = GRPCPeerHandle(
-#     "node1",
-#     "localhost:8080",
-#     DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
-# )
-peer2 = GRPCPeerHandle(
-    "node2",
-    "localhost:8081",
-    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
-)
-shard = models[path_or_hf_repo]
-request_id = str(uuid.uuid4())
-
-async def run_prompt(prompt: str):
-    if tokenizer.chat_template is None:
-        tokenizer.chat_template = tokenizer.default_chat_template
-    if (
-        hasattr(tokenizer, "apply_chat_template")
-        and tokenizer.chat_template is not None
-    ):
-        messages = [{"role": "user", "content": prompt}]
-        prompt = tokenizer.apply_chat_template(
-            messages, tokenize=False, add_generation_prompt=True
-        )
-
-    await peer2.connect()
-
-    try:
-        await peer2.send_prompt(shard, prompt, request_id)
-    except Exception as e:
-        print(e)
-
-    import time
-    # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
-    previous_length = 0
-    n_tokens = 0
-    start_time = time.perf_counter()
-    while True:
-        try:
-            result, is_finished = await peer2.get_inference_result(request_id)
-        except Exception as e:
-            continue
-        await asyncio.sleep(0.1)
-
-        # Print the updated string in place
-        updated_string = tokenizer.decode(result)
-        n_tokens = len(result)
-        print(updated_string[previous_length:], end='', flush=True)
-        previous_length = len(updated_string)
-
-        if is_finished:
-            print("\nDone")
-            break
-    end_time = time.perf_counter()
-    print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Run prompt")
-    parser.add_argument("--prompt", type=str, help="The prompt to run")
-    args = parser.parse_args()
-
-    asyncio.run(run_prompt(args.prompt))

+ 84 - 137
exo/api/chatgpt_api.py

@@ -3,106 +3,37 @@ import time
 import asyncio
 import asyncio
 import json
 import json
 from pathlib import Path
 from pathlib import Path
-from transformers import AutoTokenizer, AutoProcessor
+from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
 from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
-from exo.helpers import terminal_link, PrefixDict
+from exo.helpers import PrefixDict
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-
-shard_mappings = {
-  ### llama
-  "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),
-  },
-  "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-  },
-  "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
-  ### mistral
-  "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-  },
-  "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-  },
-  ### deepseek v2
-  "deepseek-coder-v2-lite": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
-  },
-  ### llava
-  "llava-1.5-7b-hf": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),
-  },
-}
-
+from exo.models import model_base_shards
+from typing import Callable
 
 
 
 
 class Message:
 class Message:
-    def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
-        self.role = role
-        self.content = content
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+    self.role = role
+    self.content = content
 
 
-    def to_dict(self):
-        return {
-            "role": self.role,
-            "content": self.content
-        }
+  def to_dict(self):
+    return {"role": self.role, "content": self.content}
 
 
 
 
 class ChatCompletionRequest:
 class ChatCompletionRequest:
-    def __init__(self, model: str, messages: List[Message], temperature: float):
-        self.model = model
-        self.messages = messages
-        self.temperature = temperature
-
-    def to_dict(self):
-        return {
-            "model": self.model,
-            "messages": [message.to_dict() for message in self.messages],
-            "temperature": self.temperature
-        }
-
-
-
-async def resolve_tokenizer(model_id: str):
-  try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id}")
-    processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
-    if not hasattr(processor, 'eos_token_id'):
-      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
-    if not hasattr(processor, 'encode'):
-      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
-    if not hasattr(processor, 'decode'):
-      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
-    return processor
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id}. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id}")
-    return AutoTokenizer.from_pretrained(model_id)
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  raise ValueError(f"[TODO] Unsupported model: {model_id}")
+  def __init__(self, model: str, messages: List[Message], temperature: float):
+    self.model = model
+    self.messages = messages
+    self.temperature = temperature
+
+  def to_dict(self):
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
 
 
 
 
 def generate_completion(
 def generate_completion(
@@ -121,14 +52,12 @@ def generate_completion(
     "created": int(time.time()),
     "created": int(time.time()),
     "model": chat_request.model,
     "model": chat_request.model,
     "system_fingerprint": f"exo_{VERSION}",
     "system_fingerprint": f"exo_{VERSION}",
-    "choices": [
-      {
-        "index": 0,
-        "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
-        "logprobs": None,
-        "finish_reason": finish_reason,
-      }
-    ],
+    "choices": [{
+      "index": 0,
+      "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
+      "logprobs": None,
+      "finish_reason": finish_reason,
+    }],
   }
   }
 
 
   if not stream:
   if not stream:
@@ -151,37 +80,38 @@ def generate_completion(
 
 
 
 
 def remap_messages(messages: List[Message]) -> List[Message]:
 def remap_messages(messages: List[Message]) -> List[Message]:
-    remapped_messages = []
-    last_image = None
-    for message in messages:
-        if not isinstance(message.content, list):
-           remapped_messages.append(message)
-           continue
-
-        remapped_content = []
-        for content in message.content:
-            if isinstance(content, dict):
-                if content.get("type") in ["image_url", "image"]:
-                    image_url = content.get("image_url", {}).get("url") or content.get("image")
-                    if image_url:
-                        last_image = {"type": "image", "image": image_url}
-                        remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
-                else:
-                    remapped_content.append(content)
-            else:
-                remapped_content.append(content)
-        remapped_messages.append(Message(role=message.role, content=remapped_content))
-
-    if last_image:
-        # Replace the last image placeholder with the actual image content
-        for message in reversed(remapped_messages):
-            for i, content in enumerate(message.content):
-                if isinstance(content, dict):
-                  if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
-                      message.content[i] = last_image
-                      return remapped_messages
-
-    return remapped_messages
+  remapped_messages = []
+  last_image = None
+  for message in messages:
+    if not isinstance(message.content, list):
+      remapped_messages.append(message)
+      continue
+
+    remapped_content = []
+    for content in message.content:
+      if isinstance(content, dict):
+        if content.get("type") in ["image_url", "image"]:
+          image_url = content.get("image_url", {}).get("url") or content.get("image")
+          if image_url:
+            last_image = {"type": "image", "image": image_url}
+            remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+        else:
+          remapped_content.append(content)
+      else:
+        remapped_content.append(content)
+    remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+  if last_image:
+    # Replace the last image placeholder with the actual image content
+    for message in reversed(remapped_messages):
+      for i, content in enumerate(message.content):
+        if isinstance(content, dict):
+          if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+            message.content[i] = last_image
+            return remapped_messages
+
+  return remapped_messages
+
 
 
 def build_prompt(tokenizer, _messages: List[Message]):
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   messages = remap_messages(_messages)
@@ -214,18 +144,21 @@ def parse_chat_request(data: dict):
     data.get("temperature", 0.0),
     data.get("temperature", 0.0),
   )
   )
 
 
+
 class PromptSession:
 class PromptSession:
   def __init__(self, request_id: str, timestamp: int, prompt: str):
   def __init__(self, request_id: str, timestamp: int, prompt: str):
     self.request_id = request_id
     self.request_id = request_id
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
+
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
     self.response_timeout_secs = response_timeout_secs
-    self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
+    self.on_chat_completion_request = on_chat_completion_request
+    self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
@@ -236,9 +169,14 @@ class ChatGPTAPI:
       allow_headers="*",
       allow_headers="*",
       allow_methods="*",
       allow_methods="*",
     )
     )
-    cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
+    cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
+    cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
-    self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
+    cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+
+    self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_static("/", self.static_dir, name="static")
     self.app.router.add_static("/", self.static_dir, name="static")
 
 
@@ -253,11 +191,14 @@ class ChatGPTAPI:
     return middleware
     return middleware
 
 
   async def handle_root(self, request):
   async def handle_root(self, request):
-    return web.FileResponse(self.static_dir / "index.html")
+    return web.FileResponse(self.static_dir/"index.html")
+
+  async def handle_get_models(self, request):
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()])
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
     data = await request.json()
-    shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(shard.model_id)
     tokenizer = await resolve_tokenizer(shard.model_id)
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
@@ -269,12 +210,12 @@ class ChatGPTAPI:
     chat_request = parse_chat_request(data)
     chat_request = parse_chat_request(data)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = "llama-3.1-8b"
       chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in shard_mappings:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+    if not chat_request.model or chat_request.model not in model_base_shards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
       chat_request.model = "llama-3.1-8b"
       chat_request.model = "llama-3.1-8b"
-    shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
+    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
     if not shard:
     if not shard:
-      supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
       return web.json_response(
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
         status=400,
@@ -285,6 +226,11 @@ class ChatGPTAPI:
 
 
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     prompt, image_str = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
+    if self.on_chat_completion_request:
+      try:
+        self.on_chat_completion_request(request_id, chat_request, prompt)
+      except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
     # request_id = None
     # request_id = None
     # match = self.prompts.find_longest_prefix(prompt)
     # match = self.prompts.find_longest_prefix(prompt)
     # if match and len(prompt) > len(match[1].prompt):
     # if match and len(prompt) > len(match[1].prompt):
@@ -316,7 +262,7 @@ class ChatGPTAPI:
           status=200,
           status=200,
           reason="OK",
           reason="OK",
           headers={
           headers={
-            "Content-Type": "application/json",
+            "Content-Type": "text/event-stream",
             "Cache-Control": "no-cache",
             "Cache-Control": "no-cache",
           },
           },
         )
         )
@@ -327,7 +273,8 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
+                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             new_tokens = new_tokens[:-1]
             if is_finished:
             if is_finished:

+ 47 - 68
exo/download/download_progress.py

@@ -2,81 +2,60 @@ from typing import Dict, Callable, Coroutine, Any, Literal
 from dataclasses import dataclass
 from dataclasses import dataclass
 from datetime import timedelta
 from datetime import timedelta
 
 
+
 @dataclass
 @dataclass
 class RepoFileProgressEvent:
 class RepoFileProgressEvent:
-    repo_id: str
-    repo_revision: str
-    file_path: str
-    downloaded: int
-    downloaded_this_session: int
-    total: int
-    speed: int
-    eta: timedelta
-    status: Literal["not_started", "in_progress", "complete"]
-
-    def to_dict(self):
-        return {
-            "repo_id": self.repo_id,
-            "repo_revision": self.repo_revision,
-            "file_path": self.file_path,
-            "downloaded": self.downloaded,
-            "downloaded_this_session": self.downloaded_this_session,
-            "total": self.total,
-            "speed": self.speed,
-            "eta": self.eta.total_seconds(),
-            "status": self.status
-        }
+  repo_id: str
+  repo_revision: str
+  file_path: str
+  downloaded: int
+  downloaded_this_session: int
+  total: int
+  speed: int
+  eta: timedelta
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+    }
+
+  @classmethod
+  def from_dict(cls, data):
+    if 'eta' in data: data['eta'] = timedelta(seconds=data['eta'])
+    return cls(**data)
 
 
-    @classmethod
-    def from_dict(cls, data):
-        # Convert eta from seconds back to timedelta
-        if 'eta' in data:
-            data['eta'] = timedelta(seconds=data['eta'])
-        return cls(**data)
 
 
 @dataclass
 @dataclass
 class RepoProgressEvent:
 class RepoProgressEvent:
-    repo_id: str
-    repo_revision: str
-    completed_files: int
-    total_files: int
-    downloaded_bytes: int
-    downloaded_bytes_this_session: int
-    total_bytes: int
-    overall_speed: int
-    overall_eta: timedelta
-    file_progress: Dict[str, RepoFileProgressEvent]
-    status: Literal["not_started", "in_progress", "complete"]
-
-    def to_dict(self):
-        return {
-            "repo_id": self.repo_id,
-            "repo_revision": self.repo_revision,
-            "completed_files": self.completed_files,
-            "total_files": self.total_files,
-            "downloaded_bytes": self.downloaded_bytes,
-            "downloaded_bytes_this_session": self.downloaded_bytes_this_session,
-            "total_bytes": self.total_bytes,
-            "overall_speed": self.overall_speed,
-            "overall_eta": self.overall_eta.total_seconds(),
-            "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()},
-            "status": self.status
-        }
-
-    @classmethod
-    def from_dict(cls, data):
-        # Convert overall_eta from seconds back to timedelta
-        if 'overall_eta' in data:
-            data['overall_eta'] = timedelta(seconds=data['overall_eta'])
-
-        # Parse file_progress
-        if 'file_progress' in data:
-            data['file_progress'] = {
-                k: RepoFileProgressEvent.from_dict(v)
-                for k, v in data['file_progress'].items()
-            }
+  repo_id: str
+  repo_revision: str
+  completed_files: int
+  total_files: int
+  downloaded_bytes: int
+  downloaded_bytes_this_session: int
+  total_bytes: int
+  overall_speed: int
+  overall_eta: timedelta
+  file_progress: Dict[str, RepoFileProgressEvent]
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
+      "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
+      "file_progress": {k: v.to_dict()
+                        for k, v in self.file_progress.items()}, "status": self.status
+    }
+
+  @classmethod
+  def from_dict(cls, data):
+    if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+    if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
+
+    return cls(**data)
 
 
-        return cls(**data)
 
 
 RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
 RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
 RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]
 RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 347 - 292
exo/download/hf/hf_helpers.py

@@ -16,283 +16,341 @@ import aiofiles
 from aiofiles import os as aios
 from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
+
+async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
+  refs_dir = get_repo_root(repo_id)/"refs"
+  refs_file = refs_dir/revision
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
+      return snapshot_dir
+  return None
+
+
 def filter_repo_objects(
 def filter_repo_objects(
-    items: Iterable[T],
-    *,
-    allow_patterns: Optional[Union[List[str], str]] = None,
-    ignore_patterns: Optional[Union[List[str], str]] = None,
-    key: Optional[Callable[[T], str]] = None,
+  items: Iterable[T],
+  *,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  key: Optional[Callable[[T], str]] = None,
 ) -> Generator[T, None, None]:
 ) -> Generator[T, None, None]:
-    if isinstance(allow_patterns, str):
-        allow_patterns = [allow_patterns]
-    if isinstance(ignore_patterns, str):
-        ignore_patterns = [ignore_patterns]
-    if allow_patterns is not None:
-        allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
-    if ignore_patterns is not None:
-        ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
-
-    if key is None:
-        def _identity(item: T) -> str:
-            if isinstance(item, str):
-                return item
-            if isinstance(item, Path):
-                return str(item)
-            raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-        key = _identity
-
-    for item in items:
-        path = key(item)
-        if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
-            continue
-        if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
-            continue
-        yield item
+  if isinstance(allow_patterns, str):
+    allow_patterns = [allow_patterns]
+  if isinstance(ignore_patterns, str):
+    ignore_patterns = [ignore_patterns]
+  if allow_patterns is not None:
+    allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
+  if ignore_patterns is not None:
+    ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
+
+  if key is None:
+
+    def _identity(item: T) -> str:
+      if isinstance(item, str):
+        return item
+      if isinstance(item, Path):
+        return str(item)
+      raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
+
+    key = _identity
+
+  for item in items:
+    path = key(item)
+    if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
+      continue
+    if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
+      continue
+    yield item
+
 
 
 def _add_wildcard_to_directories(pattern: str) -> str:
 def _add_wildcard_to_directories(pattern: str) -> str:
-    if pattern[-1] == "/":
-        return pattern + "*"
-    return pattern
+  if pattern[-1] == "/":
+    return pattern + "*"
+  return pattern
+
 
 
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
-    """Get the Hugging Face home directory."""
-    return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+  """Get the Hugging Face home directory."""
+  return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
+
 
 
 async def get_hf_token():
 async def get_hf_token():
-    """Retrieve the Hugging Face token from the user's HF_HOME directory."""
-    token_path = get_hf_home() / "token"
-    if await aios.path.exists(token_path):
-        async with aiofiles.open(token_path, 'r') as f:
-            return (await f.read()).strip()
-    return None
+  """Retrieve the Hugging Face token from the user's HF_HOME directory."""
+  token_path = get_hf_home()/"token"
+  if await aios.path.exists(token_path):
+    async with aiofiles.open(token_path, 'r') as f:
+      return (await f.read()).strip()
+  return None
+
 
 
 async def get_auth_headers():
 async def get_auth_headers():
-    """Get authentication headers if a token is available."""
-    token = await get_hf_token()
-    if token:
-        return {"Authorization": f"Bearer {token}"}
-    return {}
+  """Get authentication headers if a token is available."""
+  token = await get_hf_token()
+  if token:
+    return {"Authorization": f"Bearer {token}"}
+  return {}
+
 
 
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
-    """Get the root directory for a given repo ID in the Hugging Face cache."""
-    sanitized_repo_id = repo_id.replace("/", "--")
-    return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+  """Get the root directory for a given repo ID in the Hugging Face cache."""
+  sanitized_repo_id = repo_id.replace("/", "--")
+  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
+
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
-    api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
-    url = f"{api_url}/{path}" if path else api_url
-
-    headers = await get_auth_headers()
-    async with session.get(url, headers=headers) as response:
-        if response.status == 200:
-            data = await response.json()
-            files = []
-            for item in data:
-                if item["type"] == "file":
-                    files.append({"path": item["path"], "size": item["size"]})
-                elif item["type"] == "directory":
-                    subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-                    files.extend(subfiles)
-            return files
-        else:
-            raise Exception(f"Failed to fetch file list: {response.status}")
+  api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
 
 
 
 
 @retry(
 @retry(
-    stop=stop_after_attempt(5),
-    wait=wait_exponential(multiplier=1, min=4, max=60),
-    retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)),
-    reraise=True
+  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
 )
 )
-async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True):
-    base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
-    url = urljoin(base_url, file_path)
-    local_path = os.path.join(save_directory, file_path)
-
-    await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-    # Check if file already exists and get its size
-    local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-    headers = await get_auth_headers()
-    if use_range_request:
-        headers["Range"] = f"bytes={local_file_size}-"
-
-    async with session.get(url, headers=headers) as response:
-        total_size = int(response.headers.get('Content-Length', 0))
-        downloaded_size = local_file_size
-        downloaded_this_session = 0
-        mode = 'ab' if use_range_request else 'wb'
-        if downloaded_size == total_size:
-            if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-            if progress_callback:
-                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-            return
-
-        if response.status == 200:
-            # File doesn't support range requests or we're not using them, start from beginning
-            mode = 'wb'
-            downloaded_size = 0
-        elif response.status == 206:
-            # Partial content, resume download
-            content_range = response.headers.get('Content-Range', '')
-            try:
-                total_size = int(content_range.split('/')[-1])
-            except ValueError:
-                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-        elif response.status == 416:
-            # Range not satisfiable, get the actual file size
-            content_range = response.headers.get('Content-Range', '')
-            try:
-                total_size = int(content_range.split('/')[-1])
-                if downloaded_size == total_size:
-                    if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-                    if progress_callback:
-                        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-                    return
-            except ValueError:
-                if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-                return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-        else:
-            raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
+async def download_file(
+  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
+):
+  base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, file_path)
+  local_path = os.path.join(save_directory, file_path)
+
+  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
+
+  # Check if file already exists and get its size
+  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
+
+  headers = await get_auth_headers()
+  if use_range_request:
+    headers["Range"] = f"bytes={local_file_size}-"
+
+  async with session.get(url, headers=headers) as response:
+    total_size = int(response.headers.get('Content-Length', 0))
+    downloaded_size = local_file_size
+    downloaded_this_session = 0
+    mode = 'ab' if use_range_request else 'wb'
+    if downloaded_size == total_size:
+      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    if response.status == 200:
+      # File doesn't support range requests or we're not using them, start from beginning
+      mode = 'wb'
+      downloaded_size = 0
+    elif response.status == 206:
+      # Partial content, resume download
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+    elif response.status == 416:
+      # Range not satisfiable, get the actual file size
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
         if downloaded_size == total_size:
         if downloaded_size == total_size:
-            print(f"File already downloaded: {file_path}")
-            if progress_callback:
-                await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-            return
-
-        DOWNLOAD_CHUNK_SIZE = 32768
-        start_time = datetime.now()
-        async with aiofiles.open(local_path, mode) as f:
-            async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-                await f.write(chunk)
-                downloaded_size += len(chunk)
-                downloaded_this_session += len(chunk)
-                if progress_callback and total_size:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_size = total_size - downloaded_size
-                    eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
-                    status = "in_progress" if downloaded_size < total_size else "complete"
-                    if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-                    await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-        if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-async def download_repo_files(repo_id: str, revision: str = "main", progress_callback: Optional[RepoProgressCallback] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, max_parallel_downloads: int = 4) -> Path:
-    repo_root = get_repo_root(repo_id)
-    refs_dir = repo_root / "refs"
-    snapshots_dir = repo_root / "snapshots"
-    cachedreqs_dir = repo_root / "cachedreqs"
-
-    # Ensure directories exist
-    await aios.makedirs(refs_dir, exist_ok=True)
-    await aios.makedirs(snapshots_dir, exist_ok=True)
-    await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-    # Check if we have a cached commit hash
-    refs_file = refs_dir / revision
-    if await aios.path.exists(refs_file):
-        async with aiofiles.open(refs_file, 'r') as f:
-            commit_hash = (await f.read()).strip()
-            if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+          if progress_callback:
+            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+          return
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
     else:
     else:
-        async with aiohttp.ClientSession() as session:
-            # Fetch the commit hash for the given revision
-            api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
-            headers = await get_auth_headers()
-            async with session.get(api_url, headers=headers) as response:
-                if response.status != 200:
-                    raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-                revision_info = await response.json()
-                commit_hash = revision_info['sha']
-
-            # Cache the commit hash
-            async with aiofiles.open(refs_file, 'w') as f:
-                await f.write(commit_hash)
-
-    # Set up the snapshot directory
-    snapshot_dir = snapshots_dir / commit_hash
-    await aios.makedirs(snapshot_dir, exist_ok=True)
-
-    # Set up the cached file list directory
-    cached_file_list_dir = cachedreqs_dir / commit_hash
-    await aios.makedirs(cached_file_list_dir, exist_ok=True)
-    cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
-
+      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
+
+    if downloaded_size == total_size:
+      print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    DOWNLOAD_CHUNK_SIZE = 32768
+    start_time = datetime.now()
+    async with aiofiles.open(local_path, mode) as f:
+      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
+        await f.write(chunk)
+        downloaded_size += len(chunk)
+        downloaded_this_session += len(chunk)
+        if progress_callback and total_size:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
+          remaining_size = total_size - downloaded_size
+          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
+          status = "in_progress" if downloaded_size < total_size else "complete"
+          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+    if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+
+async def download_repo_files(
+  repo_id: str,
+  revision: str = "main",
+  progress_callback: Optional[RepoProgressCallback] = None,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  max_parallel_downloads: int = 4
+) -> Path:
+  repo_root = get_repo_root(repo_id)
+  refs_dir = repo_root/"refs"
+  snapshots_dir = repo_root/"snapshots"
+  cachedreqs_dir = repo_root/"cachedreqs"
+
+  # Ensure directories exist
+  await aios.makedirs(refs_dir, exist_ok=True)
+  await aios.makedirs(snapshots_dir, exist_ok=True)
+  await aios.makedirs(cachedreqs_dir, exist_ok=True)
+
+  # Check if we have a cached commit hash
+  refs_file = refs_dir/revision
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}")
+  else:
     async with aiohttp.ClientSession() as session:
     async with aiohttp.ClientSession() as session:
-        # Check if we have a cached file list
-        if await aios.path.exists(cached_file_list_path):
-            async with aiofiles.open(cached_file_list_path, 'r') as f:
-                file_list = json.loads(await f.read())
-            if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-        else:
-            file_list = await fetch_file_list(session, repo_id, revision)
-            # Cache the file list
-            async with aiofiles.open(cached_file_list_path, 'w') as f:
-                await f.write(json.dumps(file_list))
-            if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
+      # Fetch the commit hash for the given revision
+      api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}"
+      headers = await get_auth_headers()
+      async with session.get(api_url, headers=headers) as response:
+        if response.status != 200:
+          raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+        revision_info = await response.json()
+        commit_hash = revision_info['sha']
+
+<<<<<<< HEAD
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
         if DEBUG >= 2: print(f"Filtered file list {allow_patterns=} {ignore_patterns=}\noriginal: {file_list}\nfiltered: {filtered_file_list}")
         if DEBUG >= 2: print(f"Filtered file list {allow_patterns=} {ignore_patterns=}\noriginal: {file_list}\nfiltered: {filtered_file_list}")
         total_files = len(filtered_file_list)
         total_files = len(filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
         total_bytes = sum(file["size"] for file in filtered_file_list)
         file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         file_progress: Dict[str, RepoFileProgressEvent] = {file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") for file in filtered_file_list}
         start_time = datetime.now()
         start_time = datetime.now()
+=======
+      # Cache the commit hash
+      async with aiofiles.open(refs_file, 'w') as f:
+        await f.write(commit_hash)
+>>>>>>> main
+
+  # Set up the snapshot directory
+  snapshot_dir = snapshots_dir/commit_hash
+  await aios.makedirs(snapshot_dir, exist_ok=True)
+
+  # Set up the cached file list directory
+  cached_file_list_dir = cachedreqs_dir/commit_hash
+  await aios.makedirs(cached_file_list_dir, exist_ok=True)
+  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
+
+  async with aiohttp.ClientSession() as session:
+    # Check if we have a cached file list
+    if await aios.path.exists(cached_file_list_path):
+      async with aiofiles.open(cached_file_list_path, 'r') as f:
+        file_list = json.loads(await f.read())
+      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
+    else:
+      file_list = await fetch_file_list(session, repo_id, revision)
+      # Cache the file list
+      async with aiofiles.open(cached_file_list_path, 'w') as f:
+        await f.write(json.dumps(file_list))
+      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
+
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+    total_files = len(filtered_file_list)
+    total_bytes = sum(file["size"] for file in filtered_file_list)
+    file_progress: Dict[str, RepoFileProgressEvent] = {
+      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
+      for file in filtered_file_list
+    }
+    start_time = datetime.now()
+
+    async def download_with_progress(file_info, progress_state):
+      local_path = snapshot_dir/file_info["path"]
+      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
+        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
+        progress_state['completed_files'] += 1
+        progress_state['downloaded_bytes'] += file_info["size"]
+        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+        return
+
+      async def file_progress_callback(event: RepoFileProgressEvent):
+        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
+        file_progress[event.file_path] = event
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+
+      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
+      progress_state['completed_files'] += 1
+      file_progress[
+        file_info["path"]
+      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+      if progress_callback:
+        elapsed_time = (datetime.now() - start_time).total_seconds()
+        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
+        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+        await progress_callback(
+          RepoProgressEvent(
+            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+            overall_eta, file_progress, status
+          )
+        )
+
+    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+
+    async def download_with_semaphore(file_info):
+      async with semaphore:
+        await download_with_progress(file_info, progress_state)
+
+    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
+    await asyncio.gather(*tasks)
+
+  return snapshot_dir
 
 
-        async def download_with_progress(file_info, progress_state):
-            local_path = snapshot_dir / file_info["path"]
-            if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-                if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-                progress_state['completed_files'] += 1
-                progress_state['downloaded_bytes'] += file_info["size"]
-                file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-                if progress_callback:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                    status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-                return
-
-            async def file_progress_callback(event: RepoFileProgressEvent):
-                progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-                progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-                file_progress[event.file_path] = event
-                if progress_callback:
-                    elapsed_time = (datetime.now() - start_time).total_seconds()
-                    overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                    remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                    overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                    status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-                    await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-
-            await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-            progress_state['completed_files'] += 1
-            file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-            if progress_callback:
-                elapsed_time = (datetime.now() - start_time).total_seconds()
-                overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
-                remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-                overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-                status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-                await progress_callback(RepoProgressEvent(repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, overall_eta, file_progress, status))
-
-        progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-        semaphore = asyncio.Semaphore(max_parallel_downloads)
-        async def download_with_semaphore(file_info):
-            async with semaphore:
-                await download_with_progress(file_info, progress_state)
-        tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-        await asyncio.gather(*tasks)
-
-    return snapshot_dir
 
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
 async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-    """
+  """
     Retrieve the weight map from the model.safetensors.index.json file.
     Retrieve the weight map from the model.safetensors.index.json file.
 
 
     Args:
     Args:
@@ -303,56 +361,53 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
         Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
         Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
     """
     """
 
 
-    # Download the index file
-    await download_repo_files(
-        repo_id=repo_id,
-        revision=revision,
-        allow_patterns="model.safetensors.index.json"
-    )
+  # Download the index file
+  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
+
+  # Check if the file exists
+  repo_root = get_repo_root(repo_id)
+  snapshot_dir = repo_root/"snapshots"
+  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
 
 
-    # Check if the file exists
-    repo_root = get_repo_root(repo_id)
-    snapshot_dir = repo_root / "snapshots"
-    index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
+  if index_file:
+    index_file_path = snapshot_dir/index_file
+    if await aios.path.exists(index_file_path):
+      async with aiofiles.open(index_file_path, 'r') as f:
+        index_data = json.loads(await f.read())
+      return index_data.get("weight_map")
 
 
-    if index_file:
-        index_file_path = snapshot_dir / index_file
-        if await aios.path.exists(index_file_path):
-            async with aiofiles.open(index_file_path, 'r') as f:
-                index_data = json.loads(await f.read())
-            return index_data.get("weight_map")
+  return None
 
 
-    return None
 
 
 def extract_layer_num(tensor_name: str) -> Optional[int]:
 def extract_layer_num(tensor_name: str) -> Optional[int]:
-    # This is a simple example and might need to be adjusted based on the actual naming convention
-    parts = tensor_name.split('.')
-    for part in parts:
-        if part.isdigit():
-            return int(part)
-    return None
+  # This is a simple example and might need to be adjusted based on the actual naming convention
+  parts = tensor_name.split('.')
+  for part in parts:
+    if part.isdigit():
+      return int(part)
+  return None
 
 
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-    default_patterns = [
-        "*.json",
-        "*.py",
-        "tokenizer.model",
-        "*.tiktoken",
-        "*.txt",
-    ]
-    shard_specific_patterns = []
-    if weight_map:
-        for tensor_name, filename in weight_map.items():
-            layer_num = extract_layer_num(tensor_name)
-            if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
-                shard_specific_patterns.append(filename)
-        sorted_file_names = sorted(weight_map.values())
-        if shard.is_first_layer():
-            shard_specific_patterns.append(sorted_file_names[0])
-        elif shard.is_last_layer():
-            shard_specific_patterns.append(sorted_file_names[-1])
-    else:
-        shard_specific_patterns = ["*.safetensors"]
-    if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
-    return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates
+  default_patterns = [
+    "*.json",
+    "*.py",
+    "tokenizer.model",
+    "*.tiktoken",
+    "*.txt",
+  ]
+  shard_specific_patterns = []
+  if weight_map:
+    for tensor_name, filename in weight_map.items():
+      layer_num = extract_layer_num(tensor_name)
+      if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+        shard_specific_patterns.append(filename)
+    sorted_file_names = sorted(weight_map.values())
+    if shard.is_first_layer():
+      shard_specific_patterns.append(sorted_file_names[0])
+    elif shard.is_last_layer():
+      shard_specific_patterns.append(sorted_file_names[-1])
+  else:
+    shard_specific_patterns = ["*.safetensors"]
+  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
+  return list(set(default_patterns + shard_specific_patterns))  # Remove duplicates

+ 58 - 60
exo/download/hf/hf_shard_download.py

@@ -8,72 +8,70 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.helpers import AsyncCallbackSystem, DEBUG
 
 
+
 class HFShardDownloader(ShardDownloader):
 class HFShardDownloader(ShardDownloader):
-    def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-        self.quick_check = quick_check
-        self.max_parallel_downloads = max_parallel_downloads
-        self.active_downloads: Dict[Shard, asyncio.Task] = {}
-        self.completed_downloads: Dict[Shard, Path] = {}
-        self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
+    self.quick_check = quick_check
+    self.max_parallel_downloads = max_parallel_downloads
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+    self.completed_downloads: Dict[Shard, Path] = {}
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
 
-    async def ensure_shard(self, shard: Shard) -> Path:
-        if shard in self.completed_downloads:
-            return self.completed_downloads[shard]
-        if self.quick_check:
-            repo_root = get_repo_root(shard.model_id)
-            snapshots_dir = repo_root / "snapshots"
-            if snapshots_dir.exists():
-                most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
-                return most_recent_dir
+  async def ensure_shard(self, shard: Shard) -> Path:
+    if shard in self.completed_downloads:
+      return self.completed_downloads[shard]
+    if self.quick_check:
+      repo_root = get_repo_root(shard.model_id)
+      snapshots_dir = repo_root/"snapshots"
+      if snapshots_dir.exists():
+        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
+        if visible_dirs:
+          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
+          return most_recent_dir
 
 
-        # If a download on this shard is already in progress, keep that one
-        for active_shard in self.active_downloads:
-            if active_shard == shard:
-                if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-                return await self.active_downloads[shard]
+    # If a download on this shard is already in progress, keep that one
+    for active_shard in self.active_downloads:
+      if active_shard == shard:
+        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
+        return await self.active_downloads[shard]
 
 
-        # Cancel any downloads for this model_id on a different shard
-        existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-        for active_shard in existing_active_shards:
-            if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-            task = self.active_downloads[active_shard]
-            task.cancel()
-            try:
-                await task
-            except asyncio.CancelledError:
-                pass
-            except Exception as e:
-                if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-                traceback.print_exc()
-        self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
+    # Cancel any downloads for this model_id on a different shard
+    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+    for active_shard in existing_active_shards:
+      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+      task = self.active_downloads[active_shard]
+      task.cancel()
+      try:
+        await task
+      except asyncio.CancelledError:
+        pass  # This is expected when cancelling a task
+      except Exception as e:
+        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+        traceback.print_exc()
+    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
 
-        # Start new download
-        download_task = asyncio.create_task(self._download_shard(shard))
-        self.active_downloads[shard] = download_task
-        try:
-            path = await download_task
-            self.completed_downloads[shard] = path
-            return path
-        finally:
-            # Ensure the task is removed even if an exception occurs
-            print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-            if shard in self.active_downloads:
-                self.active_downloads.pop(shard)
+    # Start new download
+    download_task = asyncio.create_task(self._download_shard(shard))
+    self.active_downloads[shard] = download_task
+    try:
+      path = await download_task
+      self.completed_downloads[shard] = path
+      return path
+    finally:
+      # Ensure the task is removed even if an exception occurs
+      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+      if shard in self.active_downloads:
+        self.active_downloads.pop(shard)
 
 
-    async def _download_shard(self, shard: Shard) -> Path:
-        async def wrapped_progress_callback(event: RepoProgressEvent):
-            self._on_progress.trigger_all(shard, event)
+  async def _download_shard(self, shard: Shard) -> Path:
+    async def wrapped_progress_callback(event: RepoProgressEvent):
+      self._on_progress.trigger_all(shard, event)
 
 
-        weight_map = await get_weight_map(shard.model_id)
-        allow_patterns = get_allow_patterns(weight_map, shard)
+    weight_map = await get_weight_map(shard.model_id)
+    allow_patterns = get_allow_patterns(weight_map, shard)
 
 
-        return await download_repo_files(
-            repo_id=shard.model_id,
-            progress_callback=wrapped_progress_callback,
-            allow_patterns=allow_patterns,
-            max_parallel_downloads=self.max_parallel_downloads
-        )
+    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
 
-    @property
-    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-        return self._on_progress
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress

+ 9 - 8
exo/download/shard_download.py

@@ -5,10 +5,11 @@ from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 
 
+
 class ShardDownloader(ABC):
 class ShardDownloader(ABC):
-    @abstractmethod
-    async def ensure_shard(self, shard: Shard) -> Path:
-        """
+  @abstractmethod
+  async def ensure_shard(self, shard: Shard) -> Path:
+    """
         Ensures that the shard is downloaded.
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
         Does not allow multiple overlapping downloads at once.
         If you try to download a Shard which overlaps a Shard that is already being downloaded,
         If you try to download a Shard which overlaps a Shard that is already being downloaded,
@@ -17,9 +18,9 @@ class ShardDownloader(ABC):
         Args:
         Args:
             shard (Shard): The shard to download.
             shard (Shard): The shard to download.
         """
         """
-        pass
+    pass
 
 
-    @property
-    @abstractmethod
-    def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-        pass
+  @property
+  @abstractmethod
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    pass

+ 82 - 89
exo/helpers.py

@@ -20,6 +20,7 @@ exo_text = r"""
  \___/_/\_\___/ 
  \___/_/\_\___/ 
     """
     """
 
 
+
 def get_system_info():
 def get_system_info():
   if psutil.MACOS:
   if psutil.MACOS:
     if platform.machine() == "arm64":
     if platform.machine() == "arm64":
@@ -32,21 +33,6 @@ def get_system_info():
   return "Non-Mac, non-Linux system"
   return "Non-Mac, non-Linux system"
 
 
 
 
-def get_inference_engine(inference_engine_name, shard_downloader: 'ShardDownloader'):
-  if inference_engine_name == "mlx":
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-
-    return MLXDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "tinygrad":
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    import tinygrad.helpers
-    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-
-    return TinygradDynamicShardInferenceEngine(shard_downloader)
-  else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
-
-
 def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
 def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
   used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
   used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
 
 
@@ -102,6 +88,8 @@ def terminal_link(uri, label=None):
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 K = TypeVar("K")
 K = TypeVar("K")
+
+
 class AsyncCallback(Generic[T]):
 class AsyncCallback(Generic[T]):
   def __init__(self) -> None:
   def __init__(self) -> None:
     self.condition: asyncio.Condition = asyncio.Condition()
     self.condition: asyncio.Condition = asyncio.Condition()
@@ -110,9 +98,7 @@ class AsyncCallback(Generic[T]):
 
 
   async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
   async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
     async with self.condition:
     async with self.condition:
-      await asyncio.wait_for(
-        self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
-      )
+      await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
       assert self.result is not None  # for type checking
       assert self.result is not None  # for type checking
       return self.result
       return self.result
 
 
@@ -154,89 +140,96 @@ class AsyncCallbackSystem(Generic[K, T]):
 
 
 K = TypeVar('K', bound=str)
 K = TypeVar('K', bound=str)
 V = TypeVar('V')
 V = TypeVar('V')
+
+
 class PrefixDict(Generic[K, V]):
 class PrefixDict(Generic[K, V]):
-    def __init__(self):
-        self.items: Dict[K, V] = {}
+  def __init__(self):
+    self.items: Dict[K, V] = {}
 
 
-    def add(self, key: K, value: V) -> None:
-        self.items[key] = value
+  def add(self, key: K, value: V) -> None:
+    self.items[key] = value
 
 
-    def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
-        return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
+  def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
+    return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
 
 
-    def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
-        matches = self.find_prefix(argument)
-        if len(matches) == 0:
-            return None
+  def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
+    matches = self.find_prefix(argument)
+    if len(matches) == 0:
+      return None
+
+    return max(matches, key=lambda x: len(x[0]))
 
 
-        return max(matches, key=lambda x: len(x[0]))
 
 
 def is_valid_uuid(val):
 def is_valid_uuid(val):
-    try:
-        uuid.UUID(str(val))
-        return True
-    except ValueError:
-        return False
+  try:
+    uuid.UUID(str(val))
+    return True
+  except ValueError:
+    return False
+
 
 
 def get_or_create_node_id():
 def get_or_create_node_id():
-    NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
-    try:
-        if NODE_ID_FILE.is_file():
-            with open(NODE_ID_FILE, "r") as f:
-                stored_id = f.read().strip()
-            if is_valid_uuid(stored_id):
-                if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
-                return stored_id
-            else:
-                if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
-
-        new_id = str(uuid.uuid4())
-        with open(NODE_ID_FILE, "w") as f:
-            f.write(new_id)
-
-        if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
-        return new_id
-    except IOError as e:
-        if DEBUG >= 2: print(f"IO error creating node_id: {e}")
-        return str(uuid.uuid4())
-    except Exception as e:
-        if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
-        return str(uuid.uuid4())
+  NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id"
+  try:
+    if NODE_ID_FILE.is_file():
+      with open(NODE_ID_FILE, "r") as f:
+        stored_id = f.read().strip()
+      if is_valid_uuid(stored_id):
+        if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
+        return stored_id
+      else:
+        if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
+
+    new_id = str(uuid.uuid4())
+    with open(NODE_ID_FILE, "w") as f:
+      f.write(new_id)
+
+    if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
+    return new_id
+  except IOError as e:
+    if DEBUG >= 2: print(f"IO error creating node_id: {e}")
+    return str(uuid.uuid4())
+  except Exception as e:
+    if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
+    return str(uuid.uuid4())
+
 
 
 def pretty_print_bytes(size_in_bytes: int) -> str:
 def pretty_print_bytes(size_in_bytes: int) -> str:
-    if size_in_bytes < 1024:
-        return f"{size_in_bytes} B"
-    elif size_in_bytes < 1024 ** 2:
-        return f"{size_in_bytes / 1024:.2f} KB"
-    elif size_in_bytes < 1024 ** 3:
-        return f"{size_in_bytes / (1024 ** 2):.2f} MB"
-    elif size_in_bytes < 1024 ** 4:
-        return f"{size_in_bytes / (1024 ** 3):.2f} GB"
-    else:
-        return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+  if size_in_bytes < 1024:
+    return f"{size_in_bytes} B"
+  elif size_in_bytes < 1024**2:
+    return f"{size_in_bytes / 1024:.2f} KB"
+  elif size_in_bytes < 1024**3:
+    return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+  elif size_in_bytes < 1024**4:
+    return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+  else:
+    return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
 
 
 def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
 def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
-    if bytes_per_second < 1024:
-        return f"{bytes_per_second} B/s"
-    elif bytes_per_second < 1024 ** 2:
-        return f"{bytes_per_second / 1024:.2f} KB/s"
-    elif bytes_per_second < 1024 ** 3:
-        return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
-    elif bytes_per_second < 1024 ** 4:
-        return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
-    else:
-        return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+  if bytes_per_second < 1024:
+    return f"{bytes_per_second} B/s"
+  elif bytes_per_second < 1024**2:
+    return f"{bytes_per_second / 1024:.2f} KB/s"
+  elif bytes_per_second < 1024**3:
+    return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+  elif bytes_per_second < 1024**4:
+    return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+  else:
+    return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+
 
 
 def get_all_ip_addresses():
 def get_all_ip_addresses():
-    try:
-      ip_addresses = []
-      for interface in netifaces.interfaces():
-        ifaddresses = netifaces.ifaddresses(interface)
-        if netifaces.AF_INET in ifaddresses:
-          for link in ifaddresses[netifaces.AF_INET]:
-            ip = link['addr']
-            ip_addresses.append(ip)
-      return list(set(ip_addresses))
-    except:
-      if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-      return ["localhost"]
+  try:
+    ip_addresses = []
+    for interface in netifaces.interfaces():
+      ifaddresses = netifaces.ifaddresses(interface)
+      if netifaces.AF_INET in ifaddresses:
+        for link in ifaddresses[netifaces.AF_INET]:
+          ip = link['addr']
+          ip_addresses.append(ip)
+    return list(set(ip_addresses))
+  except:
+    if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
+    return ["localhost"]

+ 6 - 8
exo/inference/debug_inference_engine.py

@@ -10,7 +10,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   from exo.inference.tinygrad.inference import Tokenizer
   from exo.inference.tinygrad.inference import Tokenizer
   from pathlib import Path
   from pathlib import Path
 
 
-  _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
+  _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
   prompt = "In a single word only, what is the last name of the president of the United States? "
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
@@ -52,10 +52,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(
-  test_inference_engine(
-    TinygradDynamicShardInferenceEngine(),
-    TinygradDynamicShardInferenceEngine(),
-    "llama3-8b-sfr",
-  )
-)
+asyncio.run(test_inference_engine(
+  TinygradDynamicShardInferenceEngine(),
+  TinygradDynamicShardInferenceEngine(),
+  "llama3-8b-sfr",
+))

+ 17 - 0
exo/inference/inference_engine.py

@@ -1,9 +1,11 @@
 import numpy as np
 import numpy as np
+import os
 
 
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
 
 
+
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
@@ -12,3 +14,18 @@ class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
     pass
+
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    import tinygrad.helpers
+    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
+  else:
+    raise ValueError(f"Inference engine {inference_engine_name} not supported")

+ 3 - 6
exo/inference/mlx/models/deepseek_v2.py

@@ -7,7 +7,7 @@ import mlx.nn as nn
 from mlx_lm.models.base import KVCache
 from mlx_lm.models.base import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
 from .base import IdentityBlock
-from ...shard import Shard
+from exo.inference.shard import Shard
 
 
 
 
 @dataclass
 @dataclass
@@ -59,7 +59,7 @@ class DeepseekV2Model(nn.Module):
       mask = mask.astype(h.dtype)
       mask = mask.astype(h.dtype)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)
       h = layer(h, mask, c)
@@ -107,10 +107,7 @@ class Model(nn.Module):
         for k in ["weight", "scales", "biases"]:
         for k in ["weight", "scales", "biases"]:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
           if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
             to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
             to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
-            shard_state_dict[
-              f"{prefix}.mlp.switch_mlp.{
-       m}.{k}"
-            ] = mx.stack(to_join)
+            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
 
 
     return shard_state_dict
     return shard_state_dict
 
 

+ 4 - 7
exo/inference/mlx/models/llama.py

@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 
 
-from mlx_lm.models.base import create_additive_causal_mask
+from mlx_lm.models.base import create_attention_mask
 from mlx_lm.models.llama import TransformerBlock, ModelArgs
 from mlx_lm.models.llama import TransformerBlock, ModelArgs
 
 
 from ...shard import Shard
 from ...shard import Shard
@@ -40,7 +40,6 @@ class LlamaModel(nn.Module):
         self.layers.append(TransformerBlock(args=args))
         self.layers.append(TransformerBlock(args=args))
       else:
       else:
         self.layers.append(IdentityBlock())
         self.layers.append(IdentityBlock())
-
     if self.args.shard.is_last_layer():
     if self.args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 
 
@@ -56,11 +55,10 @@ class LlamaModel(nn.Module):
 
 
     mask = None
     mask = None
     if h.shape[1] > 1:
     if h.shape[1] > 1:
-      mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
-      mask = mask.astype(h.dtype)
+      mask = create_attention_mask(h, cache)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, cache=c)
       h = layer(h, mask, cache=c)
@@ -75,7 +73,6 @@ class Model(nn.Module):
     super().__init__()
     super().__init__()
     self.args = args
     self.args = args
     self.model_type = args.model_type
     self.model_type = args.model_type
-
     self.model = LlamaModel(args)
     self.model = LlamaModel(args)
     if self.args.shard.is_last_layer():
     if self.args.shard.is_last_layer():
       if not args.tie_word_embeddings:
       if not args.tie_word_embeddings:
@@ -121,7 +118,7 @@ class Model(nn.Module):
 
 
   @property
   @property
   def head_dim(self):
   def head_dim(self):
-    return self.args.hidden_size // self.args.num_attention_heads
+    return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
 
 
   @property
   @property
   def n_kv_heads(self):
   def n_kv_heads(self):

+ 507 - 555
exo/inference/mlx/models/llava.py

@@ -15,619 +15,571 @@ import numpy as np
 
 
 @dataclass
 @dataclass
 class VisionConfig:
 class VisionConfig:
-    model_type: str
-    num_hidden_layers: int = 24
-    hidden_size: int = 1024
-    intermediate_size: int = 4096
-    num_attention_heads: int = 16
-    image_size: int = 336
-    patch_size: int = 14
-    projection_dim: int = 768
-    vocab_size: int = 32000
-    num_channels: int = 3
-    layer_norm_eps: float = 1e-5
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
+  model_type: str
+  num_hidden_layers: int = 24
+  hidden_size: int = 1024
+  intermediate_size: int = 4096
+  num_attention_heads: int = 16
+  image_size: int = 336
+  patch_size: int = 14
+  projection_dim: int = 768
+  vocab_size: int = 32000
+  num_channels: int = 3
+  layer_norm_eps: float = 1e-5
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
 
 
 
 
 class VisionAttention(nn.Module):
 class VisionAttention(nn.Module):
-    def __init__(
-            self,
-            dims: int,
-            num_heads: int,
-            query_input_dims: Optional[int] = None,
-            key_input_dims: Optional[int] = None,
-            value_input_dims: Optional[int] = None,
-            value_dims: Optional[int] = None,
-            value_output_dims: Optional[int] = None,
-            bias: bool = False,
-    ):
-        super().__init__()
-
-        if (dims % num_heads) != 0:
-            raise ValueError(
-                "The input feature dimensions should be divisible by the "
-                f"number of heads ({dims} % {num_heads}) != 0"
-            )
-
-        query_input_dims = query_input_dims or dims
-        key_input_dims = key_input_dims or dims
-        value_input_dims = value_input_dims or key_input_dims
-        value_dims = value_dims or dims
-        value_output_dims = value_output_dims or dims
-
-        self.num_heads = num_heads
-        self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
-        self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
-        self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
-        self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
-
-    def __call__(self, queries, keys, values, mask=None):
-        queries = self.q_proj(queries)
-        keys = self.k_proj(keys)
-        values = self.v_proj(values)
-
-        num_heads = self.num_heads
-        B, L, D = queries.shape
-        _, S, _ = keys.shape
-        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
-        values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
-
-        scale = math.sqrt(1 / queries.shape[-1])
-        scores = (queries * scale) @ keys
-        if mask is not None:
-            scores = scores + mask.astype(scores.dtype)
-        scores = mx.softmax(scores, axis=-1)
-        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
-
-        return self.out_proj(values_hat)
+  def __init__(
+    self,
+    dims: int,
+    num_heads: int,
+    query_input_dims: Optional[int] = None,
+    key_input_dims: Optional[int] = None,
+    value_input_dims: Optional[int] = None,
+    value_dims: Optional[int] = None,
+    value_output_dims: Optional[int] = None,
+    bias: bool = False,
+  ):
+    super().__init__()
+
+    if (dims % num_heads) != 0:
+      raise ValueError("The input feature dimensions should be divisible by the "
+                       f"number of heads ({dims} % {num_heads}) != 0")
+
+    query_input_dims = query_input_dims or dims
+    key_input_dims = key_input_dims or dims
+    value_input_dims = value_input_dims or key_input_dims
+    value_dims = value_dims or dims
+    value_output_dims = value_output_dims or dims
+
+    self.num_heads = num_heads
+    self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
+    self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
+    self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
+    self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
+
+  def __call__(self, queries, keys, values, mask=None):
+    queries = self.q_proj(queries)
+    keys = self.k_proj(keys)
+    values = self.v_proj(values)
+
+    num_heads = self.num_heads
+    B, L, D = queries.shape
+    _, S, _ = keys.shape
+    queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
+    values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+    scale = math.sqrt(1/queries.shape[-1])
+    scores = (queries*scale) @ keys
+    if mask is not None:
+      scores = scores + mask.astype(scores.dtype)
+    scores = mx.softmax(scores, axis=-1)
+    values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+    return self.out_proj(values_hat)
 
 
 
 
 class VisionMLP(nn.Module):
 class VisionMLP(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.activation_fn = nn.GELU(approx="fast")
-        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
-        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.activation_fn = nn.GELU(approx="fast")
+    self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+    self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
 
 
-    def __call__(self, x: mx.array) -> mx.array:
-        x = self.activation_fn(self.fc1(x))
-        x = self.fc2(x)
-        return x
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.activation_fn(self.fc1(x))
+    x = self.fc2(x)
+    return x
 
 
 
 
 class VisionEncoderLayer(nn.Module):
 class VisionEncoderLayer(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.embed_dim = config.hidden_size
-        self.self_attn = VisionAttention(
-            config.hidden_size, config.num_attention_heads, bias=True
-        )
-        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-        self.mlp = VisionMLP(config)
-        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-
-    def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
-        y = self.layer_norm1(x)
-        y = self.self_attn(y, y, y, mask)
-        x = x + y
-        y = self.layer_norm2(x)
-        y = self.mlp(y)
-        return x + y
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embed_dim = config.hidden_size
+    self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True)
+    self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+    self.mlp = VisionMLP(config)
+    self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
+    y = self.layer_norm1(x)
+    y = self.self_attn(y, y, y, mask)
+    x = x + y
+    y = self.layer_norm2(x)
+    y = self.mlp(y)
+    return x + y
 
 
 
 
 class VisionEncoder(nn.Module):
 class VisionEncoder(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
 
 
 
 
 class VisionEmbeddings(nn.Module):
 class VisionEmbeddings(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.config = config
-        self.embed_dim = config.hidden_size
-        self.image_size = config.image_size
-        self.patch_size = config.patch_size
-
-        self.class_embedding = mx.zeros((config.hidden_size,))
-
-        self.patch_embedding = nn.Conv2d(
-            in_channels=config.num_channels,
-            out_channels=self.embed_dim,
-            kernel_size=self.patch_size,
-            stride=self.patch_size,
-            bias=False,
-        )
-
-        self.num_patches = (self.image_size // self.patch_size) ** 2
-        self.num_positions = self.num_patches + 1
-        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
-
-    def __call__(self, x: mx.array) -> mx.array:
-        batch_size = x.shape[0]
-        patch_embeddings = self.patch_embedding(x)
-        patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
-        embed_dim = patch_embeddings.shape[-1]
-        cls_embeddings = mx.broadcast_to(
-            self.class_embedding, (batch_size, 1, embed_dim)
-        )
-        embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
-        embeddings += self.position_embedding.weight
-        return embeddings
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.config = config
+    self.embed_dim = config.hidden_size
+    self.image_size = config.image_size
+    self.patch_size = config.patch_size
+
+    self.class_embedding = mx.zeros((config.hidden_size,))
+
+    self.patch_embedding = nn.Conv2d(
+      in_channels=config.num_channels,
+      out_channels=self.embed_dim,
+      kernel_size=self.patch_size,
+      stride=self.patch_size,
+      bias=False,
+    )
+
+    self.num_patches = (self.image_size // self.patch_size)**2
+    self.num_positions = self.num_patches + 1
+    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    batch_size = x.shape[0]
+    patch_embeddings = self.patch_embedding(x)
+    patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
+    embed_dim = patch_embeddings.shape[-1]
+    cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim))
+    embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
+    embeddings += self.position_embedding.weight
+    return embeddings
 
 
 
 
 class ClipVisionModel(nn.Module):
 class ClipVisionModel(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-        self.embeddings = VisionEmbeddings(config)
-        self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
-        self.encoder = VisionEncoder(config)
-        self.post_layernorm = nn.LayerNorm(config.hidden_size)
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embeddings = VisionEmbeddings(config)
+    self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
+    self.encoder = VisionEncoder(config)
+    self.post_layernorm = nn.LayerNorm(config.hidden_size)
 
 
-    def __call__(
-        self,
-        x: mx.array,
-        output_hidden_states: Optional[bool] = None,
-    ) -> mx.array:
-        x = self.embeddings(x)
-        x = self.pre_layrnorm(x)
+  def __call__(
+    self,
+    x: mx.array,
+    output_hidden_states: Optional[bool] = None,
+  ) -> mx.array:
+    x = self.embeddings(x)
+    x = self.pre_layrnorm(x)
 
 
-        encoder_states = (x,) if output_hidden_states else None
+    encoder_states = (x,) if output_hidden_states else None
 
 
-        for l in self.encoder.layers:
-            x = l(x, mask=None)
-            if output_hidden_states:
-                encoder_states = encoder_states + (x,)
+    for l in self.encoder.layers:
+      x = l(x, mask=None)
+      if output_hidden_states:
+        encoder_states = encoder_states + (x,)
 
 
-        pooler_output = self.post_layernorm(x[:, 0, :])
-        return pooler_output, x, encoder_states
+    pooler_output = self.post_layernorm(x[:, 0, :])
+    return pooler_output, x, encoder_states
 
 
 
 
 class VisionModel(nn.Module):
 class VisionModel(nn.Module):
-    def __init__(self, config: VisionConfig):
-        super().__init__()
-
-        self.model_type = config.model_type
-        if self.model_type != "clip_vision_model":
-            raise ValueError(f"Unsupported model type: {self.model_type}")
-
-        self.vision_model = ClipVisionModel(config)
-
-    def __call__(
-            self, x: mx.array, output_hidden_states: Optional[bool] = None
-    ) -> mx.array:
-        return self.vision_model(x, output_hidden_states)
-
-    def sanitize(self, weights):
-        sanitized_weights = {}
-        for k, v in weights.items():
-            if "position_ids" in k:
-                # Remove unused position_ids
-                continue
-            elif "patch_embedding.weight" in k:
-                # PyTorch conv2d weight tensors have shape:
-                #   [out_channels, in_channels, kH, KW]
-                # MLX conv2d expects the weight be of shape:
-                #   [out_channels, kH, KW, in_channels]
-                sanitized_weights[k] = v.transpose(0, 2, 3, 1)
-            else:
-                sanitized_weights[k] = v
-
-        return sanitized_weights
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+
+    self.model_type = config.model_type
+    if self.model_type != "clip_vision_model":
+      raise ValueError(f"Unsupported model type: {self.model_type}")
+
+    self.vision_model = ClipVisionModel(config)
+
+  def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array:
+    return self.vision_model(x, output_hidden_states)
+
+  def sanitize(self, weights):
+    sanitized_weights = {}
+    for k, v in weights.items():
+      if "position_ids" in k:
+        # Remove unused position_ids
+        continue
+      elif "patch_embedding.weight" in k:
+        # PyTorch conv2d weight tensors have shape:
+        #   [out_channels, in_channels, kH, KW]
+        # MLX conv2d expects the weight be of shape:
+        #   [out_channels, kH, KW, in_channels]
+        sanitized_weights[k] = v.transpose(0, 2, 3, 1)
+      else:
+        sanitized_weights[k] = v
+
+    return sanitized_weights
 
 
 
 
 @dataclass
 @dataclass
 class TextConfig:
 class TextConfig:
-    model_type: str
-    hidden_size: int = 4096
-    num_hidden_layers: int = 32
-    intermediate_size: int = 11008
-    num_attention_heads: int = 32
-    head_dim: int = None
-    rms_norm_eps: float = 1e-6
-    vocab_size: int = 32000
-    num_key_value_heads: int = None
-    rope_theta: float = 10000
-    rope_traditional: bool = False
-    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(
-            **{
-                k: v
-                for k, v in params.items()
-                if k in inspect.signature(cls).parameters
-            }
-        )
-
-    def __post_init__(self):
-        if self.num_key_value_heads is None:
-            self.num_key_value_heads = self.num_attention_heads
-
-        if self.head_dim is None:
-            self.head_dim = self.hidden_size // self.num_attention_heads
-
-        if self.model_type is None:
-            self.model_type = "llama"
-
-        if self.rope_scaling:
-            required_keys = {"factor", "type"}
-            if not all(key in self.rope_scaling for key in required_keys):
-                raise ValueError(f"rope_scaling must contain keys {required_keys}")
-
-            if self.rope_scaling["type"] != "linear":
-                raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+  model_type: str
+  hidden_size: int = 4096
+  num_hidden_layers: int = 32
+  intermediate_size: int = 11008
+  num_attention_heads: int = 32
+  head_dim: int = None
+  rms_norm_eps: float = 1e-6
+  vocab_size: int = 32000
+  num_key_value_heads: int = None
+  rope_theta: float = 10000
+  rope_traditional: bool = False
+  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+  def __post_init__(self):
+    if self.num_key_value_heads is None:
+      self.num_key_value_heads = self.num_attention_heads
+
+    if self.head_dim is None:
+      self.head_dim = self.hidden_size // self.num_attention_heads
+
+    if self.model_type is None:
+      self.model_type = "llama"
+
+    if self.rope_scaling:
+      required_keys = {"factor", "type"}
+      if not all(key in self.rope_scaling for key in required_keys):
+        raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+      if self.rope_scaling["type"] != "linear":
+        raise ValueError("rope_scaling 'type' currently only supports 'linear'")
 
 
 
 
 class TextAttention(nn.Module):
 class TextAttention(nn.Module):
-    def __init__(self, config: TextConfig):
-        super().__init__()
-
-        dim = config.hidden_size
-        self.n_heads = n_heads = config.num_attention_heads
-        self.n_kv_heads = n_kv_heads = config.num_key_value_heads
-
-        self.repeats = n_heads // n_kv_heads
-
-        head_dim = config.hidden_size // n_heads
-        self.scale = head_dim ** -0.5
-
-        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
-        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
-
-        rope_scale = (
-            1 / config.rope_scaling["factor"]
-            if config.rope_scaling is not None
-               and config.rope_scaling["type"] == "linear"
-            else 1
-        )
-        self.rope = nn.RoPE(
-            head_dim,
-            traditional=config.rope_traditional,
-            base=config.rope_theta,
-            scale=rope_scale,
-        )
-
-    def __call__(
-            self,
-            x: mx.array,
-            mask: Optional[mx.array] = None,
-            cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        B, L, D = x.shape
-
-        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-        # Prepare the queries, keys and values for the attention computation
-        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-        if cache is not None:
-            queries = self.rope(queries, offset=cache.offset)
-            keys = self.rope(keys, offset=cache.offset)
-            keys, values = cache.update_and_fetch(keys, values)
-        else:
-            queries = self.rope(queries)
-            keys = self.rope(keys)
-
-        output = mx.fast.scaled_dot_product_attention(
-            queries, keys, values, scale=self.scale, mask=mask
-        )
-        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-        return self.o_proj(output)
+  def __init__(self, config: TextConfig):
+    super().__init__()
+
+    dim = config.hidden_size
+    self.n_heads = n_heads = config.num_attention_heads
+    self.n_kv_heads = n_kv_heads = config.num_key_value_heads
+
+    self.repeats = n_heads // n_kv_heads
+
+    head_dim = config.hidden_size // n_heads
+    self.scale = head_dim**-0.5
+
+    self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
+    self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
+    self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
+    self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
+
+    rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
+    self.rope = nn.RoPE(
+      head_dim,
+      traditional=config.rope_traditional,
+      base=config.rope_theta,
+      scale=rope_scale,
+    )
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    B, L, D = x.shape
+
+    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+    # Prepare the queries, keys and values for the attention computation
+    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+    if cache is not None:
+      queries = self.rope(queries, offset=cache.offset)
+      keys = self.rope(keys, offset=cache.offset)
+      keys, values = cache.update_and_fetch(keys, values)
+    else:
+      queries = self.rope(queries)
+      keys = self.rope(keys)
+
+    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
+    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+    return self.o_proj(output)
 
 
 
 
 class TextMLP(nn.Module):
 class TextMLP(nn.Module):
-    def __init__(self, dim, hidden_dim):
-        super().__init__()
-        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
-        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
-        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
+  def __init__(self, dim, hidden_dim):
+    super().__init__()
+    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
+    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 
 
-    def __call__(self, x) -> mx.array:
-        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+  def __call__(self, x) -> mx.array:
+    return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
 
 
 
 
 class TransformerBlock(nn.Module):
 class TransformerBlock(nn.Module):
-    def __init__(self, config: TextConfig):
-        super().__init__()
-        self.num_attention_heads = config.num_attention_heads
-        self.hidden_size = config.hidden_size
-        self.self_attn = TextAttention(config)
-        self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
-        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-        self.post_attention_layernorm = nn.RMSNorm(
-            config.hidden_size, eps=config.rms_norm_eps
-        )
-        self.config = config
-
-    def __call__(
-            self,
-            x: mx.array,
-            mask: Optional[mx.array] = None,
-            cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        r = self.self_attn(self.input_layernorm(x), mask, cache)
-        h = x + r
-        r = self.mlp(self.post_attention_layernorm(h))
-        out = h + r
-        return out
+  def __init__(self, config: TextConfig):
+    super().__init__()
+    self.num_attention_heads = config.num_attention_heads
+    self.hidden_size = config.hidden_size
+    self.self_attn = TextAttention(config)
+    self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
+    self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.config = config
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    r = self.self_attn(self.input_layernorm(x), mask, cache)
+    h = x + r
+    r = self.mlp(self.post_attention_layernorm(h))
+    out = h + r
+    return out
 
 
 
 
 class Llama(nn.Module):
 class Llama(nn.Module):
-    def __init__(self, config: TextConfig, shard: Shard):
-        super().__init__()
-        self.config = config
-        self.shard = shard
-        self.vocab_size = config.vocab_size
-        self.model_type = config.model_type
-        self.num_hidden_layers = config.num_hidden_layers
-        self.num_key_value_heads = config.num_key_value_heads
-        self.head_dim = config.head_dim
-        assert self.vocab_size > 0
-        if self.shard.is_first_layer():
-            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
-        self.layers = []
-        for i in range(self.num_hidden_layers):
-          if self.shard.start_layer <= i <= self.shard.end_layer:
-            self.layers.append(TransformerBlock(config=config))
-          else:
-            self.layers.append(IdentityBlock())
-        if self.shard.is_last_layer():
-            self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
-    def __call__(
-            self,
-            inputs: mx.array,
-            cache=None,
-            inputs_embeds=None,
-    ):
-        # for passing merged input embeddings
-        if inputs_embeds is None:
-            if self.shard.is_first_layer():
-                h = self.embed_tokens(inputs)
-            else:
-                h = inputs
-        else:
-            h = inputs_embeds
-
-        mask = None
-        if h.shape[1] > 1:
-            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
-            mask = mask.astype(h.dtype)
-
-        if cache is None:
-            cache = [None] * len(self.layers)
-
-
-        for layer, c in zip(self.layers, cache):
-            h = layer(h, mask, c)
-
-        if self.shard.is_last_layer():
-            h = self.norm(h)
-        return h
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.config = config
+    self.shard = shard
+    self.vocab_size = config.vocab_size
+    self.model_type = config.model_type
+    self.num_hidden_layers = config.num_hidden_layers
+    self.num_key_value_heads = config.num_key_value_heads
+    self.head_dim = config.head_dim
+    assert self.vocab_size > 0
+    if self.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.shard.start_layer <= i <= self.shard.end_layer:
+        self.layers.append(TransformerBlock(config=config))
+      else:
+        self.layers.append(IdentityBlock())
+    if self.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    # for passing merged input embeddings
+    if inputs_embeds is None:
+      if self.shard.is_first_layer():
+        h = self.embed_tokens(inputs)
+      else:
+        h = inputs
+    else:
+      h = inputs_embeds
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
 
 
 class LanguageModel(nn.Module):
 class LanguageModel(nn.Module):
-    def __init__(self, config: TextConfig, shard: Shard):
-        super().__init__()
-        self.model_type = config.model_type
-        if self.model_type != "llama":
-            raise ValueError(
-                f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
-            )
-        self.shard = shard
-        self.model = Llama(config, shard)
-        if self.shard.is_last_layer():
-            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-        inputs_embeds=None,
-    ):
-        out = self.model(inputs, cache, inputs_embeds)
-        if self.shard.is_last_layer():
-            out = self.lm_head(out)
-        return out
-
-    def sanitize(self, weights):
-        shard_state_dict = {}
-        for key, value in weights.items():
-            if "self_attn.rotary_emb.inv_freq" in key:
-                continue
-
-            if key.startswith('language_model.model.layers.'):
-                layer_num = int(key.split('.')[3])
-                if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
-                    continue
-            if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
-                continue
-            elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
-                continue
-
-            shard_state_dict[key] = value
-
-        return shard_state_dict
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.model_type = config.model_type
+    if self.model_type != "llama":
+      raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported")
+    self.shard = shard
+    self.model = Llama(config, shard)
+    if self.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    out = self.model(inputs, cache, inputs_embeds)
+    if self.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+
+      if key.startswith('language_model.model.layers.'):
+        layer_num = int(key.split('.')[3])
+        if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+          continue
+      if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
+        continue
+      elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
+        continue
+
+      shard_state_dict[key] = value
+
+    return shard_state_dict
+
 
 
 @dataclass
 @dataclass
 class LlaVAConfig(BaseModelArgs):
 class LlaVAConfig(BaseModelArgs):
-    text_config: TextConfig
-    vision_config: VisionConfig = None
-    model_type: str = "llava"
-    ignore_index: int = -100
-    image_token_index: int = 32000
-    vision_feature_select_strategy: str = "default"
-    vision_feature_layer: int = -2
-    vocab_size: int = 32000
-
-    @classmethod
-    def from_dict(cls, params):
-        updated_params = {}
-        class_params = inspect.signature(cls).parameters
-        for k, v in params.items():
-            if k in class_params:
-                if k in ["text_config", "vision_config"]:
-                    v = class_params[k].annotation.from_dict(v)
-                updated_params.update({k: v})
-
-        return cls(**updated_params)
+  text_config: TextConfig
+  vision_config: VisionConfig = None
+  model_type: str = "llava"
+  ignore_index: int = -100
+  image_token_index: int = 32000
+  vision_feature_select_strategy: str = "default"
+  vision_feature_layer: int = -2
+  vocab_size: int = 32000
+
+  @classmethod
+  def from_dict(cls, params):
+    updated_params = {}
+    class_params = inspect.signature(cls).parameters
+    for k, v in params.items():
+      if k in class_params:
+        if k in ["text_config", "vision_config"]:
+          v = class_params[k].annotation.from_dict(v)
+        updated_params.update({k: v})
+
+    return cls(**updated_params)
 
 
 
 
 @dataclass
 @dataclass
 class ModelArgs(LlaVAConfig):
 class ModelArgs(LlaVAConfig):
-    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
 
 
-    def __post_init__(self):
-        if isinstance(self.shard, dict):
-            self.shard = Shard(**self.shard)
+  def __post_init__(self):
+    if isinstance(self.shard, dict):
+      self.shard = Shard(**self.shard)
 
 
-        if not isinstance(self.shard, Shard):
-            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+    if not isinstance(self.shard, Shard):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
 
 
-        if not self.shard.is_first_layer():
-            self.vision_config = None
+    if not self.shard.is_first_layer():
+      self.vision_config = None
 
 
 
 
 class LlavaMultiModalProjector(nn.Module):
 class LlavaMultiModalProjector(nn.Module):
-    def __init__(self, config: LlaVAConfig):
-        super().__init__()
-        self.linear_1 = nn.Linear(
-            config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
-        )
-        self.gelu = nn.GELU()
-        self.linear_2 = nn.Linear(
-            config.text_config.hidden_size, config.text_config.hidden_size, bias=True
-        )
-
-    def __call__(self, x: mx.array) -> mx.array:
-        x = self.linear_1(x)
-        x = self.gelu(x)
-        x = self.linear_2(x)
-        return x
+  def __init__(self, config: LlaVAConfig):
+    super().__init__()
+    self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+    self.gelu = nn.GELU()
+    self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.linear_1(x)
+    x = self.gelu(x)
+    x = self.linear_2(x)
+    return x
 
 
 
 
 class Model(nn.Module):
 class Model(nn.Module):
-    def __init__(self, config: ModelArgs):
-        super().__init__()
-        self.config = config
-        self.model_type = config.model_type
-        if config.vision_config:
-            self.vision_tower = VisionModel(config.vision_config)
-            self.multi_modal_projector = LlavaMultiModalProjector(config)
-            self.vision_feature_layer = config.vision_feature_layer
-            self.vision_feature_select_strategy = config.vision_feature_select_strategy
-        self.language_model = LanguageModel(config.text_config, config.shard)
-
-    def get_input_embeddings(
-            self,
-            input_ids: Optional[mx.array] = None,
-            pixel_values: Optional[mx.array] = None,
-    ):
-        if pixel_values is None:
-            return self.language_model(input_ids)
-
-        # Get the input embeddings from the language model
-        inputs_embeds = self.language_model.model.embed_tokens(input_ids)
-
-        # Get the ouptut hidden states from the vision model
-        *_, hidden_states = self.vision_tower(
-            pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
-        )
-
-        # Select the hidden states from the desired layer
-        selected_image_feature = hidden_states[self.vision_feature_layer]
-
-        if self.vision_feature_select_strategy == "default":
-            selected_image_feature = selected_image_feature[:, 1:]
-        elif self.vision_feature_select_strategy == "full":
-            selected_image_feature = selected_image_feature
-        else:
-            raise ValueError(
-                "Unexpected feature selection strategy: "
-                f"{self.vision_feature_select_strategy}"
-            )
-
-        # Pass image features through the multi-modal projector
-        image_features = self.multi_modal_projector(selected_image_feature)
-
-        # Insert special image tokens in the input_ids
-        final_inputs_embeds = self._merge_input_ids_with_image_features(
-            image_features, inputs_embeds, input_ids
-        )
-        return final_inputs_embeds
-
-    def _merge_input_ids_with_image_features(
-            self, image_features, inputs_embeds, input_ids
-    ):
-        image_token_index = self.config.image_token_index
-        num_images, num_image_patches, embed_dim = image_features.shape
-
-        # Positions of <image> tokens in input_ids, assuming batch size is 1
-        image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
-
-        if len(image_positions) != num_images:
-            raise ValueError(
-                f"The number of image tokens ({len(image_positions)}) does not "
-                f" match the number of image inputs ({num_images})."
-            )
-
-        text_segments = []
-        start_idx = 0
-
-        for position in image_positions:
-            text_segments.append(inputs_embeds[:, start_idx:position])
-            start_idx = position + 1
-
-        image_embeddings = mx.split(image_features, image_features.shape[0])
-        final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
-        final_embeddings += [inputs_embeds[:, start_idx:]]
-
-        # Create a final embedding of shape
-        # (1, num_image_patches*num_images + sequence_len, embed_dim)
-        return mx.concatenate(final_embeddings, axis=1)
-
-    def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
-        input_embddings = None
-        if pixel_values is not None:
-            input_embddings = self.get_input_embeddings(input_ids, pixel_values)
-        logits = self.language_model(
-            input_ids, cache=cache, inputs_embeds=input_embddings
-        )
-        return logits
-
-    def sanitize(self, weights):
-        if self.config.vision_config:
-          weights = self.vision_tower.sanitize(weights)
-        else:
-          weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
-        weights = self.language_model.sanitize(weights)
-        return weights
-
-    @property
-    def layers(self):
-        return self.language_model.model.layers
-
-    @property
-    def head_dim(self):
-        return (
-                self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads
-        )
-
-    @property
-    def n_kv_heads(self):
-        return self.language_model.model.num_key_value_heads
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.config = config
+    self.model_type = config.model_type
+    if config.vision_config:
+      self.vision_tower = VisionModel(config.vision_config)
+      self.multi_modal_projector = LlavaMultiModalProjector(config)
+      self.vision_feature_layer = config.vision_feature_layer
+      self.vision_feature_select_strategy = config.vision_feature_select_strategy
+    self.language_model = LanguageModel(config.text_config, config.shard)
+
+  def get_input_embeddings(
+    self,
+    input_ids: Optional[mx.array] = None,
+    pixel_values: Optional[mx.array] = None,
+  ):
+    if pixel_values is None:
+      return self.language_model(input_ids)
+
+    # Get the input embeddings from the language model
+    inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+
+    # Get the ouptut hidden states from the vision model
+    *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)
+
+    # Select the hidden states from the desired layer
+    selected_image_feature = hidden_states[self.vision_feature_layer]
+
+    if self.vision_feature_select_strategy == "default":
+      selected_image_feature = selected_image_feature[:, 1:]
+    elif self.vision_feature_select_strategy == "full":
+      selected_image_feature = selected_image_feature
+    else:
+      raise ValueError("Unexpected feature selection strategy: "
+                       f"{self.vision_feature_select_strategy}")
+
+    # Pass image features through the multi-modal projector
+    image_features = self.multi_modal_projector(selected_image_feature)
+
+    # Insert special image tokens in the input_ids
+    final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
+    return final_inputs_embeds
+
+  def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
+    image_token_index = self.config.image_token_index
+    num_images, num_image_patches, embed_dim = image_features.shape
+
+    # Positions of <image> tokens in input_ids, assuming batch size is 1
+    image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
+
+    if len(image_positions) != num_images:
+      raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
+                       f" match the number of image inputs ({num_images}).")
+
+    text_segments = []
+    start_idx = 0
+
+    for position in image_positions:
+      text_segments.append(inputs_embeds[:, start_idx:position])
+      start_idx = position + 1
+
+    image_embeddings = mx.split(image_features, image_features.shape[0])
+    final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
+    final_embeddings += [inputs_embeds[:, start_idx:]]
+
+    # Create a final embedding of shape
+    # (1, num_image_patches*num_images + sequence_len, embed_dim)
+    return mx.concatenate(final_embeddings, axis=1)
+
+  def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
+    input_embddings = None
+    if pixel_values is not None:
+      input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+    logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
+    return logits
+
+  def sanitize(self, weights):
+    if self.config.vision_config:
+      weights = self.vision_tower.sanitize(weights)
+    else:
+      weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
+    weights = self.language_model.sanitize(weights)
+    return weights
+
+  @property
+  def layers(self):
+    return self.language_model.model.layers
+
+  @property
+  def head_dim(self):
+    return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)
+
+  @property
+  def n_kv_heads(self):
+    return self.language_model.model.num_key_value_heads

+ 8 - 5
exo/inference/mlx/sharded_model.py

@@ -3,7 +3,7 @@ from collections import OrderedDict
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from mlx_lm.models.base import RotatingKVCache
+from mlx_lm.models.base import KVCache, RotatingKVCache
 from mlx_lm.sample_utils import top_p_sampling
 from mlx_lm.sample_utils import top_p_sampling
 
 
 from ..shard import Shard
 from ..shard import Shard
@@ -38,7 +38,7 @@ class StatefulShardedModel:
         if top_p > 0 and top_p < 1.0:
         if top_p > 0 and top_p < 1.0:
           token = top_p_sampling(logits, top_p, temp)
           token = top_p_sampling(logits, top_p, temp)
         else:
         else:
-          token = mx.random.categorical(logits * (1 / temp))
+          token = mx.random.categorical(logits*(1/temp))
 
 
       return token
       return token
 
 
@@ -74,10 +74,13 @@ class StatefulShardedModel:
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
 
   def init_cache(self, request_id: str):
   def init_cache(self, request_id: str):
-    kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
-    new_cache = [RotatingKVCache(self.model.head_dim, n, self.max_kv_size) for n in kv_heads]
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    if self.max_kv_size is not None:
+      cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+    else:
+      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
 
 
     if len(self.caches) >= self.max_caches:
     if len(self.caches) >= self.max_caches:
       self.caches.popitem(last=False)
       self.caches.popitem(last=False)
 
 
-    self.caches[request_id] = new_cache
+    self.caches[request_id] = cache

+ 31 - 28
exo/inference/mlx/sharded_utils.py

@@ -60,7 +60,7 @@ def _get_classes(config: dict):
 
 
 def load_config(model_path: Path) -> dict:
 def load_config(model_path: Path) -> dict:
   try:
   try:
-    with open(model_path / "config.json", "r") as f:
+    with open(model_path/"config.json", "r") as f:
       config = json.load(f)
       config = json.load(f)
   except FileNotFoundError:
   except FileNotFoundError:
     logging.error(f"Config file not found in {model_path}")
     logging.error(f"Config file not found in {model_path}")
@@ -103,11 +103,11 @@ def load_model_shard(
     "n_layers": shard.n_layers,
     "n_layers": shard.n_layers,
   }
   }
 
 
-  weight_files = glob.glob(str(model_path / "model*.safetensors"))
+  weight_files = glob.glob(str(model_path/"model*.safetensors"))
 
 
   if not weight_files:
   if not weight_files:
     # Try weight for back-compat
     # Try weight for back-compat
-    weight_files = glob.glob(str(model_path / "weight*.safetensors"))
+    weight_files = glob.glob(str(model_path/"weight*.safetensors"))
 
 
   if not weight_files:
   if not weight_files:
     logging.error(f"No safetensors found in {model_path}")
     logging.error(f"No safetensors found in {model_path}")
@@ -139,9 +139,10 @@ def load_model_shard(
   if (quantization := config.get("quantization", None)) is not None:
   if (quantization := config.get("quantization", None)) is not None:
     # Handle legacy models which may not have everything quantized
     # Handle legacy models which may not have everything quantized
     def class_predicate(p, m):
     def class_predicate(p, m):
-        if not hasattr(m, "to_quantized"):
-            return False
-        return f"{p}.scales" in weights
+      if not hasattr(m, "to_quantized"):
+        return False
+      return f"{p}.scales" in weights
+
     nn.quantize(
     nn.quantize(
       model,
       model,
       **quantization,
       **quantization,
@@ -156,6 +157,7 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
+
 async def load_shard(
 async def load_shard(
   model_path: str,
   model_path: str,
   shard: Shard,
   shard: Shard,
@@ -179,26 +181,27 @@ async def load_shard(
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     tokenizer = load_tokenizer(model_path, tokenizer_config)
     return model, tokenizer
     return model, tokenizer
 
 
+
 async def get_image_from_str(_image_str: str):
 async def get_image_from_str(_image_str: str):
-    image_str = _image_str.strip()
-
-    if image_str.startswith("http"):
-        async with aiohttp.ClientSession() as session:
-            async with session.get(image_str, timeout=10) as response:
-                content = await response.read()
-                return Image.open(BytesIO(content)).convert("RGB")
-    elif image_str.startswith("data:image/"):
-        # Extract the image format and base64 data
-        format_prefix, base64_data = image_str.split(";base64,")
-        image_format = format_prefix.split("/")[1].lower()
-        if DEBUG >= 2: print(f"{image_str=} {image_format=}")
-        imgdata = base64.b64decode(base64_data)
-        img = Image.open(BytesIO(imgdata))
-
-        # Convert to RGB if not already
-        if img.mode != "RGB":
-            img = img.convert("RGB")
-
-        return img
-    else:
-        raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
+  image_str = _image_str.strip()
+
+  if image_str.startswith("http"):
+    async with aiohttp.ClientSession() as session:
+      async with session.get(image_str, timeout=10) as response:
+        content = await response.read()
+        return Image.open(BytesIO(content)).convert("RGB")
+  elif image_str.startswith("data:image/"):
+    # Extract the image format and base64 data
+    format_prefix, base64_data = image_str.split(";base64,")
+    image_format = format_prefix.split("/")[1].lower()
+    if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+    imgdata = base64.b64decode(base64_data)
+    img = Image.open(BytesIO(imgdata))
+
+    # Convert to RGB if not already
+    if img.mode != "RGB":
+      img = img.convert("RGB")
+
+    return img
+  else:
+    raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")

+ 6 - 6
exo/inference/mlx/test_sharded_llava.py

@@ -39,8 +39,8 @@ y = full.step("full", input_ids, pixel_values, temp=0)
 full_generated_tokens = [y.item()]
 full_generated_tokens = [y.item()]
 
 
 for _ in range(13):
 for _ in range(13):
-    y = full.step("full", y, temp=0)
-    full_generated_tokens.append(y.item())
+  y = full.step("full", y, temp=0)
+  full_generated_tokens.append(y.item())
 
 
 full_response = full_processor.tokenizer.decode(full_generated_tokens)
 full_response = full_processor.tokenizer.decode(full_generated_tokens)
 print("full response:", full_response)
 print("full response:", full_response)
@@ -54,11 +54,11 @@ y = m2.step("shard", y, temp=0)
 full_generated_tokens = [y.item()]
 full_generated_tokens = [y.item()]
 
 
 for _ in range(13):
 for _ in range(13):
-    y = m1.step("shard", y, temp=0)
-    y = m2.step("shard", y, temp=0)
-    full_generated_tokens.append(y.item())
+  y = m1.step("shard", y, temp=0)
+  y = m2.step("shard", y, temp=0)
+  full_generated_tokens.append(y.item())
 
 
 sharded_response = processor2.tokenizer.decode(full_generated_tokens)
 sharded_response = processor2.tokenizer.decode(full_generated_tokens)
 print("sharded response:", sharded_response)
 print("sharded response:", sharded_response)
 
 
-assert full_response == sharded_response
+assert full_response == sharded_response

+ 2 - 2
exo/inference/mlx/test_sharded_model.py

@@ -21,7 +21,7 @@ class DummyModel(nn.Module):
 
 
   def __call__(self, x, cache=None):
   def __call__(self, x, cache=None):
     if self.shard:
     if self.shard:
-      for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
+      for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
         x = layer(x)
         x = layer(x)
       if self.shard.is_last_layer():
       if self.shard.is_last_layer():
         x = x.reshape((1, 2, 4))
         x = x.reshape((1, 2, 4))
@@ -38,7 +38,7 @@ model.save_weights("./test_weights.npz")
 n_layers = 5
 n_layers = 5
 shard1 = Shard("test", 0, n_layers // 2, n_layers)
 shard1 = Shard("test", 0, n_layers // 2, n_layers)
 sharded_model1 = DummyModel(shard1)
 sharded_model1 = DummyModel(shard1)
-shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
+shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
 sharded_model2 = DummyModel(shard2)
 sharded_model2 = DummyModel(shard2)
 
 
 model.load_weights("./test_weights.npz")
 model.load_weights("./test_weights.npz")

+ 2 - 4
exo/inference/shard.py

@@ -34,8 +34,6 @@ class Shard:
   def overlaps(self, other: 'Shard') -> bool:
   def overlaps(self, other: 'Shard') -> bool:
     return shards_overlap(self, other)
     return shards_overlap(self, other)
 
 
+
 def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
 def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
-  return (
-      shard1.model_id == shard2.model_id
-      and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)
-  )
+  return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))

+ 13 - 11
exo/inference/test_inference_engine.py

@@ -7,6 +7,7 @@ import os
 import asyncio
 import asyncio
 import numpy as np
 import numpy as np
 
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
@@ -22,7 +23,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
   resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
     input_data=resp1,
     input_data=resp1,
     inference_state=inference_state_1,
     inference_state=inference_state_1,
   )
   )
@@ -34,7 +35,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   )
   )
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
   resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp+1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,
     inference_state=inference_state_3,
     inference_state=inference_state_3,
   )
   )
@@ -42,21 +43,22 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(
-  test_inference_engine(
-    MLXDynamicShardInferenceEngine(HFShardDownloader()),
-    MLXDynamicShardInferenceEngine(HFShardDownloader()),
-    "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-  )
-)
+
+asyncio.run(test_inference_engine(
+  MLXDynamicShardInferenceEngine(HFShardDownloader()),
+  MLXDynamicShardInferenceEngine(HFShardDownloader()),
+  "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+))
 
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
   import tinygrad
   import os
   import os
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-  asyncio.run(test_inference_engine(
+  asyncio.run(
+    test_inference_engine(
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       TinygradDynamicShardInferenceEngine(HFShardDownloader()),
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
       "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-  ))
+    )
+  )

+ 10 - 13
exo/inference/tinygrad/inference.py

@@ -3,6 +3,7 @@ import json
 import os
 import os
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad.nn.state import safe_load, torch_load, load_state_dict
 from tinygrad import Tensor, dtypes, nn, Context
 from tinygrad import Tensor, dtypes, nn, Context
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
@@ -20,16 +21,11 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 ALPHA_P = 0.0
 MODEL_PARAMS = {
 MODEL_PARAMS = {
-  "8B": {
-    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
-    "files": 1
-  },
-  "70B": {
-    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
-    "files": 8
-  }
+  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 }
 
 
+
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
@@ -38,9 +34,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
-    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
-    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+    if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
+    elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
   else:
     weights = load(str(model_path), shard)
     weights = load(str(model_path), shard)
   weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
@@ -48,9 +44,10 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
 
   with Context(BEAM=0):
   with Context(BEAM=0):
     # replace weights in model
     # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=False) # consume=True
+    load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
   return model
 
 
+
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
@@ -94,5 +91,5 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
 
     model_path = await self.shard_downloader.ensure_shard(shard)
     model_path = await self.shard_downloader.ensure_shard(shard)
     self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
     self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
-    self.tokenizer = AutoTokenizer.from_pretrained(str((model_path if model_path.is_dir() else model_path.parent)))
+    self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
     self.shard = shard
     self.shard = shard

+ 76 - 42
exo/inference/tinygrad/models/llama.py

@@ -2,21 +2,24 @@ from typing import Tuple, Union, Optional, Dict, Any
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
 
 
+
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
-  freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
+  freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+  freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
+
 
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
 def complex_mult(A, c, d):
-  a,b = A[..., 0:1], A[..., 1:2]
+  a, b = A[..., 0:1], A[..., 1:2]
   ro = a*c - b*d
   ro = a*c - b*d
   co = a*d + b*c
   co = a*d + b*c
   return ro.cat(co, dim=-1)
   return ro.cat(co, dim=-1)
 
 
-def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -26,26 +29,28 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
   xk_out = complex_mult(xk, c, d)
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
   return xq_out.flatten(3), xk_out.flatten(3)
 
 
-def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   bs, seqlen, n_kv_heads, head_dim = x.shape
   if n_rep == 1: return x
   if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
-  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
+  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
+
 
 
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
     self.max_context = max_context
 
 
-    self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
-    self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
-    self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
-    self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
+    self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
+    self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
+    self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
+    self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
       xqkv = x @ self.wqkv.T
@@ -69,10 +74,10 @@ class Attention:
 
 
     # update the cache
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -80,26 +85,29 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
     return self.wo(attn)
 
 
+
 class FeedForward:
 class FeedForward:
-  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
+  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
+    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
+
+  def __call__(self, x: Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu()*self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
-  def __call__(self, x:Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
 class TransformerBlock:
 class TransformerBlock:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
+  def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
+
 # standard openai sampling
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -113,20 +121,20 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   if af or ap:
   if af or ap:
     if not hasattr(sample, "alpha_counter"):
     if not hasattr(sample, "alpha_counter"):
       setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
       setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
-    logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
+    logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
 
 
   # replace NaNs with -inf
   # replace NaNs with -inf
   logits = (logits != logits).where(-float("inf"), logits)
   logits = (logits != logits).where(-float("inf"), logits)
 
 
   # softmax
   # softmax
-  t = (logits / temp).softmax()
+  t = (logits/temp).softmax()
 
 
   counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   # top k
   # top k
   if k:
   if k:
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
     for i in range(k):
-      t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
+      t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
       output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
       output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
       t = (counter == t_argmax).where(0, t)
       t = (counter == t_argmax).where(0, t)
@@ -134,8 +142,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
     # approximate top p
     # approximate top p
     # because we are already limited to top k elements we can do top p "without sorting"
     # because we are already limited to top k elements we can do top p "without sorting"
     output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
     output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
-    output = (output_cumsum >= (1 - p)) * output
-    output_indices = (output_cumsum >= (1 - p)) * output_indices
+    output = (output_cumsum >= (1 - p))*output
+    output_indices = (output_cumsum >= (1 - p))*output_indices
 
 
     # sample
     # sample
     output_idx = output.multinomial()
     output_idx = output.multinomial()
@@ -149,23 +157,40 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
 
   return output_token
   return output_token
 
 
+
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+
 class Transformer:
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard=None, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_layers: int,
+    norm_eps: float,
+    vocab_size,
+    shard: Shard = None,
+    linear=nn.Linear,
+    n_kv_heads=None,
+    rope_theta=10000,
+    max_context=1024,
+    jit=True,
+    feed_forward=FeedForward
+  ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(self, x:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
     seqlen = x.shape[1]
     seqlen = x.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
-    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos+1).realize() if seqlen > 1 else None
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
     if self.shard.is_first_layer():
     if self.shard.is_first_layer():
       h = self.tok_embeddings(x)
       h = self.tok_embeddings(x)
@@ -182,24 +207,32 @@ class Transformer:
     else:
     else:
       return h
       return h
 
 
-  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
+  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
+    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
       return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
       return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
     return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
 
 
+
 # *** helpers ***
 # *** helpers ***
 
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+
+def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
   def permute(v: Tensor, n_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
 
   keymap = {
   keymap = {
     "model.embed_tokens.weight": "tok_embeddings.weight",
     "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
+       for x in ["q", "k", "v", "o"]
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
+       for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
+       for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
     "lm_head.weight": "output.weight",
   }
   }
@@ -215,9 +248,10 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     sd[keymap[k]] = v
     sd[keymap[k]] = v
   return sd
   return sd
 
 
-def fix_bf16(weights:Dict[Any, Tensor]):
+
+def fix_bf16(weights: Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
   # TODO: check if device supports bf16
-  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 7 - 3
exo/inference/tinygrad/tinygrad_helpers.py

@@ -8,6 +8,7 @@ from exo.helpers import DEBUG
 from exo.download.hf.hf_helpers import get_allow_patterns
 from exo.download.hf.hf_helpers import get_allow_patterns
 from fnmatch import fnmatch
 from fnmatch import fnmatch
 
 
+
 # **** helper functions ****
 # **** helper functions ****
 def concat_weights(models, device=None):
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
   def convert(name) -> Tensor:
@@ -17,11 +18,14 @@ def concat_weights(models, device=None):
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+
   return {name: convert(name) for name in {name: None for model in models for name in model}}
   return {name: convert(name) for name in {name: None for model in models for name in model}}
 
 
-def load(fn:str, shard: Shard):
+
+def load(fn: str, shard: Shard):
   if fn.endswith('.index.json'):
   if fn.endswith('.index.json'):
-    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+    with open(fn) as fp:
+      weight_map = json.load(fp)['weight_map']
     parts = {}
     parts = {}
     filtered_weight_map = {}
     filtered_weight_map = {}
     allow_patterns = get_allow_patterns(weight_map, shard)
     allow_patterns = get_allow_patterns(weight_map, shard)
@@ -33,7 +37,7 @@ def load(fn:str, shard: Shard):
         if layer_num < shard.start_layer or layer_num > shard.end_layer:
         if layer_num < shard.start_layer or layer_num > shard.end_layer:
           continue
           continue
 
 
-      parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
+      parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
       filtered_weight_map[k] = n
       filtered_weight_map[k] = n
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}

+ 41 - 0
exo/inference/tokenizers.py

@@ -0,0 +1,41 @@
+import traceback
+from aiofiles import os as aios
+from transformers import AutoTokenizer, AutoProcessor
+from exo.download.hf.hf_helpers import get_local_snapshot_dir
+from exo.helpers import DEBUG
+
+async def resolve_tokenizer(model_id: str):
+  local_path = await get_local_snapshot_dir(model_id)
+  if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
+  try:
+    if await aios.path.exists(local_path):
+      if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
+      return await _resolve_tokenizer(local_path)
+  except:
+    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
+    if DEBUG >= 5: traceback.print_exc()
+  return await _resolve_tokenizer(model_id)
+
+async def _resolve_tokenizer(model_id_or_local_path: str):
+  try:
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
+    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False)
+    if not hasattr(processor, 'eos_token_id'):
+      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
+    if not hasattr(processor, 'encode'):
+      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
+    if not hasattr(processor, 'decode'):
+      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
+    return processor
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  try:
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
+    return AutoTokenizer.from_pretrained(model_id_or_local_path)
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")

+ 29 - 0
exo/models.py

@@ -0,0 +1,29 @@
+from exo.inference.shard import Shard
+
+model_base_shards = {
+  ### llama
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+  },
+  ### mistral
+  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
+  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  ### deepseek v2
+  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
+  ### llava
+  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+}

+ 1 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -27,7 +27,7 @@ class GRPCPeerHandle(PeerHandle):
     return self._device_capabilities
     return self._device_capabilities
 
 
   async def connect(self):
   async def connect(self):
-    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)])
+    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
 
 
   async def is_connected(self) -> bool:
   async def is_connected(self) -> bool:

+ 17 - 14
exo/networking/grpc/grpc_server.py

@@ -1,6 +1,7 @@
 import grpc
 import grpc
 from concurrent import futures
 from concurrent import futures
 import numpy as np
 import numpy as np
+from asyncio import CancelledError
 
 
 from . import node_service_pb2
 from . import node_service_pb2
 from . import node_service_pb2_grpc
 from . import node_service_pb2_grpc
@@ -20,9 +21,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     self.server = grpc.aio.server(
     self.server = grpc.aio.server(
       futures.ThreadPoolExecutor(max_workers=10),
       futures.ThreadPoolExecutor(max_workers=10),
       options=[
       options=[
-        ("grpc.max_metadata_size", 32 * 1024 * 1024),
-        ("grpc.max_send_message_length", 128 * 1024 * 1024),
-        ("grpc.max_receive_message_length", 128 * 1024 * 1024),
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ("grpc.max_send_message_length", 128*1024*1024),
+        ("grpc.max_receive_message_length", 128*1024*1024),
       ],
       ],
     )
     )
     node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
     node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
@@ -33,8 +34,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
 
   async def stop(self) -> None:
   async def stop(self) -> None:
     if self.server:
     if self.server:
-      await self.server.stop(grace=5)
-      await self.server.wait_for_termination()
+      try:
+        await self.server.stop(grace=5)
+        await self.server.wait_for_termination()
+      except CancelledError:
+        pass
       if DEBUG >= 1: print("Server stopped and all connections are closed")
       if DEBUG >= 1: print("Server stopped and all connections are closed")
 
 
   async def SendPrompt(self, request, context):
   async def SendPrompt(self, request, context):
@@ -77,9 +81,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       node_service_pb2.InferenceResult(
       node_service_pb2.InferenceResult(
         tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
         tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
         is_finished=result[1],
         is_finished=result[1],
-      )
-      if result[0] is not None
-      else node_service_pb2.InferenceResult(is_finished=result[1])
+      ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
     )
     )
 
 
   async def CollectTopology(self, request, context):
   async def CollectTopology(self, request, context):
@@ -87,12 +89,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     visited = set(request.visited)
     visited = set(request.visited)
     topology = await self.node.collect_topology(visited, max_depth)
     topology = await self.node.collect_topology(visited, max_depth)
     nodes = {
     nodes = {
-      node_id: node_service_pb2.DeviceCapabilities(
-        model=cap.model,
-        chip=cap.chip,
-        memory=cap.memory,
-        flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
-      )
+      node_id:
+        node_service_pb2.DeviceCapabilities(
+          model=cap.model,
+          chip=cap.chip,
+          memory=cap.memory,
+          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
+        )
       for node_id, cap in topology.nodes.items()
       for node_id, cap in topology.nodes.items()
     }
     }
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
     peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}

Plik diff jest za duży
+ 0 - 3
exo/networking/grpc/node_service_pb2.py


+ 228 - 273
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,306 +12,261 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 _version_not_supported = False
 
 
 try:
 try:
-    from grpc._utilities import first_version_is_lower
-    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+  from grpc._utilities import first_version_is_lower
+  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
 except ImportError:
-    _version_not_supported = True
+  _version_not_supported = True
 
 
 if _version_not_supported:
 if _version_not_supported:
-    warnings.warn(
-        f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
-        + f' grpcio>={GRPC_GENERATED_VERSION}.'
-        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
-        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
-    )
+  warnings.warn(
+    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
+    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
+    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
+  )
 
 
 
 
 class NodeServiceStub(object):
 class NodeServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def __init__(self, channel):
-        """Constructor.
+  """Missing associated documentation comment in .proto file."""
+  def __init__(self, channel):
+    """Constructor.
 
 
         Args:
         Args:
             channel: A grpc.Channel.
             channel: A grpc.Channel.
         """
         """
-        self.SendPrompt = channel.unary_unary(
-                '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.SendTensor = channel.unary_unary(
-                '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
-        self.CollectTopology = channel.unary_unary(
-                '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
-                _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.SendOpaqueStatus = channel.unary_unary(
-                '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
+    self.SendPrompt = channel.unary_unary(
+      '/node_service.NodeService/SendPrompt',
+      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.SendTensor = channel.unary_unary(
+      '/node_service.NodeService/SendTensor',
+      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.GetInferenceResult = channel.unary_unary(
+      '/node_service.NodeService/GetInferenceResult',
+      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.InferenceResult.FromString,
+      _registered_method=True
+    )
+    self.CollectTopology = channel.unary_unary(
+      '/node_service.NodeService/CollectTopology',
+      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Topology.FromString,
+      _registered_method=True
+    )
+    self.SendResult = channel.unary_unary(
+      '/node_service.NodeService/SendResult',
+      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.SendOpaqueStatus = channel.unary_unary(
+      '/node_service.NodeService/SendOpaqueStatus',
+      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
 
 
 
 
 class NodeServiceServicer(object):
 class NodeServiceServicer(object):
+  """Missing associated documentation comment in .proto file."""
+  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
     """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def SendPrompt(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def SendTensor(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendTensor(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def GetInferenceResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def CollectTopology(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def CollectTopology(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def SendResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
-    def SendOpaqueStatus(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendOpaqueStatus(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
 def add_NodeServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-            'SendPrompt': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'SendTensor': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
-            'CollectTopology': grpc.unary_unary_rpc_method_handler(
-                    servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
-            ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-            'node_service.NodeService', rpc_method_handlers)
-    server.add_generic_rpc_handlers((generic_handler,))
-    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+  rpc_method_handlers = {
+    'SendPrompt':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendPrompt,
+        request_deserializer=node__service__pb2.PromptRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'SendTensor':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendTensor,
+        request_deserializer=node__service__pb2.TensorRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'GetInferenceResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.GetInferenceResult,
+        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+      ),
+    'CollectTopology':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.CollectTopology,
+        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+        response_serializer=node__service__pb2.Topology.SerializeToString,
+      ),
+    'SendResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendResult,
+        request_deserializer=node__service__pb2.SendResultRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'SendOpaqueStatus':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendOpaqueStatus,
+        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+  }
+  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
+  server.add_generic_rpc_handlers((generic_handler,))
+  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
 
 
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
 class NodeService(object):
 class NodeService(object):
-    """Missing associated documentation comment in .proto file."""
-
-    @staticmethod
-    def SendPrompt(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  """Missing associated documentation comment in .proto file."""
+  @staticmethod
+  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendPrompt',
+      node__service__pb2.PromptRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def SendTensor(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendTensor',
+      node__service__pb2.TensorRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/GetInferenceResult',
+      node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      node__service__pb2.InferenceResult.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def CollectTopology(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/CollectTopology',
+      node__service__pb2.CollectTopologyRequest.SerializeToString,
+      node__service__pb2.Topology.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def SendResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendResult',
+      node__service__pb2.SendResultRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
 
-    @staticmethod
-    def SendOpaqueStatus(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendOpaqueStatus',
+      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )

+ 9 - 11
exo/networking/udp_discovery.py

@@ -98,14 +98,12 @@ class UDPDiscovery(Discovery):
     sock = transport.get_extra_info("socket")
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
 
 
-    message = json.dumps(
-      {
-        "type": "discovery",
-        "node_id": self.node_id,
-        "grpc_port": self.node_port,
-        "device_capabilities": self.device_capabilities.to_dict(),
-      }
-    ).encode("utf-8")
+    message = json.dumps({
+      "type": "discovery",
+      "node_id": self.node_id,
+      "grpc_port": self.node_port,
+      "device_capabilities": self.device_capabilities.to_dict(),
+    }).encode("utf-8")
 
 
     while True:
     while True:
       try:
       try:
@@ -165,14 +163,14 @@ class UDPDiscovery(Discovery):
       try:
       try:
         current_time = time.time()
         current_time = time.time()
         peers_to_remove = [
         peers_to_remove = [
-          peer_handle.id()
-          for peer_handle, connected_at, last_seen in self.known_peers.values()
+          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
           if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
         ]
         if DEBUG_DISCOVERY >= 2:
         if DEBUG_DISCOVERY >= 2:
           print(
           print(
             "Peer statuses:",
             "Peer statuses:",
-            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
+            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
+             for peer_handle, connected_at, last_seen in self.known_peers.values()},
           )
           )
         if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
         if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
           print(f"Cleaning up peers: {peers_to_remove}")
           print(f"Cleaning up peers: {peers_to_remove}")

+ 63 - 70
exo/orchestration/standard_node.py

@@ -26,9 +26,7 @@ class StandardNode(Node):
     discovery: Discovery,
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     max_generate_tokens: int = 1024,
-    chatgpt_api_endpoints: List[str] = [],
-    web_chat_urls: List[str] = [],
-    disable_tui: Optional[bool] = False,
+    topology_viz: Optional[TopologyViz] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
@@ -39,13 +37,25 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
+    self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
 
 
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
       status_data = json.loads(opaque_status)
       status_data = json.loads(opaque_status)
@@ -66,36 +76,22 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
       if DEBUG >= 1: traceback.print_exc()
 
 
-  async def start(self, wait_for_peers: int = 0) -> None:
-    await self.server.start()
-    await self.discovery.start()
-    await self.update_peers(wait_for_peers)
-    await self.collect_topology()
-    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(5))
-
-  async def stop(self) -> None:
-    await self.discovery.stop()
-    await self.server.stop()
-
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
   async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "start_process_prompt",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "prompt": prompt,
-            "image_str": image_str,
-            "inference_state": inference_state,
-            "request_id": request_id,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "image_str": image_str,
+          "inference_state": inference_state,
+          "request_id": request_id,
+        }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
@@ -105,21 +101,19 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "end_process_prompt",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "prompt": prompt,
-            "image_str": image_str,
-            "inference_state": inference_state,
-            "request_id": request_id,
-            "elapsed_time_ns": elapsed_time_ns,
-            "result_size": resp.size if resp is not None else 0,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "image_str": image_str,
+          "inference_state": inference_state,
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+          "result_size": resp.size if resp is not None else 0,
+        }),
       )
       )
     )
     )
     return resp
     return resp
@@ -165,19 +159,17 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "start_process_tensor",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "tensor_size": tensor.size,
-            "tensor_shape": tensor.shape,
-            "request_id": request_id,
-            "inference_state": inference_state,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_tensor",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "tensor_size": tensor.size,
+          "tensor_shape": tensor.shape,
+          "request_id": request_id,
+          "inference_state": inference_state,
+        }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
@@ -187,18 +179,16 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "end_process_tensor",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "request_id": request_id,
-            "elapsed_time_ns": elapsed_time_ns,
-            "result_size": resp.size if resp is not None else 0,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_tensor",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+          "result_size": resp.size if resp is not None else 0,
+        }),
       )
       )
     )
     )
     return resp
     return resp
@@ -256,7 +246,7 @@ class StandardNode(Node):
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
     if current_partition_index is not None:
     if current_partition_index is not None:
-      next_partition_index = (current_partition_index + 1) % len(partitions)
+      next_partition_index = (current_partition_index+1) % len(partitions)
       next_partition: Partition = partitions[next_partition_index]
       next_partition: Partition = partitions[next_partition_index]
       next_shard = shards[next_partition_index]
       next_shard = shards[next_partition_index]
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
@@ -306,6 +296,7 @@ class StandardNode(Node):
         await self.collect_topology()
         await self.collect_topology()
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology: {e}")
         print(f"Error collecting topology: {e}")
+        traceback.print_exc()
 
 
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
   async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
     if request_id not in self.buffered_token_output:
     if request_id not in self.buffered_token_output:
@@ -319,6 +310,7 @@ class StandardNode(Node):
     if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
     if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
 
 
     prev_visited = visited.copy()
     prev_visited = visited.copy()
+    # TODO: should we add our own peer id here?
     visited.update(p.id() for p in self.peers)
     visited.update(p.id() for p in self.peers)
 
 
     for peer in self.peers:
     for peer in self.peers:
@@ -371,6 +363,7 @@ class StandardNode(Node):
 
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
     if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
     if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}")
+
     async def send_status_to_peer(peer):
     async def send_status_to_peer(peer):
       try:
       try:
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
         await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)

+ 1 - 1
exo/stats/metrics.py

@@ -24,6 +24,6 @@ def start_metrics_server(node: Node, port: int):
     elif status == "end_process_tensor":
     elif status == "end_process_tensor":
       elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
       elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
       PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
       PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
-      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9)  # Convert ns to seconds
+      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9)  # Convert ns to seconds
 
 
   node.on_opaque_status.register("stats").on_next(_on_opaque_status)
   node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 73 - 64
exo/topology/device_capabilities.py

@@ -44,83 +44,92 @@ CHIP_FLOPS = {
   # Source: https://www.cpu-monkey.com
   # Source: https://www.cpu-monkey.com
   # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
   # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
   ### M chips
   ### M chips
-  "Apple M1": DeviceFlops(fp32=2.29 * TFLOPS, fp16=4.58 * TFLOPS, int8=9.16 * TFLOPS),
-  "Apple M1 Pro": DeviceFlops(fp32=5.30 * TFLOPS, fp16=10.60 * TFLOPS, int8=21.20 * TFLOPS),
-  "Apple M1 Max": DeviceFlops(fp32=10.60 * TFLOPS, fp16=21.20 * TFLOPS, int8=42.40 * TFLOPS),
-  "Apple M1 Ultra": DeviceFlops(fp32=21.20 * TFLOPS, fp16=42.40 * TFLOPS, int8=84.80 * TFLOPS),
-  "Apple M2": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
-  "Apple M2 Pro": DeviceFlops(fp32=5.68 * TFLOPS, fp16=11.36 * TFLOPS, int8=22.72 * TFLOPS),
-  "Apple M2 Max": DeviceFlops(fp32=13.49 * TFLOPS, fp16=26.98 * TFLOPS, int8=53.96 * TFLOPS),
-  "Apple M2 Ultra": DeviceFlops(fp32=26.98 * TFLOPS, fp16=53.96 * TFLOPS, int8=107.92 * TFLOPS),
-  "Apple M3": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS),
-  "Apple M3 Pro": DeviceFlops(fp32=4.97 * TFLOPS, fp16=9.94 * TFLOPS, int8=19.88 * TFLOPS),
-  "Apple M4": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
+  "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
+  "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
+  "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
+  "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
+  "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
+  "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
+  "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
   ### A chips
   ### A chips
-  "Apple A13 Bionic": DeviceFlops(fp32=0.69 * TFLOPS, fp16=1.38 * TFLOPS, int8=2.76 * TFLOPS),
-  "Apple A14 Bionic": DeviceFlops(fp32=0.75 * TFLOPS, fp16=1.50 * TFLOPS, int8=3.00 * TFLOPS),
-  "Apple A15 Bionic": DeviceFlops(fp32=1.37 * TFLOPS, fp16=2.74 * TFLOPS, int8=5.48 * TFLOPS),
-  "Apple A16 Bionic": DeviceFlops(fp32=1.79 * TFLOPS, fp16=3.58 * TFLOPS, int8=7.16 * TFLOPS),
-  "Apple A17 Pro": DeviceFlops(fp32=2.15 * TFLOPS, fp16=4.30 * TFLOPS, int8=8.60 * TFLOPS),
+  "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
+  "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
+  "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
+  "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
+  "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
   ### NVIDIA GPUs
   ### NVIDIA GPUs
   # RTX 40 series
   # RTX 40 series
-  "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58 * TFLOPS, fp16=165.16 * TFLOPS, int8=330.32 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74 * TFLOPS, fp16=97.48 * TFLOPS, int8=194.96 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0 * TFLOPS, fp16=104.0 * TFLOPS, int8=208.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43 * TFLOPS, fp16=78.86 * TFLOPS, int8=157.72 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0 * TFLOPS, fp16=60.0 * TFLOPS, int8=120.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0 * TFLOPS, fp16=58.0 * TFLOPS, int8=116.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0 * TFLOPS, fp16=44.0 * TFLOPS, int8=88.0 * TFLOPS),
+  "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
+  "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
+  "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
   # RTX 30 series
   # RTX 30 series
-  "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8 * TFLOPS, fp16=59.6 * TFLOPS, int8=119.2 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6 * TFLOPS, fp16=61.2 * TFLOPS, int8=122.4 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1 * TFLOPS, fp16=68.2 * TFLOPS, int8=136.4 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6 * TFLOPS, fp16=71.2 * TFLOPS, int8=142.4 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
+  "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
+  "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
+  # RTX 20 series
+  "NVIDIA GEFORCE RTX 2060": DeviceFlops(fp32=6.45*TFLOPS, fp16=12.9*TFLOPS, int8=25.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 2060 SUPER": DeviceFlops(fp32=7.2*TFLOPS, fp16=14.4*TFLOPS, int8=28.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS),
+  "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS),
+  "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS),
+  "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS),
+  "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS),
   # QUATRO RTX Ampere series
   # QUATRO RTX Ampere series
-  "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99 * TFLOPS, fp16=7.99 * TFLOPS, int8=31.91 * TFLOPS),
-  "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17 * TFLOPS, fp16=19.17 * TFLOPS, int8=76.68 * TFLOPS),
-  "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65 * TFLOPS, fp16=23.65 * TFLOPS, int8=94.6 * TFLOPS),
-  "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8 * TFLOPS, fp16=27.8 * TFLOPS, int8=111.2 * TFLOPS),
-  "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71 * TFLOPS, fp16=38.71 * TFLOPS, int8=154.84 * TFLOPS),
+  "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
+  "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
+  "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
+  "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
+  "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
   # Common Server GPUs
   # Common Server GPUs
-  "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4 * TFLOPS, fp16=149.7 * TFLOPS, int8=299.3 * TFLOPS),
-  "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
+  "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
+  "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
   # ... add more devices if needed ...
   # ... add more devices if needed ...
   ### AMD GPUs
   ### AMD GPUs
   # RX 6000 series
   # RX 6000 series
-  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04 * TFLOPS, fp16=46.08 * TFLOPS, int8=92.16 * TFLOPS),
-  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74 * TFLOPS, fp16=41.48 * TFLOPS, int8=82.96 * TFLOPS),
-  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17 * TFLOPS, fp16=32.34 * TFLOPS, int8=64.68 * TFLOPS),
-  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21 * TFLOPS, fp16=26.42 * TFLOPS, int8=52.84 * TFLOPS),
-  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4 * TFLOPS, fp16=22.8 * TFLOPS, int8=45.6 * TFLOPS),
-  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6 * TFLOPS, fp16=21.2 * TFLOPS, int8=42.4 * TFLOPS),
-  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93 * TFLOPS, fp16=17.86 * TFLOPS, int8=35.72 * TFLOPS),
-  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77 * TFLOPS, fp16=11.54 * TFLOPS, int8=23.08 * TFLOPS),
-  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57 * TFLOPS, fp16=7.14 * TFLOPS, int8=14.28 * TFLOPS),
+  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
+  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
+  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
+  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
+  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
+  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
+  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
+  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
+  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
   # RX 7000 series
   # RX 7000 series
-  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4 * TFLOPS, fp16=122.8 * TFLOPS, int8=245.6 * TFLOPS),
-  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4 * TFLOPS, fp16=106.8 * TFLOPS, int8=213.6 * TFLOPS),
-  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6 * TFLOPS, fp16=85.2 * TFLOPS, int8=170.4 * TFLOPS),
-  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2 * TFLOPS, fp16=68.4 * TFLOPS, int8=136.8 * TFLOPS),
-  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5 * TFLOPS, fp16=43.0 * TFLOPS, int8=86.0 * TFLOPS),
-  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
-  # ... add more devices if needed ...
+  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
+  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
+  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
+  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
+  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
+  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
   ### Qualcomm embedded chips: TODO
   ### Qualcomm embedded chips: TODO
 }
 }
 CHIP_FLOPS.update({f"LAPTOP GPU {key}": value for key, value in CHIP_FLOPS.items()})
 CHIP_FLOPS.update({f"LAPTOP GPU {key}": value for key, value in CHIP_FLOPS.items()})
 CHIP_FLOPS.update({f"Laptop GPU {key}": value for key, value in CHIP_FLOPS.items()})
 CHIP_FLOPS.update({f"Laptop GPU {key}": value for key, value in CHIP_FLOPS.items()})
+CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items()})
+CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
 
 
 
 
 def device_capabilities() -> DeviceCapabilities:
 def device_capabilities() -> DeviceCapabilities:
@@ -149,7 +158,7 @@ def mac_device_capabilities() -> DeviceCapabilities:
   memory_units = memory_str.split()
   memory_units = memory_str.split()
   memory_value = int(memory_units[0])
   memory_value = int(memory_units[0])
   if memory_units[1] == "GB":
   if memory_units[1] == "GB":
-    memory = memory_value * 1024
+    memory = memory_value*1024
   else:
   else:
     memory = memory_value
     memory = memory_value
 
 

+ 2 - 2
exo/topology/partitioning_strategy.py

@@ -22,8 +22,8 @@ class PartitioningStrategy(ABC):
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
   shards = []
   shards = []
   for i, partition in enumerate(partitions):
   for i, partition in enumerate(partitions):
-    start_layer = int(partition.start * num_layers)
-    end_layer = int(partition.end * num_layers) - 1
+    start_layer = int(partition.start*num_layers)
+    end_layer = int(partition.end*num_layers) - 1
 
 
     # Ensure the last partition covers up to num_layers - 1
     # Ensure the last partition covers up to num_layers - 1
     if i == len(partitions) - 1:
     if i == len(partitions) - 1:

+ 1 - 1
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -12,7 +12,7 @@ class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
     partitions = []
     partitions = []
     start = 0
     start = 0
     for node in nodes:
     for node in nodes:
-      end = round(start + (node[1].memory / total_memory), 5)
+      end = round(start + (node[1].memory/total_memory), 5)
       partitions.append(Partition(node[0], start, end))
       partitions.append(Partition(node[0], start, end))
       start = end
       start = end
     return partitions
     return partitions

+ 1 - 1
exo/topology/test_device_capabilities.py

@@ -80,7 +80,7 @@ Activation Lock Status: Disabled
     self.assertEqual(result.model, "MacBook Pro")
     self.assertEqual(result.model, "MacBook Pro")
     self.assertEqual(result.chip, "Apple M3 Max")
     self.assertEqual(result.chip, "Apple M3 Max")
     self.assertEqual(result.memory, 131072)  # 128 GB in MB
     self.assertEqual(result.memory, 131072)  # 128 GB in MB
-    self.assertEqual(result.flops, DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS))
+    self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
     self.assertEqual(
     self.assertEqual(
       str(result),
       str(result),
       "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
       "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",

+ 2 - 2
exo/topology/test_map_partitions.py

@@ -56,8 +56,8 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
     def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
     def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
       shards = []
       shards = []
       for i, partition in enumerate(partitions):
       for i, partition in enumerate(partitions):
-        start_layer = int(partition.start * num_layers)
-        end_layer = int(partition.end * num_layers) - 1
+        start_layer = int(partition.start*num_layers)
+        end_layer = int(partition.end*num_layers) - 1
         shards.append(Shard(model_id, start_layer, end_layer, num_layers))
         shards.append(Shard(model_id, start_layer, end_layer, num_layers))
       return shards
       return shards
 
 

+ 3 - 3
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -49,7 +49,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
       DeviceCapabilities(
       DeviceCapabilities(
         model="MacBook Pro",
         model="MacBook Pro",
         chip="test1",
         chip="test1",
-        memory=128 * 1024 * 1024 * 1024,
+        memory=128*1024*1024*1024,
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       ),
       ),
     )
     )
@@ -58,7 +58,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
       DeviceCapabilities(
       DeviceCapabilities(
         model="Mac Studio",
         model="Mac Studio",
         chip="test2",
         chip="test2",
-        memory=192 * 1024 * 1024 * 1024,
+        memory=192*1024*1024*1024,
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       ),
       ),
     )
     )
@@ -67,7 +67,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
       DeviceCapabilities(
       DeviceCapabilities(
         model="MacBook Pro",
         model="MacBook Pro",
         chip="test3",
         chip="test3",
-        memory=128 * 1024 * 1024 * 1024,
+        memory=128*1024*1024*1024,
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       ),
       ),
     )
     )

+ 52 - 51
exo/viz/test_topology_viz.py

@@ -9,55 +9,56 @@ from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
 
 
 
 
 def create_hf_repo_progress_event(
 def create_hf_repo_progress_event(
-    completed_files: int = 5,
-    total_files: int = 10,
-    downloaded_bytes: int = 500000000,
-    downloaded_bytes_this_session: int = 250000000,
-    total_bytes: int = 1000000000,
-    overall_speed: int = 5000000,
-    overall_eta: timedelta = timedelta(seconds=100),
-    file_progress: dict = None,
-    status: str = "in_progress"
+  completed_files: int = 5,
+  total_files: int = 10,
+  downloaded_bytes: int = 500000000,
+  downloaded_bytes_this_session: int = 250000000,
+  total_bytes: int = 1000000000,
+  overall_speed: int = 5000000,
+  overall_eta: timedelta = timedelta(seconds=100),
+  file_progress: dict = None,
+  status: str = "in_progress"
 ) -> RepoProgressEvent:
 ) -> RepoProgressEvent:
-    if file_progress is None:
-        file_progress = {
-            "file1.bin": RepoFileProgressEvent(
-                repo_id="repo_id",
-                repo_revision="repo_revision",
-                file_path="file1.bin",
-                downloaded=100000000,
-                downloaded_this_session=50000000,
-                total=200000000,
-                speed=1000000,
-                eta=timedelta(seconds=100),
-                status="in_progress"
-            ),
-            "file2.bin": RepoFileProgressEvent(
-                repo_id="repo_id",
-                repo_revision="repo_revision",
-                file_path="file2.bin",
-                downloaded=200000000,
-                downloaded_this_session=100000000,
-                total=200000000,
-                speed=2000000,
-                eta=timedelta(seconds=0),
-                status="complete"
-            )
-        }
+  if file_progress is None:
+    file_progress = {
+      "file1.bin":
+        RepoFileProgressEvent(
+          repo_id="repo_id",
+          repo_revision="repo_revision",
+          file_path="file1.bin",
+          downloaded=100000000,
+          downloaded_this_session=50000000,
+          total=200000000,
+          speed=1000000,
+          eta=timedelta(seconds=100),
+          status="in_progress"
+        ), "file2.bin":
+          RepoFileProgressEvent(
+            repo_id="repo_id",
+            repo_revision="repo_revision",
+            file_path="file2.bin",
+            downloaded=200000000,
+            downloaded_this_session=100000000,
+            total=200000000,
+            speed=2000000,
+            eta=timedelta(seconds=0),
+            status="complete"
+          )
+    }
 
 
-    return RepoProgressEvent(
-        repo_id="repo_id",
-        repo_revision="repo_revision",
-        completed_files=completed_files,
-        total_files=total_files,
-        downloaded_bytes=downloaded_bytes,
-        downloaded_bytes_this_session=downloaded_bytes_this_session,
-        total_bytes=total_bytes,
-        overall_speed=overall_speed,
-        overall_eta=overall_eta,
-        file_progress=file_progress,
-        status=status
-    )
+  return RepoProgressEvent(
+    repo_id="repo_id",
+    repo_revision="repo_revision",
+    completed_files=completed_files,
+    total_files=total_files,
+    downloaded_bytes=downloaded_bytes,
+    downloaded_bytes_this_session=downloaded_bytes_this_session,
+    total_bytes=total_bytes,
+    overall_speed=overall_speed,
+    overall_eta=overall_eta,
+    file_progress=file_progress,
+    status=status
+  )
 
 
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
@@ -65,19 +66,19 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     self.topology = Topology()
     self.topology = Topology()
     self.topology.update_node(
     self.topology.update_node(
       "node1",
       "node1",
-      DeviceCapabilities(model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
+      DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
     )
     )
     self.topology.update_node(
     self.topology.update_node(
       "node2",
       "node2",
-      DeviceCapabilities(model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
+      DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
     )
     )
     self.topology.update_node(
     self.topology.update_node(
       "node3",
       "node3",
-      DeviceCapabilities(model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
+      DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
     )
     )
     self.topology.update_node(
     self.topology.update_node(
       "node4",
       "node4",
-      DeviceCapabilities(model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
+      DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
     )
     )
 
 
     self.top_viz = TopologyViz()
     self.top_viz = TopologyViz()

+ 122 - 57
exo/viz/topology_viz.py

@@ -1,17 +1,21 @@
 import math
 import math
+from collections import OrderedDict
 from typing import List, Optional, Tuple, Dict
 from typing import List, Optional, Tuple, Dict
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
-from rich.console import Console
-from rich.panel import Panel
+from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.console import Console, Group
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
 from rich.table import Table
 from rich.table import Table
 from rich.layout import Layout
 from rich.layout import Layout
-from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
+from rich.syntax import Syntax
+from rich.panel import Panel
+from rich.markdown import Markdown
+
 
 
 class TopologyViz:
 class TopologyViz:
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
@@ -21,17 +25,20 @@ class TopologyViz:
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
     self.node_id = None
     self.node_id = None
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.requests: OrderedDict[str, Tuple[str, str]] = {}
 
 
     self.console = Console()
     self.console = Console()
     self.layout = Layout()
     self.layout = Layout()
-    self.layout.split(
-      Layout(name="main"),
-      Layout(name="download", size=25)
-    )
+    self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25))
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.download_panel = Panel("", title="Download Progress", border_style="cyan")
     self.layout["main"].update(self.main_panel)
     self.layout["main"].update(self.main_panel)
+    self.layout["prompt_output"].update(self.prompt_output_panel)
     self.layout["download"].update(self.download_panel)
     self.layout["download"].update(self.download_panel)
+
+    # Initially hide the prompt_output panel
+    self.layout["prompt_output"].visible = False
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 
@@ -43,12 +50,34 @@ class TopologyViz:
       self.node_download_progress = node_download_progress
       self.node_download_progress = node_download_progress
     self.refresh()
     self.refresh()
 
 
+  def update_prompt(self, request_id: str, prompt: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [prompt, self.requests[request_id][1]]
+    else:
+      self.requests[request_id] = [prompt, ""]
+    self.refresh()
+
+  def update_prompt_output(self, request_id: str, output: Optional[str] = None):
+    if request_id in self.requests:
+      self.requests[request_id] = [self.requests[request_id][0], output]
+    else:
+      self.requests[request_id] = ["", output]
+    self.refresh()
+
   def refresh(self):
   def refresh(self):
     self.main_panel.renderable = self._generate_main_layout()
     self.main_panel.renderable = self._generate_main_layout()
     # Update the panel title with the number of nodes and partitions
     # Update the panel title with the number of nodes and partitions
     node_count = len(self.topology.nodes)
     node_count = len(self.topology.nodes)
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
     self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
 
 
+    # Update and show/hide prompt and output panel
+    if any(r[0] or r[1] for r in self.requests.values()):
+      self.prompt_output_panel = self._generate_prompt_output_layout()
+      self.layout["prompt_output"].update(self.prompt_output_panel)
+      self.layout["prompt_output"].visible = True
+    else:
+      self.layout["prompt_output"].visible = False
+
     # Only show download_panel if there are in-progress downloads
     # Only show download_panel if there are in-progress downloads
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
     if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
       self.download_panel.renderable = self._generate_download_layout()
       self.download_panel.renderable = self._generate_download_layout()
@@ -58,6 +87,42 @@ class TopologyViz:
 
 
     self.live_panel.update(self.layout, refresh=True)
     self.live_panel.update(self.layout, refresh=True)
 
 
+  def _generate_prompt_output_layout(self) -> Panel:
+    content = []
+    requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
+    max_width = self.console.width - 6  # Full width minus padding and icon
+    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    for (prompt, output) in reversed(requests):
+      prompt_icon, output_icon = "💬️", "🤖"
+
+      # Process prompt
+      prompt_lines = prompt.split('\n')
+      if len(prompt_lines) > max_lines // 2:
+        prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
+      prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
+      prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+
+      # Process output
+      output_lines = output.split('\n')
+      remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
+      if len(output_lines) > remaining_lines:
+        output_lines = output_lines[:remaining_lines - 1] + ['...']
+      output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
+      output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+
+      content.append(prompt_text)
+      content.append(output_text)
+      content.append(Text())  # Empty line between entries
+
+    return Panel(
+      Group(*content),
+      title="",
+      border_style="cyan",
+      height=15,  # Increased height to accommodate multiple lines
+      expand=True  # Allow the panel to expand to full width
+    )
+
   def _generate_main_layout(self) -> str:
   def _generate_main_layout(self) -> str:
     # Calculate visualization parameters
     # Calculate visualization parameters
     num_partitions = len(self.partitions)
     num_partitions = len(self.partitions)
@@ -74,7 +139,7 @@ class TopologyViz:
     max_line_length = max(len(line) for line in exo_lines)
     max_line_length = max(len(line) for line in exo_lines)
     for i, line in enumerate(exo_lines):
     for i, line in enumerate(exo_lines):
       centered_line = line.center(max_line_length)
       centered_line = line.center(max_line_length)
-      start_x = (100 - max_line_length) // 2 + 15
+      start_x = (100-max_line_length) // 2 + 15
       colored_line = Text(centered_line, style=yellow_style)
       colored_line = Text(centered_line, style=yellow_style)
       for j, char in enumerate(str(colored_line)):
       for j, char in enumerate(str(colored_line)):
         if 0 <= start_x + j < 100 and i < len(visualization):
         if 0 <= start_x + j < 100 and i < len(visualization):
@@ -96,18 +161,18 @@ class TopologyViz:
 
 
     # Calculate total FLOPS and position on the bar
     # Calculate total FLOPS and position on the bar
     total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
     total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
-    bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
+    bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2
 
 
     # Add GPU poor/rich bar
     # Add GPU poor/rich bar
     bar_width = 30
     bar_width = 30
-    bar_start_x = (100 - bar_width) // 2
+    bar_start_x = (100-bar_width) // 2
     bar_y = info_start_y + len(info_lines) + 1
     bar_y = info_start_y + len(info_lines) + 1
 
 
     # Create a gradient bar using emojis
     # Create a gradient bar using emojis
     gradient_bar = Text()
     gradient_bar = Text()
     emojis = ["🟥", "🟧", "🟨", "🟩"]
     emojis = ["🟥", "🟧", "🟨", "🟩"]
     for i in range(bar_width):
     for i in range(bar_width):
-      emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
+      emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1)
       gradient_bar.append(emojis[emoji_index])
       gradient_bar.append(emojis[emoji_index])
 
 
     # Add the gradient bar to the visualization
     # Add the gradient bar to the visualization
@@ -117,14 +182,14 @@ class TopologyViz:
       visualization[bar_y][bar_start_x + i] = segment
       visualization[bar_y][bar_start_x + i] = segment
 
 
     # Add labels
     # Add labels
-    visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
-    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
+    visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
+    visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich"
 
 
     # Add position indicator and FLOPS value
     # Add position indicator and FLOPS value
-    pos_x = bar_start_x + int(bar_pos * bar_width)
+    pos_x = bar_start_x + int(bar_pos*bar_width)
     flops_str = f"{total_flops:.2f} TFLOPS"
     flops_str = f"{total_flops:.2f} TFLOPS"
     visualization[bar_y - 1][pos_x] = "▼"
     visualization[bar_y - 1][pos_x] = "▼"
-    visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
+    visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
     visualization[bar_y + 2][pos_x] = "▲"
 
 
     # Add an extra empty line for spacing
     # Add an extra empty line for spacing
@@ -133,9 +198,9 @@ class TopologyViz:
     for i, partition in enumerate(self.partitions):
     for i, partition in enumerate(self.partitions):
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
 
-      angle = 2 * math.pi * i / num_partitions
-      x = int(center_x + radius_x * math.cos(angle))
-      y = int(center_y + radius_y * math.sin(angle))
+      angle = 2*math.pi*i/num_partitions
+      x = int(center_x + radius_x*math.cos(angle))
+      y = int(center_y + radius_y*math.sin(angle))
 
 
       # Place node with different color for active node and this node
       # Place node with different color for active node and this node
       if partition.node_id == self.topology.active_node_id:
       if partition.node_id == self.topology.active_node_id:
@@ -155,8 +220,8 @@ class TopologyViz:
       # Calculate info position based on angle
       # Calculate info position based on angle
       info_distance_x = radius_x + 6
       info_distance_x = radius_x + 6
       info_distance_y = radius_y + 3
       info_distance_y = radius_y + 3
-      info_x = int(center_x + info_distance_x * math.cos(angle))
-      info_y = int(center_y + info_distance_y * math.sin(angle))
+      info_x = int(center_x + info_distance_x*math.cos(angle))
+      info_y = int(center_y + info_distance_y*math.sin(angle))
 
 
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
       if info_x < x:
       if info_x < x:
@@ -165,9 +230,9 @@ class TopologyViz:
         info_x = min(99 - len(max(node_info, key=len)), info_x)
         info_x = min(99 - len(max(node_info, key=len)), info_x)
 
 
       # Adjust for top and bottom nodes
       # Adjust for top and bottom nodes
-      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
+      if 5*math.pi/4 < angle < 7*math.pi/4:
         info_x += 4
         info_x += 4
-      elif math.pi / 4 < angle < 3 * math.pi / 4:
+      elif math.pi/4 < angle < 3*math.pi/4:
         info_x += 3
         info_x += 3
         info_y -= 2
         info_y -= 2
 
 
@@ -178,16 +243,16 @@ class TopologyViz:
               visualization[info_y + j][info_x + k] = char
               visualization[info_y + j][info_x + k] = char
 
 
       # Draw line to next node
       # Draw line to next node
-      next_i = (i + 1) % num_partitions
-      next_angle = 2 * math.pi * next_i / num_partitions
-      next_x = int(center_x + radius_x * math.cos(next_angle))
-      next_y = int(center_y + radius_y * math.sin(next_angle))
+      next_i = (i+1) % num_partitions
+      next_angle = 2*math.pi*next_i/num_partitions
+      next_x = int(center_x + radius_x*math.cos(next_angle))
+      next_y = int(center_y + radius_y*math.sin(next_angle))
 
 
       # Simple line drawing
       # Simple line drawing
       steps = max(abs(next_x - x), abs(next_y - y))
       steps = max(abs(next_x - x), abs(next_y - y))
       for step in range(1, steps):
       for step in range(1, steps):
-        line_x = int(x + (next_x - x) * step / steps)
-        line_y = int(y + (next_y - y) * step / steps)
+        line_x = int(x + (next_x-x)*step/steps)
+        line_y = int(y + (next_y-y)*step/steps)
         if 0 <= line_y < 48 and 0 <= line_x < 100:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
           visualization[line_y][line_x] = "-"
 
 
@@ -202,41 +267,41 @@ class TopologyViz:
 
 
     # Current node download progress
     # Current node download progress
     if self.node_id in self.node_download_progress:
     if self.node_id in self.node_download_progress:
-        download_progress = self.node_download_progress[self.node_id]
-        title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
-        summary.add_row(Text(title, style="bold"))
-        progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
-        summary.add_row(progress_info)
+      download_progress = self.node_download_progress[self.node_id]
+      title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
+      summary.add_row(Text(title, style="bold"))
+      progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
+      summary.add_row(progress_info)
 
 
-        eta_info = f"{download_progress.overall_eta}"
-        summary.add_row(eta_info)
+      eta_info = f"{download_progress.overall_eta}"
+      summary.add_row(eta_info)
 
 
-        summary.add_row("")  # Empty row for spacing
+      summary.add_row("")  # Empty row for spacing
 
 
-        for file_path, file_progress in download_progress.file_progress.items():
-            if file_progress.status != "complete":
-                progress = int(file_progress.downloaded / file_progress.total * 30)
-                bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
-                percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
-                summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
+      for file_path, file_progress in download_progress.file_progress.items():
+        if file_progress.status != "complete":
+          progress = int(file_progress.downloaded/file_progress.total*30)
+          bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
+          percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
+          summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
 
 
     summary.add_row("")  # Empty row for spacing
     summary.add_row("")  # Empty row for spacing
 
 
     # Other nodes download progress summary
     # Other nodes download progress summary
     summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
     summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
     for node_id, progress in self.node_download_progress.items():
     for node_id, progress in self.node_download_progress.items():
-        if node_id != self.node_id:
-            device = self.topology.nodes.get(node_id)
-            partition = next((p for p in self.partitions if p.node_id == node_id), None)
-            partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
-            percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
-            speed = pretty_print_bytes_per_second(progress.overall_speed)
-            device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
-            progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
-            progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
-            percentage_str = f"{percentage:.1f}%"
-            eta_str = f"{progress.overall_eta}"
-            summary.add_row(device_info, progress_info, percentage_str)
-            summary.add_row("", progress_bar, eta_str)
-
-    return summary
+      if node_id != self.node_id:
+        device = self.topology.nodes.get(node_id)
+        partition = next((p for p in self.partitions if p.node_id == node_id), None)
+        partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
+        percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0
+        speed = pretty_print_bytes_per_second(progress.overall_speed)
+        device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
+        progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
+        progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
+        percentage_str = f"{percentage:.1f}%"
+        eta_str = f"{progress.overall_eta}"
+        summary.add_row(device_info, progress_info, percentage_str)
+        summary.add_row("", progress_bar, eta_str)
+
+    return summary

+ 34 - 37
extra/download_hf.py

@@ -3,51 +3,48 @@ import asyncio
 from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
 
 
 DEFAULT_ALLOW_PATTERNS = [
 DEFAULT_ALLOW_PATTERNS = [
-    "*.json",
-    "*.py",
-    "tokenizer.model",
-    "*.tiktoken",
-    "*.txt",
-    "*.safetensors",
+  "*.json",
+  "*.py",
+  "tokenizer.model",
+  "*.tiktoken",
+  "*.txt",
+  "*.safetensors",
 ]
 ]
 # Always ignore `.git` and `.cache/huggingface` folders in commits
 # Always ignore `.git` and `.cache/huggingface` folders in commits
 DEFAULT_IGNORE_PATTERNS = [
 DEFAULT_IGNORE_PATTERNS = [
-    ".git",
-    ".git/*",
-    "*/.git",
-    "**/.git/**",
-    ".cache/huggingface",
-    ".cache/huggingface/*",
-    "*/.cache/huggingface",
-    "**/.cache/huggingface/**",
+  ".git",
+  ".git/*",
+  "*/.git",
+  "**/.git/**",
+  ".cache/huggingface",
+  ".cache/huggingface/*",
+  "*/.cache/huggingface",
+  "**/.cache/huggingface/**",
 ]
 ]
 
 
+
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-    async def progress_callback(event: RepoProgressEvent):
-        print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
-        print(f"Estimated time remaining: {event.overall_eta}")
-        print("File Progress:")
-        for file_path, progress in event.file_progress.items():
-            status_icon = {
-                'not_started': '⚪',
-                'in_progress': '🔵',
-                'complete': '✅'
-            }[progress.status]
-            eta_str = str(progress.eta)
-            print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
-                  f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
-        print("\n")
-
-    await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
+  async def progress_callback(event: RepoProgressEvent):
+    print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
+    print(f"Estimated time remaining: {event.overall_eta}")
+    print("File Progress:")
+    for file_path, progress in event.file_progress.items():
+      status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
+      eta_str = str(progress.eta)
+      print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
+            f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
+    print("\n")
+
+  await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
-    parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
-    parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
-    parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
-    parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
+  parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
+  parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
+  parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
+  parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
+  parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
 
 
-    args = parser.parse_args()
+  args = parser.parse_args()
 
 
-    asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
+  asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

+ 3 - 0
extra/start_openwebui.sh

@@ -0,0 +1,3 @@
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+echo "Using API_ENDPOINT=${API_ENDPOINT}"
+docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main

+ 10 - 87
format.py

@@ -1,99 +1,22 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
-import re
 import subprocess
 import subprocess
 import sys
 import sys
 import os
 import os
-import fnmatch
 
 
-DEBUG_PATTERN = re.compile(r'^(\s*)(if\s+DEBUG\s*>=?\s*\d+\s*:.+)$', re.MULTILINE)
-PLACEHOLDER = "###DEBUG_PLACEHOLDER###"
 
 
-# Add ignore patterns here
-IGNORE_PATTERNS = [
-  '.venv/*',
-  'setup.py',
-  '*helpers.py',
-  '*node_service_pb2.py',
-  '*node_service_pb2_grpc.py',
-]
-
-
-def should_ignore(file_path):
-  for pattern in IGNORE_PATTERNS:
-    if fnmatch.fnmatch(file_path, pattern):
-      return True
-  return False
-
-
-def preserve_debug_lines(content):
-  def replace(match):
-    indent, line = match.groups()
-    return f"{indent}{PLACEHOLDER}{line.strip()}"
-
-  return DEBUG_PATTERN.sub(replace, content)
-
-
-def restore_debug_lines(content):
-  return re.sub(f"^(\\s*){PLACEHOLDER}(.+)$", r"\1\2", content, flags=re.MULTILINE)
-
-
-def adjust_indentation(content):
-  lines = content.split('\n')
-  adjusted_lines = []
-  for line in lines:
-    if line.strip() and not line.startswith(PLACEHOLDER):
-      indent = len(line) - len(line.lstrip())
-      new_indent = ' ' * (indent // 2)
-      adjusted_lines.append(new_indent + line.lstrip())
-    else:
-      adjusted_lines.append(line)
-  return '\n'.join(adjusted_lines)
-
-
-def process_file(file_path, process_func):
-  with open(file_path, 'r') as file:
-    content = file.read()
-
-  modified_content = process_func(content)
-
-  if content != modified_content:
-    with open(file_path, 'w') as file:
-      file.write(modified_content)
-
-
-def run_black(target):
-  # Convert ignore patterns to Black's --extend-exclude format
-  exclude_patterns = '|'.join(f'({pattern.replace("*", ".*")})' for pattern in IGNORE_PATTERNS)
-  command = ["black", "--line-length", "200", "--extend-exclude", exclude_patterns, target]
-  subprocess.run(command, check=True)
-
-
-def format_files(target):
+def run_yapf(target):
   if os.path.isfile(target):
   if os.path.isfile(target):
-    files = [target] if not should_ignore(target) else []
-  elif os.path.isdir(target):
-    files = []
-    for root, _, filenames in os.walk(target):
-      for filename in filenames:
-        if filename.endswith('.py'):
-          file_path = os.path.join(root, filename)
-          if not should_ignore(file_path):
-            files.append(file_path)
+    files = [target]
   else:
   else:
-    print(f"Error: {target} is not a valid file or directory")
-    return
-
-  # Preserve debug lines
-  for file in files:
-    process_file(file, preserve_debug_lines)
-
-  # Run Black
-  run_black(target)
+    files = [os.path.join(root, file) for root, _, files in os.walk(target) for file in files if file.endswith('.py')]
 
 
-  # Adjust indentation and restore debug lines
   for file in files:
   for file in files:
-    process_file(file, adjust_indentation)
-    process_file(file, restore_debug_lines)
+    try:
+      command = ["yapf", "-i", file]
+      subprocess.run(command, check=True, capture_output=True, text=True)
+      print(f"Formatted: {file}")
+    except subprocess.CalledProcessError as e:
+      print(f"Error formatting {file}: {e.stderr}")
 
 
 
 
 def main():
 def main():
@@ -102,7 +25,7 @@ def main():
     sys.exit(1)
     sys.exit(1)
 
 
   target = sys.argv[1]
   target = sys.argv[1]
-  format_files(target)
+  run_yapf(target)
   print("Formatting completed.")
   print("Formatting completed.")
 
 
 
 

+ 120 - 67
main.py

@@ -4,6 +4,7 @@ import signal
 import json
 import json
 import time
 import time
 import traceback
 import traceback
+import uuid
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp_discovery import UDPDiscovery
 from exo.networking.udp_discovery import UDPDiscovery
@@ -11,8 +12,13 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.tokenizers import resolve_tokenizer
+from exo.orchestration.node import Node
+from exo.models import model_base_shards
+from exo.viz.topology_viz import TopologyViz
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -31,6 +37,8 @@ parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90,
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
+parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
 print_yellow_exo()
 print_yellow_exo()
@@ -44,94 +52,139 @@ inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 
 if args.node_port is None:
 if args.node_port is None:
-    args.node_port = find_available_port(args.node_host)
-    if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+  args.node_port = find_available_port(args.node_host)
+  if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 
 args.node_id = args.node_id or get_or_create_node_id()
 args.node_id = args.node_id or get_or_create_node_id()
-discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
-chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
 if DEBUG >= 0:
 if DEBUG >= 0:
-    print("Chat interface started:")
-    for web_chat_url in web_chat_urls:
-        print(f" - {terminal_link(web_chat_url)}")
-    print("ChatGPT API endpoint served at:")
-    for chatgpt_api_endpoint in chatgpt_api_endpoints:
-        print(f" - {terminal_link(chatgpt_api_endpoint)}")
+  print("Chat interface started:")
+  for web_chat_url in web_chat_urls:
+    print(f" - {terminal_link(web_chat_url)}")
+  print("ChatGPT API endpoint served at:")
+  for chatgpt_api_endpoint in chatgpt_api_endpoints:
+    print(f" - {terminal_link(chatgpt_api_endpoint)}")
+
+discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
+topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
-    args.node_id,
-    None,
-    inference_engine,
-    discovery,
-    chatgpt_api_endpoints=chatgpt_api_endpoints,
-    web_chat_urls=web_chat_urls,
-    partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
-    disable_tui=args.disable_tui,
-    max_generate_tokens=args.max_generate_tokens,
+  args.node_id,
+  None,
+  inference_engine,
+  discovery,
+  partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+  max_generate_tokens=args.max_generate_tokens,
+  topology_viz=topology_viz
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+api = ChatGPTAPI(
+  node,
+  inference_engine.__class__.__name__,
+  response_timeout_secs=args.chatgpt_api_response_timeout_secs,
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+)
+node.on_token.register("update_topology_viz").on_next(
+  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
+)
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
-    try:
-        status = json.loads(opaque_status)
-        if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
-            current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
-            if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-            asyncio.create_task(shard_downloader.ensure_shard(current_shard))
-    except Exception as e:
-        if DEBUG >= 2:
-            print(f"Failed to preemptively start download: {e}")
-            traceback.print_exc()
+  try:
+    status = json.loads(opaque_status)
+    if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+      current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+      if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Failed to preemptively start download: {e}")
+      traceback.print_exc()
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
+
 if args.prometheus_client_port:
 if args.prometheus_client_port:
-    from exo.stats.metrics import start_metrics_server
-    start_metrics_server(node, args.prometheus_client_port)
+  from exo.stats.metrics import start_metrics_server
+  start_metrics_server(node, args.prometheus_client_port)
 
 
 last_broadcast_time = 0
 last_broadcast_time = 0
+
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+
+
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
+
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
-    """Gracefully shutdown the server and close the asyncio loop."""
-    print(f"Received exit signal {signal.name}...")
-    print("Thank you for using exo.")
-    print_yellow_exo()
-    server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-    [task.cancel() for task in server_tasks]
-    print(f"Cancelling {len(server_tasks)} outstanding tasks")
-    await asyncio.gather(*server_tasks, return_exceptions=True)
-    await server.stop()
-    loop.stop()
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
+  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  if not shard:
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    return
+  tokenizer = await resolve_tokenizer(shard.model_id)
+  request_id = str(uuid.uuid4())
+  callback_id = f"cli-wait-response-{request_id}"
+  callback = node.on_token.register(callback_id)
+  if topology_viz:
+    topology_viz.update_prompt(request_id, prompt)
+  prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+  try:
+    print(f"Processing prompt: {prompt}")
+    await node.process_prompt(shard, prompt, None, request_id=request_id)
+
+    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+
+    print("\nGenerated response:")
+    print(tokenizer.decode(tokens))
+  except Exception as e:
+    print(f"Error processing prompt: {str(e)}")
+    traceback.print_exc()
+  finally:
+    node.on_token.deregister(callback_id)
+
 
 
 async def main():
 async def main():
-    loop = asyncio.get_running_loop()
+  loop = asyncio.get_running_loop()
 
 
-    # Use a more direct approach to handle signals
-    def handle_exit():
-        asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+  # Use a more direct approach to handle signals
+  def handle_exit():
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
 
 
-    for s in [signal.SIGINT, signal.SIGTERM]:
-        loop.add_signal_handler(s, handle_exit)
+  for s in [signal.SIGINT, signal.SIGTERM]:
+    loop.add_signal_handler(s, handle_exit)
 
 
-    await node.start(wait_for_peers=args.wait_for_peers)
-    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+  await node.start(wait_for_peers=args.wait_for_peers)
 
 
+  if args.run_model:
+    await run_model_cli(node, inference_engine, args.run_model, args.prompt)
+  else:
+    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()
     await asyncio.Event().wait()
 
 
+
 if __name__ == "__main__":
 if __name__ == "__main__":
-    loop = asyncio.new_event_loop()
-    asyncio.set_event_loop(loop)
-    try:
-        loop.run_until_complete(main())
-    except KeyboardInterrupt:
-        print("Received keyboard interrupt. Shutting down...")
-    finally:
-        loop.run_until_complete(shutdown(signal.SIGTERM, loop))
-        loop.close()
+  loop = asyncio.new_event_loop()
+  asyncio.set_event_loop(loop)
+  try:
+    loop.run_until_complete(main())
+  except KeyboardInterrupt:
+    print("Received keyboard interrupt. Shutting down...")
+  finally:
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.close()

+ 0 - 10
pyproject.toml

@@ -1,13 +1,3 @@
-[tool.black]
-line-length = 200
-indent-size = 2
-skip-string-normalization = true
-
-[tool.isort]
-profile = "black"
-line_length = 200
-indent = "  "
-
 [tool.pylint.format]
 [tool.pylint.format]
 indent-string = '  '
 indent-string = '  '
 max-line-length = 200
 max-line-length = 200

+ 41 - 42
setup.py

@@ -4,55 +4,54 @@ from setuptools import find_packages, setup
 
 
 # Base requirements for all platforms
 # Base requirements for all platforms
 install_requires = [
 install_requires = [
-    "aiohttp==3.10.2",
-    "aiohttp_cors==0.7.0",
-    "aiofiles==24.1.0",
-    "blobfile==2.1.1",
-    "grpcio==1.64.1",
-    "grpcio-tools==1.64.1",
-    "hf-transfer==0.1.8",
-    "huggingface-hub==0.24.5",
-    "Jinja2==3.1.4",
-    "netifaces==0.11.0",
-    "numpy==2.0.0",
-    "pillow==10.4.0",
-    "prometheus-client==0.20.0",
-    "protobuf==5.27.1",
-    "psutil==6.0.0",
-    "pynvml==11.5.3",
-    "requests==2.32.3",
-    "rich==13.7.1",
-    "safetensors==0.4.3",
-    "tenacity==9.0.0",
-    "tiktoken==0.7.0",
-    "tokenizers==0.19.1",
-    "tqdm==4.66.4",
-    "transformers==4.43.3",
-    "uuid==1.30",
-    "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7",
+  "aiohttp==3.10.2",
+  "aiohttp_cors==0.7.0",
+  "aiofiles==24.1.0",
+  "blobfile==2.1.1",
+  "grpcio==1.64.1",
+  "grpcio-tools==1.64.1",
+  "hf-transfer==0.1.8",
+  "huggingface-hub==0.24.5",
+  "Jinja2==3.1.4",
+  "netifaces==0.11.0",
+  "numpy==2.0.0",
+  "pillow==10.4.0",
+  "prometheus-client==0.20.0",
+  "protobuf==5.27.1",
+  "psutil==6.0.0",
+  "pynvml==11.5.3",
+  "requests==2.32.3",
+  "rich==13.7.1",
+  "safetensors==0.4.3",
+  "tenacity==9.0.0",
+  "tiktoken==0.7.0",
+  "tokenizers==0.19.1",
+  "tqdm==4.66.4",
+  "transformers==4.43.3",
+  "uuid==1.30",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7",
 ]
 ]
 
 
 # Add macOS-specific packages if on Darwin (macOS)
 # Add macOS-specific packages if on Darwin (macOS)
 if sys.platform.startswith("darwin"):
 if sys.platform.startswith("darwin"):
-    install_requires.extend(
-        [
-            "mlx==0.16.3",
-            "mlx-lm==0.17.0",
-        ]
-    )
+  install_requires.extend([
+    "mlx==0.17.1",
+    "mlx-lm==0.17.0",
+  ])
 
 
 extras_require = {
 extras_require = {
-    "linting": [
-        "pylint==3.2.6",
-        "ruff==0.5.5",
-        "mypy==1.11.0",
-    ],
+  "linting": [
+    "pylint==3.2.6",
+    "ruff==0.5.5",
+    "mypy==1.11.0",
+    "yapf==0.40.2",
+  ],
 }
 }
 
 
 setup(
 setup(
-    name="exo",
-    version="0.0.1",
-    packages=find_packages(),
-    install_requires=install_requires,
-    extras_require=extras_require,
+  name="exo",
+  version="0.0.1",
+  packages=find_packages(),
+  install_requires=install_requires,
+  extras_require=extras_require,
 )
 )

+ 34 - 0
test/test_tokenizers.py

@@ -0,0 +1,34 @@
+from transformers import AutoTokenizer, AutoProcessor
+from exo.models import model_base_shards
+
+
+def test_tokenizer(name, tokenizer, verbose=False):
+    print(f"--- {name} ({tokenizer.__class__.__name__}) ---")
+    text = "Hello! How can I assist you today? Let me know if you need help with something or just want to chat."
+    encoded = tokenizer.encode(text)
+    decoded = tokenizer.decode(encoded)
+
+    print(f"{encoded=}")
+    print(f"{decoded=}")
+
+    reconstructed = ""
+    for token in encoded:
+      if verbose:
+        print(f"{token=}")
+        print(f"{tokenizer.decode([token])=}")
+      reconstructed += tokenizer.decode([token])
+    print(f"{reconstructed=}")
+
+    strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
+    assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
+
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "llava-hf/llava-1.5-7b-hf"]
+models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if shard.model_id not in ignore]
+
+import os
+verbose = os.environ.get("VERBOSE", "0").lower() == "1"
+for m in models:
+    # TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
+    # test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
+    test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
+    test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)

+ 119 - 131
tinychat/examples/tinychat/index.html

@@ -1,57 +1,44 @@
 <!DOCTYPE html>
 <!DOCTYPE html>
 
 
 <head>
 <head>
-  <title>tinychat</title>
-  <meta name="viewport" content="width=device-width, initial-scale=1">
-  <link rel="icon" href="favicon.svg" type="image/svg+xml">
-
-  <script defer src="https://cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
-  <script defer src="https://cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
-  <script defer src="https://cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
-  <script defer src="https://unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
-  <script defer src="https://unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
-
-  <script src="https://unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
-  <script src="https://unpkg.com/marked@13.0.0/marked.min.js"></script>
-  <script src="https://unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
-  <script src="https://unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
-
-  <script src="index.js"></script>
-
-  <link rel="preconnect" href="https://fonts.googleapis.com">
-  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
-  <link href="https://fonts.googleapis.com/css2?family=Megrim&display=swap" rel="stylesheet">
-
-  <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css">
-  <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css"
-    integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A=="
-    crossorigin="anonymous" referrerpolicy="no-referrer" />
-  <link rel="stylesheet" href="https://unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css">
-
-  <link rel="stylesheet" href="index.css">
-  <link rel="stylesheet" href="common.css">
-</head>
-
+<title>tinychat</title>
+<meta content="width=device-width, initial-scale=1" name="viewport"/>
+<link href="favicon.svg" rel="icon" type="image/svg+xml"/>
+<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
+<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
+<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
+<script defer="" src="/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
+<script defer="" src="/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
+<script src="/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
+<script src="/static/unpkg.com/marked@13.0.0/marked.min.js"></script>
+<script src="/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
+<script src="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
+<script src="/index.js"></script>
+<link href="/static/fonts.googleapis.com" rel="preconnect"/>
+<link crossorigin="" href="/static/fonts.gstatic.com" rel="preconnect"/>
+<link href="/static/fonts.googleapis.com/css2" rel="stylesheet"/>
+<link href="/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css" rel="stylesheet"/>
+<link crossorigin="anonymous" href="/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css" integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A==" referrerpolicy="no-referrer" rel="stylesheet">
+<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
+<link href="/index.css" rel="stylesheet"/>
+<link href="/common.css" rel="stylesheet"/>
+</link></head>
 <body>
 <body>
-  <main x-data="state" x-init="console.log(endpoint)">
-    <div class="model-selector">
-      <select x-model="cstate.selectedModel" @change="if (cstate) cstate.selectedModel = $event.target.value">
-        <option value="llama-3.1-8b" selected>Llama 3.1 8B</option>
-        <option value="llama-3.1-70b">Llama 3.1 70B</option>
-        <option value="llama-3.1-405b">Llama 3.1 405B</option>
-        <option value="llama-3-8b">Llama 3 8B</option>
-        <option value="llama-3-70b">Llama 3 70B</option>
-        <option value="mistral-nemo">Mistral Nemo</option>
-        <option value="mistral-large">Mistral Large</option>
-        <option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
-        <option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
-      </select>
-    </div>
-    <div class="home centered" x-show="home === 0" x-transition x-effect="
-      $refs.inputForm.focus();
-      if (home === 1) setTimeout(() => home = 2, 100);
-      if (home === -1) setTimeout(() => home = 0, 100);
-    " @popstate.window="
+<main x-data="state" x-init="console.log(endpoint)">
+<div class="model-selector">
+<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel">
+<option selected="" value="llama-3.1-8b">Llama 3.1 8B</option>
+<option value="llama-3.1-70b">Llama 3.1 70B</option>
+<option value="llama-3.1-405b">Llama 3.1 405B</option>
+<option value="llama-3-8b">Llama 3 8B</option>
+<option value="llama-3-70b">Llama 3 70B</option>
+<option value="mistral-nemo">Mistral Nemo</option>
+<option value="mistral-large">Mistral Large</option>
+<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
+<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
+</select>
+</div>
+<div @popstate.window="
       if (home === 2) {
       if (home === 2) {
         home = -1;
         home = -1;
         cstate = { time: null, messages: [], selectedModel: 'llama-3.1-8b' };
         cstate = { time: null, messages: [], selectedModel: 'llama-3.1-8b' };
@@ -59,51 +46,55 @@
         tokens_per_second = 0;
         tokens_per_second = 0;
         total_tokens = 0;
         total_tokens = 0;
       }
       }
-    ">
-      <h1 class="title megrim-regular">tinychat</h1>
-      <div class="histories-container-container">
-        <template x-if="histories.length">
-          <div class="histories-start"></div>
-        </template>
-        <div class="histories-container" x-intersect="
+    " class="home centered" x-effect="
+      $refs.inputForm.focus();
+      if (home === 1) setTimeout(() =&gt; home = 2, 100);
+      if (home === -1) setTimeout(() =&gt; home = 0, 100);
+    " x-show="home === 0" x-transition="">
+<h1 class="title megrim-regular">tinychat</h1>
+<div class="histories-container-container">
+<template x-if="histories.length">
+<div class="histories-start"></div>
+</template>
+<div class="histories-container" x-intersect="
           $el.scrollTo({ top: 0, behavior: 'smooth' });
           $el.scrollTo({ top: 0, behavior: 'smooth' });
         ">
         ">
-          <template x-for="_state in histories.toSorted((a, b) => b.time - a.time)">
-            <div x-data="{ otx: 0, trigger: 75 }" class="history" @click="
+<template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
+<div @click="
             cstate = _state;
             cstate = _state;
             if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
             if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
             // updateTotalTokens(cstate.messages);
             // updateTotalTokens(cstate.messages);
             home = 1;
             home = 1;
             // ensure that going back in history will go back to home
             // ensure that going back in history will go back to home
             window.history.pushState({}, '', '/');
             window.history.pushState({}, '', '/');
-          " @touchstart="
-            otx = $event.changedTouches[0].clientX;
-          " @touchmove="
-            $el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
-            $el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
           " @touchend="
           " @touchend="
-            if (Math.abs($event.changedTouches[0].clientX - otx) > trigger) removeHistory(_state);
+            if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
             $el.style.setProperty('--tx', 0);
             $el.style.setProperty('--tx', 0);
             $el.style.setProperty('--opacity', 1);
             $el.style.setProperty('--opacity', 1);
-          ">
-              <h3 x-text="new Date(_state.time).toLocaleString()"></h3>
-              <p x-text="$truncate(_state.messages[0].content, 80)"></p>
-              <!-- delete button -->
-              <button class="history-delete-button" @click.stop="removeHistory(_state);">
-                <i class=" fas fa-trash"></i>
-              </button>
-            </div>
-          </template>
-        </div>
-        <template x-if="histories.length">
-          <div class="histories-end"></div>
-        </template>
-      </div>
-    </div>
-    <div x-ref="messages" class="messages" x-init="
-      $watch('cstate', value => {
+          " @touchmove="
+            $el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
+            $el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
+          " @touchstart="
+            otx = $event.changedTouches[0].clientX;
+          " class="history" x-data="{ otx: 0, trigger: 75 }">
+<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
+<p x-text="$truncate(_state.messages[0].content, 80)"></p>
+<!-- delete button -->
+<button @click.stop="removeHistory(_state);" class="history-delete-button">
+<i class="fas fa-trash"></i>
+</button>
+</div>
+</template>
+</div>
+<template x-if="histories.length">
+<div class="histories-end"></div>
+</template>
+</div>
+</div>
+<div class="messages" x-init="
+      $watch('cstate', value =&gt; {
         $el.innerHTML = '';
         $el.innerHTML = '';
-        value.messages.forEach(({ role, content }) => {
+        value.messages.forEach(({ role, content }) =&gt; {
           const div = document.createElement('div');
           const div = document.createElement('div');
           div.className = `message message-role-${role}`;
           div.className = `message message-role-${role}`;
           try {
           try {
@@ -115,11 +106,11 @@
 
 
           // add a clipboard button to all code blocks
           // add a clipboard button to all code blocks
           const codeBlocks = div.querySelectorAll('.hljs');
           const codeBlocks = div.querySelectorAll('.hljs');
-          codeBlocks.forEach(codeBlock => {
+          codeBlocks.forEach(codeBlock =&gt; {
             const button = document.createElement('button');
             const button = document.createElement('button');
             button.className = 'clipboard-button';
             button.className = 'clipboard-button';
-            button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
-            button.onclick = () => {
+            button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;';
+            button.onclick = () =&gt; {
               // navigator.clipboard.writeText(codeBlock.textContent);
               // navigator.clipboard.writeText(codeBlock.textContent);
               const range = document.createRange();
               const range = document.createRange();
               range.setStartBefore(codeBlock);
               range.setStartBefore(codeBlock);
@@ -129,8 +120,8 @@
               document.execCommand('copy');
               document.execCommand('copy');
               window.getSelection()?.removeAllRanges();
               window.getSelection()?.removeAllRanges();
 
 
-              button.innerHTML = '<i class=\'fas fa-check\'></i>';
-              setTimeout(() => button.innerHTML = '<i class=\'fas fa-clipboard\'></i>', 1000);
+              button.innerHTML = '&lt;i class=\'fas fa-check\'&gt;&lt;/i&gt;';
+              setTimeout(() =&gt; button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;', 1000);
             };
             };
             codeBlock.appendChild(button);
             codeBlock.appendChild(button);
           });
           });
@@ -142,38 +133,37 @@
       });
       });
     " x-intersect="
     " x-intersect="
       $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
       $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
-    " x-show="home === 2" x-transition>
-    </div>
-    <div class="input-container">
-      <div class="input-performance">
-        <span class="input-performance-point">
-          <p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
-          <p class="megrim-regular">SEC TO FIRST TOKEN</p>
-        </span>
-        <span class="input-performance-point">
-          <p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
-          <p class="megrim-regular">TOKENS/SEC</p>
-        </span>
-        <span class="input-performance-point">
-          <p class="monospace" x-text="total_tokens"></p>
-          <p class="megrim-regular">TOKENS</p>
-        </span>
-      </div>
-      <div class="input">
-        <button x-show="cstate.selectedModel === 'llava-1.5-7b-hf'" class="image-input-button" @click="$refs.imageUpload.click()">
-          <i class="fas fa-image"></i>
-        </button>
-        <input x-ref="imageUpload" type="file" id="image-upload" accept="image/*" @change="$data.handleImageUpload($event)" style="display: none;">
-        <div x-show="imagePreview" class="image-preview-container">
-          <img :src="imagePreview" alt="Uploaded Image" class="image-preview">
-          <button @click="imagePreview = null; imageUrl = null;" class="remove-image-button">
-            <i class="fas fa-times"></i>
-          </button>
-        </div>
-        <textarea x-ref="inputForm" id="input-form" class="input-form" autofocus rows=1 x-autosize
-          :placeholder="generating ? 'Generating...' : 'Say something'" :disabled="generating" @input="
+    " x-ref="messages" x-show="home === 2" x-transition="">
+</div>
+<div class="input-container">
+<div class="input-performance">
+<span class="input-performance-point">
+<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
+<p class="megrim-regular">SEC TO FIRST TOKEN</p>
+</span>
+<span class="input-performance-point">
+<p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
+<p class="megrim-regular">TOKENS/SEC</p>
+</span>
+<span class="input-performance-point">
+<p class="monospace" x-text="total_tokens"></p>
+<p class="megrim-regular">TOKENS</p>
+</span>
+</div>
+<div class="input">
+<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
+<i class="fas fa-image"></i>
+</button>
+<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
+<div class="image-preview-container" x-show="imagePreview">
+<img :src="imagePreview" alt="Uploaded Image" class="image-preview"/>
+<button @click="imagePreview = null; imageUrl = null;" class="remove-image-button">
+<i class="fas fa-times"></i>
+</button>
+</div>
+<textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
             home = (home === 0) ? 1 : home
             home = (home === 0) ? 1 : home
-            if (cstate.messages.length === 0 && $el.value === '') home = -1;
+            if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;
 
 
             if ($el.value !== '') {
             if ($el.value !== '') {
               const messages = [...cstate.messages];
               const messages = [...cstate.messages];
@@ -183,19 +173,17 @@
               if (cstate.messages.length === 0) total_tokens = 0;
               if (cstate.messages.length === 0) total_tokens = 0;
               // else updateTotalTokens(cstate.messages);
               // else updateTotalTokens(cstate.messages);
             }
             }
-          " x-effect="
+          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)" autofocus="" class="input-form" id="input-form" rows="1" x-autosize="" x-effect="
             console.log(generating);
             console.log(generating);
-            if (!generating) $nextTick(() => {
+            if (!generating) $nextTick(() =&gt; {
               $el.focus();
               $el.focus();
-              setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
+              setTimeout(() =&gt; $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
             });
             });
-          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)"></textarea>
-        <button class="input-button" :disabled="generating" @click="await handleSend()">
-          <i class="fas" :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'"></i>
-        </button>
-      </div>
-    </div>
-  </main>
+          " x-ref="inputForm"></textarea>
+<button :disabled="generating" @click="await handleSend()" class="input-button">
+<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
+</button>
+</div>
+</div>
+</main>
 </body>
 </body>
-
-</html>

Plik diff jest za duży
+ 0 - 0
tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js


Plik diff jest za duży
+ 0 - 0
tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js


+ 1 - 0
tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js

@@ -0,0 +1 @@
+(()=>{function o(e){e.directive("intersect",e.skipDuringClone((t,{value:i,expression:l,modifiers:n},{evaluateLater:r,cleanup:c})=>{let s=r(l),a={rootMargin:x(n),threshold:f(n)},u=new IntersectionObserver(d=>{d.forEach(h=>{h.isIntersecting!==(i==="leave")&&(s(),n.includes("once")&&u.disconnect())})},a);u.observe(t),c(()=>{u.disconnect()})}))}function f(e){if(e.includes("full"))return .99;if(e.includes("half"))return .5;if(!e.includes("threshold"))return 0;let t=e[e.indexOf("threshold")+1];return t==="100"?1:t==="0"?0:Number(`.${t}`)}function p(e){let t=e.match(/^(-?[0-9]+)(px|%)?$/);return t?t[1]+(t[2]||"px"):void 0}function x(e){let t="margin",i="0px 0px 0px 0px",l=e.indexOf(t);if(l===-1)return i;let n=[];for(let r=1;r<5;r++)n.push(p(e[l+r]||""));return n=n.filter(r=>r!==void 0),n.length?n.join(" ").trim():i}document.addEventListener("alpine:init",()=>{window.Alpine.plugin(o)});})();

+ 11 - 0
tinychat/examples/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css

@@ -0,0 +1,11 @@
+/*!
+Pure v3.0.0
+Copyright 2013 Yahoo!
+Licensed under the BSD License.
+https://github.com/pure-css/pure/blob/master/LICENSE
+*/
+/*!
+normalize.css v | MIT License | https://necolas.github.io/normalize.css/
+Copyright (c) Nicolas Gallagher and Jonathan Neal
+*/
+/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block}

Plik diff jest za duży
+ 5 - 0
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf


BIN
tinychat/examples/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2


+ 7 - 0
tinychat/examples/tinychat/static/fonts.googleapis.com/css2

@@ -0,0 +1,7 @@
+@font-face {
+  font-family: 'Megrim';
+  font-style: normal;
+  font-weight: 400;
+  font-display: swap;
+  src: url(https://fonts.gstatic.com/s/megrim/v16/46kulbz5WjvLqJZlbQ.ttf) format('truetype');
+}

Plik diff jest za duży
+ 316 - 0
tinychat/examples/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js


+ 1 - 0
tinychat/examples/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css

@@ -0,0 +1 @@
+pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{background:#1e1e1e;color:#dcdcdc}.hljs-keyword,.hljs-literal,.hljs-name,.hljs-symbol{color:#569cd6}.hljs-link{color:#569cd6;text-decoration:underline}.hljs-built_in,.hljs-type{color:#4ec9b0}.hljs-class,.hljs-number{color:#b8d7a3}.hljs-meta .hljs-string,.hljs-string{color:#d69d85}.hljs-regexp,.hljs-template-tag{color:#9a5334}.hljs-formula,.hljs-function,.hljs-params,.hljs-subst,.hljs-title{color:#dcdcdc}.hljs-comment,.hljs-quote{color:#57a64a;font-style:italic}.hljs-doctag{color:#608b4e}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-tag{color:#9b9b9b}.hljs-template-variable,.hljs-variable{color:#bd63c5}.hljs-attr,.hljs-attribute{color:#9cdcfe}.hljs-section{color:gold}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}.hljs-bullet,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-selector-pseudo,.hljs-selector-tag{color:#d7ba7d}.hljs-addition{background-color:#144212;display:inline-block;width:100%}.hljs-deletion{background-color:#600;display:inline-block;width:100%}

Plik diff jest za duży
+ 0 - 0
tinychat/examples/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js


Plik diff jest za duży
+ 0 - 0
tinychat/examples/tinychat/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js


Plik diff jest za duży
+ 1 - 0
tinychat/examples/tinychat/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js


+ 97 - 0
tinychat/examples/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js

@@ -0,0 +1,97 @@
+(function (global, factory) {
+  typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
+  typeof define === 'function' && define.amd ? define(['exports'], factory) :
+  (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.markedHighlight = {}));
+})(this, (function (exports) { 'use strict';
+
+  function markedHighlight(options) {
+    if (typeof options === 'function') {
+      options = {
+        highlight: options
+      };
+    }
+
+    if (!options || typeof options.highlight !== 'function') {
+      throw new Error('Must provide highlight function');
+    }
+
+    if (typeof options.langPrefix !== 'string') {
+      options.langPrefix = 'language-';
+    }
+
+    return {
+      async: !!options.async,
+      walkTokens(token) {
+        if (token.type !== 'code') {
+          return;
+        }
+
+        const lang = getLang(token.lang);
+
+        if (options.async) {
+          return Promise.resolve(options.highlight(token.text, lang, token.lang || '')).then(updateToken(token));
+        }
+
+        const code = options.highlight(token.text, lang, token.lang || '');
+        if (code instanceof Promise) {
+          throw new Error('markedHighlight is not set to async but the highlight function is async. Set the async option to true on markedHighlight to await the async highlight function.');
+        }
+        updateToken(token)(code);
+      },
+      useNewRenderer: true,
+      renderer: {
+        code({ text, lang, escaped }) {
+          const language = getLang(lang);
+          const classAttr = language
+            ? ` class="${options.langPrefix}${escape(language)}"`
+            : '';
+          text = text.replace(/\n$/, '');
+          return `<pre><code${classAttr}>${escaped ? text : escape(text, true)}\n</code></pre>`;
+        }
+      }
+    };
+  }
+
+  function getLang(lang) {
+    return (lang || '').match(/\S*/)[0];
+  }
+
+  function updateToken(token) {
+    return (code) => {
+      if (typeof code === 'string' && code !== token.text) {
+        token.escaped = true;
+        token.text = code;
+      }
+    };
+  }
+
+  // copied from marked helpers
+  const escapeTest = /[&<>"']/;
+  const escapeReplace = new RegExp(escapeTest.source, 'g');
+  const escapeTestNoEncode = /[<>"']|&(?!(#\d{1,7}|#[Xx][a-fA-F0-9]{1,6}|\w+);)/;
+  const escapeReplaceNoEncode = new RegExp(escapeTestNoEncode.source, 'g');
+  const escapeReplacements = {
+    '&': '&amp;',
+    '<': '&lt;',
+    '>': '&gt;',
+    '"': '&quot;',
+    "'": '&#39;'
+  };
+  const getEscapeReplacement = (ch) => escapeReplacements[ch];
+  function escape(html, encode) {
+    if (encode) {
+      if (escapeTest.test(html)) {
+        return html.replace(escapeReplace, getEscapeReplacement);
+      }
+    } else {
+      if (escapeTestNoEncode.test(html)) {
+        return html.replace(escapeReplaceNoEncode, getEscapeReplacement);
+      }
+    }
+
+    return html;
+  }
+
+  exports.markedHighlight = markedHighlight;
+
+}));

Plik diff jest za duży
+ 5 - 0
tinychat/examples/tinychat/static/unpkg.com/marked@13.0.0/marked.min.js


+ 90 - 0
tinychat/examples/tinychat/update_deps.py

@@ -0,0 +1,90 @@
+import os
+import requests
+from bs4 import BeautifulSoup
+from urllib.parse import urljoin, urlparse
+import re
+
+def download_file(url, local_path):
+    response = requests.get(url)
+    if response.status_code == 200:
+        os.makedirs(os.path.dirname(local_path), exist_ok=True)
+        with open(local_path, 'wb') as f:
+            f.write(response.content)
+        print(f"Downloaded: {local_path}")
+    else:
+        print(response.status_code)
+        print(f"Failed to download: {url}")
+
+def update_html(html_content, base_url):
+    soup = BeautifulSoup(html_content, 'html.parser')
+
+    for tag in soup.find_all(['script', 'link']):
+        if tag.has_attr('src'):
+            url = tag['src']
+        elif tag.has_attr('href'):
+            url = tag['href']
+        else:
+            continue
+
+        if url.startswith(('http://', 'https://')):
+            full_url = url
+        else:
+            full_url = urljoin(base_url, url)
+
+        parsed_url = urlparse(full_url)
+        local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+
+        download_file(full_url, local_path)
+
+        relative_path = os.path.relpath(local_path, '.')
+        if tag.name == 'script':
+            tag['src'] = "/" + relative_path
+        elif tag.name == 'link':
+            tag['href'] = "/" + relative_path
+
+    return str(soup)
+
+# Read the HTML file
+with open('./index.html', 'r') as f:
+    html_content = f.read()
+
+# Update HTML and download files
+# updated_html = update_html(html_content, 'https://example.com')
+
+# # Write the updated HTML
+# with open('./index.html', 'w') as f:
+#     f.write(updated_html)
+
+print("HTML file updated with local paths.")
+
+# Download Font Awesome CSS and font files
+base_url = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/"
+css_url = urljoin(base_url, "css/all.min.css")
+output_dir = "static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2"
+
+# Download CSS file
+css_output_path = os.path.join(output_dir, "css", "all.min.css")
+download_file(css_url, css_output_path)
+
+# Parse CSS file for font URLs
+with open(css_output_path, 'r', encoding='utf-8') as f:
+    css_content = f.read()
+
+# Extract font URLs from the CSS content
+font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
+
+print(f"Found {len(font_urls)} font URLs")
+
+# Download font files
+for font_url in font_urls:
+    font_url = font_url.strip('"\'')
+    if font_url.startswith('../'):
+        font_url = font_url[3:]
+
+    # Use base_url instead of urljoin to keep the version number
+    full_url = base_url + font_url
+    relative_path = font_url
+    output_path = os.path.join(output_dir, relative_path)
+    download_file(full_url, output_path)
+
+print("Download complete!")

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików