unet.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from tinygrad import Tensor, dtypes
  2. from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
  3. from typing import Optional, Union, List, Any, Tuple
  4. import math
  5. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
  6. def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
  7. half = dim // 2
  8. freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
  9. args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
  10. return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
  11. class ResBlock:
  12. def __init__(self, channels:int, emb_channels:int, out_channels:int):
  13. self.in_layers = [
  14. GroupNorm(32, channels),
  15. Tensor.silu,
  16. Conv2d(channels, out_channels, 3, padding=1),
  17. ]
  18. self.emb_layers = [
  19. Tensor.silu,
  20. Linear(emb_channels, out_channels),
  21. ]
  22. self.out_layers = [
  23. GroupNorm(32, out_channels),
  24. Tensor.silu,
  25. lambda x: x, # needed for weights loading code to work
  26. Conv2d(out_channels, out_channels, 3, padding=1),
  27. ]
  28. self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else (lambda x: x)
  29. def __call__(self, x:Tensor, emb:Tensor) -> Tensor:
  30. h = x.sequential(self.in_layers)
  31. emb_out = emb.sequential(self.emb_layers)
  32. h = h + emb_out.reshape(*emb_out.shape, 1, 1)
  33. h = h.sequential(self.out_layers)
  34. return self.skip_connection(x) + h
  35. class CrossAttention:
  36. def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int):
  37. self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
  38. self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False)
  39. self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False)
  40. self.num_heads = n_heads
  41. self.head_size = d_head
  42. self.to_out = [Linear(n_heads*d_head, query_dim)]
  43. def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
  44. ctx = x if ctx is None else ctx
  45. q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx)
  46. q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
  47. attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2)
  48. h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
  49. return h_.sequential(self.to_out)
  50. class GEGLU:
  51. def __init__(self, dim_in:int, dim_out:int):
  52. self.proj = Linear(dim_in, dim_out * 2)
  53. self.dim_out = dim_out
  54. def __call__(self, x:Tensor) -> Tensor:
  55. x, gate = self.proj(x).chunk(2, dim=-1)
  56. return x * gate.gelu()
  57. class FeedForward:
  58. def __init__(self, dim:int, mult:int=4):
  59. self.net = [
  60. GEGLU(dim, dim*mult),
  61. lambda x: x, # needed for weights loading code to work
  62. Linear(dim*mult, dim)
  63. ]
  64. def __call__(self, x:Tensor) -> Tensor:
  65. return x.sequential(self.net)
  66. class BasicTransformerBlock:
  67. def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int):
  68. self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
  69. self.ff = FeedForward(dim)
  70. self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head)
  71. self.norm1 = LayerNorm(dim)
  72. self.norm2 = LayerNorm(dim)
  73. self.norm3 = LayerNorm(dim)
  74. def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
  75. x = x + self.attn1(self.norm1(x))
  76. x = x + self.attn2(self.norm2(x), ctx=ctx)
  77. x = x + self.ff(self.norm3(x))
  78. return x
  79. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/attention.py#L619
  80. class SpatialTransformer:
  81. def __init__(self, channels:int, n_heads:int, d_head:int, ctx_dim:Union[int,List[int]], use_linear:bool, depth:int=1):
  82. if isinstance(ctx_dim, int):
  83. ctx_dim = [ctx_dim]*depth
  84. else:
  85. assert isinstance(ctx_dim, list) and depth == len(ctx_dim)
  86. self.norm = GroupNorm(32, channels)
  87. assert channels == n_heads * d_head
  88. self.proj_in = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
  89. self.transformer_blocks = [BasicTransformerBlock(channels, ctx_dim[d], n_heads, d_head) for d in range(depth)]
  90. self.proj_out = Linear(channels, channels) if use_linear else Conv2d(channels, channels, 1)
  91. self.use_linear = use_linear
  92. def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor:
  93. b, c, h, w = x.shape
  94. x_in = x
  95. x = self.norm(x)
  96. ops = [ (lambda z: z.reshape(b, c, h*w).permute(0,2,1)), (lambda z: self.proj_in(z)) ]
  97. x = x.sequential(ops if self.use_linear else ops[::-1])
  98. for block in self.transformer_blocks:
  99. x = block(x, ctx=ctx)
  100. ops = [ (lambda z: self.proj_out(z)), (lambda z: z.permute(0,2,1).reshape(b, c, h, w)) ]
  101. x = x.sequential(ops if self.use_linear else ops[::-1])
  102. return x + x_in
  103. class Downsample:
  104. def __init__(self, channels:int):
  105. self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
  106. def __call__(self, x:Tensor) -> Tensor:
  107. return self.op(x)
  108. class Upsample:
  109. def __init__(self, channels:int):
  110. self.conv = Conv2d(channels, channels, 3, padding=1)
  111. def __call__(self, x:Tensor) -> Tensor:
  112. bs,c,py,px = x.shape
  113. z = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
  114. return self.conv(z)
  115. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/openaimodel.py#L472
  116. class UNetModel:
  117. def __init__(self, adm_in_ch:Optional[int], in_ch:int, out_ch:int, model_ch:int, attention_resolutions:List[int], num_res_blocks:int, channel_mult:List[int], transformer_depth:List[int], ctx_dim:Union[int,List[int]], use_linear:bool=False, d_head:Optional[int]=None, n_heads:Optional[int]=None):
  118. self.model_ch = model_ch
  119. self.num_res_blocks = [num_res_blocks] * len(channel_mult)
  120. self.attention_resolutions = attention_resolutions
  121. self.d_head = d_head
  122. self.n_heads = n_heads
  123. def get_d_and_n_heads(dims:int) -> Tuple[int,int]:
  124. if self.d_head is None:
  125. assert self.n_heads is not None, f"d_head and n_heads cannot both be None"
  126. return dims // self.n_heads, self.n_heads
  127. else:
  128. assert self.n_heads is None, f"d_head and n_heads cannot both be non-None"
  129. return self.d_head, dims // self.d_head
  130. time_embed_dim = model_ch * 4
  131. self.time_embed = [
  132. Linear(model_ch, time_embed_dim),
  133. Tensor.silu,
  134. Linear(time_embed_dim, time_embed_dim),
  135. ]
  136. if adm_in_ch is not None:
  137. self.label_emb = [
  138. [
  139. Linear(adm_in_ch, time_embed_dim),
  140. Tensor.silu,
  141. Linear(time_embed_dim, time_embed_dim),
  142. ]
  143. ]
  144. self.input_blocks: List[Any] = [
  145. [Conv2d(in_ch, model_ch, 3, padding=1)]
  146. ]
  147. input_block_channels = [model_ch]
  148. ch = model_ch
  149. ds = 1
  150. for idx, mult in enumerate(channel_mult):
  151. for _ in range(self.num_res_blocks[idx]):
  152. layers: List[Any] = [
  153. ResBlock(ch, time_embed_dim, model_ch*mult),
  154. ]
  155. ch = mult * model_ch
  156. if ds in attention_resolutions:
  157. d_head, n_heads = get_d_and_n_heads(ch)
  158. layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
  159. self.input_blocks.append(layers)
  160. input_block_channels.append(ch)
  161. if idx != len(channel_mult) - 1:
  162. self.input_blocks.append([
  163. Downsample(ch),
  164. ])
  165. input_block_channels.append(ch)
  166. ds *= 2
  167. d_head, n_heads = get_d_and_n_heads(ch)
  168. self.middle_block: List = [
  169. ResBlock(ch, time_embed_dim, ch),
  170. SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[-1]),
  171. ResBlock(ch, time_embed_dim, ch),
  172. ]
  173. self.output_blocks = []
  174. for idx, mult in list(enumerate(channel_mult))[::-1]:
  175. for i in range(self.num_res_blocks[idx] + 1):
  176. ich = input_block_channels.pop()
  177. layers = [
  178. ResBlock(ch + ich, time_embed_dim, model_ch*mult),
  179. ]
  180. ch = model_ch * mult
  181. if ds in attention_resolutions:
  182. d_head, n_heads = get_d_and_n_heads(ch)
  183. layers.append(SpatialTransformer(ch, n_heads, d_head, ctx_dim, use_linear, depth=transformer_depth[idx]))
  184. if idx > 0 and i == self.num_res_blocks[idx]:
  185. layers.append(Upsample(ch))
  186. ds //= 2
  187. self.output_blocks.append(layers)
  188. self.out = [
  189. GroupNorm(32, ch),
  190. Tensor.silu,
  191. Conv2d(model_ch, out_ch, 3, padding=1),
  192. ]
  193. def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
  194. t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16)
  195. emb = t_emb.sequential(self.time_embed)
  196. if y is not None:
  197. assert y.shape[0] == x.shape[0]
  198. emb = emb + y.sequential(self.label_emb[0])
  199. emb = emb.cast(dtypes.float16)
  200. ctx = ctx.cast(dtypes.float16)
  201. x = x .cast(dtypes.float16)
  202. def run(x:Tensor, bb) -> Tensor:
  203. if isinstance(bb, ResBlock): x = bb(x, emb)
  204. elif isinstance(bb, SpatialTransformer): x = bb(x, ctx)
  205. else: x = bb(x)
  206. return x
  207. saved_inputs = []
  208. for b in self.input_blocks:
  209. for bb in b:
  210. x = run(x, bb)
  211. saved_inputs.append(x)
  212. for bb in self.middle_block:
  213. x = run(x, bb)
  214. for b in self.output_blocks:
  215. x = x.cat(saved_inputs.pop(), dim=1)
  216. for bb in b:
  217. x = run(x, bb)
  218. return x.sequential(self.out)