1
0
Pranav Veldurthi 5 сар өмнө
parent
commit
0f10244900
100 өөрчлөгдсөн 9482 нэмэгдсэн , 0 устгасан
  1. 1 0
      build/lib/exo/__init__.py
  2. 1 0
      build/lib/exo/api/__init__.py
  3. 539 0
      build/lib/exo/api/chatgpt_api.py
  4. 1 0
      build/lib/exo/apputil/__init__.py
  5. 161 0
      build/lib/exo/apputil/anim.py
  6. 0 0
      build/lib/exo/download/__init__.py
  7. 61 0
      build/lib/exo/download/download_progress.py
  8. 0 0
      build/lib/exo/download/hf/__init__.py
  9. 447 0
      build/lib/exo/download/hf/hf_helpers.py
  10. 79 0
      build/lib/exo/download/hf/hf_shard_download.py
  11. 36 0
      build/lib/exo/download/shard_download.py
  12. 274 0
      build/lib/exo/helpers.py
  13. 0 0
      build/lib/exo/inference/__init__.py
  14. 58 0
      build/lib/exo/inference/debug_inference_engine.py
  15. 34 0
      build/lib/exo/inference/dummy_inference_engine.py
  16. 58 0
      build/lib/exo/inference/inference_engine.py
  17. 0 0
      build/lib/exo/inference/mlx/__init__.py
  18. 307 0
      build/lib/exo/inference/mlx/models/StableDiffusionPipeline.py
  19. 0 0
      build/lib/exo/inference/mlx/models/__init__.py
  20. 9 0
      build/lib/exo/inference/mlx/models/base.py
  21. 127 0
      build/lib/exo/inference/mlx/models/deepseek_v2.py
  22. 118 0
      build/lib/exo/inference/mlx/models/gemma2.py
  23. 125 0
      build/lib/exo/inference/mlx/models/llama.py
  24. 585 0
      build/lib/exo/inference/mlx/models/llava.py
  25. 128 0
      build/lib/exo/inference/mlx/models/qwen2.py
  26. 77 0
      build/lib/exo/inference/mlx/sharded_inference_engine.py
  27. 256 0
      build/lib/exo/inference/mlx/sharded_utils.py
  28. 45 0
      build/lib/exo/inference/mlx/stateful_model.py
  29. 40 0
      build/lib/exo/inference/mlx/test_sharded_llama.py
  30. 64 0
      build/lib/exo/inference/mlx/test_sharded_llava.py
  31. 52 0
      build/lib/exo/inference/mlx/test_sharded_model.py
  32. 39 0
      build/lib/exo/inference/shard.py
  33. 53 0
      build/lib/exo/inference/test_dummy_inference_engine.py
  34. 56 0
      build/lib/exo/inference/test_inference_engine.py
  35. 0 0
      build/lib/exo/inference/tinygrad/__init__.py
  36. 99 0
      build/lib/exo/inference/tinygrad/inference.py
  37. 0 0
      build/lib/exo/inference/tinygrad/models/__init__.py
  38. 282 0
      build/lib/exo/inference/tinygrad/models/llama.py
  39. 42 0
      build/lib/exo/inference/tinygrad/stateful_model.py
  40. 52 0
      build/lib/exo/inference/tinygrad/tinygrad_helpers.py
  41. 64 0
      build/lib/exo/inference/tokenizers.py
  42. 274 0
      build/lib/exo/main.py
  43. 151 0
      build/lib/exo/models.py
  44. 5 0
      build/lib/exo/networking/__init__.py
  45. 17 0
      build/lib/exo/networking/discovery.py
  46. 0 0
      build/lib/exo/networking/grpc/__init__.py
  47. 173 0
      build/lib/exo/networking/grpc/grpc_peer_handle.py
  48. 147 0
      build/lib/exo/networking/grpc/grpc_server.py
  49. 16 0
      build/lib/exo/networking/grpc/node_service_pb2.py
  50. 360 0
      build/lib/exo/networking/grpc/node_service_pb2_grpc.py
  51. 0 0
      build/lib/exo/networking/manual/__init__.py
  52. 71 0
      build/lib/exo/networking/manual/manual_discovery.py
  53. 31 0
      build/lib/exo/networking/manual/network_topology_config.py
  54. 103 0
      build/lib/exo/networking/manual/test_manual_discovery.py
  55. 49 0
      build/lib/exo/networking/manual/test_network_topology_config.py
  56. 56 0
      build/lib/exo/networking/peer_handle.py
  57. 11 0
      build/lib/exo/networking/server.py
  58. 0 0
      build/lib/exo/networking/tailscale/__init__.py
  59. 178 0
      build/lib/exo/networking/tailscale/tailscale_discovery.py
  60. 125 0
      build/lib/exo/networking/tailscale/tailscale_helpers.py
  61. 43 0
      build/lib/exo/networking/tailscale/test_tailscale_discovery.py
  62. 0 0
      build/lib/exo/networking/udp/__init__.py
  63. 77 0
      build/lib/exo/networking/udp/test_udp_discovery.py
  64. 215 0
      build/lib/exo/networking/udp/udp_discovery.py
  65. 4 0
      build/lib/exo/orchestration/__init__.py
  66. 47 0
      build/lib/exo/orchestration/node.py
  67. 488 0
      build/lib/exo/orchestration/standard_node.py
  68. 57 0
      build/lib/exo/orchestration/test_node.py
  69. 0 0
      build/lib/exo/stats/__init__.py
  70. 29 0
      build/lib/exo/stats/metrics.py
  71. 50 0
      build/lib/exo/test_callbacks.py
  72. 130 0
      build/lib/exo/tinychat/common.css
  73. 25 0
      build/lib/exo/tinychat/favicon.svg
  74. 484 0
      build/lib/exo/tinychat/index.css
  75. 255 0
      build/lib/exo/tinychat/index.html
  76. 687 0
      build/lib/exo/tinychat/index.js
  77. 0 0
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js
  78. 0 0
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js
  79. 1 0
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js
  80. 11 0
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css
  81. 5 0
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css
  82. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf
  83. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2
  84. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf
  85. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2
  86. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf
  87. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2
  88. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf
  89. BIN
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2
  90. 7 0
      build/lib/exo/tinychat/static/fonts.googleapis.com/css2
  91. 316 0
      build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js
  92. 1 0
      build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css
  93. 0 0
      build/lib/exo/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js
  94. 0 0
      build/lib/exo/tinychat/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js
  95. 1 0
      build/lib/exo/tinychat/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js
  96. 97 0
      build/lib/exo/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js
  97. 5 0
      build/lib/exo/tinychat/static/unpkg.com/marked@13.0.0/marked.min.js
  98. 93 0
      build/lib/exo/tinychat/update_deps.py
  99. 0 0
      build/lib/exo/topology/__init__.py
  100. 217 0
      build/lib/exo/topology/device_capabilities.py

+ 1 - 0
build/lib/exo/__init__.py

@@ -0,0 +1 @@
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 1 - 0
build/lib/exo/api/__init__.py

@@ -0,0 +1 @@
+from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI

+ 539 - 0
build/lib/exo/api/chatgpt_api.py

@@ -0,0 +1,539 @@
+import uuid
+import time
+import asyncio
+import json
+from pathlib import Path
+from transformers import AutoTokenizer
+from typing import List, Literal, Union, Dict
+from aiohttp import web
+import aiohttp_cors
+import traceback
+import signal
+from exo import DEBUG, VERSION
+from exo.download.download_progress import RepoProgressEvent
+from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
+from exo.inference.tokenizers import resolve_tokenizer
+from exo.orchestration import Node
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from exo.apputil import create_animation_mp4
+from typing import Callable, Optional
+from PIL import Image
+import numpy as np
+import base64
+from io import BytesIO
+import mlx.core as mx
+import tempfile
+
+class Message:
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+    self.role = role
+    self.content = content
+
+  def to_dict(self):
+    return {"role": self.role, "content": self.content}
+
+
+
+class ChatCompletionRequest:
+  def __init__(self, model: str, messages: List[Message], temperature: float):
+    self.model = model
+    self.messages = messages
+    self.temperature = temperature
+
+  def to_dict(self):
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
+
+
+def generate_completion(
+  chat_request: ChatCompletionRequest,
+  tokenizer,
+  prompt: str,
+  request_id: str,
+  tokens: List[int],
+  stream: bool,
+  finish_reason: Union[Literal["length", "stop"], None],
+  object_type: Literal["chat.completion", "text_completion"],
+) -> dict:
+  completion = {
+    "id": f"chatcmpl-{request_id}",
+    "object": object_type,
+    "created": int(time.time()),
+    "model": chat_request.model,
+    "system_fingerprint": f"exo_{VERSION}",
+    "choices": [{
+      "index": 0,
+      "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
+      "logprobs": None,
+      "finish_reason": finish_reason,
+    }],
+  }
+
+  if not stream:
+    completion["usage"] = {
+      "prompt_tokens": len(tokenizer.encode(prompt)),
+      "completion_tokens": len(tokens),
+      "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
+    }
+
+  choice = completion["choices"][0]
+  if object_type.startswith("chat.completion"):
+    key_name = "delta" if stream else "message"
+    choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
+  elif object_type == "text_completion":
+    choice["text"] = tokenizer.decode(tokens)
+  else:
+    ValueError(f"Unsupported response type: {object_type}")
+
+  return completion
+
+
+def remap_messages(messages: List[Message]) -> List[Message]:
+  remapped_messages = []
+  last_image = None
+  for message in messages:
+    if not isinstance(message.content, list):
+      remapped_messages.append(message)
+      continue
+
+    remapped_content = []
+    for content in message.content:
+      if isinstance(content, dict):
+        if content.get("type") in ["image_url", "image"]:
+          image_url = content.get("image_url", {}).get("url") or content.get("image")
+          if image_url:
+            last_image = {"type": "image", "image": image_url}
+            remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
+        else:
+          remapped_content.append(content)
+      else:
+        remapped_content.append(content)
+    remapped_messages.append(Message(role=message.role, content=remapped_content))
+
+  if last_image:
+    # Replace the last image placeholder with the actual image content
+    for message in reversed(remapped_messages):
+      for i, content in enumerate(message.content):
+        if isinstance(content, dict):
+          if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
+            message.content[i] = last_image
+            return remapped_messages
+
+  return remapped_messages
+
+
+def build_prompt(tokenizer, _messages: List[Message]):
+  messages = remap_messages(_messages)
+  prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
+  for message in messages:
+    if not isinstance(message.content, list):
+      continue
+
+  return prompt
+
+
+def parse_message(data: dict):
+  if "role" not in data or "content" not in data:
+    raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
+  return Message(data["role"], data["content"])
+
+
+def parse_chat_request(data: dict, default_model: str):
+  return ChatCompletionRequest(
+    data.get("model", default_model),
+    [parse_message(msg) for msg in data["messages"]],
+    data.get("temperature", 0.0),
+  )
+
+
+class PromptSession:
+  def __init__(self, request_id: str, timestamp: int, prompt: str):
+    self.request_id = request_id
+    self.timestamp = timestamp
+    self.prompt = prompt
+
+class ChatGPTAPI:
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
+    self.node = node
+    self.inference_engine_classname = inference_engine_classname
+    self.response_timeout = response_timeout
+    self.on_chat_completion_request = on_chat_completion_request
+    self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload
+    self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
+    self.prev_token_lens: Dict[str, int] = {}
+    self.stream_tasks: Dict[str, asyncio.Task] = {}
+    self.default_model = default_model or "llama-3.2-1b"
+
+    cors = aiohttp_cors.setup(self.app)
+    cors_options = aiohttp_cors.ResourceOptions(
+      allow_credentials=True,
+      expose_headers="*",
+      allow_headers="*",
+      allow_methods="*",
+    )
+    cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
+    cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
+    cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
+    cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
+    cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
+    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
+    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
+    cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
+
+      
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
+      self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
+
+    self.app.middlewares.append(self.timeout_middleware)
+    self.app.middlewares.append(self.log_request)
+
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
+
+  async def timeout_middleware(self, app, handler):
+    async def middleware(request):
+      try:
+        return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
+      except asyncio.TimeoutError:
+        return web.json_response({"detail": "Request timed out"}, status=408)
+
+    return middleware
+
+  async def log_request(self, app, handler):
+    async def middleware(request):
+      if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
+      return await handler(request)
+
+    return middleware
+
+  async def handle_root(self, request):
+    return web.FileResponse(self.static_dir/"index.html")
+
+  async def handle_healthcheck(self, request):
+    return web.json_response({"status": "ok"})
+
+  async def handle_model_support(self, request):
+    return web.json_response({
+      "model pool": {
+        model_name: pretty_name.get(model_name, model_name)
+        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
+      }
+    })
+
+  async def handle_get_models(self, request):
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
+
+  async def handle_post_chat_token_encode(self, request):
+    data = await request.json()
+    shard = build_base_shard(self.default_model, self.inference_engine_classname)
+    messages = [parse_message(msg) for msg in data.get("messages", [])]
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
+    return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
+
+  async def handle_get_download_progress(self, request):
+    progress_data = {}
+    for node_id, progress_event in self.node.node_download_progress.items():
+      if isinstance(progress_event, RepoProgressEvent):
+        progress_data[node_id] = progress_event.to_dict()
+      else:
+        print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
+    return web.json_response(progress_data)
+
+  async def handle_post_chat_completions(self, request):
+    data = await request.json()
+    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    stream = data.get("stream", False)
+    chat_request = parse_chat_request(data, self.default_model)
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
+    if not chat_request.model or chat_request.model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      chat_request.model = self.default_model
+    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
+    if not shard:
+      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
+      return web.json_response(
+        {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
+        status=400,
+      )
+
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
+    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
+
+    prompt = build_prompt(tokenizer, chat_request.messages)
+    request_id = str(uuid.uuid4())
+    if self.on_chat_completion_request:
+      try:
+        self.on_chat_completion_request(request_id, chat_request, prompt)
+      except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
+    # request_id = None
+    # match = self.prompts.find_longest_prefix(prompt)
+    # if match and len(prompt) > len(match[1].prompt):
+    #     if DEBUG >= 2:
+    #       print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
+    #     request_id = match[1].request_id
+    #     self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
+    #     # remove the matching prefix from the prompt
+    #     prompt = prompt[len(match[1].prompt):]
+    # else:
+    #   request_id = str(uuid.uuid4())
+    #   self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
+
+    callback_id = f"chatgpt-api-wait-response-{request_id}"
+    callback = self.node.on_token.register(callback_id)
+
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+
+    try:
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
+
+      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
+
+      if stream:
+        response = web.StreamResponse(
+          status=200,
+          reason="OK",
+          headers={
+            "Content-Type": "text/event-stream",
+            "Cache-Control": "no-cache",
+          },
+        )
+        await response.prepare(request)
+
+        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
+          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
+          new_tokens = tokens[prev_last_tokens_len:]
+          finish_reason = None
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
+                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
+          if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
+            new_tokens = new_tokens[:-1]
+            if is_finished:
+              finish_reason = "stop"
+          if is_finished and not finish_reason:
+            finish_reason = "length"
+
+          completion = generate_completion(
+            chat_request,
+            tokenizer,
+            prompt,
+            request_id,
+            new_tokens,
+            stream,
+            finish_reason,
+            "chat.completion",
+          )
+          if DEBUG >= 2: print(f"Streaming completion: {completion}")
+          try:
+            await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+          except Exception as e:
+            if DEBUG >= 2: print(f"Error streaming completion: {e}")
+            if DEBUG >= 2: traceback.print_exc()
+
+        def on_result(_request_id: str, tokens: List[int], is_finished: bool):
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
+
+          return _request_id == request_id and is_finished
+
+        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
+        if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
+          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
+          try:
+            await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
+          except asyncio.TimeoutError:
+            print("WARNING: Stream task timed out. This should not happen.")
+        await response.write_eof()
+        return response
+      else:
+        _, tokens, _ = await callback.wait(
+          lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+          timeout=self.response_timeout,
+        )
+
+        finish_reason = "length"
+        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
+        if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
+        if tokens[-1] == eos_token_id:
+          tokens = tokens[:-1]
+          finish_reason = "stop"
+
+        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
+    except asyncio.TimeoutError:
+      return web.json_response({"detail": "Response generation timed out"}, status=408)
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+    finally:
+      deregistered_callback = self.node.on_token.deregister(callback_id)
+      if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
+
+  
+  async def handle_post_image_generations(self, request):
+    data = await request.json()
+
+    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    stream = data.get("stream", False)
+    model = data.get("model", "")
+    prompt = data.get("prompt", "")
+    image_url = data.get("image_url", "")
+    print(f"model: {model}, prompt: {prompt}, stream: {stream}")
+    shard = build_base_shard(model, self.inference_engine_classname)
+    print(f"shard: {shard}")
+    if not shard:
+        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
+
+    request_id = str(uuid.uuid4())
+    callback_id = f"chatgpt-api-wait-response-{request_id}"
+    callback = self.node.on_token.register(callback_id)
+    try:
+      if image_url != "" and image_url != None:
+        img = self.base64_decode(image_url)
+      else:
+        img = None
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
+
+
+      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
+      await response.prepare(request)
+
+      def get_progress_bar(current_step, total_steps, bar_length=50):
+        # Calculate the percentage of completion
+        percent = float(current_step) / total_steps
+        # Calculate the number of hashes to display
+        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
+        spaces = ' ' * (bar_length - len(arrow))
+        
+        # Create the progress bar string
+        progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
+        return progress_bar
+
+      async def stream_image(_request_id: str, result, is_finished: bool):
+          if isinstance(result, list):
+              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
+
+          elif isinstance(result, np.ndarray):
+            im = Image.fromarray(np.array(result))
+            images_folder = get_exo_images_dir()
+            # Save the image to a file
+            image_filename = f"{_request_id}.png"
+            image_path = images_folder / image_filename
+            im.save(image_path)
+            image_url = request.app.router['static_images'].url_for(filename=image_filename)
+            base_url = f"{request.scheme}://{request.host}"
+            # Construct the full URL correctly
+            full_image_url = base_url + str(image_url)
+            
+            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            if is_finished:
+              await response.write_eof()
+              
+
+      stream_task = None
+      def on_result(_request_id: str, result, is_finished: bool):
+          nonlocal stream_task
+          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
+          return _request_id == request_id and is_finished
+
+      await callback.wait(on_result, timeout=self.response_timeout*10)
+      
+      if stream_task:
+          # Wait for the stream task to complete before returning
+          await stream_task
+
+      return response
+
+    except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
+        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+  
+  async def handle_create_animation(self, request):
+    try:
+      data = await request.json()
+      replacement_image_path = data.get("replacement_image_path")
+      device_name = data.get("device_name", "Local Device")
+      prompt_text = data.get("prompt", "")
+
+      if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
+
+      if not replacement_image_path:
+        return web.json_response({"error": "replacement_image_path is required"}, status=400)
+
+      # Create temp directory if it doesn't exist
+      tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
+      tmp_dir.mkdir(parents=True, exist_ok=True)
+
+      # Generate unique output filename in temp directory
+      output_filename = f"animation_{uuid.uuid4()}.mp4"
+      output_path = str(tmp_dir/output_filename)
+
+      if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
+
+      # Create the animation
+      create_animation_mp4(
+        replacement_image_path,
+        output_path,
+        device_name,
+        prompt_text
+      )
+
+      return web.json_response({
+        "status": "success",
+        "output_path": output_path
+      })
+
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
+  async def handle_post_download(self, request):
+    try:
+      data = await request.json()
+      model_name = data.get("model")
+      if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
+      if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
+      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+
+      return web.json_response({
+        "status": "success",
+        "message": f"Download started for model: {model_name}"
+      })
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
+  async def run(self, host: str = "0.0.0.0", port: int = 52415):
+    runner = web.AppRunner(self.app)
+    await runner.setup()
+    site = web.TCPSite(runner, host, port)
+    await site.start()
+
+  def base64_decode(self, base64_string):
+    #decode and reshape image
+    if base64_string.startswith('data:image'):
+        base64_string = base64_string.split(',')[1]
+    image_data = base64.b64decode(base64_string)
+    img = Image.open(BytesIO(image_data))
+    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    if W != img.width or H != img.height:
+        print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+    img = mx.array(np.array(img))
+    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = img[None]
+    return img
+  

+ 1 - 0
build/lib/exo/apputil/__init__.py

@@ -0,0 +1 @@
+from exo.apputil.anim import create_animation_mp4

+ 161 - 0
build/lib/exo/apputil/anim.py

@@ -0,0 +1,161 @@
+from PIL import Image, ImageDraw, ImageFont, ImageFilter
+import os
+import numpy as np
+import cv2
+
+def draw_rounded_rectangle(draw, coords, radius, fill):
+  left, top, right, bottom = coords
+  diameter = radius * 2
+  draw.rectangle([left + radius, top, right - radius, bottom], fill=fill)
+  draw.rectangle([left, top + radius, right, bottom - radius], fill=fill)
+  draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill)
+  draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill)
+  draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill)
+  draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill)
+
+def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_width = rect_right - rect_left
+  rect_height = rect_bottom - rect_top
+  text_x = rect_left + (rect_width - text_width) // 2
+  text_y = rect_top + (rect_height - text_height) // 2
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_height = rect_bottom - rect_top
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = rect_left + padding_left
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  _, rect_top, rect_right, rect_bottom = base_coords
+  rect_height = rect_bottom - rect_top
+  new_rect_left = rect_right - (text_width + (padding * 2))
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = new_rect_left + padding
+  draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+  return new_rect_left
+
+def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)):
+  left, top, right, bottom = coords
+  total_width = right - left
+  draw.rectangle(coords, fill=bg_color)
+  progress_width = int(total_width * progress)
+  if progress_width > 0:
+    draw.rectangle((left, top, left + progress_width, bottom), fill=color)
+
+def crop_image(image, top_crop=70):
+  width, height = image.size
+  return image.crop((0, top_crop, width, height))
+
+def create_animation_mp4(
+  replacement_image_path,
+  output_path,
+  device_name,
+  prompt_text,
+  fps=30,
+  target_size=(512, 512),
+  target_position=(139, 755),
+  progress_coords=(139, 1285, 655, 1295),
+  device_coords=(1240, 370, 1640, 416),
+  prompt_coords=(332, 1702, 2662, 1745)
+):
+  frames = []
+  try:
+    font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20)
+    promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24)
+  except:
+    font = ImageFont.load_default()
+    promptfont = ImageFont.load_default()
+
+  # Process first frame
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
+  draw = ImageDraw.Draw(base_img)
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
+
+  # Process second frame with typing animation
+  base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
+  for i in range(len(prompt_text) + 1):
+    current_frame = base_img2.copy()
+    draw = ImageDraw.Draw(current_frame)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    if i > 0:  # Only draw if we have at least one character
+      draw_left_aligned_text_rounded(draw, prompt_text[:i], promptfont, prompt_coords)
+    frames.extend([crop_image(current_frame)] * 2)  # 2 frames per character for smooth typing
+  
+  # Hold the complete prompt for a moment
+  frames.extend([frames[-1]] * 30)  # Hold for 1 second
+
+  # Create blur sequence
+  replacement_img = Image.open(replacement_image_path)
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
+  blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
+
+  for i, blur_amount in enumerate(blur_steps):
+    new_frame = base_img.copy()
+    draw = ImageDraw.Draw(new_frame)
+
+    replacement_copy = replacement_img.copy()
+    replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+    if blur_amount > 0:
+      replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount))
+
+    mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+    new_frame.paste(replacement_copy, target_position, mask)
+
+    draw_progress_bar(draw, (i + 1) / 9, progress_coords)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+    frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
+
+  # Create and add final frame (image4)
+  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
+  draw = ImageDraw.Draw(final_base)
+
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+  replacement_copy = replacement_img.copy()
+  replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+  mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+  final_base.paste(replacement_copy, target_position, mask)
+
+  frames.extend([crop_image(final_base)] * 30)  # 1 second at 30fps
+
+  # Convert frames to video using H.264 codec
+  if frames:
+    first_frame = np.array(frames[0])
+    height, width = first_frame.shape[:2]
+    fourcc = cv2.VideoWriter_fourcc(*'avc1')
+    out = cv2.VideoWriter(
+      output_path,
+      fourcc,
+      fps,
+      (width, height),
+      isColor=True
+    )
+
+    if not out.isOpened():
+      print("Error: VideoWriter failed to open")
+      return
+
+    for frame in frames:
+      frame_array = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
+      out.write(frame_array)
+    
+    out.release()
+    print(f"Video saved successfully to {output_path}")

+ 0 - 0
build/lib/exo/download/__init__.py


+ 61 - 0
build/lib/exo/download/download_progress.py

@@ -0,0 +1,61 @@
+from typing import Dict, Callable, Coroutine, Any, Literal
+from dataclasses import dataclass
+from datetime import timedelta
+
+
+@dataclass
+class RepoFileProgressEvent:
+  repo_id: str
+  repo_revision: str
+  file_path: str
+  downloaded: int
+  downloaded_this_session: int
+  total: int
+  speed: int
+  eta: timedelta
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
+      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
+    }
+
+  @classmethod
+  def from_dict(cls, data):
+    if 'eta' in data: data['eta'] = timedelta(seconds=data['eta'])
+    return cls(**data)
+
+
+@dataclass
+class RepoProgressEvent:
+  repo_id: str
+  repo_revision: str
+  completed_files: int
+  total_files: int
+  downloaded_bytes: int
+  downloaded_bytes_this_session: int
+  total_bytes: int
+  overall_speed: int
+  overall_eta: timedelta
+  file_progress: Dict[str, RepoFileProgressEvent]
+  status: Literal["not_started", "in_progress", "complete"]
+
+  def to_dict(self):
+    return {
+      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
+      "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
+      "file_progress": {k: v.to_dict()
+                        for k, v in self.file_progress.items()}, "status": self.status
+    }
+
+  @classmethod
+  def from_dict(cls, data):
+    if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
+    if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
+
+    return cls(**data)
+
+
+RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
+RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 0 - 0
build/lib/exo/download/hf/__init__.py


+ 447 - 0
build/lib/exo/download/hf/hf_helpers.py

@@ -0,0 +1,447 @@
+import aiofiles.os as aios
+from typing import Union
+import asyncio
+import aiohttp
+import json
+import os
+import sys
+import shutil
+from urllib.parse import urljoin
+from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
+from datetime import datetime, timedelta
+from fnmatch import fnmatch
+from pathlib import Path
+from typing import Generator, Iterable, TypeVar, TypedDict
+from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
+from exo.helpers import DEBUG, is_frozen
+from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
+from exo.inference.shard import Shard
+import aiofiles
+
+T = TypeVar("T")
+
+async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
+  refs_dir = get_repo_root(repo_id)/"refs"
+  refs_file = refs_dir/revision
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
+      return snapshot_dir
+  return None
+
+
+def filter_repo_objects(
+  items: Iterable[T],
+  *,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  key: Optional[Callable[[T], str]] = None,
+) -> Generator[T, None, None]:
+  if isinstance(allow_patterns, str):
+    allow_patterns = [allow_patterns]
+  if isinstance(ignore_patterns, str):
+    ignore_patterns = [ignore_patterns]
+  if allow_patterns is not None:
+    allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
+  if ignore_patterns is not None:
+    ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
+
+  if key is None:
+
+    def _identity(item: T) -> str:
+      if isinstance(item, str):
+        return item
+      if isinstance(item, Path):
+        return str(item)
+      raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
+
+    key = _identity
+
+  for item in items:
+    path = key(item)
+    if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
+      continue
+    if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
+      continue
+    yield item
+
+
+def _add_wildcard_to_directories(pattern: str) -> str:
+  if pattern[-1] == "/":
+    return pattern + "*"
+  return pattern
+
+
+def get_hf_endpoint() -> str:
+  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+
+
+def get_hf_home() -> Path:
+  """Get the Hugging Face home directory."""
+  return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
+
+
+async def get_hf_token():
+  """Retrieve the Hugging Face token from the user's HF_HOME directory."""
+  token_path = get_hf_home()/"token"
+  if await aios.path.exists(token_path):
+    async with aiofiles.open(token_path, 'r') as f:
+      return (await f.read()).strip()
+  return None
+
+
+async def get_auth_headers():
+  """Get authentication headers if a token is available."""
+  token = await get_hf_token()
+  if token:
+    return {"Authorization": f"Bearer {token}"}
+  return {}
+
+
+def get_repo_root(repo_id: str) -> Path:
+  """Get the root directory for a given repo ID in the Hugging Face cache."""
+  sanitized_repo_id = str(repo_id).replace("/", "--")
+  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
+
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)  
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.name.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if await aios.path.exists(dest_path):
+        print('Skipping moving model to .cache directory')
+      else:
+        try:
+          await aios.rename(str(path), str(dest_path))
+        except Exception as e:
+          print(f'Error moving model to .cache: {e}')
+    
+    
+    
+async def fetch_file_list(session, repo_id, revision, path=""):
+  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
+  url = f"{api_url}/{path}" if path else api_url
+
+  headers = await get_auth_headers()
+  async with session.get(url, headers=headers) as response:
+    if response.status == 200:
+      data = await response.json()
+      files = []
+      for item in data:
+        if item["type"] == "file":
+          files.append({"path": item["path"], "size": item["size"]})
+        elif item["type"] == "directory":
+          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
+          files.extend(subfiles)
+      return files
+    else:
+      raise Exception(f"Failed to fetch file list: {response.status}")
+
+
+@retry(
+  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
+)
+async def download_file(
+  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
+):
+  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+  url = urljoin(base_url, file_path)
+  local_path = os.path.join(save_directory, file_path)
+
+  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
+
+  # Check if file already exists and get its size
+  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
+
+  headers = await get_auth_headers()
+  if use_range_request:
+    headers["Range"] = f"bytes={local_file_size}-"
+
+  async with session.get(url, headers=headers) as response:
+    total_size = int(response.headers.get('Content-Length', 0))
+    downloaded_size = local_file_size
+    downloaded_this_session = 0
+    mode = 'ab' if use_range_request else 'wb'
+    if downloaded_size == total_size:
+      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    if response.status == 200:
+      # File doesn't support range requests or we're not using them, start from beginning
+      mode = 'wb'
+      downloaded_size = 0
+    elif response.status == 206:
+      # Partial content, resume download
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+    elif response.status == 416:
+      # Range not satisfiable, get the actual file size
+      content_range = response.headers.get('Content-Range', '')
+      try:
+        total_size = int(content_range.split('/')[-1])
+        if downloaded_size == total_size:
+          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
+          if progress_callback:
+            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+          return
+      except ValueError:
+        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
+        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
+    else:
+      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
+
+    if downloaded_size == total_size:
+      print(f"File already downloaded: {file_path}")
+      if progress_callback:
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+      return
+
+    DOWNLOAD_CHUNK_SIZE = 32768
+    start_time = datetime.now()
+    async with aiofiles.open(local_path, mode) as f:
+      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
+        await f.write(chunk)
+        downloaded_size += len(chunk)
+        downloaded_this_session += len(chunk)
+        if progress_callback and total_size:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
+          remaining_size = total_size - downloaded_size
+          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
+          status = "in_progress" if downloaded_size < total_size else "complete"
+          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
+          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
+    if DEBUG >= 2: print(f"Downloaded: {file_path}")
+
+
+async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
+  repo_root = get_repo_root(repo_id)
+  refs_dir = repo_root/"refs"
+  refs_file = refs_dir/revision
+
+  # Check if we have a cached commit hash
+  if await aios.path.exists(refs_file):
+    async with aiofiles.open(refs_file, 'r') as f:
+      commit_hash = (await f.read()).strip()
+      if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
+      return commit_hash
+
+  # Fetch the commit hash for the given revision
+  async with aiohttp.ClientSession() as session:
+    api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
+    headers = await get_auth_headers()
+    async with session.get(api_url, headers=headers) as response:
+      if response.status != 200:
+        raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
+      revision_info = await response.json()
+      commit_hash = revision_info['sha']
+
+  # Cache the commit hash
+  await aios.makedirs(refs_dir, exist_ok=True)
+  async with aiofiles.open(refs_file, 'w') as f:
+    await f.write(commit_hash)
+
+  return commit_hash
+
+
+async def download_repo_files(
+  repo_id: str,
+  revision: str = "main",
+  progress_callback: Optional[RepoProgressCallback] = None,
+  allow_patterns: Optional[Union[List[str], str]] = None,
+  ignore_patterns: Optional[Union[List[str], str]] = None,
+  max_parallel_downloads: int = 4
+) -> Path:
+  repo_root = get_repo_root(repo_id)
+  snapshots_dir = repo_root/"snapshots"
+  cachedreqs_dir = repo_root/"cachedreqs"
+
+  # Ensure directories exist
+  await aios.makedirs(snapshots_dir, exist_ok=True)
+  await aios.makedirs(cachedreqs_dir, exist_ok=True)
+
+  # Resolve revision to commit hash
+  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
+
+  # Set up the snapshot directory
+  snapshot_dir = snapshots_dir/commit_hash
+  await aios.makedirs(snapshot_dir, exist_ok=True)
+
+  # Set up the cached file list directory
+  cached_file_list_dir = cachedreqs_dir/commit_hash
+  await aios.makedirs(cached_file_list_dir, exist_ok=True)
+  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
+
+  async with aiohttp.ClientSession() as session:
+    # Check if we have a cached file list
+    if await aios.path.exists(cached_file_list_path):
+      async with aiofiles.open(cached_file_list_path, 'r') as f:
+        file_list = json.loads(await f.read())
+      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
+    else:
+      file_list = await fetch_file_list(session, repo_id, revision)
+      # Cache the file list
+      async with aiofiles.open(cached_file_list_path, 'w') as f:
+        await f.write(json.dumps(file_list))
+      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
+
+    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
+    if model_index_exists:
+      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
+
+    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
+    total_files = len(filtered_file_list)
+    total_bytes = sum(file["size"] for file in filtered_file_list)
+    file_progress: Dict[str, RepoFileProgressEvent] = {
+      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
+      for file in filtered_file_list
+    }
+    start_time = datetime.now()
+
+    async def download_with_progress(file_info, progress_state):
+      local_path = snapshot_dir/file_info["path"]
+      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
+        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
+        progress_state['completed_files'] += 1
+        progress_state['downloaded_bytes'] += file_info["size"]
+        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+        return
+
+      async def file_progress_callback(event: RepoFileProgressEvent):
+        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
+        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
+        file_progress[event.file_path] = event
+        if progress_callback:
+          elapsed_time = (datetime.now() - start_time).total_seconds()
+          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
+          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
+          await progress_callback(
+            RepoProgressEvent(
+              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+              overall_eta, file_progress, status
+            )
+          )
+
+      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
+      progress_state['completed_files'] += 1
+      file_progress[
+        file_info["path"]
+      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
+      if progress_callback:
+        elapsed_time = (datetime.now() - start_time).total_seconds()
+        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
+        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
+        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
+        await progress_callback(
+          RepoProgressEvent(
+            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
+            overall_eta, file_progress, status
+          )
+        )
+
+    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
+
+    semaphore = asyncio.Semaphore(max_parallel_downloads)
+
+    async def download_with_semaphore(file_info):
+      async with semaphore:
+        await download_with_progress(file_info, progress_state)
+
+    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
+    await asyncio.gather(*tasks)
+
+  return snapshot_dir
+
+
+async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
+  """
+    Retrieve the weight map from the model.safetensors.index.json file.
+
+    Args:
+        repo_id (str): The Hugging Face repository ID.
+        revision (str): The revision of the repository to use.
+
+    Returns:
+        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
+    """
+
+  # Download the index file
+  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
+
+  # Check if the file exists
+  repo_root = get_repo_root(repo_id)
+  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
+  snapshot_dir = repo_root/"snapshots"/commit_hash
+  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
+
+  if index_file:
+    index_file_path = snapshot_dir/index_file
+    if await aios.path.exists(index_file_path):
+      async with aiofiles.open(index_file_path, 'r') as f:
+        index_data = json.loads(await f.read())
+      return index_data.get("weight_map")
+
+  return None
+
+
+def extract_layer_num(tensor_name: str) -> Optional[int]:
+  # This is a simple example and might need to be adjusted based on the actual naming convention
+  parts = tensor_name.split('.')
+  for part in parts:
+    if part.isdigit():
+      return int(part)
+  return None
+
+
+def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
+  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
+  shard_specific_patterns = set()
+  if weight_map:
+    for tensor_name, filename in weight_map.items():
+      layer_num = extract_layer_num(tensor_name)
+      if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
+        shard_specific_patterns.add(filename)
+    sorted_file_names = sorted(weight_map.values())
+    if shard.is_first_layer():
+      shard_specific_patterns.add(sorted_file_names[0])
+    elif shard.is_last_layer():
+      shard_specific_patterns.add(sorted_file_names[-1])
+  else:
+    shard_specific_patterns = set(["*.safetensors"])
+  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
+  return list(default_patterns | shard_specific_patterns)
+
+async def has_hf_home_read_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.R_OK)
+  except OSError: return False
+
+async def has_hf_home_write_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.W_OK)
+  except OSError: return False

+ 79 - 0
build/lib/exo/download/hf/hf_shard_download.py

@@ -0,0 +1,79 @@
+import asyncio
+import traceback
+from pathlib import Path
+from typing import Dict, List, Tuple
+from exo.inference.shard import Shard
+from exo.download.shard_download import ShardDownloader
+from exo.download.download_progress import RepoProgressEvent
+from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
+from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import model_cards, get_repo
+
+
+class HFShardDownloader(ShardDownloader):
+  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
+    self.quick_check = quick_check
+    self.max_parallel_downloads = max_parallel_downloads
+    self.active_downloads: Dict[Shard, asyncio.Task] = {}
+    self.completed_downloads: Dict[Shard, Path] = {}
+    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    repo_name = get_repo(shard.model_id, inference_engine_name)
+    if shard in self.completed_downloads:
+      return self.completed_downloads[shard]
+    if self.quick_check:
+      repo_root = get_repo_root(repo_name)
+      snapshots_dir = repo_root/"snapshots"
+      if snapshots_dir.exists():
+        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
+        if visible_dirs:
+          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
+          return most_recent_dir
+
+    # If a download on this shard is already in progress, keep that one
+    for active_shard in self.active_downloads:
+      if active_shard == shard:
+        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
+        return await self.active_downloads[shard]
+
+    # Cancel any downloads for this model_id on a different shard
+    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
+    for active_shard in existing_active_shards:
+      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
+      task = self.active_downloads[active_shard]
+      task.cancel()
+      try:
+        await task
+      except asyncio.CancelledError:
+        pass  # This is expected when cancelling a task
+      except Exception as e:
+        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
+        traceback.print_exc()
+    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
+
+    # Start new download
+    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
+    self.active_downloads[shard] = download_task
+    try:
+      path = await download_task
+      self.completed_downloads[shard] = path
+      return path
+    finally:
+      # Ensure the task is removed even if an exception occurs
+      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
+      if shard in self.active_downloads:
+        self.active_downloads.pop(shard)
+
+  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
+    async def wrapped_progress_callback(event: RepoProgressEvent):
+      self._on_progress.trigger_all(shard, event)
+
+    weight_map = await get_weight_map(repo_name)
+    allow_patterns = get_allow_patterns(weight_map, shard)
+
+    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return self._on_progress

+ 36 - 0
build/lib/exo/download/shard_download.py

@@ -0,0 +1,36 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple
+from pathlib import Path
+from exo.inference.shard import Shard
+from exo.download.download_progress import RepoProgressEvent
+from exo.helpers import AsyncCallbackSystem
+
+
+class ShardDownloader(ABC):
+  @abstractmethod
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    """
+        Ensures that the shard is downloaded.
+        Does not allow multiple overlapping downloads at once.
+        If you try to download a Shard which overlaps a Shard that is already being downloaded,
+        the download will be cancelled and a new download will start.
+
+        Args:
+            shard (Shard): The shard to download.
+            inference_engine_name (str): The inference engine used on the node hosting the shard
+        """
+    pass
+
+  @property
+  @abstractmethod
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    pass
+
+
+class NoopShardDownloader(ShardDownloader):
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    return Path("/tmp/noop_shard")
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return AsyncCallbackSystem()

+ 274 - 0
build/lib/exo/helpers.py

@@ -0,0 +1,274 @@
+import os
+import sys
+import asyncio
+from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
+import socket
+import random
+import platform
+import psutil
+import uuid
+import netifaces
+from pathlib import Path
+import tempfile
+
+DEBUG = int(os.getenv("DEBUG", default="0"))
+DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
+VERSION = "0.0.1"
+
+exo_text = r"""
+  _____  _____  
+ / _ \ \/ / _ \ 
+|  __/>  < (_) |
+ \___/_/\_\___/ 
+    """
+
+
+def get_system_info():
+  if psutil.MACOS:
+    if platform.machine() == "arm64":
+      return "Apple Silicon Mac"
+    if platform.machine() in ["x86_64", "i386"]:
+      return "Intel Mac"
+    return "Unknown Mac architecture"
+  if psutil.LINUX:
+    return "Linux"
+  return "Non-Mac, non-Linux system"
+
+
+def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
+  used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")
+
+  def read_used_ports():
+    if os.path.exists(used_ports_file):
+      with open(used_ports_file, "r") as f:
+        return [int(line.strip()) for line in f if line.strip().isdigit()]
+    return []
+
+  def write_used_port(port, used_ports):
+    with open(used_ports_file, "w") as f:
+      print(used_ports[-19:])
+      for p in used_ports[-19:] + [port]:
+        f.write(f"{p}\n")
+
+  used_ports = read_used_ports()
+  available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
+
+  while available_ports:
+    port = random.choice(list(available_ports))
+    if DEBUG >= 2: print(f"Trying to find available port {port=}")
+    try:
+      with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind((host, port))
+      write_used_port(port, used_ports)
+      return port
+    except socket.error:
+      available_ports.remove(port)
+
+  raise RuntimeError("No available ports in the specified range")
+
+
+def print_exo():
+  print(exo_text)
+
+
+def print_yellow_exo():
+  yellow = "\033[93m"  # ANSI escape code for yellow
+  reset = "\033[0m"  # ANSI escape code to reset color
+  print(f"{yellow}{exo_text}{reset}")
+
+
+def terminal_link(uri, label=None):
+  if label is None:
+    label = uri
+  parameters = ""
+
+  # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST
+  escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\"
+
+  return escape_mask.format(parameters, uri, label)
+
+
+T = TypeVar("T")
+K = TypeVar("K")
+
+
+class AsyncCallback(Generic[T]):
+  def __init__(self) -> None:
+    self.condition: asyncio.Condition = asyncio.Condition()
+    self.result: Optional[Tuple[T, ...]] = None
+    self.observers: list[Callable[..., None]] = []
+
+  async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
+    async with self.condition:
+      await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
+      assert self.result is not None  # for type checking
+      return self.result
+
+  def on_next(self, callback: Callable[..., None]) -> None:
+    self.observers.append(callback)
+
+  def set(self, *args: T) -> None:
+    self.result = args
+    for observer in self.observers:
+      observer(*args)
+    asyncio.create_task(self.notify())
+
+  async def notify(self) -> None:
+    async with self.condition:
+      self.condition.notify_all()
+
+
+class AsyncCallbackSystem(Generic[K, T]):
+  def __init__(self) -> None:
+    self.callbacks: Dict[K, AsyncCallback[T]] = {}
+
+  def register(self, name: K) -> AsyncCallback[T]:
+    if name not in self.callbacks:
+      self.callbacks[name] = AsyncCallback[T]()
+    return self.callbacks[name]
+
+  def deregister(self, name: K) -> None:
+    if name in self.callbacks:
+      del self.callbacks[name]
+
+  def trigger(self, name: K, *args: T) -> None:
+    if name in self.callbacks:
+      self.callbacks[name].set(*args)
+
+  def trigger_all(self, *args: T) -> None:
+    for callback in self.callbacks.values():
+      callback.set(*args)
+
+
+K = TypeVar('K', bound=str)
+V = TypeVar('V')
+
+
+class PrefixDict(Generic[K, V]):
+  def __init__(self):
+    self.items: Dict[K, V] = {}
+
+  def add(self, key: K, value: V) -> None:
+    self.items[key] = value
+
+  def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
+    return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
+
+  def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
+    matches = self.find_prefix(argument)
+    if len(matches) == 0:
+      return None
+
+    return max(matches, key=lambda x: len(x[0]))
+
+
+def is_valid_uuid(val):
+  try:
+    uuid.UUID(str(val))
+    return True
+  except ValueError:
+    return False
+
+
+def get_or_create_node_id():
+  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
+  try:
+    if NODE_ID_FILE.is_file():
+      with open(NODE_ID_FILE, "r") as f:
+        stored_id = f.read().strip()
+      if is_valid_uuid(stored_id):
+        if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
+        return stored_id
+      else:
+        if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
+
+    new_id = str(uuid.uuid4())
+    with open(NODE_ID_FILE, "w") as f:
+      f.write(new_id)
+
+    if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
+    return new_id
+  except IOError as e:
+    if DEBUG >= 2: print(f"IO error creating node_id: {e}")
+    return str(uuid.uuid4())
+  except Exception as e:
+    if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
+    return str(uuid.uuid4())
+
+
+def pretty_print_bytes(size_in_bytes: int) -> str:
+  if size_in_bytes < 1024:
+    return f"{size_in_bytes} B"
+  elif size_in_bytes < 1024**2:
+    return f"{size_in_bytes / 1024:.2f} KB"
+  elif size_in_bytes < 1024**3:
+    return f"{size_in_bytes / (1024 ** 2):.2f} MB"
+  elif size_in_bytes < 1024**4:
+    return f"{size_in_bytes / (1024 ** 3):.2f} GB"
+  else:
+    return f"{size_in_bytes / (1024 ** 4):.2f} TB"
+
+
+def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
+  if bytes_per_second < 1024:
+    return f"{bytes_per_second} B/s"
+  elif bytes_per_second < 1024**2:
+    return f"{bytes_per_second / 1024:.2f} KB/s"
+  elif bytes_per_second < 1024**3:
+    return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
+  elif bytes_per_second < 1024**4:
+    return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
+  else:
+    return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
+
+
+def get_all_ip_addresses():
+  try:
+    ip_addresses = []
+    for interface in netifaces.interfaces():
+      ifaddresses = netifaces.ifaddresses(interface)
+      if netifaces.AF_INET in ifaddresses:
+        for link in ifaddresses[netifaces.AF_INET]:
+          ip = link['addr']
+          ip_addresses.append(ip)
+    return list(set(ip_addresses))
+  except:
+    if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
+    return ["localhost"]
+
+
+async def shutdown(signal, loop, server):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
+
+
+def get_exo_home() -> Path:
+  if os.name == "nt":  # Check if the OS is Windows
+    docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
+  else:
+    docs_folder = Path.home() / "Documents"
+  exo_folder = docs_folder / "Exo"
+  if not exo_folder.exists():
+    exo_folder.mkdir()
+  return exo_folder
+
+def get_exo_images_dir() -> Path:
+  exo_home = get_exo_home()
+  images_dir = exo_home / "Images"
+  if not images_dir.exists():
+    images_dir.mkdir()
+  return images_dir
+  

+ 0 - 0
build/lib/exo/inference/__init__.py


+ 58 - 0
build/lib/exo/inference/debug_inference_engine.py

@@ -0,0 +1,58 @@
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+import asyncio
+import numpy as np
+
+
+# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
+  from exo.inference.tinygrad.inference import Tokenizer
+  from pathlib import Path
+
+  _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
+
+  prompt = "In a single word only, what is the last name of the president of the United States? "
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+
+  next_resp_full = await inference_engine_1.infer_tensor(
+    "A",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    input_data=token_full,
+  )
+
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp1,
+  )
+  token2 = await inference_engine_2.sample(resp2)
+  resp3 = await inference_engine_1.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
+    input_data=token2,
+  )
+  resp4 = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp3,
+  )
+
+  print(f"{resp2=}")
+  print(f"full: {_tokenizer.decode(resp_full)}")
+  print(f"next full: {_tokenizer.decode(next_resp_full)}")
+  print(f"resp2: {_tokenizer.decode(resp2)}")
+  print(f"{resp4=}")
+  print(f"resp4: {_tokenizer.decode(resp4)}")
+
+  assert np.array_equal(resp_full, resp2)
+  assert np.array_equal(next_resp_full, resp4)
+
+
+asyncio.run(test_inference_engine(
+  TinygradDynamicShardInferenceEngine(),
+  TinygradDynamicShardInferenceEngine(),
+  "llama3-8b-sfr",
+))

+ 34 - 0
build/lib/exo/inference/dummy_inference_engine.py

@@ -0,0 +1,34 @@
+from typing import Optional, Tuple, TYPE_CHECKING
+import numpy as np
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+from exo.inference.tokenizers import DummyTokenizer
+
+class DummyInferenceEngine(InferenceEngine):
+  def __init__(self):
+    self.shard = None
+    self.vocab_size = 1000
+    self.hidden_size = 256
+    self.eos_token_id = 0
+    self.latency_mean = 0.1
+    self.latency_stddev = 0.02
+    self.num_generate_dummy_tokens = 10
+    self.tokenizer = DummyTokenizer()
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    return np.array(self.tokenizer.encode(prompt))
+  
+  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
+    return x
+
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    return self.tokenizer.decode(tokens)
+
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    await self.ensure_shard(shard)
+    return input_data + 1 if self.shard.is_last_layer() else input_data
+
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard: return
+    self.shard = shard

+ 58 - 0
build/lib/exo/inference/inference_engine.py

@@ -0,0 +1,58 @@
+import numpy as np
+import os
+from exo.helpers import DEBUG  # Make sure to import DEBUG
+
+from typing import Tuple, Optional
+from abc import ABC, abstractmethod
+from .shard import Shard
+
+
+class InferenceEngine(ABC):
+  @abstractmethod
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  @abstractmethod
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    pass
+
+  @abstractmethod
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    pass
+
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    pass
+  
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    if shard.model_id != 'stable-diffusion-2-1-base':
+      x = tokens.reshape(1, -1)
+    else:
+      x = tokens
+    output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
+    return output_data, inference_state
+
+inference_engine_classes = {
+  "mlx": "MLXDynamicShardInferenceEngine",
+  "tinygrad": "TinygradDynamicShardInferenceEngine",
+  "dummy": "DummyInferenceEngine",
+}
+
+def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if DEBUG >= 2:
+    print(f"get_inference_engine called with: {inference_engine_name}")
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+    import tinygrad.helpers
+    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+
+    return TinygradDynamicShardInferenceEngine(shard_downloader)
+  elif inference_engine_name == "dummy":
+    from exo.inference.dummy_inference_engine import DummyInferenceEngine
+    return DummyInferenceEngine()
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 0 - 0
build/lib/exo/inference/mlx/__init__.py


+ 307 - 0
build/lib/exo/inference/mlx/models/StableDiffusionPipeline.py

@@ -0,0 +1,307 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
+
+import time
+from typing import Optional, Tuple
+import inspect
+
+import mlx.core as mx
+import mlx.nn as nn
+from pathlib import Path
+
+from tqdm import tqdm
+
+from .sd_models.vae import ModelArgs as VAEArgs
+from .sd_models.vae import Autoencoder
+from .sd_models.tokenizer import load_tokenizer
+from .sd_models.clip import CLIPTextModel
+from .sd_models.clip import ModelArgs as CLIPArgs
+from .sd_models.unet import UNetConfig, UNetModel
+
+from dataclasses import dataclass, field
+from exo.inference.shard import Shard
+
+@dataclass
+class DiffusionConfig:
+    beta_schedule: str = "scaled_linear"
+    beta_start: float = 0.00085
+    beta_end: float = 0.012
+    num_train_steps: int = 1000
+
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+
+#Sampler
+def _linspace(a, b, num):
+    x = mx.arange(0, num) / (num - 1)
+    return (b - a) * x + a
+
+
+def _interp(y, x_new):
+    """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
+    x_low = x_new.astype(mx.int32)
+    x_high = mx.minimum(x_low + 1, len(y) - 1)
+
+    y_low = y[x_low]
+    y_high = y[x_high]
+    delta_x = x_new - x_low
+    y_new = y_low * (1 - delta_x) + delta_x * y_high
+
+    return y_new
+
+class SimpleEulerSampler:
+    """A simple Euler integrator that can be used to sample from our diffusion models.
+
+    The method ``step()`` performs one Euler step from x_t to x_t_prev.
+    """
+
+    def __init__(self, config: DiffusionConfig):
+        # Compute the noise schedule
+        if config.beta_schedule == "linear":
+            betas = _linspace(
+                config.beta_start, config.beta_end, config.num_train_steps
+            )
+        elif config.beta_schedule == "scaled_linear":
+            betas = _linspace(
+                config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
+            ).square()
+        else:
+            raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
+
+        alphas = 1 - betas
+        alphas_cumprod = mx.cumprod(alphas)
+
+        self._sigmas = mx.concatenate(
+            [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
+        )
+
+    @property
+    def max_time(self):
+        return len(self._sigmas) - 1
+
+    def sample_prior(self, shape, dtype=mx.float32, key=None):
+        noise = mx.random.normal(shape, key=key)
+        return (
+            noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
+        ).astype(dtype)
+
+    def add_noise(self, x, t, key=None):
+        noise = mx.random.normal(x.shape, key=key)
+        s = self.sigmas(t)
+        return (x + noise * s) * (s.square() + 1).rsqrt()
+
+    def sigmas(self, t):
+        return _interp(self._sigmas, t)
+
+    def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
+        start_time = start_time or (len(self._sigmas) - 1)
+        assert 0 < start_time <= (len(self._sigmas) - 1)
+        steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
+        return list(zip(steps, steps[1:]))
+
+    def current_timestep(self, step, total_steps, start_time=None):
+        if step < total_steps:
+            steps = self.timesteps(total_steps, start_time)
+            return steps[step]
+        else:
+            return mx.array(0),mx.array(0)
+
+    def step(self, eps_pred, x_t, t, t_prev):
+        sigma = self.sigmas(t).astype(eps_pred.dtype)
+        sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
+
+        dt = sigma_prev - sigma
+        x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
+
+        x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
+
+        return x_t_prev
+
+@dataclass
+class ShardConfig:
+    model_id:str
+    start_layer:int
+    end_layer:int
+    n_layers:int
+
+@dataclass
+class StableDiffusionConfig:
+    model_type:str
+    vae:VAEArgs
+    text_encoder:CLIPArgs
+    scheduler:DiffusionConfig
+    unet:UNetConfig
+    shard:ShardConfig
+    
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+@dataclass
+class ModelArgs(StableDiffusionConfig):
+    shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+
+class Model(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.model_type = config.model_type
+        self.config = config
+        self.model_path = config.vae['path'].split('/vae')[0]
+        self.shard = config.shard
+        self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder  = model_shards(config.shard)
+        self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
+        if self.shard_clip.start_layer != -1:
+            self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
+        else:
+            self.text_encoder = nn.Identity()    
+        self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
+        self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
+        self.sampler = SimpleEulerSampler(self.diffusion_config)
+        if self.shard_unet.start_layer!=-1:
+            self.config_unet = UNetConfig.from_dict(config.unet['config'])
+            self.unet = UNetModel(self.config_unet, self.shard_unet)
+        else:
+            self.unet = nn.Identity()
+        self.config_vae=VAEArgs.from_dict(config.vae['config'])
+        if self.shard_encoder.start_layer != -1:
+            self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder") 
+        else:
+            self.encoder = nn.Identity()            
+        if self.shard_decoder.start_layer != -1:
+            self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder") 
+        else:
+            self.decoder = nn.Identity()            
+
+    def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.7, start_step=None):
+        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
+        is_finished = False
+        is_step_finished = False
+        if t.item()==1000:
+            if self.shard_clip.start_layer == 0:
+                conditioning = x
+            if self.shard_clip.start_layer != -1:
+                conditioning, mask= self.text_encoder(conditioning,mask)
+            seed = int(time.time()) 
+            mx.random.seed(seed)
+            if image is None:
+                if self.shard_encoder.is_last_layer():
+                    x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
+                    x_t_prev=x
+                    start_step = self.sampler.max_time
+            else:
+                if self.shard_encoder.start_layer != -1:
+                    image= self.encoder.encode(image)
+                    if self.shard_encoder.is_last_layer():
+                        start_step = self.sampler.max_time*strength
+                        total_steps = int(total_steps*strength)
+                        image = mx.broadcast_to(image, (1,) + image.shape[1:])
+                        x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
+                        image = None
+                        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
+        # Perform the denoising loop
+        if self.shard_unet.start_layer != -1:
+            with tqdm(total=total_steps,initial=step+1) as pbar:
+                if step<total_steps:
+                    x = x_t_prev
+                    if self.shard_unet.is_first_layer():
+                        x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
+                    else:
+                        x_t_unet = x
+                    t_unet = mx.broadcast_to(t, [len(x_t_unet)])
+                    x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
+                    if self.shard_unet.is_last_layer():
+                        if cfg_weight > 1:
+                            eps_text, eps_neg = x.split(2)
+                            eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
+                        x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
+                        x_t_prev=x
+                    mx.eval(x)
+                    
+        if self.shard_decoder.is_last_layer():
+            is_step_finished=True
+            if self.shard_decoder.start_layer != -1:
+                x=self.decoder.decode(x)
+            if self.shard_decoder.is_last_layer():
+                x = mx.clip(x / 2 + 0.5, 0, 1)
+                B, H, W, C = x.shape
+                x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
+                x = x.reshape(1 * H, B // 1 * W, C)
+                x = (x * 255).astype(mx.uint8)
+                if t_prev.item() ==0:
+                    is_finished=True   
+        mx.eval(x)
+         
+        return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
+    
+
+    def load(self):
+        if self.shard_encoder.start_layer != -1:    
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.encoder.sanitize(vae_weights)
+            self.encoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_decoder.start_layer != -1:
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.decoder.sanitize(vae_weights)
+            self.decoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_clip.start_layer != -1:
+            clip_weights = mx.load(self.config_clip.weight_files[0])
+            clip_weights = self.text_encoder.sanitize(clip_weights)
+            self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
+        if self.shard_unet.start_layer !=-1:
+            unet_weights = mx.load(self.config_unet.weight_files[0])
+            unet_weights = self.unet.sanitize(unet_weights)
+            self.unet.load_weights(list(unet_weights.items()), strict=True)
+
+def model_shards(shard:ShardConfig):
+    def create_shard(shard, model_ranges):
+        start_layer = shard.start_layer
+        end_layer = shard.end_layer
+        
+        shards = {}
+        
+        for model_name, (range_start, range_end) in model_ranges.items():
+            if start_layer < range_end and end_layer >= range_start:
+                # Calculate the overlap with the model range
+                overlap_start = max(start_layer, range_start)
+                overlap_end = min(end_layer, range_end - 1)
+
+                # Adjust the layers relative to the model's range
+                relative_start = overlap_start - range_start
+                relative_end = overlap_end - range_start
+                shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
+            else:
+                # If no overlap, create a zero-layer shard
+                shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
+        
+        return shards
+
+    # Define the ranges for different models
+    model_ranges = {
+        'clip': (0, 12),
+        'vae_encoder':(12,17),
+        'unet':(17,26),
+        'vae_decoder': (26, 31) # Example range for unet
+    }
+
+    # Call the function and get the shards for all models
+    shards = create_shard(shard, model_ranges)
+
+    # Access individual shards
+    shard_clip = shards['clip']
+    shard_encoder = shards['vae_encoder']
+    shard_unet = shards['unet']
+    shard_decoder = shards['vae_decoder']
+    
+    return shard_clip, shard_encoder, shard_unet, shard_decoder
+
+
+

+ 0 - 0
build/lib/exo/inference/mlx/models/__init__.py


+ 9 - 0
build/lib/exo/inference/mlx/models/base.py

@@ -0,0 +1,9 @@
+from typing import Optional
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.cache import KVCache
+
+
+class IdentityBlock(nn.Module):
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
+    return x

+ 127 - 0
build/lib/exo/inference/mlx/models/deepseek_v2.py

@@ -0,0 +1,127 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.cache import KVCache
+from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
+from .base import IdentityBlock
+from exo.inference.shard import Shard
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class DeepseekV2Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.num_hidden_layers = config.num_hidden_layers
+    self.vocab_size = config.vocab_size
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(DeepseekV2DecoderLayer(config, i))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    x: mx.array,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(x)
+    else:
+      h = x
+
+    mask = None
+    T = h.shape[1]
+    if T > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.model_type = config.model_type
+    self.model = DeepseekV2Model(config)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache: Optional[KVCache] = None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      return self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    for l in range(self.args.num_hidden_layers):
+      prefix = f"model.layers.{l}"
+      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+        for k in ["weight", "scales", "biases"]:
+          if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
+            to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
+            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return (
+      self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
+      self.args.v_head_dim,
+    )
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 118 - 0
build/lib/exo/inference/mlx/models/gemma2.py

@@ -0,0 +1,118 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class GemmaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if args.shard.is_first_layer() or args.shard.is_last_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if args.shard.start_layer <= i <= args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if args.shard.is_last_layer():
+      self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+      h = h * (self.args.hidden_size**0.5)
+    else:
+      h = inputs
+
+    mask = None
+    if h.ndim > 1 and h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = GemmaModel(args)
+    if args.shard.is_last_layer():
+      self.final_logit_softcapping = args.final_logit_softcapping
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.model.embed_tokens.as_linear(out)
+      out = mx.tanh(out / self.final_logit_softcapping)
+      out = out * self.final_logit_softcapping
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.head_dim
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 125 - 0
build/lib/exo/inference/mlx/models/llama.py

@@ -0,0 +1,125 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.llama import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()  # Ensure parent initializations are respected
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class LlamaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if args.shard.start_layer <= i <= args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.ndim > 1 and h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = LlamaModel(args)
+    if args.shard.is_last_layer():
+      if not args.tie_word_embeddings:
+        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      if self.args.tie_word_embeddings:
+        out = self.model.embed_tokens.as_linear(out)
+      else:
+        out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 585 - 0
build/lib/exo/inference/mlx/models/llava.py

@@ -0,0 +1,585 @@
+# Copyright © 2024 Apple Inc.
+
+import math
+import inspect
+from dataclasses import dataclass, field
+from typing import Optional, Dict, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.base import BaseModelArgs, KVCache
+from exo.inference.shard import Shard
+from .base import IdentityBlock
+import numpy as np
+
+
+@dataclass
+class VisionConfig:
+  model_type: str
+  num_hidden_layers: int = 24
+  hidden_size: int = 1024
+  intermediate_size: int = 4096
+  num_attention_heads: int = 16
+  image_size: int = 336
+  patch_size: int = 14
+  projection_dim: int = 768
+  vocab_size: int = 32000
+  num_channels: int = 3
+  layer_norm_eps: float = 1e-5
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+
+class VisionAttention(nn.Module):
+  def __init__(
+    self,
+    dims: int,
+    num_heads: int,
+    query_input_dims: Optional[int] = None,
+    key_input_dims: Optional[int] = None,
+    value_input_dims: Optional[int] = None,
+    value_dims: Optional[int] = None,
+    value_output_dims: Optional[int] = None,
+    bias: bool = False,
+  ):
+    super().__init__()
+
+    if (dims % num_heads) != 0:
+      raise ValueError("The input feature dimensions should be divisible by the "
+                       f"number of heads ({dims} % {num_heads}) != 0")
+
+    query_input_dims = query_input_dims or dims
+    key_input_dims = key_input_dims or dims
+    value_input_dims = value_input_dims or key_input_dims
+    value_dims = value_dims or dims
+    value_output_dims = value_output_dims or dims
+
+    self.num_heads = num_heads
+    self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
+    self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
+    self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
+    self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
+
+  def __call__(self, queries, keys, values, mask=None):
+    queries = self.q_proj(queries)
+    keys = self.k_proj(keys)
+    values = self.v_proj(values)
+
+    num_heads = self.num_heads
+    B, L, D = queries.shape
+    _, S, _ = keys.shape
+    queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
+    values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
+
+    scale = math.sqrt(1/queries.shape[-1])
+    scores = (queries*scale) @ keys
+    if mask is not None:
+      scores = scores + mask.astype(scores.dtype)
+    scores = mx.softmax(scores, axis=-1)
+    values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
+
+    return self.out_proj(values_hat)
+
+
+class VisionMLP(nn.Module):
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.activation_fn = nn.GELU(approx="fast")
+    self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+    self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.activation_fn(self.fc1(x))
+    x = self.fc2(x)
+    return x
+
+
+class VisionEncoderLayer(nn.Module):
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embed_dim = config.hidden_size
+    self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True)
+    self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+    self.mlp = VisionMLP(config)
+    self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
+    y = self.layer_norm1(x)
+    y = self.self_attn(y, y, y, mask)
+    x = x + y
+    y = self.layer_norm2(x)
+    y = self.mlp(y)
+    return x + y
+
+
+class VisionEncoder(nn.Module):
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
+
+
+class VisionEmbeddings(nn.Module):
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.config = config
+    self.embed_dim = config.hidden_size
+    self.image_size = config.image_size
+    self.patch_size = config.patch_size
+
+    self.class_embedding = mx.zeros((config.hidden_size,))
+
+    self.patch_embedding = nn.Conv2d(
+      in_channels=config.num_channels,
+      out_channels=self.embed_dim,
+      kernel_size=self.patch_size,
+      stride=self.patch_size,
+      bias=False,
+    )
+
+    self.num_patches = (self.image_size // self.patch_size)**2
+    self.num_positions = self.num_patches + 1
+    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    batch_size = x.shape[0]
+    patch_embeddings = self.patch_embedding(x)
+    patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
+    embed_dim = patch_embeddings.shape[-1]
+    cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim))
+    embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
+    embeddings += self.position_embedding.weight
+    return embeddings
+
+
+class ClipVisionModel(nn.Module):
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+    self.embeddings = VisionEmbeddings(config)
+    self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
+    self.encoder = VisionEncoder(config)
+    self.post_layernorm = nn.LayerNorm(config.hidden_size)
+
+  def __call__(
+    self,
+    x: mx.array,
+    output_hidden_states: Optional[bool] = None,
+  ) -> mx.array:
+    x = self.embeddings(x)
+    x = self.pre_layrnorm(x)
+
+    encoder_states = (x,) if output_hidden_states else None
+
+    for l in self.encoder.layers:
+      x = l(x, mask=None)
+      if output_hidden_states:
+        encoder_states = encoder_states + (x,)
+
+    pooler_output = self.post_layernorm(x[:, 0, :])
+    return pooler_output, x, encoder_states
+
+
+class VisionModel(nn.Module):
+  def __init__(self, config: VisionConfig):
+    super().__init__()
+
+    self.model_type = config.model_type
+    if self.model_type != "clip_vision_model":
+      raise ValueError(f"Unsupported model type: {self.model_type}")
+
+    self.vision_model = ClipVisionModel(config)
+
+  def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array:
+    return self.vision_model(x, output_hidden_states)
+
+  def sanitize(self, weights):
+    sanitized_weights = {}
+    for k, v in weights.items():
+      if "position_ids" in k:
+        # Remove unused position_ids
+        continue
+      elif "patch_embedding.weight" in k:
+        # PyTorch conv2d weight tensors have shape:
+        #   [out_channels, in_channels, kH, KW]
+        # MLX conv2d expects the weight be of shape:
+        #   [out_channels, kH, KW, in_channels]
+        sanitized_weights[k] = v.transpose(0, 2, 3, 1)
+      else:
+        sanitized_weights[k] = v
+
+    return sanitized_weights
+
+
+@dataclass
+class TextConfig:
+  model_type: str
+  hidden_size: int = 4096
+  num_hidden_layers: int = 32
+  intermediate_size: int = 11008
+  num_attention_heads: int = 32
+  head_dim: int = None
+  rms_norm_eps: float = 1e-6
+  vocab_size: int = 32000
+  num_key_value_heads: int = None
+  rope_theta: float = 10000
+  rope_traditional: bool = False
+  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+
+  @classmethod
+  def from_dict(cls, params):
+    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+  def __post_init__(self):
+    if self.num_key_value_heads is None:
+      self.num_key_value_heads = self.num_attention_heads
+
+    if self.head_dim is None:
+      self.head_dim = self.hidden_size // self.num_attention_heads
+
+    if self.model_type is None:
+      self.model_type = "llama"
+
+    if self.rope_scaling:
+      required_keys = {"factor", "type"}
+      if not all(key in self.rope_scaling for key in required_keys):
+        raise ValueError(f"rope_scaling must contain keys {required_keys}")
+
+      if self.rope_scaling["type"] != "linear":
+        raise ValueError("rope_scaling 'type' currently only supports 'linear'")
+
+
+class TextAttention(nn.Module):
+  def __init__(self, config: TextConfig):
+    super().__init__()
+
+    dim = config.hidden_size
+    self.n_heads = n_heads = config.num_attention_heads
+    self.n_kv_heads = n_kv_heads = config.num_key_value_heads
+
+    self.repeats = n_heads // n_kv_heads
+
+    head_dim = config.hidden_size // n_heads
+    self.scale = head_dim**-0.5
+
+    self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
+    self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
+    self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
+    self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
+
+    rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
+    self.rope = nn.RoPE(
+      head_dim,
+      traditional=config.rope_traditional,
+      base=config.rope_theta,
+      scale=rope_scale,
+    )
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    B, L, D = x.shape
+
+    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+    # Prepare the queries, keys and values for the attention computation
+    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+    if cache is not None:
+      queries = self.rope(queries, offset=cache.offset)
+      keys = self.rope(keys, offset=cache.offset)
+      keys, values = cache.update_and_fetch(keys, values)
+    else:
+      queries = self.rope(queries)
+      keys = self.rope(keys)
+
+    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
+    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+    return self.o_proj(output)
+
+
+class TextMLP(nn.Module):
+  def __init__(self, dim, hidden_dim):
+    super().__init__()
+    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
+    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
+
+  def __call__(self, x) -> mx.array:
+    return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
+
+
+class TransformerBlock(nn.Module):
+  def __init__(self, config: TextConfig):
+    super().__init__()
+    self.num_attention_heads = config.num_attention_heads
+    self.hidden_size = config.hidden_size
+    self.self_attn = TextAttention(config)
+    self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
+    self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+    self.config = config
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    r = self.self_attn(self.input_layernorm(x), mask, cache)
+    h = x + r
+    r = self.mlp(self.post_attention_layernorm(h))
+    out = h + r
+    return out
+
+
+class Llama(nn.Module):
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.config = config
+    self.shard = shard
+    self.vocab_size = config.vocab_size
+    self.model_type = config.model_type
+    self.num_hidden_layers = config.num_hidden_layers
+    self.num_key_value_heads = config.num_key_value_heads
+    self.head_dim = config.head_dim
+    assert self.vocab_size > 0
+    if self.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.shard.start_layer <= i <= self.shard.end_layer:
+        self.layers.append(TransformerBlock(config=config))
+      else:
+        self.layers.append(IdentityBlock())
+    if self.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    # for passing merged input embeddings
+    if inputs_embeds is None:
+      if self.shard.is_first_layer():
+        h = self.embed_tokens(inputs)
+      else:
+        h = inputs
+    else:
+      h = inputs_embeds
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class LanguageModel(nn.Module):
+  def __init__(self, config: TextConfig, shard: Shard):
+    super().__init__()
+    self.model_type = config.model_type
+    if self.model_type != "llama":
+      raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported")
+    self.shard = shard
+    self.model = Llama(config, shard)
+    if self.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+    inputs_embeds=None,
+  ):
+    out = self.model(inputs, cache, inputs_embeds)
+    if self.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+
+      if key.startswith('language_model.model.layers.'):
+        layer_num = int(key.split('.')[3])
+        if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+          continue
+      if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
+        continue
+      elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
+        continue
+
+      shard_state_dict[key] = value
+
+    return shard_state_dict
+
+
+@dataclass
+class LlaVAConfig(BaseModelArgs):
+  text_config: TextConfig
+  vision_config: VisionConfig = None
+  model_type: str = "llava"
+  ignore_index: int = -100
+  image_token_index: int = 32000
+  vision_feature_select_strategy: str = "default"
+  vision_feature_layer: int = -2
+  vocab_size: int = 32000
+
+  @classmethod
+  def from_dict(cls, params):
+    updated_params = {}
+    class_params = inspect.signature(cls).parameters
+    for k, v in params.items():
+      if k in class_params:
+        if k in ["text_config", "vision_config"]:
+          v = class_params[k].annotation.from_dict(v)
+        updated_params.update({k: v})
+
+    return cls(**updated_params)
+
+
+@dataclass
+class ModelArgs(LlaVAConfig):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, dict):
+      self.shard = Shard(**self.shard)
+
+    if not isinstance(self.shard, Shard):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    if not self.shard.is_first_layer():
+      self.vision_config = None
+
+
+class LlavaMultiModalProjector(nn.Module):
+  def __init__(self, config: LlaVAConfig):
+    super().__init__()
+    self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
+    self.gelu = nn.GELU()
+    self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
+
+  def __call__(self, x: mx.array) -> mx.array:
+    x = self.linear_1(x)
+    x = self.gelu(x)
+    x = self.linear_2(x)
+    return x
+
+
+class Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.config = config
+    self.model_type = config.model_type
+    if config.vision_config:
+      self.vision_tower = VisionModel(config.vision_config)
+      self.multi_modal_projector = LlavaMultiModalProjector(config)
+      self.vision_feature_layer = config.vision_feature_layer
+      self.vision_feature_select_strategy = config.vision_feature_select_strategy
+    self.language_model = LanguageModel(config.text_config, config.shard)
+
+  def get_input_embeddings(
+    self,
+    input_ids: Optional[mx.array] = None,
+    pixel_values: Optional[mx.array] = None,
+  ):
+    if pixel_values is None:
+      return self.language_model(input_ids)
+
+    # Get the input embeddings from the language model
+    inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+
+    # Get the ouptut hidden states from the vision model
+    *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)
+
+    # Select the hidden states from the desired layer
+    selected_image_feature = hidden_states[self.vision_feature_layer]
+
+    if self.vision_feature_select_strategy == "default":
+      selected_image_feature = selected_image_feature[:, 1:]
+    elif self.vision_feature_select_strategy == "full":
+      selected_image_feature = selected_image_feature
+    else:
+      raise ValueError("Unexpected feature selection strategy: "
+                       f"{self.vision_feature_select_strategy}")
+
+    # Pass image features through the multi-modal projector
+    image_features = self.multi_modal_projector(selected_image_feature)
+
+    # Insert special image tokens in the input_ids
+    final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
+    return final_inputs_embeds
+
+  def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
+    image_token_index = self.config.image_token_index
+    num_images, num_image_patches, embed_dim = image_features.shape
+
+    # Positions of <image> tokens in input_ids, assuming batch size is 1
+    image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
+
+    if len(image_positions) != num_images:
+      raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
+                       f" match the number of image inputs ({num_images}).")
+
+    text_segments = []
+    start_idx = 0
+
+    for position in image_positions:
+      text_segments.append(inputs_embeds[:, start_idx:position])
+      start_idx = position + 1
+
+    image_embeddings = mx.split(image_features, image_features.shape[0])
+    final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
+    final_embeddings += [inputs_embeds[:, start_idx:]]
+
+    # Create a final embedding of shape
+    # (1, num_image_patches*num_images + sequence_len, embed_dim)
+    return mx.concatenate(final_embeddings, axis=1)
+
+  def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
+    input_embddings = None
+    if pixel_values is not None:
+      input_embddings = self.get_input_embeddings(input_ids, pixel_values)
+    logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
+    return logits
+
+  def sanitize(self, weights):
+    if self.config.vision_config:
+      weights = self.vision_tower.sanitize(weights)
+    else:
+      weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
+    weights = self.language_model.sanitize(weights)
+    return weights
+
+  @property
+  def layers(self):
+    return self.language_model.model.layers
+
+  @property
+  def head_dim(self):
+    return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)
+
+  @property
+  def n_kv_heads(self):
+    return self.language_model.model.num_key_value_heads

+ 128 - 0
build/lib/exo/inference/mlx/models/qwen2.py

@@ -0,0 +1,128 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()  # Ensure parent initializations are respected
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class Qwen2Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = Qwen2Model(args)
+    if self.args.shard.is_last_layer():
+      if not args.tie_word_embeddings:
+        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      if self.args.tie_word_embeddings:
+        out = self.model.embed_tokens.as_linear(out)
+      else:
+        out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    if self.args.tie_word_embeddings:
+      shard_state_dict.pop("lm_head.weight", None)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 77 - 0
build/lib/exo/inference/mlx/sharded_inference_engine.py

@@ -0,0 +1,77 @@
+import numpy as np
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.sample_utils import top_p_sampling
+from ..inference_engine import InferenceEngine
+from .stateful_model import StatefulModel
+from .sharded_utils import load_shard
+from ..shard import Shard
+from typing import Dict, Optional, Tuple
+from exo.download.shard_download import ShardDownloader
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
+
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
+
+class MLXDynamicShardInferenceEngine(InferenceEngine):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard = None
+    self.shard_downloader = shard_downloader
+    self.executor = ThreadPoolExecutor(max_workers=1)
+
+  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
+    return out
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return np.array(tokens)
+
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
+    await self.ensure_shard(shard)
+    output_data, inference_state = await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id, inference_state)
+    output_data = np.array(output_data)
+    return output_data, inference_state
+
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
+
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
+
+    if self.shard != shard:
+      loop = asyncio.get_running_loop()
+
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
+      model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
+      self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 256 - 0
build/lib/exo/inference/mlx/sharded_utils.py

@@ -0,0 +1,256 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
+
+import glob
+import importlib
+import json
+import logging
+import asyncio
+import aiohttp
+from functools import partial
+from pathlib import Path
+from typing import Optional, Tuple, Union, List, Callable
+from PIL import Image
+from io import BytesIO
+import base64
+import traceback
+
+import mlx.core as mx
+import mlx.nn as nn
+from transformers import AutoProcessor
+
+from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
+
+from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
+from ..shard import Shard
+
+
+class ModelNotFoundError(Exception):
+  def __init__(self, message):
+    self.message = message
+    super().__init__(self.message)
+
+
+MODEL_REMAPPING = {
+  "mistral": "llama",  # mistral is compatible with llama
+  "phi-msft": "phixtral",
+}
+
+
+def _get_classes(config: dict):
+  """
+  Retrieve the model and model args classes based on the configuration.
+
+  Args:
+   config (dict): The model configuration.
+
+  Returns:
+   A tuple containing the Model class and the ModelArgs class.
+  """
+  model_type = config["model_type"]
+  model_type = MODEL_REMAPPING.get(model_type, model_type)
+  try:
+    arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
+  except ImportError:
+    msg = f"Model type {model_type} not supported."
+    logging.error(msg)
+    traceback.print_exc()
+    raise ValueError(msg)
+
+  return arch.Model, arch.ModelArgs
+
+
+def load_config(model_path: Path) -> dict:
+  try:
+    config_path = model_path / "config.json"
+    if config_path.exists():
+      with open(config_path, "r") as f:
+        config = json.load(f)
+      return config
+    
+    model_index_path = model_path / "model_index.json"
+    if model_index_path.exists():
+      config = load_model_index(model_path, model_index_path)
+      return config
+  except FileNotFoundError:
+    logging.error(f"Config file not found in {model_path}")
+    raise
+  return config
+
+def load_model_shard(
+  model_path: Path,
+  shard: Shard,
+  lazy: bool = False,
+  model_config: dict = {},
+) -> nn.Module:
+  """
+  Load and initialize the model from a given path.
+
+  Args:
+   model_path (Path): The path to load the model from.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
+   model_config(dict, optional): Configuration parameters for the model.
+    Defaults to an empty dictionary.
+
+  Returns:
+   nn.Module: The loaded and initialized model.
+
+  Raises:
+   FileNotFoundError: If the weight files (.safetensors) are not found.
+   ValueError: If the model class or args class are not found or cannot be instantiated.
+  """
+  config = load_config(model_path)
+  config.update(model_config)
+
+  # TODO hack
+  config["shard"] = {
+    "model_id": model_path.name,
+    "start_layer": shard.start_layer,
+    "end_layer": shard.end_layer,
+    "n_layers": shard.n_layers,
+  }
+
+  weight_files = glob.glob(str(model_path/"model*.safetensors"))
+
+  if not weight_files:
+    # Try weight for back-compat
+    weight_files = glob.glob(str(model_path/"weight*.safetensors"))
+
+  model_class, model_args_class = _get_classes(config=config)
+
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x, *args, **kwargs)
+      return y
+
+  model_args = model_args_class.from_dict(config)
+  model = ShardedModel(model_args)
+
+  if config.get("model_index", False):
+    model.load()
+    return model
+
+  if not weight_files:
+    logging.error(f"No safetensors found in {model_path}")
+    raise FileNotFoundError(f"No safetensors found in {model_path}")
+
+  weights = {}
+  for wf in sorted(weight_files):
+    if DEBUG >= 8:
+      layer_nums = set()
+      for k in mx.load(wf):
+        if k.startswith("model.layers."):
+          layer_num = int(k.split(".")[2])
+          layer_nums.add(layer_num)
+        if k.startswith("language_model.model.layers."):
+          layer_num = int(k.split(".")[3])
+          layer_nums.add(layer_num)
+      print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
+
+    weights.update(mx.load(wf))
+
+  
+
+  if hasattr(model, "sanitize"):
+    weights = model.sanitize(weights)
+
+  if (quantization := config.get("quantization", None)) is not None:
+    # Handle legacy models which may not have everything quantized
+    def class_predicate(p, m):
+      if not hasattr(m, "to_quantized"):
+        return False
+      return f"{p}.scales" in weights
+
+    nn.quantize(
+      model,
+      **quantization,
+      class_predicate=class_predicate,
+    )
+
+  model.load_weights(list(weights.items()), strict=True)
+
+  if not lazy:
+    mx.eval(model.parameters())
+
+  model.eval()
+  return model
+
+async def load_shard(
+  model_path: str,
+  shard: Shard,
+  tokenizer_config={},
+  model_config={},
+  adapter_path: Optional[str] = None,
+  lazy: bool = False,
+) -> Tuple[nn.Module, TokenizerWrapper]:
+  model = load_model_shard(model_path, shard, lazy, model_config)
+
+  # TODO: figure out a generic solution
+  if model.model_type == "llava":
+    processor = AutoProcessor.from_pretrained(model_path)
+    processor.eos_token_id = processor.tokenizer.eos_token_id
+    processor.encode = processor.tokenizer.encode
+    return model, processor
+  elif hasattr(model, "tokenizer"):
+    tokenizer = model.tokenizer
+    return model, tokenizer
+  else:
+    tokenizer = await resolve_tokenizer(model_path)
+    return model, tokenizer
+
+
+async def get_image_from_str(_image_str: str):
+  image_str = _image_str.strip()
+
+  if image_str.startswith("http"):
+    async with aiohttp.ClientSession() as session:
+      async with session.get(image_str, timeout=10) as response:
+        content = await response.read()
+        return Image.open(BytesIO(content)).convert("RGB")
+  elif image_str.startswith("data:image/"):
+    # Extract the image format and base64 data
+    format_prefix, base64_data = image_str.split(";base64,")
+    image_format = format_prefix.split("/")[1].lower()
+    if DEBUG >= 2: print(f"{image_str=} {image_format=}")
+    imgdata = base64.b64decode(base64_data)
+    img = Image.open(BytesIO(imgdata))
+
+    # Convert to RGB if not already
+    if img.mode != "RGB":
+      img = img.convert("RGB")
+
+    return img
+  else:
+    raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
+
+# loading a combined config for all models in the index
+def load_model_index(model_path: Path, model_index_path: Path):
+  models_config = {}
+  with open(model_index_path, "r") as f:
+      model_index = json.load(f)
+  models_config["model_index"] = True
+  models_config["model_type"] = model_index["_class_name"]
+  models_config["models"] = {}
+  for model in model_index.keys():
+    model_config_path = glob.glob(str(model_path / model / "*config.json"))
+    if len(model_config_path)>0:
+      with open(model_config_path[0], "r") as f:
+        model_config = { }
+        model_config["model_type"] = model
+        model_config["config"] = json.load(f)
+        model_config["path"] = model_path / model
+        if model_config["path"]/"*model.safetensors":
+          model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
+        model_config["path"] = str(model_path / model)
+        m = {}
+        m[model] = model_config
+        models_config.update(m)
+  models_config = json.dumps(models_config)
+  models_config = json.loads(models_config)
+  return models_config

+ 45 - 0
build/lib/exo/inference/mlx/stateful_model.py

@@ -0,0 +1,45 @@
+from typing import Dict, Tuple, Optional
+from collections import OrderedDict
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.cache import make_prompt_cache
+
+from ..shard import Shard
+
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_kv_size = max_kv_size
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+  
+  def init_cache(self, request_id: str):
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    # if self.max_kv_size is not None:
+      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    cache = make_prompt_cache(self.model)
+
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str, inference_state: Optional[dict] = None):
+    if self.model.model_type !='StableDiffusionPipeline':
+      if request_id not in self.caches:
+        self.init_cache(request_id)
+      else:
+        self.caches.move_to_end(request_id)
+
+      cache = self.caches[request_id]
+
+      y = self.model(x, cache=cache)
+    else:
+      y, inference_state = self.model(x, **inference_state)
+    return y, inference_state
+    

+ 40 - 0
build/lib/exo/inference/mlx/test_sharded_llama.py

@@ -0,0 +1,40 @@
+import mlx.core as mx
+from exo.inference.mlx.stateful_model import StatefulModel
+from exo.inference.mlx.sharded_utils import load_shard
+from exo.inference.shard import Shard
+
+# 79, 80 for Llama-3-70B
+shard_full = Shard("llama", 0, 31, 32)
+shard1 = Shard("llama", 0, 12, 32)
+shard2 = Shard("llama", 13, 31, 32)
+
+full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
+model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
+model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
+
+full = StatefulModel(shard_full, full_model_shard)
+m1 = StatefulModel(shard1, model_shard1)
+m2 = StatefulModel(shard2, model_shard2)
+
+prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
+prompt_tokens = mx.array(full_tokenizer.encode(prompt))
+max_tokens = 50
+
+resp = prompt_tokens
+full_generated_tokens = []
+for _ in range(max_tokens):
+  resp = full.step(resp)
+  full_generated_tokens.append(resp.item())
+
+print("full response: ", full_tokenizer.decode(full_generated_tokens))
+
+sharded_generated_tokens = []
+sharded_resp = prompt_tokens
+for _ in range(max_tokens):
+  resp1 = m1.step(sharded_resp)
+  sharded_resp = m2.step(resp1)
+  sharded_generated_tokens.append(sharded_resp.item())
+
+print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
+
+assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)

+ 64 - 0
build/lib/exo/inference/mlx/test_sharded_llava.py

@@ -0,0 +1,64 @@
+import codecs
+import asyncio
+import requests
+from PIL import Image
+from io import BytesIO
+
+import mlx.core as mx
+from mlx_lm.models.cache import KVCache
+
+from exo.inference.mlx.stateful_model import StatefulModel
+from exo.inference.mlx.sharded_utils import load_shard
+from exo.inference.shard import Shard
+
+shard_full = Shard("llava", 0, 31, 32)
+shard1 = Shard("llava", 0, 12, 32)
+shard2 = Shard("llava", 13, 31, 32)
+
+model_path = "llava-hf/llava-1.5-7b-hf"
+
+full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
+model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
+model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
+
+full = StatefulShardedModel(shard_full, full_model_shard)
+m1 = StatefulShardedModel(shard1, model_shard1)
+m2 = StatefulShardedModel(shard2, model_shard2)
+
+PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
+IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
+response = requests.get(IMAGE_FILE)
+img = Image.open(BytesIO(response.content))
+prompt = codecs.decode(PROMPT, "unicode_escape")
+inputs = full_processor(prompt, img, return_tensors="np")
+pixel_values = mx.array(inputs["pixel_values"])
+input_ids = mx.array(inputs["input_ids"])
+
+print(prompt)
+y = full.step("full", input_ids, pixel_values, temp=0)
+full_generated_tokens = [y.item()]
+
+for _ in range(13):
+  y = full.step("full", y, temp=0)
+  full_generated_tokens.append(y.item())
+
+full_response = full_processor.tokenizer.decode(full_generated_tokens)
+print("full response:", full_response)
+
+inputs = processor1(prompt, img, return_tensors="np")
+pixel_values = mx.array(inputs["pixel_values"])
+input_ids = mx.array(inputs["input_ids"])
+
+y = m1.step("shard", input_ids, pixel_values, temp=0)
+y = m2.step("shard", y, temp=0)
+full_generated_tokens = [y.item()]
+
+for _ in range(13):
+  y = m1.step("shard", y, temp=0)
+  y = m2.step("shard", y, temp=0)
+  full_generated_tokens.append(y.item())
+
+sharded_response = processor2.tokenizer.decode(full_generated_tokens)
+print("sharded response:", sharded_response)
+
+assert full_response == sharded_response

+ 52 - 0
build/lib/exo/inference/mlx/test_sharded_model.py

@@ -0,0 +1,52 @@
+from exo.inference.shard import Shard
+import mlx.core as mx
+import mlx.nn as nn
+from typing import Optional
+import numpy as np
+
+
+class DummyModel(nn.Module):
+  def __init__(self, shard: Optional[Shard] = None):
+    self.shard = shard
+    self.layers = [
+      nn.Linear(8, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 8),
+    ]
+
+    self.n_kv_heads = 4
+    self.head_dim = 4
+
+  def __call__(self, x, cache=None):
+    if self.shard:
+      for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
+        x = layer(x)
+      if self.shard.is_last_layer():
+        x = x.reshape((1, 2, 4))
+    else:
+      for layer in self.layers:
+        x = layer(x)
+      x = x.reshape((1, 2, 4))
+
+    return x
+
+
+model = DummyModel()
+model.save_weights("./test_weights.npz")
+n_layers = 5
+shard1 = Shard("test", 0, n_layers // 2, n_layers)
+sharded_model1 = DummyModel(shard1)
+shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
+sharded_model2 = DummyModel(shard2)
+
+model.load_weights("./test_weights.npz")
+sharded_model1.load_weights("./test_weights.npz")
+sharded_model2.load_weights("./test_weights.npz")
+
+fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
+resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
+resp2 = sharded_model2(resp1)
+
+assert np.all(np.array(fullresp) == np.array(resp2))

+ 39 - 0
build/lib/exo/inference/shard.py

@@ -0,0 +1,39 @@
+from dataclasses import dataclass, field
+
+
+@dataclass(frozen=True)
+class Shard:
+  model_id: str
+  start_layer: int
+  end_layer: int
+  n_layers: int
+
+  def __hash__(self):
+    return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
+
+  def is_first_layer(self) -> bool:
+    return self.start_layer == 0
+
+  def is_last_layer(self) -> bool:
+    return self.end_layer == self.n_layers - 1
+
+  def get_layer_count(self) -> int:
+    return self.end_layer - self.start_layer + 1
+
+  def to_dict(self) -> dict:
+    return {
+      "model_id": self.model_id,
+      "start_layer": self.start_layer,
+      "end_layer": self.end_layer,
+      "n_layers": self.n_layers,
+    }
+
+  def from_dict(data: dict) -> 'Shard':
+    return Shard(**data)
+
+  def overlaps(self, other: 'Shard') -> bool:
+    return shards_overlap(self, other)
+
+
+def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
+  return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))

+ 53 - 0
build/lib/exo/inference/test_dummy_inference_engine.py

@@ -0,0 +1,53 @@
+import pytest
+import json
+import numpy as np
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
+from exo.inference.shard import Shard
+
+
+class MockShardDownloader:
+  async def ensure_shard(self, shard):
+    pass
+
+
+@pytest.mark.asyncio
+async def test_dummy_inference_specific():
+  engine = DummyInferenceEngine(MockShardDownloader())
+  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+  test_prompt = "This is a test prompt"
+
+  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+
+  print(f"Inference result shape: {result.shape}")
+
+  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+
+
+@pytest.mark.asyncio
+async def test_dummy_inference_engine():
+  # Initialize the DummyInferenceEngine
+  engine = DummyInferenceEngine(MockShardDownloader())
+
+  # Create a test shard
+  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+
+  # Test infer_prompt
+  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  # Test infer_tensor
+  input_tensor = np.array([[1, 2, 3]])
+  output = await engine.infer_tensor("test_id", shard, input_tensor)
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  print("All tests passed!")
+
+
+if __name__ == "__main__":
+  import asyncio
+  asyncio.run(test_dummy_inference_engine())
+  asyncio.run(test_dummy_inference_specific())

+ 56 - 0
build/lib/exo/inference/test_inference_engine.py

@@ -0,0 +1,56 @@
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+from exo.helpers import DEBUG
+import os
+import asyncio
+import numpy as np
+
+
+# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
+  prompt = "In a single word only, what is the last name of the current president of the USA?"
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+  token_full = token_full.reshape(1, -1)
+  next_resp_full = await inference_engine_1.infer_tensor(
+    "A",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
+    input_data=token_full,
+  )
+
+  pp = n_layers // 2
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
+    input_data=resp1,
+  )
+  tokens2 = await inference_engine_1.sample(resp2)
+  tokens2 = tokens2.reshape(1, -1)
+  resp3 = await inference_engine_1.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
+    input_data=tokens2,
+  )
+  resp4 = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
+    input_data=resp3,
+  )
+
+  assert np.array_equal(resp_full, resp2)
+  assert np.array_equal(next_resp_full, resp4)
+
+
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
+
+if os.getenv("RUN_TINYGRAD", default="0") == "1":
+  import tinygrad
+  import os
+  from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+  tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
+  asyncio.run(
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
+  )

+ 0 - 0
build/lib/exo/inference/tinygrad/__init__.py


+ 99 - 0
build/lib/exo/inference/tinygrad/inference.py

@@ -0,0 +1,99 @@
+from pathlib import Path
+import json
+import os
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
+from exo.inference.shard import Shard
+from exo.inference.tokenizers import resolve_tokenizer
+from tinygrad.nn.state import load_state_dict
+from tinygrad import Tensor, nn, Context
+from exo.inference.inference_engine import InferenceEngine
+import numpy as np
+from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
+from exo.download.shard_download import ShardDownloader
+from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
+import asyncio
+
+Tensor.no_grad = True
+# default settings
+TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
+TOP_K = 25
+TOP_P = 0.9
+ALPHA_F = 0.1
+ALPHA_P = 0.0
+MODEL_PARAMS = {
+  "1B": {
+    "args": {
+      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "3B": {
+    "args": {
+      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
+}
+
+
+def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
+  # build model
+  linear = nn.Linear
+  model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
+
+  # load weights
+  if model_path.is_dir():
+    if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
+    elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+  else:
+    weights = load(str(model_path), shard)
+  weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
+  weights = fix_bf16(weights)
+
+  with Context(BEAM=0):
+    # replace weights in model
+    load_state_dict(model, weights, strict=False, consume=False)  # consume=True
+  return model
+
+class TinygradDynamicShardInferenceEngine(InferenceEngine):
+  def __init__(self, shard_downloader: ShardDownloader):
+    self.shard = None
+    self.shard_downloader = shard_downloader
+    self.executor = ThreadPoolExecutor(max_workers=1)
+
+  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
+    logits = x[:, -1, :]
+    def sample_wrapper():
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
+  
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
+
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
+
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
+
+    if self.shard != shard:
+      loop = asyncio.get_running_loop()
+      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
+
+      tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
+      self.tokenizer = await resolve_tokenizer(tokenizer_path)
+      self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 0
build/lib/exo/inference/tinygrad/models/__init__.py


+ 282 - 0
build/lib/exo/inference/tinygrad/models/llama.py

@@ -0,0 +1,282 @@
+from typing import Tuple, Union, Optional, Dict, Any, List
+from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
+from tinygrad.helpers import getenv
+from collections import OrderedDict
+
+
+# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
+  freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+
+  if rope_scaling:
+    factor = rope_scaling.get('factor', 1.0)
+    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
+    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
+    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
+
+    freqs[:dim // 4] *= low_freq_factor
+    freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
+    freqs *= (original_max_pos_emb/end)**(1.0/factor)
+
+  freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
+  # TODO: move dtype outside this
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
+
+
+# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
+def complex_mult(A, c, d):
+  a, b = A[..., 0:1], A[..., 1:2]
+  ro = a*c - b*d
+  co = a*d + b*c
+  return ro.cat(co, dim=-1)
+
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
+  assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
+  xq = xq.reshape(*xq.shape[0:-1], -1, 2)
+  xk = xk.reshape(*xk.shape[0:-1], -1, 2)
+  assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
+  c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
+  xq_out = complex_mult(xq, c, d)
+  xk_out = complex_mult(xk, c, d)
+  return xq_out.flatten(3), xk_out.flatten(3)
+
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
+  bs, seqlen, n_kv_heads, head_dim = x.shape
+  if n_rep == 1: return x
+  # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
+  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
+
+class Attention:
+  def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
+    self.n_heads = n_heads
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.head_dim = dim // n_heads
+    self.n_rep = self.n_heads // self.n_kv_heads
+    self.max_context = max_context
+
+    self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
+    self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
+    self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
+    self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
+
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
+    if getenv("WQKV"):
+      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
+      xqkv = x @ self.wqkv.T
+      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
+    else:
+      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+    xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
+    xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
+    xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
+
+    xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
+    bsz, seqlen, _, _ = xq.shape
+
+    if cache is not None:
+      # update the cache
+      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
+      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+
+      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    else:
+      keys = xk
+      values = xv
+
+    keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
+    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
+    attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
+    attn = attn.reshape(bsz, seqlen, -1)
+    return self.wo(attn)
+
+
+class FeedForward:
+  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
+    self.w1 = linear(dim, hidden_dim, bias=False)
+    self.w2 = linear(hidden_dim, dim, bias=False)
+    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
+
+  def __call__(self, x: Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu()*self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
+
+
+class TransformerBlock:
+  def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
+    self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
+    self.feed_forward = feed_forward(dim, hidden_dim, linear)
+    self.attention_norm = nn.RMSNorm(dim, norm_eps)
+    self.ffn_norm = nn.RMSNorm(dim, norm_eps)
+
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
+    return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
+
+
+# standard openai sampling
+def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+  assert logits.ndim == 1, "only works on 1d tensors"
+  assert 0 <= p <= 1, "p must be between 0 and 1"
+  assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
+
+  # if temperature is very low just use argmax
+  if temp < 1e-6: return logits.argmax().reshape(1)
+
+  # alpha sampling
+  if af or ap:
+    if not hasattr(sample, "alpha_counter"):
+      setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
+    logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
+
+  # replace NaNs with -inf
+  logits = (logits != logits).where(-float("inf"), logits)
+
+  # softmax
+  t = (logits/temp).softmax()
+
+  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
+  # top k
+  if k:
+    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
+    for i in range(k):
+      t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
+      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
+      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
+      t = (counter == t_argmax).where(0, t)
+
+    # approximate top p
+    # because we are already limited to top k elements we can do top p "without sorting"
+    output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
+    output = (output_cumsum >= (1 - p))*output
+    output_indices = (output_cumsum >= (1 - p))*output_indices
+
+    # sample
+    output_idx = output.multinomial()
+    output_token = output_indices[output_idx]
+  else:
+    output_token = t.multinomial()
+
+  # increase alpha counter
+  if af or ap:
+    sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
+
+  return output_token
+
+
+from exo.inference.shard import Shard
+
+
+class Transformer:
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_layers: int,
+    norm_eps: float,
+    vocab_size,
+    shard: Shard = None,
+    linear=nn.Linear,
+    n_kv_heads=None,
+    rope_theta=10000,
+    max_context=1024,
+    jit=True,
+    feed_forward=FeedForward,
+    rope_scaling: Optional[Dict[str, float]] = None,
+    tie_word_embeddings=False,
+  ):
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
+    self.norm = nn.RMSNorm(dim, norm_eps)
+    self.tok_embeddings = nn.Embedding(vocab_size, dim)
+    self.output = nn.Linear(dim, vocab_size, bias=False)
+    if tie_word_embeddings:
+      self.output.weight = self.tok_embeddings.weight
+    self.max_context = max_context
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
+    self.forward_jit = TinyJit(self.forward_base) if jit else None
+    self.shard = shard
+
+  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
+    seqlen = x.shape[1]
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
+
+    h = x
+
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
+      layer = self.layers[i]
+      h = layer(h, start_pos, freqs_cis, mask, cache=c)
+
+    if self.shard.is_last_layer():
+      logits = self.output(self.norm(h)).float().realize()
+      return logits
+    else:
+      return h
+
+  def embed(self, inputs: Tensor):
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(inputs)
+    else:
+      h = inputs
+    return h
+
+  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
+    return self.forward_base(x, start_pos, cache=cache)
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
+    # TODO: better way to handle the first call v.s. the rest?
+    h = self.embed(x)
+    return self.forward(h, start_pos, cache=cache)
+
+
+# *** helpers ***
+
+
+def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
+  def permute(v: Tensor, n_heads: int):
+    return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
+
+  keymap = {
+    "model.embed_tokens.weight": "tok_embeddings.weight",
+    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
+       for x in ["q", "k", "v", "o"]
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
+       for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
+       for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
+       for l in range(len(model.layers))},
+    "model.norm.weight": "norm.weight",
+    "lm_head.weight": "output.weight",
+  }
+  sd = {}
+  for k, v in weights.items():
+    if ".rotary_emb." in k: continue
+    v = v.to(Device.DEFAULT)
+    if "model.layers" in k:
+      if "q_proj" in k:
+        v = permute(v, n_heads)
+      elif "k_proj" in k:
+        v = permute(v, n_kv_heads)
+    if k in keymap:
+      sd[keymap[k]] = v
+    else:
+      sd[k] = v
+  return sd
+
+
+def fix_bf16(weights: Dict[Any, Tensor]):
+  if getenv("SUPPORT_BF16", 1):
+    # TODO: without casting to float16, 70B llama OOM on tinybox.
+    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
+  # TODO: check if device supports bf16
+  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 42 - 0
build/lib/exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,42 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+from typing import List
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class ModelState:
+  cache: List[Tensor]
+  start: int 
+  def __init__(self, cache: List[Tensor], start: int = 0):
+    self.cache = cache
+    self.start = start
+
+class StatefulModel:
+  def __init__(self, model, max_states: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_states = max_states
+    self.states = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.states) >= self.max_states:
+      self.states.popitem(last=False)
+
+    self.states[request_id] = ModelState(cache)
+
+  def __call__(self, x: Tensor, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.states:
+      self.init_cache(h, request_id)
+    else:
+      self.states.move_to_end(request_id)
+    out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
+    self.states[request_id].start += h.shape[1]
+    return out
+

+ 52 - 0
build/lib/exo/inference/tinygrad/tinygrad_helpers.py

@@ -0,0 +1,52 @@
+from tinygrad.nn.state import safe_load, torch_load
+from tinygrad import Tensor
+from pathlib import Path
+import json
+from typing import List
+from exo.inference.shard import Shard
+from exo.helpers import DEBUG
+from exo.download.hf.hf_helpers import get_allow_patterns
+from fnmatch import fnmatch
+import re
+
+
+# **** helper functions ****
+def concat_weights(models, device=None):
+  def convert(name) -> Tensor:
+    disk_tensors: List[Tensor] = [model[name] for model in models]
+    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
+      return disk_tensors[0].to(device=device)
+    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
+    lazy_tensors = [data.to(device=device) for data in disk_tensors]
+    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+
+  return {name: convert(name) for name in {name: None for model in models for name in model}}
+
+
+def load(fn: str, shard: Shard):
+  if fn.endswith('.index.json'):
+    with open(fn) as fp:
+      weight_map = json.load(fp)['weight_map']
+    parts = {}
+    filtered_weight_map = {}
+    allow_patterns = get_allow_patterns(weight_map, shard)
+    for k, n in weight_map.items():
+      if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
+        continue
+      if k.startswith("model.layers."):
+        layer_num = int(k.split('.')[2])
+        if layer_num < shard.start_layer or layer_num > shard.end_layer:
+          continue
+
+      parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
+      filtered_weight_map[k] = n
+    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
+    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
+  elif fn.endswith(".safetensors"):
+    weight_map = safe_load(fn)
+    for k in list(weight_map):
+      if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
+          del weight_map[k]
+    return weight_map
+  else:
+    return torch_load(fn)

+ 64 - 0
build/lib/exo/inference/tokenizers.py

@@ -0,0 +1,64 @@
+import traceback
+from aiofiles import os as aios
+from os import PathLike
+from pathlib import Path
+from typing import Union
+from transformers import AutoTokenizer, AutoProcessor
+import numpy as np
+from exo.download.hf.hf_helpers import get_local_snapshot_dir
+from exo.helpers import DEBUG
+
+
+class DummyTokenizer:
+  def __init__(self):
+    self.eos_token_id = 69
+    self.vocab_size = 1000
+
+  def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
+    return "dummy_tokenized_prompt"
+
+  def encode(self, text):
+    return np.array([1])
+
+  def decode(self, tokens):
+    return "dummy" * len(tokens)
+
+
+async def resolve_tokenizer(model_id: str):
+  if model_id == "dummy":
+    return DummyTokenizer()
+  local_path = await get_local_snapshot_dir(model_id)
+  if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
+  try:
+    if local_path and await aios.path.exists(local_path):
+      if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
+      return await _resolve_tokenizer(local_path)
+  except:
+    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
+    if DEBUG >= 5: traceback.print_exc()
+  return await _resolve_tokenizer(model_id)
+
+
+async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
+  try:
+    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
+    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
+    if not hasattr(processor, 'eos_token_id'):
+      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
+    if not hasattr(processor, 'encode'):
+      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
+    if not hasattr(processor, 'decode'):
+      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
+    return processor
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  try:
+    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
+    return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
+  except Exception as e:
+    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
+    if DEBUG >= 4: print(traceback.format_exc())
+
+  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")

+ 274 - 0
build/lib/exo/main.py

@@ -0,0 +1,274 @@
+import argparse
+import asyncio
+import atexit
+import signal
+import json
+import logging
+import platform
+import os
+import sys
+import time
+import traceback
+import uuid
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
+from exo.orchestration.standard_node import StandardNode
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.networking.udp.udp_discovery import UDPDiscovery
+from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
+from exo.api import ChatGPTAPI
+from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
+from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
+from exo.inference.shard import Shard
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.inference.tokenizers import resolve_tokenizer
+from exo.orchestration.node import Node
+from exo.models import build_base_shard, get_repo
+from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
+
+# parse args
+parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
+parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
+parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
+parser.add_argument("--node-id", type=str, default=None, help="Node ID")
+parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
+parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
+parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
+parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
+parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
+parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
+parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
+parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
+parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
+parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
+parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
+parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
+parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
+parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
+parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
+parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
+parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
+parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
+args = parser.parse_args()
+print(f"Selected inference engine: {args.inference_engine}")
+
+print_yellow_exo()
+
+system_info = get_system_info()
+print(f"Detected system: {system_info}")
+
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
+                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
+print(f"Inference engine name after selection: {inference_engine_name}")
+
+inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
+print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
+
+if args.node_port is None:
+  args.node_port = find_available_port(args.node_host)
+  if DEBUG >= 1: print(f"Using available port: {args.node_port}")
+
+args.node_id = args.node_id or get_or_create_node_id()
+chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
+web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
+if DEBUG >= 0:
+  print("Chat interface started:")
+  for web_chat_url in web_chat_urls:
+    print(f" - {terminal_link(web_chat_url)}")
+  print("ChatGPT API endpoint served at:")
+  for chatgpt_api_endpoint in chatgpt_api_endpoints:
+    print(f" - {terminal_link(chatgpt_api_endpoint)}")
+
+# Convert node-id-filter to list if provided
+allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
+
+if args.discovery_module == "udp":
+  discovery = UDPDiscovery(
+    args.node_id,
+    args.node_port,
+    args.listen_port,
+    args.broadcast_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout,
+    allowed_node_ids=allowed_node_ids
+  )
+elif args.discovery_module == "tailscale":
+  discovery = TailscaleDiscovery(
+    args.node_id,
+    args.node_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout,
+    tailscale_api_key=args.tailscale_api_key,
+    tailnet=args.tailnet_name,
+    allowed_node_ids=allowed_node_ids
+  )
+elif args.discovery_module == "manual":
+  if not args.discovery_config_path:
+    raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
+node = StandardNode(
+  args.node_id,
+  None,
+  inference_engine,
+  discovery,
+  partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+  max_generate_tokens=args.max_generate_tokens,
+  topology_viz=topology_viz,
+  shard_downloader=shard_downloader,
+  default_sample_temperature=args.default_temp
+)
+server = GRPCServer(node, args.node_host, args.node_port)
+node.server = server
+api = ChatGPTAPI(
+  node,
+  inference_engine.__class__.__name__,
+  response_timeout=args.chatgpt_api_response_timeout,
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
+)
+node.on_token.register("update_topology_viz").on_next(
+  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
+)
+
+def preemptively_start_download(request_id: str, opaque_status: str):
+  try:
+    status = json.loads(opaque_status)
+    if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
+      current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
+      if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Failed to preemptively start download: {e}")
+      traceback.print_exc()
+
+
+node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
+
+if args.prometheus_client_port:
+  from exo.stats.metrics import start_metrics_server
+  start_metrics_server(node, args.prometheus_client_port)
+
+last_broadcast_time = 0
+
+
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+
+
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
+
+async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
+  inference_class = inference_engine.__class__.__name__
+  shard = build_base_shard(model_name, inference_class)
+  if not shard:
+    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
+    return
+  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
+  request_id = str(uuid.uuid4())
+  callback_id = f"cli-wait-response-{request_id}"
+  callback = node.on_token.register(callback_id)
+  if topology_viz:
+    topology_viz.update_prompt(request_id, prompt)
+  prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
+
+  try:
+    print(f"Processing prompt: {prompt}")
+    await node.process_prompt(shard, prompt, request_id=request_id)
+
+    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+
+    print("\nGenerated response:")
+    print(tokenizer.decode(tokens))
+  except Exception as e:
+    print(f"Error processing prompt: {str(e)}")
+    traceback.print_exc()
+  finally:
+    node.on_token.deregister(callback_id)
+
+def clean_path(path):
+    """Clean and resolve path"""
+    if path.startswith("Optional("):
+        path = path.strip('Optional("').rstrip('")')
+    return os.path.expanduser(path)
+
+async def main():
+  loop = asyncio.get_running_loop()
+
+  # Check HuggingFace directory permissions
+  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
+  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  print(f"{has_read=}, {has_write=}")
+  if not has_read or not has_write:
+    print(f"""
+          WARNING: Limited permissions for model storage directory: {hf_home}.
+          This may prevent model downloads from working correctly.
+          {"❌ No read access" if not has_read else ""}
+          {"❌ No write access" if not has_write else ""}
+          """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      models_seed_dir = clean_path(args.models_seed_dir)
+      await move_models_to_hf(models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
+
+  def restore_cursor():
+    if platform.system() != "Windows":
+        os.system("tput cnorm")  # Show cursor
+
+  # Restore the cursor when the program exits
+  atexit.register(restore_cursor)
+
+  # Use a more direct approach to handle signals
+  def handle_exit():
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
+
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
+
+  await node.start(wait_for_peers=args.wait_for_peers)
+
+  if args.command == "run" or args.run_model:
+    model_name = args.model_name or args.run_model
+    if not model_name:
+      print("Error: Model name is required when using 'run' command or --run-model")
+      return
+    await run_model_cli(node, inference_engine, model_name, args.prompt)
+  else:
+    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
+    await asyncio.Event().wait()
+
+
+def run():
+  loop = asyncio.new_event_loop()
+  asyncio.set_event_loop(loop)
+  try:
+    loop.run_until_complete(main())
+  except KeyboardInterrupt:
+    print("Received keyboard interrupt. Shutting down...")
+  finally:
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
+    loop.close()
+
+
+if __name__ == "__main__":
+  run()

+ 151 - 0
build/lib/exo/models.py

@@ -0,0 +1,151 @@
+from exo.inference.shard import Shard
+from typing import Optional, List
+
+model_cards = {
+  ### llama
+  "llama-3.2-1b": {
+    "layers": 16,
+    "repo": {
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
+  },
+  "llama-3.2-3b": {
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
+  },
+  "llama-3.1-8b": {
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+    },
+  },
+  "llama-3.1-70b": {
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
+  },
+  "llama-3.1-70b-bf16": {
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
+  },
+  "llama-3-8b": {
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+    },
+  },
+  "llama-3-70b": {
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+    },
+  },
+  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
+  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
+  ### mistral
+  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
+  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
+  ### deepseek
+  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
+  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
+  ### llava
+  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
+  ### qwen
+  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
+  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
+  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
+  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
+  ### nemotron
+  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
+  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
+  # gemma
+  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
+  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
+  # stable diffusion
+  "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
+  # dummy
+  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
+}
+
+pretty_name = {
+  "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.1-8b": "Llama 3.1 8B",
+  "llama-3.1-70b": "Llama 3.1 70B",
+  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
+  "llama-3.1-405b": "Llama 3.1 405B",
+  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
+  "gemma2-9b": "Gemma2 9B",
+  "gemma2-27b": "Gemma2 27B",
+  "nemotron-70b": "Nemotron 70B",
+  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
+  "mistral-nemo": "Mistral Nemo",
+  "mistral-large": "Mistral Large",
+  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
+  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
+  "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
+  "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-72b": "Qwen 2.5 72B",
+  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "llama-3-8b": "Llama 3 8B",
+  "llama-3-70b": "Llama 3 70B",
+  "stable-diffusion-2-1-base": "Stable Diffusion 2.1",
+}
+
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
+  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
+
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  repo = get_repo(model_id, inference_engine_classname)
+  n_layers = model_cards.get(model_id, {}).get("layers", 0)
+  if repo is None or n_layers < 1:
+    return None
+  return Shard(model_id, 0, 0, n_layers)
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 5 - 0
build/lib/exo/networking/__init__.py

@@ -0,0 +1,5 @@
+from .discovery import Discovery
+from .peer_handle import PeerHandle
+from .server import Server
+
+__all__ = ["Discovery", "PeerHandle", "Server"]

+ 17 - 0
build/lib/exo/networking/discovery.py

@@ -0,0 +1,17 @@
+from abc import ABC, abstractmethod
+from typing import List
+from .peer_handle import PeerHandle
+
+
+class Discovery(ABC):
+  @abstractmethod
+  async def start(self) -> None:
+    pass
+
+  @abstractmethod
+  async def stop(self) -> None:
+    pass
+
+  @abstractmethod
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    pass

+ 0 - 0
build/lib/exo/networking/grpc/__init__.py


+ 173 - 0
build/lib/exo/networking/grpc/grpc_peer_handle.py

@@ -0,0 +1,173 @@
+import grpc
+import numpy as np
+import asyncio
+from typing import Optional, Tuple, List
+
+from . import node_service_pb2
+from . import node_service_pb2_grpc
+
+from ..peer_handle import PeerHandle
+from exo.inference.shard import Shard
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from exo.helpers import DEBUG
+import json
+import mlx.core as mx
+
+class GRPCPeerHandle(PeerHandle):
+  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+    self._id = _id
+    self.address = address
+    self._device_capabilities = device_capabilities
+    self.channel = None
+    self.stub = None
+
+  def id(self) -> str:
+    return self._id
+
+  def addr(self) -> str:
+    return self.address
+
+  def device_capabilities(self) -> DeviceCapabilities:
+    return self._device_capabilities
+
+  async def connect(self):
+    if self.channel is None:
+      self.channel = grpc.aio.insecure_channel(self.address, options=[
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ('grpc.max_receive_message_length', 32*1024*1024),
+        ('grpc.max_send_message_length', 32*1024*1024)
+      ])
+      self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+    await self.channel.channel_ready()
+
+  async def is_connected(self) -> bool:
+    return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
+
+  async def disconnect(self):
+    if self.channel:
+      await self.channel.close()
+    self.channel = None
+    self.stub = None
+
+  async def _ensure_connected(self):
+    if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
+
+  async def health_check(self) -> bool:
+    try:
+      await self._ensure_connected()
+      request = node_service_pb2.HealthCheckRequest()
+      response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
+      return response.is_healthy
+    except asyncio.TimeoutError:
+      return False
+    except Exception:
+      if DEBUG >= 4:
+        print(f"Health check failed for {self._id}@{self.address}.")
+        import traceback
+        traceback.print_exc()
+      return False
+
+  async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
+    request = node_service_pb2.PromptRequest(
+      prompt=prompt,
+      shard=node_service_pb2.Shard(
+        model_id=shard.model_id,
+        start_layer=shard.start_layer,
+        end_layer=shard.end_layer,
+        n_layers=shard.n_layers,
+      ),
+      request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
+    )
+    response = await self.stub.SendPrompt(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
+    request = node_service_pb2.TensorRequest(
+      shard=node_service_pb2.Shard(
+        model_id=shard.model_id,
+        start_layer=shard.start_layer,
+        end_layer=shard.end_layer,
+        n_layers=shard.n_layers,
+      ),
+      tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
+      request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
+    )
+    response = await self.stub.SendTensor(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
+    response = await self.stub.GetInferenceResult(request)
+    if response.tensor is None:
+      return None, response.is_finished
+    return (
+      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
+      response.is_finished,
+    )
+
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
+    response = await self.stub.CollectTopology(request)
+    topology = Topology()
+    for node_id, capabilities in response.nodes.items():
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+      )
+      topology.update_node(node_id, device_capabilities)
+    for node_id, peers in response.peer_graph.items():
+      for peer_id in peers.peer_ids:
+        topology.add_edge(node_id, peer_id)
+    return topology
+
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    tensor = None
+    if isinstance(result, np.ndarray):
+      tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
+      result = []
+    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
+    await self.stub.SendResult(request)
+
+  async def send_opaque_status(self, request_id: str, status: str) -> None:
+    request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
+    await self.stub.SendOpaqueStatus(request)
+
+  def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
+    proto_inference_state = node_service_pb2.InferenceState()
+    other_data = {}
+    for k, v in inference_state.items():
+        if isinstance(v, mx.array):
+            np_array = np.array(v)
+            tensor_data = node_service_pb2.Tensor(
+                tensor_data=np_array.tobytes(),
+                shape=list(np_array.shape),
+                dtype=str(np_array.dtype)
+            )
+            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
+        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+            tensor_list = node_service_pb2.TensorList()
+            for tensor in v:
+                np_array = np.array(tensor)
+                tensor_data = node_service_pb2.Tensor(
+                    tensor_data=np_array.tobytes(),
+                    shape=list(np_array.shape),
+                    dtype=str(np_array.dtype)
+                )
+                tensor_list.tensors.append(tensor_data)
+            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
+        else:
+            # For non-tensor data, we'll still use JSON
+            other_data[k] = v
+    if other_data:
+      proto_inference_state.other_data_json = json.dumps(other_data)
+    return proto_inference_state

+ 147 - 0
build/lib/exo/networking/grpc/grpc_server.py

@@ -0,0 +1,147 @@
+import grpc
+from concurrent import futures
+import numpy as np
+from asyncio import CancelledError
+
+from . import node_service_pb2
+from . import node_service_pb2_grpc
+from exo import DEBUG
+from exo.inference.shard import Shard
+from exo.orchestration import Node
+import json
+import mlx.core as mx
+
+
+class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
+  def __init__(self, node: Node, host: str, port: int):
+    self.node = node
+    self.host = host
+    self.port = port
+    self.server = None
+
+  async def start(self) -> None:
+    self.server = grpc.aio.server(
+      futures.ThreadPoolExecutor(max_workers=10),
+      options=[
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ("grpc.max_send_message_length", 128*1024*1024),
+        ("grpc.max_receive_message_length", 128*1024*1024),
+      ],
+    )
+    node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
+    listen_addr = f"{self.host}:{self.port}"
+    self.server.add_insecure_port(listen_addr)
+    await self.server.start()
+    if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
+
+  async def stop(self) -> None:
+    if self.server:
+      try:
+        await self.server.stop(grace=5)
+        await self.server.wait_for_termination()
+      except CancelledError:
+        pass
+      if DEBUG >= 1: print("Server stopped and all connections are closed")
+
+  async def SendPrompt(self, request, context):
+    shard = Shard(
+      model_id=request.shard.model_id,
+      start_layer=request.shard.start_layer,
+      end_layer=request.shard.end_layer,
+      n_layers=request.shard.n_layers,
+    )
+    prompt = request.prompt
+    request_id = request.request_id
+    inference_state = self.deserialize_inference_state(request.inference_state)
+    result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+
+  async def SendTensor(self, request, context):
+    shard = Shard(
+      model_id=request.shard.model_id,
+      start_layer=request.shard.start_layer,
+      end_layer=request.shard.end_layer,
+      n_layers=request.shard.n_layers,
+    )
+    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
+    request_id = request.request_id
+
+    inference_state = self.deserialize_inference_state(request.inference_state)
+
+    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+
+  async def GetInferenceResult(self, request, context):
+    request_id = request.request_id
+    result = await self.node.get_inference_result(request_id)
+    if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
+    tensor_data = result[0].tobytes() if result[0] is not None else None
+    return (
+      node_service_pb2.InferenceResult(
+        tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
+        is_finished=result[1],
+      ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
+    )
+
+  async def CollectTopology(self, request, context):
+    max_depth = request.max_depth
+    visited = set(request.visited)
+    topology = await self.node.collect_topology(visited, max_depth)
+    nodes = {
+      node_id:
+        node_service_pb2.DeviceCapabilities(
+          model=cap.model,
+          chip=cap.chip,
+          memory=cap.memory,
+          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
+        )
+      for node_id, cap in topology.nodes.items()
+    }
+    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
+    return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
+
+  async def SendResult(self, request, context):
+    request_id = request.request_id
+    result = request.result
+    is_finished = request.is_finished
+    img = request.tensor
+    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    result = list(result)
+    if len(img.tensor_data) > 0:
+      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
+    self.node.on_token.trigger_all(request_id, result, is_finished)
+    return node_service_pb2.Empty()
+
+  async def SendOpaqueStatus(self, request, context):
+    request_id = request.request_id
+    status = request.status
+    if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+    self.node.on_opaque_status.trigger_all(request_id, status)
+    return node_service_pb2.Empty()
+
+  async def HealthCheck(self, request, context):
+    return node_service_pb2.HealthCheckResponse(is_healthy=True)
+
+  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
+    inference_state = {}
+    
+    for k, tensor_data in inference_state_proto.tensor_data.items():
+        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
+        inference_state[k] = mx.array(np_array)
+    
+    for k, tensor_list in inference_state_proto.tensor_list_data.items():
+        inference_state[k] = [
+            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
+            for tensor in tensor_list.tensors
+        ]
+    
+    if inference_state_proto.other_data_json:
+        other_data = json.loads(inference_state_proto.other_data_json)
+        inference_state.update(other_data)
+    
+    return inference_state

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 16 - 0
build/lib/exo/networking/grpc/node_service_pb2.py


+ 360 - 0
build/lib/exo/networking/grpc/node_service_pb2_grpc.py

@@ -0,0 +1,360 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+import warnings
+
+from exo.networking.grpc import node_service_pb2 as node__service__pb2
+
+GRPC_GENERATED_VERSION = '1.64.1'
+GRPC_VERSION = grpc.__version__
+EXPECTED_ERROR_RELEASE = '1.65.0'
+SCHEDULED_RELEASE_DATE = 'June 25, 2024'
+_version_not_supported = False
+
+try:
+    from grpc._utilities import first_version_is_lower
+    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+except ImportError:
+    _version_not_supported = True
+
+if _version_not_supported:
+    warnings.warn(
+        f'The grpc package installed is at version {GRPC_VERSION},'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' grpcio>={GRPC_GENERATED_VERSION}.'
+        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
+        RuntimeWarning
+    )
+
+
+class NodeServiceStub(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
+
+        Args:
+            channel: A grpc.Channel.
+        """
+        self.SendPrompt = channel.unary_unary(
+                '/node_service.NodeService/SendPrompt',
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.SendTensor = channel.unary_unary(
+                '/node_service.NodeService/SendTensor',
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
+                _registered_method=True)
+        self.GetInferenceResult = channel.unary_unary(
+                '/node_service.NodeService/GetInferenceResult',
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                _registered_method=True)
+        self.CollectTopology = channel.unary_unary(
+                '/node_service.NodeService/CollectTopology',
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
+                _registered_method=True)
+        self.SendResult = channel.unary_unary(
+                '/node_service.NodeService/SendResult',
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.SendOpaqueStatus = channel.unary_unary(
+                '/node_service.NodeService/SendOpaqueStatus',
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
+                _registered_method=True)
+        self.HealthCheck = channel.unary_unary(
+                '/node_service.NodeService/HealthCheck',
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                _registered_method=True)
+
+
+class NodeServiceServicer(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def SendPrompt(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendTensor(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def GetInferenceResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def CollectTopology(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendResult(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def SendOpaqueStatus(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+    def HealthCheck(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+
+def add_NodeServiceServicer_to_server(servicer, server):
+    rpc_method_handlers = {
+            'SendPrompt': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendPrompt,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'SendTensor': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendTensor,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+            ),
+            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.GetInferenceResult,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+            ),
+            'CollectTopology': grpc.unary_unary_rpc_method_handler(
+                    servicer.CollectTopology,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
+            ),
+            'SendResult': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendResult,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
+                    servicer.SendOpaqueStatus,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
+            ),
+            'HealthCheck': grpc.unary_unary_rpc_method_handler(
+                    servicer.HealthCheck,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'node_service.NodeService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+
+
+ # This class is part of an EXPERIMENTAL API.
+class NodeService(object):
+    """Missing associated documentation comment in .proto file."""
+
+    @staticmethod
+    def SendPrompt(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendPrompt',
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendTensor(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendTensor',
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def GetInferenceResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/GetInferenceResult',
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def CollectTopology(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/CollectTopology',
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendResult(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendResult',
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def SendOpaqueStatus(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/SendOpaqueStatus',
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)
+
+    @staticmethod
+    def HealthCheck(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            '/node_service.NodeService/HealthCheck',
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+            _registered_method=True)

+ 0 - 0
build/lib/exo/networking/manual/__init__.py


+ 71 - 0
build/lib/exo/networking/manual/manual_discovery.py

@@ -0,0 +1,71 @@
+import asyncio
+from exo.networking.discovery import Discovery
+from typing import Dict, List, Callable
+
+from exo.topology.device_capabilities import DeviceCapabilities
+from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
+from exo.helpers import DEBUG_DISCOVERY
+from exo.networking.peer_handle import PeerHandle
+
+
+class ManualDiscovery(Discovery):
+  def __init__(
+    self,
+    network_config_path: str,
+    node_id: str,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+  ):
+    self.topology = NetworkTopology.from_path(network_config_path)
+    self.create_peer_handle = create_peer_handle
+
+    if node_id not in self.topology.peers:
+      raise ValueError(
+        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+      )
+
+    self.listen_task = None
+
+    self.known_peers: Dict[str, PeerHandle] = {}
+    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
+    self.peers_in_network.pop(node_id)
+
+  async def start(self) -> None:
+    self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+
+  async def stop(self) -> None:
+    if self.listen_task:
+      self.listen_task.cancel()
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
+    return list(self.known_peers.values())
+
+  async def task_find_peers_from_config(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
+    while True:
+      for peer_id, peer_config in self.peers_in_network.items():
+        try:
+          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
+          peer = self.known_peers.get(peer_id)
+          if not peer:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+          is_healthy = await peer.health_check()
+          if is_healthy:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
+            self.known_peers[peer_id] = peer
+          else:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass
+        except Exception as e:
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+      await asyncio.sleep(1.0)
+
+      if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

+ 31 - 0
build/lib/exo/networking/manual/network_topology_config.py

@@ -0,0 +1,31 @@
+from typing import Dict
+from pydantic import BaseModel, ValidationError
+
+from exo.topology.device_capabilities import DeviceCapabilities
+
+
+class PeerConfig(BaseModel):
+  address: str
+  port: int
+  device_capabilities: DeviceCapabilities
+
+
+class NetworkTopology(BaseModel):
+  """Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
+
+  peers: Dict[str, PeerConfig]
+  """
+  node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
+  """
+  @classmethod
+  def from_path(cls, path: str) -> "NetworkTopology":
+    try:
+      with open(path, "r") as f:
+        config_data = f.read()
+    except FileNotFoundError as e:
+      raise FileNotFoundError(f"Config file not found at {path}") from e
+
+    try:
+      return cls.model_validate_json(config_data)
+    except ValidationError as e:
+      raise ValueError(f"Error validating network topology config from {path}: {e}") from e

+ 103 - 0
build/lib/exo/networking/manual/test_manual_discovery.py

@@ -0,0 +1,103 @@
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.orchestration.node import Node
+
+root_path = "./exo/networking/manual/test_data/test_config.json"
+
+
+class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    _ = self.discovery1.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
+    assert len(peers1) == 0
+
+    self.peer1.connect.assert_not_called()
+
+
+class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+
+class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    config = NetworkTopology.from_path(root_path)
+
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
+    self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertTrue(await peers2[0].is_connected())
+
+    # Kill server1
+    await self.server1.stop()
+
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+    # Kill server2
+    await self.server2.stop()
+
+    self.assertFalse(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 49 - 0
build/lib/exo/networking/manual/test_network_topology_config.py

@@ -0,0 +1,49 @@
+import unittest
+
+from exo.networking.manual.network_topology_config import NetworkTopology
+
+root_path = "./exo/networking/manual/test_data/"
+
+
+class TestNetworkTopologyConfig(unittest.TestCase):
+  def test_from_path_invalid_path(self):
+    with self.assertRaises(FileNotFoundError) as e:
+      NetworkTopology.from_path("invalid_path")
+    self.assertEqual(str(e.exception), "Config file not found at invalid_path")
+
+  def test_from_path_invalid_json(self):
+    with self.assertRaises(ValueError) as e:
+      NetworkTopology.from_path(root_path + "invalid_json.json")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("1 validation error for NetworkTopology\n  Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
+
+  def test_from_path_invalid_config(self):
+    with self.assertRaises(ValueError) as e:
+      NetworkTopology.from_path(root_path + "invalid_config.json")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("port\n  Field required", str(e.exception))
+
+  def test_from_path_valid(self):
+    config = NetworkTopology.from_path(root_path + "test_config.json")
+
+    self.assertEqual(config.peers["node1"].port, 50051)
+    self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node1"].address, "localhost")
+    self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
+
+    self.assertEqual(config.peers["node2"].port, 50052)
+    self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node2"].address, "localhost")
+    self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
+
+
+if __name__ == "__main__":
+  unittest.main()

+ 56 - 0
build/lib/exo/networking/peer_handle.py

@@ -0,0 +1,56 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Tuple, List
+import numpy as np
+from exo.inference.shard import Shard
+from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.topology import Topology
+
+
+class PeerHandle(ABC):
+  @abstractmethod
+  def id(self) -> str:
+    pass
+
+  @abstractmethod
+  def addr(self) -> str:
+    pass
+
+  @abstractmethod
+  def device_capabilities(self) -> DeviceCapabilities:
+    pass
+
+  @abstractmethod
+  async def connect(self) -> None:
+    pass
+
+  @abstractmethod
+  async def is_connected(self) -> bool:
+    pass
+
+  @abstractmethod
+  async def disconnect(self) -> None:
+    pass
+
+  @abstractmethod
+  async def health_check(self) -> bool:
+    pass
+
+  @abstractmethod
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+    pass
+
+  @abstractmethod
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
+    pass
+
+  @abstractmethod
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    pass
+
+  @abstractmethod
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    pass
+
+  @abstractmethod
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    pass

+ 11 - 0
build/lib/exo/networking/server.py

@@ -0,0 +1,11 @@
+from abc import ABC, abstractmethod
+
+
+class Server(ABC):
+  @abstractmethod
+  async def start(self) -> None:
+    pass
+
+  @abstractmethod
+  async def stop(self) -> None:
+    pass

+ 0 - 0
build/lib/exo/networking/tailscale/__init__.py


+ 178 - 0
build/lib/exo/networking/tailscale/tailscale_discovery.py

@@ -0,0 +1,178 @@
+import asyncio
+import time
+import traceback
+from typing import List, Dict, Callable, Tuple
+from exo.networking.discovery import Discovery
+from exo.networking.peer_handle import PeerHandle
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
+from exo.helpers import DEBUG, DEBUG_DISCOVERY
+from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
+
+
+class TailscaleDiscovery(Discovery):
+  def __init__(
+    self,
+    node_id: str,
+    node_port: int,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    discovery_interval: int = 5,
+    discovery_timeout: int = 30,
+    update_interval: int = 15,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    tailscale_api_key: str = None,
+    tailnet: str = None,
+    allowed_node_ids: List[str] = None,
+  ):
+    self.node_id = node_id
+    self.node_port = node_port
+    self.create_peer_handle = create_peer_handle
+    self.discovery_interval = discovery_interval
+    self.discovery_timeout = discovery_timeout
+    self.update_interval = update_interval
+    self.device_capabilities = device_capabilities
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
+    self.discovery_task = None
+    self.cleanup_task = None
+    self.tailscale_api_key = tailscale_api_key
+    self.tailnet = tailnet
+    self.allowed_node_ids = allowed_node_ids
+    self._device_id = None
+    self.update_task = None
+
+  async def start(self):
+    self.device_capabilities = device_capabilities()
+    self.discovery_task = asyncio.create_task(self.task_discover_peers())
+    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+    self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())
+
+  async def task_update_device_posture_attributes(self):
+    while True:
+      try:
+        await self.update_device_posture_attributes()
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Updated device posture attributes")
+      except Exception as e:
+        print(f"Error updating device posture attributes: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.update_interval)
+
+  async def get_device_id(self):
+    if self._device_id:
+      return self._device_id
+    self._device_id = await get_device_id()
+    return self._device_id
+
+  async def update_device_posture_attributes(self):
+    await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
+
+  async def task_discover_peers(self):
+    while True:
+      try:
+        devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
+        current_time = time.time()
+
+        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
+
+        if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
+        if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
+
+        for device in active_devices.values():
+          if device.name == self.node_id: continue
+          peer_host = device.addresses[0]
+          peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
+          if not peer_id:
+            if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
+            continue
+
+          if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+            if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+            continue
+
+          if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+            if not await new_peer_handle.health_check():
+              if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
+              continue
+
+            if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+            self.known_peers[peer_id] = (
+              new_peer_handle,
+              current_time,
+              current_time,
+            )
+          else:
+            if not await self.known_peers[peer_id][0].health_check():
+              if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
+              if peer_id in self.known_peers: del self.known_peers[peer_id]
+              continue
+            self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], current_time)
+
+      except Exception as e:
+        print(f"Error in discover peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)
+
+  async def stop(self):
+    if self.discovery_task:
+      self.discovery_task.cancel()
+    if self.cleanup_task:
+      self.cleanup_task.cancel()
+    if self.update_task:
+      self.update_task.cancel()
+    if self.discovery_task or self.cleanup_task or self.update_task:
+      await asyncio.gather(self.discovery_task, self.cleanup_task, self.update_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+
+  async def task_cleanup_peers(self):
+    while True:
+      try:
+        current_time = time.time()
+        peers_to_remove = []
+
+        peer_ids = list(self.known_peers.keys())
+        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
+
+        for peer_id, should_remove in zip(peer_ids, results):
+          if should_remove: peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            }
+          )
+
+        for peer_id in peers_to_remove:
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
+      except Exception as e:
+        print(f"Error in cleanup peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.discovery_interval)
+
+  async def check_peer(self, peer_id: str, current_time: float) -> bool:
+    peer_handle, connected_at, last_seen = self.known_peers.get(peer_id, (None, None, None))
+    if peer_handle is None: return False
+
+    try:
+      is_connected = await peer_handle.is_connected()
+      health_ok = await peer_handle.health_check()
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
+      return True
+
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
+    return should_remove

+ 125 - 0
build/lib/exo/networking/tailscale/tailscale_helpers.py

@@ -0,0 +1,125 @@
+import json
+import asyncio
+import aiohttp
+import re
+from typing import Dict, Any, Tuple, List, Optional
+from exo.helpers import DEBUG_DISCOVERY
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from datetime import datetime, timezone
+
+
+class Device:
+  def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
+    self.device_id = device_id
+    self.name = name
+    self.addresses = addresses
+    self.last_seen = last_seen
+
+  @classmethod
+  def from_dict(cls, data: Dict[str, Any]) -> 'Device':
+    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
+
+  @staticmethod
+  def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
+    if not date_string:
+      return None
+    return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
+
+
+async def get_device_id() -> str:
+  try:
+    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
+    stdout, stderr = await process.communicate()
+    if process.returncode != 0:
+      raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
+    if DEBUG_DISCOVERY >= 4: print(f"tailscale status: {stdout.decode()}")
+    data = json.loads(stdout.decode())
+    return data['Self']['ID']
+  except Exception as e:
+    raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
+
+
+async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
+  async with aiohttp.ClientSession() as session:
+    base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
+    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
+
+    attributes = {
+      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
+    }
+
+    for attr_name, attr_value in attributes.items():
+      url = f"{base_url}/{attr_name}"
+      data = {"value": str(attr_value).replace(' ', '_')}  # Ensure all values are strings for JSON
+      async with session.post(url, headers=headers, json=data) as response:
+        if response.status == 200:
+          if DEBUG_DISCOVERY >= 1: print(f"Updated device posture attribute {attr_name} for device {device_id}")
+        else:
+          print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
+
+
+async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
+  async with aiohttp.ClientSession() as session:
+    url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
+    headers = {'Authorization': f'Bearer {api_key}'}
+    async with session.get(url, headers=headers) as response:
+      if response.status == 200:
+        data = await response.json()
+        attributes = data.get("attributes", {})
+        node_id = attributes.get("custom:exo_node_id", "").replace('_', '-')
+        node_port = int(attributes.get("custom:exo_node_port", 0))
+        device_capabilities = DeviceCapabilities(
+          model=attributes.get("custom:exo_device_capability_model", "").replace('_', ' '),
+          chip=attributes.get("custom:exo_device_capability_chip", "").replace('_', ' '),
+          memory=int(attributes.get("custom:exo_device_capability_memory", 0)),
+          flops=DeviceFlops(
+            fp16=float(attributes.get("custom:exo_device_capability_flops_fp16", 0)),
+            fp32=float(attributes.get("custom:exo_device_capability_flops_fp32", 0)),
+            int8=float(attributes.get("custom:exo_device_capability_flops_int8", 0))
+          )
+        )
+        return node_id, node_port, device_capabilities
+      else:
+        print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
+        return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
+
+
+def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
+  result = {}
+  prefix = "custom:exo_"
+  for key, value in data.items():
+    if key.startswith(prefix):
+      attr_name = key.replace(prefix, "")
+      if attr_name in ["node_id", "node_port", "device_capability_chip", "device_capability_model"]:
+        result[attr_name] = value.replace('_', ' ')
+      elif attr_name in ["device_capability_memory", "device_capability_flops_fp16", "device_capability_flops_fp32", "device_capability_flops_int8"]:
+        result[attr_name] = float(value)
+  return result
+
+
+def sanitize_attribute(value: str) -> str:
+  # Replace invalid characters with underscores
+  sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
+  # Truncate to 50 characters
+  return sanitized_value[:50]
+
+
+async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
+  async with aiohttp.ClientSession() as session:
+    url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
+    headers = {"Authorization": f"Bearer {api_key}"}
+
+    async with session.get(url, headers=headers) as response:
+      response.raise_for_status()
+      data = await response.json()
+
+      devices = {}
+      for device_data in data.get("devices", []):
+        print("Device data: ", device_data)
+        device = Device.from_dict(device_data)
+        devices[device.name] = device
+
+      return devices

+ 43 - 0
build/lib/exo/networking/tailscale/test_tailscale_discovery.py

@@ -0,0 +1,43 @@
+import os
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
+from exo.networking.peer_handle import PeerHandle
+
+
+class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
+    self.tailnet = os.environ.get("TAILSCALE_TAILNET", "")
+    self.discovery = TailscaleDiscovery(
+      node_id="test_node",
+      node_port=50051,
+      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
+      tailscale_api_key=self.tailscale_api_key,
+      tailnet=self.tailnet
+    )
+    await self.discovery.start()
+
+  async def asyncTearDown(self):
+    await self.discovery.stop()
+
+  async def test_discovery(self):
+    # Wait for a short period to allow discovery to happen
+    await asyncio.sleep(15)
+
+    # Get discovered peers
+    peers = await self.discovery.discover_peers()
+
+    # Check if any peers were discovered
+    self.assertGreater(len(peers), 0, "No peers were discovered")
+
+    # Print discovered peers for debugging
+    print(f"Discovered peers: {[peer.id() for peer in peers]}")
+
+    # Check if discovered peers are instances of GRPCPeerHandle
+    print(peers)
+
+
+if __name__ == '__main__':
+  unittest.main()

+ 0 - 0
build/lib/exo/networking/udp/__init__.py


+ 77 - 0
build/lib/exo/networking/udp/test_udp_discovery.py

@@ -0,0 +1,77 @@
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.udp.udp_discovery import UDPDiscovery
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.orchestration.node import Node
+
+
+class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+
+class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, "localhost", 50053)
+    self.server2 = GRPCServer(self.node2, "localhost", 50054)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+    assert not await peers1[0].is_connected()
+    assert not await peers2[0].is_connected()
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    assert await peers1[0].is_connected()
+    assert await peers2[0].is_connected()
+
+    # Kill server1
+    await self.server1.stop()
+
+    assert await peers1[0].is_connected()
+    assert not await peers2[0].is_connected()
+
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 215 - 0
build/lib/exo/networking/udp/udp_discovery.py

@@ -0,0 +1,215 @@
+import asyncio
+import json
+import socket
+import time
+import traceback
+from typing import List, Dict, Callable, Tuple, Coroutine
+from exo.networking.discovery import Discovery
+from exo.networking.peer_handle import PeerHandle
+from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
+
+
+class ListenProtocol(asyncio.DatagramProtocol):
+  def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
+    super().__init__()
+    self.on_message = on_message
+    self.loop = asyncio.get_event_loop()
+
+  def connection_made(self, transport):
+    self.transport = transport
+
+  def datagram_received(self, data, addr):
+    asyncio.create_task(self.on_message(data, addr))
+
+
+class BroadcastProtocol(asyncio.DatagramProtocol):
+  def __init__(self, message: str, broadcast_port: int):
+    self.message = message
+    self.broadcast_port = broadcast_port
+
+  def connection_made(self, transport):
+    sock = transport.get_extra_info("socket")
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+    transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
+
+
+class UDPDiscovery(Discovery):
+  def __init__(
+    self,
+    node_id: str,
+    node_port: int,
+    listen_port: int,
+    broadcast_port: int,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    broadcast_interval: int = 1,
+    discovery_timeout: int = 30,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    allowed_node_ids: List[str] = None,
+  ):
+    self.node_id = node_id
+    self.node_port = node_port
+    self.listen_port = listen_port
+    self.broadcast_port = broadcast_port
+    self.create_peer_handle = create_peer_handle
+    self.broadcast_interval = broadcast_interval
+    self.discovery_timeout = discovery_timeout
+    self.device_capabilities = device_capabilities
+    self.allowed_node_ids = allowed_node_ids
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
+    self.broadcast_task = None
+    self.listen_task = None
+    self.cleanup_task = None
+
+  async def start(self):
+    self.device_capabilities = device_capabilities()
+    self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
+    self.listen_task = asyncio.create_task(self.task_listen_for_peers())
+    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+
+  async def stop(self):
+    if self.broadcast_task: self.broadcast_task.cancel()
+    if self.listen_task: self.listen_task.cancel()
+    if self.cleanup_task: self.cleanup_task.cancel()
+    if self.broadcast_task or self.listen_task or self.cleanup_task:
+      await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
+
+  async def task_broadcast_presence(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
+
+    while True:
+      # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
+      # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
+      for addr in get_all_ip_addresses():
+        message = json.dumps({
+          "type": "discovery",
+          "node_id": self.node_id,
+          "grpc_port": self.node_port,
+          "device_capabilities": self.device_capabilities.to_dict(),
+          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+        })
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
+
+        transport = None
+        try:
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
+          if DEBUG_DISCOVERY >= 3:
+            print(f"Broadcasting presence at ({addr})")
+        except Exception as e:
+          print(f"Error in broadcast presence ({addr}): {e}")
+        finally:
+          if transport:
+            try:
+              transport.close()
+            except Exception as e:
+              if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
+              if DEBUG_DISCOVERY >= 2: traceback.print_exc()
+      await asyncio.sleep(self.broadcast_interval)
+
+  async def on_listen_message(self, data, addr):
+    if not data:
+      return
+
+    decoded_data = data.decode("utf-8", errors="ignore")
+
+    # Check if the decoded data starts with a valid JSON character
+    if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
+      if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
+      return
+
+    try:
+      decoder = json.JSONDecoder(strict=False)
+      message = decoder.decode(decoded_data)
+    except json.JSONDecodeError as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}")
+      return
+
+    if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
+
+    if message["type"] == "discovery" and message["node_id"] != self.node_id:
+      peer_id = message["node_id"]
+      
+      # Skip if peer_id is not in allowed list
+      if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+        return
+
+      peer_host = addr[0]
+      peer_port = message["grpc_port"]
+      peer_prio = message["priority"]
+      device_capabilities = DeviceCapabilities(**message["device_capabilities"])
+
+      if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+        if peer_id in self.known_peers:
+          existing_peer_prio = self.known_peers[peer_id][3]
+          if existing_peer_prio >= peer_prio:
+            if DEBUG >= 1:
+              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
+            return
+        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
+        if not await new_peer_handle.health_check():
+          if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
+          return
+        if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
+        self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
+      else:
+        if not await self.known_peers[peer_id][0].health_check():
+          if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
+          if peer_id in self.known_peers: del self.known_peers[peer_id]
+          return
+        if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
+
+  async def task_listen_for_peers(self):
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
+    if DEBUG_DISCOVERY >= 2: print("Started listen task")
+
+  async def task_cleanup_peers(self):
+    while True:
+      try:
+        current_time = time.time()
+        peers_to_remove = []
+
+        peer_ids = list(self.known_peers.keys())
+        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
+
+        for peer_id, should_remove in zip(peer_ids, results):
+          if should_remove: peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
+              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
+            }
+          )
+
+        for peer_id in peers_to_remove:
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
+      except Exception as e:
+        print(f"Error in cleanup peers: {e}")
+        print(traceback.format_exc())
+      finally:
+        await asyncio.sleep(self.broadcast_interval)
+
+  async def check_peer(self, peer_id: str, current_time: float) -> bool:
+    peer_handle, connected_at, last_seen, prio = self.known_peers.get(peer_id, (None, None, None, None))
+    if peer_handle is None: return False
+
+    try:
+      is_connected = await peer_handle.is_connected()
+      health_ok = await peer_handle.health_check()
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
+      return True
+
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
+    return should_remove

+ 4 - 0
build/lib/exo/orchestration/__init__.py

@@ -0,0 +1,4 @@
+from .node import Node
+from .standard_node import StandardNode
+
+__all__ = ["Node", "StandardNode"]

+ 47 - 0
build/lib/exo/orchestration/node.py

@@ -0,0 +1,47 @@
+from typing import Optional, Tuple, List
+import numpy as np
+from abc import ABC, abstractmethod
+from exo.helpers import AsyncCallbackSystem
+from exo.inference.shard import Shard
+from exo.topology.topology import Topology
+
+
+class Node(ABC):
+  @abstractmethod
+  async def start(self, wait_for_peers: int = 0) -> None:
+    pass
+
+  @abstractmethod
+  async def stop(self) -> None:
+    pass
+
+  @abstractmethod
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
+    pass
+
+  @abstractmethod
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
+    pass
+
+  @abstractmethod
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    pass
+
+  @abstractmethod
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
+    pass
+
+  @property
+  @abstractmethod
+  def current_topology(self) -> Topology:
+    pass
+
+  @property
+  @abstractmethod
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+    pass
+
+  @property
+  @abstractmethod
+  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+    pass

+ 488 - 0
build/lib/exo/orchestration/standard_node.py

@@ -0,0 +1,488 @@
+import numpy as np
+import json
+import asyncio
+import uuid
+import time
+import traceback
+from typing import List, Dict, Optional, Tuple, Union, Set
+from exo.networking import Discovery, PeerHandle, Server
+from exo.inference.inference_engine import InferenceEngine, Shard
+from .node import Node
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import device_capabilities
+from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
+from exo import DEBUG
+from exo.helpers import AsyncCallbackSystem
+from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+
+class StandardNode(Node):
+  def __init__(
+    self,
+    _id: str,
+    server: Server,
+    inference_engine: InferenceEngine,
+    discovery: Discovery,
+    partitioning_strategy: PartitioningStrategy = None,
+    max_generate_tokens: int = 1024,
+    default_sample_temperature: float = 0.0,
+    topology_viz: Optional[TopologyViz] = None,
+    shard_downloader: Optional[HFShardDownloader] = None,
+  ):
+    self.id = _id
+    self.inference_engine = inference_engine
+    self.server = server
+    self.discovery = discovery
+    self.partitioning_strategy = partitioning_strategy
+    self.peers: List[PeerHandle] = {}
+    self.topology: Topology = Topology()
+    self.device_capabilities = device_capabilities()
+    self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
+    self.max_generate_tokens = max_generate_tokens
+    self.topology_viz = topology_viz
+    self.default_sample_temperature = default_sample_temperature
+    self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+    self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
+    self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.topology_inference_engines_pool: List[List[str]] = []
+    self.shard_downloader = shard_downloader
+
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(1.0))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
+  def on_node_status(self, request_id, opaque_status):
+    try:
+      status_data = json.loads(opaque_status)
+      if status_data.get("type", "") == "supported_inference_engines":
+        node_id = status_data.get("node_id")
+        engines = status_data.get("engines", [])
+        self.topology_inference_engines_pool.append(engines)
+      if status_data.get("type", "") == "node_status":
+        if status_data.get("status", "").startswith("start_"):
+          self.current_topology.active_node_id = status_data.get("node_id")
+        elif status_data.get("status", "").startswith("end_"):
+          if status_data.get("node_id") == self.current_topology.active_node_id:
+            self.current_topology.active_node_id = None
+      download_progress = None
+      if status_data.get("type", "") == "download_progress":
+        if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
+        download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
+        self.node_download_progress[status_data.get('node_id')] = download_progress
+      if self.topology_viz:
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
+    except Exception as e:
+      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      if DEBUG >= 1: traceback.print_exc()
+
+  def get_supported_inference_engines(self):
+    supported_engine_names = []
+    if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
+      supported_engine_names.append('mlx')
+      supported_engine_names.append('tinygrad')
+    else:
+      supported_engine_names.append('tinygrad')
+    return supported_engine_names
+
+  async def broadcast_supported_engines(self, supported_engines_names: List[str]):
+    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
+    await self.broadcast_opaque_status("", status_message)
+
+  def get_topology_inference_engines(self) -> List[List[str]]:
+    return self.topology_inference_engines_pool
+  
+  async def process_inference_result(
+    self,
+    shard,
+    result: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
+  ):
+    if shard.model_id != 'stable-diffusion-2-1-base':
+      if request_id not in self.buffered_token_output:
+        self.buffered_token_output[request_id] = ([], False)
+      is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+      if shard.is_last_layer() and not is_finished:
+        token = await self.inference_engine.sample(result, temp = self.default_sample_temperature)
+        self.buffered_token_output[request_id][0].append(token.item())
+        intermediate_result = self.buffered_token_output[request_id][0]
+        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+        is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+        forward = token.reshape(1, -1)
+      else:
+        forward = result
+    else:
+      await self.inference_engine.ensure_shard(shard)
+      is_finished = inference_state.get('is_finished', False)
+      intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
+      forward = result
+    if shard.is_last_layer():
+      asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
+      self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
+    if is_finished:
+      if shard.model_id != 'stable-diffusion-2-1-base':
+          self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+          intermediate_result = self.buffered_token_output[request_id][0]
+    else:
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
+
+    return np.array(intermediate_result)
+  
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+    inference_state: Optional[dict] = {},
+  ) -> Optional[np.ndarray]:
+    shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "request_id": request_id,
+        }),
+      )
+    )
+    start_time = time.perf_counter_ns()
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
+    end_time = time.perf_counter_ns()
+    elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+          "result_size": resp.size if resp is not None else 0,
+        }),
+      )
+    )
+    return resp
+
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
+    if request_id is None:
+      request_id = str(uuid.uuid4())
+    shard = self.get_current_shard(base_shard)
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if not shard.is_first_layer():
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
+      return None
+    else:
+      result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
+      ret = await self.process_inference_result(shard, result, request_id, inference_state) 
+      return result
+
+  async def process_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
+  ) -> Optional[np.ndarray]:
+    shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_tensor",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "tensor_size": tensor.size,
+          "tensor_shape": tensor.shape,
+          "request_id": request_id,
+        }),
+      )
+    )
+    start_time = time.perf_counter_ns()
+    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    end_time = time.perf_counter_ns()
+    elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_tensor",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+          "result_size": resp.size if resp is not None else 0,
+        }),
+      )
+    )
+    return resp
+
+  async def _process_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
+  ) -> Optional[np.ndarray]:
+    if request_id is None:
+      request_id = str(uuid.uuid4())
+    shard = self.get_current_shard(base_shard)
+
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
+    try:
+      result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
+      ret = await self.process_inference_result(shard, result, request_id, inference_state) 
+      return ret
+    except Exception as e:
+      print(f"Error processing tensor for shard {shard}: {e}")
+      traceback.print_exc()
+      return None
+
+  async def forward_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: str,
+    target_index: int,
+    inference_state: Optional[dict] = None,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+    inference_state: Optional[dict] = None,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id, inference_state)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
+
+  def get_partition_index(self, offset: int = 0):
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return None
+    partitions = self.partitioning_strategy.partition(self.topology)
+    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    if current_partition_index is None:
+      raise ValueError(f"No current partition found for node: {self.id}")
+    return (current_partition_index + offset) % len(partitions)
+
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+    return shards[index]
+
+  async def update_peers(self, wait_for_peers: int = 0) -> bool:
+    next_peers = await self.discovery.discover_peers(wait_for_peers)
+    current_peer_ids = {peer.id() for peer in self.peers}
+    next_peer_ids = {peer.id() for peer in next_peers}
+    peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
+    peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
+    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
+    peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
+
+    def _pretty(peers: List[PeerHandle]) -> List[str]:
+      return [f"{peer.id()}@{peer.addr()}" for peer in peers]
+
+    if DEBUG >= 2:
+      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    async def disconnect_with_timeout(peer, timeout=5):
+      try:
+        await asyncio.wait_for(peer.disconnect(), timeout)
+        return True
+      except Exception as e:
+        print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
+        traceback.print_exc()
+        return False
+
+    async def connect_with_timeout(peer, timeout=5):
+      try:
+        await asyncio.wait_for(peer.connect(), timeout)
+        return True
+      except Exception as e:
+        print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
+        traceback.print_exc()
+        return False
+
+    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
+
+    successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
+    failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
+    successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True]
+    failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False]
+    if DEBUG >= 1:
+      if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}")
+      if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}")
+      if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}")
+      if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
+
+    self.peers = next_peers
+    return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
+
+  async def select_best_inference_engine(self):
+    if self.inference_engine.__class__.__name__ == 'DummyInferenceEngine': return
+    supported_engines = self.get_supported_inference_engines()
+    await self.broadcast_supported_engines(supported_engines)
+    if len(self.get_topology_inference_engines()):
+      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
+
+  async def periodic_topology_collection(self, interval: int):
+    while True:
+      await asyncio.sleep(interval)
+      try:
+        did_peers_change = await self.update_peers()
+        if DEBUG >= 2: print(f"{did_peers_change=}")
+        if did_peers_change:
+          await self.collect_topology()
+          await self.select_best_inference_engine()
+      except Exception as e:
+        print(f"Error collecting topology: {e}")
+        traceback.print_exc()
+
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    if request_id not in self.buffered_token_output:
+      return None, False
+    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
+    next_topology = Topology()
+    next_topology.update_node(self.id, self.device_capabilities)
+
+    if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
+
+    prev_visited = visited.copy()
+    visited.add(self.id)
+    visited.update(p.id() for p in self.peers)
+
+    for peer in self.peers:
+      next_topology.update_node(peer.id(), peer.device_capabilities())
+      next_topology.add_edge(self.id, peer.id())
+
+      if peer.id() in prev_visited:
+        continue
+
+      if max_depth <= 0:
+        if DEBUG >= 2: print("Max depth reached. Skipping...")
+        continue
+
+      try:
+        other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
+        if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
+        self.topology.merge(other_topology)
+      except Exception as e:
+        print(f"Error collecting topology from {peer.id()}: {e}")
+        traceback.print_exc()
+
+    next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
+    self.topology = next_topology
+    if self.topology_viz:
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
+    return next_topology
+
+  @property
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+    return self._on_token
+
+  @property
+  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+    return self._on_opaque_status
+
+  def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
+    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
+    self.on_token.trigger_all(request_id, tokens, is_finished)
+  
+  async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    async def send_result_to_peer(peer):
+      try:
+        await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
+      except asyncio.TimeoutError:
+        print(f"Timeout broadcasting result to {peer.id()}")
+      except Exception as e:
+        print(f"Error broadcasting result to {peer.id()}: {e}")
+        traceback.print_exc()
+
+    await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+
+  async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+    if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
+
+    async def send_status_to_peer(peer):
+      try:
+        await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
+      except asyncio.TimeoutError:
+        print(f"Timeout sending opaque status to {peer.id()}")
+      except Exception as e:
+        print(f"Error sending opaque status to {peer.id()}: {e}")
+        traceback.print_exc()
+
+    await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
+    # in the case of opaque status, we also want to receive our own opaque statuses
+    self.on_opaque_status.trigger_all(request_id, status)
+
+  @property
+  def current_topology(self) -> Topology:
+    return self.topology
+
+  def handle_stable_diffusion(self, inference_state, result):
+    if inference_state['is_step_finished']:
+      inference_state['step']+=1
+    progress = [inference_state['step'],inference_state['total_steps']]
+    intermediate_result = result
+    if progress[0] == progress[1]:
+      intermediate_result = result
+    return intermediate_result, inference_state

+ 57 - 0
build/lib/exo/orchestration/test_node.py

@@ -0,0 +1,57 @@
+import unittest
+from unittest.mock import Mock, AsyncMock
+import numpy as np
+
+from .standard_node import StandardNode
+from exo.networking.peer_handle import PeerHandle
+
+
+class TestNode(unittest.IsolatedAsyncioTestCase):
+  def setUp(self):
+    self.mock_inference_engine = AsyncMock()
+    self.mock_server = AsyncMock()
+    self.mock_server.start = AsyncMock()
+    self.mock_server.stop = AsyncMock()
+    self.mock_discovery = AsyncMock()
+    self.mock_discovery.start = AsyncMock()
+    self.mock_discovery.stop = AsyncMock()
+    mock_peer1 = Mock(spec=PeerHandle)
+    mock_peer1.id.return_value = "peer1"
+    mock_peer2 = Mock(spec=PeerHandle)
+    mock_peer2.id.return_value = "peer2"
+    self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
+
+    self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
+
+  async def asyncSetUp(self):
+    await self.node.start()
+
+  async def asyncTearDown(self):
+    await self.node.stop()
+
+  async def test_node_initialization(self):
+    self.assertEqual(self.node.node_id, "test_node")
+    self.assertEqual(self.node.host, "localhost")
+    self.assertEqual(self.node.port, 50051)
+
+  async def test_node_start(self):
+    self.mock_server.start.assert_called_once_with("localhost", 50051)
+
+  async def test_node_stop(self):
+    await self.node.stop()
+    self.mock_server.stop.assert_called_once()
+
+  async def test_discover_and_connect_to_peers(self):
+    await self.node.discover_and_connect_to_peers()
+    self.assertEqual(len(self.node.peers), 2)
+    self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
+    self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
+
+  async def test_process_tensor_calls_inference_engine(self):
+    mock_peer = Mock()
+    self.node.peers = [mock_peer]
+
+    input_tensor = np.array([69, 1, 2])
+    await self.node.process_tensor(input_tensor, None)
+
+    self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)

+ 0 - 0
build/lib/exo/stats/__init__.py


+ 29 - 0
build/lib/exo/stats/metrics.py

@@ -0,0 +1,29 @@
+from exo.orchestration import Node
+from prometheus_client import start_http_server, Counter, Histogram
+import json
+
+# Create metrics to track time spent and requests made.
+PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
+PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
+PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
+
+
+def start_metrics_server(node: Node, port: int):
+  start_http_server(port)
+
+  def _on_opaque_status(request_id, opaque_status: str):
+    status_data = json.loads(opaque_status)
+    _type = status_data.get("type", "")
+    node_id = status_data.get("node_id", "")
+    if _type != "node_status":
+      return
+    status = status_data.get("status", "")
+
+    if status == "end_process_prompt":
+      PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
+    elif status == "end_process_tensor":
+      elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
+      PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
+      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9)  # Convert ns to seconds
+
+  node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 50 - 0
build/lib/exo/test_callbacks.py

@@ -0,0 +1,50 @@
+import asyncio
+from typing import Any, Callable
+from exo.helpers import AsyncCallbackSystem, AsyncCallback
+
+
+# Usage example
+async def main() -> None:
+  callback_system = AsyncCallbackSystem[str, Any]()
+
+  # Register callbacks
+  callback1 = callback_system.register("callback1")
+  callback2 = callback_system.register("callback2")
+
+  def on_next_callback(name: str) -> Callable[..., None]:
+    def callback(*args: Any) -> None:
+      print(f"{name} received values: {args}")
+
+    return callback
+
+  callback1.on_next(on_next_callback("Callback1"))
+  callback2.on_next(on_next_callback("Callback2"))
+
+  async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
+    try:
+      result = await callback.wait(condition, timeout=2)
+      print(f"{name} wait completed with result: {result}")
+    except asyncio.TimeoutError:
+      print(f"{name} wait timed out")
+
+  # Trigger all callbacks at once
+  callback_system.trigger_all("Hello", 42, True)
+
+  # Wait for all callbacks with different conditions
+  await asyncio.gather(
+    wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
+    wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True),
+  )
+
+  # Trigger individual callback
+  callback_system.trigger("callback2", "World", -10, False)
+
+  # Demonstrate timeout
+  new_callback = callback_system.register("new_callback")
+  new_callback.on_next(on_next_callback("NewCallback"))
+  await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
+
+  callback_system.trigger("callback2", "World", 200, False)
+
+
+asyncio.run(main())

+ 130 - 0
build/lib/exo/tinychat/common.css

@@ -0,0 +1,130 @@
+/* make it responsive */
+@media(min-width: 852px) {
+  body {
+    font-size: 14px;
+  }
+}
+@media(max-width: 852px) {
+  body {
+    font-size: 12px;
+  }
+}
+
+/* resets */
+html, body {
+  width: 100%;
+  height: 100%;
+}
+
+*::-webkit-scrollbar {
+  display: none;
+}
+
+* {
+  -ms-overflow-style: none;
+  scrollbar-width: none;
+}
+
+* {
+  -moz-box-sizing: border-box;
+  -webkit-box-sizing: border-box;
+  box-sizing: border-box;
+}
+
+/* default */
+body {
+  margin: 0;
+  background-color: var(--primary-bg-color);
+  color: var(--foreground-color);
+}
+
+h1, h2, h3, h4, h5, h6 {
+  margin: 0em;
+}
+
+hr {
+  width: 92%;
+}
+
+button {
+  cursor: pointer;
+  border: none;
+  background-color: transparent;
+}
+button:hover {
+}
+button:active {
+}
+
+/* components */
+.container {
+  margin: 0 auto;
+  padding: 1rem;
+}
+
+.centered {
+  display: flex;
+  flex-direction: column;
+  justify-content: center;
+  align-items: center;
+}
+
+.centered-w-only {
+  position: absolute;
+  left: 50%;
+  transform: translateX(-50%);
+}
+
+.centered-h-only {
+  position: absolute;
+  top: 50%;
+  transform: translateY(-50%);
+}
+
+.card {
+  padding: 0;
+}
+
+.card-header {
+  padding: 0.5rem 1rem;
+}
+
+.card-container {
+  width: 96vw;
+  height: 100%;
+  gap: 1rem;
+  display: flex;
+  flex-direction: row;
+  flex-wrap: wrap;
+  justify-content: center;
+  align-items: center;
+}
+
+.clean-a {
+  text-decoration: underline;
+  text-decoration-color: #006fc1;
+  text-decoration-thickness: 2px;
+  color: inherit;
+}
+
+.hover-underline {
+  text-decoration: underline;
+  text-decoration-color: #228039;
+  text-decoration-thickness: 2px;
+  color: inherit;
+}
+
+.flex-horizontal {
+  display: flex;
+  flex-direction: row;
+  justify-content: space-between;
+  align-items: center;
+}
+
+.vertical-separator {
+  padding: 0 0.5rem;
+}
+
+[x-cloak] {
+  display: none !important;
+}

+ 25 - 0
build/lib/exo/tinychat/favicon.svg

@@ -0,0 +1,25 @@
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges">
+  <g id="logo">
+    <!-- t -->
+    <polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
+    <!-- i -->
+    <polygon points="40,40 40,20 50,20 50,40" />
+    <polygon points="40,10 40,0 50,0 50,10" />
+    <!-- n -->
+    <polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
+    <!-- y -->
+    <polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
+  </g>
+  <style>
+  @media (prefers-color-scheme: dark) {
+    #logo {
+      fill: #fff;
+    }
+  }
+  @media (prefers-color-scheme: light) {
+    #logo {
+      fill: #000;
+    }
+  }
+  </style>
+</svg>

+ 484 - 0
build/lib/exo/tinychat/index.css

@@ -0,0 +1,484 @@
+/* define colors */
+:root {
+  --primary-color: #fff;
+  --secondary-color: #2a2a2a;
+  --secondary-color-transparent: #ffffff66;
+  --primary-bg-color: #1a1a1a;
+  --foreground-color: #f0f0f0;
+  --red-color: #a52e4d;
+}
+
+main {
+  width: 100%;
+  height: 100%;
+
+  display: flex;
+  flex-direction: column;
+
+  place-items: center;
+}
+
+.home {
+  width: 100%;
+  height: 90%;
+
+  margin-bottom: 10rem;
+}
+
+.title {
+  font-size: 3rem;
+  margin: 1rem 0;
+  margin-top: 3rem;
+}
+
+.histories-container-container {
+  width: 100%;
+  max-height: 75%;
+
+  position: relative;
+}
+
+.histories-container {
+  overflow-y: auto;
+  overflow-x: hidden;
+  width: 100%;
+  height: 100%;
+
+  display: flex;
+  flex-direction: column;
+  gap: 1rem;
+  align-items: center;
+
+  margin: 0;
+  padding: 3rem 1rem;
+}
+
+.histories-start {
+  height: 3rem;
+  width: 100%;
+
+  z-index: 999;
+  top: 0;
+  position: absolute;
+
+  background: linear-gradient(
+    180deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
+}
+.histories-end {
+  height: 3rem;
+  width: 100%;
+
+  z-index: 999;
+  bottom: 0;
+  position: absolute;
+
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
+}
+
+.history {
+  padding: 1rem;
+  width: 100%;
+  max-width: 40rem;
+
+  background-color: var(--secondary-color);
+  border-radius: 10px;
+  border-left: 2px solid var(--primary-color);
+
+  cursor: pointer;
+
+  transform: translateX(calc(1px * var(--tx, 0)));
+  opacity: var(--opacity, 1);
+}
+.history:hover {
+  background-color: var(--secondary-color);
+}
+
+.history-delete-button {
+  position: absolute;
+  top: 0;
+  right: 0;
+  padding: 0.5rem;
+  margin: 0;
+  outline: none;
+  border: none;
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  border-radius: 0 0 0 10px;
+  cursor: pointer;
+  transition: 0.2s;
+}
+.history-delete-button:hover {
+  background-color: var(--secondary-color);
+  padding: 0.75rem;
+}
+
+.messages {
+  overflow-y: auto;
+  height: 100%;
+  width: 100%;
+  max-width: 1200px;
+
+  display: flex;
+  flex-direction: column;
+  gap: 1rem;
+  align-items: center;
+  padding-top: 1rem;
+  padding-bottom: 11rem;
+}
+
+.message {
+  max-width: 75%;
+  padding: 0.5rem 1rem;
+  border-radius: 20px;
+}
+.message-role-assistant {
+  background-color: var(--secondary-color);
+  margin-right: auto;
+  color: #fff;
+}
+.message-role-user {
+  margin-left: auto;
+  background-color: var(--primary-color);
+  color: #000;
+}
+.download-progress {
+  margin-bottom: 12em;
+  overflow-y: auto;
+  min-height: 350px;
+  padding: 2rem;
+}
+.message > pre {
+  white-space: pre-wrap;
+}
+
+.progress-bar-container {
+  width: 100%;
+  background-color: #e0e0e0;
+  border-radius: 4px;
+  margin: 10px 0;
+}
+.progress-bar {
+  height: 20px;
+  border-radius: 4px;
+  transition: width 0.5s ease-in-out;
+}
+.progress-bar.complete {
+  background-color: #4CAF50;
+}
+.progress-bar.in-progress {
+  background-color: #2196F3;
+}
+
+.toast {
+    width: 100%;
+    background-color: #fc2a2a;
+    color: #fff;
+    text-align: left;
+    border-radius: 2px;
+    padding: 16px;
+    position: fixed;
+    z-index: 9999;
+    top: 0;
+    left: 0;
+    right: 0;
+    display: flex;
+    flex-direction: column;
+    white-space: pre-wrap;
+    font-family: monospace;
+}
+
+.toast-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    width: 100%;
+}
+
+.toast-error-message {
+    flex-grow: 1;
+}
+
+.toast-header-buttons {
+    display: flex;
+    align-items: center;
+    gap: 16px;
+    margin-left: 24px;
+}
+
+.toast-expand-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1em;
+}
+
+.toast-close-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1.2em;
+    line-height: 1;
+}
+
+.toast-expand-button:hover,
+.toast-close-button:hover {
+    opacity: 0.8;
+}
+
+.toast-content {
+    margin-top: 10px;
+    padding: 10px;
+    background-color: rgba(0, 0, 0, 0.2);
+    border-radius: 4px;
+}
+
+.hljs {
+  width: 100%;
+  position: relative;
+  border-radius: 10px;
+  /* wrap code blocks */
+  white-space: pre-wrap;
+}
+/* put clipboard button in the top right corner of the code block */
+.clipboard-button {
+  position: absolute;
+  top: 0;
+  right: 0;
+  padding: 0.5rem;
+  margin: 0;
+  outline: none;
+  border: none;
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  border-radius: 0 0 0 10px;
+  cursor: pointer;
+  transition: 0.2s;
+}
+.clipboard-button:hover {
+  background-color: var(--secondary-color);
+  padding: 0.75rem;
+}
+
+.input-container {
+  position: absolute;
+  bottom: 0;
+
+  /* linear gradient from background-color to transparent on the top */
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
+
+  width: 100%;
+  max-width: 1200px;
+  display: flex;
+  flex-direction: column;
+  justify-content: center;
+  align-items: center;
+  z-index: 999;
+}
+
+.input-performance {
+  margin-top: 4rem;
+
+  display: flex;
+  flex-direction: row;
+  gap: 1rem;
+}
+
+.input-performance-point {
+  display: flex;
+  flex-direction: row;
+  place-items: center;
+  gap: 0.5rem;
+}
+.input-performance-point > p {
+  height: 1rem;
+  line-height: normal;
+}
+
+.input {
+  width: 90%;
+  min-height: 3rem;
+  flex-shrink: 0;
+
+  display: flex;
+  flex-direction: row;
+  justify-content: center;
+  gap: 0.5rem;
+
+  align-items: flex-end;
+  margin-bottom: 2rem;
+}
+
+.input-form {
+  width: 100%;
+  padding: 1rem;
+  min-height: 3rem;
+  max-height: 8rem;
+
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  border-radius: 10px;
+  border: none;
+  resize: none;
+  outline: none;
+}
+
+.input-button {
+  height: 3rem;
+  width: 4rem;
+
+  background-color: var(--primary-color);
+  color: var(--secondary-color);
+  border-radius: 10px;
+  padding: 0.5rem;
+  cursor: pointer;
+}
+.input-button:hover {
+  background-color: var(--secondary-color-transparent);
+}
+.input-button:disabled {
+  background-color: var(--secondary-color);
+  cursor: not-allowed;
+}
+
+/* wrap text */
+p {
+  white-space: pre-wrap;
+}
+
+/* fonts */
+.megrim-regular {
+  font-family: "Megrim", system-ui;
+  font-weight: 400;
+  font-style: normal;
+}
+
+.monospace {
+  font-family: monospace;
+}
+
+.model-selector {
+  display: flex;
+  justify-content: center;
+  padding: 20px 0;
+}
+.model-selector select {
+  padding: 10px 20px;
+  font-size: 16px;
+  border: 1px solid #ccc;
+  border-radius: 5px;
+  background-color: #f8f8f8;
+  cursor: pointer;
+}
+.model-selector select:focus {
+  outline: none;
+  border-color: #007bff;
+  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+}
+
+/* Image upload button styles */
+.image-input-button {
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  border: none;
+  border-radius: 50%;
+  width: 40px;
+  height: 40px;
+  font-size: 18px;
+  cursor: pointer;
+  transition: all 0.3s ease;
+  display: flex;
+  align-items: center;
+  justify-content: center;
+  margin-right: 10px;
+}
+
+.image-input-button:hover {
+  background-color: var(--secondary-color-transparent);
+  transform: scale(1.1);
+}
+
+.image-input-button:focus {
+  outline: none;
+  box-shadow: 0 0 0 3px rgba(var(--secondary-color-rgb), 0.5);
+}
+
+.image-input-button i {
+  transition: all 0.3s ease;
+}
+
+.image-input-button:hover i {
+  transform: scale(1.2);
+}
+
+/* Hidden file input styles */
+#image-upload {
+  display: none;
+}
+
+.image-preview-container {
+  position: relative;
+  display: inline-block;
+  margin-right: 10px;
+}
+
+.image-preview {
+  max-width: 100px;
+  max-height: 100px;
+  object-fit: cover;
+  border-radius: 5px;
+}
+
+.remove-image-button {
+  position: absolute;
+  top: -5px;
+  right: -5px;
+  background-color: rgba(255, 255, 255, 0.8);
+  border: none;
+  border-radius: 50%;
+  padding: 2px 5px;
+  cursor: pointer;
+}
+
+.message > p > img {
+  max-width: 100%;
+  max-height: 100%;
+  object-fit: contain;
+}
+
+.clear-history-button {
+  background-color: var(--red-color);
+  color: white;
+  padding: 10px 20px;
+  border-radius: 5px;
+  display: flex;
+  align-items: center;
+  gap: 8px;
+  transition: all 0.3s ease;
+  margin: 1rem auto;
+  border: none;
+  cursor: pointer;
+}
+
+.clear-history-button:hover {
+  opacity: 0.8;
+  transform: scale(1.05);
+}
+
+.clear-history-button i {
+  font-size: 14px;
+}

+ 255 - 0
build/lib/exo/tinychat/index.html

@@ -0,0 +1,255 @@
+<!DOCTYPE html>
+
+<head>
+<title>tinychat</title>
+<meta content="width=device-width, initial-scale=1" name="viewport"/>
+<link href="favicon.svg" rel="icon" type="image/svg+xml"/>
+<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
+<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
+<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
+<script defer="" src="/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
+<script defer="" src="/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
+<script src="/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
+<script src="/static/unpkg.com/marked@13.0.0/marked.min.js"></script>
+<script src="/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
+<script src="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
+<script src="/index.js"></script>
+<link href="/static/fonts.googleapis.com" rel="preconnect"/>
+<link crossorigin="" href="/static/fonts.gstatic.com" rel="preconnect"/>
+<link href="/static/fonts.googleapis.com/css2" rel="stylesheet"/>
+<link href="/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css" rel="stylesheet"/>
+<link crossorigin="anonymous" href="/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css" integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A==" referrerpolicy="no-referrer" rel="stylesheet">
+<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
+<link href="/index.css" rel="stylesheet"/>
+<link href="/common.css" rel="stylesheet"/>
+</head>
+<body>
+<main x-data="state" x-init="console.log(endpoint)">
+     <!-- Error Toast -->
+    <div x-show="errorMessage" x-transition.opacity class="toast">
+        <div class="toast-header">
+            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <div class="toast-header-buttons">
+                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
+                        class="toast-expand-button" 
+                        x-show="errorMessage.stack">
+                    <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
+                </button>
+                <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
+                    <i class="fas fa-times"></i>
+                </button>
+            </div>
+        </div>
+        <div class="toast-content" x-show="errorExpanded" x-transition>
+            <span x-text="errorMessage.stack"></span>
+        </div>
+    </div>
+<div class="model-selector">
+  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
+  </select>
+</div>
+<div @popstate.window="
+      if (home === 2) {
+        home = -1;
+        cstate = { time: null, messages: [], selectedModel: 'llama-3.1-8b' };
+        time_till_first = 0;
+        tokens_per_second = 0;
+        total_tokens = 0;
+      }
+    " class="home centered" x-effect="
+      $refs.inputForm.focus();
+      if (home === 1) setTimeout(() =&gt; home = 2, 100);
+      if (home === -1) setTimeout(() =&gt; home = 0, 100);
+    " x-show="home === 0" x-transition="">
+<h1 class="title megrim-regular">tinychat</h1>
+<template x-if="histories.length">
+  <button 
+    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
+    class="clear-history-button">
+    <i class="fas fa-trash"></i> Clear All History
+  </button>
+</template>
+<div class="histories-container-container">
+<template x-if="histories.length">
+<div class="histories-start"></div>
+</template>
+<div class="histories-container" x-intersect="
+          $el.scrollTo({ top: 0, behavior: 'smooth' });
+        ">
+<template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
+<div @click="
+            cstate = _state;
+            if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
+            // updateTotalTokens(cstate.messages);
+            home = 1;
+            // ensure that going back in history will go back to home
+            window.history.pushState({}, '', '/');
+          " @touchend="
+            if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
+            $el.style.setProperty('--tx', 0);
+            $el.style.setProperty('--opacity', 1);
+          " @touchmove="
+            $el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
+            $el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
+          " @touchstart="
+            otx = $event.changedTouches[0].clientX;
+          " class="history" x-data="{ otx: 0, trigger: 75 }">
+<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
+<p x-text="$truncate(_state.messages[0].content, 80)"></p>
+<!-- delete button -->
+<button @click.stop="removeHistory(_state);" class="history-delete-button">
+<i class="fas fa-trash"></i>
+</button>
+</div>
+</template>
+</div>
+<template x-if="histories.length">
+<div class="histories-end"></div>
+</template>
+</div>
+</div>
+<div class="messages" x-init="
+      $watch('cstate', value =&gt; {
+        $el.innerHTML = '';
+        value.messages.forEach(({ role, content }) =&gt; {
+          const div = document.createElement('div');
+          div.className = `message message-role-${role}`;
+          try {
+              if (content.includes('![Generated Image]')) {
+                const imageUrl = content.match(/\((.*?)\)/)[1];
+                const img = document.createElement('img');
+                img.src = imageUrl;
+                img.alt = 'Generated Image';
+                img.onclick = async () => {
+                  try {
+                    const response = await fetch(img.src);
+                    const blob = await response.blob();
+                    const file = new File([blob], 'image.png', { type: 'image/png' });
+                    handleImageUpload({ target: { files: [file] } });
+                  } catch (error) {
+                    console.error('Error fetching image:', error);
+                  }
+                };
+                div.appendChild(img);
+              } else {
+                div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+              }
+          } catch (e) {
+            console.log(content);
+            console.error(e);
+          }
+
+          // add a clipboard button to all code blocks
+          const codeBlocks = div.querySelectorAll('.hljs');
+          codeBlocks.forEach(codeBlock =&gt; {
+            const button = document.createElement('button');
+            button.className = 'clipboard-button';
+            button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;';
+            button.onclick = () =&gt; {
+              // navigator.clipboard.writeText(codeBlock.textContent);
+              const range = document.createRange();
+              range.setStartBefore(codeBlock);
+              range.setEndAfter(codeBlock);
+              window.getSelection()?.removeAllRanges();
+              window.getSelection()?.addRange(range);
+              document.execCommand('copy');
+              window.getSelection()?.removeAllRanges();
+
+              button.innerHTML = '&lt;i class=\'fas fa-check\'&gt;&lt;/i&gt;';
+              setTimeout(() =&gt; button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;', 1000);
+            };
+            codeBlock.appendChild(button);
+          });
+
+          $el.appendChild(div);
+        });
+
+        $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
+      });
+    " x-intersect="
+      $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
+    " x-ref="messages" x-show="home === 2" x-transition="">
+</div>
+
+<!-- Download Progress Section -->
+<template x-if="downloadProgress && downloadProgress.length > 0">
+  <div class="download-progress message message-role-assistant">
+    <h2>Download Progress</h2>
+    <br>
+    <template x-for="(progress, index) in downloadProgress" :key="index">
+      <div class="download-progress-node">
+        <br>
+        <h3 x-text="`Download ${index + 1}`"></h3>
+        <p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
+        <p><strong>Status:</strong> <span x-text="progress.status"></span></p>
+        <div class="progress-bar-container">
+          <div class="progress-bar" 
+               :class="progress.isComplete ? 'complete' : 'in-progress'"
+               :style="`width: ${progress.percentage}%;`">
+          </div>
+        </div>
+        <template x-if="!progress.isComplete">
+          <div>
+            <p><strong>Progress:</strong> <span x-text="`${progress.downloaded_bytes_display} / ${progress.total_bytes_display} (${progress.percentage}%)`"></span></p>
+            <p><strong>Speed:</strong> <span x-text="progress.overall_speed_display || 'N/A'"></span></p>
+            <p><strong>ETA:</strong> <span x-text="progress.overall_eta_display || 'N/A'"></span></p>
+          </div>
+        </template>
+      </div>
+    </template>
+  </div>
+</template>
+
+
+<div class="input-container">
+<div class="input-performance">
+<span class="input-performance-point">
+<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
+<p class="megrim-regular">SEC TO FIRST TOKEN</p>
+</span>
+<span class="input-performance-point">
+<p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
+<p class="megrim-regular">TOKENS/SEC</p>
+</span>
+<span class="input-performance-point">
+<p class="monospace" x-text="total_tokens"></p>
+<p class="megrim-regular">TOKENS</p>
+</span>
+</div>
+<div class="input">
+<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
+<i class="fas fa-image"></i>
+</button>
+<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
+<div class="image-preview-container" x-show="imagePreview">
+<img :src="imagePreview" alt="Uploaded Image" class="image-preview"/>
+<button @click="imagePreview = null; imageUrl = null;" class="remove-image-button">
+<i class="fas fa-times"></i>
+</button>
+</div>
+<textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
+            home = (home === 0) ? 1 : home
+            if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;
+
+            if ($el.value !== '') {
+              const messages = [...cstate.messages];
+              messages.push({ role: 'user', content: $el.value });
+              // updateTotalTokens(messages);
+            } else {
+              if (cstate.messages.length === 0) total_tokens = 0;
+              // else updateTotalTokens(cstate.messages);
+            }
+          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)" autofocus="" class="input-form" id="input-form" rows="1" x-autosize="" x-effect="
+            console.log(generating);
+            if (!generating) $nextTick(() =&gt; {
+              $el.focus();
+              setTimeout(() =&gt; $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
+            });
+          " x-ref="inputForm"></textarea>
+<button :disabled="generating" @click="await handleSend()" class="input-button">
+<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
+</button>
+</div>
+</div>
+</main>
+</body>

+ 687 - 0
build/lib/exo/tinychat/index.js

@@ -0,0 +1,687 @@
+document.addEventListener("alpine:init", () => {
+  Alpine.data("state", () => ({
+    // current state
+    cstate: {
+      time: null,
+      messages: [],
+      selectedModel: 'llama-3.2-1b',
+    },    
+
+    // historical state
+    histories: JSON.parse(localStorage.getItem("histories")) || [],
+
+    home: 0,
+    generating: false,
+    endpoint: `${window.location.origin}/v1`,
+    errorMessage: null,
+    errorExpanded: false,
+    errorTimeout: null,
+
+    // performance tracking
+    time_till_first: 0,
+    tokens_per_second: 0,
+    total_tokens: 0,
+
+    // image handling
+    imagePreview: null,
+
+    // download progress
+    downloadProgress: null,
+    downloadProgressInterval: null, // To keep track of the polling interval
+
+    // Pending message storage
+    pendingMessage: null,
+
+    init() {
+      // Clean up any pending messages
+      localStorage.removeItem("pendingMessage");
+
+      // Start polling for download progress
+      this.startDownloadProgressPolling();
+    },
+
+    removeHistory(cstate) {
+      const index = this.histories.findIndex((state) => {
+        return state.time === cstate.time;
+      });
+      if (index !== -1) {
+        this.histories.splice(index, 1);
+        localStorage.setItem("histories", JSON.stringify(this.histories));
+      }
+    },
+
+    clearAllHistory() {
+      this.histories = [];
+      localStorage.setItem("histories", JSON.stringify([]));
+    },
+
+    // Utility functions
+    formatBytes(bytes) {
+      if (bytes === 0) return '0 B';
+      const k = 1024;
+      const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
+      const i = Math.floor(Math.log(bytes) / Math.log(k));
+      return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
+    },
+
+    formatDuration(seconds) {
+      if (seconds === null || seconds === undefined || isNaN(seconds)) return '';
+      const h = Math.floor(seconds / 3600);
+      const m = Math.floor((seconds % 3600) / 60);
+      const s = Math.floor(seconds % 60);
+      if (h > 0) return `${h}h ${m}m ${s}s`;
+      if (m > 0) return `${m}m ${s}s`;
+      return `${s}s`;
+    },
+
+    async populateSelector() {
+      try {
+        const response = await fetch(`${window.location.origin}/modelpool`);
+        const responseText = await response.text(); // Get raw response text first
+        
+        if (!response.ok) {
+          throw new Error(`HTTP error! status: ${response.status}`);
+        }
+        
+        // Try to parse the response text
+        let responseJson;
+        try {
+          responseJson = JSON.parse(responseText);
+        } catch (parseError) {
+          console.error('Failed to parse JSON:', parseError);
+          throw new Error(`Invalid JSON response: ${responseText}`);
+        }
+
+        const sel = document.querySelector(".model-select");
+        if (!sel) {
+          throw new Error("Could not find model selector element");
+        }
+
+        // Clear the current options and add new ones
+        sel.innerHTML = '';
+          
+        const modelDict = responseJson["model pool"];
+        if (!modelDict) {
+          throw new Error("Response missing 'model pool' property");
+        }
+
+        Object.entries(modelDict).forEach(([key, value]) => {
+          const opt = document.createElement("option");
+          opt.value = key;
+          opt.textContent = value;
+          sel.appendChild(opt);
+        });
+
+        // Set initial value to the first model
+        const firstKey = Object.keys(modelDict)[0];
+        if (firstKey) {
+          sel.value = firstKey;
+          this.cstate.selectedModel = firstKey;
+        }
+      } catch (error) {
+        console.error("Error populating model selector:", error);
+        this.errorMessage = `Failed to load models: ${error.message}`;
+      }
+    },
+
+    async handleImageUpload(event) {
+      const file = event.target.files[0];
+      if (file) {
+        const reader = new FileReader();
+        reader.onload = (e) => {
+          this.imagePreview = e.target.result;
+          this.imageUrl = e.target.result; // Store the image URL
+          // Add image preview to the chat
+          this.cstate.messages.push({
+            role: "user",
+            content: `![Uploaded Image](${this.imagePreview})`,
+          });
+        };
+        reader.readAsDataURL(file);
+      }
+    },
+
+
+    async handleSend() {
+      try {
+        const el = document.getElementById("input-form");
+        const value = el.value.trim();
+        if (!value && !this.imagePreview) return;
+
+        if (this.generating) return;
+        this.generating = true;
+        if (this.home === 0) this.home = 1;
+
+        // ensure that going back in history will go back to home
+        window.history.pushState({}, "", "/");
+
+        // add message to list
+        if (value) {
+          this.cstate.messages.push({ role: "user", content: value });
+        }
+
+        // clear textarea
+        el.value = "";
+        el.style.height = "auto";
+        el.style.height = el.scrollHeight + "px";
+
+        localStorage.setItem("pendingMessage", value);
+        this.processMessage(value);
+      } catch (error) {
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
+        this.generating = false;
+      }
+    },
+
+    async processMessage(value) {
+      try {
+        // reset performance tracking
+        const prefill_start = Date.now();
+        let start_time = 0;
+        let tokens = 0;
+        this.tokens_per_second = 0;
+
+        // prepare messages for API request
+        let apiMessages = this.cstate.messages.map(msg => {
+          if (msg.content.startsWith('![Uploaded Image]')) {
+            return {
+              role: "user",
+              content: [
+                {
+                  type: "image_url",
+                  image_url: {
+                    url: this.imageUrl
+                  }
+                },
+                {
+                  type: "text",
+                  text: value // Use the actual text the user typed
+                }
+              ]
+            };
+          } else {
+            return {
+              role: msg.role,
+              content: msg.content
+            };
+          }
+        });
+        
+        if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
+          // Send a request to the image generation endpoint
+          console.log(apiMessages[apiMessages.length - 1].content)
+          console.log(this.cstate.selectedModel)  
+          console.log(this.endpoint)
+          const response = await fetch(`${this.endpoint}/image/generations`, {
+            method: "POST",
+            headers: {
+              "Content-Type": "application/json",
+            },
+            body: JSON.stringify({
+              "model": 'stable-diffusion-2-1-base',
+              "prompt": apiMessages[apiMessages.length - 1].content,
+              "image_url": this.imageUrl
+            }),
+          });
+      
+          if (!response.ok) {
+            throw new Error("Failed to fetch");
+          }
+          const reader = response.body.getReader();
+          let done = false;
+          let gottenFirstChunk = false;
+  
+          while (!done) {
+            const { value, done: readerDone } = await reader.read();
+            done = readerDone;
+            const decoder = new TextDecoder();
+  
+            if (value) {
+              // Assume non-binary data (text) comes first
+              const chunk = decoder.decode(value, { stream: true });
+              const parsed = JSON.parse(chunk);
+              console.log(parsed)
+  
+              if (parsed.progress) {
+                if (!gottenFirstChunk) {
+                  this.cstate.messages.push({ role: "assistant", content: "" });
+                  gottenFirstChunk = true;
+                }
+                this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
+              }
+              else if (parsed.images) {
+                if (!gottenFirstChunk) {
+                  this.cstate.messages.push({ role: "assistant", content: "" });
+                  gottenFirstChunk = true;
+                }
+                const imageUrl = parsed.images[0].url;
+                console.log(imageUrl)
+                this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
+              }
+            }
+          }
+        }
+        
+        else{        
+          const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
+          if (containsImage) {
+            // Map all messages with string content to object with type text
+            apiMessages = apiMessages.map(msg => {
+              if (typeof msg.content === 'string') {
+                return {
+                  ...msg,
+                  content: [
+                    {
+                      type: "text",
+                      text: msg.content
+                    }
+                  ]
+                };
+              }
+              return msg;
+            });
+          }
+
+          console.log(apiMessages)
+          //start receiving server sent events
+          let gottenFirstChunk = false;
+          for await (
+            const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
+          ) {
+            if (!gottenFirstChunk) {
+              this.cstate.messages.push({ role: "assistant", content: "" });
+              gottenFirstChunk = true;
+            }
+
+            // add chunk to the last message
+            this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
+
+            // calculate performance tracking
+            tokens += 1;
+            this.total_tokens += 1;
+            if (start_time === 0) {
+              start_time = Date.now();
+              this.time_till_first = start_time - prefill_start;
+            } else {
+              const diff = Date.now() - start_time;
+              if (diff > 0) {
+                this.tokens_per_second = tokens / (diff / 1000);
+              }
+            }
+          }
+        }
+        // Clean the cstate before adding it to histories
+        const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
+        cleanedCstate.messages = cleanedCstate.messages.map(msg => {
+          if (Array.isArray(msg.content)) {
+            return {
+              ...msg,
+              content: msg.content.map(item =>
+                item.type === 'image_url' ? { type: 'image_url', image_url: { url: '[IMAGE_PLACEHOLDER]' } } : item
+              )
+            };
+          }
+          return msg;
+        });
+
+        // Update the state in histories or add it if it doesn't exist
+        const index = this.histories.findIndex((cstate) => cstate.time === cleanedCstate.time);
+        cleanedCstate.time = Date.now();
+        if (index !== -1) {
+          // Update the existing entry
+          this.histories[index] = cleanedCstate;
+        } else {
+          // Add a new entry
+          this.histories.push(cleanedCstate);
+        }
+        console.log(this.histories)
+        // update in local storage
+        try {
+          localStorage.setItem("histories", JSON.stringify(this.histories));
+        } catch (error) {
+          console.error("Failed to save histories to localStorage:", error);
+        }
+      } catch (error) {
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
+      } finally {
+        this.generating = false;
+      }
+    },
+
+    async handleEnter(event) {
+      // if shift is not pressed
+      if (!event.shiftKey) {
+        event.preventDefault();
+        await this.handleSend();
+      }
+    },
+
+    updateTotalTokens(messages) {
+      fetch(`${this.endpoint}/chat/token/encode`, {
+        method: "POST",
+        headers: { "Content-Type": "application/json" },
+        body: JSON.stringify({ messages }),
+      }).then((response) => response.json()).then((data) => {
+        this.total_tokens = data.length;
+      }).catch(console.error);
+    },
+
+    async *openaiChatCompletion(model, messages) {
+      // stream response
+      console.log("model", model)
+      const response = await fetch(`${this.endpoint}/chat/completions`, {
+        method: "POST",
+        headers: {
+          "Content-Type": "application/json",
+        },
+        body: JSON.stringify({
+          "model": model,
+          "messages": messages,
+          "stream": true,
+        }),
+      });
+      if (!response.ok) {
+        const errorResBody = await response.json()
+        if (errorResBody?.detail) {
+          throw new Error(`Failed to fetch completions: ${errorResBody.detail}`);
+        } else {
+          throw new Error("Failed to fetch completions: Unknown error");
+        }
+      }
+
+      const reader = response.body.pipeThrough(new TextDecoderStream())
+        .pipeThrough(new EventSourceParserStream()).getReader();
+      while (true) {
+        const { done, value } = await reader.read();
+        if (done) {
+          break;
+        }
+        if (value.type === "event") {
+          const json = JSON.parse(value.data);
+          if (json.choices) {
+            const choice = json.choices[0];
+            if (choice.finish_reason === "stop") {
+              break;
+            }
+            yield choice.delta.content;
+          }
+        }
+      }
+    },
+
+    async fetchDownloadProgress() {
+      try {
+        const response = await fetch(`${this.endpoint}/download/progress`);
+        if (response.ok) {
+          const data = await response.json();
+          const progressArray = Object.values(data);
+          if (progressArray.length > 0) {
+            this.downloadProgress = progressArray.map(progress => {
+              // Check if download is complete
+              if (progress.status === "complete") {
+                return {
+                  ...progress,
+                  isComplete: true,
+                  percentage: 100
+                };
+              } else if (progress.status === "failed") {
+                return {
+                  ...progress,
+                  isComplete: false,
+                  errorMessage: "Download failed"
+                };
+              } else {
+                return {
+                  ...progress,
+                  isComplete: false,
+                  downloaded_bytes_display: this.formatBytes(progress.downloaded_bytes),
+                  total_bytes_display: this.formatBytes(progress.total_bytes),
+                  overall_speed_display: progress.overall_speed ? this.formatBytes(progress.overall_speed) + '/s' : '',
+                  overall_eta_display: progress.overall_eta ? this.formatDuration(progress.overall_eta) : '',
+                  percentage: ((progress.downloaded_bytes / progress.total_bytes) * 100).toFixed(2)
+                };
+              }
+            });
+            const allComplete = this.downloadProgress.every(progress => progress.isComplete);
+            if (allComplete) {
+              // Check for pendingMessage
+              const savedMessage = localStorage.getItem("pendingMessage");
+              if (savedMessage) {
+                // Clear pendingMessage
+                localStorage.removeItem("pendingMessage");
+                // Call processMessage() with savedMessage
+                if (this.lastErrorMessage) {
+                  await this.processMessage(savedMessage);
+                }
+              }
+              this.lastErrorMessage = null;
+              this.downloadProgress = null;
+            }
+          } else {
+            // No ongoing download
+            this.downloadProgress = null;
+          }
+        }
+      } catch (error) {
+        console.error("Error fetching download progress:", error);
+        this.downloadProgress = null;
+      }
+    },
+
+    startDownloadProgressPolling() {
+      if (this.downloadProgressInterval) {
+        // Already polling
+        return;
+      }
+      this.fetchDownloadProgress(); // Fetch immediately
+      this.downloadProgressInterval = setInterval(() => {
+        this.fetchDownloadProgress();
+      }, 1000); // Poll every second
+    },
+  }));
+});
+
+const { markedHighlight } = globalThis.markedHighlight;
+marked.use(markedHighlight({
+  langPrefix: "hljs language-",
+  highlight(code, lang, _info) {
+    const language = hljs.getLanguage(lang) ? lang : "plaintext";
+    return hljs.highlight(code, { language }).value;
+  },
+}));
+
+// **** eventsource-parser ****
+class EventSourceParserStream extends TransformStream {
+  constructor() {
+    let parser;
+
+    super({
+      start(controller) {
+        parser = createParser((event) => {
+          if (event.type === "event") {
+            controller.enqueue(event);
+          }
+        });
+      },
+
+      transform(chunk) {
+        parser.feed(chunk);
+      },
+    });
+  }
+}
+
+function createParser(onParse) {
+  let isFirstChunk;
+  let buffer;
+  let startingPosition;
+  let startingFieldLength;
+  let eventId;
+  let eventName;
+  let data;
+  reset();
+  return {
+    feed,
+    reset,
+  };
+  function reset() {
+    isFirstChunk = true;
+    buffer = "";
+    startingPosition = 0;
+    startingFieldLength = -1;
+    eventId = void 0;
+    eventName = void 0;
+    data = "";
+  }
+  function feed(chunk) {
+    buffer = buffer ? buffer + chunk : chunk;
+    if (isFirstChunk && hasBom(buffer)) {
+      buffer = buffer.slice(BOM.length);
+    }
+    isFirstChunk = false;
+    const length = buffer.length;
+    let position = 0;
+    let discardTrailingNewline = false;
+    while (position < length) {
+      if (discardTrailingNewline) {
+        if (buffer[position] === "\n") {
+          ++position;
+        }
+        discardTrailingNewline = false;
+      }
+      let lineLength = -1;
+      let fieldLength = startingFieldLength;
+      let character;
+      for (
+        let index = startingPosition;
+        lineLength < 0 && index < length;
+        ++index
+      ) {
+        character = buffer[index];
+        if (character === ":" && fieldLength < 0) {
+          fieldLength = index - position;
+        } else if (character === "\r") {
+          discardTrailingNewline = true;
+          lineLength = index - position;
+        } else if (character === "\n") {
+          lineLength = index - position;
+        }
+      }
+      if (lineLength < 0) {
+        startingPosition = length - position;
+        startingFieldLength = fieldLength;
+        break;
+      } else {
+        startingPosition = 0;
+        startingFieldLength = -1;
+      }
+      parseEventStreamLine(buffer, position, fieldLength, lineLength);
+      position += lineLength + 1;
+    }
+    if (position === length) {
+      buffer = "";
+    } else if (position > 0) {
+      buffer = buffer.slice(position);
+    }
+  }
+  function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) {
+    if (lineLength === 0) {
+      if (data.length > 0) {
+        onParse({
+          type: "event",
+          id: eventId,
+          event: eventName || void 0,
+          data: data.slice(0, -1),
+          // remove trailing newline
+        });
+
+        data = "";
+        eventId = void 0;
+      }
+      eventName = void 0;
+      return;
+    }
+    const noValue = fieldLength < 0;
+    const field = lineBuffer.slice(
+      index,
+      index + (noValue ? lineLength : fieldLength),
+    );
+    let step = 0;
+    if (noValue) {
+      step = lineLength;
+    } else if (lineBuffer[index + fieldLength + 1] === " ") {
+      step = fieldLength + 2;
+    } else {
+      step = fieldLength + 1;
+    }
+    const position = index + step;
+    const valueLength = lineLength - step;
+    const value = lineBuffer.slice(position, position + valueLength).toString();
+    if (field === "data") {
+      data += value ? "".concat(value, "\n") : "\n";
+    } else if (field === "event") {
+      eventName = value;
+    } else if (field === "id" && !value.includes("\0")) {
+      eventId = value;
+    } else if (field === "retry") {
+      const retry = parseInt(value, 10);
+      if (!Number.isNaN(retry)) {
+        onParse({
+          type: "reconnect-interval",
+          value: retry,
+        });
+      }
+    }
+  }
+}
+
+const BOM = [239, 187, 191];
+function hasBom(buffer) {
+  return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
+}

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 0 - 0
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 0 - 0
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js


+ 1 - 0
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js

@@ -0,0 +1 @@
+(()=>{function o(e){e.directive("intersect",e.skipDuringClone((t,{value:i,expression:l,modifiers:n},{evaluateLater:r,cleanup:c})=>{let s=r(l),a={rootMargin:x(n),threshold:f(n)},u=new IntersectionObserver(d=>{d.forEach(h=>{h.isIntersecting!==(i==="leave")&&(s(),n.includes("once")&&u.disconnect())})},a);u.observe(t),c(()=>{u.disconnect()})}))}function f(e){if(e.includes("full"))return .99;if(e.includes("half"))return .5;if(!e.includes("threshold"))return 0;let t=e[e.indexOf("threshold")+1];return t==="100"?1:t==="0"?0:Number(`.${t}`)}function p(e){let t=e.match(/^(-?[0-9]+)(px|%)?$/);return t?t[1]+(t[2]||"px"):void 0}function x(e){let t="margin",i="0px 0px 0px 0px",l=e.indexOf(t);if(l===-1)return i;let n=[];for(let r=1;r<5;r++)n.push(p(e[l+r]||""));return n=n.filter(r=>r!==void 0),n.length?n.join(" ").trim():i}document.addEventListener("alpine:init",()=>{window.Alpine.plugin(o)});})();

+ 11 - 0
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css

@@ -0,0 +1,11 @@
+/*!
+Pure v3.0.0
+Copyright 2013 Yahoo!
+Licensed under the BSD License.
+https://github.com/pure-css/pure/blob/master/LICENSE
+*/
+/*!
+normalize.css v | MIT License | https://necolas.github.io/normalize.css/
+Copyright (c) Nicolas Gallagher and Jonathan Neal
+*/
+/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block}

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 5 - 0
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf


BIN
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2


+ 7 - 0
build/lib/exo/tinychat/static/fonts.googleapis.com/css2

@@ -0,0 +1,7 @@
+@font-face {
+  font-family: 'Megrim';
+  font-style: normal;
+  font-weight: 400;
+  font-display: swap;
+  src: url(https://fonts.gstatic.com/s/megrim/v16/46kulbz5WjvLqJZlbQ.ttf) format('truetype');
+}

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 316 - 0
build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js


+ 1 - 0
build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css

@@ -0,0 +1 @@
+pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{background:#1e1e1e;color:#dcdcdc}.hljs-keyword,.hljs-literal,.hljs-name,.hljs-symbol{color:#569cd6}.hljs-link{color:#569cd6;text-decoration:underline}.hljs-built_in,.hljs-type{color:#4ec9b0}.hljs-class,.hljs-number{color:#b8d7a3}.hljs-meta .hljs-string,.hljs-string{color:#d69d85}.hljs-regexp,.hljs-template-tag{color:#9a5334}.hljs-formula,.hljs-function,.hljs-params,.hljs-subst,.hljs-title{color:#dcdcdc}.hljs-comment,.hljs-quote{color:#57a64a;font-style:italic}.hljs-doctag{color:#608b4e}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-tag{color:#9b9b9b}.hljs-template-variable,.hljs-variable{color:#bd63c5}.hljs-attr,.hljs-attribute{color:#9cdcfe}.hljs-section{color:gold}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}.hljs-bullet,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-selector-pseudo,.hljs-selector-tag{color:#d7ba7d}.hljs-addition{background-color:#144212;display:inline-block;width:100%}.hljs-deletion{background-color:#600;display:inline-block;width:100%}

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 0 - 0
build/lib/exo/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 0 - 0
build/lib/exo/tinychat/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 1 - 0
build/lib/exo/tinychat/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js


+ 97 - 0
build/lib/exo/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js

@@ -0,0 +1,97 @@
+(function (global, factory) {
+  typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
+  typeof define === 'function' && define.amd ? define(['exports'], factory) :
+  (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.markedHighlight = {}));
+})(this, (function (exports) { 'use strict';
+
+  function markedHighlight(options) {
+    if (typeof options === 'function') {
+      options = {
+        highlight: options
+      };
+    }
+
+    if (!options || typeof options.highlight !== 'function') {
+      throw new Error('Must provide highlight function');
+    }
+
+    if (typeof options.langPrefix !== 'string') {
+      options.langPrefix = 'language-';
+    }
+
+    return {
+      async: !!options.async,
+      walkTokens(token) {
+        if (token.type !== 'code') {
+          return;
+        }
+
+        const lang = getLang(token.lang);
+
+        if (options.async) {
+          return Promise.resolve(options.highlight(token.text, lang, token.lang || '')).then(updateToken(token));
+        }
+
+        const code = options.highlight(token.text, lang, token.lang || '');
+        if (code instanceof Promise) {
+          throw new Error('markedHighlight is not set to async but the highlight function is async. Set the async option to true on markedHighlight to await the async highlight function.');
+        }
+        updateToken(token)(code);
+      },
+      useNewRenderer: true,
+      renderer: {
+        code({ text, lang, escaped }) {
+          const language = getLang(lang);
+          const classAttr = language
+            ? ` class="${options.langPrefix}${escape(language)}"`
+            : '';
+          text = text.replace(/\n$/, '');
+          return `<pre><code${classAttr}>${escaped ? text : escape(text, true)}\n</code></pre>`;
+        }
+      }
+    };
+  }
+
+  function getLang(lang) {
+    return (lang || '').match(/\S*/)[0];
+  }
+
+  function updateToken(token) {
+    return (code) => {
+      if (typeof code === 'string' && code !== token.text) {
+        token.escaped = true;
+        token.text = code;
+      }
+    };
+  }
+
+  // copied from marked helpers
+  const escapeTest = /[&<>"']/;
+  const escapeReplace = new RegExp(escapeTest.source, 'g');
+  const escapeTestNoEncode = /[<>"']|&(?!(#\d{1,7}|#[Xx][a-fA-F0-9]{1,6}|\w+);)/;
+  const escapeReplaceNoEncode = new RegExp(escapeTestNoEncode.source, 'g');
+  const escapeReplacements = {
+    '&': '&amp;',
+    '<': '&lt;',
+    '>': '&gt;',
+    '"': '&quot;',
+    "'": '&#39;'
+  };
+  const getEscapeReplacement = (ch) => escapeReplacements[ch];
+  function escape(html, encode) {
+    if (encode) {
+      if (escapeTest.test(html)) {
+        return html.replace(escapeReplace, getEscapeReplacement);
+      }
+    } else {
+      if (escapeTestNoEncode.test(html)) {
+        return html.replace(escapeReplaceNoEncode, getEscapeReplacement);
+      }
+    }
+
+    return html;
+  }
+
+  exports.markedHighlight = markedHighlight;
+
+}));

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 5 - 0
build/lib/exo/tinychat/static/unpkg.com/marked@13.0.0/marked.min.js


+ 93 - 0
build/lib/exo/tinychat/update_deps.py

@@ -0,0 +1,93 @@
+import os
+import requests
+from bs4 import BeautifulSoup
+from urllib.parse import urljoin, urlparse
+import re
+
+
+def download_file(url, local_path):
+  response = requests.get(url)
+  if response.status_code == 200:
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    with open(local_path, 'wb') as f:
+      f.write(response.content)
+    print(f"Downloaded: {local_path}")
+  else:
+    print(response.status_code)
+    print(f"Failed to download: {url}")
+
+
+def update_html(html_content, base_url):
+  soup = BeautifulSoup(html_content, 'html.parser')
+
+  for tag in soup.find_all(['script', 'link']):
+    if tag.has_attr('src'):
+      url = tag['src']
+    elif tag.has_attr('href'):
+      url = tag['href']
+    else:
+      continue
+
+    if url.startswith(('http://', 'https://')):
+      full_url = url
+    else:
+      full_url = urljoin(base_url, url)
+
+    parsed_url = urlparse(full_url)
+    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+
+    download_file(full_url, local_path)
+
+    relative_path = os.path.relpath(local_path, '.')
+    if tag.name == 'script':
+      tag['src'] = "/" + relative_path
+    elif tag.name == 'link':
+      tag['href'] = "/" + relative_path
+
+  return str(soup)
+
+
+# Read the HTML file
+with open('./index.html', 'r') as f:
+  html_content = f.read()
+
+# Update HTML and download files
+# updated_html = update_html(html_content, 'https://example.com')
+
+# # Write the updated HTML
+# with open('./index.html', 'w') as f:
+#     f.write(updated_html)
+
+print("HTML file updated with local paths.")
+
+# Download Font Awesome CSS and font files
+base_url = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/"
+css_url = urljoin(base_url, "css/all.min.css")
+output_dir = "static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2"
+
+# Download CSS file
+css_output_path = os.path.join(output_dir, "css", "all.min.css")
+download_file(css_url, css_output_path)
+
+# Parse CSS file for font URLs
+with open(css_output_path, 'r', encoding='utf-8') as f:
+  css_content = f.read()
+
+# Extract font URLs from the CSS content
+font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
+
+print(f"Found {len(font_urls)} font URLs")
+
+# Download font files
+for font_url in font_urls:
+  font_url = font_url.strip('"\'')
+  if font_url.startswith('../'):
+    font_url = font_url[3:]
+
+  # Use base_url instead of urljoin to keep the version number
+  full_url = base_url + font_url
+  relative_path = font_url
+  output_path = os.path.join(output_dir, relative_path)
+  download_file(full_url, output_path)
+
+print("Download complete!")

+ 0 - 0
build/lib/exo/topology/__init__.py


+ 217 - 0
build/lib/exo/topology/device_capabilities.py

@@ -0,0 +1,217 @@
+from typing import Any
+from pydantic import BaseModel
+from exo import DEBUG
+import subprocess
+import psutil
+
+TFLOPS = 1.00
+
+
+class DeviceFlops(BaseModel):
+  # units of TFLOPS
+  fp32: float
+  fp16: float
+  int8: float
+
+  def __str__(self):
+    return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
+
+  def to_dict(self):
+    return self.model_dump()
+
+
+class DeviceCapabilities(BaseModel):
+  model: str
+  chip: str
+  memory: int
+  flops: DeviceFlops
+
+  def __str__(self):
+    return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
+
+  def model_post_init(self, __context: Any) -> None:
+    if isinstance(self.flops, dict):
+      self.flops = DeviceFlops(**self.flops)
+
+  def to_dict(self):
+    return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()}
+
+
+UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+
+CHIP_FLOPS = {
+  # Source: https://www.cpu-monkey.com
+  # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
+  ### M chips
+  "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
+  "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
+  "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
+  "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
+  "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
+  "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
+  "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
+  "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=4.26*TFLOPS, fp16=8.52*TFLOPS, int8=17.04*TFLOPS),
+  "Apple M4 Pro": DeviceFlops(fp32=5.72*TFLOPS, fp16=11.44*TFLOPS, int8=22.88*TFLOPS),
+  "Apple M4 Max": DeviceFlops(fp32=18.03*TFLOPS, fp16=36.07*TFLOPS, int8=72.14*TFLOPS),
+  ### A chips
+  "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
+  "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
+  "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
+  "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
+  "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
+  ### NVIDIA GPUs
+  # RTX 40 series
+  "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
+  "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
+  "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4060 TI": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
+  # RTX 30 series
+  "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
+  "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
+  # RTX 20 series
+  "NVIDIA GEFORCE RTX 2060": DeviceFlops(fp32=6.45*TFLOPS, fp16=12.9*TFLOPS, int8=25.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 2060 SUPER": DeviceFlops(fp32=7.2*TFLOPS, fp16=14.4*TFLOPS, int8=28.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS),
+  "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS),
+  "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS),
+  "NVIDIA GEFORCE RTX 2080 TI": DeviceFlops(fp32=13.45*TFLOPS, fp16=26.9*TFLOPS, int8=40.28*TFLOPS),
+  "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS),
+  "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS),
+  # GTX 10 series
+  "NVIDIA GEFORCE GTX 1050 TI": DeviceFlops(fp32=2.0*TFLOPS, fp16=4.0*TFLOPS, int8=8.0*TFLOPS),
+  "NVIDIA GEFORCE GTX 1070": DeviceFlops(fp32=6.463*TFLOPS, fp16=0.101*TFLOPS, int8=25.852*TFLOPS),
+  "NVIDIA GEFORCE GTX 1080": DeviceFlops(fp32=8.873*TFLOPS, fp16=0.138*TFLOPS, int8=35.492*TFLOPS),
+  "NVIDIA GEFORCE GTX 1080 TI": DeviceFlops(fp32=11.34*TFLOPS, fp16=0.177*TFLOPS, int8=45.36*TFLOPS),
+  # GTX 16 series
+  "NVIDIA GeForce GTX 1660 TI": DeviceFlops(fp32=4.8*TFLOPS, fp16=9.6*TFLOPS, int8=19.2*TFLOPS),
+  # QUADRO RTX Ampere series
+  "NVIDIA RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
+  "NVIDIA RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
+  "NVIDIA RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
+  "NVIDIA RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
+  "NVIDIA RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
+  # NVIDIA Ada Lovelace Architecture-Based
+  "NVIDIA RTX 4000 ADA GENERATION": DeviceFlops(fp32=26.7*TFLOPS, fp16=26.7*TFLOPS, int8=258.0*TFLOPS),
+  # Common Server GPUs
+  "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
+  "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  # ... add more devices if needed ...
+  ### AMD GPUs
+  # RX 6000 series
+  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
+  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
+  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
+  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
+  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
+  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
+  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
+  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
+  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
+  # RX 7000 series
+  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
+  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
+  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
+  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
+  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
+  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
+  ### Qualcomm embedded chips: TODO
+}
+CHIP_FLOPS.update({f"LAPTOP GPU {key}": value for key, value in CHIP_FLOPS.items()})
+CHIP_FLOPS.update({f"Laptop GPU {key}": value for key, value in CHIP_FLOPS.items()})
+CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items()})
+CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
+
+
+def device_capabilities() -> DeviceCapabilities:
+  if psutil.MACOS:
+    return mac_device_capabilities()
+  elif psutil.LINUX:
+    return linux_device_capabilities()
+  else:
+    return DeviceCapabilities(
+      model="Unknown Device",
+      chip="Unknown Chip",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+
+
+def mac_device_capabilities() -> DeviceCapabilities:
+  # Fetch the model of the Mac using system_profiler
+  model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
+  model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
+  model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
+  chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
+  chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
+  memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
+  memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
+  memory_units = memory_str.split()
+  memory_value = int(memory_units[0])
+  if memory_units[1] == "GB":
+    memory = memory_value*1024
+  else:
+    memory = memory_value
+
+  # Assuming static values for other attributes for demonstration
+  return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
+
+
+def linux_device_capabilities() -> DeviceCapabilities:
+  import psutil
+  from tinygrad import Device
+
+  if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
+  if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
+    import pynvml
+
+    pynvml.nvmlInit()
+    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
+    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
+    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+
+    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+    return DeviceCapabilities(
+      model=f"Linux Box ({gpu_name})",
+      chip=gpu_name,
+      memory=gpu_memory_info.total // 2**20,
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+  elif Device.DEFAULT == "AMD":
+    # TODO AMD support
+    return DeviceCapabilities(
+      model="Linux Box (AMD)",
+      chip="Unknown AMD",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+  else:
+    return DeviceCapabilities(
+      model=f"Linux Box (Device: {Device.DEFAULT})",
+      chip=f"Unknown Chip (Device: {Device.DEFAULT})",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно