external_test_checkpoint_loading.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Test whether pretrained weights from first BERT pretraining phase have been loaded correctly
  2. # Usage:
  3. # 1. Download the BERT checkoints with `wikipedia_download.py`
  4. # Command: BASEDIR=/path/to/wiki python3 wikipedia_download.py
  5. # 2. Run this script. (Adjust EVAL_BS and GPUS as needed)
  6. # Command: EVAL_BEAM=4 DEFAULT_FLOAT=half GPUS=6 BASEDIR=/path/to/wiki python3 test/external/mlperf_bert/external_test_checkpoint_loading.py
  7. import os
  8. from tqdm import tqdm
  9. from tinygrad.tensor import Tensor
  10. from tinygrad.device import Device
  11. from tinygrad.helpers import getenv
  12. from tinygrad.nn.state import get_state_dict
  13. from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
  14. from examples.mlperf.dataloader import batch_load_val_bert
  15. from examples.mlperf.model_train import eval_step_bert
  16. if __name__ == "__main__":
  17. BASEDIR = os.environ["BASEDIR"] = getenv("BASEDIR", "/raid/datasets/wiki")
  18. INIT_CKPT_DIR = getenv("INIT_CKPT_DIR", BASEDIR)
  19. GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
  20. EVAL_BS = getenv("EVAL_BS", 4 * len(GPUS))
  21. max_eval_steps = (10000 + EVAL_BS - 1) // EVAL_BS
  22. for i in range(10):
  23. assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
  24. f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
  25. required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
  26. assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
  27. f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
  28. Tensor.training = False
  29. model = get_mlperf_bert_model(INIT_CKPT_DIR)
  30. for _, x in get_state_dict(model).items():
  31. x.realize().to_(GPUS)
  32. eval_accuracy = []
  33. eval_it = iter(batch_load_val_bert(EVAL_BS))
  34. for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps):
  35. eval_data = get_data_bert(GPUS, eval_it)
  36. eval_result: dict[str, Tensor] = eval_step_bert(model, eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], \
  37. eval_data["masked_lm_positions"], eval_data["masked_lm_ids"], \
  38. eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
  39. mlm_accuracy = eval_result["masked_lm_accuracy"].numpy().item()
  40. eval_accuracy.append(mlm_accuracy)
  41. total_lm_accuracy = sum(eval_accuracy) / len(eval_accuracy)
  42. assert total_lm_accuracy >= 0.34, "Checkpoint loaded incorrectly. Accuracy should be very close to 0.34085 as per MLPerf BERT README."
  43. print(f"Checkpoint loaded correctly. Accuracy of {total_lm_accuracy*100:.3f}% achieved. (Reference: 34.085%)")