extract_sa_pairs.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import sys, sqlite3, pickle, math
  2. from collections import defaultdict
  3. from tqdm import tqdm, trange
  4. import numpy as np
  5. # stuff needed to unpack a kernel
  6. from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
  7. from tinygrad.dtype import dtypes
  8. from tinygrad.shape.shapetracker import ShapeTracker
  9. from tinygrad.shape.view import View
  10. from tinygrad.shape.symbolic import Variable
  11. inf, nan = float('inf'), float('nan')
  12. from tinygrad.codegen.kernel import Opt, OptOps
  13. # more stuff
  14. from tinygrad.codegen.kernel import Kernel
  15. from tinygrad.engine.search import actions
  16. from extra.optimization.helpers import lin_to_feats
  17. from extra.optimization.pretrain_valuenet import ValueNet
  18. from tinygrad.nn.optim import Adam
  19. from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
  20. import random
  21. from tinygrad.tensor import Tensor
  22. from tinygrad.helpers import getenv
  23. def dataset_from_cache(fn):
  24. conn = sqlite3.connect(fn)
  25. cur = conn.cursor()
  26. cur.execute("SELECT * FROM time_linearizer")
  27. grouped = defaultdict(dict)
  28. for f in tqdm(cur.fetchall()): grouped[f[0]][f[1:-1]] = pickle.loads(f[-1])
  29. opts_to_outcome = {}
  30. for ast,sk in grouped.items():
  31. cnts = defaultdict(int)
  32. for sks,tm in sk.items():
  33. if sks[1] != 1: continue
  34. opts = eval(sks[0])
  35. cnts[(len(opts), sks[1])] += 1
  36. opts_to_outcome[(ast, tuple(opts))] = tm
  37. #print(cnts)
  38. S,A,V = [], [], []
  39. for ast,k in tqdm(opts_to_outcome):
  40. if len(k) == 0: continue
  41. old_tm = min(opts_to_outcome[(ast,k[:-1])])
  42. new_tm = min(opts_to_outcome[(ast,k)])
  43. if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
  44. try:
  45. lin = Kernel(eval(ast))
  46. except Exception:
  47. continue
  48. for opt in k[:-1]: lin.apply_opt(opt)
  49. act = k[-1]
  50. log_ratio = math.log(old_tm/new_tm)
  51. #print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")
  52. S.append(lin_to_feats(lin, use_sts=True))
  53. A.append(actions.index(act))
  54. V.append([log_ratio]) # NOTE: i have written the bug many times with this having the wrong dim
  55. S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32)
  56. X = np.zeros((S.shape[0], S.shape[1]+len(actions)), dtype=np.float32)
  57. X[:, :S.shape[1]] = S
  58. X[range(S.shape[0]), S.shape[1]+A] = 1.0
  59. return X, V
  60. def log_likelihood(x:Tensor, mu:Tensor, log_sigma:Tensor):
  61. #print(x.shape, mu.shape, log_sigma.shape)
  62. #return (x-mu).abs() * (-log_sigma).exp() + log_sigma
  63. return (x-mu).square() * (-2*log_sigma).exp() / 2 + log_sigma
  64. if __name__ == "__main__":
  65. if getenv("REGEN"):
  66. X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache")
  67. safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset")
  68. else:
  69. ld = safe_load("/tmp/dataset")
  70. X,V = ld['X'].numpy(), ld['V'].numpy()
  71. print(X.shape, V.shape)
  72. order = list(range(X.shape[0]))
  73. random.shuffle(order)
  74. X, V = X[order], V[order]
  75. ratio = -512
  76. X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
  77. X,V = X[:ratio], V[:ratio]
  78. print(X.shape, V.shape)
  79. #print(X[0], V[0])
  80. #print(X[-1], V[-1])
  81. print(X.shape)
  82. net = ValueNet(X.shape[1], 2)
  83. optim = Adam(get_parameters(net))
  84. def get_minibatch(X,Y,bs):
  85. xs, ys = [], []
  86. #random.seed(1337)
  87. for _ in range(bs):
  88. sel = random.randint(0, len(X)-1)
  89. xs.append(X[sel])
  90. ys.append(Y[sel])
  91. return Tensor(xs), Tensor(ys)
  92. Tensor.no_grad, Tensor.training = False, True
  93. losses = []
  94. test_losses = []
  95. test_loss = float('inf')
  96. for i in (t:=trange(2000)):
  97. x,y = get_minibatch(X,V,bs=256)
  98. out = net(x)
  99. #loss = (out-y).square().mean()
  100. loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean()
  101. optim.zero_grad()
  102. loss.backward()
  103. optim.step()
  104. t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
  105. losses.append(loss.numpy().item())
  106. test_losses.append(test_loss)
  107. if i % 10: test_loss = (net(X_test)[:, 0:1]-V_test).square().mean().numpy().item()
  108. safe_save(get_state_dict(net), "/tmp/qnet.safetensors")
  109. import matplotlib.pyplot as plt
  110. plt.plot(losses[20:])
  111. plt.plot(test_losses[20:])
  112. plt.show()