pretrain_valuenet.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from tinygrad.codegen.kernel import Kernel
  2. from tqdm import tqdm, trange
  3. import math
  4. import random
  5. from tinygrad.tensor import Tensor
  6. from tinygrad.nn import Linear
  7. from tinygrad.nn.optim import Adam
  8. from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
  9. # stuff needed to unpack a kernel
  10. from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
  11. from tinygrad.dtype import dtypes
  12. from tinygrad.shape.shapetracker import ShapeTracker
  13. from tinygrad.shape.view import View
  14. from tinygrad.shape.symbolic import Variable
  15. inf, nan = float('inf'), float('nan')
  16. from tinygrad.codegen.kernel import Opt, OptOps
  17. from extra.optimization.helpers import lin_to_feats, MAX_DIMS
  18. # NOTE: this is not real value of the state, it's just a prediction of the runtime
  19. INNER = 512
  20. class ValueNet:
  21. def __init__(self, feats=240, out=1):
  22. self.l1 = Linear(feats,INNER)
  23. self.l2 = Linear(INNER,INNER)
  24. self.l3 = Linear(INNER,INNER)
  25. self.l4 = Linear(INNER,out)
  26. def __call__(self, x):
  27. x = self.l1(x).relu()
  28. x = self.l2(x).relu()
  29. x = self.l3(x).relu().dropout(0.8)
  30. return self.l4(x)
  31. if __name__ == "__main__":
  32. net = ValueNet()
  33. optim = Adam(get_parameters(net))
  34. TEST_SIZE = 256
  35. dset = open("/tmp/logtm").read().strip().split("\n")
  36. random.seed(1337)
  37. random.shuffle(dset)
  38. X,Y = [], []
  39. for i,x in enumerate(tqdm(dset)):
  40. ast, opts, tms = eval(x)
  41. lin = Kernel(ast)
  42. for o in opts: lin.apply_opt(o)
  43. if lin.shape_len >= MAX_DIMS: continue
  44. if min(tms) == float('inf'): continue
  45. X.append(lin_to_feats(lin))
  46. Y.append([math.log(min(tms))])
  47. print(f"got {len(X)} samples")
  48. X_test,Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:])
  49. X,Y = X[:-TEST_SIZE], Y[:-TEST_SIZE]
  50. def get_minibatch(X,Y,bs):
  51. xs, ys = [], []
  52. for _ in range(bs):
  53. sel = random.randint(0, len(X)-1)
  54. xs.append(X[sel])
  55. ys.append(Y[sel])
  56. return Tensor(xs), Tensor(ys)
  57. Tensor.no_grad, Tensor.training = False, True
  58. losses = []
  59. test_losses = []
  60. test_loss = float('inf')
  61. for i in (t:=trange(2000)):
  62. x,y = get_minibatch(X,Y,bs=256)
  63. out = net(x)
  64. loss = (out-y).square().mean()
  65. optim.zero_grad()
  66. loss.backward()
  67. optim.step()
  68. t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
  69. losses.append(loss.numpy().item())
  70. test_losses.append(test_loss)
  71. if i % 10: test_loss = (net(X_test)-Y_test).square().mean().numpy().item()
  72. safe_save(get_state_dict(net), "/tmp/valuenet.safetensors")
  73. import matplotlib.pyplot as plt
  74. plt.plot(losses[200:])
  75. plt.plot(test_losses[200:])
  76. plt.show()