unet3d.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from pathlib import Path
  2. import torch
  3. from tinygrad import nn
  4. from tinygrad.tensor import Tensor
  5. from tinygrad.helpers import fetch, get_child
  6. class DownsampleBlock:
  7. def __init__(self, c0, c1, stride=2):
  8. self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
  9. self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
  10. def __call__(self, x):
  11. return x.sequential(self.conv1).sequential(self.conv2)
  12. class UpsampleBlock:
  13. def __init__(self, c0, c1):
  14. self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
  15. self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
  16. self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
  17. def __call__(self, x, skip):
  18. x = x.sequential(self.upsample_conv)
  19. x = Tensor.cat(x, skip, dim=1)
  20. return x.sequential(self.conv1).sequential(self.conv2)
  21. class UNet3D:
  22. def __init__(self, in_channels=1, n_class=3):
  23. filters = [32, 64, 128, 256, 320]
  24. inp, out = filters[:-1], filters[1:]
  25. self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
  26. self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
  27. self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
  28. self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
  29. self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
  30. def __call__(self, x):
  31. x = self.input_block(x)
  32. outputs = [x]
  33. for downsample in self.downsample:
  34. x = downsample(x)
  35. outputs.append(x)
  36. x = self.bottleneck(x)
  37. for upsample, skip in zip(self.upsample, outputs[::-1]):
  38. x = upsample(x, skip)
  39. x = self.output["conv"](x)
  40. return x
  41. def load_from_pretrained(self):
  42. fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
  43. fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
  44. state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
  45. for k, v in state_dict.items():
  46. obj = get_child(self, k)
  47. assert obj.shape == v.shape, (k, obj.shape, v.shape)
  48. obj.assign(v.numpy())
  49. if __name__ == "__main__":
  50. mdl = UNet3D()
  51. mdl.load_from_pretrained()