| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- from tinygrad.codegen.kernel import Kernel
- from tqdm import tqdm, trange
- import math
- import random
- from tinygrad.tensor import Tensor
- from tinygrad.nn import Linear
- from tinygrad.nn.optim import Adam
- from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
- # stuff needed to unpack a kernel
- from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
- from tinygrad.dtype import dtypes
- from tinygrad.shape.shapetracker import ShapeTracker
- from tinygrad.shape.view import View
- from tinygrad.shape.symbolic import Variable
- inf, nan = float('inf'), float('nan')
- from tinygrad.codegen.kernel import Opt, OptOps
- from extra.optimization.helpers import lin_to_feats, MAX_DIMS
- # NOTE: this is not real value of the state, it's just a prediction of the runtime
- INNER = 512
- class ValueNet:
- def __init__(self, feats=240, out=1):
- self.l1 = Linear(feats,INNER)
- self.l2 = Linear(INNER,INNER)
- self.l3 = Linear(INNER,INNER)
- self.l4 = Linear(INNER,out)
- def __call__(self, x):
- x = self.l1(x).relu()
- x = self.l2(x).relu()
- x = self.l3(x).relu().dropout(0.8)
- return self.l4(x)
- if __name__ == "__main__":
- net = ValueNet()
- optim = Adam(get_parameters(net))
- TEST_SIZE = 256
- dset = open("/tmp/logtm").read().strip().split("\n")
- random.seed(1337)
- random.shuffle(dset)
- X,Y = [], []
- for i,x in enumerate(tqdm(dset)):
- ast, opts, tms = eval(x)
- lin = Kernel(ast)
- for o in opts: lin.apply_opt(o)
- if lin.shape_len >= MAX_DIMS: continue
- if min(tms) == float('inf'): continue
- X.append(lin_to_feats(lin))
- Y.append([math.log(min(tms))])
- print(f"got {len(X)} samples")
- X_test,Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:])
- X,Y = X[:-TEST_SIZE], Y[:-TEST_SIZE]
- def get_minibatch(X,Y,bs):
- xs, ys = [], []
- for _ in range(bs):
- sel = random.randint(0, len(X)-1)
- xs.append(X[sel])
- ys.append(Y[sel])
- return Tensor(xs), Tensor(ys)
- Tensor.no_grad, Tensor.training = False, True
- losses = []
- test_losses = []
- test_loss = float('inf')
- for i in (t:=trange(2000)):
- x,y = get_minibatch(X,Y,bs=256)
- out = net(x)
- loss = (out-y).square().mean()
- optim.zero_grad()
- loss.backward()
- optim.step()
- t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
- losses.append(loss.numpy().item())
- test_losses.append(test_loss)
- if i % 10: test_loss = (net(X_test)-Y_test).square().mean().numpy().item()
- safe_save(get_state_dict(net), "/tmp/valuenet.safetensors")
- import matplotlib.pyplot as plt
- plt.plot(losses[200:])
- plt.plot(test_losses[200:])
- plt.show()
|