f16_w_uint32.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import numpy as np
  2. from tinygrad import Device, dtypes, Tensor
  3. # TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
  4. def bit_extract(x, s, e) -> Tensor:
  5. # extract the top bits we don't want
  6. top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1))
  7. x = (x - top_bits) / (1<<e)
  8. return x.contiguous()
  9. def u16_to_f16(x):
  10. sign = bit_extract(x, 15, 15).float()
  11. exponent = bit_extract(x, 14, 10).float()
  12. fraction = bit_extract(x, 9, 0).float()
  13. return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
  14. def u32_to_f16(oo):
  15. oo1 = (oo/0x10000).floor().contiguous()
  16. # TODO: this is wrong and unextractable until we do this math in u32
  17. oo2 = (oo-(oo1*0x10000)).floor().contiguous()
  18. f1 = u16_to_f16(oo1)
  19. f2 = u16_to_f16(oo2)
  20. return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
  21. if __name__ == "__main__":
  22. # random float16
  23. Tensor.manual_seed(2)
  24. a = Tensor.randn(100, dtype=dtypes.float16)
  25. # this converts it to u32 on disk
  26. oo = a.to("disk:/tmp/f16").cast(dtypes.uint32)[:50].to(Device.DEFAULT).realize()
  27. # convert to 2xf16 using tinygrad math ops
  28. f16 = u32_to_f16(oo)
  29. ref = a.numpy()
  30. out = f16.numpy().astype(np.float16)
  31. print(ref-out)
  32. np.testing.assert_allclose(ref, out)