helpers.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from collections import OrderedDict
  2. import unicodedata
  3. import numpy as np
  4. from tinygrad.nn import state
  5. from tinygrad.tensor import Tensor, dtypes
  6. from tinygrad.helpers import getenv
  7. #
  8. # checkpointing utils
  9. #
  10. def invert_dict(d): return {v: k for k, v in reversed(d.items())}
  11. def dedup_dict(d): return invert_dict(invert_dict(d))
  12. # store each tensor into the first key it appears in
  13. def get_training_state(model, optimizer, scheduler):
  14. # hack: let get_state_dict walk the tree starting with model, so that the checkpoint keys are
  15. # readable and can be loaded as a model for eval
  16. train_state = {'model': model, 'optimizer': optimizer, 'scheduler': scheduler}
  17. return dedup_dict(state.get_state_dict(train_state))
  18. def load_training_state(model, optimizer, scheduler, state_dict):
  19. # use fresh model to restore duplicate keys
  20. train_state = {'model': model, 'optimizer': optimizer, 'scheduler': scheduler}
  21. big_dict = state.get_state_dict(train_state)
  22. # hack: put back the dupes
  23. dupe_names = {}
  24. for k, v in big_dict.items():
  25. if v not in dupe_names:
  26. dupe_names[v] = k
  27. assert k in state_dict
  28. state_dict[k] = state_dict[dupe_names[v]]
  29. # scheduler contains optimizer and all params, load each weight only once
  30. scheduler_state = {'scheduler': scheduler}
  31. state.load_state_dict(scheduler_state, state_dict)
  32. def gaussian_kernel(n, std):
  33. from scipy import signal
  34. gaussian_1d = signal.windows.gaussian(n, std)
  35. gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
  36. gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
  37. gaussian_3d = gaussian_3d.reshape(n, n, n)
  38. gaussian_3d = np.cbrt(gaussian_3d)
  39. gaussian_3d /= gaussian_3d.max()
  40. return gaussian_3d
  41. def prepare_arrays(image, roi_shape=(128, 128, 128)):
  42. assert len(roi_shape) == 3 and any(roi_shape)
  43. image_shape = list(image.shape[2:])
  44. result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
  45. norm_map = np.zeros_like(result)
  46. norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
  47. return result, norm_map, norm_patch
  48. def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
  49. assert len(roi_shape) == 3 and any(roi_shape)
  50. assert 0 < overlap_factor < 1
  51. image_shape, dim = list(image.shape[2:]), len(image.shape[2:])
  52. strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)]
  53. size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
  54. for i in range(0, strides[0] * size[0], strides[0]):
  55. for j in range(0, strides[1] * size[1], strides[1]):
  56. for k in range(0, strides[2] * size[2], strides[2]):
  57. yield i, j, k
  58. def _get_best_indices(logits, n_best_size):
  59. index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
  60. return list(map(lambda x: x[0], index_and_score))[:n_best_size]
  61. def _is_punctuation(char):
  62. if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
  63. return True
  64. return unicodedata.category(char).startswith("P")
  65. def _is_whitespace(char):
  66. if char == " " or char == "\t" or char == "\n" or char == "\r":
  67. return True
  68. return unicodedata.category(char) == "Zs"
  69. def _is_control(char):
  70. if char == "\t" or char == "\n" or char == "\r":
  71. return False
  72. return unicodedata.category(char).startswith("C")
  73. def _run_split_on_punc(text):
  74. if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
  75. return [text]
  76. start_new_word = True
  77. output = []
  78. for i in range(len(text)):
  79. if _is_punctuation(char := text[i]):
  80. output.append([char])
  81. start_new_word = True
  82. else:
  83. if start_new_word:
  84. output.append([])
  85. start_new_word = False
  86. output[-1].append(char)
  87. return ["".join(x) for x in output]
  88. def _run_strip_accents(text):
  89. output = []
  90. for char in unicodedata.normalize("NFD", text):
  91. if unicodedata.category(char) != "Mn":
  92. output.append(char)
  93. return "".join(output)
  94. def _clean_text(text):
  95. output = []
  96. for char in text:
  97. if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
  98. output.append(" " if _is_whitespace(char) else char)
  99. return "".join(output)
  100. def _get_final_text(pred_text, orig_text):
  101. def _strip_spaces(text):
  102. ns_text = ""
  103. ns_to_s_map = OrderedDict()
  104. for i, c in enumerate(text):
  105. if c == " ":
  106. continue
  107. ns_to_s_map[len(ns_text)] = i
  108. ns_text += c
  109. return ns_text, ns_to_s_map
  110. orig_tokens = _clean_text(orig_text).strip().split()
  111. split_tokens = []
  112. for token in orig_tokens:
  113. if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
  114. token = token.lower()
  115. token = _run_strip_accents(token)
  116. split_tokens.extend(_run_split_on_punc(token))
  117. tok_text = " ".join(" ".join(split_tokens).strip().split())
  118. start_position = tok_text.find(pred_text)
  119. if start_position == -1:
  120. return orig_text
  121. end_position = start_position + len(pred_text) - 1
  122. orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text)
  123. tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text)
  124. if len(orig_ns_text) != len(tok_ns_text):
  125. return orig_text
  126. tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()}
  127. orig_start_position = None
  128. if start_position in tok_s_to_ns_map:
  129. if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map:
  130. orig_start_position = orig_ns_to_s_map[ns_start_position]
  131. if orig_start_position is None:
  132. return orig_text
  133. orig_end_position = None
  134. if end_position in tok_s_to_ns_map:
  135. if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map:
  136. orig_end_position = orig_ns_to_s_map[ns_end_position]
  137. if orig_end_position is None:
  138. return orig_text
  139. output_text = orig_text[orig_start_position:(orig_end_position + 1)]
  140. return output_text
  141. def get_bert_qa_prediction(features, example, start_end_logits):
  142. prelim_predictions = []
  143. for i, feature in enumerate(features):
  144. for start_index in _get_best_indices(start_end_logits[i][0], 20):
  145. for end_index in _get_best_indices(start_end_logits[i][1], 20):
  146. if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]):
  147. continue
  148. if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]:
  149. continue
  150. if not feature["token_is_max_context"].get(start_index, False):
  151. continue
  152. if end_index < start_index or end_index - start_index + 1 > 30:
  153. continue
  154. prelim_predictions.append({
  155. "feature_index": i,
  156. "start_index": start_index,
  157. "end_index": end_index,
  158. "start_logit": start_end_logits[i][0, start_index],
  159. "end_logit": start_end_logits[i][1, end_index]
  160. })
  161. predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True)
  162. if len(predictions) > 0:
  163. feature = features[predictions[0]["feature_index"]]
  164. tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)]
  165. orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
  166. orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
  167. orig_tokens = example["context"][orig_doc_start:(orig_doc_end + 1)]
  168. tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "")
  169. tok_text = " ".join(tok_text.strip().split())
  170. orig_text = " ".join(orig_tokens)
  171. return _get_final_text(tok_text, orig_text)
  172. return "empty"
  173. def get_mlperf_bert_config():
  174. """Config is BERT-large"""
  175. return {
  176. "attention_probs_dropout_prob": 0.1,
  177. "hidden_dropout_prob": 0.1,
  178. "hidden_size": 1024,
  179. "intermediate_size": 4096,
  180. "max_position_embeddings": 512,
  181. "num_attention_heads": 16,
  182. "num_hidden_layers": 24,
  183. "type_vocab_size": 2,
  184. "vocab_size": 30522
  185. }
  186. def get_mlperf_bert_model(checkpoint_path:str=""):
  187. from extra.models import bert
  188. from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
  189. bert.Linear = LinearBert
  190. bert.Embedding = EmbeddingBert
  191. bert.LayerNorm = LayerNormBert
  192. from extra.models.bert import BertForPretraining
  193. config = get_mlperf_bert_config()
  194. if getenv("DISABLE_DROPOUT", 0):
  195. config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
  196. model = BertForPretraining(**config)
  197. if checkpoint_path: model.load_from_pretrained(checkpoint_path)
  198. return model
  199. def get_data_bert(GPUS:list[str], it):
  200. data: dict[str, Tensor] = next(it)
  201. for key in data.keys(): data[key].shard_(GPUS, axis=0)
  202. return data
  203. def get_fake_data_bert(GPUS:list[str], BS:int):
  204. return {
  205. "input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
  206. "input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0),
  207. "segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
  208. "masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
  209. "masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
  210. "masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
  211. "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0),
  212. }