hlb_cifar10.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. #!/usr/bin/env python3
  2. # tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
  3. # https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
  4. # https://siboehm.com/articles/22/CUDA-MMM
  5. import random, time
  6. import numpy as np
  7. from typing import Optional
  8. from extra.datasets import fetch_cifar, cifar_mean, cifar_std
  9. from extra.lr_scheduler import OneCycleLR
  10. from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
  11. from tinygrad.nn.state import get_state_dict, get_parameters
  12. from tinygrad.nn import optim
  13. from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
  14. from tinygrad.multi import MultiLazyBuffer
  15. BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
  16. EVAL_BS = getenv("EVAL_BS", BS)
  17. GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
  18. assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
  19. assert EVAL_BS % len(GPUS) == 0, f"{EVAL_BS=} is not a multiple of {len(GPUS)=}, uneven multi GPU is slow"
  20. class UnsyncedBatchNorm:
  21. def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
  22. self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
  23. self.num_devices = num_devices
  24. if affine: self.weight, self.bias = Tensor.ones(sz, dtype=dtypes.float32), Tensor.zeros(sz, dtype=dtypes.float32)
  25. else: self.weight, self.bias = None, None
  26. self.running_mean = Tensor.zeros(num_devices, sz, dtype=dtypes.float32, requires_grad=False)
  27. self.running_var = Tensor.ones(num_devices, sz, dtype=dtypes.float32, requires_grad=False)
  28. self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)
  29. def __call__(self, x:Tensor):
  30. if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
  31. xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32)
  32. batch_mean, batch_invstd = self.calc_stats(xr)
  33. ret = xr.batchnorm(
  34. self.weight.reshape(1, -1).expand((self.num_devices, -1)),
  35. self.bias.reshape(1, -1).expand((self.num_devices, -1)),
  36. batch_mean, batch_invstd, axis=(0, 2))
  37. return ret.reshape(x.shape).cast(x.dtype)
  38. def calc_stats(self, x:Tensor):
  39. if Tensor.training:
  40. # This requires two full memory accesses to x
  41. # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
  42. # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
  43. batch_mean = x.mean(axis=(1,3,4))
  44. y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) # d(var)/d(mean) = 0
  45. batch_var = (y*y).mean(axis=(1,3,4))
  46. batch_invstd = batch_var.add(self.eps).pow(-0.5)
  47. # NOTE: wow, this is done all throughout training in most PyTorch models
  48. if self.track_running_stats:
  49. self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach().cast(self.running_mean.dtype))
  50. batch_var_adjust = prod(y.shape[1:])/(prod(y.shape[1:])-y.shape[2])
  51. self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * batch_var_adjust * batch_var.detach().cast(self.running_var.dtype))
  52. self.num_batches_tracked += 1
  53. else:
  54. batch_mean = self.running_mean
  55. # NOTE: this can be precomputed for static inference. we expand it here so it fuses
  56. batch_invstd = self.running_var.reshape(self.running_var.shape[0], 1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
  57. return batch_mean, batch_invstd
  58. class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm):
  59. def __init__(self, num_features):
  60. super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
  61. self.weight.requires_grad = False
  62. self.bias.requires_grad = True
  63. class ConvGroup:
  64. def __init__(self, channels_in, channels_out):
  65. self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
  66. self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
  67. self.norm1 = BatchNorm(channels_out)
  68. self.norm2 = BatchNorm(channels_out)
  69. def __call__(self, x):
  70. x = self.conv1(x)
  71. x = x.max_pool2d(2)
  72. x = x.float()
  73. x = self.norm1(x)
  74. x = x.cast(dtypes.default_float)
  75. x = x.quick_gelu()
  76. residual = x
  77. x = self.conv2(x)
  78. x = x.float()
  79. x = self.norm2(x)
  80. x = x.cast(dtypes.default_float)
  81. x = x.quick_gelu()
  82. return x + residual
  83. class SpeedyResNet:
  84. def __init__(self, W):
  85. self.whitening = W
  86. self.net = [
  87. nn.Conv2d(12, 32, kernel_size=1, bias=False),
  88. lambda x: x.quick_gelu(),
  89. ConvGroup(32, 64),
  90. ConvGroup(64, 256),
  91. ConvGroup(256, 512),
  92. lambda x: x.max((2,3)),
  93. nn.Linear(512, 10, bias=False),
  94. lambda x: x / 9.,
  95. ]
  96. def __call__(self, x, training=True):
  97. # pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
  98. # TODO: remove the pad but instead let the kernel optimize itself
  99. forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
  100. return forward(x) if training else (forward(x) + forward(x[..., ::-1])) / 2.
  101. # hyper-parameters were exactly the same as the original repo
  102. bias_scaler = 58
  103. hyp = {
  104. 'seed' : 209,
  105. 'opt': {
  106. 'bias_lr': 1.76 * bias_scaler/512,
  107. 'non_bias_lr': 1.76 / 512,
  108. 'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
  109. 'non_bias_decay': 1.08 * 6.45e-4 * BS,
  110. 'final_lr_ratio': 0.025,
  111. 'initial_div_factor': 1e6,
  112. 'label_smoothing': 0.20,
  113. 'momentum': 0.85,
  114. 'percent_start': 0.23,
  115. 'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
  116. },
  117. 'net': {
  118. 'kernel_size': 2, # kernel size for the whitening layer
  119. 'cutmix_size': 3,
  120. 'cutmix_steps': 499,
  121. 'pad_amount': 2
  122. },
  123. 'ema': {
  124. 'steps': 399,
  125. 'decay_base': .95,
  126. 'decay_pow': 1.6,
  127. 'every_n_steps': 5,
  128. },
  129. }
  130. def train_cifar():
  131. def set_seed(seed):
  132. Tensor.manual_seed(seed)
  133. random.seed(seed)
  134. # ========== Model ==========
  135. def whitening(X, kernel_size=hyp['net']['kernel_size']):
  136. def _cov(X):
  137. return (X.T @ X) / (X.shape[0] - 1)
  138. def _patches(data, patch_size=(kernel_size,kernel_size)):
  139. h, w = patch_size
  140. c = data.shape[1]
  141. axis = (2, 3)
  142. return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
  143. def _eigens(patches):
  144. n,c,h,w = patches.shape
  145. Σ = _cov(patches.reshape(n, c*h*w))
  146. Λ, V = np.linalg.eigh(Σ, UPLO='U')
  147. return np.flip(Λ, 0), np.flip(V.T.reshape(c*h*w, c, h, w), 0)
  148. # NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
  149. Λ, V = _eigens(_patches(X.float().numpy()))
  150. W = V/np.sqrt(Λ+1e-2)[:,None,None,None]
  151. return Tensor(W.astype(np.float32), requires_grad=False).cast(dtypes.default_float)
  152. # ========== Loss ==========
  153. def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
  154. divisor = y.shape[1]
  155. assert isinstance(divisor, int), "only supported int divisor"
  156. y = (1 - label_smoothing)*y + label_smoothing / divisor
  157. ret = -x.log_softmax(axis=1).mul(y).sum(axis=1)
  158. if reduction=='none': return ret
  159. if reduction=='sum': return ret.sum()
  160. if reduction=='mean': return ret.mean()
  161. raise NotImplementedError(reduction)
  162. # ========== Preprocessing ==========
  163. # NOTE: this only works for RGB in format of NxCxHxW and pads the HxW
  164. def pad_reflect(X, size=2) -> Tensor:
  165. X = X[...,:,1:size+1].flip(-1).cat(X, X[...,:,-(size+1):-1].flip(-1), dim=-1)
  166. X = X[...,1:size+1,:].flip(-2).cat(X, X[...,-(size+1):-1,:].flip(-2), dim=-2)
  167. return X
  168. # return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
  169. def make_square_mask(shape, mask_size) -> Tensor:
  170. BS, _, H, W = shape
  171. low_x = Tensor.randint(BS, low=0, high=W-mask_size).reshape(BS,1,1,1)
  172. low_y = Tensor.randint(BS, low=0, high=H-mask_size).reshape(BS,1,1,1)
  173. idx_x = Tensor.arange(W, dtype=dtypes.int32).reshape((1,1,1,W))
  174. idx_y = Tensor.arange(H, dtype=dtypes.int32).reshape((1,1,H,1))
  175. return (idx_x >= low_x) * (idx_x < (low_x + mask_size)) * (idx_y >= low_y) * (idx_y < (low_y + mask_size))
  176. def random_crop(X:Tensor, crop_size=32):
  177. mask = make_square_mask(X.shape, crop_size)
  178. mask = mask.expand((-1,3,-1,-1))
  179. X_cropped = Tensor(X.numpy()[mask.numpy()])
  180. return X_cropped.reshape((-1, 3, crop_size, crop_size))
  181. def cutmix(X:Tensor, Y:Tensor, mask_size=3):
  182. # fill the square with randomly selected images from the same batch
  183. mask = make_square_mask(X.shape, mask_size)
  184. order = list(range(0, X.shape[0]))
  185. random.shuffle(order)
  186. X_patch = Tensor(X.numpy()[order], device=X.device, dtype=X.dtype)
  187. Y_patch = Tensor(Y.numpy()[order], device=Y.device, dtype=Y.dtype)
  188. X_cutmix = mask.where(X_patch, X)
  189. mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
  190. Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
  191. return X_cutmix, Y_cutmix
  192. # the operations that remain inside batch fetcher is the ones that involves random operations
  193. def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
  194. step, epoch = 0, 0
  195. while True:
  196. st = time.monotonic()
  197. X, Y = X_in, Y_in
  198. if is_train:
  199. # TODO: these are not jitted
  200. if getenv("RANDOM_CROP", 1):
  201. X = random_crop(X, crop_size=32)
  202. if getenv("RANDOM_FLIP", 1):
  203. X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
  204. if getenv("CUTMIX", 1):
  205. if step >= hyp['net']['cutmix_steps']:
  206. X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
  207. order = list(range(0, X.shape[0]))
  208. random.shuffle(order)
  209. X, Y = X.numpy()[order], Y.numpy()[order]
  210. else:
  211. X, Y = X.numpy(), Y.numpy()
  212. et = time.monotonic()
  213. print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({epoch=})")
  214. for i in range(0, X.shape[0], BS):
  215. # pad the last batch # TODO: not correct for test
  216. batch_end = min(i+BS, Y.shape[0])
  217. x = Tensor(X[batch_end-BS:batch_end], device=X_in.device, dtype=X_in.dtype)
  218. y = Tensor(Y[batch_end-BS:batch_end], device=Y_in.device, dtype=Y_in.dtype)
  219. step += 1
  220. yield x, y
  221. epoch += 1
  222. if not is_train: break
  223. transform = [
  224. lambda x: x / 255.0,
  225. lambda x: x.reshape((-1,3,32,32)) - Tensor(cifar_mean, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
  226. lambda x: x / Tensor(cifar_std, device=x.device, dtype=x.dtype).reshape((1,3,1,1)),
  227. ]
  228. class modelEMA():
  229. def __init__(self, w, net):
  230. # self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
  231. self.net_ema = SpeedyResNet(w)
  232. for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
  233. net_ema_param.requires_grad = False
  234. net_ema_param.assign(net_param.numpy())
  235. @TinyJit
  236. def update(self, net, decay):
  237. # TODO with Tensor.no_grad()
  238. Tensor.no_grad = True
  239. for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
  240. # batchnorm currently is not being tracked
  241. if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
  242. net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
  243. Tensor.no_grad = False
  244. set_seed(getenv('SEED', hyp['seed']))
  245. X_train, Y_train, X_test, Y_test = fetch_cifar()
  246. # load data and label into GPU and convert to dtype accordingly
  247. X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
  248. Y_train, Y_test = Y_train.to(device=Device.DEFAULT), Y_test.to(device=Device.DEFAULT)
  249. # one-hot encode labels
  250. Y_train, Y_test = Y_train.one_hot(10), Y_test.one_hot(10)
  251. # preprocess data
  252. X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
  253. # precompute whitening patches
  254. W = whitening(X_train)
  255. # initialize model weights
  256. model = SpeedyResNet(W)
  257. # padding is not timed in the original repo since it can be done all at once
  258. X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
  259. # Convert data and labels to the default dtype
  260. X_train, Y_train = X_train.cast(dtypes.default_float), Y_train.cast(dtypes.default_float)
  261. X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
  262. if len(GPUS) > 1:
  263. for k, x in get_state_dict(model).items():
  264. if not getenv('SYNCBN') and ('running_mean' in k or 'running_var' in k):
  265. x.shard_(GPUS, axis=0)
  266. else:
  267. x.to_(GPUS)
  268. # parse the training params into bias and non-bias
  269. params_dict = get_state_dict(model)
  270. params_bias = []
  271. params_non_bias = []
  272. for params in params_dict:
  273. if params_dict[params].requires_grad is not False:
  274. if 'bias' in params:
  275. params_bias.append(params_dict[params])
  276. else:
  277. params_non_bias.append(params_dict[params])
  278. opt_bias = optim.SGD(params_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
  279. opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
  280. # NOTE taken from the hlb_CIFAR repository, might need to be tuned
  281. initial_div_factor = hyp['opt']['initial_div_factor']
  282. final_lr_ratio = hyp['opt']['final_lr_ratio']
  283. pct_start = hyp['opt']['percent_start']
  284. lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
  285. lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
  286. def train_step(model, optimizer, lr_scheduler, X, Y):
  287. out = model(X)
  288. loss_batchsize_scaler = 512/BS
  289. loss = cross_entropy(out, Y, reduction='none', label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
  290. if not getenv("DISABLE_BACKWARD"):
  291. # index 0 for bias and 1 for non-bias
  292. optimizer.zero_grad()
  293. loss.backward()
  294. optimizer.step()
  295. lr_scheduler[0].step()
  296. lr_scheduler[1].step()
  297. return loss.realize()
  298. train_step_jitted = TinyJit(train_step)
  299. def eval_step(model, X, Y):
  300. out = model(X, training=False)
  301. loss = cross_entropy(out, Y, reduction='mean')
  302. correct = out.argmax(axis=1) == Y.argmax(axis=1)
  303. return correct.realize(), loss.realize()
  304. eval_step_jitted = TinyJit(eval_step)
  305. eval_step_ema_jitted = TinyJit(eval_step)
  306. # 97 steps in 2 seconds = 20ms / step
  307. # step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
  308. # 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
  309. # 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
  310. # 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off
  311. # https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
  312. # 136 TFLOPS is the theoretical max w float16 on 3080 Ti
  313. model_ema: Optional[modelEMA] = None
  314. projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
  315. i = 0
  316. eval_acc_pct = 0.0
  317. batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
  318. with Tensor.train():
  319. st = time.monotonic()
  320. while i <= STEPS:
  321. if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1 and not getenv("DISABLE_BACKWARD"):
  322. # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
  323. corrects = []
  324. corrects_ema = []
  325. losses = []
  326. losses_ema = []
  327. for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
  328. if len(GPUS) > 1:
  329. Xt.shard_(GPUS, axis=0)
  330. Yt.shard_(GPUS, axis=0)
  331. correct, loss = eval_step_jitted(model, Xt, Yt)
  332. losses.append(loss.numpy().tolist())
  333. corrects.extend(correct.numpy().tolist())
  334. if model_ema:
  335. correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
  336. losses_ema.append(loss_ema.numpy().tolist())
  337. corrects_ema.extend(correct_ema.numpy().tolist())
  338. # collect accuracy across ranks
  339. correct_sum, correct_len = sum(corrects), len(corrects)
  340. if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
  341. eval_acc_pct = correct_sum/correct_len*100.0
  342. if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
  343. print(f"eval {correct_sum}/{correct_len} {eval_acc_pct:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
  344. if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
  345. if STEPS == 0 or i == STEPS: break
  346. GlobalCounters.reset()
  347. X, Y = next(batcher)
  348. if len(GPUS) > 1:
  349. X.shard_(GPUS, axis=0)
  350. Y.shard_(GPUS, axis=0)
  351. with Context(BEAM=getenv("LATEBEAM", BEAM.value), WINO=getenv("LATEWINO", WINO.value)):
  352. loss = train_step_jitted(model, optim.OptimizerGroup(opt_bias, opt_non_bias), [lr_sched_bias, lr_sched_non_bias], X, Y)
  353. et = time.monotonic()
  354. loss_cpu = loss.numpy()
  355. # EMA for network weights
  356. if getenv("EMA") and i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
  357. if model_ema is None:
  358. model_ema = modelEMA(W, model)
  359. model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
  360. cl = time.monotonic()
  361. device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
  362. # 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
  363. print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms {device_str}, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
  364. st = cl
  365. i += 1
  366. # verify eval acc
  367. if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
  368. if eval_acc_pct >= target:
  369. print(colored(f"{eval_acc_pct=} >= {target}", "green"))
  370. else:
  371. raise ValueError(colored(f"{eval_acc_pct=} < {target}", "red"))
  372. if __name__ == "__main__":
  373. train_cifar()