| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- import sys, sqlite3, pickle, math
- from collections import defaultdict
- from tqdm import tqdm, trange
- import numpy as np
- # stuff needed to unpack a kernel
- from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
- from tinygrad.dtype import dtypes
- from tinygrad.shape.shapetracker import ShapeTracker
- from tinygrad.shape.view import View
- from tinygrad.shape.symbolic import Variable
- inf, nan = float('inf'), float('nan')
- from tinygrad.codegen.kernel import Opt, OptOps
- # more stuff
- from tinygrad.codegen.kernel import Kernel
- from tinygrad.engine.search import actions
- from extra.optimization.helpers import lin_to_feats
- from extra.optimization.pretrain_valuenet import ValueNet
- from tinygrad.nn.optim import Adam
- from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
- import random
- from tinygrad.tensor import Tensor
- from tinygrad.helpers import getenv
- def dataset_from_cache(fn):
- conn = sqlite3.connect(fn)
- cur = conn.cursor()
- cur.execute("SELECT * FROM time_linearizer")
- grouped = defaultdict(dict)
- for f in tqdm(cur.fetchall()): grouped[f[0]][f[1:-1]] = pickle.loads(f[-1])
- opts_to_outcome = {}
- for ast,sk in grouped.items():
- cnts = defaultdict(int)
- for sks,tm in sk.items():
- if sks[1] != 1: continue
- opts = eval(sks[0])
- cnts[(len(opts), sks[1])] += 1
- opts_to_outcome[(ast, tuple(opts))] = tm
- #print(cnts)
- S,A,V = [], [], []
- for ast,k in tqdm(opts_to_outcome):
- if len(k) == 0: continue
- old_tm = min(opts_to_outcome[(ast,k[:-1])])
- new_tm = min(opts_to_outcome[(ast,k)])
- if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
- try:
- lin = Kernel(eval(ast))
- except Exception:
- continue
- for opt in k[:-1]: lin.apply_opt(opt)
- act = k[-1]
- log_ratio = math.log(old_tm/new_tm)
- #print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")
- S.append(lin_to_feats(lin, use_sts=True))
- A.append(actions.index(act))
- V.append([log_ratio]) # NOTE: i have written the bug many times with this having the wrong dim
- S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32)
- X = np.zeros((S.shape[0], S.shape[1]+len(actions)), dtype=np.float32)
- X[:, :S.shape[1]] = S
- X[range(S.shape[0]), S.shape[1]+A] = 1.0
- return X, V
- def log_likelihood(x:Tensor, mu:Tensor, log_sigma:Tensor):
- #print(x.shape, mu.shape, log_sigma.shape)
- #return (x-mu).abs() * (-log_sigma).exp() + log_sigma
- return (x-mu).square() * (-2*log_sigma).exp() / 2 + log_sigma
- if __name__ == "__main__":
- if getenv("REGEN"):
- X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache")
- safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset")
- else:
- ld = safe_load("/tmp/dataset")
- X,V = ld['X'].numpy(), ld['V'].numpy()
- print(X.shape, V.shape)
- order = list(range(X.shape[0]))
- random.shuffle(order)
- X, V = X[order], V[order]
- ratio = -512
- X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
- X,V = X[:ratio], V[:ratio]
- print(X.shape, V.shape)
- #print(X[0], V[0])
- #print(X[-1], V[-1])
- print(X.shape)
- net = ValueNet(X.shape[1], 2)
- optim = Adam(get_parameters(net))
- def get_minibatch(X,Y,bs):
- xs, ys = [], []
- #random.seed(1337)
- for _ in range(bs):
- sel = random.randint(0, len(X)-1)
- xs.append(X[sel])
- ys.append(Y[sel])
- return Tensor(xs), Tensor(ys)
- Tensor.no_grad, Tensor.training = False, True
- losses = []
- test_losses = []
- test_loss = float('inf')
- for i in (t:=trange(2000)):
- x,y = get_minibatch(X,V,bs=256)
- out = net(x)
- #loss = (out-y).square().mean()
- loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean()
- optim.zero_grad()
- loss.backward()
- optim.step()
- t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
- losses.append(loss.numpy().item())
- test_losses.append(test_loss)
- if i % 10: test_loss = (net(X_test)[:, 0:1]-V_test).square().mean().numpy().item()
- safe_save(get_state_dict(net), "/tmp/qnet.safetensors")
- import matplotlib.pyplot as plt
- plt.plot(losses[20:])
- plt.plot(test_losses[20:])
- plt.show()
|