extract_policynet.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os, sys, sqlite3, pickle, random
  2. from tqdm import tqdm, trange
  3. from copy import deepcopy
  4. from tinygrad.nn import Linear
  5. from tinygrad.tensor import Tensor
  6. from tinygrad.nn.optim import Adam
  7. from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
  8. from tinygrad.engine.search import actions
  9. from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
  10. from tinygrad.codegen.kernel import Kernel
  11. from tinygrad.helpers import getenv
  12. # stuff needed to unpack a kernel
  13. from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
  14. from tinygrad.dtype import dtypes
  15. from tinygrad.shape.shapetracker import ShapeTracker
  16. from tinygrad.shape.view import View
  17. from tinygrad.shape.symbolic import Variable
  18. inf, nan = float('inf'), float('nan')
  19. from tinygrad.codegen.kernel import Opt, OptOps
  20. INNER = 256
  21. class PolicyNet:
  22. def __init__(self):
  23. self.l1 = Linear(1021,INNER)
  24. self.l2 = Linear(INNER,INNER)
  25. self.l3 = Linear(INNER,1+len(actions))
  26. def __call__(self, x):
  27. x = self.l1(x).relu()
  28. x = self.l2(x).relu().dropout(0.9)
  29. return self.l3(x).log_softmax()
  30. def dataset_from_cache(fn):
  31. conn = sqlite3.connect(fn)
  32. cur = conn.cursor()
  33. cur.execute("SELECT * FROM beam_search")
  34. X,A = [], []
  35. for f in tqdm(cur.fetchall()):
  36. Xs,As = [], []
  37. try:
  38. lin = Kernel(eval(f[0]))
  39. opts = pickle.loads(f[-1])
  40. for o in opts:
  41. Xs.append(lin_to_feats(lin, use_sts=True))
  42. As.append(actions.index(o))
  43. lin.apply_opt(o)
  44. Xs.append(lin_to_feats(lin, use_sts=True))
  45. As.append(0)
  46. except Exception:
  47. pass
  48. X += Xs
  49. A += As
  50. return X,A
  51. if __name__ == "__main__":
  52. if getenv("REGEN"):
  53. X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache")
  54. safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset_policy")
  55. else:
  56. ld = safe_load("/tmp/dataset_policy")
  57. X,V = ld['X'].numpy(), ld['V'].numpy()
  58. print(X.shape, V.shape)
  59. order = list(range(X.shape[0]))
  60. random.shuffle(order)
  61. X, V = X[order], V[order]
  62. ratio = -256
  63. X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
  64. X,V = X[:ratio], V[:ratio]
  65. print(X.shape, V.shape)
  66. net = PolicyNet()
  67. #if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
  68. optim = Adam(get_parameters(net))
  69. def get_minibatch(X,Y,bs):
  70. xs, ys = [], []
  71. for _ in range(bs):
  72. sel = random.randint(0, len(X)-1)
  73. xs.append(X[sel])
  74. ys.append(Y[sel])
  75. return Tensor(xs), Tensor(ys)
  76. Tensor.no_grad, Tensor.training = False, True
  77. losses = []
  78. test_losses = []
  79. test_accuracy = 0
  80. test_loss = float('inf')
  81. for i in (t:=trange(500)):
  82. x,y = get_minibatch(X,V,bs=256)
  83. out = net(x)
  84. loss = out.sparse_categorical_crossentropy(y)
  85. optim.zero_grad()
  86. loss.backward()
  87. optim.step()
  88. cat = out.argmax(axis=-1)
  89. accuracy = (cat == y).mean()
  90. t.set_description(f"loss {loss.numpy():7.2f} accuracy {accuracy.numpy()*100:7.2f}%, test loss {test_loss:7.2f} test accuracy {test_accuracy*100:7.2f}%")
  91. losses.append(loss.numpy().item())
  92. test_losses.append(test_loss)
  93. if i % 10:
  94. out = net(X_test)
  95. test_loss = out.sparse_categorical_crossentropy(V_test).square().mean().numpy().item()
  96. cat = out.argmax(axis=-1)
  97. test_accuracy = (cat == y).mean().numpy()
  98. safe_save(get_state_dict(net), "/tmp/policynet.safetensors")
  99. import matplotlib.pyplot as plt
  100. plt.plot(losses[10:])
  101. plt.plot(test_losses[10:])
  102. plt.show()