瀏覽代碼

Remove build

Pranav Veldurthi 4 月之前
父節點
當前提交
3475be9e9e
共有 100 個文件被更改,包括 0 次插入9482 次删除
  1. 0 1
      build/lib/exo/__init__.py
  2. 0 1
      build/lib/exo/api/__init__.py
  3. 0 539
      build/lib/exo/api/chatgpt_api.py
  4. 0 1
      build/lib/exo/apputil/__init__.py
  5. 0 161
      build/lib/exo/apputil/anim.py
  6. 0 0
      build/lib/exo/download/__init__.py
  7. 0 61
      build/lib/exo/download/download_progress.py
  8. 0 0
      build/lib/exo/download/hf/__init__.py
  9. 0 447
      build/lib/exo/download/hf/hf_helpers.py
  10. 0 79
      build/lib/exo/download/hf/hf_shard_download.py
  11. 0 36
      build/lib/exo/download/shard_download.py
  12. 0 274
      build/lib/exo/helpers.py
  13. 0 0
      build/lib/exo/inference/__init__.py
  14. 0 58
      build/lib/exo/inference/debug_inference_engine.py
  15. 0 34
      build/lib/exo/inference/dummy_inference_engine.py
  16. 0 58
      build/lib/exo/inference/inference_engine.py
  17. 0 0
      build/lib/exo/inference/mlx/__init__.py
  18. 0 307
      build/lib/exo/inference/mlx/models/StableDiffusionPipeline.py
  19. 0 0
      build/lib/exo/inference/mlx/models/__init__.py
  20. 0 9
      build/lib/exo/inference/mlx/models/base.py
  21. 0 127
      build/lib/exo/inference/mlx/models/deepseek_v2.py
  22. 0 118
      build/lib/exo/inference/mlx/models/gemma2.py
  23. 0 125
      build/lib/exo/inference/mlx/models/llama.py
  24. 0 585
      build/lib/exo/inference/mlx/models/llava.py
  25. 0 128
      build/lib/exo/inference/mlx/models/qwen2.py
  26. 0 77
      build/lib/exo/inference/mlx/sharded_inference_engine.py
  27. 0 256
      build/lib/exo/inference/mlx/sharded_utils.py
  28. 0 45
      build/lib/exo/inference/mlx/stateful_model.py
  29. 0 40
      build/lib/exo/inference/mlx/test_sharded_llama.py
  30. 0 64
      build/lib/exo/inference/mlx/test_sharded_llava.py
  31. 0 52
      build/lib/exo/inference/mlx/test_sharded_model.py
  32. 0 39
      build/lib/exo/inference/shard.py
  33. 0 53
      build/lib/exo/inference/test_dummy_inference_engine.py
  34. 0 56
      build/lib/exo/inference/test_inference_engine.py
  35. 0 0
      build/lib/exo/inference/tinygrad/__init__.py
  36. 0 99
      build/lib/exo/inference/tinygrad/inference.py
  37. 0 0
      build/lib/exo/inference/tinygrad/models/__init__.py
  38. 0 282
      build/lib/exo/inference/tinygrad/models/llama.py
  39. 0 42
      build/lib/exo/inference/tinygrad/stateful_model.py
  40. 0 52
      build/lib/exo/inference/tinygrad/tinygrad_helpers.py
  41. 0 64
      build/lib/exo/inference/tokenizers.py
  42. 0 274
      build/lib/exo/main.py
  43. 0 151
      build/lib/exo/models.py
  44. 0 5
      build/lib/exo/networking/__init__.py
  45. 0 17
      build/lib/exo/networking/discovery.py
  46. 0 0
      build/lib/exo/networking/grpc/__init__.py
  47. 0 173
      build/lib/exo/networking/grpc/grpc_peer_handle.py
  48. 0 147
      build/lib/exo/networking/grpc/grpc_server.py
  49. 0 16
      build/lib/exo/networking/grpc/node_service_pb2.py
  50. 0 360
      build/lib/exo/networking/grpc/node_service_pb2_grpc.py
  51. 0 0
      build/lib/exo/networking/manual/__init__.py
  52. 0 71
      build/lib/exo/networking/manual/manual_discovery.py
  53. 0 31
      build/lib/exo/networking/manual/network_topology_config.py
  54. 0 103
      build/lib/exo/networking/manual/test_manual_discovery.py
  55. 0 49
      build/lib/exo/networking/manual/test_network_topology_config.py
  56. 0 56
      build/lib/exo/networking/peer_handle.py
  57. 0 11
      build/lib/exo/networking/server.py
  58. 0 0
      build/lib/exo/networking/tailscale/__init__.py
  59. 0 178
      build/lib/exo/networking/tailscale/tailscale_discovery.py
  60. 0 125
      build/lib/exo/networking/tailscale/tailscale_helpers.py
  61. 0 43
      build/lib/exo/networking/tailscale/test_tailscale_discovery.py
  62. 0 0
      build/lib/exo/networking/udp/__init__.py
  63. 0 77
      build/lib/exo/networking/udp/test_udp_discovery.py
  64. 0 215
      build/lib/exo/networking/udp/udp_discovery.py
  65. 0 4
      build/lib/exo/orchestration/__init__.py
  66. 0 47
      build/lib/exo/orchestration/node.py
  67. 0 488
      build/lib/exo/orchestration/standard_node.py
  68. 0 57
      build/lib/exo/orchestration/test_node.py
  69. 0 0
      build/lib/exo/stats/__init__.py
  70. 0 29
      build/lib/exo/stats/metrics.py
  71. 0 50
      build/lib/exo/test_callbacks.py
  72. 0 130
      build/lib/exo/tinychat/common.css
  73. 0 25
      build/lib/exo/tinychat/favicon.svg
  74. 0 484
      build/lib/exo/tinychat/index.css
  75. 0 255
      build/lib/exo/tinychat/index.html
  76. 0 687
      build/lib/exo/tinychat/index.js
  77. 0 0
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js
  78. 0 0
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js
  79. 0 1
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js
  80. 0 11
      build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css
  81. 0 5
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css
  82. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf
  83. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2
  84. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf
  85. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2
  86. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf
  87. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2
  88. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf
  89. 二進制
      build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2
  90. 0 7
      build/lib/exo/tinychat/static/fonts.googleapis.com/css2
  91. 0 316
      build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js
  92. 0 1
      build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css
  93. 0 0
      build/lib/exo/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js
  94. 0 0
      build/lib/exo/tinychat/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js
  95. 0 1
      build/lib/exo/tinychat/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js
  96. 0 97
      build/lib/exo/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js
  97. 0 5
      build/lib/exo/tinychat/static/unpkg.com/marked@13.0.0/marked.min.js
  98. 0 93
      build/lib/exo/tinychat/update_deps.py
  99. 0 0
      build/lib/exo/topology/__init__.py
  100. 0 217
      build/lib/exo/topology/device_capabilities.py

+ 0 - 1
build/lib/exo/__init__.py

@@ -1 +0,0 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 0 - 1
build/lib/exo/api/__init__.py

@@ -1 +0,0 @@
-from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI

+ 0 - 539
build/lib/exo/api/chatgpt_api.py

@@ -1,539 +0,0 @@
-import uuid
-import time
-import asyncio
-import json
-from pathlib import Path
-from transformers import AutoTokenizer
-from typing import List, Literal, Union, Dict
-from aiohttp import web
-import aiohttp_cors
-import traceback
-import signal
-from exo import DEBUG, VERSION
-from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
-from exo.inference.tokenizers import resolve_tokenizer
-from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
-from exo.apputil import create_animation_mp4
-from typing import Callable, Optional
-from PIL import Image
-import numpy as np
-import base64
-from io import BytesIO
-import mlx.core as mx
-import tempfile
-
-class Message:
-  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
-    self.role = role
-    self.content = content
-
-  def to_dict(self):
-    return {"role": self.role, "content": self.content}
-
-
-
-class ChatCompletionRequest:
-  def __init__(self, model: str, messages: List[Message], temperature: float):
-    self.model = model
-    self.messages = messages
-    self.temperature = temperature
-
-  def to_dict(self):
-    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
-
-
-def generate_completion(
-  chat_request: ChatCompletionRequest,
-  tokenizer,
-  prompt: str,
-  request_id: str,
-  tokens: List[int],
-  stream: bool,
-  finish_reason: Union[Literal["length", "stop"], None],
-  object_type: Literal["chat.completion", "text_completion"],
-) -> dict:
-  completion = {
-    "id": f"chatcmpl-{request_id}",
-    "object": object_type,
-    "created": int(time.time()),
-    "model": chat_request.model,
-    "system_fingerprint": f"exo_{VERSION}",
-    "choices": [{
-      "index": 0,
-      "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
-      "logprobs": None,
-      "finish_reason": finish_reason,
-    }],
-  }
-
-  if not stream:
-    completion["usage"] = {
-      "prompt_tokens": len(tokenizer.encode(prompt)),
-      "completion_tokens": len(tokens),
-      "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
-    }
-
-  choice = completion["choices"][0]
-  if object_type.startswith("chat.completion"):
-    key_name = "delta" if stream else "message"
-    choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
-  elif object_type == "text_completion":
-    choice["text"] = tokenizer.decode(tokens)
-  else:
-    ValueError(f"Unsupported response type: {object_type}")
-
-  return completion
-
-
-def remap_messages(messages: List[Message]) -> List[Message]:
-  remapped_messages = []
-  last_image = None
-  for message in messages:
-    if not isinstance(message.content, list):
-      remapped_messages.append(message)
-      continue
-
-    remapped_content = []
-    for content in message.content:
-      if isinstance(content, dict):
-        if content.get("type") in ["image_url", "image"]:
-          image_url = content.get("image_url", {}).get("url") or content.get("image")
-          if image_url:
-            last_image = {"type": "image", "image": image_url}
-            remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
-        else:
-          remapped_content.append(content)
-      else:
-        remapped_content.append(content)
-    remapped_messages.append(Message(role=message.role, content=remapped_content))
-
-  if last_image:
-    # Replace the last image placeholder with the actual image content
-    for message in reversed(remapped_messages):
-      for i, content in enumerate(message.content):
-        if isinstance(content, dict):
-          if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
-            message.content[i] = last_image
-            return remapped_messages
-
-  return remapped_messages
-
-
-def build_prompt(tokenizer, _messages: List[Message]):
-  messages = remap_messages(_messages)
-  prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  for message in messages:
-    if not isinstance(message.content, list):
-      continue
-
-  return prompt
-
-
-def parse_message(data: dict):
-  if "role" not in data or "content" not in data:
-    raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
-  return Message(data["role"], data["content"])
-
-
-def parse_chat_request(data: dict, default_model: str):
-  return ChatCompletionRequest(
-    data.get("model", default_model),
-    [parse_message(msg) for msg in data["messages"]],
-    data.get("temperature", 0.0),
-  )
-
-
-class PromptSession:
-  def __init__(self, request_id: str, timestamp: int, prompt: str):
-    self.request_id = request_id
-    self.timestamp = timestamp
-    self.prompt = prompt
-
-class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
-    self.node = node
-    self.inference_engine_classname = inference_engine_classname
-    self.response_timeout = response_timeout
-    self.on_chat_completion_request = on_chat_completion_request
-    self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload
-    self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
-    self.prev_token_lens: Dict[str, int] = {}
-    self.stream_tasks: Dict[str, asyncio.Task] = {}
-    self.default_model = default_model or "llama-3.2-1b"
-
-    cors = aiohttp_cors.setup(self.app)
-    cors_options = aiohttp_cors.ResourceOptions(
-      allow_credentials=True,
-      expose_headers="*",
-      allow_headers="*",
-      allow_methods="*",
-    )
-    cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
-    cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
-    cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
-    cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
-    cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
-    cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
-    cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
-    cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
-    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
-    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
-    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
-    cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
-    cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
-
-      
-    if "__compiled__" not in globals():
-      self.static_dir = Path(__file__).parent.parent/"tinychat"
-      self.app.router.add_get("/", self.handle_root)
-      self.app.router.add_static("/", self.static_dir, name="static")
-      self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
-
-    self.app.middlewares.append(self.timeout_middleware)
-    self.app.middlewares.append(self.log_request)
-
-  async def handle_quit(self, request):
-    if DEBUG>=1: print("Received quit signal")
-    response = web.json_response({"detail": "Quit signal received"}, status=200)
-    await response.prepare(request)
-    await response.write_eof()
-    await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
-
-  async def timeout_middleware(self, app, handler):
-    async def middleware(request):
-      try:
-        return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
-      except asyncio.TimeoutError:
-        return web.json_response({"detail": "Request timed out"}, status=408)
-
-    return middleware
-
-  async def log_request(self, app, handler):
-    async def middleware(request):
-      if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
-      return await handler(request)
-
-    return middleware
-
-  async def handle_root(self, request):
-    return web.FileResponse(self.static_dir/"index.html")
-
-  async def handle_healthcheck(self, request):
-    return web.json_response({"status": "ok"})
-
-  async def handle_model_support(self, request):
-    return web.json_response({
-      "model pool": {
-        model_name: pretty_name.get(model_name, model_name)
-        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
-      }
-    })
-
-  async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
-
-  async def handle_post_chat_token_encode(self, request):
-    data = await request.json()
-    shard = build_base_shard(self.default_model, self.inference_engine_classname)
-    messages = [parse_message(msg) for msg in data.get("messages", [])]
-    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
-
-  async def handle_get_download_progress(self, request):
-    progress_data = {}
-    for node_id, progress_event in self.node.node_download_progress.items():
-      if isinstance(progress_event, RepoProgressEvent):
-        progress_data[node_id] = progress_event.to_dict()
-      else:
-        print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
-    return web.json_response(progress_data)
-
-  async def handle_post_chat_completions(self, request):
-    data = await request.json()
-    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
-    stream = data.get("stream", False)
-    chat_request = parse_chat_request(data, self.default_model)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
-      chat_request.model = self.default_model
-    if not chat_request.model or chat_request.model not in model_cards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
-      chat_request.model = self.default_model
-    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
-    if not shard:
-      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
-      return web.json_response(
-        {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
-        status=400,
-      )
-
-    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
-
-    prompt = build_prompt(tokenizer, chat_request.messages)
-    request_id = str(uuid.uuid4())
-    if self.on_chat_completion_request:
-      try:
-        self.on_chat_completion_request(request_id, chat_request, prompt)
-      except Exception as e:
-        if DEBUG >= 2: traceback.print_exc()
-    # request_id = None
-    # match = self.prompts.find_longest_prefix(prompt)
-    # if match and len(prompt) > len(match[1].prompt):
-    #     if DEBUG >= 2:
-    #       print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
-    #     request_id = match[1].request_id
-    #     self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-    #     # remove the matching prefix from the prompt
-    #     prompt = prompt[len(match[1].prompt):]
-    # else:
-    #   request_id = str(uuid.uuid4())
-    #   self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-
-    callback_id = f"chatgpt-api-wait-response-{request_id}"
-    callback = self.node.on_token.register(callback_id)
-
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
-
-    try:
-      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
-
-      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
-
-      if stream:
-        response = web.StreamResponse(
-          status=200,
-          reason="OK",
-          headers={
-            "Content-Type": "text/event-stream",
-            "Cache-Control": "no-cache",
-          },
-        )
-        await response.prepare(request)
-
-        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
-          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
-          new_tokens = tokens[prev_last_tokens_len:]
-          finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
-                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
-          if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
-            new_tokens = new_tokens[:-1]
-            if is_finished:
-              finish_reason = "stop"
-          if is_finished and not finish_reason:
-            finish_reason = "length"
-
-          completion = generate_completion(
-            chat_request,
-            tokenizer,
-            prompt,
-            request_id,
-            new_tokens,
-            stream,
-            finish_reason,
-            "chat.completion",
-          )
-          if DEBUG >= 2: print(f"Streaming completion: {completion}")
-          try:
-            await response.write(f"data: {json.dumps(completion)}\n\n".encode())
-          except Exception as e:
-            if DEBUG >= 2: print(f"Error streaming completion: {e}")
-            if DEBUG >= 2: traceback.print_exc()
-
-        def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
-
-          return _request_id == request_id and is_finished
-
-        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
-        if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
-          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
-          try:
-            await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
-          except asyncio.TimeoutError:
-            print("WARNING: Stream task timed out. This should not happen.")
-        await response.write_eof()
-        return response
-      else:
-        _, tokens, _ = await callback.wait(
-          lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
-          timeout=self.response_timeout,
-        )
-
-        finish_reason = "length"
-        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
-        if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
-        if tokens[-1] == eos_token_id:
-          tokens = tokens[:-1]
-          finish_reason = "stop"
-
-        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
-    except asyncio.TimeoutError:
-      return web.json_response({"detail": "Response generation timed out"}, status=408)
-    except Exception as e:
-      if DEBUG >= 2: traceback.print_exc()
-      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-    finally:
-      deregistered_callback = self.node.on_token.deregister(callback_id)
-      if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
-
-  
-  async def handle_post_image_generations(self, request):
-    data = await request.json()
-
-    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
-    stream = data.get("stream", False)
-    model = data.get("model", "")
-    prompt = data.get("prompt", "")
-    image_url = data.get("image_url", "")
-    print(f"model: {model}, prompt: {prompt}, stream: {stream}")
-    shard = build_base_shard(model, self.inference_engine_classname)
-    print(f"shard: {shard}")
-    if not shard:
-        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
-
-    request_id = str(uuid.uuid4())
-    callback_id = f"chatgpt-api-wait-response-{request_id}"
-    callback = self.node.on_token.register(callback_id)
-    try:
-      if image_url != "" and image_url != None:
-        img = self.base64_decode(image_url)
-      else:
-        img = None
-      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
-
-
-      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
-      await response.prepare(request)
-
-      def get_progress_bar(current_step, total_steps, bar_length=50):
-        # Calculate the percentage of completion
-        percent = float(current_step) / total_steps
-        # Calculate the number of hashes to display
-        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
-        spaces = ' ' * (bar_length - len(arrow))
-        
-        # Create the progress bar string
-        progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
-        return progress_bar
-
-      async def stream_image(_request_id: str, result, is_finished: bool):
-          if isinstance(result, list):
-              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
-
-          elif isinstance(result, np.ndarray):
-            im = Image.fromarray(np.array(result))
-            images_folder = get_exo_images_dir()
-            # Save the image to a file
-            image_filename = f"{_request_id}.png"
-            image_path = images_folder / image_filename
-            im.save(image_path)
-            image_url = request.app.router['static_images'].url_for(filename=image_filename)
-            base_url = f"{request.scheme}://{request.host}"
-            # Construct the full URL correctly
-            full_image_url = base_url + str(image_url)
-            
-            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
-            if is_finished:
-              await response.write_eof()
-              
-
-      stream_task = None
-      def on_result(_request_id: str, result, is_finished: bool):
-          nonlocal stream_task
-          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
-          return _request_id == request_id and is_finished
-
-      await callback.wait(on_result, timeout=self.response_timeout*10)
-      
-      if stream_task:
-          # Wait for the stream task to complete before returning
-          await stream_task
-
-      return response
-
-    except Exception as e:
-        if DEBUG >= 2: traceback.print_exc()
-        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-  
-  async def handle_create_animation(self, request):
-    try:
-      data = await request.json()
-      replacement_image_path = data.get("replacement_image_path")
-      device_name = data.get("device_name", "Local Device")
-      prompt_text = data.get("prompt", "")
-
-      if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
-
-      if not replacement_image_path:
-        return web.json_response({"error": "replacement_image_path is required"}, status=400)
-
-      # Create temp directory if it doesn't exist
-      tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
-      tmp_dir.mkdir(parents=True, exist_ok=True)
-
-      # Generate unique output filename in temp directory
-      output_filename = f"animation_{uuid.uuid4()}.mp4"
-      output_path = str(tmp_dir/output_filename)
-
-      if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
-
-      # Create the animation
-      create_animation_mp4(
-        replacement_image_path,
-        output_path,
-        device_name,
-        prompt_text
-      )
-
-      return web.json_response({
-        "status": "success",
-        "output_path": output_path
-      })
-
-    except Exception as e:
-      if DEBUG >= 2: traceback.print_exc()
-      return web.json_response({"error": str(e)}, status=500)
-
-  async def handle_post_download(self, request):
-    try:
-      data = await request.json()
-      model_name = data.get("model")
-      if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
-      if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
-      shard = build_base_shard(model_name, self.inference_engine_classname)
-      if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
-      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
-
-      return web.json_response({
-        "status": "success",
-        "message": f"Download started for model: {model_name}"
-      })
-    except Exception as e:
-      if DEBUG >= 2: traceback.print_exc()
-      return web.json_response({"error": str(e)}, status=500)
-
-  async def run(self, host: str = "0.0.0.0", port: int = 52415):
-    runner = web.AppRunner(self.app)
-    await runner.setup()
-    site = web.TCPSite(runner, host, port)
-    await site.start()
-
-  def base64_decode(self, base64_string):
-    #decode and reshape image
-    if base64_string.startswith('data:image'):
-        base64_string = base64_string.split(',')[1]
-    image_data = base64.b64decode(base64_string)
-    img = Image.open(BytesIO(image_data))
-    W, H = (dim - dim % 64 for dim in (img.width, img.height))
-    if W != img.width or H != img.height:
-        print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
-        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
-    img = mx.array(np.array(img))
-    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
-    img = img[None]
-    return img
-  

+ 0 - 1
build/lib/exo/apputil/__init__.py

@@ -1 +0,0 @@
-from exo.apputil.anim import create_animation_mp4

+ 0 - 161
build/lib/exo/apputil/anim.py

@@ -1,161 +0,0 @@
-from PIL import Image, ImageDraw, ImageFont, ImageFilter
-import os
-import numpy as np
-import cv2
-
-def draw_rounded_rectangle(draw, coords, radius, fill):
-  left, top, right, bottom = coords
-  diameter = radius * 2
-  draw.rectangle([left + radius, top, right - radius, bottom], fill=fill)
-  draw.rectangle([left, top + radius, right, bottom - radius], fill=fill)
-  draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill)
-  draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill)
-  draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill)
-  draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill)
-
-def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)):
-  bbox = font.getbbox(text)
-  text_width = bbox[2] - bbox[0]
-  text_height = bbox[3] - bbox[1]
-  rect_left, rect_top, rect_right, rect_bottom = rect_coords
-  rect_width = rect_right - rect_left
-  rect_height = rect_bottom - rect_top
-  text_x = rect_left + (rect_width - text_width) // 2
-  text_y = rect_top + (rect_height - text_height) // 2
-  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
-  draw.text((text_x, text_y), text, fill=text_color, font=font)
-
-def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
-  bbox = font.getbbox(text)
-  text_height = bbox[3] - bbox[1]
-  rect_left, rect_top, rect_right, rect_bottom = rect_coords
-  rect_height = rect_bottom - rect_top
-  text_y = rect_top + (rect_height - text_height) // 2
-  text_x = rect_left + padding_left
-  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
-  draw.text((text_x, text_y), text, fill=text_color, font=font)
-
-def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
-  bbox = font.getbbox(text)
-  text_width = bbox[2] - bbox[0]
-  text_height = bbox[3] - bbox[1]
-  _, rect_top, rect_right, rect_bottom = base_coords
-  rect_height = rect_bottom - rect_top
-  new_rect_left = rect_right - (text_width + (padding * 2))
-  text_y = rect_top + (rect_height - text_height) // 2
-  text_x = new_rect_left + padding
-  draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color)
-  draw.text((text_x, text_y), text, fill=text_color, font=font)
-  return new_rect_left
-
-def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)):
-  left, top, right, bottom = coords
-  total_width = right - left
-  draw.rectangle(coords, fill=bg_color)
-  progress_width = int(total_width * progress)
-  if progress_width > 0:
-    draw.rectangle((left, top, left + progress_width, bottom), fill=color)
-
-def crop_image(image, top_crop=70):
-  width, height = image.size
-  return image.crop((0, top_crop, width, height))
-
-def create_animation_mp4(
-  replacement_image_path,
-  output_path,
-  device_name,
-  prompt_text,
-  fps=30,
-  target_size=(512, 512),
-  target_position=(139, 755),
-  progress_coords=(139, 1285, 655, 1295),
-  device_coords=(1240, 370, 1640, 416),
-  prompt_coords=(332, 1702, 2662, 1745)
-):
-  frames = []
-  try:
-    font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20)
-    promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24)
-  except:
-    font = ImageFont.load_default()
-    promptfont = ImageFont.load_default()
-
-  # Process first frame
-  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
-  draw = ImageDraw.Draw(base_img)
-  draw_centered_text_rounded(draw, device_name, font, device_coords)
-  frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
-
-  # Process second frame with typing animation
-  base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
-  for i in range(len(prompt_text) + 1):
-    current_frame = base_img2.copy()
-    draw = ImageDraw.Draw(current_frame)
-    draw_centered_text_rounded(draw, device_name, font, device_coords)
-    if i > 0:  # Only draw if we have at least one character
-      draw_left_aligned_text_rounded(draw, prompt_text[:i], promptfont, prompt_coords)
-    frames.extend([crop_image(current_frame)] * 2)  # 2 frames per character for smooth typing
-  
-  # Hold the complete prompt for a moment
-  frames.extend([frames[-1]] * 30)  # Hold for 1 second
-
-  # Create blur sequence
-  replacement_img = Image.open(replacement_image_path)
-  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
-  blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
-
-  for i, blur_amount in enumerate(blur_steps):
-    new_frame = base_img.copy()
-    draw = ImageDraw.Draw(new_frame)
-
-    replacement_copy = replacement_img.copy()
-    replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
-    if blur_amount > 0:
-      replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount))
-
-    mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
-    new_frame.paste(replacement_copy, target_position, mask)
-
-    draw_progress_bar(draw, (i + 1) / 9, progress_coords)
-    draw_centered_text_rounded(draw, device_name, font, device_coords)
-    draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
-
-    frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
-
-  # Create and add final frame (image4)
-  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
-  draw = ImageDraw.Draw(final_base)
-
-  draw_centered_text_rounded(draw, device_name, font, device_coords)
-  draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
-
-  replacement_copy = replacement_img.copy()
-  replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
-  mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
-  final_base.paste(replacement_copy, target_position, mask)
-
-  frames.extend([crop_image(final_base)] * 30)  # 1 second at 30fps
-
-  # Convert frames to video using H.264 codec
-  if frames:
-    first_frame = np.array(frames[0])
-    height, width = first_frame.shape[:2]
-    fourcc = cv2.VideoWriter_fourcc(*'avc1')
-    out = cv2.VideoWriter(
-      output_path,
-      fourcc,
-      fps,
-      (width, height),
-      isColor=True
-    )
-
-    if not out.isOpened():
-      print("Error: VideoWriter failed to open")
-      return
-
-    for frame in frames:
-      frame_array = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
-      out.write(frame_array)
-    
-    out.release()
-    print(f"Video saved successfully to {output_path}")

+ 0 - 0
build/lib/exo/download/__init__.py


+ 0 - 61
build/lib/exo/download/download_progress.py

@@ -1,61 +0,0 @@
-from typing import Dict, Callable, Coroutine, Any, Literal
-from dataclasses import dataclass
-from datetime import timedelta
-
-
-@dataclass
-class RepoFileProgressEvent:
-  repo_id: str
-  repo_revision: str
-  file_path: str
-  downloaded: int
-  downloaded_this_session: int
-  total: int
-  speed: int
-  eta: timedelta
-  status: Literal["not_started", "in_progress", "complete"]
-
-  def to_dict(self):
-    return {
-      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
-      "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
-    }
-
-  @classmethod
-  def from_dict(cls, data):
-    if 'eta' in data: data['eta'] = timedelta(seconds=data['eta'])
-    return cls(**data)
-
-
-@dataclass
-class RepoProgressEvent:
-  repo_id: str
-  repo_revision: str
-  completed_files: int
-  total_files: int
-  downloaded_bytes: int
-  downloaded_bytes_this_session: int
-  total_bytes: int
-  overall_speed: int
-  overall_eta: timedelta
-  file_progress: Dict[str, RepoFileProgressEvent]
-  status: Literal["not_started", "in_progress", "complete"]
-
-  def to_dict(self):
-    return {
-      "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
-      "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
-      "file_progress": {k: v.to_dict()
-                        for k, v in self.file_progress.items()}, "status": self.status
-    }
-
-  @classmethod
-  def from_dict(cls, data):
-    if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
-    if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
-
-    return cls(**data)
-
-
-RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
-RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]

+ 0 - 0
build/lib/exo/download/hf/__init__.py


+ 0 - 447
build/lib/exo/download/hf/hf_helpers.py

@@ -1,447 +0,0 @@
-import aiofiles.os as aios
-from typing import Union
-import asyncio
-import aiohttp
-import json
-import os
-import sys
-import shutil
-from urllib.parse import urljoin
-from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
-from datetime import datetime, timedelta
-from fnmatch import fnmatch
-from pathlib import Path
-from typing import Generator, Iterable, TypeVar, TypedDict
-from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG, is_frozen
-from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
-from exo.inference.shard import Shard
-import aiofiles
-
-T = TypeVar("T")
-
-async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
-  refs_dir = get_repo_root(repo_id)/"refs"
-  refs_file = refs_dir/revision
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
-      return snapshot_dir
-  return None
-
-
-def filter_repo_objects(
-  items: Iterable[T],
-  *,
-  allow_patterns: Optional[Union[List[str], str]] = None,
-  ignore_patterns: Optional[Union[List[str], str]] = None,
-  key: Optional[Callable[[T], str]] = None,
-) -> Generator[T, None, None]:
-  if isinstance(allow_patterns, str):
-    allow_patterns = [allow_patterns]
-  if isinstance(ignore_patterns, str):
-    ignore_patterns = [ignore_patterns]
-  if allow_patterns is not None:
-    allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
-  if ignore_patterns is not None:
-    ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
-
-  if key is None:
-
-    def _identity(item: T) -> str:
-      if isinstance(item, str):
-        return item
-      if isinstance(item, Path):
-        return str(item)
-      raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
-
-    key = _identity
-
-  for item in items:
-    path = key(item)
-    if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
-      continue
-    if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
-      continue
-    yield item
-
-
-def _add_wildcard_to_directories(pattern: str) -> str:
-  if pattern[-1] == "/":
-    return pattern + "*"
-  return pattern
-
-
-def get_hf_endpoint() -> str:
-  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
-
-
-def get_hf_home() -> Path:
-  """Get the Hugging Face home directory."""
-  return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
-
-
-async def get_hf_token():
-  """Retrieve the Hugging Face token from the user's HF_HOME directory."""
-  token_path = get_hf_home()/"token"
-  if await aios.path.exists(token_path):
-    async with aiofiles.open(token_path, 'r') as f:
-      return (await f.read()).strip()
-  return None
-
-
-async def get_auth_headers():
-  """Get authentication headers if a token is available."""
-  token = await get_hf_token()
-  if token:
-    return {"Authorization": f"Bearer {token}"}
-  return {}
-
-
-def get_repo_root(repo_id: str) -> Path:
-  """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = str(repo_id).replace("/", "--")
-  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
-
-async def move_models_to_hf(seed_dir: Union[str, Path]):
-  """Move model in resources folder of app to .cache/huggingface/hub"""
-  source_dir = Path(seed_dir)
-  dest_dir = get_hf_home()/"hub"
-  await aios.makedirs(dest_dir, exist_ok=True)  
-  for path in source_dir.iterdir():
-    if path.is_dir() and path.name.startswith("models--"):
-      dest_path = dest_dir / path.name
-      if await aios.path.exists(dest_path):
-        print('Skipping moving model to .cache directory')
-      else:
-        try:
-          await aios.rename(str(path), str(dest_path))
-        except Exception as e:
-          print(f'Error moving model to .cache: {e}')
-    
-    
-    
-async def fetch_file_list(session, repo_id, revision, path=""):
-  api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
-  url = f"{api_url}/{path}" if path else api_url
-
-  headers = await get_auth_headers()
-  async with session.get(url, headers=headers) as response:
-    if response.status == 200:
-      data = await response.json()
-      files = []
-      for item in data:
-        if item["type"] == "file":
-          files.append({"path": item["path"], "size": item["size"]})
-        elif item["type"] == "directory":
-          subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
-          files.extend(subfiles)
-      return files
-    else:
-      raise Exception(f"Failed to fetch file list: {response.status}")
-
-
-@retry(
-  stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
-)
-async def download_file(
-  session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
-):
-  base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-  url = urljoin(base_url, file_path)
-  local_path = os.path.join(save_directory, file_path)
-
-  await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
-
-  # Check if file already exists and get its size
-  local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
-
-  headers = await get_auth_headers()
-  if use_range_request:
-    headers["Range"] = f"bytes={local_file_size}-"
-
-  async with session.get(url, headers=headers) as response:
-    total_size = int(response.headers.get('Content-Length', 0))
-    downloaded_size = local_file_size
-    downloaded_this_session = 0
-    mode = 'ab' if use_range_request else 'wb'
-    if downloaded_size == total_size:
-      if DEBUG >= 2: print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
-
-    if response.status == 200:
-      # File doesn't support range requests or we're not using them, start from beginning
-      mode = 'wb'
-      downloaded_size = 0
-    elif response.status == 206:
-      # Partial content, resume download
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    elif response.status == 416:
-      # Range not satisfiable, get the actual file size
-      content_range = response.headers.get('Content-Range', '')
-      try:
-        total_size = int(content_range.split('/')[-1])
-        if downloaded_size == total_size:
-          if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
-          if progress_callback:
-            await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-          return
-      except ValueError:
-        if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
-        return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
-    else:
-      raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
-
-    if downloaded_size == total_size:
-      print(f"File already downloaded: {file_path}")
-      if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
-      return
-
-    DOWNLOAD_CHUNK_SIZE = 32768
-    start_time = datetime.now()
-    async with aiofiles.open(local_path, mode) as f:
-      async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
-        await f.write(chunk)
-        downloaded_size += len(chunk)
-        downloaded_this_session += len(chunk)
-        if progress_callback and total_size:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
-          remaining_size = total_size - downloaded_size
-          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
-          status = "in_progress" if downloaded_size < total_size else "complete"
-          if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
-          await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
-    if DEBUG >= 2: print(f"Downloaded: {file_path}")
-
-
-async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
-  repo_root = get_repo_root(repo_id)
-  refs_dir = repo_root/"refs"
-  refs_file = refs_dir/revision
-
-  # Check if we have a cached commit hash
-  if await aios.path.exists(refs_file):
-    async with aiofiles.open(refs_file, 'r') as f:
-      commit_hash = (await f.read()).strip()
-      if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
-      return commit_hash
-
-  # Fetch the commit hash for the given revision
-  async with aiohttp.ClientSession() as session:
-    api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
-    headers = await get_auth_headers()
-    async with session.get(api_url, headers=headers) as response:
-      if response.status != 200:
-        raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
-      revision_info = await response.json()
-      commit_hash = revision_info['sha']
-
-  # Cache the commit hash
-  await aios.makedirs(refs_dir, exist_ok=True)
-  async with aiofiles.open(refs_file, 'w') as f:
-    await f.write(commit_hash)
-
-  return commit_hash
-
-
-async def download_repo_files(
-  repo_id: str,
-  revision: str = "main",
-  progress_callback: Optional[RepoProgressCallback] = None,
-  allow_patterns: Optional[Union[List[str], str]] = None,
-  ignore_patterns: Optional[Union[List[str], str]] = None,
-  max_parallel_downloads: int = 4
-) -> Path:
-  repo_root = get_repo_root(repo_id)
-  snapshots_dir = repo_root/"snapshots"
-  cachedreqs_dir = repo_root/"cachedreqs"
-
-  # Ensure directories exist
-  await aios.makedirs(snapshots_dir, exist_ok=True)
-  await aios.makedirs(cachedreqs_dir, exist_ok=True)
-
-  # Resolve revision to commit hash
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-
-  # Set up the snapshot directory
-  snapshot_dir = snapshots_dir/commit_hash
-  await aios.makedirs(snapshot_dir, exist_ok=True)
-
-  # Set up the cached file list directory
-  cached_file_list_dir = cachedreqs_dir/commit_hash
-  await aios.makedirs(cached_file_list_dir, exist_ok=True)
-  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
-
-  async with aiohttp.ClientSession() as session:
-    # Check if we have a cached file list
-    if await aios.path.exists(cached_file_list_path):
-      async with aiofiles.open(cached_file_list_path, 'r') as f:
-        file_list = json.loads(await f.read())
-      if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
-    else:
-      file_list = await fetch_file_list(session, repo_id, revision)
-      # Cache the file list
-      async with aiofiles.open(cached_file_list_path, 'w') as f:
-        await f.write(json.dumps(file_list))
-      if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
-
-    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
-    if model_index_exists:
-      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
-
-    filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
-    total_files = len(filtered_file_list)
-    total_bytes = sum(file["size"] for file in filtered_file_list)
-    file_progress: Dict[str, RepoFileProgressEvent] = {
-      file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
-      for file in filtered_file_list
-    }
-    start_time = datetime.now()
-
-    async def download_with_progress(file_info, progress_state):
-      local_path = snapshot_dir/file_info["path"]
-      if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
-        if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
-        progress_state['completed_files'] += 1
-        progress_state['downloaded_bytes'] += file_info["size"]
-        file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-        return
-
-      async def file_progress_callback(event: RepoFileProgressEvent):
-        progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
-        progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
-        file_progress[event.file_path] = event
-        if progress_callback:
-          elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-          remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-          status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
-          await progress_callback(
-            RepoProgressEvent(
-              repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-              overall_eta, file_progress, status
-            )
-          )
-
-      await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
-      progress_state['completed_files'] += 1
-      file_progress[
-        file_info["path"]
-      ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
-      if progress_callback:
-        elapsed_time = (datetime.now() - start_time).total_seconds()
-        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
-        remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
-        status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
-        await progress_callback(
-          RepoProgressEvent(
-            repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
-            overall_eta, file_progress, status
-          )
-        )
-
-    progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
-
-    semaphore = asyncio.Semaphore(max_parallel_downloads)
-
-    async def download_with_semaphore(file_info):
-      async with semaphore:
-        await download_with_progress(file_info, progress_state)
-
-    tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
-    await asyncio.gather(*tasks)
-
-  return snapshot_dir
-
-
-async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
-  """
-    Retrieve the weight map from the model.safetensors.index.json file.
-
-    Args:
-        repo_id (str): The Hugging Face repository ID.
-        revision (str): The revision of the repository to use.
-
-    Returns:
-        Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
-    """
-
-  # Download the index file
-  await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
-
-  # Check if the file exists
-  repo_root = get_repo_root(repo_id)
-  commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
-  snapshot_dir = repo_root/"snapshots"/commit_hash
-  index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
-
-  if index_file:
-    index_file_path = snapshot_dir/index_file
-    if await aios.path.exists(index_file_path):
-      async with aiofiles.open(index_file_path, 'r') as f:
-        index_data = json.loads(await f.read())
-      return index_data.get("weight_map")
-
-  return None
-
-
-def extract_layer_num(tensor_name: str) -> Optional[int]:
-  # This is a simple example and might need to be adjusted based on the actual naming convention
-  parts = tensor_name.split('.')
-  for part in parts:
-    if part.isdigit():
-      return int(part)
-  return None
-
-
-def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
-  shard_specific_patterns = set()
-  if weight_map:
-    for tensor_name, filename in weight_map.items():
-      layer_num = extract_layer_num(tensor_name)
-      if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
-        shard_specific_patterns.add(filename)
-    sorted_file_names = sorted(weight_map.values())
-    if shard.is_first_layer():
-      shard_specific_patterns.add(sorted_file_names[0])
-    elif shard.is_last_layer():
-      shard_specific_patterns.add(sorted_file_names[-1])
-  else:
-    shard_specific_patterns = set(["*.safetensors"])
-  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
-  return list(default_patterns | shard_specific_patterns)
-
-async def has_hf_home_read_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.R_OK)
-  except OSError: return False
-
-async def has_hf_home_write_access() -> bool:
-  hf_home = get_hf_home()
-  try: return await aios.access(hf_home, os.W_OK)
-  except OSError: return False

+ 0 - 79
build/lib/exo/download/hf/hf_shard_download.py

@@ -1,79 +0,0 @@
-import asyncio
-import traceback
-from pathlib import Path
-from typing import Dict, List, Tuple
-from exo.inference.shard import Shard
-from exo.download.shard_download import ShardDownloader
-from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
-from exo.helpers import AsyncCallbackSystem, DEBUG
-from exo.models import model_cards, get_repo
-
-
-class HFShardDownloader(ShardDownloader):
-  def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
-    self.quick_check = quick_check
-    self.max_parallel_downloads = max_parallel_downloads
-    self.active_downloads: Dict[Shard, asyncio.Task] = {}
-    self.completed_downloads: Dict[Shard, Path] = {}
-    self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
-
-  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    repo_name = get_repo(shard.model_id, inference_engine_name)
-    if shard in self.completed_downloads:
-      return self.completed_downloads[shard]
-    if self.quick_check:
-      repo_root = get_repo_root(repo_name)
-      snapshots_dir = repo_root/"snapshots"
-      if snapshots_dir.exists():
-        visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
-        if visible_dirs:
-          most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
-          return most_recent_dir
-
-    # If a download on this shard is already in progress, keep that one
-    for active_shard in self.active_downloads:
-      if active_shard == shard:
-        if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
-        return await self.active_downloads[shard]
-
-    # Cancel any downloads for this model_id on a different shard
-    existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
-    for active_shard in existing_active_shards:
-      if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
-      task = self.active_downloads[active_shard]
-      task.cancel()
-      try:
-        await task
-      except asyncio.CancelledError:
-        pass  # This is expected when cancelling a task
-      except Exception as e:
-        if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
-        traceback.print_exc()
-    self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
-
-    # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
-    self.active_downloads[shard] = download_task
-    try:
-      path = await download_task
-      self.completed_downloads[shard] = path
-      return path
-    finally:
-      # Ensure the task is removed even if an exception occurs
-      print(f"Removing download task for {shard}: {shard in self.active_downloads}")
-      if shard in self.active_downloads:
-        self.active_downloads.pop(shard)
-
-  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
-    async def wrapped_progress_callback(event: RepoProgressEvent):
-      self._on_progress.trigger_all(shard, event)
-
-    weight_map = await get_weight_map(repo_name)
-    allow_patterns = get_allow_patterns(weight_map, shard)
-
-    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
-
-  @property
-  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-    return self._on_progress

+ 0 - 36
build/lib/exo/download/shard_download.py

@@ -1,36 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Optional, Tuple
-from pathlib import Path
-from exo.inference.shard import Shard
-from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import AsyncCallbackSystem
-
-
-class ShardDownloader(ABC):
-  @abstractmethod
-  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    """
-        Ensures that the shard is downloaded.
-        Does not allow multiple overlapping downloads at once.
-        If you try to download a Shard which overlaps a Shard that is already being downloaded,
-        the download will be cancelled and a new download will start.
-
-        Args:
-            shard (Shard): The shard to download.
-            inference_engine_name (str): The inference engine used on the node hosting the shard
-        """
-    pass
-
-  @property
-  @abstractmethod
-  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-    pass
-
-
-class NoopShardDownloader(ShardDownloader):
-  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
-    return Path("/tmp/noop_shard")
-
-  @property
-  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
-    return AsyncCallbackSystem()

+ 0 - 274
build/lib/exo/helpers.py

@@ -1,274 +0,0 @@
-import os
-import sys
-import asyncio
-from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
-import socket
-import random
-import platform
-import psutil
-import uuid
-import netifaces
-from pathlib import Path
-import tempfile
-
-DEBUG = int(os.getenv("DEBUG", default="0"))
-DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
-VERSION = "0.0.1"
-
-exo_text = r"""
-  _____  _____  
- / _ \ \/ / _ \ 
-|  __/>  < (_) |
- \___/_/\_\___/ 
-    """
-
-
-def get_system_info():
-  if psutil.MACOS:
-    if platform.machine() == "arm64":
-      return "Apple Silicon Mac"
-    if platform.machine() in ["x86_64", "i386"]:
-      return "Intel Mac"
-    return "Unknown Mac architecture"
-  if psutil.LINUX:
-    return "Linux"
-  return "Non-Mac, non-Linux system"
-
-
-def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
-  used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")
-
-  def read_used_ports():
-    if os.path.exists(used_ports_file):
-      with open(used_ports_file, "r") as f:
-        return [int(line.strip()) for line in f if line.strip().isdigit()]
-    return []
-
-  def write_used_port(port, used_ports):
-    with open(used_ports_file, "w") as f:
-      print(used_ports[-19:])
-      for p in used_ports[-19:] + [port]:
-        f.write(f"{p}\n")
-
-  used_ports = read_used_ports()
-  available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
-
-  while available_ports:
-    port = random.choice(list(available_ports))
-    if DEBUG >= 2: print(f"Trying to find available port {port=}")
-    try:
-      with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-        s.bind((host, port))
-      write_used_port(port, used_ports)
-      return port
-    except socket.error:
-      available_ports.remove(port)
-
-  raise RuntimeError("No available ports in the specified range")
-
-
-def print_exo():
-  print(exo_text)
-
-
-def print_yellow_exo():
-  yellow = "\033[93m"  # ANSI escape code for yellow
-  reset = "\033[0m"  # ANSI escape code to reset color
-  print(f"{yellow}{exo_text}{reset}")
-
-
-def terminal_link(uri, label=None):
-  if label is None:
-    label = uri
-  parameters = ""
-
-  # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST
-  escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\"
-
-  return escape_mask.format(parameters, uri, label)
-
-
-T = TypeVar("T")
-K = TypeVar("K")
-
-
-class AsyncCallback(Generic[T]):
-  def __init__(self) -> None:
-    self.condition: asyncio.Condition = asyncio.Condition()
-    self.result: Optional[Tuple[T, ...]] = None
-    self.observers: list[Callable[..., None]] = []
-
-  async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
-    async with self.condition:
-      await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
-      assert self.result is not None  # for type checking
-      return self.result
-
-  def on_next(self, callback: Callable[..., None]) -> None:
-    self.observers.append(callback)
-
-  def set(self, *args: T) -> None:
-    self.result = args
-    for observer in self.observers:
-      observer(*args)
-    asyncio.create_task(self.notify())
-
-  async def notify(self) -> None:
-    async with self.condition:
-      self.condition.notify_all()
-
-
-class AsyncCallbackSystem(Generic[K, T]):
-  def __init__(self) -> None:
-    self.callbacks: Dict[K, AsyncCallback[T]] = {}
-
-  def register(self, name: K) -> AsyncCallback[T]:
-    if name not in self.callbacks:
-      self.callbacks[name] = AsyncCallback[T]()
-    return self.callbacks[name]
-
-  def deregister(self, name: K) -> None:
-    if name in self.callbacks:
-      del self.callbacks[name]
-
-  def trigger(self, name: K, *args: T) -> None:
-    if name in self.callbacks:
-      self.callbacks[name].set(*args)
-
-  def trigger_all(self, *args: T) -> None:
-    for callback in self.callbacks.values():
-      callback.set(*args)
-
-
-K = TypeVar('K', bound=str)
-V = TypeVar('V')
-
-
-class PrefixDict(Generic[K, V]):
-  def __init__(self):
-    self.items: Dict[K, V] = {}
-
-  def add(self, key: K, value: V) -> None:
-    self.items[key] = value
-
-  def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
-    return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
-
-  def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
-    matches = self.find_prefix(argument)
-    if len(matches) == 0:
-      return None
-
-    return max(matches, key=lambda x: len(x[0]))
-
-
-def is_valid_uuid(val):
-  try:
-    uuid.UUID(str(val))
-    return True
-  except ValueError:
-    return False
-
-
-def get_or_create_node_id():
-  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
-  try:
-    if NODE_ID_FILE.is_file():
-      with open(NODE_ID_FILE, "r") as f:
-        stored_id = f.read().strip()
-      if is_valid_uuid(stored_id):
-        if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
-        return stored_id
-      else:
-        if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
-
-    new_id = str(uuid.uuid4())
-    with open(NODE_ID_FILE, "w") as f:
-      f.write(new_id)
-
-    if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
-    return new_id
-  except IOError as e:
-    if DEBUG >= 2: print(f"IO error creating node_id: {e}")
-    return str(uuid.uuid4())
-  except Exception as e:
-    if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
-    return str(uuid.uuid4())
-
-
-def pretty_print_bytes(size_in_bytes: int) -> str:
-  if size_in_bytes < 1024:
-    return f"{size_in_bytes} B"
-  elif size_in_bytes < 1024**2:
-    return f"{size_in_bytes / 1024:.2f} KB"
-  elif size_in_bytes < 1024**3:
-    return f"{size_in_bytes / (1024 ** 2):.2f} MB"
-  elif size_in_bytes < 1024**4:
-    return f"{size_in_bytes / (1024 ** 3):.2f} GB"
-  else:
-    return f"{size_in_bytes / (1024 ** 4):.2f} TB"
-
-
-def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
-  if bytes_per_second < 1024:
-    return f"{bytes_per_second} B/s"
-  elif bytes_per_second < 1024**2:
-    return f"{bytes_per_second / 1024:.2f} KB/s"
-  elif bytes_per_second < 1024**3:
-    return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
-  elif bytes_per_second < 1024**4:
-    return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
-  else:
-    return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
-
-
-def get_all_ip_addresses():
-  try:
-    ip_addresses = []
-    for interface in netifaces.interfaces():
-      ifaddresses = netifaces.ifaddresses(interface)
-      if netifaces.AF_INET in ifaddresses:
-        for link in ifaddresses[netifaces.AF_INET]:
-          ip = link['addr']
-          ip_addresses.append(ip)
-    return list(set(ip_addresses))
-  except:
-    if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-    return ["localhost"]
-
-
-async def shutdown(signal, loop, server):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-
-
-def is_frozen():
-  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
-    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
-    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
-
-
-def get_exo_home() -> Path:
-  if os.name == "nt":  # Check if the OS is Windows
-    docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
-  else:
-    docs_folder = Path.home() / "Documents"
-  exo_folder = docs_folder / "Exo"
-  if not exo_folder.exists():
-    exo_folder.mkdir()
-  return exo_folder
-
-def get_exo_images_dir() -> Path:
-  exo_home = get_exo_home()
-  images_dir = exo_home / "Images"
-  if not images_dir.exists():
-    images_dir.mkdir()
-  return images_dir
-  

+ 0 - 0
build/lib/exo/inference/__init__.py


+ 0 - 58
build/lib/exo/inference/debug_inference_engine.py

@@ -1,58 +0,0 @@
-from exo.inference.inference_engine import InferenceEngine
-from exo.inference.shard import Shard
-from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-import asyncio
-import numpy as np
-
-
-# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-  from exo.inference.tinygrad.inference import Tokenizer
-  from pathlib import Path
-
-  _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
-
-  prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  token_full = await inference_engine_1.sample(resp_full)
-
-  next_resp_full = await inference_engine_1.infer_tensor(
-    "A",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=token_full,
-  )
-
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
-    input_data=resp1,
-  )
-  token2 = await inference_engine_2.sample(resp2)
-  resp3 = await inference_engine_1.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=token2,
-  )
-  resp4 = await inference_engine_2.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
-    input_data=resp3,
-  )
-
-  print(f"{resp2=}")
-  print(f"full: {_tokenizer.decode(resp_full)}")
-  print(f"next full: {_tokenizer.decode(next_resp_full)}")
-  print(f"resp2: {_tokenizer.decode(resp2)}")
-  print(f"{resp4=}")
-  print(f"resp4: {_tokenizer.decode(resp4)}")
-
-  assert np.array_equal(resp_full, resp2)
-  assert np.array_equal(next_resp_full, resp4)
-
-
-asyncio.run(test_inference_engine(
-  TinygradDynamicShardInferenceEngine(),
-  TinygradDynamicShardInferenceEngine(),
-  "llama3-8b-sfr",
-))

+ 0 - 34
build/lib/exo/inference/dummy_inference_engine.py

@@ -1,34 +0,0 @@
-from typing import Optional, Tuple, TYPE_CHECKING
-import numpy as np
-from exo.inference.inference_engine import InferenceEngine
-from exo.inference.shard import Shard
-from exo.inference.tokenizers import DummyTokenizer
-
-class DummyInferenceEngine(InferenceEngine):
-  def __init__(self):
-    self.shard = None
-    self.vocab_size = 1000
-    self.hidden_size = 256
-    self.eos_token_id = 0
-    self.latency_mean = 0.1
-    self.latency_stddev = 0.02
-    self.num_generate_dummy_tokens = 10
-    self.tokenizer = DummyTokenizer()
-
-  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
-    return np.array(self.tokenizer.encode(prompt))
-  
-  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
-    if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
-    return x
-
-  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
-    return self.tokenizer.decode(tokens)
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
-    await self.ensure_shard(shard)
-    return input_data + 1 if self.shard.is_last_layer() else input_data
-
-  async def ensure_shard(self, shard: Shard):
-    if self.shard == shard: return
-    self.shard = shard

+ 0 - 58
build/lib/exo/inference/inference_engine.py

@@ -1,58 +0,0 @@
-import numpy as np
-import os
-from exo.helpers import DEBUG  # Make sure to import DEBUG
-
-from typing import Tuple, Optional
-from abc import ABC, abstractmethod
-from .shard import Shard
-
-
-class InferenceEngine(ABC):
-  @abstractmethod
-  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
-    pass
-  
-  @abstractmethod
-  async def sample(self, x: np.ndarray) -> np.ndarray:
-    pass
-
-  @abstractmethod
-  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
-    pass
-
-  @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
-    pass
-  
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
-    tokens = await self.encode(shard, prompt)
-    if shard.model_id != 'stable-diffusion-2-1-base':
-      x = tokens.reshape(1, -1)
-    else:
-      x = tokens
-    output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
-    return output_data, inference_state
-
-inference_engine_classes = {
-  "mlx": "MLXDynamicShardInferenceEngine",
-  "tinygrad": "TinygradDynamicShardInferenceEngine",
-  "dummy": "DummyInferenceEngine",
-}
-
-def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
-  if DEBUG >= 2:
-    print(f"get_inference_engine called with: {inference_engine_name}")
-  if inference_engine_name == "mlx":
-    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-
-    return MLXDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "tinygrad":
-    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-    import tinygrad.helpers
-    tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-
-    return TinygradDynamicShardInferenceEngine(shard_downloader)
-  elif inference_engine_name == "dummy":
-    from exo.inference.dummy_inference_engine import DummyInferenceEngine
-    return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 0 - 0
build/lib/exo/inference/mlx/__init__.py


+ 0 - 307
build/lib/exo/inference/mlx/models/StableDiffusionPipeline.py

@@ -1,307 +0,0 @@
-# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
-
-import time
-from typing import Optional, Tuple
-import inspect
-
-import mlx.core as mx
-import mlx.nn as nn
-from pathlib import Path
-
-from tqdm import tqdm
-
-from .sd_models.vae import ModelArgs as VAEArgs
-from .sd_models.vae import Autoencoder
-from .sd_models.tokenizer import load_tokenizer
-from .sd_models.clip import CLIPTextModel
-from .sd_models.clip import ModelArgs as CLIPArgs
-from .sd_models.unet import UNetConfig, UNetModel
-
-from dataclasses import dataclass, field
-from exo.inference.shard import Shard
-
-@dataclass
-class DiffusionConfig:
-    beta_schedule: str = "scaled_linear"
-    beta_start: float = 0.00085
-    beta_end: float = 0.012
-    num_train_steps: int = 1000
-
-    @classmethod
-    def from_dict(cls, params):
-        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
-
-
-#Sampler
-def _linspace(a, b, num):
-    x = mx.arange(0, num) / (num - 1)
-    return (b - a) * x + a
-
-
-def _interp(y, x_new):
-    """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
-    x_low = x_new.astype(mx.int32)
-    x_high = mx.minimum(x_low + 1, len(y) - 1)
-
-    y_low = y[x_low]
-    y_high = y[x_high]
-    delta_x = x_new - x_low
-    y_new = y_low * (1 - delta_x) + delta_x * y_high
-
-    return y_new
-
-class SimpleEulerSampler:
-    """A simple Euler integrator that can be used to sample from our diffusion models.
-
-    The method ``step()`` performs one Euler step from x_t to x_t_prev.
-    """
-
-    def __init__(self, config: DiffusionConfig):
-        # Compute the noise schedule
-        if config.beta_schedule == "linear":
-            betas = _linspace(
-                config.beta_start, config.beta_end, config.num_train_steps
-            )
-        elif config.beta_schedule == "scaled_linear":
-            betas = _linspace(
-                config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
-            ).square()
-        else:
-            raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
-
-        alphas = 1 - betas
-        alphas_cumprod = mx.cumprod(alphas)
-
-        self._sigmas = mx.concatenate(
-            [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
-        )
-
-    @property
-    def max_time(self):
-        return len(self._sigmas) - 1
-
-    def sample_prior(self, shape, dtype=mx.float32, key=None):
-        noise = mx.random.normal(shape, key=key)
-        return (
-            noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
-        ).astype(dtype)
-
-    def add_noise(self, x, t, key=None):
-        noise = mx.random.normal(x.shape, key=key)
-        s = self.sigmas(t)
-        return (x + noise * s) * (s.square() + 1).rsqrt()
-
-    def sigmas(self, t):
-        return _interp(self._sigmas, t)
-
-    def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
-        start_time = start_time or (len(self._sigmas) - 1)
-        assert 0 < start_time <= (len(self._sigmas) - 1)
-        steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
-        return list(zip(steps, steps[1:]))
-
-    def current_timestep(self, step, total_steps, start_time=None):
-        if step < total_steps:
-            steps = self.timesteps(total_steps, start_time)
-            return steps[step]
-        else:
-            return mx.array(0),mx.array(0)
-
-    def step(self, eps_pred, x_t, t, t_prev):
-        sigma = self.sigmas(t).astype(eps_pred.dtype)
-        sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
-
-        dt = sigma_prev - sigma
-        x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
-
-        x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
-
-        return x_t_prev
-
-@dataclass
-class ShardConfig:
-    model_id:str
-    start_layer:int
-    end_layer:int
-    n_layers:int
-
-@dataclass
-class StableDiffusionConfig:
-    model_type:str
-    vae:VAEArgs
-    text_encoder:CLIPArgs
-    scheduler:DiffusionConfig
-    unet:UNetConfig
-    shard:ShardConfig
-    
-    @classmethod
-    def from_dict(cls, params):
-        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
-
-@dataclass
-class ModelArgs(StableDiffusionConfig):
-    shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-    def __post_init__(self):
-        if isinstance(self.shard, dict):
-            self.shard = Shard(**self.shard)
-
-        if not isinstance(self.shard, Shard):
-            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-
-class Model(nn.Module):
-    def __init__(self, config):
-        super().__init__()
-        self.model_type = config.model_type
-        self.config = config
-        self.model_path = config.vae['path'].split('/vae')[0]
-        self.shard = config.shard
-        self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder  = model_shards(config.shard)
-        self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
-        if self.shard_clip.start_layer != -1:
-            self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
-        else:
-            self.text_encoder = nn.Identity()    
-        self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
-        self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
-        self.sampler = SimpleEulerSampler(self.diffusion_config)
-        if self.shard_unet.start_layer!=-1:
-            self.config_unet = UNetConfig.from_dict(config.unet['config'])
-            self.unet = UNetModel(self.config_unet, self.shard_unet)
-        else:
-            self.unet = nn.Identity()
-        self.config_vae=VAEArgs.from_dict(config.vae['config'])
-        if self.shard_encoder.start_layer != -1:
-            self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder") 
-        else:
-            self.encoder = nn.Identity()            
-        if self.shard_decoder.start_layer != -1:
-            self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder") 
-        else:
-            self.decoder = nn.Identity()            
-
-    def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.7, start_step=None):
-        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
-        is_finished = False
-        is_step_finished = False
-        if t.item()==1000:
-            if self.shard_clip.start_layer == 0:
-                conditioning = x
-            if self.shard_clip.start_layer != -1:
-                conditioning, mask= self.text_encoder(conditioning,mask)
-            seed = int(time.time()) 
-            mx.random.seed(seed)
-            if image is None:
-                if self.shard_encoder.is_last_layer():
-                    x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
-                    x_t_prev=x
-                    start_step = self.sampler.max_time
-            else:
-                if self.shard_encoder.start_layer != -1:
-                    image= self.encoder.encode(image)
-                    if self.shard_encoder.is_last_layer():
-                        start_step = self.sampler.max_time*strength
-                        total_steps = int(total_steps*strength)
-                        image = mx.broadcast_to(image, (1,) + image.shape[1:])
-                        x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
-                        image = None
-                        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
-        # Perform the denoising loop
-        if self.shard_unet.start_layer != -1:
-            with tqdm(total=total_steps,initial=step+1) as pbar:
-                if step<total_steps:
-                    x = x_t_prev
-                    if self.shard_unet.is_first_layer():
-                        x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
-                    else:
-                        x_t_unet = x
-                    t_unet = mx.broadcast_to(t, [len(x_t_unet)])
-                    x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
-                    if self.shard_unet.is_last_layer():
-                        if cfg_weight > 1:
-                            eps_text, eps_neg = x.split(2)
-                            eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
-                        x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
-                        x_t_prev=x
-                    mx.eval(x)
-                    
-        if self.shard_decoder.is_last_layer():
-            is_step_finished=True
-            if self.shard_decoder.start_layer != -1:
-                x=self.decoder.decode(x)
-            if self.shard_decoder.is_last_layer():
-                x = mx.clip(x / 2 + 0.5, 0, 1)
-                B, H, W, C = x.shape
-                x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
-                x = x.reshape(1 * H, B // 1 * W, C)
-                x = (x * 255).astype(mx.uint8)
-                if t_prev.item() ==0:
-                    is_finished=True   
-        mx.eval(x)
-         
-        return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
-    
-
-    def load(self):
-        if self.shard_encoder.start_layer != -1:    
-            vae_weights =  mx.load(self.config_vae.weight_files[0])
-            vae_weights = self.encoder.sanitize(vae_weights)
-            self.encoder.load_weights(list(vae_weights.items()), strict=True)
-        if self.shard_decoder.start_layer != -1:
-            vae_weights =  mx.load(self.config_vae.weight_files[0])
-            vae_weights = self.decoder.sanitize(vae_weights)
-            self.decoder.load_weights(list(vae_weights.items()), strict=True)
-        if self.shard_clip.start_layer != -1:
-            clip_weights = mx.load(self.config_clip.weight_files[0])
-            clip_weights = self.text_encoder.sanitize(clip_weights)
-            self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
-        if self.shard_unet.start_layer !=-1:
-            unet_weights = mx.load(self.config_unet.weight_files[0])
-            unet_weights = self.unet.sanitize(unet_weights)
-            self.unet.load_weights(list(unet_weights.items()), strict=True)
-
-def model_shards(shard:ShardConfig):
-    def create_shard(shard, model_ranges):
-        start_layer = shard.start_layer
-        end_layer = shard.end_layer
-        
-        shards = {}
-        
-        for model_name, (range_start, range_end) in model_ranges.items():
-            if start_layer < range_end and end_layer >= range_start:
-                # Calculate the overlap with the model range
-                overlap_start = max(start_layer, range_start)
-                overlap_end = min(end_layer, range_end - 1)
-
-                # Adjust the layers relative to the model's range
-                relative_start = overlap_start - range_start
-                relative_end = overlap_end - range_start
-                shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
-            else:
-                # If no overlap, create a zero-layer shard
-                shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
-        
-        return shards
-
-    # Define the ranges for different models
-    model_ranges = {
-        'clip': (0, 12),
-        'vae_encoder':(12,17),
-        'unet':(17,26),
-        'vae_decoder': (26, 31) # Example range for unet
-    }
-
-    # Call the function and get the shards for all models
-    shards = create_shard(shard, model_ranges)
-
-    # Access individual shards
-    shard_clip = shards['clip']
-    shard_encoder = shards['vae_encoder']
-    shard_unet = shards['unet']
-    shard_decoder = shards['vae_decoder']
-    
-    return shard_clip, shard_encoder, shard_unet, shard_decoder
-
-
-

+ 0 - 0
build/lib/exo/inference/mlx/models/__init__.py


+ 0 - 9
build/lib/exo/inference/mlx/models/base.py

@@ -1,9 +0,0 @@
-from typing import Optional
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.cache import KVCache
-
-
-class IdentityBlock(nn.Module):
-  def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
-    return x

+ 0 - 127
build/lib/exo/inference/mlx/models/deepseek_v2.py

@@ -1,127 +0,0 @@
-from dataclasses import dataclass, field
-from typing import Optional
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from mlx_lm.models.cache import KVCache
-from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
-from .base import IdentityBlock
-from exo.inference.shard import Shard
-
-
-@dataclass
-class ModelArgs(ModelArgs):
-  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-  def __post_init__(self):
-    if isinstance(self.shard, Shard):
-      return
-    if not isinstance(self.shard, dict):
-      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-    self.shard = Shard(**self.shard)
-
-
-class DeepseekV2Model(nn.Module):
-  def __init__(self, config: ModelArgs):
-    super().__init__()
-    self.args = config
-    self.num_hidden_layers = config.num_hidden_layers
-    self.vocab_size = config.vocab_size
-    if self.args.shard.is_first_layer():
-      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
-
-    self.layers = []
-    for i in range(self.num_hidden_layers):
-      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
-        self.layers.append(DeepseekV2DecoderLayer(config, i))
-      else:
-        self.layers.append(IdentityBlock())
-
-    if self.args.shard.is_last_layer():
-      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
-  def __call__(
-    self,
-    x: mx.array,
-    cache: Optional[KVCache] = None,
-  ) -> mx.array:
-    if self.args.shard.is_first_layer():
-      h = self.embed_tokens(x)
-    else:
-      h = x
-
-    mask = None
-    T = h.shape[1]
-    if T > 1:
-      mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
-      mask = mask.astype(h.dtype)
-
-    if cache is None:
-      cache = [None]*len(self.layers)
-
-    for layer, c in zip(self.layers, cache):
-      h = layer(h, mask, c)
-
-    if self.args.shard.is_last_layer():
-      h = self.norm(h)
-    return h
-
-
-class Model(nn.Module):
-  def __init__(self, config: ModelArgs):
-    super().__init__()
-    self.args = config
-    self.model_type = config.model_type
-    self.model = DeepseekV2Model(config)
-    if self.args.shard.is_last_layer():
-      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache: Optional[KVCache] = None,
-  ):
-    out = self.model(inputs, cache)
-    if self.args.shard.is_last_layer():
-      return self.lm_head(out)
-    return out
-
-  def sanitize(self, weights):
-    shard_state_dict = {}
-
-    for key, value in weights.items():
-      if key.startswith('model.layers.'):
-        layer_num = int(key.split('.')[2])
-        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
-          shard_state_dict[key] = value
-      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
-        shard_state_dict[key] = value
-      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
-        shard_state_dict[key] = value
-
-    for l in range(self.args.num_hidden_layers):
-      prefix = f"model.layers.{l}"
-      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
-        for k in ["weight", "scales", "biases"]:
-          if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
-            to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
-            shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
-
-    return shard_state_dict
-
-  @property
-  def layers(self):
-    return self.model.layers
-
-  @property
-  def head_dim(self):
-    return (
-      self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
-      self.args.v_head_dim,
-    )
-
-  @property
-  def n_kv_heads(self):
-    return self.args.num_key_value_heads

+ 0 - 118
build/lib/exo/inference/mlx/models/gemma2.py

@@ -1,118 +0,0 @@
-from dataclasses import dataclass, field
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from mlx_lm.models.base import create_attention_mask
-from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
-
-from ...shard import Shard
-from .base import IdentityBlock
-
-
-@dataclass
-class ModelArgs(ModelArgs):
-  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-  def __post_init__(self):
-    if isinstance(self.shard, Shard):
-      return
-    if not isinstance(self.shard, dict):
-      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-    self.shard = Shard(**self.shard)
-
-
-class GemmaModel(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.vocab_size = args.vocab_size
-    self.num_hidden_layers = args.num_hidden_layers
-    assert self.vocab_size > 0
-    if args.shard.is_first_layer() or args.shard.is_last_layer():
-      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
-    self.layers = []
-    for i in range(self.num_hidden_layers):
-      if args.shard.start_layer <= i <= args.shard.end_layer:
-        self.layers.append(TransformerBlock(args=args))
-      else:
-        self.layers.append(IdentityBlock())
-    if args.shard.is_last_layer():
-      self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    if self.args.shard.is_first_layer():
-      h = self.embed_tokens(inputs)
-      h = h * (self.args.hidden_size**0.5)
-    else:
-      h = inputs
-
-    mask = None
-    if h.ndim > 1 and h.shape[1] > 1:
-      mask = create_attention_mask(h, cache)
-
-    if cache is None:
-      cache = [None]*len(self.layers)
-
-    for layer, c in zip(self.layers, cache):
-      h = layer(h, mask, cache=c)
-
-    if self.args.shard.is_last_layer():
-      h = self.norm(h)
-    return h
-
-
-class Model(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.model_type = args.model_type
-    self.model = GemmaModel(args)
-    if args.shard.is_last_layer():
-      self.final_logit_softcapping = args.final_logit_softcapping
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    out = self.model(inputs, cache)
-    if self.args.shard.is_last_layer():
-      out = self.model.embed_tokens.as_linear(out)
-      out = mx.tanh(out / self.final_logit_softcapping)
-      out = out * self.final_logit_softcapping
-    return out
-
-  def sanitize(self, weights):
-    shard_state_dict = {}
-
-    for key, value in weights.items():
-      if "self_attn.rotary_emb.inv_freq" in key:
-        continue
-      if key.startswith('model.layers.'):
-        layer_num = int(key.split('.')[2])
-        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
-          shard_state_dict[key] = value
-      elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
-        shard_state_dict[key] = value
-      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
-        shard_state_dict[key] = value
-
-    return shard_state_dict
-
-  @property
-  def layers(self):
-    return self.model.layers
-
-  @property
-  def head_dim(self):
-    return self.args.head_dim
-
-  @property
-  def n_kv_heads(self):
-    return self.args.num_key_value_heads

+ 0 - 125
build/lib/exo/inference/mlx/models/llama.py

@@ -1,125 +0,0 @@
-from dataclasses import dataclass, field
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from mlx_lm.models.base import create_attention_mask
-from mlx_lm.models.llama import TransformerBlock, ModelArgs
-
-from ...shard import Shard
-from .base import IdentityBlock
-
-
-@dataclass
-class ModelArgs(ModelArgs):
-  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-  def __post_init__(self):
-    super().__post_init__()  # Ensure parent initializations are respected
-
-    if isinstance(self.shard, Shard):
-      return
-    if not isinstance(self.shard, dict):
-      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-    self.shard = Shard(**self.shard)
-
-
-class LlamaModel(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.vocab_size = args.vocab_size
-    self.num_hidden_layers = args.num_hidden_layers
-    assert self.vocab_size > 0
-    if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
-      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
-    self.layers = []
-    for i in range(self.num_hidden_layers):
-      if args.shard.start_layer <= i <= args.shard.end_layer:
-        self.layers.append(TransformerBlock(args=args))
-      else:
-        self.layers.append(IdentityBlock())
-    if args.shard.is_last_layer():
-      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    if self.args.shard.is_first_layer():
-      h = self.embed_tokens(inputs)
-    else:
-      h = inputs
-
-    mask = None
-    if h.ndim > 1 and h.shape[1] > 1:
-      mask = create_attention_mask(h, cache)
-
-    if cache is None:
-      cache = [None]*len(self.layers)
-
-    for layer, c in zip(self.layers, cache):
-      h = layer(h, mask, cache=c)
-
-    if self.args.shard.is_last_layer():
-      h = self.norm(h)
-    return h
-
-
-class Model(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.model_type = args.model_type
-    self.model = LlamaModel(args)
-    if args.shard.is_last_layer():
-      if not args.tie_word_embeddings:
-        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    out = self.model(inputs, cache)
-    if self.args.shard.is_last_layer():
-      if self.args.tie_word_embeddings:
-        out = self.model.embed_tokens.as_linear(out)
-      else:
-        out = self.lm_head(out)
-    return out
-
-  def sanitize(self, weights):
-    shard_state_dict = {}
-
-    for key, value in weights.items():
-      if "self_attn.rotary_emb.inv_freq" in key:
-        continue
-      if key.startswith('model.layers.'):
-        layer_num = int(key.split('.')[2])
-        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
-          shard_state_dict[key] = value
-      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
-        shard_state_dict[key] = value
-      elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
-        shard_state_dict[key] = value
-      elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
-        shard_state_dict[key] = value
-      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
-        shard_state_dict[key] = value
-
-    return shard_state_dict
-
-  @property
-  def layers(self):
-    return self.model.layers
-
-  @property
-  def head_dim(self):
-    return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
-
-  @property
-  def n_kv_heads(self):
-    return self.args.num_key_value_heads

+ 0 - 585
build/lib/exo/inference/mlx/models/llava.py

@@ -1,585 +0,0 @@
-# Copyright © 2024 Apple Inc.
-
-import math
-import inspect
-from dataclasses import dataclass, field
-from typing import Optional, Dict, Union
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.base import BaseModelArgs, KVCache
-from exo.inference.shard import Shard
-from .base import IdentityBlock
-import numpy as np
-
-
-@dataclass
-class VisionConfig:
-  model_type: str
-  num_hidden_layers: int = 24
-  hidden_size: int = 1024
-  intermediate_size: int = 4096
-  num_attention_heads: int = 16
-  image_size: int = 336
-  patch_size: int = 14
-  projection_dim: int = 768
-  vocab_size: int = 32000
-  num_channels: int = 3
-  layer_norm_eps: float = 1e-5
-
-  @classmethod
-  def from_dict(cls, params):
-    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
-
-
-class VisionAttention(nn.Module):
-  def __init__(
-    self,
-    dims: int,
-    num_heads: int,
-    query_input_dims: Optional[int] = None,
-    key_input_dims: Optional[int] = None,
-    value_input_dims: Optional[int] = None,
-    value_dims: Optional[int] = None,
-    value_output_dims: Optional[int] = None,
-    bias: bool = False,
-  ):
-    super().__init__()
-
-    if (dims % num_heads) != 0:
-      raise ValueError("The input feature dimensions should be divisible by the "
-                       f"number of heads ({dims} % {num_heads}) != 0")
-
-    query_input_dims = query_input_dims or dims
-    key_input_dims = key_input_dims or dims
-    value_input_dims = value_input_dims or key_input_dims
-    value_dims = value_dims or dims
-    value_output_dims = value_output_dims or dims
-
-    self.num_heads = num_heads
-    self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
-    self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
-    self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
-    self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
-
-  def __call__(self, queries, keys, values, mask=None):
-    queries = self.q_proj(queries)
-    keys = self.k_proj(keys)
-    values = self.v_proj(values)
-
-    num_heads = self.num_heads
-    B, L, D = queries.shape
-    _, S, _ = keys.shape
-    queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
-    keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
-    values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
-
-    scale = math.sqrt(1/queries.shape[-1])
-    scores = (queries*scale) @ keys
-    if mask is not None:
-      scores = scores + mask.astype(scores.dtype)
-    scores = mx.softmax(scores, axis=-1)
-    values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
-
-    return self.out_proj(values_hat)
-
-
-class VisionMLP(nn.Module):
-  def __init__(self, config: VisionConfig):
-    super().__init__()
-    self.activation_fn = nn.GELU(approx="fast")
-    self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
-    self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
-
-  def __call__(self, x: mx.array) -> mx.array:
-    x = self.activation_fn(self.fc1(x))
-    x = self.fc2(x)
-    return x
-
-
-class VisionEncoderLayer(nn.Module):
-  def __init__(self, config: VisionConfig):
-    super().__init__()
-    self.embed_dim = config.hidden_size
-    self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True)
-    self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-    self.mlp = VisionMLP(config)
-    self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-
-  def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
-    y = self.layer_norm1(x)
-    y = self.self_attn(y, y, y, mask)
-    x = x + y
-    y = self.layer_norm2(x)
-    y = self.mlp(y)
-    return x + y
-
-
-class VisionEncoder(nn.Module):
-  def __init__(self, config: VisionConfig):
-    super().__init__()
-    self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
-
-
-class VisionEmbeddings(nn.Module):
-  def __init__(self, config: VisionConfig):
-    super().__init__()
-    self.config = config
-    self.embed_dim = config.hidden_size
-    self.image_size = config.image_size
-    self.patch_size = config.patch_size
-
-    self.class_embedding = mx.zeros((config.hidden_size,))
-
-    self.patch_embedding = nn.Conv2d(
-      in_channels=config.num_channels,
-      out_channels=self.embed_dim,
-      kernel_size=self.patch_size,
-      stride=self.patch_size,
-      bias=False,
-    )
-
-    self.num_patches = (self.image_size // self.patch_size)**2
-    self.num_positions = self.num_patches + 1
-    self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
-
-  def __call__(self, x: mx.array) -> mx.array:
-    batch_size = x.shape[0]
-    patch_embeddings = self.patch_embedding(x)
-    patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
-    embed_dim = patch_embeddings.shape[-1]
-    cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim))
-    embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
-    embeddings += self.position_embedding.weight
-    return embeddings
-
-
-class ClipVisionModel(nn.Module):
-  def __init__(self, config: VisionConfig):
-    super().__init__()
-    self.embeddings = VisionEmbeddings(config)
-    self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
-    self.encoder = VisionEncoder(config)
-    self.post_layernorm = nn.LayerNorm(config.hidden_size)
-
-  def __call__(
-    self,
-    x: mx.array,
-    output_hidden_states: Optional[bool] = None,
-  ) -> mx.array:
-    x = self.embeddings(x)
-    x = self.pre_layrnorm(x)
-
-    encoder_states = (x,) if output_hidden_states else None
-
-    for l in self.encoder.layers:
-      x = l(x, mask=None)
-      if output_hidden_states:
-        encoder_states = encoder_states + (x,)
-
-    pooler_output = self.post_layernorm(x[:, 0, :])
-    return pooler_output, x, encoder_states
-
-
-class VisionModel(nn.Module):
-  def __init__(self, config: VisionConfig):
-    super().__init__()
-
-    self.model_type = config.model_type
-    if self.model_type != "clip_vision_model":
-      raise ValueError(f"Unsupported model type: {self.model_type}")
-
-    self.vision_model = ClipVisionModel(config)
-
-  def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array:
-    return self.vision_model(x, output_hidden_states)
-
-  def sanitize(self, weights):
-    sanitized_weights = {}
-    for k, v in weights.items():
-      if "position_ids" in k:
-        # Remove unused position_ids
-        continue
-      elif "patch_embedding.weight" in k:
-        # PyTorch conv2d weight tensors have shape:
-        #   [out_channels, in_channels, kH, KW]
-        # MLX conv2d expects the weight be of shape:
-        #   [out_channels, kH, KW, in_channels]
-        sanitized_weights[k] = v.transpose(0, 2, 3, 1)
-      else:
-        sanitized_weights[k] = v
-
-    return sanitized_weights
-
-
-@dataclass
-class TextConfig:
-  model_type: str
-  hidden_size: int = 4096
-  num_hidden_layers: int = 32
-  intermediate_size: int = 11008
-  num_attention_heads: int = 32
-  head_dim: int = None
-  rms_norm_eps: float = 1e-6
-  vocab_size: int = 32000
-  num_key_value_heads: int = None
-  rope_theta: float = 10000
-  rope_traditional: bool = False
-  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-
-  @classmethod
-  def from_dict(cls, params):
-    return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
-
-  def __post_init__(self):
-    if self.num_key_value_heads is None:
-      self.num_key_value_heads = self.num_attention_heads
-
-    if self.head_dim is None:
-      self.head_dim = self.hidden_size // self.num_attention_heads
-
-    if self.model_type is None:
-      self.model_type = "llama"
-
-    if self.rope_scaling:
-      required_keys = {"factor", "type"}
-      if not all(key in self.rope_scaling for key in required_keys):
-        raise ValueError(f"rope_scaling must contain keys {required_keys}")
-
-      if self.rope_scaling["type"] != "linear":
-        raise ValueError("rope_scaling 'type' currently only supports 'linear'")
-
-
-class TextAttention(nn.Module):
-  def __init__(self, config: TextConfig):
-    super().__init__()
-
-    dim = config.hidden_size
-    self.n_heads = n_heads = config.num_attention_heads
-    self.n_kv_heads = n_kv_heads = config.num_key_value_heads
-
-    self.repeats = n_heads // n_kv_heads
-
-    head_dim = config.hidden_size // n_heads
-    self.scale = head_dim**-0.5
-
-    self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
-    self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
-    self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
-    self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
-
-    rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
-    self.rope = nn.RoPE(
-      head_dim,
-      traditional=config.rope_traditional,
-      base=config.rope_theta,
-      scale=rope_scale,
-    )
-
-  def __call__(
-    self,
-    x: mx.array,
-    mask: Optional[mx.array] = None,
-    cache: Optional[KVCache] = None,
-  ) -> mx.array:
-    B, L, D = x.shape
-
-    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-    # Prepare the queries, keys and values for the attention computation
-    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-    if cache is not None:
-      queries = self.rope(queries, offset=cache.offset)
-      keys = self.rope(keys, offset=cache.offset)
-      keys, values = cache.update_and_fetch(keys, values)
-    else:
-      queries = self.rope(queries)
-      keys = self.rope(keys)
-
-    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
-    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-    return self.o_proj(output)
-
-
-class TextMLP(nn.Module):
-  def __init__(self, dim, hidden_dim):
-    super().__init__()
-    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
-    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
-    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
-
-  def __call__(self, x) -> mx.array:
-    return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
-
-
-class TransformerBlock(nn.Module):
-  def __init__(self, config: TextConfig):
-    super().__init__()
-    self.num_attention_heads = config.num_attention_heads
-    self.hidden_size = config.hidden_size
-    self.self_attn = TextAttention(config)
-    self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
-    self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-    self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-    self.config = config
-
-  def __call__(
-    self,
-    x: mx.array,
-    mask: Optional[mx.array] = None,
-    cache: Optional[KVCache] = None,
-  ) -> mx.array:
-    r = self.self_attn(self.input_layernorm(x), mask, cache)
-    h = x + r
-    r = self.mlp(self.post_attention_layernorm(h))
-    out = h + r
-    return out
-
-
-class Llama(nn.Module):
-  def __init__(self, config: TextConfig, shard: Shard):
-    super().__init__()
-    self.config = config
-    self.shard = shard
-    self.vocab_size = config.vocab_size
-    self.model_type = config.model_type
-    self.num_hidden_layers = config.num_hidden_layers
-    self.num_key_value_heads = config.num_key_value_heads
-    self.head_dim = config.head_dim
-    assert self.vocab_size > 0
-    if self.shard.is_first_layer():
-      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
-    self.layers = []
-    for i in range(self.num_hidden_layers):
-      if self.shard.start_layer <= i <= self.shard.end_layer:
-        self.layers.append(TransformerBlock(config=config))
-      else:
-        self.layers.append(IdentityBlock())
-    if self.shard.is_last_layer():
-      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-    inputs_embeds=None,
-  ):
-    # for passing merged input embeddings
-    if inputs_embeds is None:
-      if self.shard.is_first_layer():
-        h = self.embed_tokens(inputs)
-      else:
-        h = inputs
-    else:
-      h = inputs_embeds
-
-    mask = None
-    if h.shape[1] > 1:
-      mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
-      mask = mask.astype(h.dtype)
-
-    if cache is None:
-      cache = [None]*len(self.layers)
-
-    for layer, c in zip(self.layers, cache):
-      h = layer(h, mask, c)
-
-    if self.shard.is_last_layer():
-      h = self.norm(h)
-    return h
-
-
-class LanguageModel(nn.Module):
-  def __init__(self, config: TextConfig, shard: Shard):
-    super().__init__()
-    self.model_type = config.model_type
-    if self.model_type != "llama":
-      raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported")
-    self.shard = shard
-    self.model = Llama(config, shard)
-    if self.shard.is_last_layer():
-      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-    inputs_embeds=None,
-  ):
-    out = self.model(inputs, cache, inputs_embeds)
-    if self.shard.is_last_layer():
-      out = self.lm_head(out)
-    return out
-
-  def sanitize(self, weights):
-    shard_state_dict = {}
-    for key, value in weights.items():
-      if "self_attn.rotary_emb.inv_freq" in key:
-        continue
-
-      if key.startswith('language_model.model.layers.'):
-        layer_num = int(key.split('.')[3])
-        if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
-          continue
-      if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
-        continue
-      elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
-        continue
-
-      shard_state_dict[key] = value
-
-    return shard_state_dict
-
-
-@dataclass
-class LlaVAConfig(BaseModelArgs):
-  text_config: TextConfig
-  vision_config: VisionConfig = None
-  model_type: str = "llava"
-  ignore_index: int = -100
-  image_token_index: int = 32000
-  vision_feature_select_strategy: str = "default"
-  vision_feature_layer: int = -2
-  vocab_size: int = 32000
-
-  @classmethod
-  def from_dict(cls, params):
-    updated_params = {}
-    class_params = inspect.signature(cls).parameters
-    for k, v in params.items():
-      if k in class_params:
-        if k in ["text_config", "vision_config"]:
-          v = class_params[k].annotation.from_dict(v)
-        updated_params.update({k: v})
-
-    return cls(**updated_params)
-
-
-@dataclass
-class ModelArgs(LlaVAConfig):
-  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-  def __post_init__(self):
-    if isinstance(self.shard, dict):
-      self.shard = Shard(**self.shard)
-
-    if not isinstance(self.shard, Shard):
-      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-    if not self.shard.is_first_layer():
-      self.vision_config = None
-
-
-class LlavaMultiModalProjector(nn.Module):
-  def __init__(self, config: LlaVAConfig):
-    super().__init__()
-    self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
-    self.gelu = nn.GELU()
-    self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
-
-  def __call__(self, x: mx.array) -> mx.array:
-    x = self.linear_1(x)
-    x = self.gelu(x)
-    x = self.linear_2(x)
-    return x
-
-
-class Model(nn.Module):
-  def __init__(self, config: ModelArgs):
-    super().__init__()
-    self.config = config
-    self.model_type = config.model_type
-    if config.vision_config:
-      self.vision_tower = VisionModel(config.vision_config)
-      self.multi_modal_projector = LlavaMultiModalProjector(config)
-      self.vision_feature_layer = config.vision_feature_layer
-      self.vision_feature_select_strategy = config.vision_feature_select_strategy
-    self.language_model = LanguageModel(config.text_config, config.shard)
-
-  def get_input_embeddings(
-    self,
-    input_ids: Optional[mx.array] = None,
-    pixel_values: Optional[mx.array] = None,
-  ):
-    if pixel_values is None:
-      return self.language_model(input_ids)
-
-    # Get the input embeddings from the language model
-    inputs_embeds = self.language_model.model.embed_tokens(input_ids)
-
-    # Get the ouptut hidden states from the vision model
-    *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)
-
-    # Select the hidden states from the desired layer
-    selected_image_feature = hidden_states[self.vision_feature_layer]
-
-    if self.vision_feature_select_strategy == "default":
-      selected_image_feature = selected_image_feature[:, 1:]
-    elif self.vision_feature_select_strategy == "full":
-      selected_image_feature = selected_image_feature
-    else:
-      raise ValueError("Unexpected feature selection strategy: "
-                       f"{self.vision_feature_select_strategy}")
-
-    # Pass image features through the multi-modal projector
-    image_features = self.multi_modal_projector(selected_image_feature)
-
-    # Insert special image tokens in the input_ids
-    final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
-    return final_inputs_embeds
-
-  def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
-    image_token_index = self.config.image_token_index
-    num_images, num_image_patches, embed_dim = image_features.shape
-
-    # Positions of <image> tokens in input_ids, assuming batch size is 1
-    image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
-
-    if len(image_positions) != num_images:
-      raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
-                       f" match the number of image inputs ({num_images}).")
-
-    text_segments = []
-    start_idx = 0
-
-    for position in image_positions:
-      text_segments.append(inputs_embeds[:, start_idx:position])
-      start_idx = position + 1
-
-    image_embeddings = mx.split(image_features, image_features.shape[0])
-    final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
-    final_embeddings += [inputs_embeds[:, start_idx:]]
-
-    # Create a final embedding of shape
-    # (1, num_image_patches*num_images + sequence_len, embed_dim)
-    return mx.concatenate(final_embeddings, axis=1)
-
-  def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
-    input_embddings = None
-    if pixel_values is not None:
-      input_embddings = self.get_input_embeddings(input_ids, pixel_values)
-    logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
-    return logits
-
-  def sanitize(self, weights):
-    if self.config.vision_config:
-      weights = self.vision_tower.sanitize(weights)
-    else:
-      weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
-    weights = self.language_model.sanitize(weights)
-    return weights
-
-  @property
-  def layers(self):
-    return self.language_model.model.layers
-
-  @property
-  def head_dim(self):
-    return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)
-
-  @property
-  def n_kv_heads(self):
-    return self.language_model.model.num_key_value_heads

+ 0 - 128
build/lib/exo/inference/mlx/models/qwen2.py

@@ -1,128 +0,0 @@
-from dataclasses import dataclass, field
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from mlx_lm.models.base import create_attention_mask
-from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
-
-from ...shard import Shard
-from .base import IdentityBlock
-
-
-@dataclass
-class ModelArgs(ModelArgs):
-  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-  def __post_init__(self):
-    super().__post_init__()  # Ensure parent initializations are respected
-
-    if isinstance(self.shard, Shard):
-      return
-    if not isinstance(self.shard, dict):
-      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-    self.shard = Shard(**self.shard)
-
-
-class Qwen2Model(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.vocab_size = args.vocab_size
-    self.num_hidden_layers = args.num_hidden_layers
-    assert self.vocab_size > 0
-    if self.args.shard.is_first_layer():
-      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
-    self.layers = []
-    for i in range(self.num_hidden_layers):
-      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
-        self.layers.append(TransformerBlock(args=args))
-      else:
-        self.layers.append(IdentityBlock())
-    if self.args.shard.is_last_layer():
-      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    if self.args.shard.is_first_layer():
-      h = self.embed_tokens(inputs)
-    else:
-      h = inputs
-
-    mask = None
-    if h.shape[1] > 1:
-      mask = create_attention_mask(h, cache)
-
-    if cache is None:
-      cache = [None]*len(self.layers)
-
-    for layer, c in zip(self.layers, cache):
-      h = layer(h, mask, c)
-
-    if self.args.shard.is_last_layer():
-      h = self.norm(h)
-    return h
-
-
-class Model(nn.Module):
-  def __init__(self, args: ModelArgs):
-    super().__init__()
-    self.args = args
-    self.model_type = args.model_type
-    self.model = Qwen2Model(args)
-    if self.args.shard.is_last_layer():
-      if not args.tie_word_embeddings:
-        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
-
-  def __call__(
-    self,
-    inputs: mx.array,
-    cache=None,
-  ):
-    out = self.model(inputs, cache)
-    if self.args.shard.is_last_layer():
-      if self.args.tie_word_embeddings:
-        out = self.model.embed_tokens.as_linear(out)
-      else:
-        out = self.lm_head(out)
-    return out
-
-  def sanitize(self, weights):
-    shard_state_dict = {}
-
-    for key, value in weights.items():
-      if "self_attn.rotary_emb.inv_freq" in key:
-        continue
-      if key.startswith('model.layers.'):
-        layer_num = int(key.split('.')[2])
-        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
-          shard_state_dict[key] = value
-      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
-        shard_state_dict[key] = value
-      elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
-        shard_state_dict[key] = value
-      elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
-        shard_state_dict[key] = value
-      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
-        shard_state_dict[key] = value
-
-    if self.args.tie_word_embeddings:
-      shard_state_dict.pop("lm_head.weight", None)
-
-    return shard_state_dict
-
-  @property
-  def layers(self):
-    return self.model.layers
-
-  @property
-  def head_dim(self):
-    return self.args.hidden_size // self.args.num_attention_heads
-
-  @property
-  def n_kv_heads(self):
-    return self.args.num_key_value_heads

+ 0 - 77
build/lib/exo/inference/mlx/sharded_inference_engine.py

@@ -1,77 +0,0 @@
-import numpy as np
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.sample_utils import top_p_sampling
-from ..inference_engine import InferenceEngine
-from .stateful_model import StatefulModel
-from .sharded_utils import load_shard
-from ..shard import Shard
-from typing import Dict, Optional, Tuple
-from exo.download.shard_download import ShardDownloader
-import asyncio
-from concurrent.futures import ThreadPoolExecutor
-
-def sample_logits(
-  logits: mx.array,
-  temp: float = 0.0,
-  top_p: float = 1.0,
-  logit_bias: Optional[Dict[int, float]] = None
-) -> Tuple[mx.array, float]:
-  if logit_bias:
-    indices = mx.array(list(logit_bias.keys()))
-    values = mx.array(list(logit_bias.values()))
-    logits[:, indices] += values
-
-  if temp == 0:
-    token = mx.argmax(logits, axis=-1)
-  else:
-    if top_p > 0 and top_p < 1.0:
-      token = top_p_sampling(logits, top_p, temp)
-    else:
-      token = mx.random.categorical(logits*(1/temp))
-
-  return token
-
-class MLXDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, shard_downloader: ShardDownloader):
-    self.shard = None
-    self.shard_downloader = shard_downloader
-    self.executor = ThreadPoolExecutor(max_workers=1)
-
-  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
-    y = mx.array(x)
-    logits = y[:, -1, :]
-    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
-    return out
-
-  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
-    await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
-
-  async def decode(self, shard: Shard, tokens) -> str:
-    await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
-    
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
-    await self.ensure_shard(shard)
-    output_data, inference_state = await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id, inference_state)
-    output_data = np.array(output_data)
-    return output_data, inference_state
-
-  async def ensure_shard(self, shard: Shard):
-    if self.shard == shard:
-      return
-
-    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
-
-    if self.shard != shard:
-      loop = asyncio.get_running_loop()
-
-      def load_shard_wrapper():
-        return asyncio.run(load_shard(model_path, shard))
-
-      model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.shard = shard
-      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 256
build/lib/exo/inference/mlx/sharded_utils.py

@@ -1,256 +0,0 @@
-# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
-
-import glob
-import importlib
-import json
-import logging
-import asyncio
-import aiohttp
-from functools import partial
-from pathlib import Path
-from typing import Optional, Tuple, Union, List, Callable
-from PIL import Image
-from io import BytesIO
-import base64
-import traceback
-
-import mlx.core as mx
-import mlx.nn as nn
-from transformers import AutoProcessor
-
-from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
-
-from exo import DEBUG
-from exo.inference.tokenizers import resolve_tokenizer
-from ..shard import Shard
-
-
-class ModelNotFoundError(Exception):
-  def __init__(self, message):
-    self.message = message
-    super().__init__(self.message)
-
-
-MODEL_REMAPPING = {
-  "mistral": "llama",  # mistral is compatible with llama
-  "phi-msft": "phixtral",
-}
-
-
-def _get_classes(config: dict):
-  """
-  Retrieve the model and model args classes based on the configuration.
-
-  Args:
-   config (dict): The model configuration.
-
-  Returns:
-   A tuple containing the Model class and the ModelArgs class.
-  """
-  model_type = config["model_type"]
-  model_type = MODEL_REMAPPING.get(model_type, model_type)
-  try:
-    arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
-  except ImportError:
-    msg = f"Model type {model_type} not supported."
-    logging.error(msg)
-    traceback.print_exc()
-    raise ValueError(msg)
-
-  return arch.Model, arch.ModelArgs
-
-
-def load_config(model_path: Path) -> dict:
-  try:
-    config_path = model_path / "config.json"
-    if config_path.exists():
-      with open(config_path, "r") as f:
-        config = json.load(f)
-      return config
-    
-    model_index_path = model_path / "model_index.json"
-    if model_index_path.exists():
-      config = load_model_index(model_path, model_index_path)
-      return config
-  except FileNotFoundError:
-    logging.error(f"Config file not found in {model_path}")
-    raise
-  return config
-
-def load_model_shard(
-  model_path: Path,
-  shard: Shard,
-  lazy: bool = False,
-  model_config: dict = {},
-) -> nn.Module:
-  """
-  Load and initialize the model from a given path.
-
-  Args:
-   model_path (Path): The path to load the model from.
-   lazy (bool): If False eval the model parameters to make sure they are
-    loaded in memory before returning, otherwise they will be loaded
-    when needed. Default: ``False``
-   model_config(dict, optional): Configuration parameters for the model.
-    Defaults to an empty dictionary.
-
-  Returns:
-   nn.Module: The loaded and initialized model.
-
-  Raises:
-   FileNotFoundError: If the weight files (.safetensors) are not found.
-   ValueError: If the model class or args class are not found or cannot be instantiated.
-  """
-  config = load_config(model_path)
-  config.update(model_config)
-
-  # TODO hack
-  config["shard"] = {
-    "model_id": model_path.name,
-    "start_layer": shard.start_layer,
-    "end_layer": shard.end_layer,
-    "n_layers": shard.n_layers,
-  }
-
-  weight_files = glob.glob(str(model_path/"model*.safetensors"))
-
-  if not weight_files:
-    # Try weight for back-compat
-    weight_files = glob.glob(str(model_path/"weight*.safetensors"))
-
-  model_class, model_args_class = _get_classes(config=config)
-
-  class ShardedModel(model_class):
-    def __init__(self, args):
-      super().__init__(args)
-      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
-
-    def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x, *args, **kwargs)
-      return y
-
-  model_args = model_args_class.from_dict(config)
-  model = ShardedModel(model_args)
-
-  if config.get("model_index", False):
-    model.load()
-    return model
-
-  if not weight_files:
-    logging.error(f"No safetensors found in {model_path}")
-    raise FileNotFoundError(f"No safetensors found in {model_path}")
-
-  weights = {}
-  for wf in sorted(weight_files):
-    if DEBUG >= 8:
-      layer_nums = set()
-      for k in mx.load(wf):
-        if k.startswith("model.layers."):
-          layer_num = int(k.split(".")[2])
-          layer_nums.add(layer_num)
-        if k.startswith("language_model.model.layers."):
-          layer_num = int(k.split(".")[3])
-          layer_nums.add(layer_num)
-      print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
-
-    weights.update(mx.load(wf))
-
-  
-
-  if hasattr(model, "sanitize"):
-    weights = model.sanitize(weights)
-
-  if (quantization := config.get("quantization", None)) is not None:
-    # Handle legacy models which may not have everything quantized
-    def class_predicate(p, m):
-      if not hasattr(m, "to_quantized"):
-        return False
-      return f"{p}.scales" in weights
-
-    nn.quantize(
-      model,
-      **quantization,
-      class_predicate=class_predicate,
-    )
-
-  model.load_weights(list(weights.items()), strict=True)
-
-  if not lazy:
-    mx.eval(model.parameters())
-
-  model.eval()
-  return model
-
-async def load_shard(
-  model_path: str,
-  shard: Shard,
-  tokenizer_config={},
-  model_config={},
-  adapter_path: Optional[str] = None,
-  lazy: bool = False,
-) -> Tuple[nn.Module, TokenizerWrapper]:
-  model = load_model_shard(model_path, shard, lazy, model_config)
-
-  # TODO: figure out a generic solution
-  if model.model_type == "llava":
-    processor = AutoProcessor.from_pretrained(model_path)
-    processor.eos_token_id = processor.tokenizer.eos_token_id
-    processor.encode = processor.tokenizer.encode
-    return model, processor
-  elif hasattr(model, "tokenizer"):
-    tokenizer = model.tokenizer
-    return model, tokenizer
-  else:
-    tokenizer = await resolve_tokenizer(model_path)
-    return model, tokenizer
-
-
-async def get_image_from_str(_image_str: str):
-  image_str = _image_str.strip()
-
-  if image_str.startswith("http"):
-    async with aiohttp.ClientSession() as session:
-      async with session.get(image_str, timeout=10) as response:
-        content = await response.read()
-        return Image.open(BytesIO(content)).convert("RGB")
-  elif image_str.startswith("data:image/"):
-    # Extract the image format and base64 data
-    format_prefix, base64_data = image_str.split(";base64,")
-    image_format = format_prefix.split("/")[1].lower()
-    if DEBUG >= 2: print(f"{image_str=} {image_format=}")
-    imgdata = base64.b64decode(base64_data)
-    img = Image.open(BytesIO(imgdata))
-
-    # Convert to RGB if not already
-    if img.mode != "RGB":
-      img = img.convert("RGB")
-
-    return img
-  else:
-    raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
-
-# loading a combined config for all models in the index
-def load_model_index(model_path: Path, model_index_path: Path):
-  models_config = {}
-  with open(model_index_path, "r") as f:
-      model_index = json.load(f)
-  models_config["model_index"] = True
-  models_config["model_type"] = model_index["_class_name"]
-  models_config["models"] = {}
-  for model in model_index.keys():
-    model_config_path = glob.glob(str(model_path / model / "*config.json"))
-    if len(model_config_path)>0:
-      with open(model_config_path[0], "r") as f:
-        model_config = { }
-        model_config["model_type"] = model
-        model_config["config"] = json.load(f)
-        model_config["path"] = model_path / model
-        if model_config["path"]/"*model.safetensors":
-          model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
-        model_config["path"] = str(model_path / model)
-        m = {}
-        m[model] = model_config
-        models_config.update(m)
-  models_config = json.dumps(models_config)
-  models_config = json.loads(models_config)
-  return models_config

+ 0 - 45
build/lib/exo/inference/mlx/stateful_model.py

@@ -1,45 +0,0 @@
-from typing import Dict, Tuple, Optional
-from collections import OrderedDict
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.cache import make_prompt_cache
-
-from ..shard import Shard
-
-class StatefulModel(nn.Module):
-  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
-    super().__init__()
-    self.model = model
-    self.max_kv_size = max_kv_size
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
-  
-  def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    # if self.max_kv_size is not None:
-      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    # else:
-      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-    cache = make_prompt_cache(self.model)
-
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache
-
-  def __call__(self, x, request_id: str, inference_state: Optional[dict] = None):
-    if self.model.model_type !='StableDiffusionPipeline':
-      if request_id not in self.caches:
-        self.init_cache(request_id)
-      else:
-        self.caches.move_to_end(request_id)
-
-      cache = self.caches[request_id]
-
-      y = self.model(x, cache=cache)
-    else:
-      y, inference_state = self.model(x, **inference_state)
-    return y, inference_state
-    

+ 0 - 40
build/lib/exo/inference/mlx/test_sharded_llama.py

@@ -1,40 +0,0 @@
-import mlx.core as mx
-from exo.inference.mlx.stateful_model import StatefulModel
-from exo.inference.mlx.sharded_utils import load_shard
-from exo.inference.shard import Shard
-
-# 79, 80 for Llama-3-70B
-shard_full = Shard("llama", 0, 31, 32)
-shard1 = Shard("llama", 0, 12, 32)
-shard2 = Shard("llama", 13, 31, 32)
-
-full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
-model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
-model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
-
-full = StatefulModel(shard_full, full_model_shard)
-m1 = StatefulModel(shard1, model_shard1)
-m2 = StatefulModel(shard2, model_shard2)
-
-prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
-prompt_tokens = mx.array(full_tokenizer.encode(prompt))
-max_tokens = 50
-
-resp = prompt_tokens
-full_generated_tokens = []
-for _ in range(max_tokens):
-  resp = full.step(resp)
-  full_generated_tokens.append(resp.item())
-
-print("full response: ", full_tokenizer.decode(full_generated_tokens))
-
-sharded_generated_tokens = []
-sharded_resp = prompt_tokens
-for _ in range(max_tokens):
-  resp1 = m1.step(sharded_resp)
-  sharded_resp = m2.step(resp1)
-  sharded_generated_tokens.append(sharded_resp.item())
-
-print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
-
-assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)

+ 0 - 64
build/lib/exo/inference/mlx/test_sharded_llava.py

@@ -1,64 +0,0 @@
-import codecs
-import asyncio
-import requests
-from PIL import Image
-from io import BytesIO
-
-import mlx.core as mx
-from mlx_lm.models.cache import KVCache
-
-from exo.inference.mlx.stateful_model import StatefulModel
-from exo.inference.mlx.sharded_utils import load_shard
-from exo.inference.shard import Shard
-
-shard_full = Shard("llava", 0, 31, 32)
-shard1 = Shard("llava", 0, 12, 32)
-shard2 = Shard("llava", 13, 31, 32)
-
-model_path = "llava-hf/llava-1.5-7b-hf"
-
-full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full))
-model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1))
-model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2))
-
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
-
-PROMPT = "USER: <image>\nWhat are these?\nASSISTANT:"
-IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg"
-response = requests.get(IMAGE_FILE)
-img = Image.open(BytesIO(response.content))
-prompt = codecs.decode(PROMPT, "unicode_escape")
-inputs = full_processor(prompt, img, return_tensors="np")
-pixel_values = mx.array(inputs["pixel_values"])
-input_ids = mx.array(inputs["input_ids"])
-
-print(prompt)
-y = full.step("full", input_ids, pixel_values, temp=0)
-full_generated_tokens = [y.item()]
-
-for _ in range(13):
-  y = full.step("full", y, temp=0)
-  full_generated_tokens.append(y.item())
-
-full_response = full_processor.tokenizer.decode(full_generated_tokens)
-print("full response:", full_response)
-
-inputs = processor1(prompt, img, return_tensors="np")
-pixel_values = mx.array(inputs["pixel_values"])
-input_ids = mx.array(inputs["input_ids"])
-
-y = m1.step("shard", input_ids, pixel_values, temp=0)
-y = m2.step("shard", y, temp=0)
-full_generated_tokens = [y.item()]
-
-for _ in range(13):
-  y = m1.step("shard", y, temp=0)
-  y = m2.step("shard", y, temp=0)
-  full_generated_tokens.append(y.item())
-
-sharded_response = processor2.tokenizer.decode(full_generated_tokens)
-print("sharded response:", sharded_response)
-
-assert full_response == sharded_response

+ 0 - 52
build/lib/exo/inference/mlx/test_sharded_model.py

@@ -1,52 +0,0 @@
-from exo.inference.shard import Shard
-import mlx.core as mx
-import mlx.nn as nn
-from typing import Optional
-import numpy as np
-
-
-class DummyModel(nn.Module):
-  def __init__(self, shard: Optional[Shard] = None):
-    self.shard = shard
-    self.layers = [
-      nn.Linear(8, 128),
-      nn.Linear(128, 128),
-      nn.Linear(128, 128),
-      nn.Linear(128, 128),
-      nn.Linear(128, 8),
-    ]
-
-    self.n_kv_heads = 4
-    self.head_dim = 4
-
-  def __call__(self, x, cache=None):
-    if self.shard:
-      for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
-        x = layer(x)
-      if self.shard.is_last_layer():
-        x = x.reshape((1, 2, 4))
-    else:
-      for layer in self.layers:
-        x = layer(x)
-      x = x.reshape((1, 2, 4))
-
-    return x
-
-
-model = DummyModel()
-model.save_weights("./test_weights.npz")
-n_layers = 5
-shard1 = Shard("test", 0, n_layers // 2, n_layers)
-sharded_model1 = DummyModel(shard1)
-shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
-sharded_model2 = DummyModel(shard2)
-
-model.load_weights("./test_weights.npz")
-sharded_model1.load_weights("./test_weights.npz")
-sharded_model2.load_weights("./test_weights.npz")
-
-fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
-resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
-resp2 = sharded_model2(resp1)
-
-assert np.all(np.array(fullresp) == np.array(resp2))

+ 0 - 39
build/lib/exo/inference/shard.py

@@ -1,39 +0,0 @@
-from dataclasses import dataclass, field
-
-
-@dataclass(frozen=True)
-class Shard:
-  model_id: str
-  start_layer: int
-  end_layer: int
-  n_layers: int
-
-  def __hash__(self):
-    return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
-
-  def is_first_layer(self) -> bool:
-    return self.start_layer == 0
-
-  def is_last_layer(self) -> bool:
-    return self.end_layer == self.n_layers - 1
-
-  def get_layer_count(self) -> int:
-    return self.end_layer - self.start_layer + 1
-
-  def to_dict(self) -> dict:
-    return {
-      "model_id": self.model_id,
-      "start_layer": self.start_layer,
-      "end_layer": self.end_layer,
-      "n_layers": self.n_layers,
-    }
-
-  def from_dict(data: dict) -> 'Shard':
-    return Shard(**data)
-
-  def overlaps(self, other: 'Shard') -> bool:
-    return shards_overlap(self, other)
-
-
-def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
-  return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))

+ 0 - 53
build/lib/exo/inference/test_dummy_inference_engine.py

@@ -1,53 +0,0 @@
-import pytest
-import json
-import numpy as np
-from exo.inference.dummy_inference_engine import DummyInferenceEngine
-from exo.inference.shard import Shard
-
-
-class MockShardDownloader:
-  async def ensure_shard(self, shard):
-    pass
-
-
-@pytest.mark.asyncio
-async def test_dummy_inference_specific():
-  engine = DummyInferenceEngine(MockShardDownloader())
-  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-  test_prompt = "This is a test prompt"
-
-  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
-
-  print(f"Inference result shape: {result.shape}")
-
-  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
-
-
-@pytest.mark.asyncio
-async def test_dummy_inference_engine():
-  # Initialize the DummyInferenceEngine
-  engine = DummyInferenceEngine(MockShardDownloader())
-
-  # Create a test shard
-  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-
-  # Test infer_prompt
-  output = await engine.infer_prompt("test_id", shard, "Test prompt")
-
-  assert isinstance(output, np.ndarray), "Output should be a numpy array"
-  assert output.ndim == 2, "Output should be 2-dimensional"
-
-  # Test infer_tensor
-  input_tensor = np.array([[1, 2, 3]])
-  output = await engine.infer_tensor("test_id", shard, input_tensor)
-
-  assert isinstance(output, np.ndarray), "Output should be a numpy array"
-  assert output.ndim == 2, "Output should be 2-dimensional"
-
-  print("All tests passed!")
-
-
-if __name__ == "__main__":
-  import asyncio
-  asyncio.run(test_dummy_inference_engine())
-  asyncio.run(test_dummy_inference_specific())

+ 0 - 56
build/lib/exo/inference/test_inference_engine.py

@@ -1,56 +0,0 @@
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.inference.inference_engine import InferenceEngine
-from exo.inference.shard import Shard
-from exo.helpers import DEBUG
-import os
-import asyncio
-import numpy as np
-
-
-# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
-  prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
-  token_full = await inference_engine_1.sample(resp_full)
-  token_full = token_full.reshape(1, -1)
-  next_resp_full = await inference_engine_1.infer_tensor(
-    "A",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
-    input_data=token_full,
-  )
-
-  pp = n_layers // 2
-  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2 = await inference_engine_2.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
-    input_data=resp1,
-  )
-  tokens2 = await inference_engine_1.sample(resp2)
-  tokens2 = tokens2.reshape(1, -1)
-  resp3 = await inference_engine_1.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
-    input_data=tokens2,
-  )
-  resp4 = await inference_engine_2.infer_tensor(
-    "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
-    input_data=resp3,
-  )
-
-  assert np.array_equal(resp_full, resp2)
-  assert np.array_equal(next_resp_full, resp4)
-
-
-asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
-
-if os.getenv("RUN_TINYGRAD", default="0") == "1":
-  import tinygrad
-  import os
-  from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-  tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
-  asyncio.run(
-    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
-  )

+ 0 - 0
build/lib/exo/inference/tinygrad/__init__.py


+ 0 - 99
build/lib/exo/inference/tinygrad/inference.py

@@ -1,99 +0,0 @@
-from pathlib import Path
-import json
-import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
-from exo.inference.shard import Shard
-from exo.inference.tokenizers import resolve_tokenizer
-from tinygrad.nn.state import load_state_dict
-from tinygrad import Tensor, nn, Context
-from exo.inference.inference_engine import InferenceEngine
-import numpy as np
-from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
-from exo.download.shard_download import ShardDownloader
-from concurrent.futures import ThreadPoolExecutor
-from .stateful_model import StatefulModel
-import asyncio
-
-Tensor.no_grad = True
-# default settings
-TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
-TOP_K = 25
-TOP_P = 0.9
-ALPHA_F = 0.1
-ALPHA_P = 0.0
-MODEL_PARAMS = {
-  "1B": {
-    "args": {
-      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
-      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
-    }, "files": 1
-  }, "3B": {
-    "args": {
-      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
-      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
-    }, "files": 1
-  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
-  "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
-}
-
-
-def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
-  # build model
-  linear = nn.Linear
-  model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
-
-  # load weights
-  if model_path.is_dir():
-    if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
-    elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
-    else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
-  else:
-    weights = load(str(model_path), shard)
-  weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
-  weights = fix_bf16(weights)
-
-  with Context(BEAM=0):
-    # replace weights in model
-    load_state_dict(model, weights, strict=False, consume=False)  # consume=True
-  return model
-
-class TinygradDynamicShardInferenceEngine(InferenceEngine):
-  def __init__(self, shard_downloader: ShardDownloader):
-    self.shard = None
-    self.shard_downloader = shard_downloader
-    self.executor = ThreadPoolExecutor(max_workers=1)
-
-  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
-    logits = x[:, -1, :]
-    def sample_wrapper():
-      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
-    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
-
-  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
-    await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
-  
-  async def decode(self, shard: Shard, tokens) -> str:
-    await self.ensure_shard(shard)
-    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
-    await self.ensure_shard(shard)
-    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
-
-  async def ensure_shard(self, shard: Shard):
-    if self.shard == shard:
-      return
-
-    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
-
-    if self.shard != shard:
-      loop = asyncio.get_running_loop()
-      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
-      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
-
-      tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
-      self.tokenizer = await resolve_tokenizer(tokenizer_path)
-      self.shard = shard
-      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 0
build/lib/exo/inference/tinygrad/models/__init__.py


+ 0 - 282
build/lib/exo/inference/tinygrad/models/llama.py

@@ -1,282 +0,0 @@
-from typing import Tuple, Union, Optional, Dict, Any, List
-from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
-from tinygrad.helpers import getenv
-from collections import OrderedDict
-
-
-# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
-  freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
-
-  if rope_scaling:
-    factor = rope_scaling.get('factor', 1.0)
-    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
-    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
-    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
-
-    freqs[:dim // 4] *= low_freq_factor
-    freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
-    freqs *= (original_max_pos_emb/end)**(1.0/factor)
-
-  freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
-  # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
-
-
-# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
-def complex_mult(A, c, d):
-  a, b = A[..., 0:1], A[..., 1:2]
-  ro = a*c - b*d
-  co = a*d + b*c
-  return ro.cat(co, dim=-1)
-
-
-def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
-  assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
-  xq = xq.reshape(*xq.shape[0:-1], -1, 2)
-  xk = xk.reshape(*xk.shape[0:-1], -1, 2)
-  assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
-  c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
-  xq_out = complex_mult(xq, c, d)
-  xk_out = complex_mult(xk, c, d)
-  return xq_out.flatten(3), xk_out.flatten(3)
-
-
-def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
-  bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1: return x
-  # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
-  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
-
-class Attention:
-  def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
-    self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
-    self.head_dim = dim // n_heads
-    self.n_rep = self.n_heads // self.n_kv_heads
-    self.max_context = max_context
-
-    self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
-    self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
-    self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
-    self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
-
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
-    if getenv("WQKV"):
-      if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
-      xqkv = x @ self.wqkv.T
-      xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
-    else:
-      xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
-
-    xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
-    xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
-    xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
-
-    xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
-    bsz, seqlen, _, _ = xq.shape
-
-    if cache is not None:
-      # update the cache
-      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
-      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
-
-      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
-    else:
-      keys = xk
-      values = xv
-
-    keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
-    xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
-    attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
-    attn = attn.reshape(bsz, seqlen, -1)
-    return self.wo(attn)
-
-
-class FeedForward:
-  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
-    self.w1 = linear(dim, hidden_dim, bias=False)
-    self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
-
-  def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu()*self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
-
-
-class TransformerBlock:
-  def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
-    self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
-    self.feed_forward = feed_forward(dim, hidden_dim, linear)
-    self.attention_norm = nn.RMSNorm(dim, norm_eps)
-    self.ffn_norm = nn.RMSNorm(dim, norm_eps)
-
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
-    return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
-
-
-# standard openai sampling
-def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
-  assert logits.ndim == 1, "only works on 1d tensors"
-  assert 0 <= p <= 1, "p must be between 0 and 1"
-  assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
-
-  # if temperature is very low just use argmax
-  if temp < 1e-6: return logits.argmax().reshape(1)
-
-  # alpha sampling
-  if af or ap:
-    if not hasattr(sample, "alpha_counter"):
-      setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
-    logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
-
-  # replace NaNs with -inf
-  logits = (logits != logits).where(-float("inf"), logits)
-
-  # softmax
-  t = (logits/temp).softmax()
-
-  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
-  # top k
-  if k:
-    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
-    for i in range(k):
-      t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
-      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
-      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
-      t = (counter == t_argmax).where(0, t)
-
-    # approximate top p
-    # because we are already limited to top k elements we can do top p "without sorting"
-    output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
-    output = (output_cumsum >= (1 - p))*output
-    output_indices = (output_cumsum >= (1 - p))*output_indices
-
-    # sample
-    output_idx = output.multinomial()
-    output_token = output_indices[output_idx]
-  else:
-    output_token = t.multinomial()
-
-  # increase alpha counter
-  if af or ap:
-    sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
-
-  return output_token
-
-
-from exo.inference.shard import Shard
-
-
-class Transformer:
-  def __init__(
-    self,
-    dim: int,
-    hidden_dim: int,
-    n_heads: int,
-    n_layers: int,
-    norm_eps: float,
-    vocab_size,
-    shard: Shard = None,
-    linear=nn.Linear,
-    n_kv_heads=None,
-    rope_theta=10000,
-    max_context=1024,
-    jit=True,
-    feed_forward=FeedForward,
-    rope_scaling: Optional[Dict[str, float]] = None,
-    tie_word_embeddings=False,
-  ):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
-    self.norm = nn.RMSNorm(dim, norm_eps)
-    self.tok_embeddings = nn.Embedding(vocab_size, dim)
-    self.output = nn.Linear(dim, vocab_size, bias=False)
-    if tie_word_embeddings:
-      self.output.weight = self.tok_embeddings.weight
-    self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
-    self.forward_jit = TinyJit(self.forward_base) if jit else None
-    self.shard = shard
-
-  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
-    seqlen = x.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
-    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
-
-    h = x
-
-    if cache is None:
-      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
-    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
-      layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask, cache=c)
-
-    if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float().realize()
-      return logits
-    else:
-      return h
-
-  def embed(self, inputs: Tensor):
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(inputs)
-    else:
-      h = inputs
-    return h
-
-  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
-    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
-      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
-    return self.forward_base(x, start_pos, cache=cache)
-
-  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
-    # TODO: better way to handle the first call v.s. the rest?
-    h = self.embed(x)
-    return self.forward(h, start_pos, cache=cache)
-
-
-# *** helpers ***
-
-
-def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
-  def permute(v: Tensor, n_heads: int):
-    return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
-
-  keymap = {
-    "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
-       for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
-       for x in ["q", "k", "v", "o"]
-       for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
-       for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
-       for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
-       for l in range(len(model.layers))},
-    "model.norm.weight": "norm.weight",
-    "lm_head.weight": "output.weight",
-  }
-  sd = {}
-  for k, v in weights.items():
-    if ".rotary_emb." in k: continue
-    v = v.to(Device.DEFAULT)
-    if "model.layers" in k:
-      if "q_proj" in k:
-        v = permute(v, n_heads)
-      elif "k_proj" in k:
-        v = permute(v, n_kv_heads)
-    if k in keymap:
-      sd[keymap[k]] = v
-    else:
-      sd[k] = v
-  return sd
-
-
-def fix_bf16(weights: Dict[Any, Tensor]):
-  if getenv("SUPPORT_BF16", 1):
-    # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
-  # TODO: check if device supports bf16
-  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 0 - 42
build/lib/exo/inference/tinygrad/stateful_model.py

@@ -1,42 +0,0 @@
-from tinygrad import Tensor, Variable 
-from collections import OrderedDict
-from typing import List
-
-def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
-  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
-  if isinstance(x.device, tuple):
-    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
-  return cache_kv.realize()
-
-class ModelState:
-  cache: List[Tensor]
-  start: int 
-  def __init__(self, cache: List[Tensor], start: int = 0):
-    self.cache = cache
-    self.start = start
-
-class StatefulModel:
-  def __init__(self, model, max_states: int = 2):
-    super().__init__()
-    self.model = model
-    self.max_states = max_states
-    self.states = OrderedDict()
- 
-  def init_cache(self, x: Tensor, request_id: str):
-    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
-    if len(self.states) >= self.max_states:
-      self.states.popitem(last=False)
-
-    self.states[request_id] = ModelState(cache)
-
-  def __call__(self, x: Tensor, request_id: str): 
-    h = self.model.embed(x)
-    if request_id not in self.states:
-      self.init_cache(h, request_id)
-    else:
-      self.states.move_to_end(request_id)
-    out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
-    self.states[request_id].start += h.shape[1]
-    return out
-

+ 0 - 52
build/lib/exo/inference/tinygrad/tinygrad_helpers.py

@@ -1,52 +0,0 @@
-from tinygrad.nn.state import safe_load, torch_load
-from tinygrad import Tensor
-from pathlib import Path
-import json
-from typing import List
-from exo.inference.shard import Shard
-from exo.helpers import DEBUG
-from exo.download.hf.hf_helpers import get_allow_patterns
-from fnmatch import fnmatch
-import re
-
-
-# **** helper functions ****
-def concat_weights(models, device=None):
-  def convert(name) -> Tensor:
-    disk_tensors: List[Tensor] = [model[name] for model in models]
-    if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
-      return disk_tensors[0].to(device=device)
-    axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
-    lazy_tensors = [data.to(device=device) for data in disk_tensors]
-    return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
-
-  return {name: convert(name) for name in {name: None for model in models for name in model}}
-
-
-def load(fn: str, shard: Shard):
-  if fn.endswith('.index.json'):
-    with open(fn) as fp:
-      weight_map = json.load(fp)['weight_map']
-    parts = {}
-    filtered_weight_map = {}
-    allow_patterns = get_allow_patterns(weight_map, shard)
-    for k, n in weight_map.items():
-      if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
-        continue
-      if k.startswith("model.layers."):
-        layer_num = int(k.split('.')[2])
-        if layer_num < shard.start_layer or layer_num > shard.end_layer:
-          continue
-
-      parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
-      filtered_weight_map[k] = n
-    if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
-    return {k: parts[n][k] for k, n in filtered_weight_map.items()}
-  elif fn.endswith(".safetensors"):
-    weight_map = safe_load(fn)
-    for k in list(weight_map):
-      if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
-          del weight_map[k]
-    return weight_map
-  else:
-    return torch_load(fn)

+ 0 - 64
build/lib/exo/inference/tokenizers.py

@@ -1,64 +0,0 @@
-import traceback
-from aiofiles import os as aios
-from os import PathLike
-from pathlib import Path
-from typing import Union
-from transformers import AutoTokenizer, AutoProcessor
-import numpy as np
-from exo.download.hf.hf_helpers import get_local_snapshot_dir
-from exo.helpers import DEBUG
-
-
-class DummyTokenizer:
-  def __init__(self):
-    self.eos_token_id = 69
-    self.vocab_size = 1000
-
-  def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
-    return "dummy_tokenized_prompt"
-
-  def encode(self, text):
-    return np.array([1])
-
-  def decode(self, tokens):
-    return "dummy" * len(tokens)
-
-
-async def resolve_tokenizer(model_id: str):
-  if model_id == "dummy":
-    return DummyTokenizer()
-  local_path = await get_local_snapshot_dir(model_id)
-  if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
-  try:
-    if local_path and await aios.path.exists(local_path):
-      if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
-      return await _resolve_tokenizer(local_path)
-  except:
-    if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
-    if DEBUG >= 5: traceback.print_exc()
-  return await _resolve_tokenizer(model_id)
-
-
-async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
-  try:
-    if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
-    processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
-    if not hasattr(processor, 'eos_token_id'):
-      processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
-    if not hasattr(processor, 'encode'):
-      processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
-    if not hasattr(processor, 'decode'):
-      processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
-    return processor
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  try:
-    if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
-    return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
-  except Exception as e:
-    if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
-    if DEBUG >= 4: print(traceback.format_exc())
-
-  raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")

+ 0 - 274
build/lib/exo/main.py

@@ -1,274 +0,0 @@
-import argparse
-import asyncio
-import atexit
-import signal
-import json
-import logging
-import platform
-import os
-import sys
-import time
-import traceback
-import uuid
-from exo.networking.manual.manual_discovery import ManualDiscovery
-from exo.networking.manual.network_topology_config import NetworkTopology
-from exo.orchestration.standard_node import StandardNode
-from exo.networking.grpc.grpc_server import GRPCServer
-from exo.networking.udp.udp_discovery import UDPDiscovery
-from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
-from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
-from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
-from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
-from exo.inference.shard import Shard
-from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.inference.tokenizers import resolve_tokenizer
-from exo.orchestration.node import Node
-from exo.models import build_base_shard, get_repo
-from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
-
-# parse args
-parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
-parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
-parser.add_argument("model_name", nargs="?", help="Model name to run")
-parser.add_argument("--default-model", type=str, default=None, help="Default model")
-parser.add_argument("--node-id", type=str, default=None, help="Node ID")
-parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
-parser.add_argument("--node-port", type=int, default=None, help="Node port")
-parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
-parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
-parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
-parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
-parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
-parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
-parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
-parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
-parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
-parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
-parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
-parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
-parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
-parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
-parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
-parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
-parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
-parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
-parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
-parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
-args = parser.parse_args()
-print(f"Selected inference engine: {args.inference_engine}")
-
-print_yellow_exo()
-
-system_info = get_system_info()
-print(f"Detected system: {system_info}")
-
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
-                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
-inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
-print(f"Inference engine name after selection: {inference_engine_name}")
-
-inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
-print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
-
-if args.node_port is None:
-  args.node_port = find_available_port(args.node_host)
-  if DEBUG >= 1: print(f"Using available port: {args.node_port}")
-
-args.node_id = args.node_id or get_or_create_node_id()
-chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()]
-web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()]
-if DEBUG >= 0:
-  print("Chat interface started:")
-  for web_chat_url in web_chat_urls:
-    print(f" - {terminal_link(web_chat_url)}")
-  print("ChatGPT API endpoint served at:")
-  for chatgpt_api_endpoint in chatgpt_api_endpoints:
-    print(f" - {terminal_link(chatgpt_api_endpoint)}")
-
-# Convert node-id-filter to list if provided
-allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
-
-if args.discovery_module == "udp":
-  discovery = UDPDiscovery(
-    args.node_id,
-    args.node_port,
-    args.listen_port,
-    args.broadcast_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
-    discovery_timeout=args.discovery_timeout,
-    allowed_node_ids=allowed_node_ids
-  )
-elif args.discovery_module == "tailscale":
-  discovery = TailscaleDiscovery(
-    args.node_id,
-    args.node_port,
-    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
-    discovery_timeout=args.discovery_timeout,
-    tailscale_api_key=args.tailscale_api_key,
-    tailnet=args.tailnet_name,
-    allowed_node_ids=allowed_node_ids
-  )
-elif args.discovery_module == "manual":
-  if not args.discovery_config_path:
-    raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
-  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
-node = StandardNode(
-  args.node_id,
-  None,
-  inference_engine,
-  discovery,
-  partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
-  max_generate_tokens=args.max_generate_tokens,
-  topology_viz=topology_viz,
-  shard_downloader=shard_downloader,
-  default_sample_temperature=args.default_temp
-)
-server = GRPCServer(node, args.node_host, args.node_port)
-node.server = server
-api = ChatGPTAPI(
-  node,
-  inference_engine.__class__.__name__,
-  response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
-  default_model=args.default_model
-)
-node.on_token.register("update_topology_viz").on_next(
-  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
-)
-
-def preemptively_start_download(request_id: str, opaque_status: str):
-  try:
-    status = json.loads(opaque_status)
-    if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
-      current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
-      if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
-  except Exception as e:
-    if DEBUG >= 2:
-      print(f"Failed to preemptively start download: {e}")
-      traceback.print_exc()
-
-
-node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
-
-if args.prometheus_client_port:
-  from exo.stats.metrics import start_metrics_server
-  start_metrics_server(node, args.prometheus_client_port)
-
-last_broadcast_time = 0
-
-
-def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-  global last_broadcast_time
-  current_time = time.time()
-  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-    last_broadcast_time = current_time
-    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
-
-
-shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
-
-async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  inference_class = inference_engine.__class__.__name__
-  shard = build_base_shard(model_name, inference_class)
-  if not shard:
-    print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
-    return
-  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
-  request_id = str(uuid.uuid4())
-  callback_id = f"cli-wait-response-{request_id}"
-  callback = node.on_token.register(callback_id)
-  if topology_viz:
-    topology_viz.update_prompt(request_id, prompt)
-  prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
-
-  try:
-    print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, request_id=request_id)
-
-    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
-
-    print("\nGenerated response:")
-    print(tokenizer.decode(tokens))
-  except Exception as e:
-    print(f"Error processing prompt: {str(e)}")
-    traceback.print_exc()
-  finally:
-    node.on_token.deregister(callback_id)
-
-def clean_path(path):
-    """Clean and resolve path"""
-    if path.startswith("Optional("):
-        path = path.strip('Optional("').rstrip('")')
-    return os.path.expanduser(path)
-
-async def main():
-  loop = asyncio.get_running_loop()
-
-  # Check HuggingFace directory permissions
-  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
-  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
-  print(f"{has_read=}, {has_write=}")
-  if not has_read or not has_write:
-    print(f"""
-          WARNING: Limited permissions for model storage directory: {hf_home}.
-          This may prevent model downloads from working correctly.
-          {"❌ No read access" if not has_read else ""}
-          {"❌ No write access" if not has_write else ""}
-          """)
-    
-  if not args.models_seed_dir is None:
-    try:
-      models_seed_dir = clean_path(args.models_seed_dir)
-      await move_models_to_hf(models_seed_dir)
-    except Exception as e:
-      print(f"Error moving models to .cache/huggingface: {e}")
-
-  def restore_cursor():
-    if platform.system() != "Windows":
-        os.system("tput cnorm")  # Show cursor
-
-  # Restore the cursor when the program exits
-  atexit.register(restore_cursor)
-
-  # Use a more direct approach to handle signals
-  def handle_exit():
-    asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
-
-  if platform.system() != "Windows":
-    for s in [signal.SIGINT, signal.SIGTERM]:
-      loop.add_signal_handler(s, handle_exit)
-
-  await node.start(wait_for_peers=args.wait_for_peers)
-
-  if args.command == "run" or args.run_model:
-    model_name = args.model_name or args.run_model
-    if not model_name:
-      print("Error: Model name is required when using 'run' command or --run-model")
-      return
-    await run_model_cli(node, inference_engine, model_name, args.prompt)
-  else:
-    asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
-    await asyncio.Event().wait()
-
-
-def run():
-  loop = asyncio.new_event_loop()
-  asyncio.set_event_loop(loop)
-  try:
-    loop.run_until_complete(main())
-  except KeyboardInterrupt:
-    print("Received keyboard interrupt. Shutting down...")
-  finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
-    loop.close()
-
-
-if __name__ == "__main__":
-  run()

+ 0 - 151
build/lib/exo/models.py

@@ -1,151 +0,0 @@
-from exo.inference.shard import Shard
-from typing import Optional, List
-
-model_cards = {
-  ### llama
-  "llama-3.2-1b": {
-    "layers": 16,
-    "repo": {
-      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
-      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
-    },
-  },
-  "llama-3.2-3b": {
-    "layers": 28,
-    "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
-       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
-    },
-  },
-  "llama-3.1-8b": {
-    "layers": 32,
-    "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
-       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
-    },
-  },
-  "llama-3.1-70b": {
-    "layers": 80,
-    "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
-       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
-    },
-  },
-  "llama-3.1-70b-bf16": {
-    "layers": 80,
-    "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
-       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
-    },
-  },
-  "llama-3-8b": {
-    "layers": 32,
-    "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-    },
-  },
-  "llama-3-70b": {
-    "layers": 80,
-    "repo": {
-       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
-       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
-    },
-  },
-  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
-  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
-  ### mistral
-  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
-  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
-  ### deepseek
-  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
-  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
-  ### llava
-  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
-  ### qwen
-  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
-  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
-  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
-  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
-  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
-  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
-  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
-  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
-  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
-  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
-  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
-  ### nemotron
-  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
-  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
-  # gemma
-  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
-  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
-  # stable diffusion
-  "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
-  # dummy
-  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
-}
-
-pretty_name = {
-  "llama-3.2-1b": "Llama 3.2 1B",
-  "llama-3.2-3b": "Llama 3.2 3B",
-  "llama-3.1-8b": "Llama 3.1 8B",
-  "llama-3.1-70b": "Llama 3.1 70B",
-  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
-  "llama-3.1-405b": "Llama 3.1 405B",
-  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
-  "gemma2-9b": "Gemma2 9B",
-  "gemma2-27b": "Gemma2 27B",
-  "nemotron-70b": "Nemotron 70B",
-  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
-  "mistral-nemo": "Mistral Nemo",
-  "mistral-large": "Mistral Large",
-  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
-  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
-  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
-  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
-  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
-  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
-  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
-  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
-  "qwen-2.5-7b": "Qwen 2.5 7B",
-  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
-  "qwen-2.5-14b": "Qwen 2.5 14B",
-  "qwen-2.5-72b": "Qwen 2.5 72B",
-  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
-  "llama-3-8b": "Llama 3 8B",
-  "llama-3-70b": "Llama 3 70B",
-  "stable-diffusion-2-1-base": "Stable Diffusion 2.1",
-}
-
-def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
-  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
-
-def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
-  repo = get_repo(model_id, inference_engine_classname)
-  n_layers = model_cards.get(model_id, {}).get("layers", 0)
-  if repo is None or n_layers < 1:
-    return None
-  return Shard(model_id, 0, 0, n_layers)
-
-def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
-  if not supported_inference_engine_lists:
-    return list(model_cards.keys())
-
-  from exo.inference.inference_engine import inference_engine_classes
-  supported_inference_engine_lists = [
-    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
-    for engine_list in supported_inference_engine_lists
-  ]
-
-  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
-    return any(engine in model_info.get("repo", {}) for engine in engine_list)
-
-  def supports_all_engine_lists(model_info: dict) -> bool:
-    return all(has_any_engine(model_info, engine_list)
-              for engine_list in supported_inference_engine_lists)
-
-  return [
-    model_id for model_id, model_info in model_cards.items()
-    if supports_all_engine_lists(model_info)
-  ]

+ 0 - 5
build/lib/exo/networking/__init__.py

@@ -1,5 +0,0 @@
-from .discovery import Discovery
-from .peer_handle import PeerHandle
-from .server import Server
-
-__all__ = ["Discovery", "PeerHandle", "Server"]

+ 0 - 17
build/lib/exo/networking/discovery.py

@@ -1,17 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import List
-from .peer_handle import PeerHandle
-
-
-class Discovery(ABC):
-  @abstractmethod
-  async def start(self) -> None:
-    pass
-
-  @abstractmethod
-  async def stop(self) -> None:
-    pass
-
-  @abstractmethod
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    pass

+ 0 - 0
build/lib/exo/networking/grpc/__init__.py


+ 0 - 173
build/lib/exo/networking/grpc/grpc_peer_handle.py

@@ -1,173 +0,0 @@
-import grpc
-import numpy as np
-import asyncio
-from typing import Optional, Tuple, List
-
-from . import node_service_pb2
-from . import node_service_pb2_grpc
-
-from ..peer_handle import PeerHandle
-from exo.inference.shard import Shard
-from exo.topology.topology import Topology
-from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
-from exo.helpers import DEBUG
-import json
-import mlx.core as mx
-
-class GRPCPeerHandle(PeerHandle):
-  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
-    self._id = _id
-    self.address = address
-    self._device_capabilities = device_capabilities
-    self.channel = None
-    self.stub = None
-
-  def id(self) -> str:
-    return self._id
-
-  def addr(self) -> str:
-    return self.address
-
-  def device_capabilities(self) -> DeviceCapabilities:
-    return self._device_capabilities
-
-  async def connect(self):
-    if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(self.address, options=[
-        ("grpc.max_metadata_size", 32*1024*1024),
-        ('grpc.max_receive_message_length', 32*1024*1024),
-        ('grpc.max_send_message_length', 32*1024*1024)
-      ])
-      self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
-    await self.channel.channel_ready()
-
-  async def is_connected(self) -> bool:
-    return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
-
-  async def disconnect(self):
-    if self.channel:
-      await self.channel.close()
-    self.channel = None
-    self.stub = None
-
-  async def _ensure_connected(self):
-    if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
-
-  async def health_check(self) -> bool:
-    try:
-      await self._ensure_connected()
-      request = node_service_pb2.HealthCheckRequest()
-      response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
-      return response.is_healthy
-    except asyncio.TimeoutError:
-      return False
-    except Exception:
-      if DEBUG >= 4:
-        print(f"Health check failed for {self._id}@{self.address}.")
-        import traceback
-        traceback.print_exc()
-      return False
-
-  async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
-    request = node_service_pb2.PromptRequest(
-      prompt=prompt,
-      shard=node_service_pb2.Shard(
-        model_id=shard.model_id,
-        start_layer=shard.start_layer,
-        end_layer=shard.end_layer,
-        n_layers=shard.n_layers,
-      ),
-      request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
-    )
-    response = await self.stub.SendPrompt(request)
-
-    if not response.tensor_data or not response.shape or not response.dtype:
-      return None
-
-    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
-    request = node_service_pb2.TensorRequest(
-      shard=node_service_pb2.Shard(
-        model_id=shard.model_id,
-        start_layer=shard.start_layer,
-        end_layer=shard.end_layer,
-        n_layers=shard.n_layers,
-      ),
-      tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
-      request_id=request_id,
-      inference_state=self.serialize_inference_state(inference_state)
-    )
-    response = await self.stub.SendTensor(request)
-
-    if not response.tensor_data or not response.shape or not response.dtype:
-      return None
-
-    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
-    response = await self.stub.GetInferenceResult(request)
-    if response.tensor is None:
-      return None, response.is_finished
-    return (
-      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
-      response.is_finished,
-    )
-
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-    request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
-    response = await self.stub.CollectTopology(request)
-    topology = Topology()
-    for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(
-        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
-      )
-      topology.update_node(node_id, device_capabilities)
-    for node_id, peers in response.peer_graph.items():
-      for peer_id in peers.peer_ids:
-        topology.add_edge(node_id, peer_id)
-    return topology
-
-  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    tensor = None
-    if isinstance(result, np.ndarray):
-      tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
-      result = []
-    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
-    await self.stub.SendResult(request)
-
-  async def send_opaque_status(self, request_id: str, status: str) -> None:
-    request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
-    await self.stub.SendOpaqueStatus(request)
-
-  def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
-    proto_inference_state = node_service_pb2.InferenceState()
-    other_data = {}
-    for k, v in inference_state.items():
-        if isinstance(v, mx.array):
-            np_array = np.array(v)
-            tensor_data = node_service_pb2.Tensor(
-                tensor_data=np_array.tobytes(),
-                shape=list(np_array.shape),
-                dtype=str(np_array.dtype)
-            )
-            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
-        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
-            tensor_list = node_service_pb2.TensorList()
-            for tensor in v:
-                np_array = np.array(tensor)
-                tensor_data = node_service_pb2.Tensor(
-                    tensor_data=np_array.tobytes(),
-                    shape=list(np_array.shape),
-                    dtype=str(np_array.dtype)
-                )
-                tensor_list.tensors.append(tensor_data)
-            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
-        else:
-            # For non-tensor data, we'll still use JSON
-            other_data[k] = v
-    if other_data:
-      proto_inference_state.other_data_json = json.dumps(other_data)
-    return proto_inference_state

+ 0 - 147
build/lib/exo/networking/grpc/grpc_server.py

@@ -1,147 +0,0 @@
-import grpc
-from concurrent import futures
-import numpy as np
-from asyncio import CancelledError
-
-from . import node_service_pb2
-from . import node_service_pb2_grpc
-from exo import DEBUG
-from exo.inference.shard import Shard
-from exo.orchestration import Node
-import json
-import mlx.core as mx
-
-
-class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
-  def __init__(self, node: Node, host: str, port: int):
-    self.node = node
-    self.host = host
-    self.port = port
-    self.server = None
-
-  async def start(self) -> None:
-    self.server = grpc.aio.server(
-      futures.ThreadPoolExecutor(max_workers=10),
-      options=[
-        ("grpc.max_metadata_size", 32*1024*1024),
-        ("grpc.max_send_message_length", 128*1024*1024),
-        ("grpc.max_receive_message_length", 128*1024*1024),
-      ],
-    )
-    node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
-    listen_addr = f"{self.host}:{self.port}"
-    self.server.add_insecure_port(listen_addr)
-    await self.server.start()
-    if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
-
-  async def stop(self) -> None:
-    if self.server:
-      try:
-        await self.server.stop(grace=5)
-        await self.server.wait_for_termination()
-      except CancelledError:
-        pass
-      if DEBUG >= 1: print("Server stopped and all connections are closed")
-
-  async def SendPrompt(self, request, context):
-    shard = Shard(
-      model_id=request.shard.model_id,
-      start_layer=request.shard.start_layer,
-      end_layer=request.shard.end_layer,
-      n_layers=request.shard.n_layers,
-    )
-    prompt = request.prompt
-    request_id = request.request_id
-    inference_state = self.deserialize_inference_state(request.inference_state)
-    result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
-    tensor_data = result.tobytes() if result is not None else None
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
-
-  async def SendTensor(self, request, context):
-    shard = Shard(
-      model_id=request.shard.model_id,
-      start_layer=request.shard.start_layer,
-      end_layer=request.shard.end_layer,
-      n_layers=request.shard.n_layers,
-    )
-    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
-    request_id = request.request_id
-
-    inference_state = self.deserialize_inference_state(request.inference_state)
-
-    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
-    tensor_data = result.tobytes() if result is not None else None
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
-
-  async def GetInferenceResult(self, request, context):
-    request_id = request.request_id
-    result = await self.node.get_inference_result(request_id)
-    if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}")
-    tensor_data = result[0].tobytes() if result[0] is not None else None
-    return (
-      node_service_pb2.InferenceResult(
-        tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
-        is_finished=result[1],
-      ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
-    )
-
-  async def CollectTopology(self, request, context):
-    max_depth = request.max_depth
-    visited = set(request.visited)
-    topology = await self.node.collect_topology(visited, max_depth)
-    nodes = {
-      node_id:
-        node_service_pb2.DeviceCapabilities(
-          model=cap.model,
-          chip=cap.chip,
-          memory=cap.memory,
-          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
-        )
-      for node_id, cap in topology.nodes.items()
-    }
-    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
-    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
-    return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
-
-  async def SendResult(self, request, context):
-    request_id = request.request_id
-    result = request.result
-    is_finished = request.is_finished
-    img = request.tensor
-    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
-    result = list(result)
-    if len(img.tensor_data) > 0:
-      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
-    self.node.on_token.trigger_all(request_id, result, is_finished)
-    return node_service_pb2.Empty()
-
-  async def SendOpaqueStatus(self, request, context):
-    request_id = request.request_id
-    status = request.status
-    if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
-    self.node.on_opaque_status.trigger_all(request_id, status)
-    return node_service_pb2.Empty()
-
-  async def HealthCheck(self, request, context):
-    return node_service_pb2.HealthCheckResponse(is_healthy=True)
-
-  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
-    inference_state = {}
-    
-    for k, tensor_data in inference_state_proto.tensor_data.items():
-        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
-        inference_state[k] = mx.array(np_array)
-    
-    for k, tensor_list in inference_state_proto.tensor_list_data.items():
-        inference_state[k] = [
-            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
-            for tensor in tensor_list.tensors
-        ]
-    
-    if inference_state_proto.other_data_json:
-        other_data = json.loads(inference_state_proto.other_data_json)
-        inference_state.update(other_data)
-    
-    return inference_state

文件差異過大導致無法顯示
+ 0 - 16
build/lib/exo/networking/grpc/node_service_pb2.py


+ 0 - 360
build/lib/exo/networking/grpc/node_service_pb2_grpc.py

@@ -1,360 +0,0 @@
-# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
-"""Client and server classes corresponding to protobuf-defined services."""
-import grpc
-import warnings
-
-from exo.networking.grpc import node_service_pb2 as node__service__pb2
-
-GRPC_GENERATED_VERSION = '1.64.1'
-GRPC_VERSION = grpc.__version__
-EXPECTED_ERROR_RELEASE = '1.65.0'
-SCHEDULED_RELEASE_DATE = 'June 25, 2024'
-_version_not_supported = False
-
-try:
-    from grpc._utilities import first_version_is_lower
-    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
-except ImportError:
-    _version_not_supported = True
-
-if _version_not_supported:
-    warnings.warn(
-        f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
-        + f' grpcio>={GRPC_GENERATED_VERSION}.'
-        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
-        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
-    )
-
-
-class NodeServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def __init__(self, channel):
-        """Constructor.
-
-        Args:
-            channel: A grpc.Channel.
-        """
-        self.SendPrompt = channel.unary_unary(
-                '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.SendTensor = channel.unary_unary(
-                '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
-        self.CollectTopology = channel.unary_unary(
-                '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
-                _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.SendOpaqueStatus = channel.unary_unary(
-                '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.HealthCheck = channel.unary_unary(
-                '/node_service.NodeService/HealthCheck',
-                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
-                _registered_method=True)
-
-
-class NodeServiceServicer(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def SendPrompt(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def SendTensor(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def CollectTopology(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def SendResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def SendOpaqueStatus(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def HealthCheck(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-
-def add_NodeServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-            'SendPrompt': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'SendTensor': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
-            'CollectTopology': grpc.unary_unary_rpc_method_handler(
-                    servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
-            ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'HealthCheck': grpc.unary_unary_rpc_method_handler(
-                    servicer.HealthCheck,
-                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
-            ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-            'node_service.NodeService', rpc_method_handlers)
-    server.add_generic_rpc_handlers((generic_handler,))
-    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
-
-
- # This class is part of an EXPERIMENTAL API.
-class NodeService(object):
-    """Missing associated documentation comment in .proto file."""
-
-    @staticmethod
-    def SendPrompt(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def SendTensor(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def CollectTopology(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def SendResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def SendOpaqueStatus(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def HealthCheck(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/HealthCheck',
-            node__service__pb2.HealthCheckRequest.SerializeToString,
-            node__service__pb2.HealthCheckResponse.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)

+ 0 - 0
build/lib/exo/networking/manual/__init__.py


+ 0 - 71
build/lib/exo/networking/manual/manual_discovery.py

@@ -1,71 +0,0 @@
-import asyncio
-from exo.networking.discovery import Discovery
-from typing import Dict, List, Callable
-
-from exo.topology.device_capabilities import DeviceCapabilities
-from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
-from exo.helpers import DEBUG_DISCOVERY
-from exo.networking.peer_handle import PeerHandle
-
-
-class ManualDiscovery(Discovery):
-  def __init__(
-    self,
-    network_config_path: str,
-    node_id: str,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-  ):
-    self.topology = NetworkTopology.from_path(network_config_path)
-    self.create_peer_handle = create_peer_handle
-
-    if node_id not in self.topology.peers:
-      raise ValueError(
-        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
-      )
-
-    self.listen_task = None
-
-    self.known_peers: Dict[str, PeerHandle] = {}
-    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
-    self.peers_in_network.pop(node_id)
-
-  async def start(self) -> None:
-    self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
-
-  async def stop(self) -> None:
-    if self.listen_task:
-      self.listen_task.cancel()
-
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if wait_for_peers > 0:
-      while len(self.known_peers) < wait_for_peers:
-        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
-        await asyncio.sleep(0.1)
-    if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
-    return list(self.known_peers.values())
-
-  async def task_find_peers_from_config(self):
-    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
-    while True:
-      for peer_id, peer_config in self.peers_in_network.items():
-        try:
-          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
-          peer = self.known_peers.get(peer_id)
-          if not peer:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
-          is_healthy = await peer.health_check()
-          if is_healthy:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
-            self.known_peers[peer_id] = peer
-          else:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try:
-              del self.known_peers[peer_id]
-            except KeyError:
-              pass
-        except Exception as e:
-          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
-      await asyncio.sleep(1.0)
-
-      if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

+ 0 - 31
build/lib/exo/networking/manual/network_topology_config.py

@@ -1,31 +0,0 @@
-from typing import Dict
-from pydantic import BaseModel, ValidationError
-
-from exo.topology.device_capabilities import DeviceCapabilities
-
-
-class PeerConfig(BaseModel):
-  address: str
-  port: int
-  device_capabilities: DeviceCapabilities
-
-
-class NetworkTopology(BaseModel):
-  """Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
-
-  peers: Dict[str, PeerConfig]
-  """
-  node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
-  """
-  @classmethod
-  def from_path(cls, path: str) -> "NetworkTopology":
-    try:
-      with open(path, "r") as f:
-        config_data = f.read()
-    except FileNotFoundError as e:
-      raise FileNotFoundError(f"Config file not found at {path}") from e
-
-    try:
-      return cls.model_validate_json(config_data)
-    except ValidationError as e:
-      raise ValueError(f"Error validating network topology config from {path}: {e}") from e

+ 0 - 103
build/lib/exo/networking/manual/test_manual_discovery.py

@@ -1,103 +0,0 @@
-import asyncio
-import unittest
-from unittest import mock
-from exo.networking.manual.manual_discovery import ManualDiscovery
-from exo.networking.manual.network_topology_config import NetworkTopology
-from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from exo.networking.grpc.grpc_server import GRPCServer
-from exo.orchestration.node import Node
-
-root_path = "./exo/networking/manual/test_data/test_config.json"
-
-
-class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.peer1 = mock.AsyncMock()
-    self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    _ = self.discovery1.start()
-
-  async def asyncTearDown(self):
-    await self.discovery1.stop()
-
-  async def test_discovery(self):
-    peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
-    assert len(peers1) == 0
-
-    self.peer1.connect.assert_not_called()
-
-
-class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.peer1 = mock.AsyncMock()
-    self.peer2 = mock.AsyncMock()
-    self.peer1.connect = mock.AsyncMock()
-    self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
-    await self.discovery1.start()
-    await self.discovery2.start()
-
-  async def asyncTearDown(self):
-    await self.discovery1.stop()
-    await self.discovery2.stop()
-
-  async def test_discovery(self):
-    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
-    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
-
-    # connect has to be explicitly called after discovery
-    self.peer1.connect.assert_not_called()
-    self.peer2.connect.assert_not_called()
-
-
-class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    config = NetworkTopology.from_path(root_path)
-
-    self.node1 = mock.AsyncMock(spec=Node)
-    self.node2 = mock.AsyncMock(spec=Node)
-    self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
-    self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
-    await self.server1.start()
-    await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    await self.discovery1.start()
-    await self.discovery2.start()
-
-  async def asyncTearDown(self):
-    await self.discovery1.stop()
-    await self.discovery2.stop()
-    await self.server1.stop()
-    await self.server2.stop()
-
-  async def test_grpc_discovery(self):
-    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
-    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
-
-    # Connect
-    await peers1[0].connect()
-    await peers2[0].connect()
-    self.assertTrue(await peers1[0].is_connected())
-    self.assertTrue(await peers2[0].is_connected())
-
-    # Kill server1
-    await self.server1.stop()
-
-    self.assertTrue(await peers1[0].is_connected())
-    self.assertFalse(await peers2[0].is_connected())
-
-    # Kill server2
-    await self.server2.stop()
-
-    self.assertFalse(await peers1[0].is_connected())
-    self.assertFalse(await peers2[0].is_connected())
-
-
-if __name__ == "__main__":
-  asyncio.run(unittest.main())

+ 0 - 49
build/lib/exo/networking/manual/test_network_topology_config.py

@@ -1,49 +0,0 @@
-import unittest
-
-from exo.networking.manual.network_topology_config import NetworkTopology
-
-root_path = "./exo/networking/manual/test_data/"
-
-
-class TestNetworkTopologyConfig(unittest.TestCase):
-  def test_from_path_invalid_path(self):
-    with self.assertRaises(FileNotFoundError) as e:
-      NetworkTopology.from_path("invalid_path")
-    self.assertEqual(str(e.exception), "Config file not found at invalid_path")
-
-  def test_from_path_invalid_json(self):
-    with self.assertRaises(ValueError) as e:
-      NetworkTopology.from_path(root_path + "invalid_json.json")
-    self.assertIn("Error validating network topology config from", str(e.exception))
-    self.assertIn("1 validation error for NetworkTopology\n  Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
-
-  def test_from_path_invalid_config(self):
-    with self.assertRaises(ValueError) as e:
-      NetworkTopology.from_path(root_path + "invalid_config.json")
-    self.assertIn("Error validating network topology config from", str(e.exception))
-    self.assertIn("port\n  Field required", str(e.exception))
-
-  def test_from_path_valid(self):
-    config = NetworkTopology.from_path(root_path + "test_config.json")
-
-    self.assertEqual(config.peers["node1"].port, 50051)
-    self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
-    self.assertEqual(config.peers["node1"].address, "localhost")
-    self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
-    self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
-    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
-    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
-    self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
-
-    self.assertEqual(config.peers["node2"].port, 50052)
-    self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
-    self.assertEqual(config.peers["node2"].address, "localhost")
-    self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
-    self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
-    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
-    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
-    self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
-
-
-if __name__ == "__main__":
-  unittest.main()

+ 0 - 56
build/lib/exo/networking/peer_handle.py

@@ -1,56 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Optional, Tuple, List
-import numpy as np
-from exo.inference.shard import Shard
-from exo.topology.device_capabilities import DeviceCapabilities
-from exo.topology.topology import Topology
-
-
-class PeerHandle(ABC):
-  @abstractmethod
-  def id(self) -> str:
-    pass
-
-  @abstractmethod
-  def addr(self) -> str:
-    pass
-
-  @abstractmethod
-  def device_capabilities(self) -> DeviceCapabilities:
-    pass
-
-  @abstractmethod
-  async def connect(self) -> None:
-    pass
-
-  @abstractmethod
-  async def is_connected(self) -> bool:
-    pass
-
-  @abstractmethod
-  async def disconnect(self) -> None:
-    pass
-
-  @abstractmethod
-  async def health_check(self) -> bool:
-    pass
-
-  @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
-    pass
-
-  @abstractmethod
-  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
-    pass
-
-  @abstractmethod
-  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    pass
-
-  @abstractmethod
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    pass
-
-  @abstractmethod
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-    pass

+ 0 - 11
build/lib/exo/networking/server.py

@@ -1,11 +0,0 @@
-from abc import ABC, abstractmethod
-
-
-class Server(ABC):
-  @abstractmethod
-  async def start(self) -> None:
-    pass
-
-  @abstractmethod
-  async def stop(self) -> None:
-    pass

+ 0 - 0
build/lib/exo/networking/tailscale/__init__.py


+ 0 - 178
build/lib/exo/networking/tailscale/tailscale_discovery.py

@@ -1,178 +0,0 @@
-import asyncio
-import time
-import traceback
-from typing import List, Dict, Callable, Tuple
-from exo.networking.discovery import Discovery
-from exo.networking.peer_handle import PeerHandle
-from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY
-from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
-
-
-class TailscaleDiscovery(Discovery):
-  def __init__(
-    self,
-    node_id: str,
-    node_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-    discovery_interval: int = 5,
-    discovery_timeout: int = 30,
-    update_interval: int = 15,
-    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
-    tailscale_api_key: str = None,
-    tailnet: str = None,
-    allowed_node_ids: List[str] = None,
-  ):
-    self.node_id = node_id
-    self.node_port = node_port
-    self.create_peer_handle = create_peer_handle
-    self.discovery_interval = discovery_interval
-    self.discovery_timeout = discovery_timeout
-    self.update_interval = update_interval
-    self.device_capabilities = device_capabilities
-    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
-    self.discovery_task = None
-    self.cleanup_task = None
-    self.tailscale_api_key = tailscale_api_key
-    self.tailnet = tailnet
-    self.allowed_node_ids = allowed_node_ids
-    self._device_id = None
-    self.update_task = None
-
-  async def start(self):
-    self.device_capabilities = device_capabilities()
-    self.discovery_task = asyncio.create_task(self.task_discover_peers())
-    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
-    self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())
-
-  async def task_update_device_posture_attributes(self):
-    while True:
-      try:
-        await self.update_device_posture_attributes()
-        if DEBUG_DISCOVERY >= 2:
-          print(f"Updated device posture attributes")
-      except Exception as e:
-        print(f"Error updating device posture attributes: {e}")
-        print(traceback.format_exc())
-      finally:
-        await asyncio.sleep(self.update_interval)
-
-  async def get_device_id(self):
-    if self._device_id:
-      return self._device_id
-    self._device_id = await get_device_id()
-    return self._device_id
-
-  async def update_device_posture_attributes(self):
-    await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
-
-  async def task_discover_peers(self):
-    while True:
-      try:
-        devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
-        current_time = time.time()
-
-        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
-
-        if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
-        if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
-        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
-
-        for device in active_devices.values():
-          if device.name == self.node_id: continue
-          peer_host = device.addresses[0]
-          peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
-          if not peer_id:
-            if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
-            continue
-
-          if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
-            if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
-            continue
-
-          if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-            new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
-            if not await new_peer_handle.health_check():
-              if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
-              continue
-
-            if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
-            self.known_peers[peer_id] = (
-              new_peer_handle,
-              current_time,
-              current_time,
-            )
-          else:
-            if not await self.known_peers[peer_id][0].health_check():
-              if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
-              if peer_id in self.known_peers: del self.known_peers[peer_id]
-              continue
-            self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], current_time)
-
-      except Exception as e:
-        print(f"Error in discover peers: {e}")
-        print(traceback.format_exc())
-      finally:
-        await asyncio.sleep(self.discovery_interval)
-
-  async def stop(self):
-    if self.discovery_task:
-      self.discovery_task.cancel()
-    if self.cleanup_task:
-      self.cleanup_task.cancel()
-    if self.update_task:
-      self.update_task.cancel()
-    if self.discovery_task or self.cleanup_task or self.update_task:
-      await asyncio.gather(self.discovery_task, self.cleanup_task, self.update_task, return_exceptions=True)
-
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if wait_for_peers > 0:
-      while len(self.known_peers) < wait_for_peers:
-        if DEBUG_DISCOVERY >= 2:
-          print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
-        await asyncio.sleep(0.1)
-    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
-
-  async def task_cleanup_peers(self):
-    while True:
-      try:
-        current_time = time.time()
-        peers_to_remove = []
-
-        peer_ids = list(self.known_peers.keys())
-        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
-
-        for peer_id, should_remove in zip(peer_ids, results):
-          if should_remove: peers_to_remove.append(peer_id)
-
-        if DEBUG_DISCOVERY >= 2:
-          print(
-            "Peer statuses:", {
-              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
-              for peer_handle, connected_at, last_seen in self.known_peers.values()
-            }
-          )
-
-        for peer_id in peers_to_remove:
-          if peer_id in self.known_peers:
-            del self.known_peers[peer_id]
-            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
-      except Exception as e:
-        print(f"Error in cleanup peers: {e}")
-        print(traceback.format_exc())
-      finally:
-        await asyncio.sleep(self.discovery_interval)
-
-  async def check_peer(self, peer_id: str, current_time: float) -> bool:
-    peer_handle, connected_at, last_seen = self.known_peers.get(peer_id, (None, None, None))
-    if peer_handle is None: return False
-
-    try:
-      is_connected = await peer_handle.is_connected()
-      health_ok = await peer_handle.health_check()
-    except Exception as e:
-      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
-      return True
-
-    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
-    return should_remove

+ 0 - 125
build/lib/exo/networking/tailscale/tailscale_helpers.py

@@ -1,125 +0,0 @@
-import json
-import asyncio
-import aiohttp
-import re
-from typing import Dict, Any, Tuple, List, Optional
-from exo.helpers import DEBUG_DISCOVERY
-from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
-from datetime import datetime, timezone
-
-
-class Device:
-  def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
-    self.device_id = device_id
-    self.name = name
-    self.addresses = addresses
-    self.last_seen = last_seen
-
-  @classmethod
-  def from_dict(cls, data: Dict[str, Any]) -> 'Device':
-    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
-
-  @staticmethod
-  def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
-    if not date_string:
-      return None
-    return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
-
-
-async def get_device_id() -> str:
-  try:
-    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
-    stdout, stderr = await process.communicate()
-    if process.returncode != 0:
-      raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
-    if DEBUG_DISCOVERY >= 4: print(f"tailscale status: {stdout.decode()}")
-    data = json.loads(stdout.decode())
-    return data['Self']['ID']
-  except Exception as e:
-    raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
-
-
-async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
-  async with aiohttp.ClientSession() as session:
-    base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
-
-    attributes = {
-      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
-      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
-      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
-      "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
-    }
-
-    for attr_name, attr_value in attributes.items():
-      url = f"{base_url}/{attr_name}"
-      data = {"value": str(attr_value).replace(' ', '_')}  # Ensure all values are strings for JSON
-      async with session.post(url, headers=headers, json=data) as response:
-        if response.status == 200:
-          if DEBUG_DISCOVERY >= 1: print(f"Updated device posture attribute {attr_name} for device {device_id}")
-        else:
-          print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
-
-
-async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
-  async with aiohttp.ClientSession() as session:
-    url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {'Authorization': f'Bearer {api_key}'}
-    async with session.get(url, headers=headers) as response:
-      if response.status == 200:
-        data = await response.json()
-        attributes = data.get("attributes", {})
-        node_id = attributes.get("custom:exo_node_id", "").replace('_', '-')
-        node_port = int(attributes.get("custom:exo_node_port", 0))
-        device_capabilities = DeviceCapabilities(
-          model=attributes.get("custom:exo_device_capability_model", "").replace('_', ' '),
-          chip=attributes.get("custom:exo_device_capability_chip", "").replace('_', ' '),
-          memory=int(attributes.get("custom:exo_device_capability_memory", 0)),
-          flops=DeviceFlops(
-            fp16=float(attributes.get("custom:exo_device_capability_flops_fp16", 0)),
-            fp32=float(attributes.get("custom:exo_device_capability_flops_fp32", 0)),
-            int8=float(attributes.get("custom:exo_device_capability_flops_int8", 0))
-          )
-        )
-        return node_id, node_port, device_capabilities
-      else:
-        print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
-        return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
-
-
-def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
-  result = {}
-  prefix = "custom:exo_"
-  for key, value in data.items():
-    if key.startswith(prefix):
-      attr_name = key.replace(prefix, "")
-      if attr_name in ["node_id", "node_port", "device_capability_chip", "device_capability_model"]:
-        result[attr_name] = value.replace('_', ' ')
-      elif attr_name in ["device_capability_memory", "device_capability_flops_fp16", "device_capability_flops_fp32", "device_capability_flops_int8"]:
-        result[attr_name] = float(value)
-  return result
-
-
-def sanitize_attribute(value: str) -> str:
-  # Replace invalid characters with underscores
-  sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
-  # Truncate to 50 characters
-  return sanitized_value[:50]
-
-
-async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
-  async with aiohttp.ClientSession() as session:
-    url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
-    headers = {"Authorization": f"Bearer {api_key}"}
-
-    async with session.get(url, headers=headers) as response:
-      response.raise_for_status()
-      data = await response.json()
-
-      devices = {}
-      for device_data in data.get("devices", []):
-        print("Device data: ", device_data)
-        device = Device.from_dict(device_data)
-        devices[device.name] = device
-
-      return devices

+ 0 - 43
build/lib/exo/networking/tailscale/test_tailscale_discovery.py

@@ -1,43 +0,0 @@
-import os
-import asyncio
-import unittest
-from unittest import mock
-from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
-from exo.networking.peer_handle import PeerHandle
-
-
-class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
-    self.tailnet = os.environ.get("TAILSCALE_TAILNET", "")
-    self.discovery = TailscaleDiscovery(
-      node_id="test_node",
-      node_port=50051,
-      create_peer_handle=lambda peer_id, address, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id),
-      tailscale_api_key=self.tailscale_api_key,
-      tailnet=self.tailnet
-    )
-    await self.discovery.start()
-
-  async def asyncTearDown(self):
-    await self.discovery.stop()
-
-  async def test_discovery(self):
-    # Wait for a short period to allow discovery to happen
-    await asyncio.sleep(15)
-
-    # Get discovered peers
-    peers = await self.discovery.discover_peers()
-
-    # Check if any peers were discovered
-    self.assertGreater(len(peers), 0, "No peers were discovered")
-
-    # Print discovered peers for debugging
-    print(f"Discovered peers: {[peer.id() for peer in peers]}")
-
-    # Check if discovered peers are instances of GRPCPeerHandle
-    print(peers)
-
-
-if __name__ == '__main__':
-  unittest.main()

+ 0 - 0
build/lib/exo/networking/udp/__init__.py


+ 0 - 77
build/lib/exo/networking/udp/test_udp_discovery.py

@@ -1,77 +0,0 @@
-import asyncio
-import unittest
-from unittest import mock
-from exo.networking.udp.udp_discovery import UDPDiscovery
-from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
-from exo.networking.grpc.grpc_server import GRPCServer
-from exo.orchestration.node import Node
-
-
-class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.peer1 = mock.AsyncMock()
-    self.peer2 = mock.AsyncMock()
-    self.peer1.connect = mock.AsyncMock()
-    self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
-    self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
-    await self.discovery1.start()
-    await self.discovery2.start()
-
-  async def asyncTearDown(self):
-    await self.discovery1.stop()
-    await self.discovery2.stop()
-
-  async def test_discovery(self):
-    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
-    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
-
-    # connect has to be explicitly called after discovery
-    self.peer1.connect.assert_not_called()
-    self.peer2.connect.assert_not_called()
-
-
-class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
-  async def asyncSetUp(self):
-    self.node1 = mock.AsyncMock(spec=Node)
-    self.node2 = mock.AsyncMock(spec=Node)
-    self.server1 = GRPCServer(self.node1, "localhost", 50053)
-    self.server2 = GRPCServer(self.node2, "localhost", 50054)
-    await self.server1.start()
-    await self.server2.start()
-    self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
-    await self.discovery1.start()
-    await self.discovery2.start()
-
-  async def asyncTearDown(self):
-    await self.discovery1.stop()
-    await self.discovery2.stop()
-    await self.server1.stop()
-    await self.server2.stop()
-
-  async def test_grpc_discovery(self):
-    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
-    assert len(peers1) == 1
-    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
-    assert len(peers2) == 1
-    assert not await peers1[0].is_connected()
-    assert not await peers2[0].is_connected()
-
-    # Connect
-    await peers1[0].connect()
-    await peers2[0].connect()
-    assert await peers1[0].is_connected()
-    assert await peers2[0].is_connected()
-
-    # Kill server1
-    await self.server1.stop()
-
-    assert await peers1[0].is_connected()
-    assert not await peers2[0].is_connected()
-
-
-if __name__ == "__main__":
-  asyncio.run(unittest.main())

+ 0 - 215
build/lib/exo/networking/udp/udp_discovery.py

@@ -1,215 +0,0 @@
-import asyncio
-import json
-import socket
-import time
-import traceback
-from typing import List, Dict, Callable, Tuple, Coroutine
-from exo.networking.discovery import Discovery
-from exo.networking.peer_handle import PeerHandle
-from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
-
-
-class ListenProtocol(asyncio.DatagramProtocol):
-  def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
-    super().__init__()
-    self.on_message = on_message
-    self.loop = asyncio.get_event_loop()
-
-  def connection_made(self, transport):
-    self.transport = transport
-
-  def datagram_received(self, data, addr):
-    asyncio.create_task(self.on_message(data, addr))
-
-
-class BroadcastProtocol(asyncio.DatagramProtocol):
-  def __init__(self, message: str, broadcast_port: int):
-    self.message = message
-    self.broadcast_port = broadcast_port
-
-  def connection_made(self, transport):
-    sock = transport.get_extra_info("socket")
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-    transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
-
-
-class UDPDiscovery(Discovery):
-  def __init__(
-    self,
-    node_id: str,
-    node_port: int,
-    listen_port: int,
-    broadcast_port: int,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
-    broadcast_interval: int = 1,
-    discovery_timeout: int = 30,
-    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
-    allowed_node_ids: List[str] = None,
-  ):
-    self.node_id = node_id
-    self.node_port = node_port
-    self.listen_port = listen_port
-    self.broadcast_port = broadcast_port
-    self.create_peer_handle = create_peer_handle
-    self.broadcast_interval = broadcast_interval
-    self.discovery_timeout = discovery_timeout
-    self.device_capabilities = device_capabilities
-    self.allowed_node_ids = allowed_node_ids
-    self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
-    self.broadcast_task = None
-    self.listen_task = None
-    self.cleanup_task = None
-
-  async def start(self):
-    self.device_capabilities = device_capabilities()
-    self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
-    self.listen_task = asyncio.create_task(self.task_listen_for_peers())
-    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
-
-  async def stop(self):
-    if self.broadcast_task: self.broadcast_task.cancel()
-    if self.listen_task: self.listen_task.cancel()
-    if self.cleanup_task: self.cleanup_task.cancel()
-    if self.broadcast_task or self.listen_task or self.cleanup_task:
-      await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
-
-  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-    if wait_for_peers > 0:
-      while len(self.known_peers) < wait_for_peers:
-        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
-        await asyncio.sleep(0.1)
-    return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
-
-  async def task_broadcast_presence(self):
-    if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
-
-    while True:
-      # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
-      # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
-      for addr in get_all_ip_addresses():
-        message = json.dumps({
-          "type": "discovery",
-          "node_id": self.node_id,
-          "grpc_port": self.node_port,
-          "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
-        })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
-
-        transport = None
-        try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3:
-            print(f"Broadcasting presence at ({addr})")
-        except Exception as e:
-          print(f"Error in broadcast presence ({addr}): {e}")
-        finally:
-          if transport:
-            try:
-              transport.close()
-            except Exception as e:
-              if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
-              if DEBUG_DISCOVERY >= 2: traceback.print_exc()
-      await asyncio.sleep(self.broadcast_interval)
-
-  async def on_listen_message(self, data, addr):
-    if not data:
-      return
-
-    decoded_data = data.decode("utf-8", errors="ignore")
-
-    # Check if the decoded data starts with a valid JSON character
-    if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
-      if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
-      return
-
-    try:
-      decoder = json.JSONDecoder(strict=False)
-      message = decoder.decode(decoded_data)
-    except json.JSONDecodeError as e:
-      if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}")
-      return
-
-    if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
-
-    if message["type"] == "discovery" and message["node_id"] != self.node_id:
-      peer_id = message["node_id"]
-      
-      # Skip if peer_id is not in allowed list
-      if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
-        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
-        return
-
-      peer_host = addr[0]
-      peer_port = message["grpc_port"]
-      peer_prio = message["priority"]
-      device_capabilities = DeviceCapabilities(**message["device_capabilities"])
-
-      if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
-        if peer_id in self.known_peers:
-          existing_peer_prio = self.known_peers[peer_id][3]
-          if existing_peer_prio >= peer_prio:
-            if DEBUG >= 1:
-              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
-            return
-        new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
-        if not await new_peer_handle.health_check():
-          if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
-          return
-        if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
-        self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
-      else:
-        if not await self.known_peers[peer_id][0].health_check():
-          if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
-          if peer_id in self.known_peers: del self.known_peers[peer_id]
-          return
-        if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
-
-  async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
-    if DEBUG_DISCOVERY >= 2: print("Started listen task")
-
-  async def task_cleanup_peers(self):
-    while True:
-      try:
-        current_time = time.time()
-        peers_to_remove = []
-
-        peer_ids = list(self.known_peers.keys())
-        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
-
-        for peer_id, should_remove in zip(peer_ids, results):
-          if should_remove: peers_to_remove.append(peer_id)
-
-        if DEBUG_DISCOVERY >= 2:
-          print(
-            "Peer statuses:", {
-              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
-              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
-            }
-          )
-
-        for peer_id in peers_to_remove:
-          if peer_id in self.known_peers:
-            del self.known_peers[peer_id]
-            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
-      except Exception as e:
-        print(f"Error in cleanup peers: {e}")
-        print(traceback.format_exc())
-      finally:
-        await asyncio.sleep(self.broadcast_interval)
-
-  async def check_peer(self, peer_id: str, current_time: float) -> bool:
-    peer_handle, connected_at, last_seen, prio = self.known_peers.get(peer_id, (None, None, None, None))
-    if peer_handle is None: return False
-
-    try:
-      is_connected = await peer_handle.is_connected()
-      health_ok = await peer_handle.health_check()
-    except Exception as e:
-      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
-      return True
-
-    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
-    return should_remove

+ 0 - 4
build/lib/exo/orchestration/__init__.py

@@ -1,4 +0,0 @@
-from .node import Node
-from .standard_node import StandardNode
-
-__all__ = ["Node", "StandardNode"]

+ 0 - 47
build/lib/exo/orchestration/node.py

@@ -1,47 +0,0 @@
-from typing import Optional, Tuple, List
-import numpy as np
-from abc import ABC, abstractmethod
-from exo.helpers import AsyncCallbackSystem
-from exo.inference.shard import Shard
-from exo.topology.topology import Topology
-
-
-class Node(ABC):
-  @abstractmethod
-  async def start(self, wait_for_peers: int = 0) -> None:
-    pass
-
-  @abstractmethod
-  async def stop(self) -> None:
-    pass
-
-  @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
-    pass
-
-  @abstractmethod
-  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
-    pass
-
-  @abstractmethod
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    pass
-
-  @abstractmethod
-  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
-    pass
-
-  @property
-  @abstractmethod
-  def current_topology(self) -> Topology:
-    pass
-
-  @property
-  @abstractmethod
-  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
-    pass
-
-  @property
-  @abstractmethod
-  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
-    pass

+ 0 - 488
build/lib/exo/orchestration/standard_node.py

@@ -1,488 +0,0 @@
-import numpy as np
-import json
-import asyncio
-import uuid
-import time
-import traceback
-from typing import List, Dict, Optional, Tuple, Union, Set
-from exo.networking import Discovery, PeerHandle, Server
-from exo.inference.inference_engine import InferenceEngine, Shard
-from .node import Node
-from exo.topology.topology import Topology
-from exo.topology.device_capabilities import device_capabilities
-from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
-from exo import DEBUG
-from exo.helpers import AsyncCallbackSystem
-from exo.viz.topology_viz import TopologyViz
-from exo.download.hf.hf_helpers import RepoProgressEvent
-from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.download.hf.hf_shard_download import HFShardDownloader
-
-class StandardNode(Node):
-  def __init__(
-    self,
-    _id: str,
-    server: Server,
-    inference_engine: InferenceEngine,
-    discovery: Discovery,
-    partitioning_strategy: PartitioningStrategy = None,
-    max_generate_tokens: int = 1024,
-    default_sample_temperature: float = 0.0,
-    topology_viz: Optional[TopologyViz] = None,
-    shard_downloader: Optional[HFShardDownloader] = None,
-  ):
-    self.id = _id
-    self.inference_engine = inference_engine
-    self.server = server
-    self.discovery = discovery
-    self.partitioning_strategy = partitioning_strategy
-    self.peers: List[PeerHandle] = {}
-    self.topology: Topology = Topology()
-    self.device_capabilities = device_capabilities()
-    self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
-    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
-    self.max_generate_tokens = max_generate_tokens
-    self.topology_viz = topology_viz
-    self.default_sample_temperature = default_sample_temperature
-    self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
-    self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
-    self._on_opaque_status.register("node_status").on_next(self.on_node_status)
-    self.node_download_progress: Dict[str, RepoProgressEvent] = {}
-    self.topology_inference_engines_pool: List[List[str]] = []
-    self.shard_downloader = shard_downloader
-
-  async def start(self, wait_for_peers: int = 0) -> None:
-    await self.server.start()
-    await self.discovery.start()
-    await self.update_peers(wait_for_peers)
-    await self.collect_topology()
-    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-    asyncio.create_task(self.periodic_topology_collection(1.0))
-
-  async def stop(self) -> None:
-    await self.discovery.stop()
-    await self.server.stop()
-
-  def on_node_status(self, request_id, opaque_status):
-    try:
-      status_data = json.loads(opaque_status)
-      if status_data.get("type", "") == "supported_inference_engines":
-        node_id = status_data.get("node_id")
-        engines = status_data.get("engines", [])
-        self.topology_inference_engines_pool.append(engines)
-      if status_data.get("type", "") == "node_status":
-        if status_data.get("status", "").startswith("start_"):
-          self.current_topology.active_node_id = status_data.get("node_id")
-        elif status_data.get("status", "").startswith("end_"):
-          if status_data.get("node_id") == self.current_topology.active_node_id:
-            self.current_topology.active_node_id = None
-      download_progress = None
-      if status_data.get("type", "") == "download_progress":
-        if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
-        download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
-        self.node_download_progress[status_data.get('node_id')] = download_progress
-      if self.topology_viz:
-        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
-    except Exception as e:
-      if DEBUG >= 1: print(f"Error updating visualization: {e}")
-      if DEBUG >= 1: traceback.print_exc()
-
-  def get_supported_inference_engines(self):
-    supported_engine_names = []
-    if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-      supported_engine_names.append('mlx')
-      supported_engine_names.append('tinygrad')
-    else:
-      supported_engine_names.append('tinygrad')
-    return supported_engine_names
-
-  async def broadcast_supported_engines(self, supported_engines_names: List[str]):
-    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
-    await self.broadcast_opaque_status("", status_message)
-
-  def get_topology_inference_engines(self) -> List[List[str]]:
-    return self.topology_inference_engines_pool
-  
-  async def process_inference_result(
-    self,
-    shard,
-    result: np.ndarray,
-    request_id: Optional[str] = None,
-    inference_state: Optional[dict] = None,
-  ):
-    if shard.model_id != 'stable-diffusion-2-1-base':
-      if request_id not in self.buffered_token_output:
-        self.buffered_token_output[request_id] = ([], False)
-      is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if shard.is_last_layer() and not is_finished:
-        token = await self.inference_engine.sample(result, temp = self.default_sample_temperature)
-        self.buffered_token_output[request_id][0].append(token.item())
-        intermediate_result = self.buffered_token_output[request_id][0]
-        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-        is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
-        forward = token.reshape(1, -1)
-      else:
-        forward = result
-    else:
-      await self.inference_engine.ensure_shard(shard)
-      is_finished = inference_state.get('is_finished', False)
-      intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
-      forward = result
-    if shard.is_last_layer():
-      asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
-      self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
-    if is_finished:
-      if shard.model_id != 'stable-diffusion-2-1-base':
-          self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-          intermediate_result = self.buffered_token_output[request_id][0]
-    else:
-      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
-
-    return np.array(intermediate_result)
-  
-  async def process_prompt(
-    self,
-    base_shard: Shard,
-    prompt: str,
-    request_id: Optional[str] = None,
-    inference_state: Optional[dict] = {},
-  ) -> Optional[np.ndarray]:
-    shard = self.get_current_shard(base_shard)
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "start_process_prompt",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "prompt": prompt,
-          "request_id": request_id,
-        }),
-      )
-    )
-    start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
-    end_time = time.perf_counter_ns()
-    elapsed_time_ns = end_time - start_time
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "end_process_prompt",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "prompt": prompt,
-          "request_id": request_id,
-          "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
-        }),
-      )
-    )
-    return resp
-
-  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
-    if request_id is None:
-      request_id = str(uuid.uuid4())
-    shard = self.get_current_shard(base_shard)
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-    if not shard.is_first_layer():
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
-      return None
-    else:
-      result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
-      ret = await self.process_inference_result(shard, result, request_id, inference_state) 
-      return result
-
-  async def process_tensor(
-    self,
-    base_shard: Shard,
-    tensor: np.ndarray,
-    request_id: Optional[str] = None,
-    inference_state: Optional[dict] = None,
-  ) -> Optional[np.ndarray]:
-    shard = self.get_current_shard(base_shard)
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "start_process_tensor",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "tensor_size": tensor.size,
-          "tensor_shape": tensor.shape,
-          "request_id": request_id,
-        }),
-      )
-    )
-    start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
-    end_time = time.perf_counter_ns()
-    elapsed_time_ns = end_time - start_time
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "end_process_tensor",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "request_id": request_id,
-          "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
-        }),
-      )
-    )
-    return resp
-
-  async def _process_tensor(
-    self,
-    base_shard: Shard,
-    tensor: np.ndarray,
-    request_id: Optional[str] = None,
-    inference_state: Optional[dict] = None,
-  ) -> Optional[np.ndarray]:
-    if request_id is None:
-      request_id = str(uuid.uuid4())
-    shard = self.get_current_shard(base_shard)
-
-    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-    try:
-      result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
-      ret = await self.process_inference_result(shard, result, request_id, inference_state) 
-      return ret
-    except Exception as e:
-      print(f"Error processing tensor for shard {shard}: {e}")
-      traceback.print_exc()
-      return None
-
-  async def forward_prompt(
-    self,
-    base_shard: Shard,
-    prompt: str,
-    request_id: str,
-    target_index: int,
-    inference_state: Optional[dict] = None,
-  ) -> None:
-    if DEBUG >= 1: print(f"target partition index: {target_index}")
-    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
-    next_shard = self.get_current_shard(base_shard, target_index)
-    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
-    if target_id == self.id:
-      await self.process_prompt(next_shard, prompt, request_id)
-    else:
-      target_peer = next((p for p in self.peers if p.id() == target_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {target_index} not found")
-      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
-      await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
-  
-  async def forward_tensor(
-    self,
-    base_shard: Shard,
-    tensor: np.ndarray,
-    request_id: str,
-    target_index: int,
-    inference_state: Optional[dict] = None,
-  ) -> None:
-    if DEBUG >= 1: print(f"target partition index: {target_index}")
-    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
-    next_shard = self.get_current_shard(base_shard, target_index)
-    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
-    if target_id == self.id:
-      await self.process_tensor(next_shard, tensor, request_id, inference_state)
-    else:
-      target_peer = next((p for p in self.peers if p.id() == target_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {target_index} not found")
-      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
-      await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
-
-  def get_partition_index(self, offset: int = 0):
-    if not self.partitioning_strategy:
-      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return None
-    partitions = self.partitioning_strategy.partition(self.topology)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if current_partition_index is None:
-      raise ValueError(f"No current partition found for node: {self.id}")
-    return (current_partition_index + offset) % len(partitions)
-
-  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
-    if index is None:
-      index = self.get_partition_index()
-    partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
-    return shards[index]
-
-  async def update_peers(self, wait_for_peers: int = 0) -> bool:
-    next_peers = await self.discovery.discover_peers(wait_for_peers)
-    current_peer_ids = {peer.id() for peer in self.peers}
-    next_peer_ids = {peer.id() for peer in next_peers}
-    peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
-    peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
-    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
-    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
-    peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
-    peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
-
-    def _pretty(peers: List[PeerHandle]) -> List[str]:
-      return [f"{peer.id()}@{peer.addr()}" for peer in peers]
-
-    if DEBUG >= 2:
-      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
-
-    async def disconnect_with_timeout(peer, timeout=5):
-      try:
-        await asyncio.wait_for(peer.disconnect(), timeout)
-        return True
-      except Exception as e:
-        print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
-        traceback.print_exc()
-        return False
-
-    async def connect_with_timeout(peer, timeout=5):
-      try:
-        await asyncio.wait_for(peer.connect(), timeout)
-        return True
-      except Exception as e:
-        print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
-        traceback.print_exc()
-        return False
-
-    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
-    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
-
-    successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
-    failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
-    successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True]
-    failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False]
-    if DEBUG >= 1:
-      if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}")
-      if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}")
-      if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}")
-      if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
-
-    self.peers = next_peers
-    return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
-
-  async def select_best_inference_engine(self):
-    if self.inference_engine.__class__.__name__ == 'DummyInferenceEngine': return
-    supported_engines = self.get_supported_inference_engines()
-    await self.broadcast_supported_engines(supported_engines)
-    if len(self.get_topology_inference_engines()):
-      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
-
-  async def periodic_topology_collection(self, interval: int):
-    while True:
-      await asyncio.sleep(interval)
-      try:
-        did_peers_change = await self.update_peers()
-        if DEBUG >= 2: print(f"{did_peers_change=}")
-        if did_peers_change:
-          await self.collect_topology()
-          await self.select_best_inference_engine()
-      except Exception as e:
-        print(f"Error collecting topology: {e}")
-        traceback.print_exc()
-
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    if request_id not in self.buffered_token_output:
-      return None, False
-    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
-
-  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
-    next_topology = Topology()
-    next_topology.update_node(self.id, self.device_capabilities)
-
-    if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
-
-    prev_visited = visited.copy()
-    visited.add(self.id)
-    visited.update(p.id() for p in self.peers)
-
-    for peer in self.peers:
-      next_topology.update_node(peer.id(), peer.device_capabilities())
-      next_topology.add_edge(self.id, peer.id())
-
-      if peer.id() in prev_visited:
-        continue
-
-      if max_depth <= 0:
-        if DEBUG >= 2: print("Max depth reached. Skipping...")
-        continue
-
-      try:
-        other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
-        if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-        self.topology.merge(other_topology)
-      except Exception as e:
-        print(f"Error collecting topology from {peer.id()}: {e}")
-        traceback.print_exc()
-
-    next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
-    self.topology = next_topology
-    if self.topology_viz:
-      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
-    return next_topology
-
-  @property
-  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
-    return self._on_token
-
-  @property
-  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
-    return self._on_opaque_status
-
-  def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
-    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
-    self.on_token.trigger_all(request_id, tokens, is_finished)
-  
-  async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    async def send_result_to_peer(peer):
-      try:
-        await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
-      except asyncio.TimeoutError:
-        print(f"Timeout broadcasting result to {peer.id()}")
-      except Exception as e:
-        print(f"Error broadcasting result to {peer.id()}: {e}")
-        traceback.print_exc()
-
-    await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
-
-  async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
-    if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
-
-    async def send_status_to_peer(peer):
-      try:
-        await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
-      except asyncio.TimeoutError:
-        print(f"Timeout sending opaque status to {peer.id()}")
-      except Exception as e:
-        print(f"Error sending opaque status to {peer.id()}: {e}")
-        traceback.print_exc()
-
-    await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
-    # in the case of opaque status, we also want to receive our own opaque statuses
-    self.on_opaque_status.trigger_all(request_id, status)
-
-  @property
-  def current_topology(self) -> Topology:
-    return self.topology
-
-  def handle_stable_diffusion(self, inference_state, result):
-    if inference_state['is_step_finished']:
-      inference_state['step']+=1
-    progress = [inference_state['step'],inference_state['total_steps']]
-    intermediate_result = result
-    if progress[0] == progress[1]:
-      intermediate_result = result
-    return intermediate_result, inference_state

+ 0 - 57
build/lib/exo/orchestration/test_node.py

@@ -1,57 +0,0 @@
-import unittest
-from unittest.mock import Mock, AsyncMock
-import numpy as np
-
-from .standard_node import StandardNode
-from exo.networking.peer_handle import PeerHandle
-
-
-class TestNode(unittest.IsolatedAsyncioTestCase):
-  def setUp(self):
-    self.mock_inference_engine = AsyncMock()
-    self.mock_server = AsyncMock()
-    self.mock_server.start = AsyncMock()
-    self.mock_server.stop = AsyncMock()
-    self.mock_discovery = AsyncMock()
-    self.mock_discovery.start = AsyncMock()
-    self.mock_discovery.stop = AsyncMock()
-    mock_peer1 = Mock(spec=PeerHandle)
-    mock_peer1.id.return_value = "peer1"
-    mock_peer2 = Mock(spec=PeerHandle)
-    mock_peer2.id.return_value = "peer2"
-    self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
-
-    self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
-
-  async def asyncSetUp(self):
-    await self.node.start()
-
-  async def asyncTearDown(self):
-    await self.node.stop()
-
-  async def test_node_initialization(self):
-    self.assertEqual(self.node.node_id, "test_node")
-    self.assertEqual(self.node.host, "localhost")
-    self.assertEqual(self.node.port, 50051)
-
-  async def test_node_start(self):
-    self.mock_server.start.assert_called_once_with("localhost", 50051)
-
-  async def test_node_stop(self):
-    await self.node.stop()
-    self.mock_server.stop.assert_called_once()
-
-  async def test_discover_and_connect_to_peers(self):
-    await self.node.discover_and_connect_to_peers()
-    self.assertEqual(len(self.node.peers), 2)
-    self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
-    self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
-
-  async def test_process_tensor_calls_inference_engine(self):
-    mock_peer = Mock()
-    self.node.peers = [mock_peer]
-
-    input_tensor = np.array([69, 1, 2])
-    await self.node.process_tensor(input_tensor, None)
-
-    self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)

+ 0 - 0
build/lib/exo/stats/__init__.py


+ 0 - 29
build/lib/exo/stats/metrics.py

@@ -1,29 +0,0 @@
-from exo.orchestration import Node
-from prometheus_client import start_http_server, Counter, Histogram
-import json
-
-# Create metrics to track time spent and requests made.
-PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
-PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
-PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
-
-
-def start_metrics_server(node: Node, port: int):
-  start_http_server(port)
-
-  def _on_opaque_status(request_id, opaque_status: str):
-    status_data = json.loads(opaque_status)
-    _type = status_data.get("type", "")
-    node_id = status_data.get("node_id", "")
-    if _type != "node_status":
-      return
-    status = status_data.get("status", "")
-
-    if status == "end_process_prompt":
-      PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
-    elif status == "end_process_tensor":
-      elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
-      PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
-      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9)  # Convert ns to seconds
-
-  node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 0 - 50
build/lib/exo/test_callbacks.py

@@ -1,50 +0,0 @@
-import asyncio
-from typing import Any, Callable
-from exo.helpers import AsyncCallbackSystem, AsyncCallback
-
-
-# Usage example
-async def main() -> None:
-  callback_system = AsyncCallbackSystem[str, Any]()
-
-  # Register callbacks
-  callback1 = callback_system.register("callback1")
-  callback2 = callback_system.register("callback2")
-
-  def on_next_callback(name: str) -> Callable[..., None]:
-    def callback(*args: Any) -> None:
-      print(f"{name} received values: {args}")
-
-    return callback
-
-  callback1.on_next(on_next_callback("Callback1"))
-  callback2.on_next(on_next_callback("Callback2"))
-
-  async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
-    try:
-      result = await callback.wait(condition, timeout=2)
-      print(f"{name} wait completed with result: {result}")
-    except asyncio.TimeoutError:
-      print(f"{name} wait timed out")
-
-  # Trigger all callbacks at once
-  callback_system.trigger_all("Hello", 42, True)
-
-  # Wait for all callbacks with different conditions
-  await asyncio.gather(
-    wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
-    wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True),
-  )
-
-  # Trigger individual callback
-  callback_system.trigger("callback2", "World", -10, False)
-
-  # Demonstrate timeout
-  new_callback = callback_system.register("new_callback")
-  new_callback.on_next(on_next_callback("NewCallback"))
-  await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
-
-  callback_system.trigger("callback2", "World", 200, False)
-
-
-asyncio.run(main())

+ 0 - 130
build/lib/exo/tinychat/common.css

@@ -1,130 +0,0 @@
-/* make it responsive */
-@media(min-width: 852px) {
-  body {
-    font-size: 14px;
-  }
-}
-@media(max-width: 852px) {
-  body {
-    font-size: 12px;
-  }
-}
-
-/* resets */
-html, body {
-  width: 100%;
-  height: 100%;
-}
-
-*::-webkit-scrollbar {
-  display: none;
-}
-
-* {
-  -ms-overflow-style: none;
-  scrollbar-width: none;
-}
-
-* {
-  -moz-box-sizing: border-box;
-  -webkit-box-sizing: border-box;
-  box-sizing: border-box;
-}
-
-/* default */
-body {
-  margin: 0;
-  background-color: var(--primary-bg-color);
-  color: var(--foreground-color);
-}
-
-h1, h2, h3, h4, h5, h6 {
-  margin: 0em;
-}
-
-hr {
-  width: 92%;
-}
-
-button {
-  cursor: pointer;
-  border: none;
-  background-color: transparent;
-}
-button:hover {
-}
-button:active {
-}
-
-/* components */
-.container {
-  margin: 0 auto;
-  padding: 1rem;
-}
-
-.centered {
-  display: flex;
-  flex-direction: column;
-  justify-content: center;
-  align-items: center;
-}
-
-.centered-w-only {
-  position: absolute;
-  left: 50%;
-  transform: translateX(-50%);
-}
-
-.centered-h-only {
-  position: absolute;
-  top: 50%;
-  transform: translateY(-50%);
-}
-
-.card {
-  padding: 0;
-}
-
-.card-header {
-  padding: 0.5rem 1rem;
-}
-
-.card-container {
-  width: 96vw;
-  height: 100%;
-  gap: 1rem;
-  display: flex;
-  flex-direction: row;
-  flex-wrap: wrap;
-  justify-content: center;
-  align-items: center;
-}
-
-.clean-a {
-  text-decoration: underline;
-  text-decoration-color: #006fc1;
-  text-decoration-thickness: 2px;
-  color: inherit;
-}
-
-.hover-underline {
-  text-decoration: underline;
-  text-decoration-color: #228039;
-  text-decoration-thickness: 2px;
-  color: inherit;
-}
-
-.flex-horizontal {
-  display: flex;
-  flex-direction: row;
-  justify-content: space-between;
-  align-items: center;
-}
-
-.vertical-separator {
-  padding: 0 0.5rem;
-}
-
-[x-cloak] {
-  display: none !important;
-}

+ 0 - 25
build/lib/exo/tinychat/favicon.svg

@@ -1,25 +0,0 @@
-<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges">
-  <g id="logo">
-    <!-- t -->
-    <polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
-    <!-- i -->
-    <polygon points="40,40 40,20 50,20 50,40" />
-    <polygon points="40,10 40,0 50,0 50,10" />
-    <!-- n -->
-    <polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
-    <!-- y -->
-    <polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
-  </g>
-  <style>
-  @media (prefers-color-scheme: dark) {
-    #logo {
-      fill: #fff;
-    }
-  }
-  @media (prefers-color-scheme: light) {
-    #logo {
-      fill: #000;
-    }
-  }
-  </style>
-</svg>

+ 0 - 484
build/lib/exo/tinychat/index.css

@@ -1,484 +0,0 @@
-/* define colors */
-:root {
-  --primary-color: #fff;
-  --secondary-color: #2a2a2a;
-  --secondary-color-transparent: #ffffff66;
-  --primary-bg-color: #1a1a1a;
-  --foreground-color: #f0f0f0;
-  --red-color: #a52e4d;
-}
-
-main {
-  width: 100%;
-  height: 100%;
-
-  display: flex;
-  flex-direction: column;
-
-  place-items: center;
-}
-
-.home {
-  width: 100%;
-  height: 90%;
-
-  margin-bottom: 10rem;
-}
-
-.title {
-  font-size: 3rem;
-  margin: 1rem 0;
-  margin-top: 3rem;
-}
-
-.histories-container-container {
-  width: 100%;
-  max-height: 75%;
-
-  position: relative;
-}
-
-.histories-container {
-  overflow-y: auto;
-  overflow-x: hidden;
-  width: 100%;
-  height: 100%;
-
-  display: flex;
-  flex-direction: column;
-  gap: 1rem;
-  align-items: center;
-
-  margin: 0;
-  padding: 3rem 1rem;
-}
-
-.histories-start {
-  height: 3rem;
-  width: 100%;
-
-  z-index: 999;
-  top: 0;
-  position: absolute;
-
-  background: linear-gradient(
-    180deg,
-    var(--primary-bg-color) 0%,
-    transparent 100%
-  );
-}
-.histories-end {
-  height: 3rem;
-  width: 100%;
-
-  z-index: 999;
-  bottom: 0;
-  position: absolute;
-
-  background: linear-gradient(
-    0deg,
-    var(--primary-bg-color) 0%,
-    transparent 100%
-  );
-}
-
-.history {
-  padding: 1rem;
-  width: 100%;
-  max-width: 40rem;
-
-  background-color: var(--secondary-color);
-  border-radius: 10px;
-  border-left: 2px solid var(--primary-color);
-
-  cursor: pointer;
-
-  transform: translateX(calc(1px * var(--tx, 0)));
-  opacity: var(--opacity, 1);
-}
-.history:hover {
-  background-color: var(--secondary-color);
-}
-
-.history-delete-button {
-  position: absolute;
-  top: 0;
-  right: 0;
-  padding: 0.5rem;
-  margin: 0;
-  outline: none;
-  border: none;
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
-  border-radius: 0 0 0 10px;
-  cursor: pointer;
-  transition: 0.2s;
-}
-.history-delete-button:hover {
-  background-color: var(--secondary-color);
-  padding: 0.75rem;
-}
-
-.messages {
-  overflow-y: auto;
-  height: 100%;
-  width: 100%;
-  max-width: 1200px;
-
-  display: flex;
-  flex-direction: column;
-  gap: 1rem;
-  align-items: center;
-  padding-top: 1rem;
-  padding-bottom: 11rem;
-}
-
-.message {
-  max-width: 75%;
-  padding: 0.5rem 1rem;
-  border-radius: 20px;
-}
-.message-role-assistant {
-  background-color: var(--secondary-color);
-  margin-right: auto;
-  color: #fff;
-}
-.message-role-user {
-  margin-left: auto;
-  background-color: var(--primary-color);
-  color: #000;
-}
-.download-progress {
-  margin-bottom: 12em;
-  overflow-y: auto;
-  min-height: 350px;
-  padding: 2rem;
-}
-.message > pre {
-  white-space: pre-wrap;
-}
-
-.progress-bar-container {
-  width: 100%;
-  background-color: #e0e0e0;
-  border-radius: 4px;
-  margin: 10px 0;
-}
-.progress-bar {
-  height: 20px;
-  border-radius: 4px;
-  transition: width 0.5s ease-in-out;
-}
-.progress-bar.complete {
-  background-color: #4CAF50;
-}
-.progress-bar.in-progress {
-  background-color: #2196F3;
-}
-
-.toast {
-    width: 100%;
-    background-color: #fc2a2a;
-    color: #fff;
-    text-align: left;
-    border-radius: 2px;
-    padding: 16px;
-    position: fixed;
-    z-index: 9999;
-    top: 0;
-    left: 0;
-    right: 0;
-    display: flex;
-    flex-direction: column;
-    white-space: pre-wrap;
-    font-family: monospace;
-}
-
-.toast-header {
-    display: flex;
-    justify-content: space-between;
-    align-items: center;
-    width: 100%;
-}
-
-.toast-error-message {
-    flex-grow: 1;
-}
-
-.toast-header-buttons {
-    display: flex;
-    align-items: center;
-    gap: 16px;
-    margin-left: 24px;
-}
-
-.toast-expand-button {
-    background: none;
-    border: none;
-    color: white;
-    padding: 4px;
-    cursor: pointer;
-    font-size: 1em;
-}
-
-.toast-close-button {
-    background: none;
-    border: none;
-    color: white;
-    padding: 4px;
-    cursor: pointer;
-    font-size: 1.2em;
-    line-height: 1;
-}
-
-.toast-expand-button:hover,
-.toast-close-button:hover {
-    opacity: 0.8;
-}
-
-.toast-content {
-    margin-top: 10px;
-    padding: 10px;
-    background-color: rgba(0, 0, 0, 0.2);
-    border-radius: 4px;
-}
-
-.hljs {
-  width: 100%;
-  position: relative;
-  border-radius: 10px;
-  /* wrap code blocks */
-  white-space: pre-wrap;
-}
-/* put clipboard button in the top right corner of the code block */
-.clipboard-button {
-  position: absolute;
-  top: 0;
-  right: 0;
-  padding: 0.5rem;
-  margin: 0;
-  outline: none;
-  border: none;
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
-  border-radius: 0 0 0 10px;
-  cursor: pointer;
-  transition: 0.2s;
-}
-.clipboard-button:hover {
-  background-color: var(--secondary-color);
-  padding: 0.75rem;
-}
-
-.input-container {
-  position: absolute;
-  bottom: 0;
-
-  /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(
-    0deg,
-    var(--primary-bg-color) 55%,
-    transparent 100%
-  );
-
-  width: 100%;
-  max-width: 1200px;
-  display: flex;
-  flex-direction: column;
-  justify-content: center;
-  align-items: center;
-  z-index: 999;
-}
-
-.input-performance {
-  margin-top: 4rem;
-
-  display: flex;
-  flex-direction: row;
-  gap: 1rem;
-}
-
-.input-performance-point {
-  display: flex;
-  flex-direction: row;
-  place-items: center;
-  gap: 0.5rem;
-}
-.input-performance-point > p {
-  height: 1rem;
-  line-height: normal;
-}
-
-.input {
-  width: 90%;
-  min-height: 3rem;
-  flex-shrink: 0;
-
-  display: flex;
-  flex-direction: row;
-  justify-content: center;
-  gap: 0.5rem;
-
-  align-items: flex-end;
-  margin-bottom: 2rem;
-}
-
-.input-form {
-  width: 100%;
-  padding: 1rem;
-  min-height: 3rem;
-  max-height: 8rem;
-
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
-  border-radius: 10px;
-  border: none;
-  resize: none;
-  outline: none;
-}
-
-.input-button {
-  height: 3rem;
-  width: 4rem;
-
-  background-color: var(--primary-color);
-  color: var(--secondary-color);
-  border-radius: 10px;
-  padding: 0.5rem;
-  cursor: pointer;
-}
-.input-button:hover {
-  background-color: var(--secondary-color-transparent);
-}
-.input-button:disabled {
-  background-color: var(--secondary-color);
-  cursor: not-allowed;
-}
-
-/* wrap text */
-p {
-  white-space: pre-wrap;
-}
-
-/* fonts */
-.megrim-regular {
-  font-family: "Megrim", system-ui;
-  font-weight: 400;
-  font-style: normal;
-}
-
-.monospace {
-  font-family: monospace;
-}
-
-.model-selector {
-  display: flex;
-  justify-content: center;
-  padding: 20px 0;
-}
-.model-selector select {
-  padding: 10px 20px;
-  font-size: 16px;
-  border: 1px solid #ccc;
-  border-radius: 5px;
-  background-color: #f8f8f8;
-  cursor: pointer;
-}
-.model-selector select:focus {
-  outline: none;
-  border-color: #007bff;
-  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
-}
-
-/* Image upload button styles */
-.image-input-button {
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
-  border: none;
-  border-radius: 50%;
-  width: 40px;
-  height: 40px;
-  font-size: 18px;
-  cursor: pointer;
-  transition: all 0.3s ease;
-  display: flex;
-  align-items: center;
-  justify-content: center;
-  margin-right: 10px;
-}
-
-.image-input-button:hover {
-  background-color: var(--secondary-color-transparent);
-  transform: scale(1.1);
-}
-
-.image-input-button:focus {
-  outline: none;
-  box-shadow: 0 0 0 3px rgba(var(--secondary-color-rgb), 0.5);
-}
-
-.image-input-button i {
-  transition: all 0.3s ease;
-}
-
-.image-input-button:hover i {
-  transform: scale(1.2);
-}
-
-/* Hidden file input styles */
-#image-upload {
-  display: none;
-}
-
-.image-preview-container {
-  position: relative;
-  display: inline-block;
-  margin-right: 10px;
-}
-
-.image-preview {
-  max-width: 100px;
-  max-height: 100px;
-  object-fit: cover;
-  border-radius: 5px;
-}
-
-.remove-image-button {
-  position: absolute;
-  top: -5px;
-  right: -5px;
-  background-color: rgba(255, 255, 255, 0.8);
-  border: none;
-  border-radius: 50%;
-  padding: 2px 5px;
-  cursor: pointer;
-}
-
-.message > p > img {
-  max-width: 100%;
-  max-height: 100%;
-  object-fit: contain;
-}
-
-.clear-history-button {
-  background-color: var(--red-color);
-  color: white;
-  padding: 10px 20px;
-  border-radius: 5px;
-  display: flex;
-  align-items: center;
-  gap: 8px;
-  transition: all 0.3s ease;
-  margin: 1rem auto;
-  border: none;
-  cursor: pointer;
-}
-
-.clear-history-button:hover {
-  opacity: 0.8;
-  transform: scale(1.05);
-}
-
-.clear-history-button i {
-  font-size: 14px;
-}

+ 0 - 255
build/lib/exo/tinychat/index.html

@@ -1,255 +0,0 @@
-<!DOCTYPE html>
-
-<head>
-<title>tinychat</title>
-<meta content="width=device-width, initial-scale=1" name="viewport"/>
-<link href="favicon.svg" rel="icon" type="image/svg+xml"/>
-<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
-<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
-<script defer="" src="/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
-<script defer="" src="/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
-<script defer="" src="/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
-<script src="/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
-<script src="/static/unpkg.com/marked@13.0.0/marked.min.js"></script>
-<script src="/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
-<script src="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
-<script src="/index.js"></script>
-<link href="/static/fonts.googleapis.com" rel="preconnect"/>
-<link crossorigin="" href="/static/fonts.gstatic.com" rel="preconnect"/>
-<link href="/static/fonts.googleapis.com/css2" rel="stylesheet"/>
-<link href="/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css" rel="stylesheet"/>
-<link crossorigin="anonymous" href="/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css" integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A==" referrerpolicy="no-referrer" rel="stylesheet">
-<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
-<link href="/index.css" rel="stylesheet"/>
-<link href="/common.css" rel="stylesheet"/>
-</head>
-<body>
-<main x-data="state" x-init="console.log(endpoint)">
-     <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity class="toast">
-        <div class="toast-header">
-            <span class="toast-error-message" x-text="errorMessage.basic"></span>
-            <div class="toast-header-buttons">
-                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
-                        class="toast-expand-button" 
-                        x-show="errorMessage.stack">
-                    <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
-                </button>
-                <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
-                    <i class="fas fa-times"></i>
-                </button>
-            </div>
-        </div>
-        <div class="toast-content" x-show="errorExpanded" x-transition>
-            <span x-text="errorMessage.stack"></span>
-        </div>
-    </div>
-<div class="model-selector">
-  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
-  </select>
-</div>
-<div @popstate.window="
-      if (home === 2) {
-        home = -1;
-        cstate = { time: null, messages: [], selectedModel: 'llama-3.1-8b' };
-        time_till_first = 0;
-        tokens_per_second = 0;
-        total_tokens = 0;
-      }
-    " class="home centered" x-effect="
-      $refs.inputForm.focus();
-      if (home === 1) setTimeout(() =&gt; home = 2, 100);
-      if (home === -1) setTimeout(() =&gt; home = 0, 100);
-    " x-show="home === 0" x-transition="">
-<h1 class="title megrim-regular">tinychat</h1>
-<template x-if="histories.length">
-  <button 
-    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
-    class="clear-history-button">
-    <i class="fas fa-trash"></i> Clear All History
-  </button>
-</template>
-<div class="histories-container-container">
-<template x-if="histories.length">
-<div class="histories-start"></div>
-</template>
-<div class="histories-container" x-intersect="
-          $el.scrollTo({ top: 0, behavior: 'smooth' });
-        ">
-<template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
-<div @click="
-            cstate = _state;
-            if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
-            // updateTotalTokens(cstate.messages);
-            home = 1;
-            // ensure that going back in history will go back to home
-            window.history.pushState({}, '', '/');
-          " @touchend="
-            if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
-            $el.style.setProperty('--tx', 0);
-            $el.style.setProperty('--opacity', 1);
-          " @touchmove="
-            $el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
-            $el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
-          " @touchstart="
-            otx = $event.changedTouches[0].clientX;
-          " class="history" x-data="{ otx: 0, trigger: 75 }">
-<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
-<p x-text="$truncate(_state.messages[0].content, 80)"></p>
-<!-- delete button -->
-<button @click.stop="removeHistory(_state);" class="history-delete-button">
-<i class="fas fa-trash"></i>
-</button>
-</div>
-</template>
-</div>
-<template x-if="histories.length">
-<div class="histories-end"></div>
-</template>
-</div>
-</div>
-<div class="messages" x-init="
-      $watch('cstate', value =&gt; {
-        $el.innerHTML = '';
-        value.messages.forEach(({ role, content }) =&gt; {
-          const div = document.createElement('div');
-          div.className = `message message-role-${role}`;
-          try {
-              if (content.includes('![Generated Image]')) {
-                const imageUrl = content.match(/\((.*?)\)/)[1];
-                const img = document.createElement('img');
-                img.src = imageUrl;
-                img.alt = 'Generated Image';
-                img.onclick = async () => {
-                  try {
-                    const response = await fetch(img.src);
-                    const blob = await response.blob();
-                    const file = new File([blob], 'image.png', { type: 'image/png' });
-                    handleImageUpload({ target: { files: [file] } });
-                  } catch (error) {
-                    console.error('Error fetching image:', error);
-                  }
-                };
-                div.appendChild(img);
-              } else {
-                div.innerHTML = DOMPurify.sanitize(marked.parse(content));
-              }
-          } catch (e) {
-            console.log(content);
-            console.error(e);
-          }
-
-          // add a clipboard button to all code blocks
-          const codeBlocks = div.querySelectorAll('.hljs');
-          codeBlocks.forEach(codeBlock =&gt; {
-            const button = document.createElement('button');
-            button.className = 'clipboard-button';
-            button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;';
-            button.onclick = () =&gt; {
-              // navigator.clipboard.writeText(codeBlock.textContent);
-              const range = document.createRange();
-              range.setStartBefore(codeBlock);
-              range.setEndAfter(codeBlock);
-              window.getSelection()?.removeAllRanges();
-              window.getSelection()?.addRange(range);
-              document.execCommand('copy');
-              window.getSelection()?.removeAllRanges();
-
-              button.innerHTML = '&lt;i class=\'fas fa-check\'&gt;&lt;/i&gt;';
-              setTimeout(() =&gt; button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;', 1000);
-            };
-            codeBlock.appendChild(button);
-          });
-
-          $el.appendChild(div);
-        });
-
-        $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
-      });
-    " x-intersect="
-      $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
-    " x-ref="messages" x-show="home === 2" x-transition="">
-</div>
-
-<!-- Download Progress Section -->
-<template x-if="downloadProgress && downloadProgress.length > 0">
-  <div class="download-progress message message-role-assistant">
-    <h2>Download Progress</h2>
-    <br>
-    <template x-for="(progress, index) in downloadProgress" :key="index">
-      <div class="download-progress-node">
-        <br>
-        <h3 x-text="`Download ${index + 1}`"></h3>
-        <p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
-        <p><strong>Status:</strong> <span x-text="progress.status"></span></p>
-        <div class="progress-bar-container">
-          <div class="progress-bar" 
-               :class="progress.isComplete ? 'complete' : 'in-progress'"
-               :style="`width: ${progress.percentage}%;`">
-          </div>
-        </div>
-        <template x-if="!progress.isComplete">
-          <div>
-            <p><strong>Progress:</strong> <span x-text="`${progress.downloaded_bytes_display} / ${progress.total_bytes_display} (${progress.percentage}%)`"></span></p>
-            <p><strong>Speed:</strong> <span x-text="progress.overall_speed_display || 'N/A'"></span></p>
-            <p><strong>ETA:</strong> <span x-text="progress.overall_eta_display || 'N/A'"></span></p>
-          </div>
-        </template>
-      </div>
-    </template>
-  </div>
-</template>
-
-
-<div class="input-container">
-<div class="input-performance">
-<span class="input-performance-point">
-<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
-<p class="megrim-regular">SEC TO FIRST TOKEN</p>
-</span>
-<span class="input-performance-point">
-<p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
-<p class="megrim-regular">TOKENS/SEC</p>
-</span>
-<span class="input-performance-point">
-<p class="monospace" x-text="total_tokens"></p>
-<p class="megrim-regular">TOKENS</p>
-</span>
-</div>
-<div class="input">
-<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
-<i class="fas fa-image"></i>
-</button>
-<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
-<div class="image-preview-container" x-show="imagePreview">
-<img :src="imagePreview" alt="Uploaded Image" class="image-preview"/>
-<button @click="imagePreview = null; imageUrl = null;" class="remove-image-button">
-<i class="fas fa-times"></i>
-</button>
-</div>
-<textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
-            home = (home === 0) ? 1 : home
-            if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;
-
-            if ($el.value !== '') {
-              const messages = [...cstate.messages];
-              messages.push({ role: 'user', content: $el.value });
-              // updateTotalTokens(messages);
-            } else {
-              if (cstate.messages.length === 0) total_tokens = 0;
-              // else updateTotalTokens(cstate.messages);
-            }
-          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)" autofocus="" class="input-form" id="input-form" rows="1" x-autosize="" x-effect="
-            console.log(generating);
-            if (!generating) $nextTick(() =&gt; {
-              $el.focus();
-              setTimeout(() =&gt; $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
-            });
-          " x-ref="inputForm"></textarea>
-<button :disabled="generating" @click="await handleSend()" class="input-button">
-<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
-</button>
-</div>
-</div>
-</main>
-</body>

+ 0 - 687
build/lib/exo/tinychat/index.js

@@ -1,687 +0,0 @@
-document.addEventListener("alpine:init", () => {
-  Alpine.data("state", () => ({
-    // current state
-    cstate: {
-      time: null,
-      messages: [],
-      selectedModel: 'llama-3.2-1b',
-    },    
-
-    // historical state
-    histories: JSON.parse(localStorage.getItem("histories")) || [],
-
-    home: 0,
-    generating: false,
-    endpoint: `${window.location.origin}/v1`,
-    errorMessage: null,
-    errorExpanded: false,
-    errorTimeout: null,
-
-    // performance tracking
-    time_till_first: 0,
-    tokens_per_second: 0,
-    total_tokens: 0,
-
-    // image handling
-    imagePreview: null,
-
-    // download progress
-    downloadProgress: null,
-    downloadProgressInterval: null, // To keep track of the polling interval
-
-    // Pending message storage
-    pendingMessage: null,
-
-    init() {
-      // Clean up any pending messages
-      localStorage.removeItem("pendingMessage");
-
-      // Start polling for download progress
-      this.startDownloadProgressPolling();
-    },
-
-    removeHistory(cstate) {
-      const index = this.histories.findIndex((state) => {
-        return state.time === cstate.time;
-      });
-      if (index !== -1) {
-        this.histories.splice(index, 1);
-        localStorage.setItem("histories", JSON.stringify(this.histories));
-      }
-    },
-
-    clearAllHistory() {
-      this.histories = [];
-      localStorage.setItem("histories", JSON.stringify([]));
-    },
-
-    // Utility functions
-    formatBytes(bytes) {
-      if (bytes === 0) return '0 B';
-      const k = 1024;
-      const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
-      const i = Math.floor(Math.log(bytes) / Math.log(k));
-      return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
-    },
-
-    formatDuration(seconds) {
-      if (seconds === null || seconds === undefined || isNaN(seconds)) return '';
-      const h = Math.floor(seconds / 3600);
-      const m = Math.floor((seconds % 3600) / 60);
-      const s = Math.floor(seconds % 60);
-      if (h > 0) return `${h}h ${m}m ${s}s`;
-      if (m > 0) return `${m}m ${s}s`;
-      return `${s}s`;
-    },
-
-    async populateSelector() {
-      try {
-        const response = await fetch(`${window.location.origin}/modelpool`);
-        const responseText = await response.text(); // Get raw response text first
-        
-        if (!response.ok) {
-          throw new Error(`HTTP error! status: ${response.status}`);
-        }
-        
-        // Try to parse the response text
-        let responseJson;
-        try {
-          responseJson = JSON.parse(responseText);
-        } catch (parseError) {
-          console.error('Failed to parse JSON:', parseError);
-          throw new Error(`Invalid JSON response: ${responseText}`);
-        }
-
-        const sel = document.querySelector(".model-select");
-        if (!sel) {
-          throw new Error("Could not find model selector element");
-        }
-
-        // Clear the current options and add new ones
-        sel.innerHTML = '';
-          
-        const modelDict = responseJson["model pool"];
-        if (!modelDict) {
-          throw new Error("Response missing 'model pool' property");
-        }
-
-        Object.entries(modelDict).forEach(([key, value]) => {
-          const opt = document.createElement("option");
-          opt.value = key;
-          opt.textContent = value;
-          sel.appendChild(opt);
-        });
-
-        // Set initial value to the first model
-        const firstKey = Object.keys(modelDict)[0];
-        if (firstKey) {
-          sel.value = firstKey;
-          this.cstate.selectedModel = firstKey;
-        }
-      } catch (error) {
-        console.error("Error populating model selector:", error);
-        this.errorMessage = `Failed to load models: ${error.message}`;
-      }
-    },
-
-    async handleImageUpload(event) {
-      const file = event.target.files[0];
-      if (file) {
-        const reader = new FileReader();
-        reader.onload = (e) => {
-          this.imagePreview = e.target.result;
-          this.imageUrl = e.target.result; // Store the image URL
-          // Add image preview to the chat
-          this.cstate.messages.push({
-            role: "user",
-            content: `![Uploaded Image](${this.imagePreview})`,
-          });
-        };
-        reader.readAsDataURL(file);
-      }
-    },
-
-
-    async handleSend() {
-      try {
-        const el = document.getElementById("input-form");
-        const value = el.value.trim();
-        if (!value && !this.imagePreview) return;
-
-        if (this.generating) return;
-        this.generating = true;
-        if (this.home === 0) this.home = 1;
-
-        // ensure that going back in history will go back to home
-        window.history.pushState({}, "", "/");
-
-        // add message to list
-        if (value) {
-          this.cstate.messages.push({ role: "user", content: value });
-        }
-
-        // clear textarea
-        el.value = "";
-        el.style.height = "auto";
-        el.style.height = el.scrollHeight + "px";
-
-        localStorage.setItem("pendingMessage", value);
-        this.processMessage(value);
-      } catch (error) {
-        console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
-        this.generating = false;
-      }
-    },
-
-    async processMessage(value) {
-      try {
-        // reset performance tracking
-        const prefill_start = Date.now();
-        let start_time = 0;
-        let tokens = 0;
-        this.tokens_per_second = 0;
-
-        // prepare messages for API request
-        let apiMessages = this.cstate.messages.map(msg => {
-          if (msg.content.startsWith('![Uploaded Image]')) {
-            return {
-              role: "user",
-              content: [
-                {
-                  type: "image_url",
-                  image_url: {
-                    url: this.imageUrl
-                  }
-                },
-                {
-                  type: "text",
-                  text: value // Use the actual text the user typed
-                }
-              ]
-            };
-          } else {
-            return {
-              role: msg.role,
-              content: msg.content
-            };
-          }
-        });
-        
-        if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
-          // Send a request to the image generation endpoint
-          console.log(apiMessages[apiMessages.length - 1].content)
-          console.log(this.cstate.selectedModel)  
-          console.log(this.endpoint)
-          const response = await fetch(`${this.endpoint}/image/generations`, {
-            method: "POST",
-            headers: {
-              "Content-Type": "application/json",
-            },
-            body: JSON.stringify({
-              "model": 'stable-diffusion-2-1-base',
-              "prompt": apiMessages[apiMessages.length - 1].content,
-              "image_url": this.imageUrl
-            }),
-          });
-      
-          if (!response.ok) {
-            throw new Error("Failed to fetch");
-          }
-          const reader = response.body.getReader();
-          let done = false;
-          let gottenFirstChunk = false;
-  
-          while (!done) {
-            const { value, done: readerDone } = await reader.read();
-            done = readerDone;
-            const decoder = new TextDecoder();
-  
-            if (value) {
-              // Assume non-binary data (text) comes first
-              const chunk = decoder.decode(value, { stream: true });
-              const parsed = JSON.parse(chunk);
-              console.log(parsed)
-  
-              if (parsed.progress) {
-                if (!gottenFirstChunk) {
-                  this.cstate.messages.push({ role: "assistant", content: "" });
-                  gottenFirstChunk = true;
-                }
-                this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
-              }
-              else if (parsed.images) {
-                if (!gottenFirstChunk) {
-                  this.cstate.messages.push({ role: "assistant", content: "" });
-                  gottenFirstChunk = true;
-                }
-                const imageUrl = parsed.images[0].url;
-                console.log(imageUrl)
-                this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
-              }
-            }
-          }
-        }
-        
-        else{        
-          const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
-          if (containsImage) {
-            // Map all messages with string content to object with type text
-            apiMessages = apiMessages.map(msg => {
-              if (typeof msg.content === 'string') {
-                return {
-                  ...msg,
-                  content: [
-                    {
-                      type: "text",
-                      text: msg.content
-                    }
-                  ]
-                };
-              }
-              return msg;
-            });
-          }
-
-          console.log(apiMessages)
-          //start receiving server sent events
-          let gottenFirstChunk = false;
-          for await (
-            const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
-          ) {
-            if (!gottenFirstChunk) {
-              this.cstate.messages.push({ role: "assistant", content: "" });
-              gottenFirstChunk = true;
-            }
-
-            // add chunk to the last message
-            this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
-
-            // calculate performance tracking
-            tokens += 1;
-            this.total_tokens += 1;
-            if (start_time === 0) {
-              start_time = Date.now();
-              this.time_till_first = start_time - prefill_start;
-            } else {
-              const diff = Date.now() - start_time;
-              if (diff > 0) {
-                this.tokens_per_second = tokens / (diff / 1000);
-              }
-            }
-          }
-        }
-        // Clean the cstate before adding it to histories
-        const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
-        cleanedCstate.messages = cleanedCstate.messages.map(msg => {
-          if (Array.isArray(msg.content)) {
-            return {
-              ...msg,
-              content: msg.content.map(item =>
-                item.type === 'image_url' ? { type: 'image_url', image_url: { url: '[IMAGE_PLACEHOLDER]' } } : item
-              )
-            };
-          }
-          return msg;
-        });
-
-        // Update the state in histories or add it if it doesn't exist
-        const index = this.histories.findIndex((cstate) => cstate.time === cleanedCstate.time);
-        cleanedCstate.time = Date.now();
-        if (index !== -1) {
-          // Update the existing entry
-          this.histories[index] = cleanedCstate;
-        } else {
-          // Add a new entry
-          this.histories.push(cleanedCstate);
-        }
-        console.log(this.histories)
-        // update in local storage
-        try {
-          localStorage.setItem("histories", JSON.stringify(this.histories));
-        } catch (error) {
-          console.error("Failed to save histories to localStorage:", error);
-        }
-      } catch (error) {
-        console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
-      } finally {
-        this.generating = false;
-      }
-    },
-
-    async handleEnter(event) {
-      // if shift is not pressed
-      if (!event.shiftKey) {
-        event.preventDefault();
-        await this.handleSend();
-      }
-    },
-
-    updateTotalTokens(messages) {
-      fetch(`${this.endpoint}/chat/token/encode`, {
-        method: "POST",
-        headers: { "Content-Type": "application/json" },
-        body: JSON.stringify({ messages }),
-      }).then((response) => response.json()).then((data) => {
-        this.total_tokens = data.length;
-      }).catch(console.error);
-    },
-
-    async *openaiChatCompletion(model, messages) {
-      // stream response
-      console.log("model", model)
-      const response = await fetch(`${this.endpoint}/chat/completions`, {
-        method: "POST",
-        headers: {
-          "Content-Type": "application/json",
-        },
-        body: JSON.stringify({
-          "model": model,
-          "messages": messages,
-          "stream": true,
-        }),
-      });
-      if (!response.ok) {
-        const errorResBody = await response.json()
-        if (errorResBody?.detail) {
-          throw new Error(`Failed to fetch completions: ${errorResBody.detail}`);
-        } else {
-          throw new Error("Failed to fetch completions: Unknown error");
-        }
-      }
-
-      const reader = response.body.pipeThrough(new TextDecoderStream())
-        .pipeThrough(new EventSourceParserStream()).getReader();
-      while (true) {
-        const { done, value } = await reader.read();
-        if (done) {
-          break;
-        }
-        if (value.type === "event") {
-          const json = JSON.parse(value.data);
-          if (json.choices) {
-            const choice = json.choices[0];
-            if (choice.finish_reason === "stop") {
-              break;
-            }
-            yield choice.delta.content;
-          }
-        }
-      }
-    },
-
-    async fetchDownloadProgress() {
-      try {
-        const response = await fetch(`${this.endpoint}/download/progress`);
-        if (response.ok) {
-          const data = await response.json();
-          const progressArray = Object.values(data);
-          if (progressArray.length > 0) {
-            this.downloadProgress = progressArray.map(progress => {
-              // Check if download is complete
-              if (progress.status === "complete") {
-                return {
-                  ...progress,
-                  isComplete: true,
-                  percentage: 100
-                };
-              } else if (progress.status === "failed") {
-                return {
-                  ...progress,
-                  isComplete: false,
-                  errorMessage: "Download failed"
-                };
-              } else {
-                return {
-                  ...progress,
-                  isComplete: false,
-                  downloaded_bytes_display: this.formatBytes(progress.downloaded_bytes),
-                  total_bytes_display: this.formatBytes(progress.total_bytes),
-                  overall_speed_display: progress.overall_speed ? this.formatBytes(progress.overall_speed) + '/s' : '',
-                  overall_eta_display: progress.overall_eta ? this.formatDuration(progress.overall_eta) : '',
-                  percentage: ((progress.downloaded_bytes / progress.total_bytes) * 100).toFixed(2)
-                };
-              }
-            });
-            const allComplete = this.downloadProgress.every(progress => progress.isComplete);
-            if (allComplete) {
-              // Check for pendingMessage
-              const savedMessage = localStorage.getItem("pendingMessage");
-              if (savedMessage) {
-                // Clear pendingMessage
-                localStorage.removeItem("pendingMessage");
-                // Call processMessage() with savedMessage
-                if (this.lastErrorMessage) {
-                  await this.processMessage(savedMessage);
-                }
-              }
-              this.lastErrorMessage = null;
-              this.downloadProgress = null;
-            }
-          } else {
-            // No ongoing download
-            this.downloadProgress = null;
-          }
-        }
-      } catch (error) {
-        console.error("Error fetching download progress:", error);
-        this.downloadProgress = null;
-      }
-    },
-
-    startDownloadProgressPolling() {
-      if (this.downloadProgressInterval) {
-        // Already polling
-        return;
-      }
-      this.fetchDownloadProgress(); // Fetch immediately
-      this.downloadProgressInterval = setInterval(() => {
-        this.fetchDownloadProgress();
-      }, 1000); // Poll every second
-    },
-  }));
-});
-
-const { markedHighlight } = globalThis.markedHighlight;
-marked.use(markedHighlight({
-  langPrefix: "hljs language-",
-  highlight(code, lang, _info) {
-    const language = hljs.getLanguage(lang) ? lang : "plaintext";
-    return hljs.highlight(code, { language }).value;
-  },
-}));
-
-// **** eventsource-parser ****
-class EventSourceParserStream extends TransformStream {
-  constructor() {
-    let parser;
-
-    super({
-      start(controller) {
-        parser = createParser((event) => {
-          if (event.type === "event") {
-            controller.enqueue(event);
-          }
-        });
-      },
-
-      transform(chunk) {
-        parser.feed(chunk);
-      },
-    });
-  }
-}
-
-function createParser(onParse) {
-  let isFirstChunk;
-  let buffer;
-  let startingPosition;
-  let startingFieldLength;
-  let eventId;
-  let eventName;
-  let data;
-  reset();
-  return {
-    feed,
-    reset,
-  };
-  function reset() {
-    isFirstChunk = true;
-    buffer = "";
-    startingPosition = 0;
-    startingFieldLength = -1;
-    eventId = void 0;
-    eventName = void 0;
-    data = "";
-  }
-  function feed(chunk) {
-    buffer = buffer ? buffer + chunk : chunk;
-    if (isFirstChunk && hasBom(buffer)) {
-      buffer = buffer.slice(BOM.length);
-    }
-    isFirstChunk = false;
-    const length = buffer.length;
-    let position = 0;
-    let discardTrailingNewline = false;
-    while (position < length) {
-      if (discardTrailingNewline) {
-        if (buffer[position] === "\n") {
-          ++position;
-        }
-        discardTrailingNewline = false;
-      }
-      let lineLength = -1;
-      let fieldLength = startingFieldLength;
-      let character;
-      for (
-        let index = startingPosition;
-        lineLength < 0 && index < length;
-        ++index
-      ) {
-        character = buffer[index];
-        if (character === ":" && fieldLength < 0) {
-          fieldLength = index - position;
-        } else if (character === "\r") {
-          discardTrailingNewline = true;
-          lineLength = index - position;
-        } else if (character === "\n") {
-          lineLength = index - position;
-        }
-      }
-      if (lineLength < 0) {
-        startingPosition = length - position;
-        startingFieldLength = fieldLength;
-        break;
-      } else {
-        startingPosition = 0;
-        startingFieldLength = -1;
-      }
-      parseEventStreamLine(buffer, position, fieldLength, lineLength);
-      position += lineLength + 1;
-    }
-    if (position === length) {
-      buffer = "";
-    } else if (position > 0) {
-      buffer = buffer.slice(position);
-    }
-  }
-  function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) {
-    if (lineLength === 0) {
-      if (data.length > 0) {
-        onParse({
-          type: "event",
-          id: eventId,
-          event: eventName || void 0,
-          data: data.slice(0, -1),
-          // remove trailing newline
-        });
-
-        data = "";
-        eventId = void 0;
-      }
-      eventName = void 0;
-      return;
-    }
-    const noValue = fieldLength < 0;
-    const field = lineBuffer.slice(
-      index,
-      index + (noValue ? lineLength : fieldLength),
-    );
-    let step = 0;
-    if (noValue) {
-      step = lineLength;
-    } else if (lineBuffer[index + fieldLength + 1] === " ") {
-      step = fieldLength + 2;
-    } else {
-      step = fieldLength + 1;
-    }
-    const position = index + step;
-    const valueLength = lineLength - step;
-    const value = lineBuffer.slice(position, position + valueLength).toString();
-    if (field === "data") {
-      data += value ? "".concat(value, "\n") : "\n";
-    } else if (field === "event") {
-      eventName = value;
-    } else if (field === "id" && !value.includes("\0")) {
-      eventId = value;
-    } else if (field === "retry") {
-      const retry = parseInt(value, 10);
-      if (!Number.isNaN(retry)) {
-        onParse({
-          type: "reconnect-interval",
-          value: retry,
-        });
-      }
-    }
-  }
-}
-
-const BOM = [239, 187, 191];
-function hasBom(buffer) {
-  return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
-}

文件差異過大導致無法顯示
+ 0 - 0
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js


文件差異過大導致無法顯示
+ 0 - 0
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js


+ 0 - 1
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js

@@ -1 +0,0 @@
-(()=>{function o(e){e.directive("intersect",e.skipDuringClone((t,{value:i,expression:l,modifiers:n},{evaluateLater:r,cleanup:c})=>{let s=r(l),a={rootMargin:x(n),threshold:f(n)},u=new IntersectionObserver(d=>{d.forEach(h=>{h.isIntersecting!==(i==="leave")&&(s(),n.includes("once")&&u.disconnect())})},a);u.observe(t),c(()=>{u.disconnect()})}))}function f(e){if(e.includes("full"))return .99;if(e.includes("half"))return .5;if(!e.includes("threshold"))return 0;let t=e[e.indexOf("threshold")+1];return t==="100"?1:t==="0"?0:Number(`.${t}`)}function p(e){let t=e.match(/^(-?[0-9]+)(px|%)?$/);return t?t[1]+(t[2]||"px"):void 0}function x(e){let t="margin",i="0px 0px 0px 0px",l=e.indexOf(t);if(l===-1)return i;let n=[];for(let r=1;r<5;r++)n.push(p(e[l+r]||""));return n=n.filter(r=>r!==void 0),n.length?n.join(" ").trim():i}document.addEventListener("alpine:init",()=>{window.Alpine.plugin(o)});})();

+ 0 - 11
build/lib/exo/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css

@@ -1,11 +0,0 @@
-/*!
-Pure v3.0.0
-Copyright 2013 Yahoo!
-Licensed under the BSD License.
-https://github.com/pure-css/pure/blob/master/LICENSE
-*/
-/*!
-normalize.css v | MIT License | https://necolas.github.io/normalize.css/
-Copyright (c) Nicolas Gallagher and Jonathan Neal
-*/
-/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block}

文件差異過大導致無法顯示
+ 0 - 5
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf


二進制
build/lib/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2


+ 0 - 7
build/lib/exo/tinychat/static/fonts.googleapis.com/css2

@@ -1,7 +0,0 @@
-@font-face {
-  font-family: 'Megrim';
-  font-style: normal;
-  font-weight: 400;
-  font-display: swap;
-  src: url(https://fonts.gstatic.com/s/megrim/v16/46kulbz5WjvLqJZlbQ.ttf) format('truetype');
-}

文件差異過大導致無法顯示
+ 0 - 316
build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js


+ 0 - 1
build/lib/exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css

@@ -1 +0,0 @@
-pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{background:#1e1e1e;color:#dcdcdc}.hljs-keyword,.hljs-literal,.hljs-name,.hljs-symbol{color:#569cd6}.hljs-link{color:#569cd6;text-decoration:underline}.hljs-built_in,.hljs-type{color:#4ec9b0}.hljs-class,.hljs-number{color:#b8d7a3}.hljs-meta .hljs-string,.hljs-string{color:#d69d85}.hljs-regexp,.hljs-template-tag{color:#9a5334}.hljs-formula,.hljs-function,.hljs-params,.hljs-subst,.hljs-title{color:#dcdcdc}.hljs-comment,.hljs-quote{color:#57a64a;font-style:italic}.hljs-doctag{color:#608b4e}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-tag{color:#9b9b9b}.hljs-template-variable,.hljs-variable{color:#bd63c5}.hljs-attr,.hljs-attribute{color:#9cdcfe}.hljs-section{color:gold}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}.hljs-bullet,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-selector-pseudo,.hljs-selector-tag{color:#d7ba7d}.hljs-addition{background-color:#144212;display:inline-block;width:100%}.hljs-deletion{background-color:#600;display:inline-block;width:100%}

文件差異過大導致無法顯示
+ 0 - 0
build/lib/exo/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js


文件差異過大導致無法顯示
+ 0 - 0
build/lib/exo/tinychat/static/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js


文件差異過大導致無法顯示
+ 0 - 1
build/lib/exo/tinychat/static/unpkg.com/dompurify@3.1.5/dist/purify.min.js


+ 0 - 97
build/lib/exo/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js

@@ -1,97 +0,0 @@
-(function (global, factory) {
-  typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
-  typeof define === 'function' && define.amd ? define(['exports'], factory) :
-  (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.markedHighlight = {}));
-})(this, (function (exports) { 'use strict';
-
-  function markedHighlight(options) {
-    if (typeof options === 'function') {
-      options = {
-        highlight: options
-      };
-    }
-
-    if (!options || typeof options.highlight !== 'function') {
-      throw new Error('Must provide highlight function');
-    }
-
-    if (typeof options.langPrefix !== 'string') {
-      options.langPrefix = 'language-';
-    }
-
-    return {
-      async: !!options.async,
-      walkTokens(token) {
-        if (token.type !== 'code') {
-          return;
-        }
-
-        const lang = getLang(token.lang);
-
-        if (options.async) {
-          return Promise.resolve(options.highlight(token.text, lang, token.lang || '')).then(updateToken(token));
-        }
-
-        const code = options.highlight(token.text, lang, token.lang || '');
-        if (code instanceof Promise) {
-          throw new Error('markedHighlight is not set to async but the highlight function is async. Set the async option to true on markedHighlight to await the async highlight function.');
-        }
-        updateToken(token)(code);
-      },
-      useNewRenderer: true,
-      renderer: {
-        code({ text, lang, escaped }) {
-          const language = getLang(lang);
-          const classAttr = language
-            ? ` class="${options.langPrefix}${escape(language)}"`
-            : '';
-          text = text.replace(/\n$/, '');
-          return `<pre><code${classAttr}>${escaped ? text : escape(text, true)}\n</code></pre>`;
-        }
-      }
-    };
-  }
-
-  function getLang(lang) {
-    return (lang || '').match(/\S*/)[0];
-  }
-
-  function updateToken(token) {
-    return (code) => {
-      if (typeof code === 'string' && code !== token.text) {
-        token.escaped = true;
-        token.text = code;
-      }
-    };
-  }
-
-  // copied from marked helpers
-  const escapeTest = /[&<>"']/;
-  const escapeReplace = new RegExp(escapeTest.source, 'g');
-  const escapeTestNoEncode = /[<>"']|&(?!(#\d{1,7}|#[Xx][a-fA-F0-9]{1,6}|\w+);)/;
-  const escapeReplaceNoEncode = new RegExp(escapeTestNoEncode.source, 'g');
-  const escapeReplacements = {
-    '&': '&amp;',
-    '<': '&lt;',
-    '>': '&gt;',
-    '"': '&quot;',
-    "'": '&#39;'
-  };
-  const getEscapeReplacement = (ch) => escapeReplacements[ch];
-  function escape(html, encode) {
-    if (encode) {
-      if (escapeTest.test(html)) {
-        return html.replace(escapeReplace, getEscapeReplacement);
-      }
-    } else {
-      if (escapeTestNoEncode.test(html)) {
-        return html.replace(escapeReplaceNoEncode, getEscapeReplacement);
-      }
-    }
-
-    return html;
-  }
-
-  exports.markedHighlight = markedHighlight;
-
-}));

文件差異過大導致無法顯示
+ 0 - 5
build/lib/exo/tinychat/static/unpkg.com/marked@13.0.0/marked.min.js


+ 0 - 93
build/lib/exo/tinychat/update_deps.py

@@ -1,93 +0,0 @@
-import os
-import requests
-from bs4 import BeautifulSoup
-from urllib.parse import urljoin, urlparse
-import re
-
-
-def download_file(url, local_path):
-  response = requests.get(url)
-  if response.status_code == 200:
-    os.makedirs(os.path.dirname(local_path), exist_ok=True)
-    with open(local_path, 'wb') as f:
-      f.write(response.content)
-    print(f"Downloaded: {local_path}")
-  else:
-    print(response.status_code)
-    print(f"Failed to download: {url}")
-
-
-def update_html(html_content, base_url):
-  soup = BeautifulSoup(html_content, 'html.parser')
-
-  for tag in soup.find_all(['script', 'link']):
-    if tag.has_attr('src'):
-      url = tag['src']
-    elif tag.has_attr('href'):
-      url = tag['href']
-    else:
-      continue
-
-    if url.startswith(('http://', 'https://')):
-      full_url = url
-    else:
-      full_url = urljoin(base_url, url)
-
-    parsed_url = urlparse(full_url)
-    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
-
-    download_file(full_url, local_path)
-
-    relative_path = os.path.relpath(local_path, '.')
-    if tag.name == 'script':
-      tag['src'] = "/" + relative_path
-    elif tag.name == 'link':
-      tag['href'] = "/" + relative_path
-
-  return str(soup)
-
-
-# Read the HTML file
-with open('./index.html', 'r') as f:
-  html_content = f.read()
-
-# Update HTML and download files
-# updated_html = update_html(html_content, 'https://example.com')
-
-# # Write the updated HTML
-# with open('./index.html', 'w') as f:
-#     f.write(updated_html)
-
-print("HTML file updated with local paths.")
-
-# Download Font Awesome CSS and font files
-base_url = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/"
-css_url = urljoin(base_url, "css/all.min.css")
-output_dir = "static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2"
-
-# Download CSS file
-css_output_path = os.path.join(output_dir, "css", "all.min.css")
-download_file(css_url, css_output_path)
-
-# Parse CSS file for font URLs
-with open(css_output_path, 'r', encoding='utf-8') as f:
-  css_content = f.read()
-
-# Extract font URLs from the CSS content
-font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
-
-print(f"Found {len(font_urls)} font URLs")
-
-# Download font files
-for font_url in font_urls:
-  font_url = font_url.strip('"\'')
-  if font_url.startswith('../'):
-    font_url = font_url[3:]
-
-  # Use base_url instead of urljoin to keep the version number
-  full_url = base_url + font_url
-  relative_path = font_url
-  output_path = os.path.join(output_dir, relative_path)
-  download_file(full_url, output_path)
-
-print("Download complete!")

+ 0 - 0
build/lib/exo/topology/__init__.py


+ 0 - 217
build/lib/exo/topology/device_capabilities.py

@@ -1,217 +0,0 @@
-from typing import Any
-from pydantic import BaseModel
-from exo import DEBUG
-import subprocess
-import psutil
-
-TFLOPS = 1.00
-
-
-class DeviceFlops(BaseModel):
-  # units of TFLOPS
-  fp32: float
-  fp16: float
-  int8: float
-
-  def __str__(self):
-    return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
-
-  def to_dict(self):
-    return self.model_dump()
-
-
-class DeviceCapabilities(BaseModel):
-  model: str
-  chip: str
-  memory: int
-  flops: DeviceFlops
-
-  def __str__(self):
-    return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
-
-  def model_post_init(self, __context: Any) -> None:
-    if isinstance(self.flops, dict):
-      self.flops = DeviceFlops(**self.flops)
-
-  def to_dict(self):
-    return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()}
-
-
-UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
-
-CHIP_FLOPS = {
-  # Source: https://www.cpu-monkey.com
-  # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
-  ### M chips
-  "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
-  "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
-  "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
-  "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
-  "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-  "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
-  "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
-  "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
-  "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-  "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
-  "Apple M4": DeviceFlops(fp32=4.26*TFLOPS, fp16=8.52*TFLOPS, int8=17.04*TFLOPS),
-  "Apple M4 Pro": DeviceFlops(fp32=5.72*TFLOPS, fp16=11.44*TFLOPS, int8=22.88*TFLOPS),
-  "Apple M4 Max": DeviceFlops(fp32=18.03*TFLOPS, fp16=36.07*TFLOPS, int8=72.14*TFLOPS),
-  ### A chips
-  "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
-  "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
-  "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
-  "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
-  "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
-  ### NVIDIA GPUs
-  # RTX 40 series
-  "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
-  "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
-  "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
-  "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
-  "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
-  "NVIDIA GEFORCE RTX 4060 TI": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
-  # RTX 30 series
-  "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
-  "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
-  "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
-  "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
-  "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
-  "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
-  "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
-  # RTX 20 series
-  "NVIDIA GEFORCE RTX 2060": DeviceFlops(fp32=6.45*TFLOPS, fp16=12.9*TFLOPS, int8=25.8*TFLOPS),
-  "NVIDIA GEFORCE RTX 2060 SUPER": DeviceFlops(fp32=7.2*TFLOPS, fp16=14.4*TFLOPS, int8=28.8*TFLOPS),
-  "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS),
-  "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS),
-  "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS),
-  "NVIDIA GEFORCE RTX 2080 TI": DeviceFlops(fp32=13.45*TFLOPS, fp16=26.9*TFLOPS, int8=40.28*TFLOPS),
-  "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS),
-  "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS),
-  # GTX 10 series
-  "NVIDIA GEFORCE GTX 1050 TI": DeviceFlops(fp32=2.0*TFLOPS, fp16=4.0*TFLOPS, int8=8.0*TFLOPS),
-  "NVIDIA GEFORCE GTX 1070": DeviceFlops(fp32=6.463*TFLOPS, fp16=0.101*TFLOPS, int8=25.852*TFLOPS),
-  "NVIDIA GEFORCE GTX 1080": DeviceFlops(fp32=8.873*TFLOPS, fp16=0.138*TFLOPS, int8=35.492*TFLOPS),
-  "NVIDIA GEFORCE GTX 1080 TI": DeviceFlops(fp32=11.34*TFLOPS, fp16=0.177*TFLOPS, int8=45.36*TFLOPS),
-  # GTX 16 series
-  "NVIDIA GeForce GTX 1660 TI": DeviceFlops(fp32=4.8*TFLOPS, fp16=9.6*TFLOPS, int8=19.2*TFLOPS),
-  # QUADRO RTX Ampere series
-  "NVIDIA RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
-  "NVIDIA RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
-  "NVIDIA RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
-  "NVIDIA RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
-  "NVIDIA RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
-  # NVIDIA Ada Lovelace Architecture-Based
-  "NVIDIA RTX 4000 ADA GENERATION": DeviceFlops(fp32=26.7*TFLOPS, fp16=26.7*TFLOPS, int8=258.0*TFLOPS),
-  # Common Server GPUs
-  "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
-  "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
-  "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
-  "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
-  "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
-  "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
-  "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
-  # ... add more devices if needed ...
-  ### AMD GPUs
-  # RX 6000 series
-  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
-  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
-  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
-  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
-  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
-  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
-  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
-  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
-  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
-  # RX 7000 series
-  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
-  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
-  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
-  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
-  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
-  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
-  ### Qualcomm embedded chips: TODO
-}
-CHIP_FLOPS.update({f"LAPTOP GPU {key}": value for key, value in CHIP_FLOPS.items()})
-CHIP_FLOPS.update({f"Laptop GPU {key}": value for key, value in CHIP_FLOPS.items()})
-CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items()})
-CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
-
-
-def device_capabilities() -> DeviceCapabilities:
-  if psutil.MACOS:
-    return mac_device_capabilities()
-  elif psutil.LINUX:
-    return linux_device_capabilities()
-  else:
-    return DeviceCapabilities(
-      model="Unknown Device",
-      chip="Unknown Chip",
-      memory=psutil.virtual_memory().total // 2**20,
-      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
-    )
-
-
-def mac_device_capabilities() -> DeviceCapabilities:
-  # Fetch the model of the Mac using system_profiler
-  model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
-  model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
-  model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
-  chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
-  chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
-  memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
-  memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
-  memory_units = memory_str.split()
-  memory_value = int(memory_units[0])
-  if memory_units[1] == "GB":
-    memory = memory_value*1024
-  else:
-    memory = memory_value
-
-  # Assuming static values for other attributes for demonstration
-  return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
-
-
-def linux_device_capabilities() -> DeviceCapabilities:
-  import psutil
-  from tinygrad import Device
-
-  if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
-  if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
-    import pynvml
-
-    pynvml.nvmlInit()
-    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
-    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
-    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
-    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
-
-    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
-
-    return DeviceCapabilities(
-      model=f"Linux Box ({gpu_name})",
-      chip=gpu_name,
-      memory=gpu_memory_info.total // 2**20,
-      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
-    )
-  elif Device.DEFAULT == "AMD":
-    # TODO AMD support
-    return DeviceCapabilities(
-      model="Linux Box (AMD)",
-      chip="Unknown AMD",
-      memory=psutil.virtual_memory().total // 2**20,
-      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
-    )
-  else:
-    return DeviceCapabilities(
-      model=f"Linux Box (Device: {Device.DEFAULT})",
-      chip=f"Unknown Chip (Device: {Device.DEFAULT})",
-      memory=psutil.virtual_memory().total // 2**20,
-      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
-    )

部分文件因文件數量過多而無法顯示