model_train.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. import os, time, math, functools
  2. from pathlib import Path
  3. from tqdm import tqdm
  4. import multiprocessing
  5. from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
  6. from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear
  7. from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
  8. from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
  9. from extra.lr_scheduler import LRSchedulerGroup
  10. from examples.mlperf.helpers import get_training_state, load_training_state
  11. def train_resnet():
  12. from extra.models import resnet
  13. from examples.mlperf.dataloader import batch_load_resnet
  14. from extra.datasets.imagenet import get_train_files, get_val_files
  15. from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
  16. from examples.mlperf.initializers import Conv2dHeNormal, Linear
  17. from examples.hlb_cifar10 import UnsyncedBatchNorm
  18. config = {}
  19. seed = config["seed"] = getenv("SEED", 42)
  20. Tensor.manual_seed(seed) # seed for weight initialization
  21. INITMLPERF = getenv("INITMLPERF")
  22. RUNMLPERF = getenv("RUNMLPERF")
  23. if getenv("LOGMLPERF"):
  24. from mlperf_logging import mllog
  25. import mlperf_logging.mllog.constants as mllog_constants
  26. mllog.config(filename=f"result_{seed}.txt")
  27. mllog.config(root_dir=Path(__file__).parents[3].as_posix()) # truncate to log this. "file": "tinygrad/examples/mlperf/model_train.py"
  28. MLLOGGER = mllog.get_mllogger()
  29. if INITMLPERF:
  30. # common.yaml
  31. MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
  32. MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
  33. MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
  34. MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
  35. # closed_common.yaml
  36. MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.RESNET)
  37. diskcache_clear()
  38. MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
  39. MLLOGGER.start(key=mllog_constants.INIT_START)
  40. if RUNMLPERF:
  41. MLLOGGER.start(key=mllog_constants.RUN_START)
  42. MLLOGGER.event(key=mllog_constants.SEED, value=seed)
  43. else:
  44. MLLOGGER = None
  45. GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
  46. print(f"training on {GPUS}")
  47. for x in GPUS: Device[x]
  48. TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
  49. EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
  50. # ** model definition and initializers **
  51. num_classes = 1000
  52. resnet.Conv2d = Conv2dHeNormal
  53. resnet.Linear = Linear
  54. if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=len(GPUS))
  55. model = resnet.ResNet50(num_classes)
  56. # shard weights and initialize in order
  57. for k, x in get_state_dict(model).items():
  58. if not getenv("SYNCBN") and ("running_mean" in k or "running_var" in k):
  59. x.realize().shard_(GPUS, axis=0)
  60. else:
  61. x.realize().to_(GPUS)
  62. parameters = get_parameters(model)
  63. # ** hyperparameters **
  64. epochs = config["epochs"] = getenv("EPOCHS", 37)
  65. BS = config["BS"] = getenv("BS", 104 * len(GPUS)) # fp32 GPUS<=6 7900xtx can fit BS=112
  66. EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", BS)
  67. base_lr = config["base_lr"] = getenv("LR", 7.2 * (BS/1536))
  68. lr_warmup_epochs = config["lr_warmup_epochs"] = getenv("WARMUP_EPOCHS", 2)
  69. decay = config["decay"] = getenv("DECAY", 2e-4)
  70. loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 128.0 if dtypes.default_float == dtypes.float16 else 1.0)
  71. target, achieved = getenv("TARGET", 0.759), False
  72. eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
  73. eval_freq = getenv("EVAL_FREQ", 1)
  74. steps_in_train_epoch = config["steps_in_train_epoch"] = (round_up(len(get_train_files()), BS) // BS)
  75. steps_in_val_epoch = config["steps_in_val_epoch"] = (round_up(len(get_val_files()), EVAL_BS) // EVAL_BS)
  76. config["DEFAULT_FLOAT"] = dtypes.default_float.name
  77. config["BEAM"] = BEAM.value
  78. config["TRAIN_BEAM"] = TRAIN_BEAM
  79. config["EVAL_BEAM"] = EVAL_BEAM
  80. config["WINO"] = WINO.value
  81. config["SYNCBN"] = getenv("SYNCBN")
  82. # ** Optimizer **
  83. skip_list = [v for k, v in get_state_dict(model).items() if "bn" in k or "bias" in k or "downsample.1" in k]
  84. parameters = [x for x in parameters if x not in set(skip_list)]
  85. optimizer = LARS(parameters, base_lr, momentum=.9, weight_decay=decay)
  86. optimizer_skip = SGD(skip_list, base_lr, momentum=.9, weight_decay=0.0, classic=True)
  87. optimizer_group = OptimizerGroup(optimizer, optimizer_skip)
  88. # ** LR scheduler **
  89. scheduler = PolynomialDecayWithWarmup(optimizer, initial_lr=base_lr, end_lr=1e-4,
  90. train_steps=epochs * steps_in_train_epoch,
  91. warmup=lr_warmup_epochs * steps_in_train_epoch)
  92. scheduler_skip = PolynomialDecayWithWarmup(optimizer_skip, initial_lr=base_lr, end_lr=1e-4,
  93. train_steps=epochs * steps_in_train_epoch,
  94. warmup=lr_warmup_epochs * steps_in_train_epoch)
  95. scheduler_group = LRSchedulerGroup(scheduler, scheduler_skip)
  96. print(f"training with batch size {BS} for {epochs} epochs")
  97. # log mlperf hparams
  98. if MLLOGGER:
  99. if RUNMLPERF:
  100. MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=BS)
  101. from extra.datasets.imagenet import get_train_files, get_val_files
  102. MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=len(get_train_files()))
  103. MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=len(get_val_files()))
  104. MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
  105. MLLOGGER.event(key=mllog_constants.OPT_NAME, value="lars")
  106. assert scheduler.initial_lr == scheduler_skip.initial_lr
  107. assert scheduler.end_lr == scheduler_skip.end_lr
  108. assert scheduler.power == scheduler_skip.power
  109. MLLOGGER.event(key=mllog_constants.LARS_OPT_BASE_LEARNING_RATE, value=scheduler.initial_lr)
  110. MLLOGGER.event(key=mllog_constants.LARS_OPT_END_LR, value=scheduler.end_lr)
  111. MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_POLY_POWER, value=scheduler.power)
  112. MLLOGGER.event(key=mllog_constants.LARS_OPT_LR_DECAY_STEPS, value=epochs)
  113. MLLOGGER.event(key=mllog_constants.LARS_EPSILON, value=0) # does not support epsilon != 0
  114. MLLOGGER.event(key=mllog_constants.LARS_OPT_LEARNING_RATE_WARMUP_EPOCHS, value=lr_warmup_epochs)
  115. MLLOGGER.event(key=mllog_constants.LARS_OPT_MOMENTUM, value=optimizer.momentum)
  116. MLLOGGER.event(key=mllog_constants.LARS_OPT_WEIGHT_DECAY, value=optimizer.wd)
  117. # ** resume from checkpointing **
  118. start_epoch = 0
  119. if ckpt:=getenv("RESUME", ""):
  120. load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
  121. start_epoch = int(scheduler.epoch_counter.numpy().item() / steps_in_train_epoch)
  122. print(f"resuming from {ckpt} at epoch {start_epoch}")
  123. # ** init wandb **
  124. WANDB = getenv("WANDB")
  125. if WANDB:
  126. import wandb
  127. wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
  128. wandb.init(config=config, **wandb_args)
  129. BENCHMARK = getenv("BENCHMARK")
  130. # ** jitted steps **
  131. input_mean = Tensor([123.68, 116.78, 103.94], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
  132. # mlperf reference resnet does not divide by input_std for some reason
  133. # input_std = Tensor([0.229, 0.224, 0.225], device=GPUS, dtype=dtypes.float32).reshape(1, -1, 1, 1)
  134. def normalize(x): return (x.permute([0, 3, 1, 2]) - input_mean).cast(dtypes.default_float)
  135. @TinyJit
  136. def train_step(X, Y):
  137. optimizer_group.zero_grad()
  138. X = normalize(X)
  139. out = model.forward(X)
  140. loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
  141. top_1 = (out.argmax(-1) == Y).sum()
  142. (loss * loss_scaler).backward()
  143. for t in optimizer_group.params: t.grad = t.grad.contiguous() / loss_scaler
  144. optimizer_group.step()
  145. scheduler_group.step()
  146. return loss.realize(), top_1.realize()
  147. @TinyJit
  148. def eval_step(X, Y):
  149. X = normalize(X)
  150. out = model.forward(X)
  151. loss = out.cast(dtypes.float32).sparse_categorical_crossentropy(Y, label_smoothing=0.1)
  152. top_1 = (out.argmax(-1) == Y).sum()
  153. return loss.realize(), top_1.realize()
  154. def fake_data_get(batch_size):
  155. x = Tensor.zeros(batch_size, 224, 224, 3, dtype=dtypes.uchar).contiguous()
  156. y = [0] * batch_size
  157. return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, None
  158. def data_get(it):
  159. x, y, cookie = next(it)
  160. return x.shard(GPUS, axis=0).realize(), Tensor(y, requires_grad=False).shard(GPUS, axis=0), y, cookie
  161. # ** epoch loop **
  162. step_times = []
  163. for e in range(start_epoch, epochs):
  164. # ** train loop **
  165. if MLLOGGER and RUNMLPERF:
  166. MLLOGGER.start(key=mllog_constants.EPOCH_START, value=e+1, metadata=dict(epoch_num=e+1))
  167. Tensor.training = True
  168. BEAM.value = TRAIN_BEAM
  169. if INITMLPERF:
  170. i, proc = 0, fake_data_get(BS)
  171. else:
  172. batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e, pad_first_batch=True)
  173. it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
  174. i, proc = 0, data_get(it)
  175. prev_cookies = []
  176. st = time.perf_counter()
  177. while proc is not None:
  178. GlobalCounters.reset()
  179. (loss, top_1), y, proc = train_step(proc[0], proc[1]), proc[2], proc[3]
  180. pt = time.perf_counter()
  181. if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
  182. try:
  183. if INITMLPERF:
  184. next_proc = fake_data_get(BS)
  185. else:
  186. next_proc = data_get(it)
  187. except StopIteration:
  188. next_proc = None
  189. dt = time.perf_counter()
  190. device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
  191. loss, top_1 = loss.numpy().item(), top_1.numpy().item()
  192. top_1_acc = top_1 / sum(yi != -1 for yi in y)
  193. cl = time.perf_counter()
  194. if BENCHMARK:
  195. step_times.append(cl - st)
  196. tqdm.write(
  197. f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
  198. f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {top_1_acc:3.2f} acc, {optimizer.lr.numpy()[0]:.6f} LR, "
  199. f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
  200. if WANDB:
  201. wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/top_1_acc": top_1_acc, "train/step_time": cl - st,
  202. "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
  203. "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
  204. st = cl
  205. prev_cookies.append(proc)
  206. proc, next_proc = next_proc, None # return old cookie
  207. i += 1
  208. if i == BENCHMARK:
  209. assert not math.isnan(loss)
  210. median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
  211. estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
  212. print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
  213. print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
  214. f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
  215. # if we are doing beam search, run the first eval too
  216. if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
  217. return
  218. if MLLOGGER and RUNMLPERF:
  219. MLLOGGER.event(key=mllog_constants.EPOCH_STOP, value=e+1, metadata=dict(epoch_num=e+1))
  220. # ** eval loop **
  221. # always eval for epoch >= 33 to stop the clock as soon as eval target hits, it can converge in epoch in [33, 37]
  222. if steps_in_val_epoch > 0 and ((e + 1 - eval_start_epoch) % eval_freq == 0 or e + 1 >= 33):
  223. if MLLOGGER and RUNMLPERF:
  224. MLLOGGER.start(key=mllog_constants.EVAL_START, value=e+1, metadata=dict(epoch_num=e+1))
  225. if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :(
  226. eval_times = []
  227. eval_loss = 0.0
  228. eval_top_1 = 0
  229. eval_num_samples = 0
  230. Tensor.training = False
  231. BEAM.value = EVAL_BEAM
  232. if INITMLPERF:
  233. i, proc = 0, fake_data_get(EVAL_BS)
  234. else:
  235. it = iter(tqdm(batch_load_resnet(batch_size=EVAL_BS, val=True, shuffle=False, pad_first_batch=True), total=steps_in_val_epoch))
  236. i, proc = 0, data_get(it)
  237. prev_cookies = []
  238. while proc is not None:
  239. GlobalCounters.reset()
  240. st = time.time()
  241. (loss, top_1), y, proc = eval_step(proc[0], proc[1]), proc[2], proc[3] # drop inputs, keep cookie
  242. if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
  243. try:
  244. if INITMLPERF:
  245. next_proc = fake_data_get(EVAL_BS)
  246. else:
  247. next_proc = data_get(it)
  248. except StopIteration:
  249. next_proc = None
  250. loss, top_1 = loss.numpy().item(), top_1.numpy().item()
  251. num_samples = sum(yi != -1 for yi in y)
  252. eval_loss += loss * num_samples
  253. eval_top_1 += top_1
  254. eval_num_samples += num_samples
  255. prev_cookies.append(proc)
  256. proc, next_proc = next_proc, None
  257. i += 1
  258. if i == BENCHMARK:
  259. # assume INITMLPERF has BENCHMARK set
  260. if MLLOGGER and INITMLPERF:
  261. MLLOGGER.event(key=mllog_constants.INIT_STOP)
  262. return
  263. et = time.time()
  264. eval_times.append(et - st)
  265. if getenv("RESET_STEP", 1): eval_step.reset()
  266. if not BENCHMARK:
  267. assert eval_num_samples == len(get_val_files()), f"eval sample count mismatched. {eval_num_samples=} != {len(get_val_files())}"
  268. total_loss = eval_loss / eval_num_samples
  269. total_top_1 = eval_top_1 / eval_num_samples
  270. total_fw_time = sum(eval_times) / len(eval_times)
  271. tqdm.write(f"eval loss: {total_loss:.2f}, eval time: {total_fw_time:.2f}, eval top 1 acc: {total_top_1:.3f}")
  272. if WANDB:
  273. wandb.log({"eval/loss": total_loss, "eval/top_1_acc": total_top_1, "eval/forward_time": total_fw_time, "epoch": e + 1})
  274. if MLLOGGER and RUNMLPERF:
  275. MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=total_top_1, metadata=dict(epoch_num=e+1))
  276. MLLOGGER.event(key=mllog_constants.EVAL_STOP, value=e+1, metadata=dict(epoch_num=e+1))
  277. # save model if achieved target
  278. if not achieved and total_top_1 >= target:
  279. # stop once achieve the target
  280. if MLLOGGER and RUNMLPERF:
  281. MLLOGGER.event(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
  282. if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
  283. fn = f"./ckpts/resnet50_{seed}.safe"
  284. safe_save(get_state_dict(model), fn)
  285. print(f" *** Model saved to {fn} ***")
  286. achieved = True
  287. break
  288. # checkpoint every time we eval
  289. if getenv("CKPT"):
  290. if not os.path.exists("./ckpts"): os.mkdir("./ckpts")
  291. if WANDB and wandb.run is not None:
  292. fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe"
  293. else:
  294. fn = f"./ckpts/{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe"
  295. print(f"saving ckpt to {fn}")
  296. safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
  297. def train_retinanet():
  298. # TODO: Retinanet
  299. pass
  300. def train_unet3d():
  301. # TODO: Unet3d
  302. pass
  303. def train_rnnt():
  304. # TODO: RNN-T
  305. pass
  306. @TinyJit
  307. def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
  308. optimizer.zero_grad()
  309. lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
  310. loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
  311. (loss * loss_scaler).backward()
  312. global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device).realize()
  313. for p in optimizer.params:
  314. p.grad = p.grad / loss_scaler
  315. global_norm += p.grad.float().square().sum()
  316. global_norm = global_norm.sqrt()
  317. for p in optimizer.params: p.grad = (p.grad / Tensor.where(global_norm > 1.0, global_norm, 1.0)).cast(p.grad.dtype)
  318. optimizer.step()
  319. scheduler.step()
  320. return loss.realize()
  321. @TinyJit
  322. def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
  323. lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
  324. masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
  325. return {
  326. "masked_lm_accuracy": masked_lm_accuracy.realize(),
  327. "next_sentence_accuracy": seq_relationship_accuracy.realize(),
  328. "masked_lm_loss": masked_lm_loss.realize(),
  329. "next_sentence_loss": next_sentence_loss.realize()
  330. }
  331. def train_bert():
  332. # NOTE: pip install tensorflow, wandb required
  333. from examples.mlperf.dataloader import batch_load_train_bert, batch_load_val_bert
  334. from examples.mlperf.helpers import get_mlperf_bert_model, get_data_bert, get_fake_data_bert
  335. from examples.mlperf.lr_schedulers import PolynomialDecayWithWarmup
  336. config = {}
  337. BASEDIR = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki")
  338. GPUS = config["GPUS"] = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
  339. print(f"training on {GPUS}")
  340. for x in GPUS: Device[x]
  341. seed = config["seed"] = getenv("SEED", 12345)
  342. INITMLPERF = getenv("INITMLPERF")
  343. RUNMLPERF = getenv("RUNMLPERF")
  344. if getenv("LOGMLPERF"):
  345. from mlperf_logging import mllog
  346. import mlperf_logging.mllog.constants as mllog_constants
  347. mllog.config(filename="bert.log")
  348. mllog.config(root_dir=Path(__file__).parents[3].as_posix())
  349. MLLOGGER = mllog.get_mllogger()
  350. MLLOGGER.logger.propagate = False
  351. if INITMLPERF:
  352. assert BENCHMARK, f"BENCHMARK must be set for INITMLPERF"
  353. MLLOGGER.event(key=mllog_constants.SUBMISSION_ORG, value="tinycorp")
  354. MLLOGGER.event(key=mllog_constants.SUBMISSION_PLATFORM, value=getenv("SUBMISSION_PLATFORM", "tinybox"))
  355. MLLOGGER.event(key=mllog_constants.SUBMISSION_DIVISION, value=mllog_constants.CLOSED)
  356. MLLOGGER.event(key=mllog_constants.SUBMISSION_STATUS, value=mllog_constants.ONPREM)
  357. MLLOGGER.event(key=mllog_constants.SUBMISSION_BENCHMARK, value=mllog_constants.BERT)
  358. diskcache_clear()
  359. MLLOGGER.event(key=mllog_constants.CACHE_CLEAR, value=True)
  360. MLLOGGER.start(key=mllog_constants.INIT_START, value=None)
  361. if RUNMLPERF:
  362. MLLOGGER.start(key=mllog_constants.RUN_START, value=None)
  363. else:
  364. MLLOGGER = None
  365. # ** hyperparameters **
  366. BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 16 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
  367. EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
  368. max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00035 * math.sqrt(BS/256))
  369. train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS)
  370. warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
  371. max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
  372. eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
  373. save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000)
  374. keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5)
  375. save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts")
  376. init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)
  377. loss_scaler = config["LOSS_SCALER"] = getenv("LOSS_SCALER", 2.0**9 if dtypes.default_float == dtypes.float16 else 1.0)
  378. decay = config["DECAY"] = getenv("DECAY", 0.01)
  379. epsilon = config["EPSILON"] = getenv("EPSILON", 1e-6)
  380. poly_power = config["POLY_POWER"] = getenv("POLY_POWER", 1.0)
  381. target, achieved = getenv("TARGET", 0.72), False
  382. config["DEFAULT_FLOAT"] = dtypes.default_float.name
  383. config["DISABLE_DROPOUT"] = getenv("DISABLE_DROPOUT", 0)
  384. config["TRAIN_BEAM"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
  385. config["EVAL_BEAM"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
  386. Tensor.manual_seed(seed) # seed for weight initialization
  387. model = get_mlperf_bert_model(init_ckpt)
  388. for _, x in get_state_dict(model).items():
  389. x.realize().to_(GPUS)
  390. parameters = get_parameters(model)
  391. assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset"
  392. # ** Log run config **
  393. for key, value in config.items(): print(f'HParam: "{key}": {value}')
  394. # ** Optimizer **
  395. parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k]
  396. parameters = [x for x in parameters if x not in set(parameters_no_wd)]
  397. optimizer_wd = LAMB(parameters, lr=max_lr, eps=epsilon, weight_decay=decay, adam=False)
  398. optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=epsilon, weight_decay=0.0, adam=False)
  399. optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd)
  400. # ** LR scheduler **
  401. scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
  402. scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power)
  403. scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd)
  404. print(f"training with batch size {BS} for one epoch with {train_steps} steps")
  405. # log mlperf hparams
  406. if MLLOGGER:
  407. if RUNMLPERF:
  408. MLLOGGER.event(key=mllog_constants.GLOBAL_BATCH_SIZE, value=config["GLOBAL_BATCH_SIZE"])
  409. MLLOGGER.event(key=mllog_constants.MAX_SEQUENCE_LENGTH, value=512)
  410. MLLOGGER.event(key="max_predictions_per_seq", value=76)
  411. MLLOGGER.event(key=mllog_constants.OPT_NAME, value="LAMB")
  412. MLLOGGER.event(key=mllog_constants.OPT_BASE_LR, value=config["OPT_BASE_LEARNING_RATE"])
  413. MLLOGGER.event(key=mllog_constants.OPT_LAMB_WEIGHT_DECAY, value=config["DECAY"])
  414. MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_1, value=optimizer_wd.b1)
  415. MLLOGGER.event(key=mllog_constants.OPT_LAMB_BETA_2, value=optimizer_wd.b2)
  416. MLLOGGER.event(key=mllog_constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=config["POLY_POWER"])
  417. MLLOGGER.event(key=mllog_constants.OPT_LAMB_EPSILON, value=config["EPSILON"])
  418. MLLOGGER.event(key=mllog_constants.OPT_LR_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
  419. MLLOGGER.event(key=mllog_constants.NUM_WARMUP_STEPS, value=config["NUM_WARMUP_STEPS"])
  420. MLLOGGER.event(key='start_warmup_step', value=0)
  421. MLLOGGER.event(key='opt_learning_rate_training_steps', value=config["TRAIN_STEPS"])
  422. MLLOGGER.event(key=mllog_constants.GRADIENT_ACCUMULATION_STEPS, value=1)
  423. MLLOGGER.event(key=mllog_constants.EVAL_SAMPLES, value=config["EVAL_BS"] * config["MAX_EVAL_STEPS"])
  424. MLLOGGER.event(key=mllog_constants.TRAIN_SAMPLES, value=config["GLOBAL_BATCH_SIZE"] * config["TRAIN_STEPS"])
  425. # ** resume from checkpointing **
  426. start_step = 1
  427. previous_step = None
  428. if ckpt:=getenv("RESUME", ""):
  429. load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
  430. start_step = int(scheduler_wd.epoch_counter.numpy().item())
  431. print(f"resuming from {ckpt} at step {start_step}")
  432. # ** init wandb **
  433. WANDB = getenv("WANDB")
  434. if WANDB:
  435. import wandb
  436. wandb_args = {"id": wandb_id, "resume": "must"} if (wandb_id := getenv("WANDB_RESUME", "")) else {}
  437. wandb.init(config=config, **wandb_args, project="MLPerf-BERT")
  438. BENCHMARK = getenv("BENCHMARK")
  439. if not INITMLPERF:
  440. eval_it = iter(batch_load_val_bert(EVAL_BS))
  441. train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
  442. step_times = []
  443. # ** train loop **
  444. wc_start = time.perf_counter()
  445. if INITMLPERF:
  446. i, train_data = start_step, get_fake_data_bert(GPUS, BS)
  447. else:
  448. i, train_data = start_step, get_data_bert(GPUS, train_it)
  449. while train_data is not None and i < train_steps and not achieved:
  450. Tensor.training = True
  451. BEAM.value = TRAIN_BEAM
  452. st = time.perf_counter()
  453. GlobalCounters.reset()
  454. loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
  455. train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
  456. train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
  457. pt = time.perf_counter()
  458. try:
  459. if INITMLPERF:
  460. next_data = get_fake_data_bert(GPUS, BS)
  461. else:
  462. next_data = get_data_bert(GPUS, train_it)
  463. except StopIteration:
  464. next_data = None
  465. dt = time.perf_counter()
  466. device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
  467. loss = loss.numpy().item()
  468. cl = time.perf_counter()
  469. if BENCHMARK: step_times.append(cl - st)
  470. tqdm.write(
  471. f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
  472. f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
  473. f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
  474. if WANDB:
  475. wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
  476. "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
  477. "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)})
  478. train_data, next_data = next_data, None
  479. i += 1
  480. if i == BENCHMARK:
  481. median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
  482. estimated_total_minutes = int(median_step_time * train_steps / 60)
  483. print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
  484. print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
  485. f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
  486. # ** eval loop **
  487. if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
  488. if MLLOGGER and RUNMLPERF:
  489. MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i})
  490. train_step_bert.reset()
  491. eval_lm_losses = []
  492. eval_clsf_losses = []
  493. eval_lm_accs = []
  494. eval_clsf_accs = []
  495. eval_times = []
  496. Tensor.training = False
  497. BEAM.value = EVAL_BEAM
  498. for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
  499. if INITMLPERF:
  500. eval_data = get_fake_data_bert(GPUS, EVAL_BS)
  501. else:
  502. eval_data = get_data_bert(GPUS, eval_it)
  503. GlobalCounters.reset()
  504. st = time.time()
  505. eval_result: dict[str, Tensor] = eval_step_bert(model,
  506. eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
  507. eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
  508. lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item()
  509. lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item()
  510. eval_lm_losses.append(lm_loss)
  511. eval_clsf_losses.append(clsf_loss)
  512. eval_lm_accs.append(lm_acc)
  513. eval_clsf_accs.append(clsf_acc)
  514. et = time.time()
  515. eval_times.append(et - st)
  516. if BENCHMARK and j == BENCHMARK:
  517. # assume INITMLPERF has BENCHMARK set
  518. if MLLOGGER and INITMLPERF:
  519. MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None)
  520. return
  521. eval_step_bert.reset()
  522. avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
  523. avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
  524. avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)
  525. avg_clsf_acc = sum(eval_clsf_accs) / len(eval_clsf_accs)
  526. avg_fw_time = sum(eval_times) / len(eval_times)
  527. results = f"eval lm loss: {avg_lm_loss:.2f}, eval clsf loss: {avg_clsf_loss:.2f}, eval lm accuracy: {avg_lm_acc:.6f}, \
  528. eval clsf accuracy: {avg_clsf_acc:.2f}, avg eval step time: {avg_fw_time:.2f}"
  529. tqdm.write(results)
  530. if WANDB:
  531. wandb.log({"eval/lm_loss": avg_lm_loss, "eval/clsf_loss": avg_clsf_loss, "eval/lm_accuracy": avg_lm_acc, \
  532. "eval/clsf_accuracy": avg_clsf_acc, "eval/forward_time": avg_fw_time})
  533. if MLLOGGER and RUNMLPERF:
  534. MLLOGGER.end(key=mllog_constants.EVAL_STOP, value=i, metadata={"epoch_count": 1, "step_num": i, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"]})
  535. MLLOGGER.event(key=mllog_constants.EVAL_ACCURACY, value=avg_lm_acc, metadata={"epoch_num": 1, "masked_lm_accuracy": avg_lm_acc})
  536. # save model if achieved target
  537. if not achieved and avg_lm_acc >= target:
  538. wc_end = time.perf_counter()
  539. if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
  540. fn = f"{ckpt_dir}/bert-large.safe"
  541. safe_save(get_state_dict(model), fn)
  542. print(f" *** Model saved to {fn} ***")
  543. total_seconds = wc_end - wc_start
  544. hours = int(total_seconds // 3600)
  545. minutes = int((total_seconds % 3600) // 60)
  546. seconds = total_seconds % 60
  547. print(f"Reference Convergence point reached after {i * BS} datasamples and {hours}h{minutes}m{seconds:.2f}s.")
  548. achieved = True
  549. if MLLOGGER and RUNMLPERF:
  550. MLLOGGER.end(key=mllog_constants.RUN_STOP, metadata=dict(status=mllog_constants.SUCCESS))
  551. # stop once hitting the target
  552. break
  553. if getenv("CKPT", 1) and i % save_ckpt_freq == 0:
  554. if MLLOGGER and RUNMLPERF:
  555. if previous_step:
  556. MLLOGGER.end(key=mllog_constants.BLOCK_STOP, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "first_step_num": i, "step_num": i, "step_count": i - previous_step})
  557. MLLOGGER.start(key="checkpoint_start", value=None, metadata={"step_num" : i})
  558. if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir)
  559. if WANDB and wandb.run is not None:
  560. fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe"
  561. else:
  562. fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe"
  563. print(f"saving ckpt to {fn}")
  564. safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
  565. ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))]
  566. ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
  567. while len(ckpt_files) > keep_ckpt_amount:
  568. last = ckpt_files.pop(0)
  569. print(f"Removing old ckpt {last}")
  570. os.remove(os.path.join(ckpt_dir, last))
  571. if MLLOGGER and RUNMLPERF:
  572. MLLOGGER.end(key="checkpoint_stop", value=None, metadata={"step_num": i})
  573. MLLOGGER.start(key=mllog_constants.BLOCK_START, value=None, metadata={"first_epoch_num": 1, "epoch_num": 1, "epoch_count": 1, "samples_count": config["EVAL_BS"] * config["MAX_EVAL_STEPS"], "step_num": i, "first_step_num": i+1})
  574. previous_step = i
  575. def train_maskrcnn():
  576. # TODO: Mask RCNN
  577. pass
  578. if __name__ == "__main__":
  579. multiprocessing.set_start_method('spawn')
  580. with Tensor.train():
  581. for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
  582. nm = f"train_{m}"
  583. if nm in globals():
  584. print(f"training {m}")
  585. globals()[nm]()