convnext.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from tinygrad.tensor import Tensor
  2. from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
  3. from tinygrad.helpers import fetch, get_child
  4. class Block:
  5. def __init__(self, dim):
  6. self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
  7. self.norm = LayerNorm(dim, eps=1e-6)
  8. self.pwconv1 = Linear(dim, 4 * dim)
  9. self.pwconv2 = Linear(4 * dim, dim)
  10. self.gamma = Tensor.ones(dim)
  11. def __call__(self, x:Tensor):
  12. return x + x.sequential([
  13. self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm,
  14. self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2)
  15. ])
  16. class ConvNeXt:
  17. def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
  18. self.downsample_layers = [
  19. [Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
  20. *[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
  21. ]
  22. self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
  23. self.norm = LayerNorm(dims[-1])
  24. self.head = Linear(dims[-1], num_classes)
  25. def __call__(self, x:Tensor):
  26. for downsample, stage in zip(self.downsample_layers, self.stages):
  27. x = x.sequential(downsample).sequential(stage)
  28. return x.mean([-2, -1]).sequential([self.norm, self.head])
  29. # *** model definition is done ***
  30. versions = {
  31. "tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]},
  32. "small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
  33. "base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
  34. "large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
  35. "xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}
  36. }
  37. def get_model(version, load_weights=False):
  38. model = ConvNeXt(**versions[version])
  39. if load_weights:
  40. from tinygrad.nn.state import torch_load
  41. weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
  42. for k,v in weights.items():
  43. mv = get_child(model, k)
  44. mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
  45. return model
  46. if __name__ == "__main__":
  47. model = get_model("tiny", True)
  48. # load image
  49. from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
  50. img = Tensor(preprocess(chicken_img))
  51. Tensor.training = False
  52. Tensor.no_grad = True
  53. out = model(img).numpy()
  54. print(_LABELS[out.argmax()])