Răsfoiți Sursa

Image to image generation

Pranav Veldurthi 5 luni în urmă
părinte
comite
ca0caad0ae

+ 25 - 1
exo/api/chatgpt_api.py

@@ -20,6 +20,9 @@ from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get
 from typing import Callable, Optional
 from PIL import Image
 import numpy as np
+import base64
+from io import BytesIO
+import mlx.core as mx
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -383,6 +386,7 @@ class ChatGPTAPI:
     stream = data.get("stream", False)
     model = data.get("model", "")
     prompt = data.get("prompt", "")
+    image_url = data.get("image_url", "")
     print(f"model: {model}, prompt: {prompt}, stream: {stream}")
     shard = build_base_shard(model, self.inference_engine_classname)
     print(f"shard: {shard}")
@@ -393,7 +397,11 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
     try:
-      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
+      if image_url != "" and image_url != None:
+        img = self.base64_decode(image_url)
+      else:
+        img = None
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
 
 
       response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
@@ -454,3 +462,19 @@ class ChatGPTAPI:
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     await site.start()
+
+  def base64_decode(self, base64_string):
+    #decode and reshape image
+    if base64_string.startswith('data:image'):
+        base64_string = base64_string.split(',')[1]
+    image_data = base64.b64decode(base64_string)
+    img = Image.open(BytesIO(image_data))
+    W, H = (dim - dim % 64 for dim in (img.width, img.height))
+    if W != img.width or H != img.height:
+        print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
+        img = img.resize((W, H), Image.NEAREST)  # use desired downsampling filter
+    img = mx.array(np.array(img))
+    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
+    img = img[None]
+    return img
+  

+ 45 - 25
exo/inference/mlx/models/StableDiffusionPipeline.py

@@ -157,7 +157,7 @@ class Model(nn.Module):
         self.config = config
         self.model_path = config.vae['path'].split('/vae')[0]
         self.shard = config.shard
-        self.shard_clip, self.shard_unet, self.shard_vae  = model_shards(config.shard)
+        self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder  = model_shards(config.shard)
         self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
         if self.shard_clip.start_layer != -1:
             self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
@@ -172,26 +172,41 @@ class Model(nn.Module):
         else:
             self.unet = nn.Identity()
         self.config_vae=VAEArgs.from_dict(config.vae['config'])
-        if self.shard_vae.start_layer != -1:
-            self.first_stage_model=Autoencoder(self.config_vae, self.shard_vae) 
+        if self.shard_encoder.start_layer != -1:
+            self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder") 
         else:
-            self.first_stage_model = nn.Identity()            
+            self.encoder = nn.Identity()            
+        if self.shard_decoder.start_layer != -1:
+            self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder") 
+        else:
+            self.decoder = nn.Identity()            
 
-    def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False):
-        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps)
+    def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.7, start_step=None):
+        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
         is_finished = False
         is_step_finished = False
         if t.item()==1000:
             if self.shard_clip.start_layer == 0:
                 conditioning = x
             if self.shard_clip.start_layer != -1:
-                
                 conditioning, mask= self.text_encoder(conditioning,mask)
             seed = int(time.time()) 
             mx.random.seed(seed)
-            if self.shard_unet.is_first_layer():
-                x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
-                x_t_prev=x
+            if image is None:
+                if self.shard_encoder.is_last_layer():
+                    x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
+                    x_t_prev=x
+                    start_step = self.sampler.max_time
+            else:
+                if self.shard_encoder.start_layer != -1:
+                    image= self.encoder.encode(image)
+                    if self.shard_encoder.is_last_layer():
+                        start_step = self.sampler.max_time*strength
+                        total_steps = int(total_steps*strength)
+                        image = mx.broadcast_to(image, (1,) + image.shape[1:])
+                        x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
+                        image = None
+                        t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
         # Perform the denoising loop
         if self.shard_unet.start_layer != -1:
             with tqdm(total=total_steps,initial=step+1) as pbar:
@@ -211,28 +226,32 @@ class Model(nn.Module):
                         x_t_prev=x
                     mx.eval(x)
                     
-        if self.shard_vae.is_last_layer():
+        if self.shard_decoder.is_last_layer():
             is_step_finished=True
-            if self.shard_vae.start_layer != -1:
-                x=self.first_stage_model.decode(x)
-            if self.shard_vae.is_last_layer():
+            if self.shard_decoder.start_layer != -1:
+                x=self.decoder.decode(x)
+            if self.shard_decoder.is_last_layer():
                 x = mx.clip(x / 2 + 0.5, 0, 1)
-                x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
                 B, H, W, C = x.shape
                 x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
                 x = x.reshape(1 * H, B // 1 * W, C)
                 x = (x * 255).astype(mx.uint8)
                 if t_prev.item() ==0:
                     is_finished=True   
+        mx.eval(x)
          
-        return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps}
+        return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
     
 
     def load(self):
-        if self.shard_vae.start_layer != -1:
+        if self.shard_encoder.start_layer != -1:    
             vae_weights =  mx.load(self.config_vae.weight_files[0])
-            vae_weights = self.first_stage_model.sanitize(vae_weights)
-            self.first_stage_model.load_weights(list(vae_weights.items()), strict=True)
+            vae_weights = self.encoder.sanitize(vae_weights)
+            self.encoder.load_weights(list(vae_weights.items()), strict=True)
+        if self.shard_decoder.start_layer != -1:
+            vae_weights =  mx.load(self.config_vae.weight_files[0])
+            vae_weights = self.decoder.sanitize(vae_weights)
+            self.decoder.load_weights(list(vae_weights.items()), strict=True)
         if self.shard_clip.start_layer != -1:
             clip_weights = mx.load(self.config_clip.weight_files[0])
             clip_weights = self.text_encoder.sanitize(clip_weights)
@@ -242,7 +261,6 @@ class Model(nn.Module):
             unet_weights = self.unet.sanitize(unet_weights)
             self.unet.load_weights(list(unet_weights.items()), strict=True)
 
-
 def model_shards(shard:ShardConfig):
     def create_shard(shard, model_ranges):
         start_layer = shard.start_layer
@@ -268,9 +286,10 @@ def model_shards(shard:ShardConfig):
 
     # Define the ranges for different models
     model_ranges = {
-        'clip': (0, 23),
-        'unet':(23,32),
-        'vae': (32, 37) # Example range for unet
+        'clip': (0, 12),
+        'vae_encoder':(12,17),
+        'unet':(17,26),
+        'vae_decoder': (26, 31) # Example range for unet
     }
 
     # Call the function and get the shards for all models
@@ -278,10 +297,11 @@ def model_shards(shard:ShardConfig):
 
     # Access individual shards
     shard_clip = shards['clip']
+    shard_encoder = shards['vae_encoder']
     shard_unet = shards['unet']
-    shard_vae = shards['vae']
+    shard_decoder = shards['vae_decoder']
     
-    return shard_clip, shard_unet, shard_vae
+    return shard_clip, shard_encoder, shard_unet, shard_decoder
 
 
 

+ 11 - 12
exo/inference/mlx/models/sd_models/clip.py

@@ -1,5 +1,6 @@
 # Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
 
+import math
 from dataclasses import dataclass
 from typing import List, Optional
 
@@ -99,13 +100,15 @@ class CLIPTextModel(nn.Module):
         super().__init__()
 
         self.shard = shard
-        
+        self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2) 
         if self.shard.is_first_layer():
             self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
             self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
         self.layers = []
-        for i in range(config.num_layers):
-            if self.shard.start_layer <= i <= self.shard.end_layer:
+        for i in range(math.ceil(config.num_layers/2)):
+            if  2*i in self.layers_range:
+                self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
+            if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
                 self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
             else:
                 self.layers.append(IdentityBlock())
@@ -136,22 +139,18 @@ class CLIPTextModel(nn.Module):
             # Compute the features from the transformer
             mask = self._get_mask(N, x.dtype)
         
-        hidden_states = []
         for l in self.layers:
             x = l(x, mask)
-            hidden_states.append(x)
         # Apply the final layernorm and return
         
         if self.shard.is_last_layer():
             x = self.final_layer_norm(x)
-            last_hidden_state = x
-
+        
        
 
         return x, mask
     def sanitize(self, weights):
         sanitized_weights = {}
-        
         for key, value in weights.items():
             if "position_ids" in key:
                 continue
@@ -180,13 +179,13 @@ class CLIPTextModel(nn.Module):
             
             if key.startswith("layers."):
                 layer_num = int(key.split(".")[1])
-                if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
+                if layer_num not in self.layers_range:
                     continue
-            if not self.shard.start_layer == 0 and "embedding" in key:
+            if not self.shard.is_first_layer() and "embedding" in key:
                 continue
-            if not self.shard.end_layer == 22 and key.startswith("final_layer_norm"):
+            if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
                 continue
-            if not self.shard.end_layer == 22 and key.startswith("text_projection"):
+            if not self.shard.is_last_layer() and key.startswith("text_projection"):
                 continue
             sanitized_weights[key] = value
         return sanitized_weights

+ 145 - 106
exo/inference/mlx/models/sd_models/vae.py

@@ -128,62 +128,75 @@ class Encoder(nn.Module):
     def __init__(
         self,
         in_channels: int,
-        out_channels: int,
+        latent_channels_out: int,
         block_out_channels: List[int] = [64],
         layers_per_block: int = 2,
         resnet_groups: int = 32,
+        layers_range: List[int] = [],
+        shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
     ):
         super().__init__()
-
-        self.conv_in = nn.Conv2d(
-            in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
-        )
+        self.layers_range = layers_range
+        self.shard = shard
+        if self.shard.is_first_layer():
+            self.conv_in = nn.Conv2d(
+                in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
+            )
 
         channels = [block_out_channels[0]] + list(block_out_channels)
-        self.down_blocks = [
-            EncoderDecoderBlock2D(
-                in_channels,
-                out_channels,
-                num_layers=layers_per_block,
-                resnet_groups=resnet_groups,
-                add_downsample=i < len(block_out_channels) - 1,
-                add_upsample=False,
-            )
-            for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
-        ]
+        self.down_blocks = []
+        current_layer = 1
+        for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
+            if current_layer in self.layers_range:
+                self.down_blocks.append(
+                    EncoderDecoderBlock2D(
+                        in_channels,
+                        out_channels,
+                        num_layers=layers_per_block,
+                        resnet_groups=resnet_groups,
+                        add_downsample=i < len(block_out_channels) - 1,
+                        add_upsample=False,
+                    )
+                )
+            else:
+                self.down_blocks.append(IdentityBlock())
+            current_layer += 1
 
-        self.mid_blocks = [
-            ResnetBlock2D(
-                in_channels=block_out_channels[-1],
-                out_channels=block_out_channels[-1],
-                groups=resnet_groups,
-            ),
-            Attention(block_out_channels[-1], resnet_groups),
-            ResnetBlock2D(
-                in_channels=block_out_channels[-1],
-                out_channels=block_out_channels[-1],
-                groups=resnet_groups,
-            ),
-        ]
+        if self.shard.is_last_layer():
+            self.mid_blocks = [
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+                Attention(block_out_channels[-1], resnet_groups),
+                ResnetBlock2D(
+                    in_channels=block_out_channels[-1],
+                    out_channels=block_out_channels[-1],
+                    groups=resnet_groups,
+                ),
+            ]
 
-        self.conv_norm_out = nn.GroupNorm(
-            resnet_groups, block_out_channels[-1], pytorch_compatible=True
-        )
-        self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
+            self.conv_norm_out = nn.GroupNorm(
+                resnet_groups, block_out_channels[-1], pytorch_compatible=True
+            )
+            self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
 
     def __call__(self, x):
-        x = self.conv_in(x)
+        if self.shard.is_first_layer():
+            x = self.conv_in(x)
 
         for l in self.down_blocks:
             x = l(x)
 
-        x = self.mid_blocks[0](x)
-        x = self.mid_blocks[1](x)
-        x = self.mid_blocks[2](x)
+        if self.shard.is_last_layer():
+            x = self.mid_blocks[0](x)
+            x = self.mid_blocks[1](x)
+            x = self.mid_blocks[2](x)
 
-        x = self.conv_norm_out(x)
-        x = nn.silu(x)
-        x = self.conv_out(x)
+            x = self.conv_norm_out(x)
+            x = nn.silu(x)
+            x = self.conv_out(x)
 
         return x
 
@@ -271,7 +284,7 @@ class Decoder(nn.Module):
 class Autoencoder(nn.Module):
     """The autoencoder that allows us to perform diffusion in the latent space."""
 
-    def __init__(self, config: AutoencoderConfig, shard: Shard):
+    def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
         super().__init__()
         self.shard = shard
         self.start_layer = shard.start_layer
@@ -279,46 +292,51 @@ class Autoencoder(nn.Module):
         self.layers_range = list(range(self.start_layer, self.end_layer+1))
         self.latent_channels = config.latent_channels_in
         self.scaling_factor = config.scaling_factor
-        self.decoder_only = True  # stable diffusion text to speech only uses decoder from the autoencoder
-        if not self.decoder_only:
+        self.model_shard = model_shard
+        if self.model_shard == "vae_encoder":
             self.encoder = Encoder(
                 config.in_channels,
                 config.latent_channels_out,
                 config.block_out_channels,
                 config.layers_per_block,
                 resnet_groups=config.norm_num_groups,
+                layers_range=self.layers_range,
+                shard=shard
             )
-            self.quant_proj = nn.Linear(
-            config.latent_channels_out, config.latent_channels_out
-            )
-        self.decoder = Decoder(
-            config.latent_channels_in,
-            config.out_channels,
-            shard,
-            self.layers_range,
-            config.block_out_channels,
-            config.layers_per_block + 1,
-            resnet_groups=config.norm_num_groups,
-        )
-        if 0 in self.layers_range:
-            self.post_quant_proj = nn.Linear(
-                config.latent_channels_in, config.latent_channels_in
+            if self.shard.is_last_layer():
+                self.quant_proj = nn.Linear(
+                config.latent_channels_out, config.latent_channels_out
+                )
+        if self.model_shard == "vae_decoder":
+            self.decoder = Decoder(
+                config.latent_channels_in,
+                config.out_channels,
+                shard,
+                self.layers_range,
+                config.block_out_channels,
+                config.layers_per_block + 1,
+                resnet_groups=config.norm_num_groups,
             )
+            if self.shard.is_first_layer():
+                self.post_quant_proj = nn.Linear(
+                    config.latent_channels_in, config.latent_channels_in
+                )
 
     def decode(self, z):
-        if 0 in self.layers_range:
+        if self.shard.is_first_layer():
             z = z / self.scaling_factor
             z=self.post_quant_proj(z)
         return self.decoder(z)
 
     def encode(self, x):
         x = self.encoder(x)
-        x = self.quant_proj(x)
-        mean, logvar = x.split(2, axis=-1)
-        mean = mean * self.scaling_factor
-        logvar = logvar + 2 * math.log(self.scaling_factor)
-
-        return mean, logvar
+        if self.shard.is_last_layer():   
+            x = self.quant_proj(x)
+            mean, logvar = x.split(2, axis=-1)
+            mean = mean * self.scaling_factor
+            logvar = logvar + 2 * math.log(self.scaling_factor)
+            x = mean
+        return x
 
     def __call__(self, x, key=None):
         mean, logvar = self.encode(x)
@@ -328,46 +346,53 @@ class Autoencoder(nn.Module):
         return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
 
     def sanitize(self, weights):
+        shard = self.shard
         layers = self.layers_range
         sanitized_weights = {}
         for key, value in weights.items():
-            if 'decoder' in key and self.decoder_only:
-                if "downsamplers" in key:
-                    key = key.replace("downsamplers.0.conv", "downsample")
-                if "upsamplers" in key:
-                    key = key.replace("upsamplers.0.conv", "upsample")
-
-                # Map attention layers
-                if "key" in key:
-                    key = key.replace("key", "key_proj")
-                if "proj_attn" in key:
-                    key = key.replace("proj_attn", "out_proj")
-                if "query" in key:
-                    key = key.replace("query", "query_proj")
-                if "value" in key:
-                    key = key.replace("value", "value_proj")
-
-                # Map the mid block
-                if "mid_block.resnets.0" in key:
-                    key = key.replace("mid_block.resnets.0", "mid_blocks.0")
-                if "mid_block.attentions.0" in key:
-                    key = key.replace("mid_block.attentions.0", "mid_blocks.1")
-                if "mid_block.resnets.1" in key:
-                    key = key.replace("mid_block.resnets.1", "mid_blocks.2")
-        
-                # Map the quant/post_quant layers
-                if "quant_conv" in key:
-                    key = key.replace("quant_conv", "quant_proj")
-                    value = value.squeeze()
-                    
-                # Map the conv_shortcut to linear
-                if "conv_shortcut.weight" in key:
-                    value = value.squeeze()
-
-                if len(value.shape) == 4:
-                    value = value.transpose(0, 2, 3, 1)
-                    value = value.reshape(-1).reshape(value.shape)
 
+            if "downsamplers" in key:
+                key = key.replace("downsamplers.0.conv", "downsample")
+            if "upsamplers" in key:
+                key = key.replace("upsamplers.0.conv", "upsample")
+
+            # Map attention layers
+            if "key" in key:
+                key = key.replace("key", "key_proj")
+            if "proj_attn" in key:
+                key = key.replace("proj_attn", "out_proj")
+            if "query" in key:
+                key = key.replace("query", "query_proj")
+            if "value" in key:
+                key = key.replace("value", "value_proj")
+
+            # Map the mid block
+            if "mid_block.resnets.0" in key:
+                key = key.replace("mid_block.resnets.0", "mid_blocks.0")
+            if "mid_block.attentions.0" in key:
+                key = key.replace("mid_block.attentions.0", "mid_blocks.1")
+            if "mid_block.resnets.1" in key:
+                key = key.replace("mid_block.resnets.1", "mid_blocks.2")
+    
+            # Map the quant/post_quant layers
+            if "quant_conv" in key:
+                key = key.replace("quant_conv", "quant_proj")
+                value = value.squeeze()
+                
+            # Map the conv_shortcut to linear
+            if "conv_shortcut.weight" in key:
+                value = value.squeeze()
+
+            if len(value.shape) == 4:
+                value = value.transpose(0, 2, 3, 1)
+                value = value.reshape(-1).reshape(value.shape)
+
+
+            if "post_quant_conv" in key :
+                key = key.replace("quant_conv", "quant_proj")
+                value = value.squeeze()
+            
+            if 'decoder' in key and self.model_shard == "vae_decoder":
                 if key.startswith("decoder.mid_blocks."):
                     if 0 in layers:
                         sanitized_weights[key] = value
@@ -381,10 +406,24 @@ class Autoencoder(nn.Module):
                     sanitized_weights[key] = value
                 if key.startswith("decoder.conv_out") and 4 in layers:
                     sanitized_weights[key] = value
-                
-            if "post_quant_conv" in key and 0 in layers:
-                key = key.replace("quant_conv", "quant_proj")
-                value = value.squeeze()
-                sanitized_weights[key] = value
+            if self.model_shard == "vae_decoder":
+                if key.startswith("post_quant_proj") and 0 in layers:
+                    sanitized_weights[key] = value
+            if self.model_shard == "vae_encoder":
+                if key.startswith("encoder."):
+                    if "conv_in" in key and shard.is_first_layer():
+                        sanitized_weights[key] = value
+                    if key.startswith("encoder.down_blocks."):
+                        layer_num = int(key.split(".")[2])+1
+                        if layer_num in layers:
+                            sanitized_weights[key] = value
+                    if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                    if "conv_norm_out" in key and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                    if "conv_out" in key and shard.is_last_layer():
+                        sanitized_weights[key] = value
+                if key.startswith("quant_proj") and shard.is_last_layer():
+                    sanitized_weights[key] = value
         return sanitized_weights
 

+ 1 - 1
exo/models.py

@@ -81,7 +81,7 @@ model_cards = {
   "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
   "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
   # stable diffusion
-  "stable-diffusion-2-1-base": { "layers": 37, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
+  "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
   # dummy
   "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }

+ 2 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -68,7 +68,7 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -78,6 +78,7 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
+      inference_state=self.serialize_inference_state(inference_state)
     )
     response = await self.stub.SendPrompt(request)
 

+ 2 - 1
exo/networking/grpc/grpc_server.py

@@ -52,7 +52,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     prompt = request.prompt
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, request_id)
+    inference_state = self.deserialize_inference_state(request.inference_state)
+    result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
     if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

+ 1 - 0
exo/networking/grpc/node_service.proto

@@ -23,6 +23,7 @@ message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
   optional string request_id = 3;
+  optional InferenceState inference_state = 4;
 }
 
 message TensorRequest {

Fișier diff suprimat deoarece este prea mare
+ 1 - 1
exo/networking/grpc/node_service_pb2.py


+ 44 - 44
exo/networking/grpc/node_service_pb2_grpc.py

@@ -3,7 +3,7 @@
 import grpc
 import warnings
 
-from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
+from exo.networking.grpc import node_service_pb2 as node__service__pb2
 
 GRPC_GENERATED_VERSION = '1.64.1'
 GRPC_VERSION = grpc.__version__
@@ -20,7 +20,7 @@ except ImportError:
 if _version_not_supported:
     warnings.warn(
         f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -41,38 +41,38 @@ class NodeServiceStub(object):
         """
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.GetInferenceResult = channel.unary_unary(
                 '/node_service.NodeService/GetInferenceResult',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
+                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.InferenceResult.FromString,
                 _registered_method=True)
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
                 _registered_method=True)
         self.SendResult = channel.unary_unary(
                 '/node_service.NodeService/SendResult',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendOpaqueStatus = channel.unary_unary(
                 '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.HealthCheck = channel.unary_unary(
                 '/node_service.NodeService/HealthCheck',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
                 _registered_method=True)
 
 
@@ -126,38 +126,38 @@ def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
                     servicer.GetInferenceResult,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
+                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
             ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
             ),
             'SendResult': grpc.unary_unary_rpc_method_handler(
                     servicer.SendResult,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
                     servicer.SendOpaqueStatus,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'HealthCheck': grpc.unary_unary_rpc_method_handler(
                     servicer.HealthCheck,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
             ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
@@ -185,8 +185,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendPrompt',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -212,8 +212,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendTensor',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -239,8 +239,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/GetInferenceResult',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
+            node__service__pb2.GetInferenceResultRequest.SerializeToString,
+            node__service__pb2.InferenceResult.FromString,
             options,
             channel_credentials,
             insecure,
@@ -266,8 +266,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/CollectTopology',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
             options,
             channel_credentials,
             insecure,
@@ -293,8 +293,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendResult',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -320,8 +320,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendOpaqueStatus',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -347,8 +347,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/HealthCheck',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
             options,
             channel_credentials,
             insecure,

+ 2 - 2
exo/orchestration/node.py

@@ -16,11 +16,11 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod
-  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 3 - 2
exo/orchestration/standard_node.py

@@ -190,7 +190,7 @@ class StandardNode(Node):
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      resp = await self.forward_prompt(shard, prompt, request_id, 0)
+      resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
       return None
     else:
       result,inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
@@ -268,6 +268,7 @@ class StandardNode(Node):
     prompt: str,
     request_id: str,
     target_index: int,
+    inference_state: Optional[dict] = None,
   ) -> None:
     if DEBUG >= 1: print(f"target partition index: {target_index}")
     target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
@@ -280,7 +281,7 @@ class StandardNode(Node):
       if not target_peer:
         raise ValueError(f"Peer for {target_index} not found")
       if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
-      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
   
   async def forward_tensor(
     self,

BIN
exo/tinychat/images/8014d04e-b85a-44a2-88a3-29091c42bff5.png


+ 0 - 3
exo/tinychat/images/README.md

@@ -1,3 +0,0 @@
-# images dir
-
-Images generated in tinychat are stored and served from here.

+ 11 - 1
exo/tinychat/index.html

@@ -120,6 +120,16 @@
                 const img = document.createElement('img');
                 img.src = imageUrl;
                 img.alt = 'Generated Image';
+                img.onclick = async () => {
+                  try {
+                    const response = await fetch(img.src);
+                    const blob = await response.blob();
+                    const file = new File([blob], 'image.png', { type: 'image/png' });
+                    handleImageUpload({ target: { files: [file] } });
+                  } catch (error) {
+                    console.error('Error fetching image:', error);
+                  }
+                };
                 div.appendChild(img);
               } else {
                 div.innerHTML = DOMPurify.sanitize(marked.parse(content));
@@ -207,7 +217,7 @@
 </span>
 </div>
 <div class="input">
-<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
+<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
 <i class="fas fa-image"></i>
 </button>
 <input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>

+ 1 - 0
exo/tinychat/index.js

@@ -243,6 +243,7 @@ document.addEventListener("alpine:init", () => {
             body: JSON.stringify({
               "model": 'stable-diffusion-2-1-base',
               "prompt": apiMessages[apiMessages.length - 1].content,
+              "image_url": this.imageUrl
             }),
           });
       

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff