| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- import json
- import os
- from pathlib import Path
- from transformers import BertTokenizer
- import numpy as np
- from tinygrad.helpers import fetch
- BASEDIR = Path(__file__).parent / "squad"
- def init_dataset():
- os.makedirs(BASEDIR, exist_ok=True)
- fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
- with open(BASEDIR / "dev-v1.1.json") as f:
- data = json.load(f)["data"]
- examples = []
- for article in data:
- for paragraph in article["paragraphs"]:
- text = paragraph["context"]
- doc_tokens = []
- prev_is_whitespace = True
- for c in text:
- if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
- prev_is_whitespace = True
- else:
- if prev_is_whitespace:
- doc_tokens.append(c)
- else:
- doc_tokens[-1] += c
- prev_is_whitespace = False
- for qa in paragraph["qas"]:
- qa_id = qa["id"]
- q_text = qa["question"]
- examples.append({
- "id": qa_id,
- "question": q_text,
- "context": doc_tokens,
- "answers": list(map(lambda x: x["text"], qa["answers"]))
- })
- return examples
- def _check_is_max_context(doc_spans, cur_span_index, position):
- best_score, best_span_index = None, None
- for di, (doc_start, doc_length) in enumerate(doc_spans):
- end = doc_start + doc_length - 1
- if position < doc_start:
- continue
- if position > end:
- continue
- num_left_context = position - doc_start
- num_right_context = end - position
- score = min(num_left_context, num_right_context) + 0.01 * doc_length
- if best_score is None or score > best_score:
- best_score = score
- best_span_index = di
- return cur_span_index == best_span_index
- def convert_example_to_features(example, tokenizer):
- query_tokens = tokenizer.tokenize(example["question"])
- if len(query_tokens) > 64:
- query_tokens = query_tokens[:64]
- tok_to_orig_index = []
- orig_to_tok_index = []
- all_doc_tokens = []
- for i, token in enumerate(example["context"]):
- orig_to_tok_index.append(len(all_doc_tokens))
- sub_tokens = tokenizer.tokenize(token)
- for sub_token in sub_tokens:
- tok_to_orig_index.append(i)
- all_doc_tokens.append(sub_token)
- max_tokens_for_doc = 384 - len(query_tokens) - 3
- doc_spans = []
- start_offset = 0
- while start_offset < len(all_doc_tokens):
- length = len(all_doc_tokens) - start_offset
- length = min(length, max_tokens_for_doc)
- doc_spans.append((start_offset, length))
- if start_offset + length == len(all_doc_tokens):
- break
- start_offset += min(length, 128)
- outputs = []
- for di, (doc_start, doc_length) in enumerate(doc_spans):
- tokens = []
- token_to_orig_map = {}
- token_is_max_context = {}
- segment_ids = []
- tokens.append("[CLS]")
- segment_ids.append(0)
- for token in query_tokens:
- tokens.append(token)
- segment_ids.append(0)
- tokens.append("[SEP]")
- segment_ids.append(0)
- for i in range(doc_length):
- split_token_index = doc_start + i
- token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
- token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
- tokens.append(all_doc_tokens[split_token_index])
- segment_ids.append(1)
- tokens.append("[SEP]")
- segment_ids.append(1)
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
- input_mask = [1] * len(input_ids)
- while len(input_ids) < 384:
- input_ids.append(0)
- input_mask.append(0)
- segment_ids.append(0)
- assert len(input_ids) == 384
- assert len(input_mask) == 384
- assert len(segment_ids) == 384
- outputs.append({
- "input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
- "input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
- "segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
- "token_to_orig_map": token_to_orig_map,
- "token_is_max_context": token_is_max_context,
- "tokens": tokens,
- })
- return outputs
- def iterate(tokenizer, start=0):
- examples = init_dataset()
- print(f"there are {len(examples)} pairs in the dataset")
- for i in range(start, len(examples)):
- example = examples[i]
- features = convert_example_to_features(example, tokenizer)
- # we need to yield all features here as the f1 score is the maximum over all features
- yield features, example
- if __name__ == "__main__":
- tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
- X, Y = next(iterate(tokenizer))
- print(" ".join(X[0]["tokens"]))
- print(X[0]["input_ids"].shape, Y)
|