Kaynağa Gözat

Merge pull request #474 from pranav4501/stable-stable-diffusion-mlx

Stable diffusion mlx
Alex Cheema 3 ay önce
ebeveyn
işleme
b5cbcbc7a2

+ 2 - 0
.gitignore

@@ -171,3 +171,5 @@ cython_debug/
 
 **/*.xcodeproj/*
 .aider*
+
+exo/tinychat/images/*.png

+ 105 - 1
exo/api/chatgpt_api.py

@@ -12,11 +12,17 @@ import traceback
 import signal
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict, shutdown
+from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable, Optional
+from PIL import Image
+import numpy as np
+import base64
+from io import BytesIO
+import mlx.core as mx
+import tempfile
 from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
@@ -185,6 +191,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
@@ -195,10 +202,12 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
+      
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
       self.app.router.add_get("/", self.handle_root)
       self.app.router.add_static("/", self.static_dir, name="static")
+      self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
@@ -457,6 +466,85 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
+  
+  async def handle_post_image_generations(self, request):
+    data = await request.json()
+
+    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    stream = data.get("stream", False)
+    model = data.get("model", "")
+    prompt = data.get("prompt", "")
+    image_url = data.get("image_url", "")
+    if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
+    shard = build_base_shard(model, self.inference_engine_classname)
+    if DEBUG >= 2: print(f"shard: {shard}")
+    if not shard:
+        return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
+
+    request_id = str(uuid.uuid4())
+    callback_id = f"chatgpt-api-wait-response-{request_id}"
+    callback = self.node.on_token.register(callback_id)
+    try:
+      if image_url != "" and image_url != None:
+        img = self.base64_decode(image_url)
+      else:
+        img = None
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
+
+
+      response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
+      await response.prepare(request)
+
+      def get_progress_bar(current_step, total_steps, bar_length=50):
+        # Calculate the percentage of completion
+        percent = float(current_step) / total_steps
+        # Calculate the number of hashes to display
+        arrow = '-' * int(round(percent * bar_length) - 1) + '>'
+        spaces = ' ' * (bar_length - len(arrow))
+        
+        # Create the progress bar string
+        progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
+        return progress_bar
+
+      async def stream_image(_request_id: str, result, is_finished: bool):
+          if isinstance(result, list):
+              await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
+
+          elif isinstance(result, np.ndarray):
+            im = Image.fromarray(np.array(result))
+            images_folder = get_exo_images_dir()
+            # Save the image to a file
+            image_filename = f"{_request_id}.png"
+            image_path = images_folder / image_filename
+            im.save(image_path)
+            image_url = request.app.router['static_images'].url_for(filename=image_filename)
+            base_url = f"{request.scheme}://{request.host}"
+            # Construct the full URL correctly
+            full_image_url = base_url + str(image_url)
+            
+            await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
+            if is_finished:
+              await response.write_eof()
+              
+
+      stream_task = None
+      def on_result(_request_id: str, result, is_finished: bool):
+          nonlocal stream_task
+          stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
+          return _request_id == request_id and is_finished
+
+      await callback.wait(on_result, timeout=self.response_timeout*10)
+      
+      if stream_task:
+          # Wait for the stream task to complete before returning
+          await stream_task
+
+      return response
+
+    except Exception as e:
+        if DEBUG >= 2: traceback.print_exc()
+        return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+  
   async def handle_delete_model(self, request):
     try:
       model_name = request.match_info.get('model_name')
@@ -598,3 +686,19 @@ class ChatGPTAPI:
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     await site.start()
+
+  def base64_decode(self, base64_string):
+    #decode and reshape image
+    if base64_string.startswith('data:image'):
+        base64_string = base64_string.split(',')[1]
+    image_data = base64.b64decode(base64_string)
+    img = Image.open(BytesIO(image_data))
+    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    if W != img.width or H != img.height:
+        if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+    img = mx.array(np.array(img))
+    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = img[None]
+    return img
+  

+ 4 - 0
exo/download/hf/hf_helpers.py

@@ -303,6 +303,10 @@ async def download_repo_files(
         await f.write(json.dumps(file_list))
       if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
 
+    model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
+    if model_index_exists:
+      allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
+
     filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
     total_files = len(filtered_file_list)
     total_bytes = sum(file["size"] for file in filtered_file_list)

+ 11 - 7
exo/download/hf/hf_shard_download.py

@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
           print(f"No snapshot directory found for {self.current_repo_id}")
         return None
 
+      if not await aios.path.exists(snapshot_dir/"model_index.json"):
       # Get the weight map to know what files we need
-      weight_map = await get_weight_map(self.current_repo_id, self.revision)
-      if not weight_map:
-        if DEBUG >= 2:
-          print(f"No weight map found for {self.current_repo_id}")
-        return None
+        weight_map = await get_weight_map(self.current_repo_id, self.revision)
+        if not weight_map:
+          if DEBUG >= 2:
+            print(f"No weight map found for {self.current_repo_id}")
+          return None
+
+        # Get all files needed for this shard
+        patterns = get_allow_patterns(weight_map, self.current_shard)
+      else:
+        patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
 
-      # Get all files needed for this shard
-      patterns = get_allow_patterns(weight_map, self.current_shard)
 
       # Check download status for all relevant files
       status = {}

+ 20 - 1
exo/helpers.py

@@ -325,4 +325,23 @@ async def shutdown(signal, loop, server):
 def is_frozen():
   return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
     or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
-    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
+
+
+def get_exo_home() -> Path:
+  if os.name == "nt":  # Check if the OS is Windows
+    docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
+  else:
+    docs_folder = Path.home() / "Documents"
+  exo_folder = docs_folder / "Exo"
+  if not exo_folder.exists():
+    exo_folder.mkdir()
+  return exo_folder
+
+def get_exo_images_dir() -> Path:
+  exo_home = get_exo_home()
+  images_dir = exo_home / "Images"
+  if not images_dir.exists():
+    images_dir.mkdir()
+  return images_dir
+  

+ 8 - 4
exo/inference/inference_engine.py

@@ -39,11 +39,15 @@ class InferenceEngine(ABC):
   async def clear_session(self):
     self.session.empty()
   
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:
     tokens = await self.encode(shard, prompt)
-    x = tokens.reshape(1, -1)
-    output_data = await self.infer_tensor(request_id, shard, x)
-    return output_data 
+    if shard.model_id != 'stable-diffusion-2-1-base':
+      x = tokens.reshape(1, -1)
+    else:
+      x = tokens
+    output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
+
+    return output_data, inference_state
 
 inference_engine_classes = {
   "mlx": "MLXDynamicShardInferenceEngine",

+ 307 - 0
exo/inference/mlx/models/StableDiffusionPipeline.py

@@ -0,0 +1,307 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
+
+import time
+from typing import Optional, Tuple
+import inspect
+
+import mlx.core as mx
+import mlx.nn as nn
+from pathlib import Path
+
+from tqdm import tqdm
+
+from .sd_models.vae import ModelArgs as VAEArgs
+from .sd_models.vae import Autoencoder
+from .sd_models.tokenizer import load_tokenizer
+from .sd_models.clip import CLIPTextModel
+from .sd_models.clip import ModelArgs as CLIPArgs
+from .sd_models.unet import UNetConfig, UNetModel
+
+from dataclasses import dataclass, field
+from exo.inference.shard import Shard
+
+@dataclass
+class DiffusionConfig:
+    beta_schedule: str = "scaled_linear"
+    beta_start: float = 0.00085
+    beta_end: float = 0.012
+    num_train_steps: int = 1000
+
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+
+#Sampler
+def _linspace(a, b, num):
+    x = mx.arange(0, num) / (num - 1)
+    return (b - a) * x + a
+
+
+def _interp(y, x_new):
+    """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
+    x_low = x_new.astype(mx.int32)
+    x_high = mx.minimum(x_low + 1, len(y) - 1)
+
+    y_low = y[x_low]
+    y_high = y[x_high]
+    delta_x = x_new - x_low
+    y_new = y_low * (1 - delta_x) + delta_x * y_high
+
+    return y_new
+
+class SimpleEulerSampler:
+    """A simple Euler integrator that can be used to sample from our diffusion models.
+
+    The method ``step()`` performs one Euler step from x_t to x_t_prev.
+    """
+
+    def __init__(self, config: DiffusionConfig):
+        # Compute the noise schedule
+        if config.beta_schedule == "linear":
+            betas = _linspace(
+                config.beta_start, config.beta_end, config.num_train_steps
+            )
+        elif config.beta_schedule == "scaled_linear":
+            betas = _linspace(
+                config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
+            ).square()
+        else:
+            raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
+
+        alphas = 1 - betas
+        alphas_cumprod = mx.cumprod(alphas)
+
+        self._sigmas = mx.concatenate(
+            [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
+        )
+
+    @property
+    def max_time(self):
+        return len(self._sigmas) - 1
+
+    def sample_prior(self, shape, dtype=mx.float32, key=None):
+        noise = mx.random.normal(shape, key=key)
+        return (
+            noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
+        ).astype(dtype)
+
+    def add_noise(self, x, t, key=None):
+        noise = mx.random.normal(x.shape, key=key)
+        s = self.sigmas(t)
+        return (x + noise * s) * (s.square() + 1).rsqrt()
+
+    def sigmas(self, t):
+        return _interp(self._sigmas, t)
+
+    def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
+        start_time = start_time or (len(self._sigmas) - 1)
+        assert 0 < start_time <= (len(self._sigmas) - 1)
+        steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
+        return list(zip(steps, steps[1:]))
+
+    def current_timestep(self, step, total_steps, start_time=None):
+        if step < total_steps:
+            steps = self.timesteps(total_steps, start_time)
+            return steps[step]
+        else:
+            return mx.array(0),mx.array(0)
+
+    def step(self, eps_pred, x_t, t, t_prev):
+        sigma = self.sigmas(t).astype(eps_pred.dtype)
+        sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
+
+        dt = sigma_prev - sigma
+        x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
+
+        x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
+
+        return x_t_prev
+
+@dataclass
+class ShardConfig:
+    model_id:str
+    start_layer:int
+    end_layer:int
+    n_layers:int
+
+@dataclass
+class StableDiffusionConfig:
+    model_type:str
+    vae:VAEArgs
+    text_encoder:CLIPArgs
+    scheduler:DiffusionConfig
+    unet:UNetConfig
+    shard:ShardConfig
+    
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+@dataclass
+class ModelArgs(StableDiffusionConfig):
+    shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+
+class Model(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.model_type = config.model_type
+        self.config = config
+        self.model_path = config.vae['path'].split('/vae')[0]
+        self.shard = config.shard
+        self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder  = model_shards(config.shard)
+        self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
+        if self.shard_clip.start_layer != -1:
+            self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
+        else:
+            self.text_encoder = nn.Identity()    
+        self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
+        self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
+        self.sampler = SimpleEulerSampler(self.diffusion_config)
+        if self.shard_unet.start_layer!=-1:
+            self.config_unet = UNetConfig.from_dict(config.unet['config'])
+            self.unet = UNetModel(self.config_unet, self.shard_unet)
+        else:
+            self.unet = nn.Identity()
+        self.config_vae=VAEArgs.from_dict(config.vae['config'])
+        if self.shard_encoder.start_layer != -1:
+            self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder") 
+        else:
+            self.encoder = nn.Identity()            
+        if self.shard_decoder.start_layer != -1:
+            self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder") 
+        else:
+            self.decoder = nn.Identity()            
+
+    def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
+        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
+        is_finished = False
+        is_step_finished = False
+        if t.item()==1000:
+            if self.shard_clip.start_layer == 0:
+                conditioning = x
+            if self.shard_clip.start_layer != -1:
+                conditioning, mask= self.text_encoder(conditioning,mask)
+            seed = int(time.time()) 
+            mx.random.seed(seed)
+            if image is None:
+                if self.shard_encoder.is_last_layer():
+                    x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
+                    x_t_prev=x
+                    start_step = self.sampler.max_time
+            else:
+                if self.shard_encoder.start_layer != -1:
+                    image= self.encoder.encode(image)
+                    if self.shard_encoder.is_last_layer():
+                        start_step = self.sampler.max_time*strength
+                        total_steps = int(total_steps*strength)
+                        image = mx.broadcast_to(image, (1,) + image.shape[1:])
+                        x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
+                        image = None
+                        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
+        # Perform the denoising loop
+        if self.shard_unet.start_layer != -1:
+            with tqdm(total=total_steps,initial=step+1) as pbar:
+                if step<total_steps:
+                    x = x_t_prev
+                    if self.shard_unet.is_first_layer():
+                        x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
+                    else:
+                        x_t_unet = x
+                    t_unet = mx.broadcast_to(t, [len(x_t_unet)])
+                    x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
+                    if self.shard_unet.is_last_layer():
+                        if cfg_weight > 1:
+                            eps_text, eps_neg = x.split(2)
+                            eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
+                        x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
+                        x_t_prev=x
+                    mx.eval(x)
+                    
+        if self.shard_decoder.is_last_layer():
+            is_step_finished=True
+            if self.shard_decoder.start_layer != -1:
+                x=self.decoder.decode(x)
+            if self.shard_decoder.is_last_layer():
+                x = mx.clip(x / 2 + 0.5, 0, 1)
+                B, H, W, C = x.shape
+                x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
+                x = x.reshape(1 * H, B // 1 * W, C)
+                x = (x * 255).astype(mx.uint8)
+                if t_prev.item() ==0:
+                    is_finished=True   
+        mx.eval(x)
+         
+        return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
+    
+
+    def load(self):
+        if self.shard_encoder.start_layer != -1:    
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.encoder.sanitize(vae_weights)
+            self.encoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_decoder.start_layer != -1:
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.decoder.sanitize(vae_weights)
+            self.decoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_clip.start_layer != -1:
+            clip_weights = mx.load(self.config_clip.weight_files[0])
+            clip_weights = self.text_encoder.sanitize(clip_weights)
+            self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
+        if self.shard_unet.start_layer !=-1:
+            unet_weights = mx.load(self.config_unet.weight_files[0])
+            unet_weights = self.unet.sanitize(unet_weights)
+            self.unet.load_weights(list(unet_weights.items()), strict=True)
+
+def model_shards(shard:ShardConfig):
+    def create_shard(shard, model_ranges):
+        start_layer = shard.start_layer
+        end_layer = shard.end_layer
+        
+        shards = {}
+        
+        for model_name, (range_start, range_end) in model_ranges.items():
+            if start_layer < range_end and end_layer >= range_start:
+                # Calculate the overlap with the model range
+                overlap_start = max(start_layer, range_start)
+                overlap_end = min(end_layer, range_end - 1)
+
+                # Adjust the layers relative to the model's range
+                relative_start = overlap_start - range_start
+                relative_end = overlap_end - range_start
+                shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
+            else:
+                # If no overlap, create a zero-layer shard
+                shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
+        
+        return shards
+
+    # Define the ranges for different models
+    model_ranges = {
+        'clip': (0, 12),
+        'vae_encoder':(12,17),
+        'unet':(17,26),
+        'vae_decoder': (26, 31) # Example range for unet
+    }
+
+    # Call the function and get the shards for all models
+    shards = create_shard(shard, model_ranges)
+
+    # Access individual shards
+    shard_clip = shards['clip']
+    shard_encoder = shards['vae_encoder']
+    shard_unet = shards['unet']
+    shard_decoder = shards['vae_decoder']
+    
+    return shard_clip, shard_encoder, shard_unet, shard_decoder
+
+
+

+ 191 - 0
exo/inference/mlx/models/sd_models/clip.py

@@ -0,0 +1,191 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+from dataclasses import field, dataclass
+from exo.inference.shard import Shard
+from exo.inference.mlx.models.base import IdentityBlock 
+
+_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
+
+
+
+@dataclass
+class CLIPTextModelConfig:
+    num_layers: int = 23
+    model_dims: int = 1024
+    num_heads: int = 16
+    max_length: int = 77
+    vocab_size: int = 49408
+    projection_dim: Optional[int] = None
+    hidden_act: str = "quick_gelu"
+
+    @classmethod
+    def from_dict(cls, config):
+        return ModelArgs(
+            num_layers=config["num_hidden_layers"],
+            model_dims=config["hidden_size"],
+            num_heads=config["num_attention_heads"],
+            max_length=config["max_position_embeddings"],
+            vocab_size=config["vocab_size"],
+            projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
+            hidden_act=config.get("hidden_act", "quick_gelu"),
+            weight_files=config.get("weight_files", [])
+            )
+
+@dataclass
+class ModelArgs(CLIPTextModelConfig):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+    weight_files: List[str] = field(default_factory=lambda: [])
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        if not self.shard.is_first_layer():
+            self.vision_config = None
+
+
+@dataclass
+class CLIPOutput:
+    pooled_output: Optional[mx.array] = None
+    last_hidden_state: Optional[mx.array] = None
+    hidden_states: Optional[List[mx.array]] = None
+
+
+class CLIPEncoderLayer(nn.Module):
+    """The transformer encoder layer from CLIP."""
+
+    def __init__(self, model_dims: int, num_heads: int, activation: str):
+        super().__init__()
+
+        self.layer_norm1 = nn.LayerNorm(model_dims)
+        self.layer_norm2 = nn.LayerNorm(model_dims)
+
+        self.attention = nn.MultiHeadAttention(model_dims, num_heads)
+        self.attention.query_proj.bias = mx.zeros(model_dims)
+        self.attention.key_proj.bias = mx.zeros(model_dims)
+        self.attention.value_proj.bias = mx.zeros(model_dims)
+        self.attention.out_proj.bias = mx.zeros(model_dims)
+
+        self.linear1 = nn.Linear(model_dims, 4 * model_dims)
+        self.linear2 = nn.Linear(4 * model_dims, model_dims)
+
+        self.act = _ACTIVATIONS[activation]
+
+    def __call__(self, x, attn_mask=None):
+        
+        y = self.layer_norm1(x)
+        y = self.attention(y, y, y, attn_mask)
+        x = y + x
+        
+        y = self.layer_norm2(x)
+        y = self.linear1(y)
+        y = self.act(y)
+        y = self.linear2(y)
+        x = y + x
+        return x
+
+
+class CLIPTextModel(nn.Module):
+    """Implements the text encoder transformer from CLIP."""
+
+    def __init__(self, config: CLIPTextModelConfig, shard: Shard):
+        super().__init__()
+
+        self.shard = shard
+        self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2) 
+        if self.shard.is_first_layer():
+            self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
+            self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
+        self.layers = []
+        for i in range(math.ceil(config.num_layers/2)):
+            if  2*i in self.layers_range:
+                self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
+            if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
+                self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
+            else:
+                self.layers.append(IdentityBlock())
+        if self.shard.is_last_layer():
+            self.final_layer_norm = nn.LayerNorm(config.model_dims)
+
+        if config.projection_dim is not None:
+            self.text_projection = nn.Linear(
+                config.model_dims, config.projection_dim, bias=False
+            )
+
+    def _get_mask(self, N, dtype):
+        indices = mx.arange(N)
+        mask = indices[:, None] < indices[None]
+        mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
+        return mask
+
+    def __call__(self, x, mask=None):
+        # Extract some shapes
+        if self.shard.is_first_layer():
+            B, N = x.shape
+            eos_tokens = x.argmax(-1)
+            
+            # Compute the embeddings
+            x = self.token_embedding(x)
+           
+            x = x + self.position_embedding.weight[:N]
+            # Compute the features from the transformer
+            mask = self._get_mask(N, x.dtype)
+        
+        for l in self.layers:
+            x = l(x, mask)
+        # Apply the final layernorm and return
+        
+        if self.shard.is_last_layer():
+            x = self.final_layer_norm(x)
+        
+       
+
+        return x, mask
+    def sanitize(self, weights):
+        sanitized_weights = {}
+        for key, value in weights.items():
+            if "position_ids" in key:
+                continue
+            if key.startswith("text_model."):
+                key = key[11:]
+            if key.startswith("embeddings."):
+                key = key[11:]
+            if key.startswith("encoder."):
+                key = key[8:]
+
+            # Map attention layers
+            if "self_attn." in key:
+                key = key.replace("self_attn.", "attention.")
+            if "q_proj." in key:
+                key = key.replace("q_proj.", "query_proj.")
+            if "k_proj." in key:
+                key = key.replace("k_proj.", "key_proj.")
+            if "v_proj." in key:
+                key = key.replace("v_proj.", "value_proj.")
+
+            # Map ffn layers
+            if "mlp.fc1" in key:
+                key = key.replace("mlp.fc1", "linear1")
+            if "mlp.fc2" in key:
+                key = key.replace("mlp.fc2", "linear2")
+            
+            if key.startswith("layers."):
+                layer_num = int(key.split(".")[1])
+                if layer_num not in self.layers_range:
+                    continue
+            if not self.shard.is_first_layer() and "embedding" in key:
+                continue
+            if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
+                continue
+            if not self.shard.is_last_layer() and key.startswith("text_projection"):
+                continue
+            sanitized_weights[key] = value
+        return sanitized_weights

+ 131 - 0
exo/inference/mlx/models/sd_models/tokenizer.py

@@ -0,0 +1,131 @@
+# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
+
+import regex
+import json
+import glob
+
+
+class Tokenizer:
+    """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
+
+    def __init__(self, bpe_ranks, vocab):
+        self.bpe_ranks = bpe_ranks
+        self.vocab = vocab
+        self.pat = regex.compile(
+            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            regex.IGNORECASE,
+        )
+
+        self._cache = {self.bos: self.bos, self.eos: self.eos}
+
+    @property
+    def bos(self):
+        return "<|startoftext|>"
+
+    @property
+    def bos_token(self):
+        return self.vocab[self.bos]
+
+    @property
+    def eos(self):
+        return "<|endoftext|>"
+
+    @property
+    def eos_token(self):
+        return self.vocab[self.eos]
+
+    def bpe(self, text):
+        if text in self._cache:
+            return self._cache[text]
+
+        unigrams = list(text[:-1]) + [text[-1] + "</w>"]
+        unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+        if not unique_bigrams:
+            return unigrams
+
+        # In every iteration try to merge the two most likely bigrams. If none
+        # was merged we are done.
+        #
+        # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
+        while unique_bigrams:
+            bigram = min(
+                unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
+            )
+            if bigram not in self.bpe_ranks:
+                break
+
+            new_unigrams = []
+            skip = False
+            for a, b in zip(unigrams, unigrams[1:]):
+                if skip:
+                    skip = False
+                    continue
+
+                if (a, b) == bigram:
+                    new_unigrams.append(a + b)
+                    skip = True
+
+                else:
+                    new_unigrams.append(a)
+
+            if not skip:
+                new_unigrams.append(b)
+
+            unigrams = new_unigrams
+            unique_bigrams = set(zip(unigrams, unigrams[1:]))
+
+        self._cache[text] = unigrams
+
+        return unigrams
+
+    def tokenize(self, text, prepend_bos=True, append_eos=True):
+        if isinstance(text, list):
+            return [self.tokenize(t, prepend_bos, append_eos) for t in text]
+
+        # Lower case cleanup and split according to self.pat. Hugging Face does
+        # a much more thorough job here but this should suffice for 95% of
+        # cases.
+        clean_text = regex.sub(r"\s+", " ", text.lower())
+        tokens = regex.findall(self.pat, clean_text)
+
+        # Split the tokens according to the byte-pair merge file
+        bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
+
+        # Map to token ids and return
+        tokens = [self.vocab[t] for t in bpe_tokens]
+        if prepend_bos:
+            tokens = [self.bos_token] + tokens
+        if append_eos:
+            tokens.append(self.eos_token)
+
+        return tokens
+    
+    def encode(self, prompt):
+        tokens = [self.tokenize(prompt)]
+        negative_text = ""
+        if negative_text is not None:
+            tokens += [self.tokenize(negative_text)]
+        lengths = [len(t) for t in tokens]
+        N = max(lengths)
+        tokens = [t + [0] * (N - len(t)) for t in tokens]
+        return tokens
+
+def load_tokenizer(
+    model_path: str,
+    vocab_key: str = "tokenizer_vocab",
+    merges_key: str = "tokenizer_merges",
+):
+
+    vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
+    with open(vocab_file, encoding="utf-8") as f:
+        vocab = json.load(f)
+
+    merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
+    with open(merges_file, encoding="utf-8") as f:
+        bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
+    bpe_merges = [tuple(m.split()) for m in bpe_merges]
+    bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
+
+    return Tokenizer(bpe_ranks, vocab)
+

+ 629 - 0
exo/inference/mlx/models/sd_models/unet.py

@@ -0,0 +1,629 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
+
+import math
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from dataclasses import dataclass, field
+from typing import Tuple, Optional, List
+from exo.inference.shard import Shard
+
+@dataclass
+class UNetConfig:
+    in_channels: int = 4
+    out_channels: int = 4
+    conv_in_kernel: int = 3
+    conv_out_kernel: int = 3
+    block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
+    layers_per_block: Tuple[int] = (2, 2, 2, 2)
+    mid_block_layers: int = 2
+    transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
+    num_attention_heads: Tuple[int] = (5, 10, 20, 20)
+    cross_attention_dim: Tuple[int] = (1024,) * 4
+    norm_num_groups: int = 32
+    down_block_types: Tuple[str] = (
+        "CrossAttnDownBlock2D",
+        "CrossAttnDownBlock2D",
+        "CrossAttnDownBlock2D",
+        "DownBlock2D",
+    )
+    up_block_types: Tuple[str] = (
+        "UpBlock2D",
+        "CrossAttnUpBlock2D",
+        "CrossAttnUpBlock2D",
+        "CrossAttnUpBlock2D",
+    )
+    addition_embed_type: Optional[str] = None
+    addition_time_embed_dim: Optional[int] = None
+    projection_class_embeddings_input_dim: Optional[int] = None
+    weight_files: List[str] = field(default_factory=lambda: [])
+
+
+
+    @classmethod
+    def from_dict(cls,config):
+        n_blocks = len(config['block_out_channels'])
+        return UNetConfig(
+            in_channels=config["in_channels"],
+            out_channels=config["out_channels"],
+            block_out_channels=config["block_out_channels"],
+            layers_per_block=[config["layers_per_block"]] * n_blocks,
+            transformer_layers_per_block=config.get(
+                "transformer_layers_per_block", (1,) * 4
+            ),
+            num_attention_heads=(
+                [config["attention_head_dim"]] * n_blocks
+                if isinstance(config["attention_head_dim"], int)
+                else config["attention_head_dim"]
+            ),
+            cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
+            norm_num_groups=config["norm_num_groups"],
+            down_block_types=config["down_block_types"],
+            up_block_types=config["up_block_types"][::-1],
+            addition_embed_type=config.get("addition_embed_type", None),
+            addition_time_embed_dim=config.get("addition_time_embed_dim", None),
+            projection_class_embeddings_input_dim=config.get(
+                "projection_class_embeddings_input_dim", None
+            ),
+            weight_files=config.get("weight_files", [])
+
+        )
+
+
+def upsample_nearest(x, scale: int = 2):
+    B, H, W, C = x.shape
+    x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
+    x = x.reshape(B, H * scale, W * scale, C)
+
+    return x
+
+
+class TimestepEmbedding(nn.Module):
+    def __init__(self, in_channels: int, time_embed_dim: int):
+        super().__init__()
+
+        self.linear_1 = nn.Linear(in_channels, time_embed_dim)
+        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+    def __call__(self, x):
+        x = self.linear_1(x)
+        x = nn.silu(x)
+        x = self.linear_2(x)
+
+        return x
+
+
+class TransformerBlock(nn.Module):
+    def __init__(
+        self,
+        model_dims: int,
+        num_heads: int,
+        hidden_dims: Optional[int] = None,
+        memory_dims: Optional[int] = None,
+    ):
+        super().__init__()
+
+        self.norm1 = nn.LayerNorm(model_dims)
+        self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
+        self.attn1.out_proj.bias = mx.zeros(model_dims)
+
+        memory_dims = memory_dims or model_dims
+        self.norm2 = nn.LayerNorm(model_dims)
+        self.attn2 = nn.MultiHeadAttention(
+            model_dims, num_heads, key_input_dims=memory_dims
+        )
+        self.attn2.out_proj.bias = mx.zeros(model_dims)
+
+        hidden_dims = hidden_dims or 4 * model_dims
+        self.norm3 = nn.LayerNorm(model_dims)
+        self.linear1 = nn.Linear(model_dims, hidden_dims)
+        self.linear2 = nn.Linear(model_dims, hidden_dims)
+        self.linear3 = nn.Linear(hidden_dims, model_dims)
+
+    def __call__(self, x, memory, attn_mask, memory_mask):
+        # Self attention
+        y = self.norm1(x)
+        y = self.attn1(y, y, y, attn_mask)
+        x = x + y
+
+        # Cross attention
+        y = self.norm2(x)
+        y = self.attn2(y, memory, memory, memory_mask)
+        x = x + y
+
+        # FFN
+        y = self.norm3(x)
+        y_a = self.linear1(y)
+        y_b = self.linear2(y)
+        y = y_a * nn.gelu(y_b)
+        y = self.linear3(y)
+        x = x + y
+
+        return x
+
+
+class Transformer2D(nn.Module):
+    """A transformer model for inputs with 2 spatial dimensions."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        model_dims: int,
+        encoder_dims: int,
+        num_heads: int,
+        num_layers: int = 1,
+        norm_num_groups: int = 32,
+    ):
+        super().__init__()
+
+        self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
+        self.proj_in = nn.Linear(in_channels, model_dims)
+        self.transformer_blocks = [
+            TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
+            for i in range(num_layers)
+        ]
+        self.proj_out = nn.Linear(model_dims, in_channels)
+
+    def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
+        # Save the input to add to the output
+        input_x = x
+        dtype = x.dtype
+
+        # Perform the input norm and projection
+        B, H, W, C = x.shape
+        x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
+        x = self.proj_in(x)
+
+        # Apply the transformer
+        for block in self.transformer_blocks:
+            x = block(x, encoder_x, attn_mask, encoder_attn_mask)
+
+        # Apply the output projection and reshape
+        x = self.proj_out(x)
+        x = x.reshape(B, H, W, C)
+
+        return x + input_x
+
+
+class ResnetBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: Optional[int] = None,
+        groups: int = 32,
+        temb_channels: Optional[int] = None,
+    ):
+        super().__init__()
+
+        out_channels = out_channels or in_channels
+
+        self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
+        self.conv1 = nn.Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+        if temb_channels is not None:
+            self.time_emb_proj = nn.Linear(temb_channels, out_channels)
+        self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
+        self.conv2 = nn.Conv2d(
+            out_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+
+        if in_channels != out_channels:
+            self.conv_shortcut = nn.Linear(in_channels, out_channels)
+
+    def __call__(self, x, temb=None):
+        dtype = x.dtype
+
+        if temb is not None:
+            temb = self.time_emb_proj(nn.silu(temb))
+        y = self.norm1(x.astype(mx.float32)).astype(dtype)
+        
+        y = nn.silu(y)
+        
+        y = self.conv1(y)
+
+        
+        if temb is not None:
+            y = y + temb[:, None, None, :]
+        y = self.norm2(y.astype(mx.float32)).astype(dtype)
+        y = nn.silu(y)
+        y = self.conv2(y)
+        
+        x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
+        return x
+
+
+class UNetBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        prev_out_channels: Optional[int] = None,
+        num_layers: int = 1,
+        transformer_layers_per_block: int = 1,
+        num_attention_heads: int = 8,
+        cross_attention_dim=1280,
+        resnet_groups: int = 32,
+        add_downsample=True,
+        add_upsample=True,
+        add_cross_attention=True,
+    ):
+        super().__init__()
+
+        # Prepare the in channels list for the resnets
+        if prev_out_channels is None:
+            in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
+        else:
+            in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
+            res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
+            in_channels_list = [
+                a + b for a, b in zip(in_channels_list, res_channels_list)
+            ]
+
+        # Add resnet blocks that also process the time embedding
+        self.resnets = [
+            ResnetBlock2D(
+                in_channels=ic,
+                out_channels=out_channels,
+                temb_channels=temb_channels,
+                groups=resnet_groups,
+            )
+            for ic in in_channels_list
+        ]
+
+        # Add optional cross attention layers
+        if add_cross_attention:
+            self.attentions = [
+                Transformer2D(
+                    in_channels=out_channels,
+                    model_dims=out_channels,
+                    num_heads=num_attention_heads,
+                    num_layers=transformer_layers_per_block,
+                    encoder_dims=cross_attention_dim,
+                )
+                for i in range(num_layers)
+            ]
+
+        # Add an optional downsampling layer
+        if add_downsample:
+            self.downsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=2, padding=1
+            )
+
+        # or upsampling layer
+        if add_upsample:
+            self.upsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=1, padding=1
+            )
+
+    def __call__(
+        self,
+        x,
+        encoder_x=None,
+        temb=None,
+        attn_mask=None,
+        encoder_attn_mask=None,
+        residual_hidden_states=None,
+    ):
+        output_states = []
+
+        for i in range(len(self.resnets)):
+            if residual_hidden_states is not None:
+                x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
+
+            x = self.resnets[i](x, temb)
+
+            if "attentions" in self:
+                x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
+
+            output_states.append(x)
+
+        if "downsample" in self:
+            x = self.downsample(x)
+            output_states.append(x)
+
+        if "upsample" in self:
+            x = self.upsample(upsample_nearest(x))
+            output_states.append(x)
+
+        return x, output_states
+
+
+class UNetModel(nn.Module):
+    """The conditional 2D UNet model that actually performs the denoising."""
+
+    def __init__(self, config: UNetConfig, shard: Shard):
+        super().__init__()
+        self.shard = shard
+        self.start_layer = shard.start_layer
+        self.end_layer = shard.end_layer
+        self.layers_range = list(range(self.start_layer, self.end_layer+1))
+        if shard.is_first_layer(): 
+            self.conv_in = nn.Conv2d(
+                config.in_channels,
+                config.block_out_channels[0],
+                config.conv_in_kernel,
+                padding=(config.conv_in_kernel - 1) // 2,
+            )
+
+        self.timesteps = nn.SinusoidalPositionalEncoding(
+            config.block_out_channels[0],
+            max_freq=1,
+            min_freq=math.exp(
+                -math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
+            ),
+            scale=1.0,
+            cos_first=True,
+            full_turns=False,
+        )
+        self.time_embedding = TimestepEmbedding(
+            config.block_out_channels[0],
+            config.block_out_channels[0] * 4,
+        )
+
+        if config.addition_embed_type == "text_time":
+            self.add_time_proj = nn.SinusoidalPositionalEncoding(
+                config.addition_time_embed_dim,
+                max_freq=1,
+                min_freq=math.exp(
+                    -math.log(10000)
+                    + 2 * math.log(10000) / config.addition_time_embed_dim
+                ),
+                scale=1.0,
+                cos_first=True,
+                full_turns=False,
+            )
+            self.add_embedding = TimestepEmbedding(
+                config.projection_class_embeddings_input_dim,
+                config.block_out_channels[0] * 4,
+            )
+
+        # Make the downsampling blocks
+        block_channels = [config.block_out_channels[0]] + list(
+            config.block_out_channels
+        )
+        self.down_blocks = []
+
+        for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):    
+            if i in self.layers_range:
+                self.down_blocks.append(
+                    UNetBlock2D(
+                        in_channels=in_channels,
+                out_channels=out_channels,
+                temb_channels=config.block_out_channels[0] * 4,
+                num_layers=config.layers_per_block[i],
+                transformer_layers_per_block=config.transformer_layers_per_block[i],
+                num_attention_heads=config.num_attention_heads[i],
+                cross_attention_dim=config.cross_attention_dim[i],
+                resnet_groups=config.norm_num_groups,
+                add_downsample=(i < len(config.block_out_channels) - 1),
+                add_upsample=False,
+                        add_cross_attention="CrossAttn" in config.down_block_types[i],
+                    )
+                )
+            else:
+                self.down_blocks.append(nn.Identity())
+            
+         
+        # Make the middle block
+        if 4 in self.layers_range:
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=config.block_out_channels[-1],
+                    out_channels=config.block_out_channels[-1],
+                    temb_channels=config.block_out_channels[0] * 4,
+                    groups=config.norm_num_groups,
+                ),
+                Transformer2D(
+                    in_channels=config.block_out_channels[-1],
+                    model_dims=config.block_out_channels[-1],
+                    num_heads=config.num_attention_heads[-1],
+                    num_layers=config.transformer_layers_per_block[-1],
+                    encoder_dims=config.cross_attention_dim[-1],
+                ),
+                ResnetBlock2D(
+                    in_channels=config.block_out_channels[-1],
+                    out_channels=config.block_out_channels[-1],
+                    temb_channels=config.block_out_channels[0] * 4,
+                    groups=config.norm_num_groups,
+                ),
+            ]
+
+        # Make the upsampling blocks
+        block_channels = (
+            [config.block_out_channels[0]]
+            + list(config.block_out_channels)
+            + [config.block_out_channels[-1]]
+        )
+
+        total_items = len(block_channels) - 3
+        reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
+
+        self.up_blocks = []
+        for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):  
+            i = total_items - rev_i 
+            if rev_i+5 in self.layers_range:
+                self.up_blocks.append(
+                    UNetBlock2D(
+                        in_channels=in_channels,
+                out_channels=out_channels,
+                temb_channels=config.block_out_channels[0] * 4,
+                prev_out_channels=prev_out_channels,
+                num_layers=config.layers_per_block[i] + 1,
+                transformer_layers_per_block=config.transformer_layers_per_block[i],
+                num_attention_heads=config.num_attention_heads[i],
+                cross_attention_dim=config.cross_attention_dim[i],
+                resnet_groups=config.norm_num_groups,
+                add_downsample=False,
+                add_upsample=(i > 0),
+                        add_cross_attention="CrossAttn" in config.up_block_types[i],
+                    )
+                )
+            else:
+                self.up_blocks.append(nn.Identity())
+            
+        
+        if shard.is_last_layer():
+            self.conv_norm_out = nn.GroupNorm(
+                config.norm_num_groups,
+                config.block_out_channels[0],
+                pytorch_compatible=True,
+            )
+            self.conv_out = nn.Conv2d(
+                config.block_out_channels[0],
+                config.out_channels,
+                config.conv_out_kernel,
+                padding=(config.conv_out_kernel - 1) // 2,
+            )
+
+    def __call__(
+        self,
+        x,
+        timestep,
+        encoder_x,
+        attn_mask=None,
+        encoder_attn_mask=None,
+        text_time=None,
+        residuals=None,
+    ):
+        # Compute the time embeddings
+        
+        temb = self.timesteps(timestep).astype(x.dtype)
+        temb = self.time_embedding(temb)
+
+        # Add the extra text_time conditioning
+        if text_time is not None:
+            text_emb, time_ids = text_time
+            emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
+            emb = mx.concatenate([text_emb, emb], axis=-1)
+            emb = self.add_embedding(emb)
+            temb = temb + emb
+        
+        if self.shard.is_first_layer():
+            # Preprocess the input
+            x = self.conv_in(x)
+            residuals = [x]
+        # Run the downsampling part of the unet
+        
+        for i in range(len(self.down_blocks)):
+            if i in self.layers_range:
+                x, res = self.down_blocks[i](
+                    x,
+                    encoder_x=encoder_x,
+                    temb=temb,
+                    attn_mask=attn_mask,
+                    encoder_attn_mask=encoder_attn_mask,
+                )
+                residuals.extend(res)
+            else:
+                x= self.down_blocks[i](x)
+
+        if 4 in self.layers_range:
+            # Run the middle part of the unet
+            x = self.mid_blocks[0](x, temb)
+            x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
+            x = self.mid_blocks[2](x, temb)
+
+        # Run the upsampling part of the unet
+        for i in range(len(self.up_blocks)):
+            if i+5 in self.layers_range:
+                x, _ = self.up_blocks[i](
+                    x,
+                    encoder_x=encoder_x,
+                    temb=temb,
+                    attn_mask=attn_mask,
+                    encoder_attn_mask=encoder_attn_mask,
+                    residual_hidden_states=residuals,
+                )
+            else:
+                x= self.up_blocks[i](x)
+
+        # Postprocess the output
+        if self.shard.is_last_layer():
+            dtype = x.dtype
+            x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
+            x = nn.silu(x)
+            x = self.conv_out(x)
+
+        return x, residuals
+    def sanitize(self, weights):
+        sanitized_weights = {}
+        for key, value in weights.items():
+            k1=""
+            k2=""
+            if "downsamplers" in key:
+                key = key.replace("downsamplers.0.conv", "downsample")
+            if "upsamplers" in key:
+                key = key.replace("upsamplers.0.conv", "upsample")
+
+            # Map the mid block
+            if "mid_block.resnets.0" in key:
+                key = key.replace("mid_block.resnets.0", "mid_blocks.0")
+            if "mid_block.attentions.0" in key:
+                key = key.replace("mid_block.attentions.0", "mid_blocks.1")
+            if "mid_block.resnets.1" in key:
+                key = key.replace("mid_block.resnets.1", "mid_blocks.2")
+
+            # Map attention layers
+            if "to_k" in key:
+                key = key.replace("to_k", "key_proj")
+            if "to_out.0" in key:
+                key = key.replace("to_out.0", "out_proj")
+            if "to_q" in key:
+                key = key.replace("to_q", "query_proj")
+            if "to_v" in key:
+                key = key.replace("to_v", "value_proj")
+
+            # Map transformer ffn
+            if "ff.net.2" in key:
+                key = key.replace("ff.net.2", "linear3")
+            if "ff.net.0" in key:
+                k1 = key.replace("ff.net.0.proj", "linear1")
+                k2 = key.replace("ff.net.0.proj", "linear2")
+                v1, v2 = mx.split(value, 2)
+                
+
+            if "conv_shortcut.weight" in key:
+                value = value.squeeze()
+            
+            # Transform the weights from 1x1 convs to linear
+            if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
+                value = value.squeeze()
+
+            if len(value.shape) == 4:
+                value = value.transpose(0, 2, 3, 1)
+                value = value.reshape(-1).reshape(value.shape)
+
+            if key.startswith("conv_in") :
+                if 0 not in self.layers_range:
+                    continue
+            
+            if key.startswith("down_blocks"):
+                layer_num = int(key.split(".")[1])
+                if layer_num not in self.layers_range:
+                    continue
+
+            if key.startswith("mid_block"):
+                if 4 not in self.layers_range:
+                    continue
+
+            if key.startswith("up_blocks"):
+                layer_num = int(key.split(".")[1])
+                if (layer_num+5) not in self.layers_range:
+                    continue
+            
+            if key.startswith("conv_out") or key.startswith("conv_norm_out"):
+                if 8 not in self.layers_range:
+                    continue
+
+            if len(k1)>0:
+                sanitized_weights[k1] = v1
+                sanitized_weights[k2] = v2
+            else:
+                sanitized_weights[key] = value
+
+
+        return sanitized_weights

+ 429 - 0
exo/inference/mlx/models/sd_models/vae.py

@@ -0,0 +1,429 @@
+# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
+
+import math
+from typing import List
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .unet import ResnetBlock2D, upsample_nearest
+from dataclasses import dataclass, field
+from exo.inference.shard import Shard
+from typing import Tuple
+import inspect
+from ..base import IdentityBlock
+
+@dataclass
+class AutoencoderConfig:
+    in_channels: int = 3
+    out_channels: int = 3
+    latent_channels_out: int = 8
+    latent_channels_in: int = 4
+    block_out_channels: Tuple[int] = (128, 256, 512, 512)
+    layers_per_block: int = 2
+    norm_num_groups: int = 32
+    scaling_factor: float = 0.18215
+    weight_files: List[str] = field(default_factory=lambda: [])
+    @classmethod
+    def from_dict(cls, params):
+        return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
+
+
+@dataclass
+class ModelArgs(AutoencoderConfig):
+    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+    def __post_init__(self):
+        if isinstance(self.shard, dict):
+            self.shard = Shard(**self.shard)
+
+        if not isinstance(self.shard, Shard):
+            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+        if not self.shard.is_first_layer():
+            self.vision_config = None
+
+
+class Attention(nn.Module):
+    """A single head unmasked attention for use with the VAE."""
+
+    def __init__(self, dims: int, norm_groups: int = 32):
+        super().__init__()
+
+        self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
+        self.query_proj = nn.Linear(dims, dims)
+        self.key_proj = nn.Linear(dims, dims)
+        self.value_proj = nn.Linear(dims, dims)
+        self.out_proj = nn.Linear(dims, dims)
+
+    def __call__(self, x):
+        B, H, W, C = x.shape
+
+        y = self.group_norm(x)
+
+        queries = self.query_proj(y).reshape(B, H * W, C)
+        keys = self.key_proj(y).reshape(B, H * W, C)
+        values = self.value_proj(y).reshape(B, H * W, C)
+
+        scale = 1 / math.sqrt(queries.shape[-1])
+        scores = (queries * scale) @ keys.transpose(0, 2, 1)
+        attn = mx.softmax(scores, axis=-1)
+        y = (attn @ values).reshape(B, H, W, C)
+
+        y = self.out_proj(y)
+        x = x + y
+
+        return x
+
+
+class EncoderDecoderBlock2D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        num_layers: int = 1,
+        resnet_groups: int = 32,
+        add_downsample=True,
+        add_upsample=True,
+    ):
+        super().__init__()
+
+        # Add the resnet blocks
+        self.resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels if i == 0 else out_channels,
+                out_channels=out_channels,
+                groups=resnet_groups,
+            )
+            for i in range(num_layers)
+        ]
+
+        # Add an optional downsampling layer
+        if add_downsample:
+            self.downsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=2, padding=0
+            )
+
+        # or upsampling layer
+        if add_upsample:
+            self.upsample = nn.Conv2d(
+                out_channels, out_channels, kernel_size=3, stride=1, padding=1
+            )
+
+    def __call__(self, x):
+        for resnet in self.resnets:
+            x = resnet(x)
+        if "downsample" in self:
+            x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
+            x = self.downsample(x)
+
+        if "upsample" in self:
+            x = self.upsample(upsample_nearest(x))
+        return x
+
+
+class Encoder(nn.Module):
+    """Implements the encoder side of the Autoencoder."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        latent_channels_out: int,
+        block_out_channels: List[int] = [64],
+        layers_per_block: int = 2,
+        resnet_groups: int = 32,
+        layers_range: List[int] = [],
+        shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+    ):
+        super().__init__()
+        self.layers_range = layers_range
+        self.shard = shard
+        if self.shard.is_first_layer():
+            self.conv_in = nn.Conv2d(
+                in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
+            )
+
+        channels = [block_out_channels[0]] + list(block_out_channels)
+        self.down_blocks = []
+        current_layer = 1
+        for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
+            if current_layer in self.layers_range:
+                self.down_blocks.append(
+                    EncoderDecoderBlock2D(
+                        in_channels,
+                        out_channels,
+                        num_layers=layers_per_block,
+                        resnet_groups=resnet_groups,
+                        add_downsample=i < len(block_out_channels) - 1,
+                        add_upsample=False,
+                    )
+                )
+            else:
+                self.down_blocks.append(IdentityBlock())
+            current_layer += 1
+
+        if self.shard.is_last_layer():
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+                Attention(block_out_channels[-1], resnet_groups),
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+            ]
+
+            self.conv_norm_out = nn.GroupNorm(
+                resnet_groups, block_out_channels[-1], pytorch_compatible=True
+            )
+            self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
+
+    def __call__(self, x):
+        if self.shard.is_first_layer():
+            x = self.conv_in(x)
+
+        for l in self.down_blocks:
+            x = l(x)
+
+        if self.shard.is_last_layer():
+            x = self.mid_blocks[0](x)
+            x = self.mid_blocks[1](x)
+            x = self.mid_blocks[2](x)
+
+            x = self.conv_norm_out(x)
+            x = nn.silu(x)
+            x = self.conv_out(x)
+
+        return x
+
+
+class Decoder(nn.Module):
+    """Implements the decoder side of the Autoencoder."""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        shard: Shard,
+        layer_range: List[int],
+        block_out_channels: List[int] = [64],
+        layers_per_block: int = 2,
+        resnet_groups: int = 32,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.layers_range = layer_range
+        if 0 in layer_range:
+            self.conv_in = nn.Conv2d(
+                in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
+            )
+        
+        if 0 in layer_range:
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+                Attention(block_out_channels[-1], resnet_groups),
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+            ]
+
+        channels = list(reversed(block_out_channels))
+        channels = [channels[0]] + channels
+        
+        self.up_blocks = []
+        current_layer = 1
+
+        for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
+            if current_layer in layer_range:
+                self.up_blocks.append(
+                    EncoderDecoderBlock2D(
+                        in_channels,
+                        out_channels,
+                        num_layers=layers_per_block,
+                        resnet_groups=resnet_groups,
+                        add_downsample=False,
+                        add_upsample=i < len(block_out_channels) - 1,
+                    )
+                )
+            else:
+                self.up_blocks.append(IdentityBlock())
+            current_layer += 1
+        if 4 in layer_range:
+            self.conv_norm_out = nn.GroupNorm(
+                resnet_groups, block_out_channels[0], pytorch_compatible=True
+            )
+            self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
+
+
+    def __call__(self, x):
+        if 0 in self.layers_range:
+            x = self.conv_in(x)
+            x = self.mid_blocks[0](x)
+            x = self.mid_blocks[1](x)
+            x = self.mid_blocks[2](x)
+        
+        for l in self.up_blocks:
+            x = l(x)
+        if 4 in self.layers_range:
+            x = self.conv_norm_out(x)
+            x = nn.silu(x)
+            x = self.conv_out(x)
+        return x
+
+
+class Autoencoder(nn.Module):
+    """The autoencoder that allows us to perform diffusion in the latent space."""
+
+    def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
+        super().__init__()
+        self.shard = shard
+        self.start_layer = shard.start_layer
+        self.end_layer = shard.end_layer
+        self.layers_range = list(range(self.start_layer, self.end_layer+1))
+        self.latent_channels = config.latent_channels_in
+        self.scaling_factor = config.scaling_factor
+        self.model_shard = model_shard
+        if self.model_shard == "vae_encoder":
+            self.encoder = Encoder(
+                config.in_channels,
+                config.latent_channels_out,
+                config.block_out_channels,
+                config.layers_per_block,
+                resnet_groups=config.norm_num_groups,
+                layers_range=self.layers_range,
+                shard=shard
+            )
+            if self.shard.is_last_layer():
+                self.quant_proj = nn.Linear(
+                config.latent_channels_out, config.latent_channels_out
+                )
+        if self.model_shard == "vae_decoder":
+            self.decoder = Decoder(
+                config.latent_channels_in,
+                config.out_channels,
+                shard,
+                self.layers_range,
+                config.block_out_channels,
+                config.layers_per_block + 1,
+                resnet_groups=config.norm_num_groups,
+            )
+            if self.shard.is_first_layer():
+                self.post_quant_proj = nn.Linear(
+                    config.latent_channels_in, config.latent_channels_in
+                )
+
+    def decode(self, z):
+        if self.shard.is_first_layer():
+            z = z / self.scaling_factor
+            z=self.post_quant_proj(z)
+        return self.decoder(z)
+
+    def encode(self, x):
+        x = self.encoder(x)
+        if self.shard.is_last_layer():   
+            x = self.quant_proj(x)
+            mean, logvar = x.split(2, axis=-1)
+            mean = mean * self.scaling_factor
+            logvar = logvar + 2 * math.log(self.scaling_factor)
+            x = mean
+        return x
+
+    def __call__(self, x, key=None):
+        mean, logvar = self.encode(x)
+        z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
+        x_hat = self.decode(z)
+
+        return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
+
+    def sanitize(self, weights):
+        shard = self.shard
+        layers = self.layers_range
+        sanitized_weights = {}
+        for key, value in weights.items():
+
+            if "downsamplers" in key:
+                key = key.replace("downsamplers.0.conv", "downsample")
+            if "upsamplers" in key:
+                key = key.replace("upsamplers.0.conv", "upsample")
+
+            # Map attention layers
+            if "key" in key:
+                key = key.replace("key", "key_proj")
+            if "proj_attn" in key:
+                key = key.replace("proj_attn", "out_proj")
+            if "query" in key:
+                key = key.replace("query", "query_proj")
+            if "value" in key:
+                key = key.replace("value", "value_proj")
+
+            # Map the mid block
+            if "mid_block.resnets.0" in key:
+                key = key.replace("mid_block.resnets.0", "mid_blocks.0")
+            if "mid_block.attentions.0" in key:
+                key = key.replace("mid_block.attentions.0", "mid_blocks.1")
+            if "mid_block.resnets.1" in key:
+                key = key.replace("mid_block.resnets.1", "mid_blocks.2")
+    
+            # Map the quant/post_quant layers
+            if "quant_conv" in key:
+                key = key.replace("quant_conv", "quant_proj")
+                value = value.squeeze()
+                
+            # Map the conv_shortcut to linear
+            if "conv_shortcut.weight" in key:
+                value = value.squeeze()
+
+            if len(value.shape) == 4:
+                value = value.transpose(0, 2, 3, 1)
+                value = value.reshape(-1).reshape(value.shape)
+
+
+            if "post_quant_conv" in key :
+                key = key.replace("quant_conv", "quant_proj")
+                value = value.squeeze()
+            
+            if 'decoder' in key and self.model_shard == "vae_decoder":
+                if key.startswith("decoder.mid_blocks."):
+                    if 0 in layers:
+                        sanitized_weights[key] = value
+                if "conv_in" in key and 0 in layers:
+                    sanitized_weights[key] = value
+                if key.startswith("decoder.up_blocks."):
+                    layer_num = int(key.split(".")[2])+1
+                    if layer_num in layers:
+                        sanitized_weights[key] = value
+                if key.startswith("decoder.conv_norm_out") and 4 in layers:
+                    sanitized_weights[key] = value
+                if key.startswith("decoder.conv_out") and 4 in layers:
+                    sanitized_weights[key] = value
+            if self.model_shard == "vae_decoder":
+                if key.startswith("post_quant_proj") and 0 in layers:
+                    sanitized_weights[key] = value
+            if self.model_shard == "vae_encoder":
+                if key.startswith("encoder."):
+                    if "conv_in" in key and shard.is_first_layer():
+                        sanitized_weights[key] = value
+                    if key.startswith("encoder.down_blocks."):
+                        layer_num = int(key.split(".")[2])+1
+                        if layer_num in layers:
+                            sanitized_weights[key] = value
+                    if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                    if "conv_norm_out" in key and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                    if "conv_out" in key and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                if key.startswith("quant_proj") and shard.is_last_layer():
+                    sanitized_weights[key] = value
+        return sanitized_weights
+

+ 8 - 4
exo/inference/mlx/sharded_inference_engine.py

@@ -77,13 +77,17 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     await self.ensure_shard(shard)
     await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
     
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     loop = asyncio.get_running_loop()
-    state = await self.poll_state(request_id)
+    state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
-    output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
-    return output_data
+    if self.model.model_type != 'StableDiffusionPipeline':
+      output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
+    else:
+      output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **inference_state))
+    output_data = np.array(output_data)
+    return output_data, inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)

+ 56 - 15
exo/inference/mlx/sharded_utils.py

@@ -62,8 +62,16 @@ def _get_classes(config: dict):
 
 def load_config(model_path: Path) -> dict:
   try:
-    with open(model_path/"config.json", "r") as f:
-      config = json.load(f)
+    config_path = model_path / "config.json"
+    if config_path.exists():
+      with open(config_path, "r") as f:
+        config = json.load(f)
+      return config
+    
+    model_index_path = model_path / "model_index.json"
+    if model_index_path.exists():
+      config = load_model_index(model_path, model_index_path)
+      return config
   except FileNotFoundError:
     logging.error(f"Config file not found in {model_path}")
     raise
@@ -110,6 +118,24 @@ def load_model_shard(
     # Try weight for back-compat
     weight_files = glob.glob(str(model_path/"weight*.safetensors"))
 
+  model_class, model_args_class = _get_classes(config=config)
+
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x, *args, **kwargs)
+      return y
+
+  model_args = model_args_class.from_dict(config)
+  model = ShardedModel(model_args)
+
+  if config.get("model_index", False):
+    model.load()
+    return model
+
   if not weight_files:
     logging.error(f"No safetensors found in {model_path}")
     raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -129,19 +155,7 @@ def load_model_shard(
 
     weights.update(mx.load(wf))
 
-  model_class, model_args_class = _get_classes(config=config)
-
-  class ShardedModel(model_class):
-    def __init__(self, args):
-      super().__init__(args)
-      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
-
-    def __call__(self, x, *args, **kwargs):
-      y = super().__call__(x, *args, **kwargs)
-      return y
-
-  model_args = model_args_class.from_dict(config)
-  model = ShardedModel(model_args)
+  
 
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
@@ -186,6 +200,9 @@ async def load_shard(
     processor.eos_token_id = processor.tokenizer.eos_token_id
     processor.encode = processor.tokenizer.encode
     return model, processor
+  elif hasattr(model, "tokenizer"):
+    tokenizer = model.tokenizer
+    return model, tokenizer
   else:
     tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
     return img
   else:
     raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
+
+# loading a combined config for all models in the index
+def load_model_index(model_path: Path, model_index_path: Path):
+  models_config = {}
+  with open(model_index_path, "r") as f:
+      model_index = json.load(f)
+  models_config["model_index"] = True
+  models_config["model_type"] = model_index["_class_name"]
+  models_config["models"] = {}
+  for model in model_index.keys():
+    model_config_path = glob.glob(str(model_path / model / "*config.json"))
+    if len(model_config_path)>0:
+      with open(model_config_path[0], "r") as f:
+        model_config = { }
+        model_config["model_type"] = model
+        model_config["config"] = json.load(f)
+        model_config["path"] = model_path / model
+        if model_config["path"]/"*model.safetensors":
+          model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
+        model_config["path"] = str(model_path / model)
+        m = {}
+        m[model] = model_config
+        models_config.update(m)
+  return models_config

+ 3 - 3
exo/inference/tinygrad/inference.py

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
 from .losses import length_masked_ce_loss
 from collections import OrderedDict
 import asyncio
-
+from typing import Optional
 Tensor.no_grad = True 
 # default settings
 TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
     safe_save(state_dict, path) 
   
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
     await self.ensure_shard(shard)
     def wrap_infer():
       x = Tensor(input_data)
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       self.states[request_id].start += x.shape[1]
       return out.realize()
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
-    return output_data.numpy()
+    return output_data.numpy(), inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
     def step(x, y, l):

+ 1 - 1
exo/main.py

@@ -151,7 +151,7 @@ api = ChatGPTAPI(
   system_prompt=args.system_prompt
 )
 node.on_token.register("update_topology_viz").on_next(
-  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
+  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
 )
 
 def preemptively_start_download(request_id: str, opaque_status: str):

+ 3 - 0
exo/models.py

@@ -111,6 +111,8 @@ model_cards = {
   # gemma
   "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
   "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
+  # stable diffusion
+  "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
   # phi
   "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
   "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
@@ -156,6 +158,7 @@ pretty_name = {
   "phi-4": "Phi-4",
   "llama-3-8b": "Llama 3 8B",
   "llama-3-70b": "Llama 3 70B",
+  "stable-diffusion-2-1-base": "Stable Diffusion 2.1",
 }
 
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

+ 41 - 4
exo/networking/grpc/grpc_peer_handle.py

@@ -11,7 +11,8 @@ from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
-
+import json
+import mlx.core as mx
 
 class GRPCPeerHandle(PeerHandle):
   def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -71,7 +72,7 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -81,6 +82,7 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
     )
     response = await self.stub.SendPrompt(request)
 
@@ -89,7 +91,7 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -99,6 +101,7 @@ class GRPCPeerHandle(PeerHandle):
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
     )
     response = await self.stub.SendTensor(request)
 
@@ -175,9 +178,43 @@ class GRPCPeerHandle(PeerHandle):
     return topology
 
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
+    tensor = None
+    if isinstance(result, np.ndarray):
+      tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
+      result = []
+    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
     await self.stub.SendResult(request)
 
   async def send_opaque_status(self, request_id: str, status: str) -> None:
     request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
     await self.stub.SendOpaqueStatus(request)
+
+  def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
+    proto_inference_state = node_service_pb2.InferenceState()
+    other_data = {}
+    for k, v in inference_state.items():
+        if isinstance(v, mx.array):
+            np_array = np.array(v)
+            tensor_data = node_service_pb2.Tensor(
+                tensor_data=np_array.tobytes(),
+                shape=list(np_array.shape),
+                dtype=str(np_array.dtype)
+            )
+            proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
+        elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+            tensor_list = node_service_pb2.TensorList()
+            for tensor in v:
+                np_array = np.array(tensor)
+                tensor_data = node_service_pb2.Tensor(
+                    tensor_data=np_array.tobytes(),
+                    shape=list(np_array.shape),
+                    dtype=str(np_array.dtype)
+                )
+                tensor_list.tensors.append(tensor_data)
+            proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
+        else:
+            # For non-tensor data, we'll still use JSON
+            other_data[k] = v
+    if other_data:
+      proto_inference_state.other_data_json = json.dumps(other_data)
+    return proto_inference_state

+ 30 - 2
exo/networking/grpc/grpc_server.py

@@ -8,6 +8,8 @@ from . import node_service_pb2_grpc
 from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.orchestration import Node
+import json
+import mlx.core as mx
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -50,7 +52,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     prompt = request.prompt
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, request_id)
+    inference_state = self.deserialize_inference_state(request.inference_state)
+    result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
     if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
@@ -65,7 +68,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
 
-    result = await self.node.process_tensor(shard, tensor, request_id)
+    inference_state = self.deserialize_inference_state(request.inference_state)
+
+    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
@@ -122,7 +127,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     request_id = request.request_id
     result = request.result
     is_finished = request.is_finished
+    img = request.tensor
     if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    result = list(result)
+    if len(img.tensor_data) > 0:
+      result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
     self.node.on_token.trigger_all(request_id, result, is_finished)
     return node_service_pb2.Empty()
 
@@ -135,3 +144,22 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
   async def HealthCheck(self, request, context):
     return node_service_pb2.HealthCheckResponse(is_healthy=True)
+
+  def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
+    inference_state = {}
+    
+    for k, tensor_data in inference_state_proto.tensor_data.items():
+        np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
+        inference_state[k] = mx.array(np_array)
+    
+    for k, tensor_list in inference_state_proto.tensor_list_data.items():
+        inference_state[k] = [
+            mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
+            for tensor in tensor_list.tensors
+        ]
+    
+    if inference_state_proto.other_data_json:
+        other_data = json.loads(inference_state_proto.other_data_json)
+        inference_state.update(other_data)
+    
+    return inference_state

+ 14 - 1
exo/networking/grpc/node_service.proto

@@ -24,12 +24,14 @@ message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
   optional string request_id = 3;
+  optional InferenceState inference_state = 4;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
   optional string request_id = 3;
+  optional InferenceState inference_state = 4;
 }
 
 message ExampleRequest {
@@ -61,6 +63,16 @@ message Tensor {
   string dtype = 3;
 }
 
+message TensorList {
+  repeated Tensor tensors = 1;
+}
+
+message InferenceState {
+  map<string, Tensor> tensor_data = 1;
+  map<string, TensorList> tensor_list_data = 2;
+  string other_data_json = 3;
+}
+
 message CollectTopologyRequest {
   repeated string visited = 1;
   int32 max_depth = 2;
@@ -96,7 +108,8 @@ message DeviceCapabilities {
 message SendResultRequest {
   string request_id = 1;
   repeated int32 result = 2;
-  bool is_finished = 3;
+  optional Tensor tensor = 3;
+  bool is_finished = 4;
 }
 
 message SendOpaqueStatusRequest {

Dosya farkı çok büyük olduğundan ihmal edildi
+ 2 - 2
exo/networking/grpc/node_service_pb2.py


+ 50 - 50
exo/networking/grpc/node_service_pb2_grpc.py

@@ -3,7 +3,7 @@
 import grpc
 import warnings
 
-from . import node_service_pb2 as node__service__pb2
+from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
 
 GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
 if _version_not_supported:
     raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
+        + f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,43 @@ class NodeServiceStub(object):
         """
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
-                request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Loss.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
                 _registered_method=True)
         self.GetInferenceResult = channel.unary_unary(
                 '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
                 _registered_method=True)
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
                 _registered_method=True)
         self.SendResult = channel.unary_unary(
                 '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendOpaqueStatus = channel.unary_unary(
                 '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.HealthCheck = channel.unary_unary(
                 '/node_service.NodeService/HealthCheck',
-                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
                 _registered_method=True)
 
 
@@ -132,43 +132,43 @@ def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
             ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
-                    request_deserializer=node__service__pb2.ExampleRequest.FromString,
-                    response_serializer=node__service__pb2.Loss.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
             ),
             'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
                     servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
             ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
             ),
             'SendResult': grpc.unary_unary_rpc_method_handler(
                     servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
             ),
             'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
                     servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
             ),
             'HealthCheck': grpc.unary_unary_rpc_method_handler(
                     servicer.HealthCheck,
-                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
             ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +196,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -223,8 +223,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -250,8 +250,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendExample',
-            node__service__pb2.ExampleRequest.SerializeToString,
-            node__service__pb2.Loss.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
             options,
             channel_credentials,
             insecure,
@@ -277,8 +277,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
             options,
             channel_credentials,
             insecure,
@@ -304,8 +304,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
             options,
             channel_credentials,
             insecure,
@@ -331,8 +331,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -358,8 +358,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -385,8 +385,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/HealthCheck',
-            node__service__pb2.HealthCheckRequest.SerializeToString,
-            node__service__pb2.HealthCheckResponse.FromString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
+            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
             options,
             channel_credentials,
             insecure,

+ 53 - 28
exo/orchestration/node.py

@@ -112,37 +112,49 @@ class Node:
     shard,
     result: np.ndarray,
     request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
   ):
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
-    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if shard.is_last_layer() and not is_finished:
-      token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
-      await self.inference_engine.ensure_shard(shard)
-      self.buffered_token_output[request_id][0].append(token.item())
-      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-      asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
-      forward = token.reshape(1, -1)
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+    if shard.model_id != 'stable-diffusion-2-1-base':
+      if request_id not in self.buffered_token_output:
+        self.buffered_token_output[request_id] = ([], False)
+      is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+      if shard.is_last_layer() and not is_finished:
+        token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
+        await self.inference_engine.ensure_shard(shard)
+        self.buffered_token_output[request_id][0].append(token.item())
+        is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+        asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
+        forward = token.reshape(1, -1)
+        intermediate_result = self.buffered_token_output[request_id][0]
+      else:
+        forward = result
     else:
+      await self.inference_engine.ensure_shard(shard)
+      is_finished = inference_state.get("is_finished", False)
+      intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
       forward = result
+    if shard.is_last_layer():
+      self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
+      asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
 
     if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      if shard.model_id != 'stable-diffusion-2-1-base':
+        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
       self.outstanding_requests.pop(request_id)
     else:
       self.outstanding_requests[request_id] = "waiting"
-      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
+
+    return  np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
 
-    return np.array(self.buffered_token_output[request_id][0])
 
   async def process_prompt(
     self,
     base_shard: Shard,
     prompt: str,
     request_id: Optional[str] = None,
+    inference_state: Optional[dict] = {},
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -160,7 +172,7 @@ class Node:
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, request_id)
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -181,7 +193,7 @@ class Node:
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
@@ -190,12 +202,12 @@ class Node:
     if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
       self.outstanding_requests[request_id] = "waiting"
-      resp = await self.forward_prompt(shard, prompt, request_id, 0)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
       return None
     else:
       self.outstanding_requests[request_id] = "processing"
-      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_inference_result(shard, result, request_id)
+      result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
+      ret = await self.process_inference_result(shard, result, request_id, inference_state)
       return result
 
   async def enqueue_example(
@@ -340,6 +352,7 @@ class Node:
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -358,7 +371,7 @@ class Node:
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id)
+    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -383,6 +396,7 @@ class Node:
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
+    inference_state: Optional[dict] = None,
   ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
@@ -391,8 +405,8 @@ class Node:
     if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
       self.outstanding_requests[request_id] = "processing"
-      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      ret = await self.process_inference_result(shard, result, request_id) 
+      result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
+      ret = await self.process_inference_result(shard, result, request_id, inference_state) 
       return ret
     except Exception as e:
       self.outstanding_requests.pop(request_id)
@@ -427,19 +441,20 @@ class Node:
     prompt: str,
     request_id: str,
     target_index: int,
+    inference_state: Optional[dict] = None,
   ) -> None:
     if DEBUG >= 1: print(f"target partition index: {target_index}")
     target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
     next_shard = self.get_current_shard(base_shard, target_index)
     if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
     if target_id == self.id:
-      await self.process_prompt(next_shard, prompt, request_id)
+      await self.process_prompt(next_shard, prompt, request_id, inference_state)
     else:
       target_peer = next((p for p in self.peers if p.id() == target_id), None)
       if not target_peer:
         raise ValueError(f"Peer for {target_index} not found")
       if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
-      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
   
   async def forward_tensor(
     self,
@@ -447,19 +462,20 @@ class Node:
     tensor: np.ndarray,
     request_id: str,
     target_index: int,
+    inference_state: Optional[dict] = None,
   ) -> None:
     if DEBUG >= 1: print(f"target partition index: {target_index}")
     target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
     next_shard = self.get_current_shard(base_shard, target_index)
     if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
     if target_id == self.id:
-      await self.process_tensor(next_shard, tensor, request_id)
+      await self.process_tensor(next_shard, tensor, request_id, inference_state)
     else:
       target_peer = next((p for p in self.peers if p.id() == target_id), None)
       if not target_peer:
         raise ValueError(f"Peer for {target_index} not found")
       if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
-      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
 
   def get_partition_index(self, offset: int = 0):
     if not self.partitioning_strategy:
@@ -632,3 +648,12 @@ class Node:
   @property
   def current_topology(self) -> Topology:
     return self.topology
+
+  def handle_stable_diffusion(self, inference_state, result):
+    if inference_state['is_step_finished']:
+      inference_state['step']+=1
+    progress = [inference_state['step'],inference_state['total_steps']]
+    intermediate_result = result
+    if progress[0] == progress[1]:
+      intermediate_result = result
+    return intermediate_result, inference_state

+ 20 - 2
exo/tinychat/index.html

@@ -182,7 +182,25 @@
           const div = document.createElement('div');
           div.className = `message message-role-${role}`;
           try {
-            div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+              if (content.includes('![Generated Image]')) {
+                const imageUrl = content.match(/\((.*?)\)/)[1];
+                const img = document.createElement('img');
+                img.src = imageUrl;
+                img.alt = 'Generated Image';
+                img.onclick = async () => {
+                  try {
+                    const response = await fetch(img.src);
+                    const blob = await response.blob();
+                    const file = new File([blob], 'image.png', { type: 'image/png' });
+                    handleImageUpload({ target: { files: [file] } });
+                  } catch (error) {
+                    console.error('Error fetching image:', error);
+                  }
+                };
+                div.appendChild(img);
+              } else {
+                div.innerHTML = DOMPurify.sanitize(marked.parse(content));
+              }
           } catch (e) {
             console.log(content);
             console.error(e);
@@ -266,7 +284,7 @@
 </span>
 </div>
 <div class="input">
-<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
+<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
 <i class="fas fa-image"></i>
 </button>
 <input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>

+ 96 - 39
exo/tinychat/index.js

@@ -228,53 +228,110 @@ document.addEventListener("alpine:init", () => {
             };
           }
         });
-        const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
-        if (containsImage) {
-          // Map all messages with string content to object with type text
-          apiMessages = apiMessages.map(msg => {
-            if (typeof msg.content === 'string') {
-              return {
-                ...msg,
-                content: [
-                  {
-                    type: "text",
-                    text: msg.content
-                  }
-                ]
-              };
-            }
-            return msg;
+        
+        if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
+          // Send a request to the image generation endpoint
+          console.log(apiMessages[apiMessages.length - 1].content)
+          console.log(this.cstate.selectedModel)  
+          console.log(this.endpoint)
+          const response = await fetch(`${this.endpoint}/image/generations`, {
+            method: "POST",
+            headers: {
+              "Content-Type": "application/json",
+            },
+            body: JSON.stringify({
+              "model": 'stable-diffusion-2-1-base',
+              "prompt": apiMessages[apiMessages.length - 1].content,
+              "image_url": this.imageUrl
+            }),
           });
+      
+          if (!response.ok) {
+            throw new Error("Failed to fetch");
+          }
+          const reader = response.body.getReader();
+          let done = false;
+          let gottenFirstChunk = false;
+  
+          while (!done) {
+            const { value, done: readerDone } = await reader.read();
+            done = readerDone;
+            const decoder = new TextDecoder();
+  
+            if (value) {
+              // Assume non-binary data (text) comes first
+              const chunk = decoder.decode(value, { stream: true });
+              const parsed = JSON.parse(chunk);
+              console.log(parsed)
+  
+              if (parsed.progress) {
+                if (!gottenFirstChunk) {
+                  this.cstate.messages.push({ role: "assistant", content: "" });
+                  gottenFirstChunk = true;
+                }
+                this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
+              }
+              else if (parsed.images) {
+                if (!gottenFirstChunk) {
+                  this.cstate.messages.push({ role: "assistant", content: "" });
+                  gottenFirstChunk = true;
+                }
+                const imageUrl = parsed.images[0].url;
+                console.log(imageUrl)
+                this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
+              }
+            }
+          }
         }
-
-
-        // start receiving server sent events
-        let gottenFirstChunk = false;
-        for await (
-          const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
-        ) {
-          if (!gottenFirstChunk) {
-            this.cstate.messages.push({ role: "assistant", content: "" });
-            gottenFirstChunk = true;
+        
+        else{        
+          const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
+          if (containsImage) {
+            // Map all messages with string content to object with type text
+            apiMessages = apiMessages.map(msg => {
+              if (typeof msg.content === 'string') {
+                return {
+                  ...msg,
+                  content: [
+                    {
+                      type: "text",
+                      text: msg.content
+                    }
+                  ]
+                };
+              }
+              return msg;
+            });
           }
 
-          // add chunk to the last message
-          this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
+          console.log(apiMessages)
+          //start receiving server sent events
+          let gottenFirstChunk = false;
+          for await (
+            const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
+          ) {
+            if (!gottenFirstChunk) {
+              this.cstate.messages.push({ role: "assistant", content: "" });
+              gottenFirstChunk = true;
+            }
 
-          // calculate performance tracking
-          tokens += 1;
-          this.total_tokens += 1;
-          if (start_time === 0) {
-            start_time = Date.now();
-            this.time_till_first = start_time - prefill_start;
-          } else {
-            const diff = Date.now() - start_time;
-            if (diff > 0) {
-              this.tokens_per_second = tokens / (diff / 1000);
+            // add chunk to the last message
+            this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
+
+            // calculate performance tracking
+            tokens += 1;
+            this.total_tokens += 1;
+            if (start_time === 0) {
+              start_time = Date.now();
+              this.time_till_first = start_time - prefill_start;
+            } else {
+              const diff = Date.now() - start_time;
+              if (diff > 0) {
+                this.tokens_per_second = tokens / (diff / 1000);
+              }
             }
           }
         }
-
         // Clean the cstate before adding it to histories
         const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
         cleanedCstate.messages = cleanedCstate.messages.map(msg => {

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor