clip.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from tinygrad import Tensor, dtypes
  2. from tinygrad.helpers import fetch
  3. from tinygrad.nn import Linear, LayerNorm, Embedding
  4. from typing import List, Optional, Union, Tuple
  5. from abc import ABC, abstractmethod
  6. from functools import lru_cache
  7. import re, gzip
  8. @lru_cache()
  9. def default_bpe():
  10. # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
  11. return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
  12. class Tokenizer:
  13. """
  14. Namespace for CLIP Text Tokenizer components.
  15. """
  16. @staticmethod
  17. def get_pairs(word):
  18. """
  19. Return set of symbol pairs in a word.
  20. Word is represented as tuple of symbols (symbols being variable-length strings).
  21. """
  22. return set(zip(word, word[1:]))
  23. @staticmethod
  24. def whitespace_clean(text):
  25. text = re.sub(r'\s+', ' ', text)
  26. text = text.strip()
  27. return text
  28. @staticmethod
  29. def bytes_to_unicode():
  30. """
  31. Returns list of utf-8 byte and a corresponding list of unicode strings.
  32. The reversible bpe codes work on unicode strings.
  33. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  34. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  35. This is a significant percentage of your normal, say, 32K bpe vocab.
  36. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  37. And avoids mapping to whitespace/control characters the bpe code barfs on.
  38. """
  39. bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
  40. cs = bs[:]
  41. n = 0
  42. for b in range(2**8):
  43. if b not in bs:
  44. bs.append(b)
  45. cs.append(2**8+n)
  46. n += 1
  47. cs = [chr(n) for n in cs]
  48. return dict(zip(bs, cs))
  49. class ClipTokenizer:
  50. def __init__(self):
  51. self.byte_encoder = Tokenizer.bytes_to_unicode()
  52. merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
  53. merges = merges[1:49152-256-2+1]
  54. merges = [tuple(merge.split()) for merge in merges]
  55. vocab = list(Tokenizer.bytes_to_unicode().values())
  56. vocab = vocab + [v+'</w>' for v in vocab]
  57. for merge in merges:
  58. vocab.append(''.join(merge))
  59. vocab.extend(['<|startoftext|>', '<|endoftext|>'])
  60. self.encoder = dict(zip(vocab, range(len(vocab))))
  61. self.bpe_ranks = dict(zip(merges, range(len(merges))))
  62. self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
  63. self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
  64. def bpe(self, token):
  65. if token in self.cache:
  66. return self.cache[token]
  67. word = tuple(token[:-1]) + ( token[-1] + '</w>',)
  68. pairs = Tokenizer.get_pairs(word)
  69. if not pairs:
  70. return token+'</w>'
  71. while True:
  72. bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
  73. if bigram not in self.bpe_ranks:
  74. break
  75. first, second = bigram
  76. new_word = []
  77. i = 0
  78. while i < len(word):
  79. try:
  80. j = word.index(first, i)
  81. new_word.extend(word[i:j])
  82. i = j
  83. except Exception:
  84. new_word.extend(word[i:])
  85. break
  86. if word[i] == first and i < len(word)-1 and word[i+1] == second:
  87. new_word.append(first+second)
  88. i += 2
  89. else:
  90. new_word.append(word[i])
  91. i += 1
  92. new_word = tuple(new_word)
  93. word = new_word
  94. if len(word) == 1:
  95. break
  96. pairs = Tokenizer.get_pairs(word)
  97. word = ' '.join(word)
  98. self.cache[token] = word
  99. return word
  100. def encode(self, text:str, pad_with_zeros:bool=False):
  101. bpe_tokens: List[int] = []
  102. text = Tokenizer.whitespace_clean(text.strip()).lower()
  103. for token in re.findall(self.pat, text):
  104. token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
  105. bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
  106. # Truncation, keeping two slots for start and end tokens.
  107. if len(bpe_tokens) > 75:
  108. bpe_tokens = bpe_tokens[:75]
  109. return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
  110. class Embedder(ABC):
  111. input_key: str
  112. @abstractmethod
  113. def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
  114. pass
  115. class Closed:
  116. """
  117. Namespace for OpenAI CLIP model components.
  118. """
  119. class ClipMlp:
  120. def __init__(self):
  121. self.fc1 = Linear(768, 3072)
  122. self.fc2 = Linear(3072, 768)
  123. def __call__(self, h:Tensor) -> Tensor:
  124. h = self.fc1(h)
  125. h = h.quick_gelu()
  126. h = self.fc2(h)
  127. return h
  128. class ClipAttention:
  129. def __init__(self):
  130. self.embed_dim = 768
  131. self.num_heads = 12
  132. self.head_dim = self.embed_dim // self.num_heads
  133. self.k_proj = Linear(self.embed_dim, self.embed_dim)
  134. self.v_proj = Linear(self.embed_dim, self.embed_dim)
  135. self.q_proj = Linear(self.embed_dim, self.embed_dim)
  136. self.out_proj = Linear(self.embed_dim, self.embed_dim)
  137. def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
  138. bsz, tgt_len, embed_dim = hidden_states.shape
  139. q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
  140. q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
  141. attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
  142. return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
  143. class ClipEncoderLayer:
  144. def __init__(self):
  145. self.self_attn = Closed.ClipAttention()
  146. self.layer_norm1 = LayerNorm(768)
  147. self.mlp = Closed.ClipMlp()
  148. self.layer_norm2 = LayerNorm(768)
  149. def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
  150. residual = hidden_states
  151. hidden_states = self.layer_norm1(hidden_states)
  152. hidden_states = self.self_attn(hidden_states, causal_attention_mask)
  153. hidden_states = residual + hidden_states
  154. residual = hidden_states
  155. hidden_states = self.layer_norm2(hidden_states)
  156. hidden_states = self.mlp(hidden_states)
  157. hidden_states = residual + hidden_states
  158. return hidden_states
  159. class ClipTextEmbeddings:
  160. def __init__(self):
  161. self.token_embedding = Embedding(49408, 768)
  162. self.position_embedding = Embedding(77, 768)
  163. def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
  164. return self.token_embedding(input_ids) + self.position_embedding(position_ids)
  165. class ClipEncoder:
  166. def __init__(self, layer_count:int=12):
  167. self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
  168. def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
  169. # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
  170. layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
  171. for l in layers:
  172. x = l(x, causal_attention_mask)
  173. return x
  174. class ClipTextTransformer:
  175. def __init__(self, ret_layer_idx:Optional[int]=None):
  176. self.embeddings = Closed.ClipTextEmbeddings()
  177. self.encoder = Closed.ClipEncoder()
  178. self.final_layer_norm = LayerNorm(768)
  179. self.ret_layer_idx = ret_layer_idx
  180. def __call__(self, input_ids:Tensor) -> Tensor:
  181. x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
  182. x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
  183. return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
  184. class ClipTextModel:
  185. def __init__(self, ret_layer_idx:Optional[int]):
  186. self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
  187. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
  188. class FrozenClosedClipEmbedder(Embedder):
  189. def __init__(self, ret_layer_idx:Optional[int]=None):
  190. self.tokenizer = Tokenizer.ClipTokenizer()
  191. self.transformer = Closed.ClipTextModel(ret_layer_idx)
  192. self.input_key = "txt"
  193. def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
  194. tokens = Tensor(self.tokenizer.encode(text))
  195. return self.transformer.text_model(tokens.reshape(1,-1))
  196. class Open:
  197. """
  198. Namespace for OpenCLIP model components.
  199. """
  200. class MultiheadAttention:
  201. def __init__(self, dims:int, n_heads:int):
  202. self.dims = dims
  203. self.n_heads = n_heads
  204. self.d_head = self.dims // self.n_heads
  205. self.in_proj_bias = Tensor.empty(3*dims)
  206. self.in_proj_weight = Tensor.empty(3*dims, dims)
  207. self.out_proj = Linear(dims, dims)
  208. def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
  209. T,B,C = x.shape
  210. proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
  211. proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2)
  212. q,k,v = proj.chunk(3)
  213. q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)]
  214. attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
  215. attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C)
  216. attn_output = self.out_proj(attn_output)
  217. attn_output = attn_output.reshape(T, B, C)
  218. return attn_output
  219. class Mlp:
  220. def __init__(self, dims, hidden_dims):
  221. self.c_fc = Linear(dims, hidden_dims)
  222. self.c_proj = Linear(hidden_dims, dims)
  223. def __call__(self, x:Tensor) -> Tensor:
  224. return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
  225. # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
  226. class ResidualAttentionBlocks:
  227. def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
  228. self.ln_1 = LayerNorm(dims)
  229. self.attn = Open.MultiheadAttention(dims, n_heads)
  230. self.ln_2 = LayerNorm(dims)
  231. self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
  232. def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
  233. x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
  234. x = x + self.mlp(self.ln_2(x))
  235. return x
  236. # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
  237. class ClipTransformer:
  238. def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
  239. self.resblocks = [
  240. Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers)
  241. ]
  242. def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
  243. x = x.transpose(0, 1).contiguous()
  244. for r in self.resblocks:
  245. x = r(x, attn_mask=attn_mask)
  246. x = x.transpose(0, 1)
  247. return x
  248. # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
  249. # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
  250. class ClipTextTransformer:
  251. def __init__(self, dims:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
  252. self.token_embedding = Embedding(vocab_size, dims)
  253. self.positional_embedding = Tensor.empty(ctx_length, dims)
  254. self.transformer = Open.ClipTransformer(dims, layers, n_heads)
  255. self.ln_final = LayerNorm(dims)
  256. self.text_projection = Tensor.empty(dims, dims)
  257. @property
  258. def attn_mask(self) -> Tensor:
  259. if not hasattr(self, "_attn_mask"):
  260. self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1)
  261. return self._attn_mask
  262. def __call__(self, text:Tensor) -> Tensor:
  263. seq_len = text.shape[1]
  264. x = self.token_embedding(text)
  265. x = x + self.positional_embedding[:seq_len]
  266. x = self.transformer(x, attn_mask=self.attn_mask)
  267. x = self.ln_final(x)
  268. pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
  269. return pooled
  270. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
  271. # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
  272. class FrozenOpenClipEmbedder(Embedder):
  273. def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False):
  274. self.tokenizer = Tokenizer.ClipTokenizer()
  275. self.model = Open.ClipTextTransformer(dims, n_heads, layers)
  276. self.return_pooled = return_pooled
  277. self.input_key = "txt"
  278. self.ln_penultimate = ln_penultimate
  279. def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
  280. for r in self.model.transformer.resblocks:
  281. x, penultimate = r(x, attn_mask=attn_mask), x
  282. return x.permute(1,0,2), penultimate.permute(1,0,2)
  283. def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
  284. tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1)
  285. x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
  286. x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
  287. if self.ln_penultimate:
  288. penultimate = self.model.ln_final(penultimate)
  289. if self.return_pooled:
  290. x = self.model.ln_final(x)
  291. pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection
  292. return penultimate, pooled
  293. else:
  294. return penultimate