|
@@ -128,62 +128,75 @@ class Encoder(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
- out_channels: int,
|
|
|
+ latent_channels_out: int,
|
|
|
block_out_channels: List[int] = [64],
|
|
|
layers_per_block: int = 2,
|
|
|
resnet_groups: int = 32,
|
|
|
+ layers_range: List[int] = [],
|
|
|
+ shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
|
|
):
|
|
|
super().__init__()
|
|
|
-
|
|
|
- self.conv_in = nn.Conv2d(
|
|
|
- in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
|
|
- )
|
|
|
+ self.layers_range = layers_range
|
|
|
+ self.shard = shard
|
|
|
+ if self.shard.is_first_layer():
|
|
|
+ self.conv_in = nn.Conv2d(
|
|
|
+ in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
|
|
+ )
|
|
|
|
|
|
channels = [block_out_channels[0]] + list(block_out_channels)
|
|
|
- self.down_blocks = [
|
|
|
- EncoderDecoderBlock2D(
|
|
|
- in_channels,
|
|
|
- out_channels,
|
|
|
- num_layers=layers_per_block,
|
|
|
- resnet_groups=resnet_groups,
|
|
|
- add_downsample=i < len(block_out_channels) - 1,
|
|
|
- add_upsample=False,
|
|
|
- )
|
|
|
- for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
|
|
- ]
|
|
|
+ self.down_blocks = []
|
|
|
+ current_layer = 1
|
|
|
+ for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
|
|
+ if current_layer in self.layers_range:
|
|
|
+ self.down_blocks.append(
|
|
|
+ EncoderDecoderBlock2D(
|
|
|
+ in_channels,
|
|
|
+ out_channels,
|
|
|
+ num_layers=layers_per_block,
|
|
|
+ resnet_groups=resnet_groups,
|
|
|
+ add_downsample=i < len(block_out_channels) - 1,
|
|
|
+ add_upsample=False,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.down_blocks.append(IdentityBlock())
|
|
|
+ current_layer += 1
|
|
|
|
|
|
- self.mid_blocks = [
|
|
|
- ResnetBlock2D(
|
|
|
- in_channels=block_out_channels[-1],
|
|
|
- out_channels=block_out_channels[-1],
|
|
|
- groups=resnet_groups,
|
|
|
- ),
|
|
|
- Attention(block_out_channels[-1], resnet_groups),
|
|
|
- ResnetBlock2D(
|
|
|
- in_channels=block_out_channels[-1],
|
|
|
- out_channels=block_out_channels[-1],
|
|
|
- groups=resnet_groups,
|
|
|
- ),
|
|
|
- ]
|
|
|
+ if self.shard.is_last_layer():
|
|
|
+ self.mid_blocks = [
|
|
|
+ ResnetBlock2D(
|
|
|
+ in_channels=block_out_channels[-1],
|
|
|
+ out_channels=block_out_channels[-1],
|
|
|
+ groups=resnet_groups,
|
|
|
+ ),
|
|
|
+ Attention(block_out_channels[-1], resnet_groups),
|
|
|
+ ResnetBlock2D(
|
|
|
+ in_channels=block_out_channels[-1],
|
|
|
+ out_channels=block_out_channels[-1],
|
|
|
+ groups=resnet_groups,
|
|
|
+ ),
|
|
|
+ ]
|
|
|
|
|
|
- self.conv_norm_out = nn.GroupNorm(
|
|
|
- resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
|
|
- )
|
|
|
- self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
|
|
+ self.conv_norm_out = nn.GroupNorm(
|
|
|
+ resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
|
|
+ )
|
|
|
+ self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
|
|
|
|
|
|
def __call__(self, x):
|
|
|
- x = self.conv_in(x)
|
|
|
+ if self.shard.is_first_layer():
|
|
|
+ x = self.conv_in(x)
|
|
|
|
|
|
for l in self.down_blocks:
|
|
|
x = l(x)
|
|
|
|
|
|
- x = self.mid_blocks[0](x)
|
|
|
- x = self.mid_blocks[1](x)
|
|
|
- x = self.mid_blocks[2](x)
|
|
|
+ if self.shard.is_last_layer():
|
|
|
+ x = self.mid_blocks[0](x)
|
|
|
+ x = self.mid_blocks[1](x)
|
|
|
+ x = self.mid_blocks[2](x)
|
|
|
|
|
|
- x = self.conv_norm_out(x)
|
|
|
- x = nn.silu(x)
|
|
|
- x = self.conv_out(x)
|
|
|
+ x = self.conv_norm_out(x)
|
|
|
+ x = nn.silu(x)
|
|
|
+ x = self.conv_out(x)
|
|
|
|
|
|
return x
|
|
|
|
|
@@ -271,7 +284,7 @@ class Decoder(nn.Module):
|
|
|
class Autoencoder(nn.Module):
|
|
|
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
|
|
|
|
|
- def __init__(self, config: AutoencoderConfig, shard: Shard):
|
|
|
+ def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
|
|
|
super().__init__()
|
|
|
self.shard = shard
|
|
|
self.start_layer = shard.start_layer
|
|
@@ -279,46 +292,51 @@ class Autoencoder(nn.Module):
|
|
|
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
|
|
self.latent_channels = config.latent_channels_in
|
|
|
self.scaling_factor = config.scaling_factor
|
|
|
- self.decoder_only = True # stable diffusion text to speech only uses decoder from the autoencoder
|
|
|
- if not self.decoder_only:
|
|
|
+ self.model_shard = model_shard
|
|
|
+ if self.model_shard == "vae_encoder":
|
|
|
self.encoder = Encoder(
|
|
|
config.in_channels,
|
|
|
config.latent_channels_out,
|
|
|
config.block_out_channels,
|
|
|
config.layers_per_block,
|
|
|
resnet_groups=config.norm_num_groups,
|
|
|
+ layers_range=self.layers_range,
|
|
|
+ shard=shard
|
|
|
)
|
|
|
- self.quant_proj = nn.Linear(
|
|
|
- config.latent_channels_out, config.latent_channels_out
|
|
|
- )
|
|
|
- self.decoder = Decoder(
|
|
|
- config.latent_channels_in,
|
|
|
- config.out_channels,
|
|
|
- shard,
|
|
|
- self.layers_range,
|
|
|
- config.block_out_channels,
|
|
|
- config.layers_per_block + 1,
|
|
|
- resnet_groups=config.norm_num_groups,
|
|
|
- )
|
|
|
- if 0 in self.layers_range:
|
|
|
- self.post_quant_proj = nn.Linear(
|
|
|
- config.latent_channels_in, config.latent_channels_in
|
|
|
+ if self.shard.is_last_layer():
|
|
|
+ self.quant_proj = nn.Linear(
|
|
|
+ config.latent_channels_out, config.latent_channels_out
|
|
|
+ )
|
|
|
+ if self.model_shard == "vae_decoder":
|
|
|
+ self.decoder = Decoder(
|
|
|
+ config.latent_channels_in,
|
|
|
+ config.out_channels,
|
|
|
+ shard,
|
|
|
+ self.layers_range,
|
|
|
+ config.block_out_channels,
|
|
|
+ config.layers_per_block + 1,
|
|
|
+ resnet_groups=config.norm_num_groups,
|
|
|
)
|
|
|
+ if self.shard.is_first_layer():
|
|
|
+ self.post_quant_proj = nn.Linear(
|
|
|
+ config.latent_channels_in, config.latent_channels_in
|
|
|
+ )
|
|
|
|
|
|
def decode(self, z):
|
|
|
- if 0 in self.layers_range:
|
|
|
+ if self.shard.is_first_layer():
|
|
|
z = z / self.scaling_factor
|
|
|
z=self.post_quant_proj(z)
|
|
|
return self.decoder(z)
|
|
|
|
|
|
def encode(self, x):
|
|
|
x = self.encoder(x)
|
|
|
- x = self.quant_proj(x)
|
|
|
- mean, logvar = x.split(2, axis=-1)
|
|
|
- mean = mean * self.scaling_factor
|
|
|
- logvar = logvar + 2 * math.log(self.scaling_factor)
|
|
|
-
|
|
|
- return mean, logvar
|
|
|
+ if self.shard.is_last_layer():
|
|
|
+ x = self.quant_proj(x)
|
|
|
+ mean, logvar = x.split(2, axis=-1)
|
|
|
+ mean = mean * self.scaling_factor
|
|
|
+ logvar = logvar + 2 * math.log(self.scaling_factor)
|
|
|
+ x = mean
|
|
|
+ return x
|
|
|
|
|
|
def __call__(self, x, key=None):
|
|
|
mean, logvar = self.encode(x)
|
|
@@ -328,46 +346,53 @@ class Autoencoder(nn.Module):
|
|
|
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
|
|
|
|
|
def sanitize(self, weights):
|
|
|
+ shard = self.shard
|
|
|
layers = self.layers_range
|
|
|
sanitized_weights = {}
|
|
|
for key, value in weights.items():
|
|
|
- if 'decoder' in key and self.decoder_only:
|
|
|
- if "downsamplers" in key:
|
|
|
- key = key.replace("downsamplers.0.conv", "downsample")
|
|
|
- if "upsamplers" in key:
|
|
|
- key = key.replace("upsamplers.0.conv", "upsample")
|
|
|
-
|
|
|
- # Map attention layers
|
|
|
- if "key" in key:
|
|
|
- key = key.replace("key", "key_proj")
|
|
|
- if "proj_attn" in key:
|
|
|
- key = key.replace("proj_attn", "out_proj")
|
|
|
- if "query" in key:
|
|
|
- key = key.replace("query", "query_proj")
|
|
|
- if "value" in key:
|
|
|
- key = key.replace("value", "value_proj")
|
|
|
-
|
|
|
- # Map the mid block
|
|
|
- if "mid_block.resnets.0" in key:
|
|
|
- key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
|
|
- if "mid_block.attentions.0" in key:
|
|
|
- key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
|
|
- if "mid_block.resnets.1" in key:
|
|
|
- key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
|
|
-
|
|
|
- # Map the quant/post_quant layers
|
|
|
- if "quant_conv" in key:
|
|
|
- key = key.replace("quant_conv", "quant_proj")
|
|
|
- value = value.squeeze()
|
|
|
-
|
|
|
- # Map the conv_shortcut to linear
|
|
|
- if "conv_shortcut.weight" in key:
|
|
|
- value = value.squeeze()
|
|
|
-
|
|
|
- if len(value.shape) == 4:
|
|
|
- value = value.transpose(0, 2, 3, 1)
|
|
|
- value = value.reshape(-1).reshape(value.shape)
|
|
|
|
|
|
+ if "downsamplers" in key:
|
|
|
+ key = key.replace("downsamplers.0.conv", "downsample")
|
|
|
+ if "upsamplers" in key:
|
|
|
+ key = key.replace("upsamplers.0.conv", "upsample")
|
|
|
+
|
|
|
+ # Map attention layers
|
|
|
+ if "key" in key:
|
|
|
+ key = key.replace("key", "key_proj")
|
|
|
+ if "proj_attn" in key:
|
|
|
+ key = key.replace("proj_attn", "out_proj")
|
|
|
+ if "query" in key:
|
|
|
+ key = key.replace("query", "query_proj")
|
|
|
+ if "value" in key:
|
|
|
+ key = key.replace("value", "value_proj")
|
|
|
+
|
|
|
+ # Map the mid block
|
|
|
+ if "mid_block.resnets.0" in key:
|
|
|
+ key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
|
|
+ if "mid_block.attentions.0" in key:
|
|
|
+ key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
|
|
+ if "mid_block.resnets.1" in key:
|
|
|
+ key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
|
|
+
|
|
|
+ # Map the quant/post_quant layers
|
|
|
+ if "quant_conv" in key:
|
|
|
+ key = key.replace("quant_conv", "quant_proj")
|
|
|
+ value = value.squeeze()
|
|
|
+
|
|
|
+ # Map the conv_shortcut to linear
|
|
|
+ if "conv_shortcut.weight" in key:
|
|
|
+ value = value.squeeze()
|
|
|
+
|
|
|
+ if len(value.shape) == 4:
|
|
|
+ value = value.transpose(0, 2, 3, 1)
|
|
|
+ value = value.reshape(-1).reshape(value.shape)
|
|
|
+
|
|
|
+
|
|
|
+ if "post_quant_conv" in key :
|
|
|
+ key = key.replace("quant_conv", "quant_proj")
|
|
|
+ value = value.squeeze()
|
|
|
+
|
|
|
+ if 'decoder' in key and self.model_shard == "vae_decoder":
|
|
|
if key.startswith("decoder.mid_blocks."):
|
|
|
if 0 in layers:
|
|
|
sanitized_weights[key] = value
|
|
@@ -381,10 +406,24 @@ class Autoencoder(nn.Module):
|
|
|
sanitized_weights[key] = value
|
|
|
if key.startswith("decoder.conv_out") and 4 in layers:
|
|
|
sanitized_weights[key] = value
|
|
|
-
|
|
|
- if "post_quant_conv" in key and 0 in layers:
|
|
|
- key = key.replace("quant_conv", "quant_proj")
|
|
|
- value = value.squeeze()
|
|
|
- sanitized_weights[key] = value
|
|
|
+ if self.model_shard == "vae_decoder":
|
|
|
+ if key.startswith("post_quant_proj") and 0 in layers:
|
|
|
+ sanitized_weights[key] = value
|
|
|
+ if self.model_shard == "vae_encoder":
|
|
|
+ if key.startswith("encoder."):
|
|
|
+ if "conv_in" in key and shard.is_first_layer():
|
|
|
+ sanitized_weights[key] = value
|
|
|
+ if key.startswith("encoder.down_blocks."):
|
|
|
+ layer_num = int(key.split(".")[2])+1
|
|
|
+ if layer_num in layers:
|
|
|
+ sanitized_weights[key] = value
|
|
|
+ if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
|
|
|
+ sanitized_weights[key] = value
|
|
|
+ if "conv_norm_out" in key and shard.is_last_layer():
|
|
|
+ sanitized_weights[key] = value
|
|
|
+ if "conv_out" in key and shard.is_last_layer():
|
|
|
+ sanitized_weights[key] = value
|
|
|
+ if key.startswith("quant_proj") and shard.is_last_layer():
|
|
|
+ sanitized_weights[key] = value
|
|
|
return sanitized_weights
|
|
|
|