Forráskód Böngészése

Merge branch 'main' into runners2

Alex Cheema 3 hónapja
szülő
commit
218c1e79d9
39 módosított fájl, 2974 hozzáadás és 249 törlés
  1. 2 0
      .gitignore
  2. 8 6
      README.md
  3. 1 1
      configure_mlx.sh
  4. 111 0
      examples/function_calling.py
  5. 157 37
      exo/api/chatgpt_api.py
  6. 4 0
      exo/download/hf/hf_helpers.py
  7. 11 7
      exo/download/hf/hf_shard_download.py
  8. 18 1
      exo/helpers.py
  9. 8 4
      exo/inference/inference_engine.py
  10. 307 0
      exo/inference/mlx/models/StableDiffusionPipeline.py
  11. 117 0
      exo/inference/mlx/models/phi3.py
  12. 4 3
      exo/inference/mlx/models/qwen2.py
  13. 191 0
      exo/inference/mlx/models/sd_models/clip.py
  14. 131 0
      exo/inference/mlx/models/sd_models/tokenizer.py
  15. 629 0
      exo/inference/mlx/models/sd_models/unet.py
  16. 429 0
      exo/inference/mlx/models/sd_models/vae.py
  17. 35 18
      exo/inference/mlx/sharded_inference_engine.py
  18. 56 15
      exo/inference/mlx/sharded_utils.py
  19. 81 0
      exo/inference/mlx/test_non_blocking.py
  20. 3 3
      exo/inference/tinygrad/inference.py
  21. 1 1
      exo/inference/tokenizers.py
  22. 6 4
      exo/main.py
  23. 20 6
      exo/models.py
  24. 49 7
      exo/networking/grpc/grpc_peer_handle.py
  25. 39 8
      exo/networking/grpc/grpc_server.py
  26. 15 2
      exo/networking/grpc/node_service.proto
  27. 0 0
      exo/networking/grpc/node_service_pb2.py
  28. 51 22
      exo/networking/manual/manual_discovery.py
  29. 1 1
      exo/networking/manual/test_data/test_config.json
  30. 84 6
      exo/networking/manual/test_manual_discovery.py
  31. 61 39
      exo/orchestration/node.py
  32. 166 0
      exo/orchestration/tracing.py
  33. 20 2
      exo/tinychat/index.html
  34. 96 39
      exo/tinychat/index.js
  35. 58 13
      exo/viz/topology_viz.py
  36. 1 1
      install.sh
  37. 1 1
      scripts/compile_grpc.sh
  38. 1 1
      test/reconnect.sh
  39. 1 1
      test/test_tokenizers.py

+ 2 - 0
.gitignore

@@ -171,3 +171,5 @@ cython_debug/
 
 **/*.xcodeproj/*
 .aider*
+
+exo/tinychat/images/*.png

+ 8 - 6
README.md

@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
 [![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
 [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
 
+<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
+
 </div>
 
 ---
@@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
 
 ### Wide Model Support
 
-exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
+exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
 
 ### Dynamic Model Partitioning
 
@@ -100,13 +102,13 @@ source install.sh
 
 - There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
 
-1. Upgrade to the latest version of MacOS 15.
+1. Upgrade to the latest version of macOS Sequoia.
 2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
 
 
 ## Documentation
 
-### Example Usage on Multiple MacOS Devices
+### Example Usage on Multiple macOS Devices
 
 #### Device 1:
 
@@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \
    }'
 ```
 
-### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
+### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
 
-#### Device 1 (MacOS):
+#### Device 1 (macOS):
 
 ```sh
 exo
@@ -244,7 +246,7 @@ python3 format.py ./exo
 
 ## Known Issues
 
-- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
+- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
 
 ```sh
 /Applications/Python 3.x/Install Certificates.command

+ 1 - 1
configure_mlx.sh

@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
 
 # Get the total memory in MB
 TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))

+ 111 - 0
examples/function_calling.py

@@ -0,0 +1,111 @@
+import json
+import re
+import requests
+
+def get_current_weather(location: str, unit: str = "celsius"):
+  """Mock weather data function"""
+  # Hardcoded response for demo purposes
+  return {
+    "location": location,
+    "temperature": 22 if unit == "celsius" else 72,
+    "unit": unit,
+    "forecast": "Sunny with light clouds"
+  }
+
+def try_parse_tool_calls(content: str):
+  """Try parse the tool calls."""
+  tool_calls = []
+  offset = 0
+  for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
+    if i == 0:
+      offset = m.start()
+    try:
+      func = json.loads(m.group(1))
+      tool_calls.append({"type": "function", "function": func})
+      if isinstance(func["arguments"], str):
+        func["arguments"] = json.loads(func["arguments"])
+    except json.JSONDecodeError as e:
+      print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
+      pass
+  if tool_calls:
+    if offset > 0 and content[:offset].strip():
+      c = content[:offset]
+    else:
+      c = ""
+    return {"role": "assistant", "content": c, "tool_calls": tool_calls}
+  return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
+
+def chat_completion(messages):
+  """Send chat completion request to local server"""
+  response = requests.post(
+    "http://localhost:52415/v1/chat/completions",
+    json={
+      "model": "qwen-2.5-1.5b",
+      "messages": messages,
+      "tools": [{
+        "type": "function",
+        "function": {
+          "name": "get_current_weather",
+          "description": "Get the current weather in a given location",
+          "parameters": {
+            "type": "object",
+            "properties": {
+              "location": {
+                "type": "string",
+                "description": "The city and state, e.g. San Francisco, CA"
+              },
+              "unit": {
+                "type": "string",
+                "enum": ["celsius", "fahrenheit"]
+              }
+            },
+            "required": ["location"]
+          }
+        }
+      }],
+      "tool_choice": "auto"
+    }
+  )
+  return response.json()
+
+def main():
+  # Initial conversation
+  messages = [{
+    "role": "user",
+    "content": "Hi there, what's the weather in Boston?"
+  }]
+  
+  # Get initial response
+  response = chat_completion(messages)
+  print(f"First response: {response}")
+  assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
+  messages.append(assistant_message)
+  
+  # If there are tool calls, execute them and continue conversation
+  if "tool_calls" in assistant_message:
+    for tool_call in assistant_message["tool_calls"]:
+      if tool_call["function"]["name"] == "get_current_weather":
+        args = tool_call["function"]["arguments"]
+        weather_data = get_current_weather(**args)
+        
+        # Add tool response to messages
+        messages.append({
+          "role": "tool",
+          "content": json.dumps(weather_data),
+          "name": tool_call["function"]["name"]
+        })
+    
+    # Get final response with weather data
+    response = chat_completion(messages)
+    print(f"Final response: {response}")
+    messages.append({
+      "role": "assistant",
+      "content": response["choices"][0]["message"]["content"]
+    })
+  
+  # Print full conversation
+  for msg in messages:
+    print(f"\n{msg['role'].upper()}: {msg['content']}")
+
+if __name__ == "__main__":
+  main()

+ 157 - 37
exo/api/chatgpt_api.py

@@ -5,18 +5,24 @@ import json
 import os
 from pathlib import Path
 from transformers import AutoTokenizer
-from typing import List, Literal, Union, Dict
+from typing import List, Literal, Union, Dict, Optional
 from aiohttp import web
 import aiohttp_cors
 import traceback
 import signal
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict, shutdown
+from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable, Optional
+from PIL import Image
+import numpy as np
+import base64
+from io import BytesIO
+import mlx.core as mx
+import tempfile
 from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
@@ -24,23 +30,28 @@ from exo.apputil import create_animation_mp4
 from collections import defaultdict
 
 class Message:
-  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
+  def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
     self.role = role
     self.content = content
+    self.tools = tools
 
   def to_dict(self):
-    return {"role": self.role, "content": self.content}
+    data = {"role": self.role, "content": self.content}
+    if self.tools:
+      data["tools"] = self.tools
+    return data
 
 
 
 class ChatCompletionRequest:
-  def __init__(self, model: str, messages: List[Message], temperature: float):
+  def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
     self.model = model
     self.messages = messages
     self.temperature = temperature
+    self.tools = tools
 
   def to_dict(self):
-    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
+    return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
 
 
 def generate_completion(
@@ -120,20 +131,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
   return remapped_messages
 
 
-def build_prompt(tokenizer, _messages: List[Message]):
+def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
   messages = remap_messages(_messages)
-  prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  for message in messages:
-    if not isinstance(message.content, list):
-      continue
+  chat_template_args = {
+    "conversation": [m.to_dict() for m in messages],
+    "tokenize": False,
+    "add_generation_prompt": True
+  }
+  if tools: chat_template_args["tools"] = tools
 
+  prompt = tokenizer.apply_chat_template(**chat_template_args)
+  print(f"!!! Prompt: {prompt}")
   return prompt
 
 
 def parse_message(data: dict):
   if "role" not in data or "content" not in data:
     raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
-  return Message(data["role"], data["content"])
+  return Message(data["role"], data["content"], data.get("tools"))
 
 
 def parse_chat_request(data: dict, default_model: str):
@@ -141,6 +156,7 @@ def parse_chat_request(data: dict, default_model: str):
     data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
+    data.get("tools", None),
   )
 
 
@@ -151,7 +167,7 @@ class PromptSession:
     self.prompt = prompt
 
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -166,6 +182,7 @@ class ChatGPTAPI:
     # Get the callback system and register our handler
     self.token_callback = node.on_token.register("chatgpt-api-token-handler")
     self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
+    self.system_prompt = system_prompt
 
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
@@ -180,6 +197,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
@@ -191,10 +209,12 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
+      
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
       self.app.router.add_get("/", self.handle_root)
       self.app.router.add_static("/", self.static_dir, name="static")
+      self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
@@ -241,7 +261,7 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        for model_name, pretty in pretty_name.items():
+        async def process_model(model_name, pretty):
             if model_name in model_cards:
                 model_info = model_cards[model_name]
 
@@ -269,6 +289,12 @@ class ChatGPTAPI:
 
                         await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
 
+        # Process all models in parallel
+        await asyncio.gather(*[
+            process_model(model_name, pretty)
+            for model_name, pretty in pretty_name.items()
+        ])
+
         await response.write(b"data: [DONE]\n\n")
         return response
 
@@ -281,7 +307,8 @@ class ChatGPTAPI:
         )
 
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
+    models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
+    return web.json_response({"object": "list", "data": models_list})
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
@@ -294,7 +321,7 @@ class ChatGPTAPI:
     shard = build_base_shard(model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    prompt = build_prompt(tokenizer, messages)
+    prompt = build_prompt(tokenizer, messages, data.get("tools", None))
     tokens = tokenizer.encode(prompt)
     return web.json_response({
       "length": len(prompt),
@@ -314,13 +341,13 @@ class ChatGPTAPI:
 
   async def handle_post_chat_completions(self, request):
     data = await request.json()
-    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data, self.default_model)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
       chat_request.model = self.default_model
     if not chat_request.model or chat_request.model not in model_cards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
     shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
@@ -331,34 +358,26 @@ class ChatGPTAPI:
       )
 
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
+    if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
+
+    # Add system prompt if set
+    if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
+      chat_request.messages.insert(0, Message("system", self.system_prompt))
 
-    prompt = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
       try:
         self.on_chat_completion_request(request_id, chat_request, prompt)
       except Exception as e:
         if DEBUG >= 2: traceback.print_exc()
-    # request_id = None
-    # match = self.prompts.find_longest_prefix(prompt)
-    # if match and len(prompt) > len(match[1].prompt):
-    #     if DEBUG >= 2:
-    #       print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
-    #     request_id = match[1].request_id
-    #     self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-    #     # remove the matching prefix from the prompt
-    #     prompt = prompt[len(match[1].prompt):]
-    # else:
-    #   request_id = str(uuid.uuid4())
-    #   self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+
+    if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
 
     try:
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
-      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
+      if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
 
       if stream:
         response = web.StreamResponse(
@@ -374,10 +393,12 @@ class ChatGPTAPI:
         try:
           # Stream tokens while waiting for inference to complete
           while True:
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
             token, is_finished = await asyncio.wait_for(
               self.token_queues[request_id].get(),
               timeout=self.response_timeout
             )
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
 
             finish_reason = None
             eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
@@ -408,10 +429,13 @@ class ChatGPTAPI:
           return response
 
         except asyncio.TimeoutError:
+          if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
           return web.json_response({"detail": "Response generation timed out"}, status=408)
 
         except Exception as e:
-          if DEBUG >= 2: traceback.print_exc()
+          if DEBUG >= 2: 
+            print(f"[ChatGPTAPI] Error processing prompt: {e}")
+            traceback.print_exc()
           return web.json_response(
             {"detail": f"Error processing prompt: {str(e)}"},
             status=500
@@ -420,6 +444,7 @@ class ChatGPTAPI:
         finally:
           # Clean up the queue for this request
           if request_id in self.token_queues:
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
             del self.token_queues[request_id]
       else:
         tokens = []
@@ -441,6 +466,85 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
+  
+  async def handle_post_image_generations(self, request):
+    data = await request.json()
+
+    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    stream = data.get("stream", False)
+    model = data.get("model", "")
+    prompt = data.get("prompt", "")
+    image_url = data.get("image_url", "")
+    if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
+    shard = build_base_shard(model, self.inference_engine_classname)
+    if DEBUG >= 2: print(f"shard: {shard}")
+    if not shard:
+        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
+
+    request_id = str(uuid.uuid4())
+    callback_id = f"chatgpt-api-wait-response-{request_id}"
+    callback = self.node.on_token.register(callback_id)
+    try:
+      if image_url != "" and image_url != None:
+        img = self.base64_decode(image_url)
+      else:
+        img = None
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
+
+
+      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
+      await response.prepare(request)
+
+      def get_progress_bar(current_step, total_steps, bar_length=50):
+        # Calculate the percentage of completion
+        percent = float(current_step) / total_steps
+        # Calculate the number of hashes to display
+        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
+        spaces = ' ' * (bar_length - len(arrow))
+        
+        # Create the progress bar string
+        progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
+        return progress_bar
+
+      async def stream_image(_request_id: str, result, is_finished: bool):
+          if isinstance(result, list):
+              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
+
+          elif isinstance(result, np.ndarray):
+            im = Image.fromarray(np.array(result))
+            images_folder = get_exo_images_dir()
+            # Save the image to a file
+            image_filename = f"{_request_id}.png"
+            image_path = images_folder / image_filename
+            im.save(image_path)
+            image_url = request.app.router['static_images'].url_for(filename=image_filename)
+            base_url = f"{request.scheme}://{request.host}"
+            # Construct the full URL correctly
+            full_image_url = base_url + str(image_url)
+            
+            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            if is_finished:
+              await response.write_eof()
+              
+
+      stream_task = None
+      def on_result(_request_id: str, result, is_finished: bool):
+          nonlocal stream_task
+          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
+          return _request_id == request_id and is_finished
+
+      await callback.wait(on_result, timeout=self.response_timeout*10)
+      
+      if stream_task:
+          # Wait for the stream task to complete before returning
+          await stream_task
+
+      return response
+
+    except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
+        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+  
   async def handle_delete_model(self, request):
     try:
       model_name = request.match_info.get('model_name')
@@ -553,7 +657,7 @@ class ChatGPTAPI:
       if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
-      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+      asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
 
       return web.json_response({
         "status": "success",
@@ -585,3 +689,19 @@ class ChatGPTAPI:
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     await site.start()
+
+  def base64_decode(self, base64_string):
+    #decode and reshape image
+    if base64_string.startswith('data:image'):
+        base64_string = base64_string.split(',')[1]
+    image_data = base64.b64decode(base64_string)
+    img = Image.open(BytesIO(image_data))
+    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    if W != img.width or H != img.height:
+        if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+    img = mx.array(np.array(img))
+    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = img[None]
+    return img
+  

+ 4 - 0
exo/download/hf/hf_helpers.py

@@ -303,6 +303,10 @@ async def download_repo_files(
         await f.write(json.dumps(file_list))
       if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
 
+    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
+    if model_index_exists:
+      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
+
     filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
     total_files = len(filtered_file_list)
     total_bytes = sum(file["size"] for file in filtered_file_list)

+ 11 - 7
exo/download/hf/hf_shard_download.py

@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
           print(f"No snapshot directory found for {self.current_repo_id}")
         return None
 
+      if not await aios.path.exists(snapshot_dir/"model_index.json"):
       # Get the weight map to know what files we need
-      weight_map = await get_weight_map(self.current_repo_id, self.revision)
-      if not weight_map:
-        if DEBUG >= 2:
-          print(f"No weight map found for {self.current_repo_id}")
-        return None
+        weight_map = await get_weight_map(self.current_repo_id, self.revision)
+        if not weight_map:
+          if DEBUG >= 2:
+            print(f"No weight map found for {self.current_repo_id}")
+          return None
+
+        # Get all files needed for this shard
+        patterns = get_allow_patterns(weight_map, self.current_shard)
+      else:
+        patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
 
-      # Get all files needed for this shard
-      patterns = get_allow_patterns(weight_map, self.current_shard)
 
       # Check download status for all relevant files
       status = {}

+ 18 - 1
exo/helpers.py

@@ -350,4 +350,21 @@ async def get_mac_system_info() -> Tuple[str, str, int]:
         return model_id, chip_id, memory
     except Exception as e:
         if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
-        return "Unknown Model", "Unknown Chip", 0
+        return "Unknown Model", "Unknown Chip", 0
+
+def get_exo_home() -> Path:
+  if os.name == "nt":  # Check if the OS is Windows
+    docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
+  else:
+    docs_folder = Path.home() / "Documents"
+  exo_folder = docs_folder / "Exo"
+  if not exo_folder.exists():
+    exo_folder.mkdir()
+  return exo_folder
+
+def get_exo_images_dir() -> Path:
+  exo_home = get_exo_home()
+  images_dir = exo_home / "Images"
+  if not images_dir.exists():
+    images_dir.mkdir()
+  return images_dir

+ 8 - 4
exo/inference/inference_engine.py

@@ -39,11 +39,15 @@ class InferenceEngine(ABC):
   async def clear_session(self):
     self.session.empty()
   
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
-    x = tokens.reshape(1, -1)
-    output_data = await self.infer_tensor(request_id, shard, x)
-    return output_data 
+    if shard.model_id != 'stable-diffusion-2-1-base':
+      x = tokens.reshape(1, -1)
+    else:
+      x = tokens
+    output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
+
+    return output_data, inference_state
 
 inference_engine_classes = {
   "mlx": "MLXDynamicShardInferenceEngine",

+ 307 - 0
exo/inference/mlx/models/StableDiffusionPipeline.py

@@ -0,0 +1,307 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
+
+import time
+from typing import Optional, Tuple
+import inspect
+
+import mlx.core as mx
+import mlx.nn as nn
+from pathlib import Path
+
+from tqdm import tqdm
+
+from .sd_models.vae import ModelArgs as VAEArgs
+from .sd_models.vae import Autoencoder
+from .sd_models.tokenizer import load_tokenizer
+from .sd_models.clip import CLIPTextModel
+from .sd_models.clip import ModelArgs as CLIPArgs
+from .sd_models.unet import UNetConfig, UNetModel
+
+from dataclasses import dataclass, field
+from exo.inference.shard import Shard
+
+@dataclass
+class DiffusionConfig:
+    beta_schedule: str = "scaled_linear"
+    beta_start: float = 0.00085
+    beta_end: float = 0.012
+    num_train_steps: int = 1000
+
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+
+#Sampler
+def _linspace(a, b, num):
+    x = mx.arange(0, num) / (num - 1)
+    return (b - a) * x + a
+
+
+def _interp(y, x_new):
+    """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
+    x_low = x_new.astype(mx.int32)
+    x_high = mx.minimum(x_low + 1, len(y) - 1)
+
+    y_low = y[x_low]
+    y_high = y[x_high]
+    delta_x = x_new - x_low
+    y_new = y_low * (1 - delta_x) + delta_x * y_high
+
+    return y_new
+
+class SimpleEulerSampler:
+    """A simple Euler integrator that can be used to sample from our diffusion models.
+
+    The method ``step()`` performs one Euler step from x_t to x_t_prev.
+    """
+
+    def __init__(self, config: DiffusionConfig):
+        # Compute the noise schedule
+        if config.beta_schedule == "linear":
+            betas = _linspace(
+                config.beta_start, config.beta_end, config.num_train_steps
+            )
+        elif config.beta_schedule == "scaled_linear":
+            betas = _linspace(
+                config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
+            ).square()
+        else:
+            raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
+
+        alphas = 1 - betas
+        alphas_cumprod = mx.cumprod(alphas)
+
+        self._sigmas = mx.concatenate(
+            [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
+        )
+
+    @property
+    def max_time(self):
+        return len(self._sigmas) - 1
+
+    def sample_prior(self, shape, dtype=mx.float32, key=None):
+        noise = mx.random.normal(shape, key=key)
+        return (
+            noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
+        ).astype(dtype)
+
+    def add_noise(self, x, t, key=None):
+        noise = mx.random.normal(x.shape, key=key)
+        s = self.sigmas(t)
+        return (x + noise * s) * (s.square() + 1).rsqrt()
+
+    def sigmas(self, t):
+        return _interp(self._sigmas, t)
+
+    def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
+        start_time = start_time or (len(self._sigmas) - 1)
+        assert 0 < start_time <= (len(self._sigmas) - 1)
+        steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
+        return list(zip(steps, steps[1:]))
+
+    def current_timestep(self, step, total_steps, start_time=None):
+        if step < total_steps:
+            steps = self.timesteps(total_steps, start_time)
+            return steps[step]
+        else:
+            return mx.array(0),mx.array(0)
+
+    def step(self, eps_pred, x_t, t, t_prev):
+        sigma = self.sigmas(t).astype(eps_pred.dtype)
+        sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
+
+        dt = sigma_prev - sigma
+        x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
+
+        x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
+
+        return x_t_prev
+
+@dataclass
+class ShardConfig:
+    model_id:str
+    start_layer:int
+    end_layer:int
+    n_layers:int
+
+@dataclass
+class StableDiffusionConfig:
+    model_type:str
+    vae:VAEArgs
+    text_encoder:CLIPArgs
+    scheduler:DiffusionConfig
+    unet:UNetConfig
+    shard:ShardConfig
+    
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+@dataclass
+class ModelArgs(StableDiffusionConfig):
+    shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+
+class Model(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.model_type = config.model_type
+        self.config = config
+        self.model_path = config.vae['path'].split('/vae')[0]
+        self.shard = config.shard
+        self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder  = model_shards(config.shard)
+        self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
+        if self.shard_clip.start_layer != -1:
+            self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
+        else:
+            self.text_encoder = nn.Identity()    
+        self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
+        self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
+        self.sampler = SimpleEulerSampler(self.diffusion_config)
+        if self.shard_unet.start_layer!=-1:
+            self.config_unet = UNetConfig.from_dict(config.unet['config'])
+            self.unet = UNetModel(self.config_unet, self.shard_unet)
+        else:
+            self.unet = nn.Identity()
+        self.config_vae=VAEArgs.from_dict(config.vae['config'])
+        if self.shard_encoder.start_layer != -1:
+            self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder") 
+        else:
+            self.encoder = nn.Identity()            
+        if self.shard_decoder.start_layer != -1:
+            self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder") 
+        else:
+            self.decoder = nn.Identity()            
+
+    def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
+        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
+        is_finished = False
+        is_step_finished = False
+        if t.item()==1000:
+            if self.shard_clip.start_layer == 0:
+                conditioning = x
+            if self.shard_clip.start_layer != -1:
+                conditioning, mask= self.text_encoder(conditioning,mask)
+            seed = int(time.time()) 
+            mx.random.seed(seed)
+            if image is None:
+                if self.shard_encoder.is_last_layer():
+                    x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
+                    x_t_prev=x
+                    start_step = self.sampler.max_time
+            else:
+                if self.shard_encoder.start_layer != -1:
+                    image= self.encoder.encode(image)
+                    if self.shard_encoder.is_last_layer():
+                        start_step = self.sampler.max_time*strength
+                        total_steps = int(total_steps*strength)
+                        image = mx.broadcast_to(image, (1,) + image.shape[1:])
+                        x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
+                        image = None
+                        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
+        # Perform the denoising loop
+        if self.shard_unet.start_layer != -1:
+            with tqdm(total=total_steps,initial=step+1) as pbar:
+                if step<total_steps:
+                    x = x_t_prev
+                    if self.shard_unet.is_first_layer():
+                        x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
+                    else:
+                        x_t_unet = x
+                    t_unet = mx.broadcast_to(t, [len(x_t_unet)])
+                    x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
+                    if self.shard_unet.is_last_layer():
+                        if cfg_weight > 1:
+                            eps_text, eps_neg = x.split(2)
+                            eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
+                        x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
+                        x_t_prev=x
+                    mx.eval(x)
+                    
+        if self.shard_decoder.is_last_layer():
+            is_step_finished=True
+            if self.shard_decoder.start_layer != -1:
+                x=self.decoder.decode(x)
+            if self.shard_decoder.is_last_layer():
+                x = mx.clip(x / 2 + 0.5, 0, 1)
+                B, H, W, C = x.shape
+                x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
+                x = x.reshape(1 * H, B // 1 * W, C)
+                x = (x * 255).astype(mx.uint8)
+                if t_prev.item() ==0:
+                    is_finished=True   
+        mx.eval(x)
+         
+        return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
+    
+
+    def load(self):
+        if self.shard_encoder.start_layer != -1:    
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.encoder.sanitize(vae_weights)
+            self.encoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_decoder.start_layer != -1:
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.decoder.sanitize(vae_weights)
+            self.decoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_clip.start_layer != -1:
+            clip_weights = mx.load(self.config_clip.weight_files[0])
+            clip_weights = self.text_encoder.sanitize(clip_weights)
+            self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
+        if self.shard_unet.start_layer !=-1:
+            unet_weights = mx.load(self.config_unet.weight_files[0])
+            unet_weights = self.unet.sanitize(unet_weights)
+            self.unet.load_weights(list(unet_weights.items()), strict=True)
+
+def model_shards(shard:ShardConfig):
+    def create_shard(shard, model_ranges):
+        start_layer = shard.start_layer
+        end_layer = shard.end_layer
+        
+        shards = {}
+        
+        for model_name, (range_start, range_end) in model_ranges.items():
+            if start_layer < range_end and end_layer >= range_start:
+                # Calculate the overlap with the model range
+                overlap_start = max(start_layer, range_start)
+                overlap_end = min(end_layer, range_end - 1)
+
+                # Adjust the layers relative to the model's range
+                relative_start = overlap_start - range_start
+                relative_end = overlap_end - range_start
+                shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
+            else:
+                # If no overlap, create a zero-layer shard
+                shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
+        
+        return shards
+
+    # Define the ranges for different models
+    model_ranges = {
+        'clip': (0, 12),
+        'vae_encoder':(12,17),
+        'unet':(17,26),
+        'vae_decoder': (26, 31) # Example range for unet
+    }
+
+    # Call the function and get the shards for all models
+    shards = create_shard(shard, model_ranges)
+
+    # Access individual shards
+    shard_clip = shards['clip']
+    shard_encoder = shards['vae_encoder']
+    shard_unet = shards['unet']
+    shard_decoder = shards['vae_decoder']
+    
+    return shard_clip, shard_encoder, shard_unet, shard_decoder
+
+
+

+ 117 - 0
exo/inference/mlx/models/phi3.py

@@ -0,0 +1,117 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+class Phi3Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+        
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = Phi3Model(args)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rope.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 4 - 3
exo/inference/mlx/models/qwen2.py

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
 from ...shard import Shard
 from .base import IdentityBlock
 
-
 @dataclass
 class ModelArgs(ModelArgs):
   shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
 
   def __post_init__(self):
-    super().__post_init__()  # Ensure parent initializations are respected
+    super().__post_init__()
 
     if isinstance(self.shard, Shard):
       return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
-
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
     self.vocab_size = args.vocab_size
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
+
     if self.args.shard.is_first_layer():
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+
     self.layers = []
     for i in range(self.num_hidden_layers):
       if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
         self.layers.append(TransformerBlock(args=args))
       else:
         self.layers.append(IdentityBlock())
+
     if self.args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 

+ 191 - 0
exo/inference/mlx/models/sd_models/clip.py

@@ -0,0 +1,191 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+from dataclasses import field, dataclass
+from exo.inference.shard import Shard
+from exo.inference.mlx.models.base import IdentityBlock 
+
+_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
+
+
+
+@dataclass
+class CLIPTextModelConfig:
+    num_layers: int = 23
+    model_dims: int = 1024
+    num_heads: int = 16
+    max_length: int = 77
+    vocab_size: int = 49408
+    projection_dim: Optional[int] = None
+    hidden_act: str = "quick_gelu"
+
+    @classmethod
+    def from_dict(cls, config):
+        return ModelArgs(
+            num_layers=config["num_hidden_layers"],
+            model_dims=config["hidden_size"],
+            num_heads=config["num_attention_heads"],
+            max_length=config["max_position_embeddings"],
+            vocab_size=config["vocab_size"],
+            projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
+            hidden_act=config.get("hidden_act", "quick_gelu"),
+            weight_files=config.get("weight_files", [])
+            )
+
+@dataclass
+class ModelArgs(CLIPTextModelConfig):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+    weight_files: List[str] = field(default_factory=lambda: [])
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        if not self.shard.is_first_layer():
+            self.vision_config = None
+
+
+@dataclass
+class CLIPOutput:
+    pooled_output: Optional[mx.array] = None
+    last_hidden_state: Optional[mx.array] = None
+    hidden_states: Optional[List[mx.array]] = None
+
+
+class CLIPEncoderLayer(nn.Module):
+    """The transformer encoder layer from CLIP."""
+
+    def __init__(self, model_dims: int, num_heads: int, activation: str):
+        super().__init__()
+
+        self.layer_norm1 = nn.LayerNorm(model_dims)
+        self.layer_norm2 = nn.LayerNorm(model_dims)
+
+        self.attention = nn.MultiHeadAttention(model_dims, num_heads)
+        self.attention.query_proj.bias = mx.zeros(model_dims)
+        self.attention.key_proj.bias = mx.zeros(model_dims)
+        self.attention.value_proj.bias = mx.zeros(model_dims)
+        self.attention.out_proj.bias = mx.zeros(model_dims)
+
+        self.linear1 = nn.Linear(model_dims, 4 * model_dims)
+        self.linear2 = nn.Linear(4 * model_dims, model_dims)
+
+        self.act = _ACTIVATIONS[activation]
+
+    def __call__(self, x, attn_mask=None):
+        
+        y = self.layer_norm1(x)
+        y = self.attention(y, y, y, attn_mask)
+        x = y + x
+        
+        y = self.layer_norm2(x)
+        y = self.linear1(y)
+        y = self.act(y)
+        y = self.linear2(y)
+        x = y + x
+        return x
+
+
+class CLIPTextModel(nn.Module):
+    """Implements the text encoder transformer from CLIP."""
+
+    def __init__(self, config: CLIPTextModelConfig, shard: Shard):
+        super().__init__()
+
+        self.shard = shard
+        self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2) 
+        if self.shard.is_first_layer():
+            self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
+            self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
+        self.layers = []
+        for i in range(math.ceil(config.num_layers/2)):
+            if  2*i in self.layers_range:
+                self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
+            if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
+                self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
+            else:
+                self.layers.append(IdentityBlock())
+        if self.shard.is_last_layer():
+            self.final_layer_norm = nn.LayerNorm(config.model_dims)
+
+        if config.projection_dim is not None:
+            self.text_projection = nn.Linear(
+                config.model_dims, config.projection_dim, bias=False
+            )
+
+    def _get_mask(self, N, dtype):
+        indices = mx.arange(N)
+        mask = indices[:, None] < indices[None]
+        mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
+        return mask
+
+    def __call__(self, x, mask=None):
+        # Extract some shapes
+        if self.shard.is_first_layer():
+            B, N = x.shape
+            eos_tokens = x.argmax(-1)
+            
+            # Compute the embeddings
+            x = self.token_embedding(x)
+           
+            x = x + self.position_embedding.weight[:N]
+            # Compute the features from the transformer
+            mask = self._get_mask(N, x.dtype)
+        
+        for l in self.layers:
+            x = l(x, mask)
+        # Apply the final layernorm and return
+        
+        if self.shard.is_last_layer():
+            x = self.final_layer_norm(x)
+        
+       
+
+        return x, mask
+    def sanitize(self, weights):
+        sanitized_weights = {}
+        for key, value in weights.items():
+            if "position_ids" in key:
+                continue
+            if key.startswith("text_model."):
+                key = key[11:]
+            if key.startswith("embeddings."):
+                key = key[11:]
+            if key.startswith("encoder."):
+                key = key[8:]
+
+            # Map attention layers
+            if "self_attn." in key:
+                key = key.replace("self_attn.", "attention.")
+            if "q_proj." in key:
+                key = key.replace("q_proj.", "query_proj.")
+            if "k_proj." in key:
+                key = key.replace("k_proj.", "key_proj.")
+            if "v_proj." in key:
+                key = key.replace("v_proj.", "value_proj.")
+
+            # Map ffn layers
+            if "mlp.fc1" in key:
+                key = key.replace("mlp.fc1", "linear1")
+            if "mlp.fc2" in key:
+                key = key.replace("mlp.fc2", "linear2")
+            
+            if key.startswith("layers."):
+                layer_num = int(key.split(".")[1])
+                if layer_num not in self.layers_range:
+                    continue
+            if not self.shard.is_first_layer() and "embedding" in key:
+                continue
+            if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
+                continue
+            if not self.shard.is_last_layer() and key.startswith("text_projection"):
+                continue
+            sanitized_weights[key] = value
+        return sanitized_weights

+ 131 - 0
exo/inference/mlx/models/sd_models/tokenizer.py

@@ -0,0 +1,131 @@
+# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
+
+import regex
+import json
+import glob
+
+
+class Tokenizer:
+    """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
+
+    def __init__(self, bpe_ranks, vocab):
+        self.bpe_ranks = bpe_ranks
+        self.vocab = vocab
+        self.pat = regex.compile(
+            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            regex.IGNORECASE,
+        )
+
+        self._cache = {self.bos: self.bos, self.eos: self.eos}
+
+    @property
+    def bos(self):
+        return "<|startoftext|>"
+
+    @property
+    def bos_token(self):
+        return self.vocab[self.bos]
+
+    @property
+    def eos(self):
+        return "<|endoftext|>"
+
+    @property
+    def eos_token(self):
+        return self.vocab[self.eos]
+
+    def bpe(self, text):
+        if text in self._cache:
+            return self._cache[text]
+
+        unigrams = list(text[:-1]) + [text[-1] + "</w>"]
+        unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+        if not unique_bigrams:
+            return unigrams
+
+        # In every iteration try to merge the two most likely bigrams. If none
+        # was merged we are done.
+        #
+        # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
+        while unique_bigrams:
+            bigram = min(
+                unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
+            )
+            if bigram not in self.bpe_ranks:
+                break
+
+            new_unigrams = []
+            skip = False
+            for a, b in zip(unigrams, unigrams[1:]):
+                if skip:
+                    skip = False
+                    continue
+
+                if (a, b) == bigram:
+                    new_unigrams.append(a + b)
+                    skip = True
+
+                else:
+                    new_unigrams.append(a)
+
+            if not skip:
+                new_unigrams.append(b)
+
+            unigrams = new_unigrams
+            unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+        self._cache[text] = unigrams
+
+        return unigrams
+
+    def tokenize(self, text, prepend_bos=True, append_eos=True):
+        if isinstance(text, list):
+            return [self.tokenize(t, prepend_bos, append_eos) for t in text]
+
+        # Lower case cleanup and split according to self.pat. Hugging Face does
+        # a much more thorough job here but this should suffice for 95% of
+        # cases.
+        clean_text = regex.sub(r"\s+", " ", text.lower())
+        tokens = regex.findall(self.pat, clean_text)
+
+        # Split the tokens according to the byte-pair merge file
+        bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
+
+        # Map to token ids and return
+        tokens = [self.vocab[t] for t in bpe_tokens]
+        if prepend_bos:
+            tokens = [self.bos_token] + tokens
+        if append_eos:
+            tokens.append(self.eos_token)
+
+        return tokens
+    
+    def encode(self, prompt):
+        tokens = [self.tokenize(prompt)]
+        negative_text = ""
+        if negative_text is not None:
+            tokens += [self.tokenize(negative_text)]
+        lengths = [len(t) for t in tokens]
+        N = max(lengths)
+        tokens = [t + [0] * (N - len(t)) for t in tokens]
+        return tokens
+
+def load_tokenizer(
+    model_path: str,
+    vocab_key: str = "tokenizer_vocab",
+    merges_key: str = "tokenizer_merges",
+):
+
+    vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
+    with open(vocab_file, encoding="utf-8") as f:
+        vocab = json.load(f)
+
+    merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
+    with open(merges_file, encoding="utf-8") as f:
+        bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
+    bpe_merges = [tuple(m.split()) for m in bpe_merges]
+    bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
+
+    return Tokenizer(bpe_ranks, vocab)
+

+ 629 - 0
exo/inference/mlx/models/sd_models/unet.py

@@ -0,0 +1,629 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
+
+import math
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from dataclasses import dataclass, field
+from typing import Tuple, Optional, List
+from exo.inference.shard import Shard
+
+@dataclass
+class UNetConfig:
+    in_channels: int = 4
+    out_channels: int = 4
+    conv_in_kernel: int = 3
+    conv_out_kernel: int = 3
+    block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
+    layers_per_block: Tuple[int] = (2, 2, 2, 2)
+    mid_block_layers: int = 2
+    transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
+    num_attention_heads: Tuple[int] = (5, 10, 20, 20)
+    cross_attention_dim: Tuple[int] = (1024,) * 4
+    norm_num_groups: int = 32
+    down_block_types: Tuple[str] = (
+        "CrossAttnDownBlock2D",
+        "CrossAttnDownBlock2D",
+        "CrossAttnDownBlock2D",
+        "DownBlock2D",
+    )
+    up_block_types: Tuple[str] = (
+        "UpBlock2D",
+        "CrossAttnUpBlock2D",
+        "CrossAttnUpBlock2D",
+        "CrossAttnUpBlock2D",
+    )
+    addition_embed_type: Optional[str] = None
+    addition_time_embed_dim: Optional[int] = None
+    projection_class_embeddings_input_dim: Optional[int] = None
+    weight_files: List[str] = field(default_factory=lambda: [])
+
+
+
+    @classmethod
+    def from_dict(cls,config):
+        n_blocks = len(config['block_out_channels'])
+        return UNetConfig(
+            in_channels=config["in_channels"],
+            out_channels=config["out_channels"],
+            block_out_channels=config["block_out_channels"],
+            layers_per_block=[config["layers_per_block"]] * n_blocks,
+            transformer_layers_per_block=config.get(
+                "transformer_layers_per_block", (1,) * 4
+            ),
+            num_attention_heads=(
+                [config["attention_head_dim"]] * n_blocks
+                if isinstance(config["attention_head_dim"], int)
+                else config["attention_head_dim"]
+            ),
+            cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
+            norm_num_groups=config["norm_num_groups"],
+            down_block_types=config["down_block_types"],
+            up_block_types=config["up_block_types"][::-1],
+            addition_embed_type=config.get("addition_embed_type", None),
+            addition_time_embed_dim=config.get("addition_time_embed_dim", None),
+            projection_class_embeddings_input_dim=config.get(
+                "projection_class_embeddings_input_dim", None
+            ),
+            weight_files=config.get("weight_files", [])
+
+        )
+
+
+def upsample_nearest(x, scale: int = 2):
+    B, H, W, C = x.shape
+    x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
+    x = x.reshape(B, H * scale, W * scale, C)
+
+    return x
+
+
+class TimestepEmbedding(nn.Module):
+    def __init__(self, in_channels: int, time_embed_dim: int):
+        super().__init__()
+
+        self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+    def __call__(self, x):
+        x = self.linear_1(x)
+        x = nn.silu(x)
+        x = self.linear_2(x)
+
+        return x
+
+
+class TransformerBlock(nn.Module):
+    def __init__(
+        self,
+        model_dims: int,
+        num_heads: int,
+        hidden_dims: Optional[int] = None,
+        memory_dims: Optional[int] = None,
+    ):
+        super().__init__()
+
+        self.norm1 = nn.LayerNorm(model_dims)
+        self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
+        self.attn1.out_proj.bias = mx.zeros(model_dims)
+
+        memory_dims = memory_dims or model_dims
+        self.norm2 = nn.LayerNorm(model_dims)
+        self.attn2 = nn.MultiHeadAttention(
+            model_dims, num_heads, key_input_dims=memory_dims
+        )
+        self.attn2.out_proj.bias = mx.zeros(model_dims)
+
+        hidden_dims = hidden_dims or 4 * model_dims
+        self.norm3 = nn.LayerNorm(model_dims)
+        self.linear1 = nn.Linear(model_dims, hidden_dims)
+        self.linear2 = nn.Linear(model_dims, hidden_dims)
+        self.linear3 = nn.Linear(hidden_dims, model_dims)
+
+    def __call__(self, x, memory, attn_mask, memory_mask):
+        # Self attention
+        y = self.norm1(x)
+        y = self.attn1(y, y, y, attn_mask)
+        x = x + y
+
+        # Cross attention
+        y = self.norm2(x)
+        y = self.attn2(y, memory, memory, memory_mask)
+        x = x + y
+
+        # FFN
+        y = self.norm3(x)
+        y_a = self.linear1(y)
+        y_b = self.linear2(y)
+        y = y_a * nn.gelu(y_b)
+        y = self.linear3(y)
+        x = x + y
+
+        return x
+
+
+class Transformer2D(nn.Module):
+    """A transformer model for inputs with 2 spatial dimensions."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        model_dims: int,
+        encoder_dims: int,
+        num_heads: int,
+        num_layers: int = 1,
+        norm_num_groups: int = 32,
+    ):
+        super().__init__()
+
+        self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
+        self.proj_in = nn.Linear(in_channels, model_dims)
+        self.transformer_blocks = [
+            TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
+            for i in range(num_layers)
+        ]
+        self.proj_out = nn.Linear(model_dims, in_channels)
+
+    def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
+        # Save the input to add to the output
+        input_x = x
+        dtype = x.dtype
+
+        # Perform the input norm and projection
+        B, H, W, C = x.shape
+        x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
+        x = self.proj_in(x)
+
+        # Apply the transformer
+        for block in self.transformer_blocks:
+            x = block(x, encoder_x, attn_mask, encoder_attn_mask)
+
+        # Apply the output projection and reshape
+        x = self.proj_out(x)
+        x = x.reshape(B, H, W, C)
+
+        return x + input_x
+
+
+class ResnetBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: Optional[int] = None,
+        groups: int = 32,
+        temb_channels: Optional[int] = None,
+    ):
+        super().__init__()
+
+        out_channels = out_channels or in_channels
+
+        self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
+        self.conv1 = nn.Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+        if temb_channels is not None:
+            self.time_emb_proj = nn.Linear(temb_channels, out_channels)
+        self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
+        self.conv2 = nn.Conv2d(
+            out_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+
+        if in_channels != out_channels:
+            self.conv_shortcut = nn.Linear(in_channels, out_channels)
+
+    def __call__(self, x, temb=None):
+        dtype = x.dtype
+
+        if temb is not None:
+            temb = self.time_emb_proj(nn.silu(temb))
+        y = self.norm1(x.astype(mx.float32)).astype(dtype)
+        
+        y = nn.silu(y)
+        
+        y = self.conv1(y)
+
+        
+        if temb is not None:
+            y = y + temb[:, None, None, :]
+        y = self.norm2(y.astype(mx.float32)).astype(dtype)
+        y = nn.silu(y)
+        y = self.conv2(y)
+        
+        x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
+        return x
+
+
+class UNetBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        prev_out_channels: Optional[int] = None,
+        num_layers: int = 1,
+        transformer_layers_per_block: int = 1,
+        num_attention_heads: int = 8,
+        cross_attention_dim=1280,
+        resnet_groups: int = 32,
+        add_downsample=True,
+        add_upsample=True,
+        add_cross_attention=True,
+    ):
+        super().__init__()
+
+        # Prepare the in channels list for the resnets
+        if prev_out_channels is None:
+            in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
+        else:
+            in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
+            res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
+            in_channels_list = [
+                a + b for a, b in zip(in_channels_list, res_channels_list)
+            ]
+
+        # Add resnet blocks that also process the time embedding
+        self.resnets = [
+            ResnetBlock2D(
+                in_channels=ic,
+                out_channels=out_channels,
+                temb_channels=temb_channels,
+                groups=resnet_groups,
+            )
+            for ic in in_channels_list
+        ]
+
+        # Add optional cross attention layers
+        if add_cross_attention:
+            self.attentions = [
+                Transformer2D(
+                    in_channels=out_channels,
+                    model_dims=out_channels,
+                    num_heads=num_attention_heads,
+                    num_layers=transformer_layers_per_block,
+                    encoder_dims=cross_attention_dim,
+                )
+                for i in range(num_layers)
+            ]
+
+        # Add an optional downsampling layer
+        if add_downsample:
+            self.downsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=2, padding=1
+            )
+
+        # or upsampling layer
+        if add_upsample:
+            self.upsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=1, padding=1
+            )
+
+    def __call__(
+        self,
+        x,
+        encoder_x=None,
+        temb=None,
+        attn_mask=None,
+        encoder_attn_mask=None,
+        residual_hidden_states=None,
+    ):
+        output_states = []
+
+        for i in range(len(self.resnets)):
+            if residual_hidden_states is not None:
+                x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
+
+            x = self.resnets[i](x, temb)
+
+            if "attentions" in self:
+                x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
+
+            output_states.append(x)
+
+        if "downsample" in self:
+            x = self.downsample(x)
+            output_states.append(x)
+
+        if "upsample" in self:
+            x = self.upsample(upsample_nearest(x))
+            output_states.append(x)
+
+        return x, output_states
+
+
+class UNetModel(nn.Module):
+    """The conditional 2D UNet model that actually performs the denoising."""
+
+    def __init__(self, config: UNetConfig, shard: Shard):
+        super().__init__()
+        self.shard = shard
+        self.start_layer = shard.start_layer
+        self.end_layer = shard.end_layer
+        self.layers_range = list(range(self.start_layer, self.end_layer+1))
+        if shard.is_first_layer(): 
+            self.conv_in = nn.Conv2d(
+                config.in_channels,
+                config.block_out_channels[0],
+                config.conv_in_kernel,
+                padding=(config.conv_in_kernel - 1) // 2,
+            )
+
+        self.timesteps = nn.SinusoidalPositionalEncoding(
+            config.block_out_channels[0],
+            max_freq=1,
+            min_freq=math.exp(
+                -math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
+            ),
+            scale=1.0,
+            cos_first=True,
+            full_turns=False,
+        )
+        self.time_embedding = TimestepEmbedding(
+            config.block_out_channels[0],
+            config.block_out_channels[0] * 4,
+        )
+
+        if config.addition_embed_type == "text_time":
+            self.add_time_proj = nn.SinusoidalPositionalEncoding(
+                config.addition_time_embed_dim,
+                max_freq=1,
+                min_freq=math.exp(
+                    -math.log(10000)
+                    + 2 * math.log(10000) / config.addition_time_embed_dim
+                ),
+                scale=1.0,
+                cos_first=True,
+                full_turns=False,
+            )
+            self.add_embedding = TimestepEmbedding(
+                config.projection_class_embeddings_input_dim,
+                config.block_out_channels[0] * 4,
+            )
+
+        # Make the downsampling blocks
+        block_channels = [config.block_out_channels[0]] + list(
+            config.block_out_channels
+        )
+        self.down_blocks = []
+
+        for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):    
+            if i in self.layers_range:
+                self.down_blocks.append(
+                    UNetBlock2D(
+                        in_channels=in_channels,
+                out_channels=out_channels,
+                temb_channels=config.block_out_channels[0] * 4,
+                num_layers=config.layers_per_block[i],
+                transformer_layers_per_block=config.transformer_layers_per_block[i],
+                num_attention_heads=config.num_attention_heads[i],
+                cross_attention_dim=config.cross_attention_dim[i],
+                resnet_groups=config.norm_num_groups,
+                add_downsample=(i < len(config.block_out_channels) - 1),
+                add_upsample=False,
+                        add_cross_attention="CrossAttn" in config.down_block_types[i],
+                    )
+                )
+            else:
+                self.down_blocks.append(nn.Identity())
+            
+         
+        # Make the middle block
+        if 4 in self.layers_range:
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=config.block_out_channels[-1],
+                    out_channels=config.block_out_channels[-1],
+                    temb_channels=config.block_out_channels[0] * 4,
+                    groups=config.norm_num_groups,
+                ),
+                Transformer2D(
+                    in_channels=config.block_out_channels[-1],
+                    model_dims=config.block_out_channels[-1],
+                    num_heads=config.num_attention_heads[-1],
+                    num_layers=config.transformer_layers_per_block[-1],
+                    encoder_dims=config.cross_attention_dim[-1],
+                ),
+                ResnetBlock2D(
+                    in_channels=config.block_out_channels[-1],
+                    out_channels=config.block_out_channels[-1],
+                    temb_channels=config.block_out_channels[0] * 4,
+                    groups=config.norm_num_groups,
+                ),
+            ]
+
+        # Make the upsampling blocks
+        block_channels = (
+            [config.block_out_channels[0]]
+            + list(config.block_out_channels)
+            + [config.block_out_channels[-1]]
+        )
+
+        total_items = len(block_channels) - 3
+        reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
+
+        self.up_blocks = []
+        for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):  
+            i = total_items - rev_i 
+            if rev_i+5 in self.layers_range:
+                self.up_blocks.append(
+                    UNetBlock2D(
+                        in_channels=in_channels,
+                out_channels=out_channels,
+                temb_channels=config.block_out_channels[0] * 4,
+                prev_out_channels=prev_out_channels,
+                num_layers=config.layers_per_block[i] + 1,
+                transformer_layers_per_block=config.transformer_layers_per_block[i],
+                num_attention_heads=config.num_attention_heads[i],
+                cross_attention_dim=config.cross_attention_dim[i],
+                resnet_groups=config.norm_num_groups,
+                add_downsample=False,
+                add_upsample=(i > 0),
+                        add_cross_attention="CrossAttn" in config.up_block_types[i],
+                    )
+                )
+            else:
+                self.up_blocks.append(nn.Identity())
+            
+        
+        if shard.is_last_layer():
+            self.conv_norm_out = nn.GroupNorm(
+                config.norm_num_groups,
+                config.block_out_channels[0],
+                pytorch_compatible=True,
+            )
+            self.conv_out = nn.Conv2d(
+                config.block_out_channels[0],
+                config.out_channels,
+                config.conv_out_kernel,
+                padding=(config.conv_out_kernel - 1) // 2,
+            )
+
+    def __call__(
+        self,
+        x,
+        timestep,
+        encoder_x,
+        attn_mask=None,
+        encoder_attn_mask=None,
+        text_time=None,
+        residuals=None,
+    ):
+        # Compute the time embeddings
+        
+        temb = self.timesteps(timestep).astype(x.dtype)
+        temb = self.time_embedding(temb)
+
+        # Add the extra text_time conditioning
+        if text_time is not None:
+            text_emb, time_ids = text_time
+            emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
+            emb = mx.concatenate([text_emb, emb], axis=-1)
+            emb = self.add_embedding(emb)
+            temb = temb + emb
+        
+        if self.shard.is_first_layer():
+            # Preprocess the input
+            x = self.conv_in(x)
+            residuals = [x]
+        # Run the downsampling part of the unet
+        
+        for i in range(len(self.down_blocks)):
+            if i in self.layers_range:
+                x, res = self.down_blocks[i](
+                    x,
+                    encoder_x=encoder_x,
+                    temb=temb,
+                    attn_mask=attn_mask,
+                    encoder_attn_mask=encoder_attn_mask,
+                )
+                residuals.extend(res)
+            else:
+                x= self.down_blocks[i](x)
+
+        if 4 in self.layers_range:
+            # Run the middle part of the unet
+            x = self.mid_blocks[0](x, temb)
+            x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
+            x = self.mid_blocks[2](x, temb)
+
+        # Run the upsampling part of the unet
+        for i in range(len(self.up_blocks)):
+            if i+5 in self.layers_range:
+                x, _ = self.up_blocks[i](
+                    x,
+                    encoder_x=encoder_x,
+                    temb=temb,
+                    attn_mask=attn_mask,
+                    encoder_attn_mask=encoder_attn_mask,
+                    residual_hidden_states=residuals,
+                )
+            else:
+                x= self.up_blocks[i](x)
+
+        # Postprocess the output
+        if self.shard.is_last_layer():
+            dtype = x.dtype
+            x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
+            x = nn.silu(x)
+            x = self.conv_out(x)
+
+        return x, residuals
+    def sanitize(self, weights):
+        sanitized_weights = {}
+        for key, value in weights.items():
+            k1=""
+            k2=""
+            if "downsamplers" in key:
+                key = key.replace("downsamplers.0.conv", "downsample")
+            if "upsamplers" in key:
+                key = key.replace("upsamplers.0.conv", "upsample")
+
+            # Map the mid block
+            if "mid_block.resnets.0" in key:
+                key = key.replace("mid_block.resnets.0", "mid_blocks.0")
+            if "mid_block.attentions.0" in key:
+                key = key.replace("mid_block.attentions.0", "mid_blocks.1")
+            if "mid_block.resnets.1" in key:
+                key = key.replace("mid_block.resnets.1", "mid_blocks.2")
+
+            # Map attention layers
+            if "to_k" in key:
+                key = key.replace("to_k", "key_proj")
+            if "to_out.0" in key:
+                key = key.replace("to_out.0", "out_proj")
+            if "to_q" in key:
+                key = key.replace("to_q", "query_proj")
+            if "to_v" in key:
+                key = key.replace("to_v", "value_proj")
+
+            # Map transformer ffn
+            if "ff.net.2" in key:
+                key = key.replace("ff.net.2", "linear3")
+            if "ff.net.0" in key:
+                k1 = key.replace("ff.net.0.proj", "linear1")
+                k2 = key.replace("ff.net.0.proj", "linear2")
+                v1, v2 = mx.split(value, 2)
+                
+
+            if "conv_shortcut.weight" in key:
+                value = value.squeeze()
+            
+            # Transform the weights from 1x1 convs to linear
+            if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
+                value = value.squeeze()
+
+            if len(value.shape) == 4:
+                value = value.transpose(0, 2, 3, 1)
+                value = value.reshape(-1).reshape(value.shape)
+
+            if key.startswith("conv_in") :
+                if 0 not in self.layers_range:
+                    continue
+            
+            if key.startswith("down_blocks"):
+                layer_num = int(key.split(".")[1])
+                if layer_num not in self.layers_range:
+                    continue
+
+            if key.startswith("mid_block"):
+                if 4 not in self.layers_range:
+                    continue
+
+            if key.startswith("up_blocks"):
+                layer_num = int(key.split(".")[1])
+                if (layer_num+5) not in self.layers_range:
+                    continue
+            
+            if key.startswith("conv_out") or key.startswith("conv_norm_out"):
+                if 8 not in self.layers_range:
+                    continue
+
+            if len(k1)>0:
+                sanitized_weights[k1] = v1
+                sanitized_weights[k2] = v2
+            else:
+                sanitized_weights[key] = value
+
+
+        return sanitized_weights

+ 429 - 0
exo/inference/mlx/models/sd_models/vae.py

@@ -0,0 +1,429 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
+
+import math
+from typing import List
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .unet import ResnetBlock2D, upsample_nearest
+from dataclasses import dataclass, field
+from exo.inference.shard import Shard
+from typing import Tuple
+import inspect
+from ..base import IdentityBlock
+
+@dataclass
+class AutoencoderConfig:
+    in_channels: int = 3
+    out_channels: int = 3
+    latent_channels_out: int = 8
+    latent_channels_in: int = 4
+    block_out_channels: Tuple[int] = (128, 256, 512, 512)
+    layers_per_block: int = 2
+    norm_num_groups: int = 32
+    scaling_factor: float = 0.18215
+    weight_files: List[str] = field(default_factory=lambda: [])
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+
+@dataclass
+class ModelArgs(AutoencoderConfig):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        if not self.shard.is_first_layer():
+            self.vision_config = None
+
+
+class Attention(nn.Module):
+    """A single head unmasked attention for use with the VAE."""
+
+    def __init__(self, dims: int, norm_groups: int = 32):
+        super().__init__()
+
+        self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
+        self.query_proj = nn.Linear(dims, dims)
+        self.key_proj = nn.Linear(dims, dims)
+        self.value_proj = nn.Linear(dims, dims)
+        self.out_proj = nn.Linear(dims, dims)
+
+    def __call__(self, x):
+        B, H, W, C = x.shape
+
+        y = self.group_norm(x)
+
+        queries = self.query_proj(y).reshape(B, H * W, C)
+        keys = self.key_proj(y).reshape(B, H * W, C)
+        values = self.value_proj(y).reshape(B, H * W, C)
+
+        scale = 1 / math.sqrt(queries.shape[-1])
+        scores = (queries * scale) @ keys.transpose(0, 2, 1)
+        attn = mx.softmax(scores, axis=-1)
+        y = (attn @ values).reshape(B, H, W, C)
+
+        y = self.out_proj(y)
+        x = x + y
+
+        return x
+
+
+class EncoderDecoderBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        num_layers: int = 1,
+        resnet_groups: int = 32,
+        add_downsample=True,
+        add_upsample=True,
+    ):
+        super().__init__()
+
+        # Add the resnet blocks
+        self.resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels if i == 0 else out_channels,
+                out_channels=out_channels,
+                groups=resnet_groups,
+            )
+            for i in range(num_layers)
+        ]
+
+        # Add an optional downsampling layer
+        if add_downsample:
+            self.downsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=2, padding=0
+            )
+
+        # or upsampling layer
+        if add_upsample:
+            self.upsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=1, padding=1
+            )
+
+    def __call__(self, x):
+        for resnet in self.resnets:
+            x = resnet(x)
+        if "downsample" in self:
+            x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
+            x = self.downsample(x)
+
+        if "upsample" in self:
+            x = self.upsample(upsample_nearest(x))
+        return x
+
+
+class Encoder(nn.Module):
+    """Implements the encoder side of the Autoencoder."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        latent_channels_out: int,
+        block_out_channels: List[int] = [64],
+        layers_per_block: int = 2,
+        resnet_groups: int = 32,
+        layers_range: List[int] = [],
+        shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+    ):
+        super().__init__()
+        self.layers_range = layers_range
+        self.shard = shard
+        if self.shard.is_first_layer():
+            self.conv_in = nn.Conv2d(
+                in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
+            )
+
+        channels = [block_out_channels[0]] + list(block_out_channels)
+        self.down_blocks = []
+        current_layer = 1
+        for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
+            if current_layer in self.layers_range:
+                self.down_blocks.append(
+                    EncoderDecoderBlock2D(
+                        in_channels,
+                        out_channels,
+                        num_layers=layers_per_block,
+                        resnet_groups=resnet_groups,
+                        add_downsample=i < len(block_out_channels) - 1,
+                        add_upsample=False,
+                    )
+                )
+            else:
+                self.down_blocks.append(IdentityBlock())
+            current_layer += 1
+
+        if self.shard.is_last_layer():
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+                Attention(block_out_channels[-1], resnet_groups),
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+            ]
+
+            self.conv_norm_out = nn.GroupNorm(
+                resnet_groups, block_out_channels[-1], pytorch_compatible=True
+            )
+            self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
+
+    def __call__(self, x):
+        if self.shard.is_first_layer():
+            x = self.conv_in(x)
+
+        for l in self.down_blocks:
+            x = l(x)
+
+        if self.shard.is_last_layer():
+            x = self.mid_blocks[0](x)
+            x = self.mid_blocks[1](x)
+            x = self.mid_blocks[2](x)
+
+            x = self.conv_norm_out(x)
+            x = nn.silu(x)
+            x = self.conv_out(x)
+
+        return x
+
+
+class Decoder(nn.Module):
+    """Implements the decoder side of the Autoencoder."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        shard: Shard,
+        layer_range: List[int],
+        block_out_channels: List[int] = [64],
+        layers_per_block: int = 2,
+        resnet_groups: int = 32,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.layers_range = layer_range
+        if 0 in layer_range:
+            self.conv_in = nn.Conv2d(
+                in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
+            )
+        
+        if 0 in layer_range:
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+                Attention(block_out_channels[-1], resnet_groups),
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+            ]
+
+        channels = list(reversed(block_out_channels))
+        channels = [channels[0]] + channels
+        
+        self.up_blocks = []
+        current_layer = 1
+
+        for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
+            if current_layer in layer_range:
+                self.up_blocks.append(
+                    EncoderDecoderBlock2D(
+                        in_channels,
+                        out_channels,
+                        num_layers=layers_per_block,
+                        resnet_groups=resnet_groups,
+                        add_downsample=False,
+                        add_upsample=i < len(block_out_channels) - 1,
+                    )
+                )
+            else:
+                self.up_blocks.append(IdentityBlock())
+            current_layer += 1
+        if 4 in layer_range:
+            self.conv_norm_out = nn.GroupNorm(
+                resnet_groups, block_out_channels[0], pytorch_compatible=True
+            )
+            self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
+
+
+    def __call__(self, x):
+        if 0 in self.layers_range:
+            x = self.conv_in(x)
+            x = self.mid_blocks[0](x)
+            x = self.mid_blocks[1](x)
+            x = self.mid_blocks[2](x)
+        
+        for l in self.up_blocks:
+            x = l(x)
+        if 4 in self.layers_range:
+            x = self.conv_norm_out(x)
+            x = nn.silu(x)
+            x = self.conv_out(x)
+        return x
+
+
+class Autoencoder(nn.Module):
+    """The autoencoder that allows us to perform diffusion in the latent space."""
+
+    def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
+        super().__init__()
+        self.shard = shard
+        self.start_layer = shard.start_layer
+        self.end_layer = shard.end_layer
+        self.layers_range = list(range(self.start_layer, self.end_layer+1))
+        self.latent_channels = config.latent_channels_in
+        self.scaling_factor = config.scaling_factor
+        self.model_shard = model_shard
+        if self.model_shard == "vae_encoder":
+            self.encoder = Encoder(
+                config.in_channels,
+                config.latent_channels_out,
+                config.block_out_channels,
+                config.layers_per_block,
+                resnet_groups=config.norm_num_groups,
+                layers_range=self.layers_range,
+                shard=shard
+            )
+            if self.shard.is_last_layer():
+                self.quant_proj = nn.Linear(
+                config.latent_channels_out, config.latent_channels_out
+                )
+        if self.model_shard == "vae_decoder":
+            self.decoder = Decoder(
+                config.latent_channels_in,
+                config.out_channels,
+                shard,
+                self.layers_range,
+                config.block_out_channels,
+                config.layers_per_block + 1,
+                resnet_groups=config.norm_num_groups,
+            )
+            if self.shard.is_first_layer():
+                self.post_quant_proj = nn.Linear(
+                    config.latent_channels_in, config.latent_channels_in
+                )
+
+    def decode(self, z):
+        if self.shard.is_first_layer():
+            z = z / self.scaling_factor
+            z=self.post_quant_proj(z)
+        return self.decoder(z)
+
+    def encode(self, x):
+        x = self.encoder(x)
+        if self.shard.is_last_layer():   
+            x = self.quant_proj(x)
+            mean, logvar = x.split(2, axis=-1)
+            mean = mean * self.scaling_factor
+            logvar = logvar + 2 * math.log(self.scaling_factor)
+            x = mean
+        return x
+
+    def __call__(self, x, key=None):
+        mean, logvar = self.encode(x)
+        z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
+        x_hat = self.decode(z)
+
+        return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
+
+    def sanitize(self, weights):
+        shard = self.shard
+        layers = self.layers_range
+        sanitized_weights = {}
+        for key, value in weights.items():
+
+            if "downsamplers" in key:
+                key = key.replace("downsamplers.0.conv", "downsample")
+            if "upsamplers" in key:
+                key = key.replace("upsamplers.0.conv", "upsample")
+
+            # Map attention layers
+            if "key" in key:
+                key = key.replace("key", "key_proj")
+            if "proj_attn" in key:
+                key = key.replace("proj_attn", "out_proj")
+            if "query" in key:
+                key = key.replace("query", "query_proj")
+            if "value" in key:
+                key = key.replace("value", "value_proj")
+
+            # Map the mid block
+            if "mid_block.resnets.0" in key:
+                key = key.replace("mid_block.resnets.0", "mid_blocks.0")
+            if "mid_block.attentions.0" in key:
+                key = key.replace("mid_block.attentions.0", "mid_blocks.1")
+            if "mid_block.resnets.1" in key:
+                key = key.replace("mid_block.resnets.1", "mid_blocks.2")
+    
+            # Map the quant/post_quant layers
+            if "quant_conv" in key:
+                key = key.replace("quant_conv", "quant_proj")
+                value = value.squeeze()
+                
+            # Map the conv_shortcut to linear
+            if "conv_shortcut.weight" in key:
+                value = value.squeeze()
+
+            if len(value.shape) == 4:
+                value = value.transpose(0, 2, 3, 1)
+                value = value.reshape(-1).reshape(value.shape)
+
+
+            if "post_quant_conv" in key :
+                key = key.replace("quant_conv", "quant_proj")
+                value = value.squeeze()
+            
+            if 'decoder' in key and self.model_shard == "vae_decoder":
+                if key.startswith("decoder.mid_blocks."):
+                    if 0 in layers:
+                        sanitized_weights[key] = value
+                if "conv_in" in key and 0 in layers:
+                    sanitized_weights[key] = value
+                if key.startswith("decoder.up_blocks."):
+                    layer_num = int(key.split(".")[2])+1
+                    if layer_num in layers:
+                        sanitized_weights[key] = value
+                if key.startswith("decoder.conv_norm_out") and 4 in layers:
+                    sanitized_weights[key] = value
+                if key.startswith("decoder.conv_out") and 4 in layers:
+                    sanitized_weights[key] = value
+            if self.model_shard == "vae_decoder":
+                if key.startswith("post_quant_proj") and 0 in layers:
+                    sanitized_weights[key] = value
+            if self.model_shard == "vae_encoder":
+                if key.startswith("encoder."):
+                    if "conv_in" in key and shard.is_first_layer():
+                        sanitized_weights[key] = value
+                    if key.startswith("encoder.down_blocks."):
+                        layer_num = int(key.split(".")[2])+1
+                        if layer_num in layers:
+                            sanitized_weights[key] = value
+                    if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                    if "conv_norm_out" in key and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                    if "conv_out" in key and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                if key.startswith("quant_proj") and shard.is_last_layer():
+                    sanitized_weights[key] = value
+        return sanitized_weights
+

+ 35 - 18
exo/inference/mlx/sharded_inference_engine.py

@@ -12,6 +12,7 @@ from exo.download.shard_download import ShardDownloader
 import asyncio
 from collections import OrderedDict
 from mlx_lm.models.cache import make_prompt_cache
+from concurrent.futures import ThreadPoolExecutor
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -20,6 +21,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.caches = OrderedDict()
     self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
     self.sampler = make_sampler(*self.sampler_params)
+    self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
+    self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
+
+  async def _eval_mlx(self, *args):
+    loop = asyncio.get_running_loop()
+    await loop.run_in_executor(self._mlx_thread, mx.eval, *args)
 
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:
@@ -38,16 +45,19 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     logits = mx.array(x)
     logits = logits[:, -1, :]
     logprobs = logits - mx.logsumexp(logits, keepdims=True)
-    return np.asarray(self.sampler(logprobs), dtype=int)
+    result = self.sampler(logprobs)
+    await self._eval_mlx(result)
+    return np.asarray(result, dtype=int)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
-    tokens = self.tokenizer.encode(prompt)
-    return np.asarray(tokens)
+    loop = asyncio.get_running_loop()
+    return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))
 
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    return self.tokenizer.decode(tokens)
+    loop = asyncio.get_running_loop()
+    return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)
 
   async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
@@ -56,13 +66,18 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
     self.model.load_weights(path)
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
     await self.ensure_shard(shard)
-    state = await self.poll_state(request_id)
+    loop = asyncio.get_running_loop()
+    state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
-    output_data = np.array(self.model(x, **state), copy=False)
-    return output_data
+    if self.model.model_type != 'StableDiffusionPipeline':
+      output_data = self.model(x, **state, **inference_state)
+    else:
+      output_data, inference_state = self.model(x, **state, **inference_state)
+    output_data = np.array(output_data, copy=False)
+    return output_data, inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
@@ -87,26 +102,25 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     return True
 
   async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
-    loop = asyncio.get_running_loop()
-    nothin = await self.ensure_train(shard, loss, opt, lr)
+    await self.ensure_train(shard, loss, opt, lr)
+
     def train_step(inp, tar, lng):
       lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
       gradlayers = grad['model']['layers']
       self.session['opt'].update(self.model, grad)
-      mx.eval(self.model.parameters(), self.session['opt'].state, lval)
-      return lval, gradlayers
+      return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
 
     x = mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
 
-    score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
-    #print(f"{score=}")
+    score, gradients, eval_args = train_step(x, y, l)
+    await self._eval_mlx(*eval_args)
 
     layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
-    #print(layers[0])
-
-    return score, np.array(layers[0]['input_layernorm'], copy=False)
+    first_layer = np.array(layers[0]['input_layernorm'], copy=False)
+    await self._eval_mlx(first_layer)
+    return score, first_layer
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -121,3 +135,6 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
       self.caches = OrderedDict()
       self.session = {}
 
+  async def cleanup(self):
+    self._mlx_thread.shutdown(wait=True)
+

+ 56 - 15
exo/inference/mlx/sharded_utils.py

@@ -62,8 +62,16 @@ def _get_classes(config: dict):
 
 def load_config(model_path: Path) -> dict:
   try:
-    with open(model_path/"config.json", "r") as f:
-      config = json.load(f)
+    config_path = model_path / "config.json"
+    if config_path.exists():
+      with open(config_path, "r") as f:
+        config = json.load(f)
+      return config
+    
+    model_index_path = model_path / "model_index.json"
+    if model_index_path.exists():
+      config = load_model_index(model_path, model_index_path)
+      return config
   except FileNotFoundError:
     logging.error(f"Config file not found in {model_path}")
     raise
@@ -110,6 +118,24 @@ def load_model_shard(
     # Try weight for back-compat
     weight_files = glob.glob(str(model_path/"weight*.safetensors"))
 
+  model_class, model_args_class = _get_classes(config=config)
+
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x, *args, **kwargs)
+      return y
+
+  model_args = model_args_class.from_dict(config)
+  model = ShardedModel(model_args)
+
+  if config.get("model_index", False):
+    model.load()
+    return model
+
   if not weight_files:
     logging.error(f"No safetensors found in {model_path}")
     raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -129,19 +155,7 @@ def load_model_shard(
 
     weights.update(mx.load(wf))
 
-  model_class, model_args_class = _get_classes(config=config)
-
-  class ShardedModel(model_class):
-    def __init__(self, args):
-      super().__init__(args)
-      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
-
-    def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x, *args, **kwargs)
-      return y
-
-  model_args = model_args_class.from_dict(config)
-  model = ShardedModel(model_args)
+  
 
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
@@ -186,6 +200,9 @@ async def load_shard(
     processor.eos_token_id = processor.tokenizer.eos_token_id
     processor.encode = processor.tokenizer.encode
     return model, processor
+  elif hasattr(model, "tokenizer"):
+    tokenizer = model.tokenizer
+    return model, tokenizer
   else:
     tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
     return img
   else:
     raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
+
+# loading a combined config for all models in the index
+def load_model_index(model_path: Path, model_index_path: Path):
+  models_config = {}
+  with open(model_index_path, "r") as f:
+      model_index = json.load(f)
+  models_config["model_index"] = True
+  models_config["model_type"] = model_index["_class_name"]
+  models_config["models"] = {}
+  for model in model_index.keys():
+    model_config_path = glob.glob(str(model_path / model / "*config.json"))
+    if len(model_config_path)>0:
+      with open(model_config_path[0], "r") as f:
+        model_config = { }
+        model_config["model_type"] = model
+        model_config["config"] = json.load(f)
+        model_config["path"] = model_path / model
+        if model_config["path"]/"*model.safetensors":
+          model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
+        model_config["path"] = str(model_path / model)
+        m = {}
+        m[model] = model_config
+        models_config.update(m)
+  return models_config

+ 81 - 0
exo/inference/mlx/test_non_blocking.py

@@ -0,0 +1,81 @@
+import asyncio
+import time
+import numpy as np
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.inference.shard import Shard
+from exo.models import build_base_shard
+from collections import deque
+from statistics import mean, median
+
+async def test_non_blocking():
+    # Setup
+    shard_downloader = HFShardDownloader()
+    engine = MLXDynamicShardInferenceEngine(shard_downloader)
+    _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
+    shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
+    await engine.ensure_shard(shard)
+    
+    queue = asyncio.Queue()
+    measurements = deque(maxlen=1000000)
+    running = True
+
+    async def mlx_worker():
+        try:
+            start_time = time.time()
+            count = 0
+            while running and (time.time() - start_time) < 5:  # Hard time limit
+                start = time.perf_counter_ns()
+                await engine.infer_prompt("req1", shard, "test prompt")
+                duration = (time.perf_counter_ns() - start) / 1_000_000  # Convert to ms
+                count += 1
+                print(f"MLX operation {count} took: {duration:.3f}ms")
+        except asyncio.CancelledError:
+            pass
+        finally:
+            print(f"\nTotal MLX operations completed: {count}")
+            print(f"Average rate: {count/5:.1f} ops/second")
+
+    async def latency_producer():
+        try:
+            start_time = time.perf_counter_ns()
+            count = 0
+            while running:
+                await queue.put(time.perf_counter_ns())
+                count += 1
+                await asyncio.sleep(0)  # Yield to event loop without delay
+            duration = (time.perf_counter_ns() - start_time) / 1e9  # Convert to seconds
+            print(f"\nProducer iterations: {count}")
+            print(f"Producer rate: {count/duration:.1f} iterations/second")
+        except asyncio.CancelledError:
+            pass
+
+    async def latency_consumer():
+        try:
+            while running:
+                timestamp = await queue.get()
+                latency = (time.perf_counter_ns() - timestamp) / 1_000_000  # Convert to ms
+                measurements.append(latency)
+                queue.task_done()
+        except asyncio.CancelledError:
+            pass
+
+    tasks = [
+        asyncio.create_task(mlx_worker()),
+        asyncio.create_task(latency_producer()),
+        asyncio.create_task(latency_consumer())
+    ]
+    
+    try:
+        await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
+    except asyncio.TimeoutError:
+        print("\nTest timed out")
+    finally:
+        running = False
+        for task in tasks:
+            task.cancel()
+        await asyncio.gather(*tasks, return_exceptions=True)
+        print(f"\nFinal measurement count: {len(measurements)}")
+
+if __name__ == "__main__":
+    asyncio.run(test_non_blocking())

+ 3 - 3
exo/inference/tinygrad/inference.py

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
 from .losses import length_masked_ce_loss
 from collections import OrderedDict
 import asyncio
-
+from typing import Optional
 Tensor.no_grad = True 
 # default settings
 TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     safe_save(state_dict, path) 
   
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     def wrap_infer():
       x = Tensor(input_data)
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       self.states[request_id].start += x.shape[1]
       return out.realize()
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
-    return output_data.numpy()
+    return output_data.numpy(), inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
     def step(x, y, l):

+ 1 - 1
exo/inference/tokenizers.py

@@ -14,7 +14,7 @@ class DummyTokenizer:
     self.eos_token_id = 69
     self.vocab_size = 1000
 
-  def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
+  def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
     return "dummy_tokenized_prompt"
 
   def encode(self, text):

+ 6 - 4
exo/main.py

@@ -103,6 +103,7 @@ parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailsca
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
 parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
+parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -182,11 +183,12 @@ api = ChatGPTAPI(
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
-  default_model=args.default_model
+  default_model=args.default_model,
+  system_prompt=args.system_prompt
+)
+node.on_token.register("update_topology_viz").on_next(
+  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
 )
-# node.on_token.register("update_topology_viz").on_next(
-#   lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
-# )
 
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:

+ 20 - 6
exo/models.py

@@ -92,14 +92,17 @@ model_cards = {
   "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
   "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
+  "qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
   "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
   "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
-  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
-  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
-  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
   "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
   "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
   "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
   "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
   "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
   ### nemotron
@@ -108,6 +111,11 @@ model_cards = {
   # gemma
   "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
   "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
+  # stable diffusion
+  "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
+  # phi
+  "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
+  "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
   # dummy
   "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }
@@ -133,18 +141,24 @@ pretty_name = {
   "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
   "deepseek-coder-v2.5": "Deepseek Coder V2.5",
   "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
   "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-3b": "Qwen 2.5 3B",
   "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
-  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
-  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
-  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
   "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
   "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
   "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-32b": "Qwen 2.5 32B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
   "qwen-2.5-72b": "Qwen 2.5 72B",
   "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "phi-3.5-mini": "Phi-3.5 Mini",
+  "phi-4": "Phi-4",
   "llama-3-8b": "Llama 3 8B",
   "llama-3-70b": "Llama 3 70B",
+  "stable-diffusion-2-1-base": "Stable Diffusion 2.1",
 }
 
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

+ 49 - 7
exo/networking/grpc/grpc_peer_handle.py

@@ -11,7 +11,8 @@ from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
-
+import json
+import mlx.core as mx
 
 class GRPCPeerHandle(PeerHandle):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -90,7 +91,7 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
+  async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -100,10 +101,11 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
     )
     await self.stub.SendPrompt(request)
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -113,8 +115,14 @@ class GRPCPeerHandle(PeerHandle):
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
     )
-    await self.stub.SendTensor(request)
+    response =await self.stub.SendTensor(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
   
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.ExampleRequest(
@@ -173,10 +181,44 @@ class GRPCPeerHandle(PeerHandle):
         topology.add_edge(node_id, conn.to_id, conn.description)
     return topology
 
-  async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
-    request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
-    await self.stub.SendNewToken(request)
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    tensor = None
+    if isinstance(result, np.ndarray):
+      tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
+      result = []
+    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
+    await self.stub.SendResult(request)
 
   async def send_opaque_status(self, request_id: str, status: str) -> None:
     request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
     await self.stub.SendOpaqueStatus(request)
+
+  def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
+    proto_inference_state = node_service_pb2.InferenceState()
+    other_data = {}
+    for k, v in inference_state.items():
+        if isinstance(v, mx.array):
+            np_array = np.array(v)
+            tensor_data = node_service_pb2.Tensor(
+                tensor_data=np_array.tobytes(),
+                shape=list(np_array.shape),
+                dtype=str(np_array.dtype)
+            )
+            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
+        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+            tensor_list = node_service_pb2.TensorList()
+            for tensor in v:
+                np_array = np.array(tensor)
+                tensor_data = node_service_pb2.Tensor(
+                    tensor_data=np_array.tobytes(),
+                    shape=list(np_array.shape),
+                    dtype=str(np_array.dtype)
+                )
+                tensor_list.tensors.append(tensor_data)
+            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
+        else:
+            # For non-tensor data, we'll still use JSON
+            other_data[k] = v
+    if other_data:
+      proto_inference_state.other_data_json = json.dumps(other_data)
+    return proto_inference_state

+ 39 - 8
exo/networking/grpc/grpc_server.py

@@ -8,6 +8,8 @@ from . import node_service_pb2_grpc
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.orchestration import Node
+import json
+import mlx.core as mx
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -58,9 +60,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     prompt = request.prompt
     request_id = request.request_id
-    await self.node.process_prompt(shard, prompt, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}")
-    return node_service_pb2.Empty()
+    inference_state = self.deserialize_inference_state(request.inference_state)
+    result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
   async def SendTensor(self, request, context):
     shard = Shard(
@@ -71,9 +75,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
-    await self.node.process_tensor(shard, tensor, request_id)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}")
-    return node_service_pb2.Empty()
+
+    inference_state = self.deserialize_inference_state(request.inference_state)
+
+    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
   
   async def SendExample(self, request, context):
     shard = Shard(
@@ -127,8 +135,12 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     request_id = request.request_id
     token = request.token
     is_finished = request.is_finished
-    if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}")
-    self.node.on_token.trigger_all(request_id, token, is_finished)
+    img = request.tensor
+    if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    result = list(result)
+    if len(img.tensor_data) > 0:
+      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
+    self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
 
   async def SendOpaqueStatus(self, request, context):
@@ -140,3 +152,22 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
   async def HealthCheck(self, request, context):
     return node_service_pb2.HealthCheckResponse(is_healthy=True)
+
+  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
+    inference_state = {}
+    
+    for k, tensor_data in inference_state_proto.tensor_data.items():
+        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
+        inference_state[k] = mx.array(np_array)
+    
+    for k, tensor_list in inference_state_proto.tensor_list_data.items():
+        inference_state[k] = [
+            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
+            for tensor in tensor_list.tensors
+        ]
+    
+    if inference_state_proto.other_data_json:
+        other_data = json.loads(inference_state_proto.other_data_json)
+        inference_state.update(other_data)
+    
+    return inference_state

+ 15 - 2
exo/networking/grpc/node_service.proto

@@ -23,12 +23,14 @@ message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
   optional string request_id = 3;
+  optional InferenceState inference_state = 4;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
   optional string request_id = 3;
+  optional InferenceState inference_state = 4;
 }
 
 message ExampleRequest {
@@ -51,6 +53,16 @@ message Tensor {
   string dtype = 3;
 }
 
+message TensorList {
+  repeated Tensor tensors = 1;
+}
+
+message InferenceState {
+  map<string, Tensor> tensor_data = 1;
+  map<string, TensorList> tensor_list_data = 2;
+  string other_data_json = 3;
+}
+
 message CollectTopologyRequest {
   repeated string visited = 1;
   int32 max_depth = 2;
@@ -85,8 +97,9 @@ message DeviceCapabilities {
 
 message SendNewTokenRequest {
   string request_id = 1;
-  int32 token = 2;
-  bool is_finished = 3;
+  repeated int32 result = 2;
+  optional Tensor tensor = 3;
+  bool is_finished = 4;
 }
 
 message SendOpaqueStatusRequest {

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 51 - 22
exo/networking/manual/manual_discovery.py

@@ -1,7 +1,9 @@
+import os
 import asyncio
-from exo.networking.discovery import Discovery
-from typing import Dict, List, Callable
+from typing import Dict, List, Callable, Optional
+from concurrent.futures import ThreadPoolExecutor
 
+from exo.networking.discovery import Discovery
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
 from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
     self,
     network_config_path: str,
     node_id: str,
-    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+    create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
   ):
-    self.topology = NetworkTopology.from_path(network_config_path)
+    self.network_config_path = network_config_path
+    self.node_id = node_id
     self.create_peer_handle = create_peer_handle
 
-    if node_id not in self.topology.peers:
-      raise ValueError(
-        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
-      )
-
     self.listen_task = None
-
     self.known_peers: Dict[str, PeerHandle] = {}
-    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
-    self.peers_in_network.pop(node_id)
+
+    self._cached_peers: Dict[str, PeerConfig] = {}
+    self._last_modified_time: Optional[float] = None
+    self._file_executor = ThreadPoolExecutor(max_workers=1)
 
   async def start(self) -> None:
     self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
 
   async def stop(self) -> None:
-    if self.listen_task:
-      self.listen_task.cancel()
+    if self.listen_task: self.listen_task.cancel()
+    self._file_executor.shutdown(wait=True)
 
   async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
     if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
-      for peer_id, peer_config in self.peers_in_network.items():
+      peers_from_config = await self._get_peers()
+      new_known_peers = {}
+      for peer_id, peer_config in peers_from_config.items():
         try:
           if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
           peer = self.known_peers.get(peer_id)
@@ -57,15 +58,43 @@ class ManualDiscovery(Discovery):
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
-            self.known_peers[peer_id] = peer
-          else:
-            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try:
-              del self.known_peers[peer_id]
-            except KeyError:
-              pass
+            new_known_peers[peer_id] = peer
+          elif DEBUG_DISCOVERY >= 2:
+            print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
         except Exception as e:
           if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
       await asyncio.sleep(5.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
+
+  async def _get_peers(self):
+    try:
+      loop = asyncio.get_running_loop()
+      current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
+
+      if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
+        return self._cached_peers
+
+      topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
+
+      if self.node_id not in topology.peers:
+        raise ValueError(
+          f"Node ID {self.node_id} not found in network config file "
+          f"{self.network_config_path}. Please run with `node_id` set to "
+          f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
+        )
+
+      peers_in_network = topology.peers
+      peers_in_network.pop(self.node_id)
+
+      self._cached_peers = peers_in_network
+      self._last_modified_time = current_mtime
+
+      return peers_in_network
+
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Error when loading network config file from {self.network_config_path}. "
+              f"Please update the config file in order to successfully discover peers. "
+              f"Exception: {e}")
+      return self._cached_peers

+ 1 - 1
exo/networking/manual/test_data/test_config.json

@@ -29,4 +29,4 @@
       }
     }
   }
-}
+}

+ 84 - 6
exo/networking/manual/test_manual_discovery.py

@@ -1,3 +1,4 @@
+import json
 import asyncio
 import unittest
 from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
-    _ = self.discovery1.start()
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+    )
+    await self.discovery1.start()
 
   async def asyncTearDown(self):
     await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
     self.peer2 = mock.AsyncMock()
     self.peer1.connect = mock.AsyncMock()
     self.peer2.connect = mock.AsyncMock()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
+    )
+    self.discovery2 = ManualDiscovery(
+      root_path,
+      "node2",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
+    )
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
     await self.server1.start()
     await self.server2.start()
-    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
-    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+    self.discovery1 = ManualDiscovery(
+      root_path,
+      "node1",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+    )
+    self.discovery2 = ManualDiscovery(
+      root_path,
+      "node2",
+      create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
+    )
     await self.discovery1.start()
     await self.discovery2.start()
 
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
     self.assertFalse(await peers1[0].is_connected())
     self.assertFalse(await peers2[0].is_connected())
 
+  async def test_dynamic_config_update(self):
+    initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(initial_peers), 1)
+
+    # Save original config for cleanup
+    with open(root_path, "r") as f:
+      original_config = json.load(f)
+
+    try:
+      updated_config = {
+        "peers": {
+          **original_config["peers"],
+          "node3": {
+            "address": "localhost",
+            "port": 50053,
+            "device_capabilities": {
+              "model": "Unknown Model",
+              "chip": "Unknown Chip",
+              "memory": 0,
+              "flops": {"fp32": 0, "fp16": 0, "int8": 0},
+            },
+          },
+        }
+      }
+
+      with open(root_path, "w") as f:
+        json.dump(updated_config, f, indent=2)
+
+      node3 = mock.AsyncMock(spec=Node)
+      server3 = GRPCServer(node3, "localhost", 50053)
+      await server3.start()
+
+      try:
+        # Wait for the config to be reloaded
+        await asyncio.sleep(1.5)
+
+        updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
+        self.assertEqual(len(updated_peers), 2)
+
+        for peer in updated_peers:
+          await peer.connect()
+          self.assertTrue(await peer.is_connected())
+
+      finally:
+        await server3.stop()
+
+    finally:
+      # Restore the original config file
+      with open(root_path, "w") as f:
+        json.dump(original_config, f, indent=2)
+
+    # Wait for the config to be reloaded again
+    await asyncio.sleep(1.5)
+
+    updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
+    self.assertEqual(len(updated_peers), 1)
+
 
 if __name__ == "__main__":
   asyncio.run(unittest.main())

+ 61 - 39
exo/orchestration/node.py

@@ -118,44 +118,50 @@ class Node:
     shard,
     result: np.ndarray,
     request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
   ):
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
-    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    
-    if shard.is_last_layer() and not is_finished:
-      self.token_count += 1
-      if self.token_count == 1:
-        self.first_token_time = time.perf_counter_ns()
-      if self.token_count % 20 == 0:
-        print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
-
-      token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
-      await self.inference_engine.ensure_shard(shard)
-      self.buffered_token_output[request_id][0].append(token.item())
-      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-      forward = token.reshape(1, -1)
-      self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
-      asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
+    if shard.model_id != 'stable-diffusion-2-1-base':
+      if request_id not in self.buffered_token_output:
+        self.buffered_token_output[request_id] = ([], False)
+      is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+      if shard.is_last_layer() and not is_finished:
+        token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
+        await self.inference_engine.ensure_shard(shard)
+        self.buffered_token_output[request_id][0].append(token.item())
+        is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+        asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
+        forward = token.reshape(1, -1)
+        intermediate_result = self.buffered_token_output[request_id][0]
+      else:
+        forward = result
     else:
+      await self.inference_engine.ensure_shard(shard)
+      is_finished = inference_state.get("is_finished", False)
+      intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
       forward = result
+    if shard.is_last_layer():
+      self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
+      asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
 
     if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      if shard.model_id != 'stable-diffusion-2-1-base':
+        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
       self.outstanding_requests.pop(request_id)
     else:
       self.outstanding_requests[request_id] = "waiting"
-      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
+
+    return  np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
 
-    return np.array(self.buffered_token_output[request_id][0])
 
   async def process_prompt(
     self,
     base_shard: Shard,
     prompt: str,
     request_id: Optional[str] = None,
-  ) -> None:
+    inference_state: Optional[dict] = {},
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     start_time = time.perf_counter_ns()
     asyncio.create_task(
@@ -172,7 +178,8 @@ class Node:
         }),
       )
     )
-    await self._process_prompt(base_shard, prompt, request_id)
+    start_time = time.perf_counter_ns()
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -192,7 +199,7 @@ class Node:
     )
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
@@ -201,12 +208,13 @@ class Node:
     if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
       self.outstanding_requests[request_id] = "waiting"
-      await self.forward_prompt(shard, prompt, request_id, 0)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
       return None
-
-    self.outstanding_requests[request_id] = "processing"
-    result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-    await self.process_inference_result(shard, result, request_id)
+    else:
+      self.outstanding_requests[request_id] = "processing"
+      result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
+      ret = await self.process_inference_result(shard, result, request_id, inference_state)
+      return result
 
   async def enqueue_example(
     self,
@@ -350,10 +358,11 @@ class Node:
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-  ) -> None:
+    inference_state: Optional[dict] = None,
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     start_time = time.perf_counter_ns()
-    await self._process_tensor(shard, tensor, request_id)
+    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
@@ -363,15 +372,17 @@ class Node:
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-  ) -> None:
+    inference_state: Optional[dict] = None,
+  ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
 
     try:
       self.outstanding_requests[request_id] = "processing"
-      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      await self.process_inference_result(shard, result, request_id) 
+      result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
+      ret = await self.process_inference_result(shard, result, request_id, inference_state) 
+      return ret
     except Exception as e:
       self.outstanding_requests.pop(request_id)
       print(f"Error processing tensor for shard {shard}: {e}")
@@ -404,19 +415,20 @@ class Node:
     prompt: str,
     request_id: str,
     target_index: int,
+    inference_state: Optional[dict] = None,
   ) -> None:
     if DEBUG >= 1: print(f"target partition index: {target_index}")
     target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
     next_shard = self.get_current_shard(base_shard, target_index)
     if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
     if target_id == self.id:
-      await self.process_prompt(next_shard, prompt, request_id)
+      await self.process_prompt(next_shard, prompt, request_id, inference_state)
     else:
       target_peer = next((p for p in self.peers if p.id() == target_id), None)
       if not target_peer:
         raise ValueError(f"Peer for {target_index} not found")
       if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
-      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
   
   async def forward_tensor(
     self,
@@ -424,19 +436,20 @@ class Node:
     tensor: np.ndarray,
     request_id: str,
     target_index: int,
+    inference_state: Optional[dict] = None,
   ) -> None:
     if DEBUG >= 1: print(f"target partition index: {target_index}")
     target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
     next_shard = self.get_current_shard(base_shard, target_index)
     if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
     if target_id == self.id:
-      await self.process_tensor(next_shard, tensor, request_id)
+      await self.process_tensor(next_shard, tensor, request_id, inference_state)
     else:
       target_peer = next((p for p in self.peers if p.id() == target_id), None)
       if not target_peer:
         raise ValueError(f"Peer for {target_index} not found")
       if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
-      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
 
   def get_partition_index(self, offset: int = 0):
     if not self.partitioning_strategy:
@@ -604,3 +617,12 @@ class Node:
   @property
   def current_topology(self) -> Topology:
     return self.topology
+
+  def handle_stable_diffusion(self, inference_state, result):
+    if inference_state['is_step_finished']:
+      inference_state['step']+=1
+    progress = [inference_state['step'],inference_state['total_steps']]
+    intermediate_result = result
+    if progress[0] == progress[1]:
+      intermediate_result = result
+    return intermediate_result, inference_state

+ 166 - 0
exo/orchestration/tracing.py

@@ -0,0 +1,166 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Any
+from opentelemetry import trace, context
+from opentelemetry.trace import Status, StatusCode, SpanContext
+from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
+from contextlib import contextmanager
+import time
+from threading import Lock
+
+@dataclass
+class TraceContext:
+  request_id: str
+  sequence_number: int
+  current_span: Optional[trace.Span] = None
+  trace_parent: Optional[str] = None
+  token_group_span: Optional[trace.Span] = None
+  token_count: int = 0
+  token_group_size: int = 10  # Default group size
+  request_span: Optional[trace.Span] = None  # Track the main request span
+
+class Tracer:
+  def __init__(self):
+    self.tracer = trace.get_tracer("exo")
+    self.contexts: Dict[str, TraceContext] = {}
+    self._lock = Lock()
+    self.propagator = TraceContextTextMapPropagator()
+    
+  def get_context(self, request_id: str) -> Optional[TraceContext]:
+    with self._lock:
+      return self.contexts.get(request_id)
+
+  def set_context(self, request_id: str, context: TraceContext):
+    with self._lock:
+      self.contexts[request_id] = context
+
+  def inject_context(self, span: trace.Span) -> str:
+    """Inject current span context into carrier for propagation"""
+    carrier = {}
+    ctx = trace.set_span_in_context(span)
+    self.propagator.inject(carrier, context=ctx)
+    return carrier.get("traceparent", "")
+
+  def extract_context(self, trace_parent: str) -> Optional[context.Context]:
+    """Extract span context from carrier"""
+    if not trace_parent:
+      return None
+    carrier = {"traceparent": trace_parent}
+    return self.propagator.extract(carrier)
+
+  def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
+    """Create a new context with the given trace parent"""
+    parent_ctx = self.extract_context(trace_parent)
+    if parent_ctx:
+      # Create a new request span that links to the parent context
+      request_span = self.tracer.start_span(
+        "request",
+        context=parent_ctx,
+        attributes={
+          "request_id": request_id,
+          "sequence_number": sequence_number
+        }
+      )
+      return TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number,
+        request_span=request_span,
+        current_span=request_span,
+        trace_parent=trace_parent
+      )
+    return TraceContext(request_id=request_id, sequence_number=sequence_number)
+
+  def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
+    """Handle token generation and manage token group spans"""
+    context.token_count += 1
+    
+    # Start a new token group span if needed
+    if not context.token_group_span and context.request_span:
+      group_number = (context.token_count - 1) // context.token_group_size + 1
+      
+      # Create token group span as child of request span
+      parent_ctx = trace.set_span_in_context(context.request_span)
+      context.token_group_span = self.tracer.start_span(
+        f"token_group_{group_number}",
+        context=parent_ctx,
+        attributes={
+          "request_id": context.request_id,
+          "group.number": group_number,
+          "group.start_token": context.token_count,
+          "group.max_tokens": context.token_group_size
+        }
+      )
+    
+    # Add token to current group span
+    if context.token_group_span:
+      relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
+      context.token_group_span.set_attribute(f"token.{relative_pos}", token)
+      context.token_group_span.set_attribute("token.count", relative_pos)
+      
+      # End current group span if we've reached the group size or if generation is finished
+      if context.token_count % context.token_group_size == 0 or is_finished:
+        context.token_group_span.set_attribute("token.final_count", relative_pos)
+        context.token_group_span.end()
+        context.token_group_span = None
+
+  @contextmanager
+  def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
+    """Start a new span with proper parent context"""
+    attributes = {
+      "request_id": context.request_id,
+      "sequence_number": context.sequence_number
+    }
+    if extra_attributes:
+      attributes.update(extra_attributes)
+      
+    # Use request span as parent if available
+    parent_ctx = None
+    if context.request_span:
+      parent_ctx = trace.set_span_in_context(context.request_span)
+    elif context.trace_parent:
+      parent_ctx = self.extract_context(context.trace_parent)
+      if parent_ctx and not context.request_span:
+        # Create a new request span that links to the parent context
+        context.request_span = self.tracer.start_span(
+          "request",
+          context=parent_ctx,
+          attributes={
+            "request_id": context.request_id,
+            "sequence_number": context.sequence_number
+          }
+        )
+        parent_ctx = trace.set_span_in_context(context.request_span)
+    elif context.current_span:
+      parent_ctx = trace.set_span_in_context(context.current_span)
+    
+    # Create span with parent context if it exists
+    if parent_ctx:
+      span = self.tracer.start_span(
+        name,
+        context=parent_ctx,
+        attributes=attributes
+      )
+    else:
+      span = self.tracer.start_span(
+        name,
+        attributes=attributes
+      )
+    
+    # Update context with current span
+    prev_span = context.current_span
+    context.current_span = span
+    
+    try:
+      start_time = time.perf_counter()
+      yield span
+      duration = time.perf_counter() - start_time
+      span.set_attribute("duration_s", duration)
+      span.set_status(Status(StatusCode.OK))
+    except Exception as e:
+      span.set_status(Status(StatusCode.ERROR, str(e)))
+      raise
+    finally:
+      span.end()
+      context.current_span = prev_span
+
+# Global tracer instance
+tracer = Tracer() 

+ 20 - 2
exo/tinychat/index.html

@@ -197,7 +197,25 @@
           const div = document.createElement('div');
           div.className = `message message-role-${role}`;
           try {
-            div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+              if (content.includes('![Generated Image]')) {
+                const imageUrl = content.match(/\((.*?)\)/)[1];
+                const img = document.createElement('img');
+                img.src = imageUrl;
+                img.alt = 'Generated Image';
+                img.onclick = async () => {
+                  try {
+                    const response = await fetch(img.src);
+                    const blob = await response.blob();
+                    const file = new File([blob], 'image.png', { type: 'image/png' });
+                    handleImageUpload({ target: { files: [file] } });
+                  } catch (error) {
+                    console.error('Error fetching image:', error);
+                  }
+                };
+                div.appendChild(img);
+              } else {
+                div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+              }
           } catch (e) {
             console.log(content);
             console.error(e);
@@ -281,7 +299,7 @@
 </span>
 </div>
 <div class="input">
-<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
+<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
 <i class="fas fa-image"></i>
 </button>
 <input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>

+ 96 - 39
exo/tinychat/index.js

@@ -231,53 +231,110 @@ document.addEventListener("alpine:init", () => {
             };
           }
         });
-        const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
-        if (containsImage) {
-          // Map all messages with string content to object with type text
-          apiMessages = apiMessages.map(msg => {
-            if (typeof msg.content === 'string') {
-              return {
-                ...msg,
-                content: [
-                  {
-                    type: "text",
-                    text: msg.content
-                  }
-                ]
-              };
-            }
-            return msg;
+        
+        if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
+          // Send a request to the image generation endpoint
+          console.log(apiMessages[apiMessages.length - 1].content)
+          console.log(this.cstate.selectedModel)  
+          console.log(this.endpoint)
+          const response = await fetch(`${this.endpoint}/image/generations`, {
+            method: "POST",
+            headers: {
+              "Content-Type": "application/json",
+            },
+            body: JSON.stringify({
+              "model": 'stable-diffusion-2-1-base',
+              "prompt": apiMessages[apiMessages.length - 1].content,
+              "image_url": this.imageUrl
+            }),
           });
+      
+          if (!response.ok) {
+            throw new Error("Failed to fetch");
+          }
+          const reader = response.body.getReader();
+          let done = false;
+          let gottenFirstChunk = false;
+  
+          while (!done) {
+            const { value, done: readerDone } = await reader.read();
+            done = readerDone;
+            const decoder = new TextDecoder();
+  
+            if (value) {
+              // Assume non-binary data (text) comes first
+              const chunk = decoder.decode(value, { stream: true });
+              const parsed = JSON.parse(chunk);
+              console.log(parsed)
+  
+              if (parsed.progress) {
+                if (!gottenFirstChunk) {
+                  this.cstate.messages.push({ role: "assistant", content: "" });
+                  gottenFirstChunk = true;
+                }
+                this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
+              }
+              else if (parsed.images) {
+                if (!gottenFirstChunk) {
+                  this.cstate.messages.push({ role: "assistant", content: "" });
+                  gottenFirstChunk = true;
+                }
+                const imageUrl = parsed.images[0].url;
+                console.log(imageUrl)
+                this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
+              }
+            }
+          }
         }
-
-
-        // start receiving server sent events
-        let gottenFirstChunk = false;
-        for await (
-          const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
-        ) {
-          if (!gottenFirstChunk) {
-            this.cstate.messages.push({ role: "assistant", content: "" });
-            gottenFirstChunk = true;
+        
+        else{        
+          const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
+          if (containsImage) {
+            // Map all messages with string content to object with type text
+            apiMessages = apiMessages.map(msg => {
+              if (typeof msg.content === 'string') {
+                return {
+                  ...msg,
+                  content: [
+                    {
+                      type: "text",
+                      text: msg.content
+                    }
+                  ]
+                };
+              }
+              return msg;
+            });
           }
 
-          // add chunk to the last message
-          this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
+          console.log(apiMessages)
+          //start receiving server sent events
+          let gottenFirstChunk = false;
+          for await (
+            const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
+          ) {
+            if (!gottenFirstChunk) {
+              this.cstate.messages.push({ role: "assistant", content: "" });
+              gottenFirstChunk = true;
+            }
 
-          // calculate performance tracking
-          tokens += 1;
-          this.total_tokens += 1;
-          if (start_time === 0) {
-            start_time = Date.now();
-            this.time_till_first = start_time - prefill_start;
-          } else {
-            const diff = Date.now() - start_time;
-            if (diff > 0) {
-              this.tokens_per_second = tokens / (diff / 1000);
+            // add chunk to the last message
+            this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
+
+            // calculate performance tracking
+            tokens += 1;
+            this.total_tokens += 1;
+            if (start_time === 0) {
+              start_time = Date.now();
+              this.time_till_first = start_time - prefill_start;
+            } else {
+              const diff = Date.now() - start_time;
+              if (diff > 0) {
+                this.tokens_per_second = tokens / (diff / 1000);
+              }
             }
           }
         }
-
         // Clean the cstate before adding it to histories
         const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
         cleanedCstate.messages = cleanedCstate.messages.map(msg => {

+ 58 - 13
exo/viz/topology_viz.py

@@ -91,25 +91,70 @@ class TopologyViz:
     content = []
     requests = list(self.requests.values())[-3:]  # Get the 3 most recent requests
     max_width = self.console.width - 6  # Full width minus padding and icon
-    max_lines = 13  # Maximum number of lines for the entire panel content
+
+    # Calculate available height for content
+    panel_height = 15  # Fixed panel height
+    available_lines = panel_height - 2  # Subtract 2 for panel borders
+    lines_per_entry = available_lines // len(requests) if requests else 0
 
     for (prompt, output) in reversed(requests):
       prompt_icon, output_icon = "💬️", "🤖"
 
+      # Calculate max lines for prompt and output
+      max_prompt_lines = lines_per_entry // 3  # Allocate 1/3 for prompt
+      max_output_lines = lines_per_entry - max_prompt_lines - 1  # Remaining space minus spacing
+
       # Process prompt
-      prompt_lines = prompt.split('\n')
-      if len(prompt_lines) > max_lines // 2:
-        prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
+      prompt_lines = []
+      for line in prompt.split('\n'):
+        words = line.split()
+        current_line = []
+        current_length = 0
+
+        for word in words:
+          if current_length + len(word) + 1 <= max_width:
+            current_line.append(word)
+            current_length += len(word) + 1
+          else:
+            if current_line:
+              prompt_lines.append(' '.join(current_line))
+            current_line = [word]
+            current_length = len(word)
+
+        if current_line:
+          prompt_lines.append(' '.join(current_line))
+
+      if len(prompt_lines) > max_prompt_lines:
+        prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...']
+
       prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
-      prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
+      prompt_text.append('\n'.join(prompt_lines), style="white")
+
+      # Process output - same word-aware wrapping
+      output_lines = []
+      for line in output.split('\n'):
+        words = line.split()
+        current_line = []
+        current_length = 0
+
+        for word in words:
+          if current_length + len(word) + 1 <= max_width:
+            current_line.append(word)
+            current_length += len(word) + 1
+          else:
+            if current_line:
+              output_lines.append(' '.join(current_line))
+            current_line = [word]
+            current_length = len(word)
+
+        if current_line:
+          output_lines.append(' '.join(current_line))
+
+      if len(output_lines) > max_output_lines:
+        output_lines = output_lines[:max_output_lines - 1] + ['...']
 
-      # Process output
-      output_lines = output.split('\n')
-      remaining_lines = max_lines - len(prompt_lines) - 2  # -2 for spacing
-      if len(output_lines) > remaining_lines:
-        output_lines = output_lines[:remaining_lines - 1] + ['...']
       output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
-      output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
+      output_text.append('\n'.join(output_lines), style="white")
 
       content.append(prompt_text)
       content.append(output_text)
@@ -119,8 +164,8 @@ class TopologyViz:
       Group(*content),
       title="",
       border_style="cyan",
-      height=15,  # Increased height to accommodate multiple lines
-      expand=True  # Allow the panel to expand to full width
+      height=panel_height,
+      expand=True
     )
 
   def _generate_main_layout(self) -> str:

+ 1 - 1
install.sh

@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
 
 if command -v python3.12 &>/dev/null; then
     echo "Python 3.12 is installed, proceeding with python3.12..."

+ 1 - 1
scripts/compile_grpc.sh

@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
 source ./install.sh
 pushd exo/networking/grpc
 python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto

+ 1 - 1
test/reconnect.sh

@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
 
 echo "Starting node 1"
 DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &

+ 1 - 1
test/test_tokenizers.py

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 models = []
 for model_id in model_cards:

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott