| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- # Test whether pretrained weights from first BERT pretraining phase have been loaded correctly
- # Usage:
- # 1. Download the BERT checkoints with `wikipedia_download.py`
- # Command: BASEDIR=/path/to/wiki python3 wikipedia_download.py
- # 2. Run this script. (Adjust EVAL_BS and GPUS as needed)
- # Command: EVAL_BEAM=4 DEFAULT_FLOAT=half GPUS=6 BASEDIR=/path/to/wiki python3 test/external/mlperf_bert/external_test_checkpoint_loading.py
- import os
- from tqdm import tqdm
- from tinygrad.tensor import Tensor
- from tinygrad.device import Device
- from tinygrad.helpers import getenv
- from tinygrad.nn.state import get_state_dict
- from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert
- from examples.mlperf.dataloader import batch_load_val_bert
- from examples.mlperf.model_train import eval_step_bert
- if __name__ == "__main__":
- BASEDIR = os.environ["BASEDIR"] = getenv("BASEDIR", "/raid/datasets/wiki")
- INIT_CKPT_DIR = getenv("INIT_CKPT_DIR", BASEDIR)
- GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
- EVAL_BS = getenv("EVAL_BS", 4 * len(GPUS))
- max_eval_steps = (10000 + EVAL_BS - 1) // EVAL_BS
- for i in range(10):
- assert os.path.exists(os.path.join(BASEDIR, "eval", f"{i}.pkl")), \
- f"File {i}.pkl does not exist in {os.path.join(BASEDIR, 'eval')}"
- required_files = ["checkpoint", "model.ckpt-28252.data-00000-of-00001", "model.ckpt-28252.index", "model.ckpt-28252.meta"]
- assert all(os.path.exists(os.path.join(INIT_CKPT_DIR, f)) for f in required_files), \
- f"Missing checkpoint files in INIT_CKPT_DIR: {required_files}"
- Tensor.training = False
- model = get_mlperf_bert_model(INIT_CKPT_DIR)
- for _, x in get_state_dict(model).items():
- x.realize().to_(GPUS)
- eval_accuracy = []
- eval_it = iter(batch_load_val_bert(EVAL_BS))
- for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps):
- eval_data = get_data_bert(GPUS, eval_it)
- eval_result: dict[str, Tensor] = eval_step_bert(model, eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], \
- eval_data["masked_lm_positions"], eval_data["masked_lm_ids"], \
- eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
- mlm_accuracy = eval_result["masked_lm_accuracy"].numpy().item()
- eval_accuracy.append(mlm_accuracy)
- total_lm_accuracy = sum(eval_accuracy) / len(eval_accuracy)
- assert total_lm_accuracy >= 0.34, "Checkpoint loaded incorrectly. Accuracy should be very close to 0.34085 as per MLPerf BERT README."
- print(f"Checkpoint loaded correctly. Accuracy of {total_lm_accuracy*100:.3f}% achieved. (Reference: 34.085%)")
|