sdxl.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # This file incorporates code from the following:
  2. # Github Name | License | Link
  3. # Stability-AI/generative-models | MIT | https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/LICENSE-CODE
  4. # mlfoundations/open_clip | MIT | https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/LICENSE
  5. from tinygrad import Tensor, TinyJit, dtypes
  6. from tinygrad.nn import Conv2d, GroupNorm
  7. from tinygrad.nn.state import safe_load, load_state_dict
  8. from tinygrad.helpers import fetch, trange, colored, Timing, GlobalCounters
  9. from extra.models.clip import Embedder, FrozenClosedClipEmbedder, FrozenOpenClipEmbedder
  10. from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embedding
  11. from examples.stable_diffusion import ResnetBlock, Mid
  12. import numpy as np
  13. from typing import Dict, List, Callable, Optional, Any, Set, Tuple
  14. import argparse, tempfile
  15. from abc import ABC, abstractmethod
  16. from pathlib import Path
  17. from PIL import Image
  18. # configs:
  19. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_base.yaml
  20. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/configs/inference/sd_xl_refiner.yaml
  21. configs: Dict = {
  22. "SDXL_Base": {
  23. "model": {"adm_in_ch": 2816, "in_ch": 4, "out_ch": 4, "model_ch": 320, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4], "d_head": 64, "transformer_depth": [1, 2, 10], "ctx_dim": 2048, "use_linear": True},
  24. "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "target_size_as_tuple"]},
  25. "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
  26. "denoiser": {"num_idx": 1000},
  27. },
  28. "SDXL_Refiner": {
  29. "model": {"adm_in_ch": 2560, "in_ch": 4, "out_ch": 4, "model_ch": 384, "attention_resolutions": [4, 2], "num_res_blocks": 2, "channel_mult": [1, 2, 4, 4], "d_head": 64, "transformer_depth": [4, 4, 4, 4], "ctx_dim": [1280, 1280, 1280, 1280], "use_linear": True},
  30. "conditioner": {"concat_embedders": ["original_size_as_tuple", "crop_coords_top_left", "aesthetic_score"]},
  31. "first_stage_model": {"ch": 128, "in_ch": 3, "out_ch": 3, "z_ch": 4, "ch_mult": [1, 2, 4, 4], "num_res_blocks": 2, "resolution": 256},
  32. "denoiser": {"num_idx": 1000},
  33. }
  34. }
  35. def tensor_identity(x:Tensor) -> Tensor:
  36. return x
  37. class DiffusionModel:
  38. def __init__(self, *args, **kwargs):
  39. self.diffusion_model = UNetModel(*args, **kwargs)
  40. class Embedder(ABC):
  41. input_key: str
  42. @abstractmethod
  43. def __call__(self, x:Tensor) -> Tensor:
  44. pass
  45. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L913
  46. class ConcatTimestepEmbedderND(Embedder):
  47. def __init__(self, outdim:int, input_key:str):
  48. self.outdim = outdim
  49. self.input_key = input_key
  50. def __call__(self, x:Tensor):
  51. assert len(x.shape) == 2
  52. emb = timestep_embedding(x.flatten(), self.outdim)
  53. emb = emb.reshape((x.shape[0],-1))
  54. return emb
  55. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L71
  56. class Conditioner:
  57. OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
  58. KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
  59. embedders: List[Embedder]
  60. def __init__(self, concat_embedders:List[str]):
  61. self.embedders = [
  62. FrozenClosedClipEmbedder(ret_layer_idx=11),
  63. FrozenOpenClipEmbedder(dims=1280, n_heads=20, layers=32, return_pooled=True),
  64. ]
  65. for input_key in concat_embedders:
  66. self.embedders.append(ConcatTimestepEmbedderND(256, input_key))
  67. def get_keys(self) -> Set[str]:
  68. return set(e.input_key for e in self.embedders)
  69. def __call__(self, batch:Dict, force_zero_embeddings:List=[]) -> Dict[str,Tensor]:
  70. output: Dict[str,Tensor] = {}
  71. for embedder in self.embedders:
  72. emb_out = embedder(batch[embedder.input_key])
  73. if isinstance(emb_out, Tensor):
  74. emb_out = [emb_out]
  75. else:
  76. assert isinstance(emb_out, (list, tuple))
  77. for emb in emb_out:
  78. if embedder.input_key in force_zero_embeddings:
  79. emb = Tensor.zeros_like(emb)
  80. out_key = self.OUTPUT_DIM2KEYS[len(emb.shape)]
  81. if out_key in output:
  82. output[out_key] = Tensor.cat(output[out_key], emb, dim=self.KEY2CATDIM[out_key])
  83. else:
  84. output[out_key] = emb
  85. return output
  86. class FirstStage:
  87. """
  88. Namespace for First Stage Model components
  89. """
  90. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L487
  91. class Encoder:
  92. def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
  93. self.conv_in = Conv2d(in_ch, ch, kernel_size=3, stride=1, padding=1)
  94. in_ch_mult = (1,) + tuple(ch_mult)
  95. class BlockEntry:
  96. def __init__(self, block:List[ResnetBlock], downsample):
  97. self.block = block
  98. self.downsample = downsample
  99. self.down: List[BlockEntry] = []
  100. for i_level in range(len(ch_mult)):
  101. block = []
  102. block_in = ch * in_ch_mult[i_level]
  103. block_out = ch * ch_mult [i_level]
  104. for _ in range(num_res_blocks):
  105. block.append(ResnetBlock(block_in, block_out))
  106. block_in = block_out
  107. downsample = tensor_identity if (i_level == len(ch_mult)-1) else Downsample(block_in)
  108. self.down.append(BlockEntry(block, downsample))
  109. self.mid = Mid(block_in)
  110. self.norm_out = GroupNorm(32, block_in)
  111. self.conv_out = Conv2d(block_in, 2*z_ch, kernel_size=3, stride=1, padding=1)
  112. def __call__(self, x:Tensor) -> Tensor:
  113. h = self.conv_in(x)
  114. for down in self.down:
  115. for block in down.block:
  116. h = block(h)
  117. h = down.downsample(h)
  118. h = h.sequential([self.mid.block_1, self.mid.attn_1, self.mid.block_2])
  119. h = h.sequential([self.norm_out, Tensor.swish, self.conv_out ])
  120. return h
  121. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/model.py#L604
  122. class Decoder:
  123. def __init__(self, ch:int, in_ch:int, out_ch:int, z_ch:int, ch_mult:List[int], num_res_blocks:int, resolution:int):
  124. block_in = ch * ch_mult[-1]
  125. curr_res = resolution // 2 ** (len(ch_mult) - 1)
  126. self.z_shape = (1, z_ch, curr_res, curr_res)
  127. self.conv_in = Conv2d(z_ch, block_in, kernel_size=3, stride=1, padding=1)
  128. self.mid = Mid(block_in)
  129. class BlockEntry:
  130. def __init__(self, block:List[ResnetBlock], upsample:Callable[[Any],Any]):
  131. self.block = block
  132. self.upsample = upsample
  133. self.up: List[BlockEntry] = []
  134. for i_level in reversed(range(len(ch_mult))):
  135. block = []
  136. block_out = ch * ch_mult[i_level]
  137. for _ in range(num_res_blocks + 1):
  138. block.append(ResnetBlock(block_in, block_out))
  139. block_in = block_out
  140. upsample = tensor_identity if i_level == 0 else Upsample(block_in)
  141. self.up.insert(0, BlockEntry(block, upsample)) # type: ignore
  142. self.norm_out = GroupNorm(32, block_in)
  143. self.conv_out = Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
  144. def __call__(self, z:Tensor) -> Tensor:
  145. h = z.sequential([self.conv_in, self.mid.block_1, self.mid.attn_1, self.mid.block_2])
  146. for up in self.up[::-1]:
  147. for block in up.block:
  148. h = block(h)
  149. h = up.upsample(h)
  150. h = h.sequential([self.norm_out, Tensor.swish, self.conv_out])
  151. return h
  152. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L102
  153. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L437
  154. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/autoencoder.py#L508
  155. class FirstStageModel:
  156. def __init__(self, embed_dim:int=4, **kwargs):
  157. self.encoder = FirstStage.Encoder(**kwargs)
  158. self.decoder = FirstStage.Decoder(**kwargs)
  159. self.quant_conv = Conv2d(2*kwargs["z_ch"], 2*embed_dim, 1)
  160. self.post_quant_conv = Conv2d(embed_dim, kwargs["z_ch"], 1)
  161. def decode(self, z:Tensor) -> Tensor:
  162. return z.sequential([self.post_quant_conv, self.decoder])
  163. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/discretizer.py#L42
  164. class LegacyDDPMDiscretization:
  165. def __init__(self, linear_start:float=0.00085, linear_end:float=0.0120, num_timesteps:int=1000):
  166. self.num_timesteps = num_timesteps
  167. betas = np.linspace(linear_start**0.5, linear_end**0.5, num_timesteps, dtype=np.float32) ** 2
  168. alphas = 1.0 - betas
  169. self.alphas_cumprod = np.cumprod(alphas, axis=0)
  170. def __call__(self, n:int, flip:bool=False) -> Tensor:
  171. if n < self.num_timesteps:
  172. timesteps = np.linspace(self.num_timesteps - 1, 0, n, endpoint=False).astype(int)[::-1]
  173. alphas_cumprod = self.alphas_cumprod[timesteps]
  174. elif n == self.num_timesteps:
  175. alphas_cumprod = self.alphas_cumprod
  176. sigmas = Tensor((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
  177. sigmas = Tensor.cat(Tensor.zeros((1,)), sigmas)
  178. return sigmas if flip else sigmas.flip(axis=0) # sigmas is "pre-flipped", need to do oposite of flag
  179. def append_dims(x:Tensor, t:Tensor) -> Tensor:
  180. dims_to_append = len(t.shape) - len(x.shape)
  181. assert dims_to_append >= 0
  182. return x.reshape(x.shape + (1,)*dims_to_append)
  183. @TinyJit
  184. def run(model, x, tms, ctx, y, c_out, add):
  185. return (model(x, tms, ctx, y)*c_out + add).realize()
  186. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/models/diffusion.py#L19
  187. class SDXL:
  188. def __init__(self, config:Dict):
  189. self.conditioner = Conditioner(**config["conditioner"])
  190. self.first_stage_model = FirstStageModel(**config["first_stage_model"])
  191. self.model = DiffusionModel(**config["model"])
  192. self.discretization = LegacyDDPMDiscretization()
  193. self.sigmas = self.discretization(config["denoiser"]["num_idx"], flip=True)
  194. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L173
  195. def create_conditioning(self, pos_prompt:str, img_width:int, img_height:int, aesthetic_score:float=5.0) -> Tuple[Dict,Dict]:
  196. batch_c : Dict = {
  197. "txt": pos_prompt,
  198. "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
  199. "crop_coords_top_left": Tensor([0,0]).repeat(N,1),
  200. "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
  201. "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
  202. }
  203. batch_uc: Dict = {
  204. "txt": "",
  205. "original_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
  206. "crop_coords_top_left": Tensor([0,0]).repeat(N,1),
  207. "target_size_as_tuple": Tensor([img_height,img_width]).repeat(N,1),
  208. "aesthetic_score": Tensor([aesthetic_score]).repeat(N,1),
  209. }
  210. return model.conditioner(batch_c), model.conditioner(batch_uc, force_zero_embeddings=["txt"])
  211. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/denoiser.py#L42
  212. def denoise(self, x:Tensor, sigma:Tensor, cond:Dict) -> Tensor:
  213. def sigma_to_idx(s:Tensor) -> Tensor:
  214. dists = s - self.sigmas.unsqueeze(1)
  215. return dists.abs().argmin(axis=0).view(*s.shape)
  216. sigma = self.sigmas[sigma_to_idx(sigma)]
  217. sigma_shape = sigma.shape
  218. sigma = append_dims(sigma, x)
  219. c_out = -sigma
  220. c_in = 1 / (sigma**2 + 1.0) ** 0.5
  221. c_noise = sigma_to_idx(sigma.reshape(sigma_shape))
  222. def prep(*tensors:Tensor):
  223. return tuple(t.cast(dtypes.float16).realize() for t in tensors)
  224. return run(self.model.diffusion_model, *prep(x*c_in, c_noise, cond["crossattn"], cond["vector"], c_out, x))
  225. def decode(self, x:Tensor) -> Tensor:
  226. return self.first_stage_model.decode(1.0 / 0.13025 * x)
  227. class VanillaCFG:
  228. def __init__(self, scale:float):
  229. self.scale = scale
  230. def prepare_inputs(self, x:Tensor, s:float, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Tensor]:
  231. c_out = {}
  232. for k in c:
  233. assert k in ["vector", "crossattn", "concat"]
  234. c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
  235. return Tensor.cat(x, x), Tensor.cat(s, s), c_out
  236. def __call__(self, x:Tensor, sigma:float) -> Tensor:
  237. x_u, x_c = x.chunk(2)
  238. x_pred = x_u + self.scale*(x_c - x_u)
  239. return x_pred
  240. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
  241. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
  242. class DPMPP2MSampler:
  243. def __init__(self, cfg_scale:float):
  244. self.discretization = LegacyDDPMDiscretization()
  245. self.guider = VanillaCFG(cfg_scale)
  246. def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
  247. denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
  248. denoised = self.guider(denoised, sigma)
  249. t, t_next = sigma.log().neg(), next_sigma.log().neg()
  250. h = t_next - t
  251. r = None if prev_sigma is None else (t - prev_sigma.log().neg()) / h
  252. mults = [t_next.neg().exp()/t.neg().exp(), (-h).exp().sub(1)]
  253. if r is not None:
  254. mults.extend([1 + 1/(2*r), 1/(2*r)])
  255. mults = [append_dims(m, x) for m in mults]
  256. x_standard = mults[0]*x - mults[1]*denoised
  257. if (old_denoised is None) or (next_sigma.sum().numpy().item() < 1e-14):
  258. return x_standard, denoised
  259. denoised_d = mults[2]*denoised - mults[3]*old_denoised
  260. x_advanced = mults[0]*x - mults[1]*denoised_d
  261. x = Tensor.where(append_dims(next_sigma, x) > 0.0, x_advanced, x_standard)
  262. return x, denoised
  263. def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
  264. sigmas = self.discretization(num_steps)
  265. x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
  266. num_sigmas = len(sigmas)
  267. old_denoised = None
  268. for i in trange(num_sigmas - 1):
  269. with Timing("step in ", enabled=timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
  270. x, old_denoised = self.sampler_step(
  271. old_denoised=old_denoised,
  272. prev_sigma=(None if i==0 else sigmas[i-1].reshape(x.shape[0])),
  273. sigma=sigmas[i].reshape(x.shape[0]),
  274. next_sigma=sigmas[i+1].reshape(x.shape[0]),
  275. denoiser=denoiser,
  276. x=x,
  277. c=c,
  278. uc=uc,
  279. )
  280. x.realize()
  281. old_denoised.realize()
  282. return x
  283. if __name__ == "__main__":
  284. default_prompt = "a horse sized cat eating a bagel"
  285. parser = argparse.ArgumentParser(description="Run SDXL", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  286. parser.add_argument('--steps', type=int, default=10, help="The number of diffusion steps")
  287. parser.add_argument('--prompt', type=str, default=default_prompt, help="Description of image to generate")
  288. parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
  289. parser.add_argument('--seed', type=int, help="Set the random latent seed")
  290. parser.add_argument('--guidance', type=float, default=6.0, help="Prompt strength")
  291. parser.add_argument('--width', type=int, default=1024, help="The output image width")
  292. parser.add_argument('--height', type=int, default=1024, help="The output image height")
  293. parser.add_argument('--weights', type=str, help="Custom path to weights")
  294. parser.add_argument('--timing', action='store_true', help="Print timing per step")
  295. parser.add_argument('--noshow', action='store_true', help="Don't show the image")
  296. args = parser.parse_args()
  297. Tensor.no_grad = True
  298. if args.seed is not None:
  299. Tensor.manual_seed(args.seed)
  300. model = SDXL(configs["SDXL_Base"])
  301. default_weight_url = 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors'
  302. weights = args.weights if args.weights else fetch(default_weight_url, 'sd_xl_base_1.0.safetensors')
  303. load_state_dict(model, safe_load(weights), strict=False)
  304. N = 1
  305. C = 4
  306. F = 8
  307. assert args.width % F == 0, f"img_width must be multiple of {F}, got {args.width}"
  308. assert args.height % F == 0, f"img_height must be multiple of {F}, got {args.height}"
  309. c, uc = model.create_conditioning(args.prompt, args.width, args.height)
  310. del model.conditioner
  311. for v in c .values(): v.realize()
  312. for v in uc.values(): v.realize()
  313. print("created batch")
  314. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/inference/helpers.py#L101
  315. shape = (N, C, args.height // F, args.width // F)
  316. randn = Tensor.randn(shape)
  317. sampler = DPMPP2MSampler(args.guidance)
  318. z = sampler(model.denoise, randn, c, uc, args.steps, timing=args.timing)
  319. print("created samples")
  320. x = model.decode(z).realize()
  321. print("decoded samples")
  322. # make image correct size and scale
  323. x = (x + 1.0) / 2.0
  324. x = x.reshape(3,args.height,args.width).permute(1,2,0).clip(0,1).mul(255).cast(dtypes.uint8)
  325. print(x.shape)
  326. im = Image.fromarray(x.numpy())
  327. print(f"saving {args.out}")
  328. im.save(args.out)
  329. if not args.noshow:
  330. im.show()
  331. # validation!
  332. if args.prompt == default_prompt and args.steps == 10 and args.seed == 0 and args.guidance == 6.0 and args.width == args.height == 1024 \
  333. and not args.weights:
  334. ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "sdxl_seed0.png")))
  335. distance = (((x - ref_image).cast(dtypes.float) / ref_image.max())**2).mean().item()
  336. assert distance < 2e-3, colored(f"validation failed with {distance=}", "red")
  337. print(colored(f"output validated with {distance=}", "green"))