__init__.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import math
  2. from typing import Optional, Union, Tuple
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.helpers import prod
  5. from tinygrad.nn import optim, state, datasets # noqa: F401
  6. class BatchNorm:
  7. """
  8. Applies Batch Normalization over a 2D or 3D input.
  9. - Described: https://paperswithcode.com/method/batch-normalization
  10. - Paper: https://arxiv.org/abs/1502.03167v3
  11. See: `Tensor.batchnorm`
  12. ```python exec="true" session="tensor"
  13. from tinygrad import Tensor, dtypes, nn
  14. import numpy as np
  15. np.set_printoptions(precision=4)
  16. ```
  17. ```python exec="true" source="above" session="tensor" result="python"
  18. norm = nn.BatchNorm(3)
  19. t = Tensor.rand(2, 3, 4, 4)
  20. print(t.mean().item(), t.std().item())
  21. ```
  22. ```python exec="true" source="above" session="tensor" result="python"
  23. t = norm(t)
  24. print(t.mean().item(), t.std().item())
  25. ```
  26. """
  27. def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
  28. self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
  29. if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
  30. else: self.weight, self.bias = None, None
  31. self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
  32. self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
  33. def __call__(self, x:Tensor):
  34. shape_mask = [1, -1, *([1]*(x.ndim-2))]
  35. if Tensor.training:
  36. # This requires two full memory accesses to x
  37. # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
  38. # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
  39. batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
  40. y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
  41. batch_var = (y*y).mean(axis=reduce_axes)
  42. batch_invstd = batch_var.add(self.eps).pow(-0.5)
  43. # NOTE: wow, this is done all throughout training in most PyTorch models
  44. if self.track_running_stats:
  45. self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
  46. self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape)/(prod(y.shape)-y.shape[1]) * batch_var.detach())
  47. self.num_batches_tracked += 1
  48. else:
  49. batch_mean = self.running_mean
  50. # NOTE: this can be precomputed for static inference. we expand it here so it fuses
  51. batch_invstd = self.running_var.reshape(shape=shape_mask).expand(x.shape).add(self.eps).rsqrt()
  52. return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
  53. BatchNorm2d = BatchNorm3d = BatchNorm
  54. def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
  55. """
  56. Applies a 1D convolution over an input signal composed of several input planes.
  57. See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d
  58. ```python exec="true" source="above" session="tensor" result="python"
  59. conv = nn.Conv1d(1, 1, 3)
  60. t = Tensor.rand(1, 1, 4)
  61. print(t.numpy())
  62. ```
  63. ```python exec="true" source="above" session="tensor" result="python"
  64. t = conv(t)
  65. print(t.numpy())
  66. ```
  67. """
  68. return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)
  69. class Conv2d:
  70. """
  71. Applies a 2D convolution over an input signal composed of several input planes.
  72. See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d
  73. ```python exec="true" source="above" session="tensor" result="python"
  74. conv = nn.Conv2d(1, 1, 3)
  75. t = Tensor.rand(1, 1, 4, 4)
  76. print(t.numpy())
  77. ```
  78. ```python exec="true" source="above" session="tensor" result="python"
  79. t = conv(t)
  80. print(t.numpy())
  81. ```
  82. """
  83. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
  84. self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
  85. self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
  86. scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
  87. self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
  88. self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
  89. def __call__(self, x:Tensor):
  90. return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
  91. def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
  92. """
  93. Applies a 1D transposed convolution operator over an input signal composed of several input planes.
  94. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d
  95. ```python exec="true" source="above" session="tensor" result="python"
  96. conv = nn.ConvTranspose1d(1, 1, 3)
  97. t = Tensor.rand(1, 1, 4)
  98. print(t.numpy())
  99. ```
  100. ```python exec="true" source="above" session="tensor" result="python"
  101. t = conv(t)
  102. print(t.numpy())
  103. ```
  104. """
  105. return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
  106. class ConvTranspose2d(Conv2d):
  107. """
  108. Applies a 2D transposed convolution operator over an input image.
  109. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d
  110. ```python exec="true" source="above" session="tensor" result="python"
  111. conv = nn.ConvTranspose2d(1, 1, 3)
  112. t = Tensor.rand(1, 1, 4, 4)
  113. print(t.numpy())
  114. ```
  115. ```python exec="true" source="above" session="tensor" result="python"
  116. t = conv(t)
  117. print(t.numpy())
  118. ```
  119. """
  120. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
  121. super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
  122. scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
  123. self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
  124. self.output_padding = output_padding
  125. def __call__(self, x:Tensor):
  126. return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
  127. dilation=self.dilation, groups=self.groups)
  128. class Linear:
  129. """
  130. Applies a linear transformation to the incoming data.
  131. See: https://pytorch.org/docs/stable/generated/torch.nn.Linear
  132. ```python exec="true" source="above" session="tensor" result="python"
  133. lin = nn.Linear(3, 4)
  134. t = Tensor.rand(2, 3)
  135. print(t.numpy())
  136. ```
  137. ```python exec="true" source="above" session="tensor" result="python"
  138. t = lin(t)
  139. print(t.numpy())
  140. ```
  141. """
  142. def __init__(self, in_features, out_features, bias=True):
  143. bound = 1 / math.sqrt(in_features)
  144. self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
  145. self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
  146. def __call__(self, x:Tensor):
  147. return x.linear(self.weight.transpose(), self.bias)
  148. class GroupNorm:
  149. """
  150. Applies Group Normalization over a mini-batch of inputs.
  151. - Described: https://paperswithcode.com/method/group-normalization
  152. - Paper: https://arxiv.org/abs/1803.08494v3
  153. ```python exec="true" source="above" session="tensor" result="python"
  154. norm = nn.GroupNorm(2, 12)
  155. t = Tensor.rand(2, 12, 4, 4) * 2 + 1
  156. print(t.mean().item(), t.std().item())
  157. ```
  158. ```python exec="true" source="above" session="tensor" result="python"
  159. t = norm(t)
  160. print(t.mean().item(), t.std().item())
  161. ```
  162. """
  163. def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
  164. self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
  165. self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
  166. self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
  167. def __call__(self, x:Tensor):
  168. # reshape for layernorm to work as group norm
  169. # subtract mean and divide stddev
  170. x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
  171. if self.weight is None or self.bias is None: return x
  172. # elementwise_affine on channels
  173. return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
  174. class InstanceNorm:
  175. """
  176. Applies Instance Normalization over a mini-batch of inputs.
  177. - Described: https://paperswithcode.com/method/instance-normalization
  178. - Paper: https://arxiv.org/abs/1607.08022v3
  179. ```python exec="true" source="above" session="tensor" result="python"
  180. norm = nn.InstanceNorm(3)
  181. t = Tensor.rand(2, 3, 4, 4) * 2 + 1
  182. print(t.mean().item(), t.std().item())
  183. ```
  184. ```python exec="true" source="above" session="tensor" result="python"
  185. t = norm(t)
  186. print(t.mean().item(), t.std().item())
  187. ```
  188. """
  189. def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
  190. self.num_features, self.eps = num_features, eps
  191. self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
  192. self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
  193. def __call__(self, x:Tensor):
  194. x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
  195. if self.weight is None or self.bias is None: return x
  196. return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
  197. class LayerNorm:
  198. """
  199. Applies Layer Normalization over a mini-batch of inputs.
  200. - Described: https://paperswithcode.com/method/layer-normalization
  201. - Paper: https://arxiv.org/abs/1607.06450v1
  202. ```python exec="true" source="above" session="tensor" result="python"
  203. norm = nn.LayerNorm(3)
  204. t = Tensor.rand(2, 5, 3) * 2 + 1
  205. print(t.mean().item(), t.std().item())
  206. ```
  207. ```python exec="true" source="above" session="tensor" result="python"
  208. t = norm(t)
  209. print(t.mean().item(), t.std().item())
  210. ```
  211. """
  212. def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
  213. self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
  214. self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
  215. self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
  216. def __call__(self, x:Tensor):
  217. assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
  218. x = x.layernorm(eps=self.eps, axis=self.axis)
  219. if not self.elementwise_affine: return x
  220. return x * self.weight + self.bias
  221. class LayerNorm2d(LayerNorm):
  222. """
  223. Applies Layer Normalization over a mini-batch of 2D inputs.
  224. See: `LayerNorm`
  225. ```python exec="true" source="above" session="tensor" result="python"
  226. norm = nn.LayerNorm2d(3)
  227. t = Tensor.rand(2, 3, 4, 4) * 2 + 1
  228. print(t.mean().item(), t.std().item())
  229. ```
  230. ```python exec="true" source="above" session="tensor" result="python"
  231. t = norm(t)
  232. print(t.mean().item(), t.std().item())
  233. ```
  234. """
  235. def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  236. class RMSNorm:
  237. """
  238. Applies Root Mean Square Normalization to input.
  239. - Described: https://paperswithcode.com/method/rmsnorm
  240. - Paper: https://arxiv.org/abs/1910.07467
  241. ```python exec="true" source="above" session="tensor" result="python"
  242. norm = nn.RMSNorm(4)
  243. t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
  244. print(t.numpy())
  245. ```
  246. ```python exec="true" source="above" session="tensor" result="python"
  247. print(norm(t).numpy())
  248. ```
  249. """
  250. def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
  251. def _norm(self, x:Tensor): return x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
  252. def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
  253. class Embedding:
  254. """
  255. A simple lookup table that stores embeddings of a fixed dictionary and size.
  256. See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding
  257. ```python exec="true" source="above" session="tensor" result="python"
  258. emb = nn.Embedding(10, 3)
  259. print(emb(Tensor([1, 2, 3, 1])).numpy())
  260. ```
  261. """
  262. def __init__(self, vocab_size:int, embed_size:int):
  263. self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)
  264. def __call__(self, idx:Tensor) -> Tensor:
  265. if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), device=self.weight.device)
  266. arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
  267. if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
  268. arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
  269. return (arange == idx).mul(vals).sum(2)