helpers.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # stuff needed to unpack a kernel
  2. from typing import Tuple
  3. from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
  4. from tinygrad.codegen.kernel import Opt, OptOps
  5. from tinygrad.dtype import dtypes
  6. from tinygrad.shape.shapetracker import ShapeTracker
  7. from tinygrad.shape.view import View
  8. from tinygrad.shape.symbolic import Variable, NumNode
  9. inf, nan = float('inf'), float('nan')
  10. # kernel unpacker
  11. from tinygrad.codegen.kernel import Kernel
  12. def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val
  13. def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
  14. def kern_str_to_lin(kern_str:str, opts=None):
  15. (ast, applied_opts,) = eval(kern_str)
  16. k = Kernel(ast, opts=opts)
  17. for opt in applied_opts:
  18. k.apply_opt(opt)
  19. return k
  20. # load worlds, a dataset of about 12k kernels
  21. import gzip
  22. from pathlib import Path
  23. import random
  24. from tinygrad.helpers import dedup
  25. def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
  26. fn = Path(__file__).parent.parent / "datasets/sops.gz"
  27. ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
  28. if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
  29. if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
  30. if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
  31. random.seed(1337)
  32. random.shuffle(ast_strs)
  33. return ast_strs
  34. def assert_same_lin(l1, l2):
  35. assert l1.colored_shape() == l2.colored_shape()
  36. assert all(x==y for x,y in zip(l1.sts, l2.sts))
  37. # get features
  38. import math
  39. from tinygrad.shape.symbolic import Node
  40. MAX_DIMS = 16
  41. MAX_BUFS = 9
  42. def lin_to_feats(lin:Kernel, use_sts=True):
  43. assert lin.shape_len < MAX_DIMS, "too many dims"
  44. all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]
  45. lc = [all_colors.index(x) for x in lin.colors()]
  46. ret = []
  47. # before, some generic linearizer stuff
  48. ret.append(lin.upcasted)
  49. ret.append(lin.local_dims)
  50. # first, the full shape, including the colors
  51. for s,os,c in zip(lin.full_shape,lin.output_shape,lc):
  52. if isinstance(s, Node):
  53. ret.append(False)
  54. ret += [0]*9
  55. else:
  56. ret.append(True)
  57. ret.append(math.log2(s))
  58. ret.append(min(33, s))
  59. ret.append(math.log2(os))
  60. ret.append(min(33, os))
  61. ret.append(s%2 == 0)
  62. ret.append(s%3 == 0)
  63. ret.append(s%4 == 0)
  64. ret.append(s%8 == 0)
  65. ret.append(s%16 == 0)
  66. cc = [0]*7
  67. cc[c] = 1
  68. ret += cc
  69. ret += [0] * (17*(MAX_DIMS-len(lin.full_shape)))
  70. ret = [float(x) for x in ret]
  71. if use_sts:
  72. my_sts = dedup([(x.shape == lin.full_shape, x.real_strides(), any(v.mask is not None for v in x.views), len(x.views)) for x in lin.sts])
  73. assert len(my_sts) < MAX_BUFS
  74. sts_len = 3 + 5*MAX_DIMS
  75. for s in my_sts:
  76. ret.append(s[0]) # reduce
  77. ret.append(s[2]) # has mask
  78. ret.append(s[3]) # len views
  79. for d in s[1]:
  80. ret.append(d is None)
  81. ret.append(d == 0)
  82. ret.append(d == 1)
  83. ret.append(min(33, d) if d is not None else -1)
  84. if d is not None and d >= 1: ret.append(math.log2(d))
  85. else: ret.append(-1)
  86. ret += [0] * (5*(MAX_DIMS - len(s[1])))
  87. ret += [0] * (sts_len*(MAX_BUFS - len(my_sts)))
  88. assert len(ret) == 1021, f"wrong len {len(ret)}"
  89. else:
  90. assert len(ret) == 274, f"wrong len {len(ret)}"
  91. return ret