run_qnet.py 1.2 KB

1234567891011121314151617181920212223242526272829303132
  1. from typing import List, Tuple
  2. from tinygrad.codegen.kernel import Kernel
  3. from tinygrad.engine.search import get_kernel_actions, actions
  4. _net = None
  5. def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float]]:
  6. global _net
  7. if _net is None:
  8. from tinygrad.nn.state import load_state_dict, safe_load
  9. from extra.optimization.pretrain_valuenet import ValueNet
  10. _net = ValueNet(1021+len(actions), 2)
  11. load_state_dict(_net, safe_load("/tmp/qnet.safetensors"), verbose=False)
  12. from tinygrad.tensor import Tensor
  13. from tinygrad.helpers import Context
  14. from extra.optimization.helpers import lin_to_feats
  15. import numpy as np
  16. feats = []
  17. lins = []
  18. base_tms = []
  19. for lin,tm in beam:
  20. lin_feats = lin_to_feats(lin)
  21. for a,v in get_kernel_actions(lin, include_0=False).items():
  22. acts = np.zeros(len(actions))
  23. acts[a-1] = 1.0
  24. feats.append(np.concatenate([lin_feats, acts]))
  25. lins.append(v)
  26. base_tms.append(tm)
  27. with Context(BEAM=0):
  28. with Tensor.train(False):
  29. preds = _net(Tensor(feats)).numpy()
  30. pred_time = np.array(base_tms) / np.exp(preds[:, 0])
  31. return sorted(zip(lins, pred_time), key=lambda x: x[1])