squad.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import json
  2. import os
  3. from pathlib import Path
  4. from transformers import BertTokenizer
  5. import numpy as np
  6. from tinygrad.helpers import fetch
  7. BASEDIR = Path(__file__).parent / "squad"
  8. def init_dataset():
  9. os.makedirs(BASEDIR, exist_ok=True)
  10. fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
  11. with open(BASEDIR / "dev-v1.1.json") as f:
  12. data = json.load(f)["data"]
  13. examples = []
  14. for article in data:
  15. for paragraph in article["paragraphs"]:
  16. text = paragraph["context"]
  17. doc_tokens = []
  18. prev_is_whitespace = True
  19. for c in text:
  20. if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
  21. prev_is_whitespace = True
  22. else:
  23. if prev_is_whitespace:
  24. doc_tokens.append(c)
  25. else:
  26. doc_tokens[-1] += c
  27. prev_is_whitespace = False
  28. for qa in paragraph["qas"]:
  29. qa_id = qa["id"]
  30. q_text = qa["question"]
  31. examples.append({
  32. "id": qa_id,
  33. "question": q_text,
  34. "context": doc_tokens,
  35. "answers": list(map(lambda x: x["text"], qa["answers"]))
  36. })
  37. return examples
  38. def _check_is_max_context(doc_spans, cur_span_index, position):
  39. best_score, best_span_index = None, None
  40. for di, (doc_start, doc_length) in enumerate(doc_spans):
  41. end = doc_start + doc_length - 1
  42. if position < doc_start:
  43. continue
  44. if position > end:
  45. continue
  46. num_left_context = position - doc_start
  47. num_right_context = end - position
  48. score = min(num_left_context, num_right_context) + 0.01 * doc_length
  49. if best_score is None or score > best_score:
  50. best_score = score
  51. best_span_index = di
  52. return cur_span_index == best_span_index
  53. def convert_example_to_features(example, tokenizer):
  54. query_tokens = tokenizer.tokenize(example["question"])
  55. if len(query_tokens) > 64:
  56. query_tokens = query_tokens[:64]
  57. tok_to_orig_index = []
  58. orig_to_tok_index = []
  59. all_doc_tokens = []
  60. for i, token in enumerate(example["context"]):
  61. orig_to_tok_index.append(len(all_doc_tokens))
  62. sub_tokens = tokenizer.tokenize(token)
  63. for sub_token in sub_tokens:
  64. tok_to_orig_index.append(i)
  65. all_doc_tokens.append(sub_token)
  66. max_tokens_for_doc = 384 - len(query_tokens) - 3
  67. doc_spans = []
  68. start_offset = 0
  69. while start_offset < len(all_doc_tokens):
  70. length = len(all_doc_tokens) - start_offset
  71. length = min(length, max_tokens_for_doc)
  72. doc_spans.append((start_offset, length))
  73. if start_offset + length == len(all_doc_tokens):
  74. break
  75. start_offset += min(length, 128)
  76. outputs = []
  77. for di, (doc_start, doc_length) in enumerate(doc_spans):
  78. tokens = []
  79. token_to_orig_map = {}
  80. token_is_max_context = {}
  81. segment_ids = []
  82. tokens.append("[CLS]")
  83. segment_ids.append(0)
  84. for token in query_tokens:
  85. tokens.append(token)
  86. segment_ids.append(0)
  87. tokens.append("[SEP]")
  88. segment_ids.append(0)
  89. for i in range(doc_length):
  90. split_token_index = doc_start + i
  91. token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
  92. token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
  93. tokens.append(all_doc_tokens[split_token_index])
  94. segment_ids.append(1)
  95. tokens.append("[SEP]")
  96. segment_ids.append(1)
  97. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  98. input_mask = [1] * len(input_ids)
  99. while len(input_ids) < 384:
  100. input_ids.append(0)
  101. input_mask.append(0)
  102. segment_ids.append(0)
  103. assert len(input_ids) == 384
  104. assert len(input_mask) == 384
  105. assert len(segment_ids) == 384
  106. outputs.append({
  107. "input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
  108. "input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
  109. "segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
  110. "token_to_orig_map": token_to_orig_map,
  111. "token_is_max_context": token_is_max_context,
  112. "tokens": tokens,
  113. })
  114. return outputs
  115. def iterate(tokenizer, start=0):
  116. examples = init_dataset()
  117. print(f"there are {len(examples)} pairs in the dataset")
  118. for i in range(start, len(examples)):
  119. example = examples[i]
  120. features = convert_example_to_features(example, tokenizer)
  121. # we need to yield all features here as the f1 score is the maximum over all features
  122. yield features, example
  123. if __name__ == "__main__":
  124. tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
  125. X, Y = next(iterate(tokenizer))
  126. print(" ".join(X[0]["tokens"]))
  127. print(X[0]["input_ids"].shape, Y)