| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 |
- # Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
- # This is a modified version of the original script:
- # https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
- # ENV VARS:
- # MAX_SEQ_LENGTH - Maximum sequence length
- # MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
- # RANDOM_SEED - Random seed
- # DUPE_FACTOR - Number of times to duplicate the input data with different masks
- # MASKED_LM_PROB - Probability of masking a token
- # SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
- import os, sys, pickle, random, unicodedata
- from pathlib import Path
- import numpy as np
- from tqdm import tqdm
- from tqdm.contrib.concurrent import process_map
- from tinygrad.helpers import diskcache, getenv
- BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
- ################### Tokenization #####################
- def _is_whitespace(char:str) -> bool:
- if char == " " or char == "\t" or char == "\n" or char == "\r":
- return True
- return unicodedata.category(char) == "Zs"
- def _is_control(char:str) -> bool:
- if char == "\t" or char == "\n" or char == "\r":
- return False
- return unicodedata.category(char).startswith("C")
- def _is_punctuation(char:str) -> bool:
- # range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
- # range(58, 65) -> : ; < = > ? @
- # range(91, 97) -> [ \ ] ^ _
- # range(123, 127) -> { | } ~
- if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
- return True
- return unicodedata.category(char).startswith("P")
- def _is_chinese_char(cp:int) -> bool:
- if ((cp >= 0x4E00 and cp <= 0x9FFF) or
- (cp >= 0x3400 and cp <= 0x4DBF) or
- (cp >= 0x20000 and cp <= 0x2A6DF) or
- (cp >= 0x2A700 and cp <= 0x2B73F) or
- (cp >= 0x2B740 and cp <= 0x2B81F) or
- (cp >= 0x2B820 and cp <= 0x2CEAF) or
- (cp >= 0xF900 and cp <= 0xFAFF) or
- (cp >= 0x2F800 and cp <= 0x2FA1F)):
- return True
- return False
- def _run_split_on_punc(text:str) -> list[str]:
- if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
- return [text]
- start_new_word = True
- output = []
- for i in range(len(text)):
- if _is_punctuation(char := text[i]):
- output.append([char])
- start_new_word = True
- else:
- if start_new_word:
- output.append([])
- start_new_word = False
- output[-1].append(char)
- return ["".join(x) for x in output]
- def _run_strip_accents(text:str) -> str:
- output = []
- for char in unicodedata.normalize("NFD", text):
- if unicodedata.category(char) != "Mn":
- output.append(char)
- return "".join(output)
- def _clean_text(text:str) -> str:
- output = []
- for char in text:
- if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
- output.append(" " if _is_whitespace(char) else char)
- return "".join(output)
- def _tokenize_chinese_chars(text:str) -> str:
- output = []
- for char in text:
- cp = ord(char)
- if _is_chinese_char(cp):
- output.append(" ")
- output.append(char)
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
- def whitespace_tokenize(text):
- if not (text := text.strip()): return []
- return text.split()
- def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
- text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
- output_tokens = []
- for token in text.strip().split():
- chars = list(token)
- if len(chars) > 200:
- output_tokens.append("[UNK]")
- continue
- is_bad = False
- start = 0
- sub_tokens = []
- while start < len(chars):
- end = len(chars)
- cur_substr = None
- while start < end:
- substr = "".join(chars[start:end])
- if start > 0: substr = "##" + substr
- if substr in vocab:
- cur_substr = substr
- break
- end -= 1
- if cur_substr is None:
- is_bad = True
- break
- sub_tokens.append(cur_substr)
- start = end
- if is_bad: output_tokens.append("[UNK]")
- else: output_tokens.extend(sub_tokens)
- return output_tokens
- class Tokenizer:
- def __init__(self, vocab_file):
- self.vocab = {}
- with open(vocab_file) as f:
- for line in f:
- line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
- if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
- self.inv_vocab = {v: k for k, v in self.vocab.items()}
- def tokenize(self, text:str) -> list[str]:
- # BasicTokenizer
- split_tokens = []
- for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
- split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
- split_tokens = " ".join(split_tokens).strip().split()
- # WordpieceTokenizer
- tokens = []
- for token in split_tokens:
- tokens.extend(_wordpiece_tokenize(token, self.vocab))
- return tokens
- def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
- def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
- ##################### Feature transformation #####################
- def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
- while True:
- total_length = len(tokens_a) + len(tokens_b)
- if total_length <= max_num_tokens:
- break
- trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
- assert len(trunc_tokens) >= 1
- if rng.random() < 0.5:
- del trunc_tokens[0]
- else:
- trunc_tokens.pop()
- def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
- cand_indices = []
- for i, token in enumerate(tokens):
- if token == "[CLS]" or token == "[SEP]":
- continue
- cand_indices.append(i)
- rng.shuffle(cand_indices)
- output_tokens = list(tokens)
- num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
- masked_lms = []
- covered_indices = set()
- for index in cand_indices:
- if len(masked_lms) >= num_to_predict:
- break
- if index in covered_indices:
- continue
- covered_indices.add(index)
- masked_token = None
- if rng.random() < 0.8:
- masked_token = "[MASK]"
- else:
- if rng.random() < 0.5:
- masked_token = tokens[index]
- else:
- masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
- output_tokens[index] = masked_token
- masked_lms.append((index, tokens[index]))
- masked_lms = sorted(masked_lms, key=lambda x: x[0])
- masked_lm_positions = []
- masked_lm_labels = []
- for p in masked_lms:
- masked_lm_positions.append(p[0])
- masked_lm_labels.append(p[1])
- return output_tokens, masked_lm_positions, masked_lm_labels
- def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
- max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
- target_seq_length = max_num_tokens
- if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
- target_seq_length = rng.randint(2, max_num_tokens)
- instances = []
- current_chunk = []
- current_length = 0
- i = 0
- while i < len(doc):
- segment = doc[i]
- current_chunk.append(segment)
- current_length += len(segment)
- if i == len(doc) - 1 or current_length >= target_seq_length:
- if current_chunk:
- a_end = 1
- if len(current_chunk) >= 2:
- a_end = rng.randint(1, len(current_chunk) - 1)
- tokens_a = []
- for j in range(a_end):
- tokens_a.extend(current_chunk[j])
- tokens_b = []
- is_random_next = False
- if len(current_chunk) == 1 or rng.random() < 0.5:
- is_random_next = True
- target_b_length = target_seq_length - len(tokens_a)
- for _ in range(10):
- random_document_index = rng.randint(0, len(documents) - 1)
- if random_document_index != di:
- break
- random_document = documents[random_document_index]
- random_start = rng.randint(0, len(random_document) - 1)
- for j in range(random_start, len(random_document)):
- tokens_b.extend(random_document[j])
- if len(tokens_b) >= target_b_length:
- break
- num_unused_segments = len(current_chunk) - a_end
- i -= num_unused_segments
- else:
- is_random_next = False
- for j in range(a_end, len(current_chunk)):
- tokens_b.extend(current_chunk[j])
- truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
- assert len(tokens_a) >= 1
- assert len(tokens_b) >= 1
- tokens = []
- segment_ids = []
- tokens.append("[CLS]")
- segment_ids.append(0)
- for token in tokens_a:
- tokens.append(token)
- segment_ids.append(0)
- tokens.append("[SEP]")
- segment_ids.append(0)
- for token in tokens_b:
- tokens.append(token)
- segment_ids.append(1)
- tokens.append("[SEP]")
- segment_ids.append(1)
- tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
- instances.append({
- "tokens": tokens,
- "segment_ids": segment_ids,
- "masked_lm_positions": masked_lm_positions,
- "masked_lm_labels": masked_lm_labels,
- "is_random_next": is_random_next
- })
- current_chunk = []
- current_length = 0
- i += 1
- return instances
- def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
- documents = [[]]
- with open(BASEDIR / fn) as f:
- for line in f.readlines():
- if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
- if not (line := line.strip()): documents.append([])
- if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
- documents = [x for x in documents if x]
- rng.shuffle(documents)
- return documents
- def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
- instances = []
- for _ in range(getenv('DUPE_FACTOR', 10)):
- for di, doc in enumerate(documents):
- instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
- rng.shuffle(instances)
- return instances
- def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
- input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
- input_mask = [1] * len(input_ids)
- segment_ids = instance["segment_ids"]
- max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
- assert len(input_ids) <= max_seq_length
- while len(input_ids) < max_seq_length:
- input_ids.append(0)
- input_mask.append(0)
- segment_ids.append(0)
- assert len(input_ids) == max_seq_length
- assert len(input_mask) == max_seq_length
- assert len(segment_ids) == max_seq_length
- masked_lm_positions = instance["masked_lm_positions"]
- masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
- masked_lm_weights = [1.0] * len(masked_lm_ids)
- while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
- masked_lm_positions.append(0)
- masked_lm_ids.append(0)
- masked_lm_weights.append(0.0)
- next_sentence_label = 1 if instance["is_random_next"] else 0
- return {
- "input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
- "input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
- "segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
- "masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
- "masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
- "masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
- "next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
- }
- def process_part(part:int):
- tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
- os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
- for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
- with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
- pickle.dump(feature_batch, f)
- def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
- rng = random.Random(getenv('RANDOM_SEED', 12345))
- if val:
- tqdm.write("Getting samples from dataset")
- documents = get_documents(rng, tokenizer, "results4/eval.txt")
- instances = get_instances(rng, tokenizer, documents)
- tqdm.write(f"There are {len(instances)} samples in the dataset")
- tqdm.write(f"Picking 10000 samples")
- pick_ratio = len(instances) / 10000
- picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
- for batch in range(10):
- yield picks[batch*1000:(batch+1)*1000]
- else:
- documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
- instances = get_instances(rng, tokenizer, documents)
- while len(instances) > 0:
- batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
- batch = instances[:batch_size]
- del instances[:batch_size]
- yield [instance_to_features(instance, tokenizer) for instance in batch]
- ##################### Load files #####################
- def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
- @diskcache
- def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
- if __name__ == "__main__":
- tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
- assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
- if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
- os.makedirs(BASEDIR / "eval", exist_ok=True)
- for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
- with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
- pickle.dump(feature_batch, f)
- elif sys.argv[1] == "pre-train":
- os.makedirs(BASEDIR / "train", exist_ok=True)
- if sys.argv[2] == "all": # Use all 500 parts for training generation
- process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
- else: # Use a specific part for training generation
- part = int(sys.argv[2])
- os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
- for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
- with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
- pickle.dump(feature_batch, f)
|