wikipedia.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
  2. # This is a modified version of the original script:
  3. # https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
  4. # ENV VARS:
  5. # MAX_SEQ_LENGTH - Maximum sequence length
  6. # MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
  7. # RANDOM_SEED - Random seed
  8. # DUPE_FACTOR - Number of times to duplicate the input data with different masks
  9. # MASKED_LM_PROB - Probability of masking a token
  10. # SHORT_SEQ_PROB - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
  11. import os, sys, pickle, random, unicodedata
  12. from pathlib import Path
  13. import numpy as np
  14. from tqdm import tqdm
  15. from tqdm.contrib.concurrent import process_map
  16. from tinygrad.helpers import diskcache, getenv
  17. BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
  18. ################### Tokenization #####################
  19. def _is_whitespace(char:str) -> bool:
  20. if char == " " or char == "\t" or char == "\n" or char == "\r":
  21. return True
  22. return unicodedata.category(char) == "Zs"
  23. def _is_control(char:str) -> bool:
  24. if char == "\t" or char == "\n" or char == "\r":
  25. return False
  26. return unicodedata.category(char).startswith("C")
  27. def _is_punctuation(char:str) -> bool:
  28. # range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
  29. # range(58, 65) -> : ; < = > ? @
  30. # range(91, 97) -> [ \ ] ^ _
  31. # range(123, 127) -> { | } ~
  32. if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
  33. return True
  34. return unicodedata.category(char).startswith("P")
  35. def _is_chinese_char(cp:int) -> bool:
  36. if ((cp >= 0x4E00 and cp <= 0x9FFF) or
  37. (cp >= 0x3400 and cp <= 0x4DBF) or
  38. (cp >= 0x20000 and cp <= 0x2A6DF) or
  39. (cp >= 0x2A700 and cp <= 0x2B73F) or
  40. (cp >= 0x2B740 and cp <= 0x2B81F) or
  41. (cp >= 0x2B820 and cp <= 0x2CEAF) or
  42. (cp >= 0xF900 and cp <= 0xFAFF) or
  43. (cp >= 0x2F800 and cp <= 0x2FA1F)):
  44. return True
  45. return False
  46. def _run_split_on_punc(text:str) -> list[str]:
  47. if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
  48. return [text]
  49. start_new_word = True
  50. output = []
  51. for i in range(len(text)):
  52. if _is_punctuation(char := text[i]):
  53. output.append([char])
  54. start_new_word = True
  55. else:
  56. if start_new_word:
  57. output.append([])
  58. start_new_word = False
  59. output[-1].append(char)
  60. return ["".join(x) for x in output]
  61. def _run_strip_accents(text:str) -> str:
  62. output = []
  63. for char in unicodedata.normalize("NFD", text):
  64. if unicodedata.category(char) != "Mn":
  65. output.append(char)
  66. return "".join(output)
  67. def _clean_text(text:str) -> str:
  68. output = []
  69. for char in text:
  70. if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
  71. output.append(" " if _is_whitespace(char) else char)
  72. return "".join(output)
  73. def _tokenize_chinese_chars(text:str) -> str:
  74. output = []
  75. for char in text:
  76. cp = ord(char)
  77. if _is_chinese_char(cp):
  78. output.append(" ")
  79. output.append(char)
  80. output.append(" ")
  81. else:
  82. output.append(char)
  83. return "".join(output)
  84. def whitespace_tokenize(text):
  85. if not (text := text.strip()): return []
  86. return text.split()
  87. def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
  88. text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
  89. output_tokens = []
  90. for token in text.strip().split():
  91. chars = list(token)
  92. if len(chars) > 200:
  93. output_tokens.append("[UNK]")
  94. continue
  95. is_bad = False
  96. start = 0
  97. sub_tokens = []
  98. while start < len(chars):
  99. end = len(chars)
  100. cur_substr = None
  101. while start < end:
  102. substr = "".join(chars[start:end])
  103. if start > 0: substr = "##" + substr
  104. if substr in vocab:
  105. cur_substr = substr
  106. break
  107. end -= 1
  108. if cur_substr is None:
  109. is_bad = True
  110. break
  111. sub_tokens.append(cur_substr)
  112. start = end
  113. if is_bad: output_tokens.append("[UNK]")
  114. else: output_tokens.extend(sub_tokens)
  115. return output_tokens
  116. class Tokenizer:
  117. def __init__(self, vocab_file):
  118. self.vocab = {}
  119. with open(vocab_file) as f:
  120. for line in f:
  121. line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
  122. if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
  123. self.inv_vocab = {v: k for k, v in self.vocab.items()}
  124. def tokenize(self, text:str) -> list[str]:
  125. # BasicTokenizer
  126. split_tokens = []
  127. for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
  128. split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
  129. split_tokens = " ".join(split_tokens).strip().split()
  130. # WordpieceTokenizer
  131. tokens = []
  132. for token in split_tokens:
  133. tokens.extend(_wordpiece_tokenize(token, self.vocab))
  134. return tokens
  135. def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
  136. def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
  137. ##################### Feature transformation #####################
  138. def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
  139. while True:
  140. total_length = len(tokens_a) + len(tokens_b)
  141. if total_length <= max_num_tokens:
  142. break
  143. trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
  144. assert len(trunc_tokens) >= 1
  145. if rng.random() < 0.5:
  146. del trunc_tokens[0]
  147. else:
  148. trunc_tokens.pop()
  149. def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
  150. cand_indices = []
  151. for i, token in enumerate(tokens):
  152. if token == "[CLS]" or token == "[SEP]":
  153. continue
  154. cand_indices.append(i)
  155. rng.shuffle(cand_indices)
  156. output_tokens = list(tokens)
  157. num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
  158. masked_lms = []
  159. covered_indices = set()
  160. for index in cand_indices:
  161. if len(masked_lms) >= num_to_predict:
  162. break
  163. if index in covered_indices:
  164. continue
  165. covered_indices.add(index)
  166. masked_token = None
  167. if rng.random() < 0.8:
  168. masked_token = "[MASK]"
  169. else:
  170. if rng.random() < 0.5:
  171. masked_token = tokens[index]
  172. else:
  173. masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
  174. output_tokens[index] = masked_token
  175. masked_lms.append((index, tokens[index]))
  176. masked_lms = sorted(masked_lms, key=lambda x: x[0])
  177. masked_lm_positions = []
  178. masked_lm_labels = []
  179. for p in masked_lms:
  180. masked_lm_positions.append(p[0])
  181. masked_lm_labels.append(p[1])
  182. return output_tokens, masked_lm_positions, masked_lm_labels
  183. def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
  184. max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
  185. target_seq_length = max_num_tokens
  186. if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
  187. target_seq_length = rng.randint(2, max_num_tokens)
  188. instances = []
  189. current_chunk = []
  190. current_length = 0
  191. i = 0
  192. while i < len(doc):
  193. segment = doc[i]
  194. current_chunk.append(segment)
  195. current_length += len(segment)
  196. if i == len(doc) - 1 or current_length >= target_seq_length:
  197. if current_chunk:
  198. a_end = 1
  199. if len(current_chunk) >= 2:
  200. a_end = rng.randint(1, len(current_chunk) - 1)
  201. tokens_a = []
  202. for j in range(a_end):
  203. tokens_a.extend(current_chunk[j])
  204. tokens_b = []
  205. is_random_next = False
  206. if len(current_chunk) == 1 or rng.random() < 0.5:
  207. is_random_next = True
  208. target_b_length = target_seq_length - len(tokens_a)
  209. for _ in range(10):
  210. random_document_index = rng.randint(0, len(documents) - 1)
  211. if random_document_index != di:
  212. break
  213. random_document = documents[random_document_index]
  214. random_start = rng.randint(0, len(random_document) - 1)
  215. for j in range(random_start, len(random_document)):
  216. tokens_b.extend(random_document[j])
  217. if len(tokens_b) >= target_b_length:
  218. break
  219. num_unused_segments = len(current_chunk) - a_end
  220. i -= num_unused_segments
  221. else:
  222. is_random_next = False
  223. for j in range(a_end, len(current_chunk)):
  224. tokens_b.extend(current_chunk[j])
  225. truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
  226. assert len(tokens_a) >= 1
  227. assert len(tokens_b) >= 1
  228. tokens = []
  229. segment_ids = []
  230. tokens.append("[CLS]")
  231. segment_ids.append(0)
  232. for token in tokens_a:
  233. tokens.append(token)
  234. segment_ids.append(0)
  235. tokens.append("[SEP]")
  236. segment_ids.append(0)
  237. for token in tokens_b:
  238. tokens.append(token)
  239. segment_ids.append(1)
  240. tokens.append("[SEP]")
  241. segment_ids.append(1)
  242. tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
  243. instances.append({
  244. "tokens": tokens,
  245. "segment_ids": segment_ids,
  246. "masked_lm_positions": masked_lm_positions,
  247. "masked_lm_labels": masked_lm_labels,
  248. "is_random_next": is_random_next
  249. })
  250. current_chunk = []
  251. current_length = 0
  252. i += 1
  253. return instances
  254. def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
  255. documents = [[]]
  256. with open(BASEDIR / fn) as f:
  257. for line in f.readlines():
  258. if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
  259. if not (line := line.strip()): documents.append([])
  260. if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
  261. documents = [x for x in documents if x]
  262. rng.shuffle(documents)
  263. return documents
  264. def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
  265. instances = []
  266. for _ in range(getenv('DUPE_FACTOR', 10)):
  267. for di, doc in enumerate(documents):
  268. instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
  269. rng.shuffle(instances)
  270. return instances
  271. def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
  272. input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
  273. input_mask = [1] * len(input_ids)
  274. segment_ids = instance["segment_ids"]
  275. max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
  276. assert len(input_ids) <= max_seq_length
  277. while len(input_ids) < max_seq_length:
  278. input_ids.append(0)
  279. input_mask.append(0)
  280. segment_ids.append(0)
  281. assert len(input_ids) == max_seq_length
  282. assert len(input_mask) == max_seq_length
  283. assert len(segment_ids) == max_seq_length
  284. masked_lm_positions = instance["masked_lm_positions"]
  285. masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
  286. masked_lm_weights = [1.0] * len(masked_lm_ids)
  287. while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
  288. masked_lm_positions.append(0)
  289. masked_lm_ids.append(0)
  290. masked_lm_weights.append(0.0)
  291. next_sentence_label = 1 if instance["is_random_next"] else 0
  292. return {
  293. "input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
  294. "input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
  295. "segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
  296. "masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
  297. "masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
  298. "masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
  299. "next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
  300. }
  301. def process_part(part:int):
  302. tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
  303. os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
  304. for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
  305. with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
  306. pickle.dump(feature_batch, f)
  307. def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
  308. rng = random.Random(getenv('RANDOM_SEED', 12345))
  309. if val:
  310. tqdm.write("Getting samples from dataset")
  311. documents = get_documents(rng, tokenizer, "results4/eval.txt")
  312. instances = get_instances(rng, tokenizer, documents)
  313. tqdm.write(f"There are {len(instances)} samples in the dataset")
  314. tqdm.write(f"Picking 10000 samples")
  315. pick_ratio = len(instances) / 10000
  316. picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
  317. for batch in range(10):
  318. yield picks[batch*1000:(batch+1)*1000]
  319. else:
  320. documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
  321. instances = get_instances(rng, tokenizer, documents)
  322. while len(instances) > 0:
  323. batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
  324. batch = instances[:batch_size]
  325. del instances[:batch_size]
  326. yield [instance_to_features(instance, tokenizer) for instance in batch]
  327. ##################### Load files #####################
  328. def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
  329. @diskcache
  330. def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
  331. if __name__ == "__main__":
  332. tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
  333. assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
  334. if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
  335. os.makedirs(BASEDIR / "eval", exist_ok=True)
  336. for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
  337. with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
  338. pickle.dump(feature_batch, f)
  339. elif sys.argv[1] == "pre-train":
  340. os.makedirs(BASEDIR / "train", exist_ok=True)
  341. if sys.argv[2] == "all": # Use all 500 parts for training generation
  342. process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
  343. else: # Use a specific part for training generation
  344. part = int(sys.argv[2])
  345. os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
  346. for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
  347. with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
  348. pickle.dump(feature_batch, f)