bert.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import re, os
  2. from pathlib import Path
  3. from tinygrad.tensor import Tensor, cast
  4. from tinygrad import nn, dtypes
  5. from tinygrad.helpers import fetch, get_child
  6. from tinygrad.nn.state import get_parameters
  7. # allow for monkeypatching
  8. Embedding = nn.Embedding
  9. Linear = nn.Linear
  10. LayerNorm = nn.LayerNorm
  11. class BertForQuestionAnswering:
  12. def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1):
  13. self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
  14. self.qa_outputs = Linear(hidden_size, 2)
  15. def load_from_pretrained(self):
  16. fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt"
  17. fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn)
  18. fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt"
  19. fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
  20. import torch
  21. with open(fn, "rb") as f:
  22. state_dict = torch.load(f, map_location="cpu")
  23. for k, v in state_dict.items():
  24. if "dropout" in k: continue # skip dropout
  25. if "pooler" in k: continue # skip pooler
  26. get_child(self, k).assign(v.numpy()).realize()
  27. def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor):
  28. sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
  29. logits = self.qa_outputs(sequence_output)
  30. start_logits, end_logits = logits.chunk(2, dim=-1)
  31. start_logits = start_logits.reshape(-1, 1)
  32. end_logits = end_logits.reshape(-1, 1)
  33. return Tensor.stack(start_logits, end_logits)
  34. class BertForPretraining:
  35. def __init__(self, hidden_size:int=1024, intermediate_size:int=4096, max_position_embeddings:int=512, num_attention_heads:int=16, num_hidden_layers:int=24, type_vocab_size:int=2, vocab_size:int=30522, attention_probs_dropout_prob:float=0.1, hidden_dropout_prob:float=0.1):
  36. """Default is BERT-large"""
  37. self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
  38. self.cls = BertPreTrainingHeads(hidden_size, vocab_size, self.bert.embeddings.word_embeddings.weight)
  39. def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions:Tensor, token_type_ids:Tensor):
  40. output = self.bert(input_ids, attention_mask, token_type_ids)
  41. return self.cls(output, masked_lm_positions)
  42. def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
  43. # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
  44. def sparse_categorical_crossentropy(predictions:Tensor, labels:Tensor, ignore_index=-1):
  45. log_probs, loss_mask = predictions.log_softmax(), (labels != ignore_index)
  46. y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
  47. y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
  48. return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero
  49. masked_lm_loss = sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
  50. next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
  51. return masked_lm_loss + next_sentence_loss
  52. def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
  53. valid = masked_lm_ids != 0
  54. masked_lm_predictions = prediction_logits.log_softmax().argmax(-1)
  55. masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
  56. masked_lm_loss = prediction_logits.sparse_categorical_crossentropy(masked_lm_ids, ignore_index=masked_lm_weights)
  57. seq_relationship_predictions = seq_relationship_logits.log_softmax().argmax(-1)
  58. seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
  59. next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
  60. return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss
  61. def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
  62. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
  63. # load from tensorflow
  64. import tensorflow as tf
  65. import numpy as np
  66. state_dict = {}
  67. for name, _ in tf.train.list_variables(str(tf_weight_path)):
  68. state_dict[name] = tf.train.load_variable(str(tf_weight_path), name)
  69. for k, v in state_dict.items():
  70. m = k.split("/")
  71. if any(n in ["adam_v", "adam_m", "global_step", "LAMB", "LAMB_1", "beta1_power", "beta2_power"] for n in m):
  72. continue
  73. pointer = self
  74. n = m[-1] # this is just to stop python from complaining about possibly unbound local variable
  75. for i, n in enumerate(m):
  76. if re.fullmatch(r'[A-Za-z]+_\d+', n):
  77. l = re.split(r'_(\d+)', n)[:-1]
  78. else:
  79. l = [n]
  80. if l[0] in ["kernel", "gamma", "output_weights"]:
  81. pointer = getattr(pointer, "weight")
  82. elif l[0] in ["output_bias", "beta"]:
  83. pointer = getattr(pointer, "bias")
  84. elif l[0] == "pooler":
  85. pointer = getattr(getattr(self, "cls"), "pooler")
  86. else:
  87. pointer = getattr(pointer, l[0])
  88. if len(l) == 2: # layers
  89. pointer = pointer[int(l[1])]
  90. if n[-11:] == "_embeddings":
  91. pointer = getattr(pointer, "weight")
  92. elif n == "kernel":
  93. v = np.transpose(v)
  94. cast(Tensor, pointer).assign(v).realize()
  95. params = get_parameters(self)
  96. count = 0
  97. for p in params:
  98. param_count = 1
  99. for s in p.shape:
  100. param_count *= s
  101. count += param_count
  102. print(f"Total parameters: {count / 1000 / 1000}M")
  103. return self
  104. class BertPreTrainingHeads:
  105. def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
  106. self.predictions = BertLMPredictionHead(hidden_size, vocab_size, embeddings_weight)
  107. self.pooler = BertPooler(hidden_size)
  108. self.seq_relationship = Linear(hidden_size, 2)
  109. def __call__(self, sequence_output:Tensor, masked_lm_positions:Tensor):
  110. prediction_logits = self.predictions(gather(sequence_output, masked_lm_positions))
  111. seq_relationship_logits = self.seq_relationship(self.pooler(sequence_output))
  112. return prediction_logits, seq_relationship_logits
  113. class BertLMPredictionHead:
  114. def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
  115. self.transform = BertPredictionHeadTransform(hidden_size)
  116. self.embedding_weight = embeddings_weight
  117. self.bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)
  118. def __call__(self, hidden_states:Tensor):
  119. return self.transform(hidden_states) @ self.embedding_weight.T + self.bias
  120. class BertPredictionHeadTransform:
  121. def __init__(self, hidden_size:int):
  122. self.dense = Linear(hidden_size, hidden_size)
  123. self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
  124. def __call__(self, hidden_states:Tensor):
  125. return self.LayerNorm(gelu(self.dense(hidden_states)))
  126. class BertPooler:
  127. def __init__(self, hidden_size:int):
  128. self.dense = Linear(hidden_size, hidden_size)
  129. def __call__(self, hidden_states:Tensor):
  130. return self.dense(hidden_states[:, 0]).tanh()
  131. def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
  132. counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
  133. onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
  134. return onehot @ prediction_logits
  135. class Bert:
  136. def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
  137. self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
  138. self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob)
  139. def __call__(self, input_ids, attention_mask, token_type_ids):
  140. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  141. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  142. embedding_output = self.embeddings(input_ids, token_type_ids)
  143. encoder_outputs = self.encoder(embedding_output, extended_attention_mask)
  144. return encoder_outputs
  145. class BertEmbeddings:
  146. def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob):
  147. self.word_embeddings = Embedding(vocab_size, hidden_size)
  148. self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
  149. self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
  150. self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
  151. self.dropout = hidden_dropout_prob
  152. def __call__(self, input_ids, token_type_ids):
  153. input_shape = input_ids.shape
  154. seq_length = input_shape[1]
  155. position_ids = Tensor.arange(seq_length, requires_grad=False, device=input_ids.device).unsqueeze(0).expand(*input_shape)
  156. words_embeddings = self.word_embeddings(input_ids)
  157. position_embeddings = self.position_embeddings(position_ids)
  158. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  159. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  160. embeddings = self.LayerNorm(embeddings)
  161. embeddings = embeddings.dropout(self.dropout)
  162. return embeddings
  163. class BertEncoder:
  164. def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob):
  165. self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)]
  166. def __call__(self, hidden_states, attention_mask):
  167. for layer in self.layer:
  168. hidden_states = layer(hidden_states, attention_mask)
  169. return hidden_states
  170. class BertLayer:
  171. def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
  172. self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
  173. self.intermediate = BertIntermediate(hidden_size, intermediate_size)
  174. self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
  175. def __call__(self, hidden_states, attention_mask):
  176. attention_output = self.attention(hidden_states, attention_mask)
  177. intermediate_output = self.intermediate(attention_output)
  178. layer_output = self.output(intermediate_output, attention_output)
  179. return layer_output
  180. class BertOutput:
  181. def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
  182. self.dense = Linear(intermediate_size, hidden_size)
  183. self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
  184. self.dropout = hidden_dropout_prob
  185. def __call__(self, hidden_states, input_tensor):
  186. hidden_states = self.dense(hidden_states)
  187. hidden_states = hidden_states.dropout(self.dropout)
  188. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  189. return hidden_states
  190. def gelu(x):
  191. return x * 0.5 * (1.0 + erf(x / 1.41421))
  192. # approximation of the error function
  193. def erf(x):
  194. t = (1 + 0.3275911 * x.abs()).reciprocal()
  195. return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
  196. class BertIntermediate:
  197. def __init__(self, hidden_size, intermediate_size):
  198. self.dense = Linear(hidden_size, intermediate_size)
  199. def __call__(self, hidden_states):
  200. x = self.dense(hidden_states)
  201. # tinygrad gelu is openai gelu but we need the original bert gelu
  202. return gelu(x)
  203. class BertAttention:
  204. def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
  205. self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
  206. self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
  207. def __call__(self, hidden_states, attention_mask):
  208. self_output = self.self(hidden_states, attention_mask)
  209. attention_output = self.output(self_output, hidden_states)
  210. return attention_output
  211. class BertSelfAttention:
  212. def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
  213. self.num_attention_heads = num_attention_heads
  214. self.attention_head_size = int(hidden_size / num_attention_heads)
  215. self.all_head_size = self.num_attention_heads * self.attention_head_size
  216. self.query = Linear(hidden_size, self.all_head_size)
  217. self.key = Linear(hidden_size, self.all_head_size)
  218. self.value = Linear(hidden_size, self.all_head_size)
  219. self.dropout = attention_probs_dropout_prob
  220. def __call__(self, hidden_states, attention_mask):
  221. mixed_query_layer = self.query(hidden_states)
  222. mixed_key_layer = self.key(hidden_states)
  223. mixed_value_layer = self.value(hidden_states)
  224. query_layer = self.transpose_for_scores(mixed_query_layer)
  225. key_layer = self.transpose_for_scores(mixed_key_layer)
  226. value_layer = self.transpose_for_scores(mixed_value_layer)
  227. context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout)
  228. context_layer = context_layer.transpose(1, 2)
  229. context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size)
  230. return context_layer
  231. def transpose_for_scores(self, x):
  232. x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size)
  233. return x.transpose(1, 2)
  234. class BertSelfOutput:
  235. def __init__(self, hidden_size, hidden_dropout_prob):
  236. self.dense = Linear(hidden_size, hidden_size)
  237. self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
  238. self.dropout = hidden_dropout_prob
  239. def __call__(self, hidden_states, input_tensor):
  240. hidden_states = self.dense(hidden_states)
  241. hidden_states = hidden_states.dropout(self.dropout)
  242. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  243. return hidden_states