| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352 |
- from tinygrad import Tensor, dtypes
- from tinygrad.helpers import fetch
- from tinygrad.nn import Linear, LayerNorm, Embedding
- from typing import List, Optional, Union, Tuple
- from abc import ABC, abstractmethod
- from functools import lru_cache
- import re, gzip
- @lru_cache()
- def default_bpe():
- # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
- return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
- class Tokenizer:
- """
- Namespace for CLIP Text Tokenizer components.
- """
- @staticmethod
- def get_pairs(word):
- """
- Return set of symbol pairs in a word.
- Word is represented as tuple of symbols (symbols being variable-length strings).
- """
- return set(zip(word, word[1:]))
- @staticmethod
- def whitespace_clean(text):
- text = re.sub(r'\s+', ' ', text)
- text = text.strip()
- return text
- @staticmethod
- def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a corresponding list of unicode strings.
- The reversible bpe codes work on unicode strings.
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
- This is a significant percentage of your normal, say, 32K bpe vocab.
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
- And avoids mapping to whitespace/control characters the bpe code barfs on.
- """
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8+n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
- class ClipTokenizer:
- def __init__(self):
- self.byte_encoder = Tokenizer.bytes_to_unicode()
- merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n')
- merges = merges[1:49152-256-2+1]
- merges = [tuple(merge.split()) for merge in merges]
- vocab = list(Tokenizer.bytes_to_unicode().values())
- vocab = vocab + [v+'</w>' for v in vocab]
- for merge in merges:
- vocab.append(''.join(merge))
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
- self.encoder = dict(zip(vocab, range(len(vocab))))
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
- def bpe(self, token):
- if token in self.cache:
- return self.cache[token]
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
- pairs = Tokenizer.get_pairs(word)
- if not pairs:
- return token+'</w>'
- while True:
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
- if bigram not in self.bpe_ranks:
- break
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- new_word.extend(word[i:j])
- i = j
- except Exception:
- new_word.extend(word[i:])
- break
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
- new_word.append(first+second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- pairs = Tokenizer.get_pairs(word)
- word = ' '.join(word)
- self.cache[token] = word
- return word
- def encode(self, text:str, pad_with_zeros:bool=False):
- bpe_tokens: List[int] = []
- text = Tokenizer.whitespace_clean(text.strip()).lower()
- for token in re.findall(self.pat, text):
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
- # Truncation, keeping two slots for start and end tokens.
- if len(bpe_tokens) > 75:
- bpe_tokens = bpe_tokens[:75]
- return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2)
- class Embedder(ABC):
- input_key: str
- @abstractmethod
- def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
- pass
- class Closed:
- """
- Namespace for OpenAI CLIP model components.
- """
- class ClipMlp:
- def __init__(self):
- self.fc1 = Linear(768, 3072)
- self.fc2 = Linear(3072, 768)
- def __call__(self, h:Tensor) -> Tensor:
- h = self.fc1(h)
- h = h.quick_gelu()
- h = self.fc2(h)
- return h
- class ClipAttention:
- def __init__(self):
- self.embed_dim = 768
- self.num_heads = 12
- self.head_dim = self.embed_dim // self.num_heads
- self.k_proj = Linear(self.embed_dim, self.embed_dim)
- self.v_proj = Linear(self.embed_dim, self.embed_dim)
- self.q_proj = Linear(self.embed_dim, self.embed_dim)
- self.out_proj = Linear(self.embed_dim, self.embed_dim)
- def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
- bsz, tgt_len, embed_dim = hidden_states.shape
- q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
- q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
- attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
- return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
- class ClipEncoderLayer:
- def __init__(self):
- self.self_attn = Closed.ClipAttention()
- self.layer_norm1 = LayerNorm(768)
- self.mlp = Closed.ClipMlp()
- self.layer_norm2 = LayerNorm(768)
- def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states = self.self_attn(hidden_states, causal_attention_mask)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- class ClipTextEmbeddings:
- def __init__(self):
- self.token_embedding = Embedding(49408, 768)
- self.position_embedding = Embedding(77, 768)
- def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor:
- return self.token_embedding(input_ids) + self.position_embedding(position_ids)
- class ClipEncoder:
- def __init__(self, layer_count:int=12):
- self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)]
- def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor:
- # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state
- layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx]
- for l in layers:
- x = l(x, causal_attention_mask)
- return x
- class ClipTextTransformer:
- def __init__(self, ret_layer_idx:Optional[int]=None):
- self.embeddings = Closed.ClipTextEmbeddings()
- self.encoder = Closed.ClipEncoder()
- self.final_layer_norm = LayerNorm(768)
- self.ret_layer_idx = ret_layer_idx
- def __call__(self, input_ids:Tensor) -> Tensor:
- x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1))
- x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx)
- return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x
- class ClipTextModel:
- def __init__(self, ret_layer_idx:Optional[int]):
- self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx)
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331
- class FrozenClosedClipEmbedder(Embedder):
- def __init__(self, ret_layer_idx:Optional[int]=None):
- self.tokenizer = Tokenizer.ClipTokenizer()
- self.transformer = Closed.ClipTextModel(ret_layer_idx)
- self.input_key = "txt"
- def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
- tokens = Tensor(self.tokenizer.encode(text))
- return self.transformer.text_model(tokens.reshape(1,-1))
- class Open:
- """
- Namespace for OpenCLIP model components.
- """
- class MultiheadAttention:
- def __init__(self, dims:int, n_heads:int):
- self.dims = dims
- self.n_heads = n_heads
- self.d_head = self.dims // self.n_heads
- self.in_proj_bias = Tensor.empty(3*dims)
- self.in_proj_weight = Tensor.empty(3*dims, dims)
- self.out_proj = Linear(dims, dims)
- def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
- T,B,C = x.shape
- proj = x.linear(self.in_proj_weight.T, self.in_proj_bias)
- proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0,-2)
- q,k,v = proj.chunk(3)
- q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in (q,k,v)]
- attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
- attn_output = attn_output.permute(2,0,1,3).reshape(B*T, C)
- attn_output = self.out_proj(attn_output)
- attn_output = attn_output.reshape(T, B, C)
- return attn_output
- class Mlp:
- def __init__(self, dims, hidden_dims):
- self.c_fc = Linear(dims, hidden_dims)
- self.c_proj = Linear(hidden_dims, dims)
- def __call__(self, x:Tensor) -> Tensor:
- return x.sequential([self.c_fc, Tensor.gelu, self.c_proj])
- # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210
- class ResidualAttentionBlocks:
- def __init__(self, dims:int, n_heads:int, mlp_ratio:float):
- self.ln_1 = LayerNorm(dims)
- self.attn = Open.MultiheadAttention(dims, n_heads)
- self.ln_2 = LayerNorm(dims)
- self.mlp = Open.Mlp(dims, int(dims * mlp_ratio))
- def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
- x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
- x = x + self.mlp(self.ln_2(x))
- return x
- # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317
- class ClipTransformer:
- def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0):
- self.resblocks = [
- Open.ResidualAttentionBlocks(dims, n_heads, mlp_ratio) for _ in range(layers)
- ]
- def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor:
- x = x.transpose(0, 1).contiguous()
- for r in self.resblocks:
- x = r(x, attn_mask=attn_mask)
- x = x.transpose(0, 1)
- return x
- # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220
- # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661
- class ClipTextTransformer:
- def __init__(self, dims:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77):
- self.token_embedding = Embedding(vocab_size, dims)
- self.positional_embedding = Tensor.empty(ctx_length, dims)
- self.transformer = Open.ClipTransformer(dims, layers, n_heads)
- self.ln_final = LayerNorm(dims)
- self.text_projection = Tensor.empty(dims, dims)
- @property
- def attn_mask(self) -> Tensor:
- if not hasattr(self, "_attn_mask"):
- self._attn_mask = Tensor.full((77, 77), float("-inf")).triu(1)
- return self._attn_mask
- def __call__(self, text:Tensor) -> Tensor:
- seq_len = text.shape[1]
- x = self.token_embedding(text)
- x = x + self.positional_embedding[:seq_len]
- x = self.transformer(x, attn_mask=self.attn_mask)
- x = self.ln_final(x)
- pooled = x[:, text.argmax(dim=-1)] @ self.text_projection
- return pooled
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396
- # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498
- class FrozenOpenClipEmbedder(Embedder):
- def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False):
- self.tokenizer = Tokenizer.ClipTokenizer()
- self.model = Open.ClipTextTransformer(dims, n_heads, layers)
- self.return_pooled = return_pooled
- self.input_key = "txt"
- self.ln_penultimate = ln_penultimate
- def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None):
- for r in self.model.transformer.resblocks:
- x, penultimate = r(x, attn_mask=attn_mask), x
- return x.permute(1,0,2), penultimate.permute(1,0,2)
- def __call__(self, text:str) -> Union[Tensor,Tuple[Tensor,...]]:
- tokens = Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64).reshape(1,-1)
- x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2)
- x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
- if self.ln_penultimate:
- penultimate = self.model.ln_final(penultimate)
- if self.return_pooled:
- x = self.model.ln_final(x)
- pooled = x[Tensor.arange(x.shape[0]), tokens.argmax(axis=-1).numpy().item()] @ self.model.text_projection
- return penultimate, pooled
- else:
- return penultimate
|