initializers.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import math
  2. from typing import Union, Tuple
  3. from tinygrad import Tensor, nn, dtypes
  4. from tinygrad.helpers import prod, argfix
  5. # rejection sampling truncated randn
  6. def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
  7. CNT=8
  8. x = Tensor.randn(*(*shape, CNT), dtype=dtype, **kwargs)
  9. ctr = Tensor.arange(CNT).reshape((1,) * len(x.shape[:-1]) + (CNT,)).expand(x.shape)
  10. take = (x.abs() <= truncstds).where(ctr, CNT).min(axis=-1, keepdim=True) # set to 0 if no good samples
  11. return (ctr == take).where(x, 0).sum(axis=-1)
  12. # https://github.com/keras-team/keras/blob/v2.15.0/keras/initializers/initializers.py#L1026-L1065
  13. def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
  14. std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978
  15. return std * rand_truncn(*shape, **kwargs)
  16. class Conv2dHeNormal(nn.Conv2d):
  17. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
  18. super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
  19. self.in_channels, self.out_channels = in_channels, out_channels # for testing
  20. self.weight = he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0, dtype=dtypes.float32)
  21. if bias: self.bias = self.bias.cast(dtypes.float32)
  22. def __call__(self, x: Tensor):
  23. return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
  24. padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
  25. class Linear(nn.Linear):
  26. def __init__(self, in_features, out_features, bias=True):
  27. super().__init__(in_features, out_features, bias=bias)
  28. self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01, dtype=dtypes.float32)
  29. if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
  30. def __call__(self, x:Tensor):
  31. return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
  32. class LinearBert(nn.Linear):
  33. def __init__(self, in_features, out_features, bias=True, std=0.02):
  34. self.weight = std * rand_truncn(out_features, in_features, dtype=dtypes.float32)
  35. self.bias = Tensor.zeros(out_features, dtype=dtypes.float32) if bias else None
  36. def __call__(self, x:Tensor):
  37. return x.cast(dtypes.default_float).linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
  38. class EmbeddingBert(nn.Embedding):
  39. def __init__(self, vocab_size:int, embed_size:int, std=0.02):
  40. self.vocab_sz, self.embed_sz = vocab_size, embed_size
  41. self.weight = std * rand_truncn(vocab_size, embed_size, dtype=dtypes.float32)
  42. def __call__(self, idx:Tensor) -> Tensor:
  43. if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
  44. arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
  45. if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
  46. arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
  47. return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
  48. class LayerNormBert:
  49. def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
  50. self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
  51. self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
  52. self.weight, self.bias = (Tensor.ones(*self.normalized_shape, dtype=dtypes.float32), Tensor.zeros(*self.normalized_shape, dtype=dtypes.float32)) if elementwise_affine else (None, None)
  53. def __call__(self, x:Tensor):
  54. assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
  55. xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
  56. if not self.elementwise_affine: return xn
  57. return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))