| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- import math
- from tinygrad.tensor import Tensor
- from tinygrad.nn import BatchNorm2d
- from tinygrad.helpers import get_child, fetch
- from tinygrad.nn.state import torch_load
- class MBConvBlock:
- def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
- oup = expand_ratio * input_filters
- if expand_ratio != 1:
- self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
- self._bn0 = BatchNorm2d(oup, track_running_stats=track_running_stats)
- else:
- self._expand_conv = None
- self.strides = strides
- if strides == (2,2):
- self.pad = [(kernel_size-1)//2-1, (kernel_size-1)//2]*2
- else:
- self.pad = [(kernel_size-1)//2]*4
- self._depthwise_conv = Tensor.glorot_uniform(oup, 1, kernel_size, kernel_size)
- self._bn1 = BatchNorm2d(oup, track_running_stats=track_running_stats)
- self.has_se = has_se
- if self.has_se:
- num_squeezed_channels = max(1, int(input_filters * se_ratio))
- self._se_reduce = Tensor.glorot_uniform(num_squeezed_channels, oup, 1, 1)
- self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
- self._se_expand = Tensor.glorot_uniform(oup, num_squeezed_channels, 1, 1)
- self._se_expand_bias = Tensor.zeros(oup)
- self._project_conv = Tensor.glorot_uniform(output_filters, oup, 1, 1)
- self._bn2 = BatchNorm2d(output_filters, track_running_stats=track_running_stats)
- def __call__(self, inputs):
- x = inputs
- if self._expand_conv is not None:
- x = self._bn0(x.conv2d(self._expand_conv)).swish()
- x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0])
- x = self._bn1(x).swish()
- if self.has_se:
- x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
- x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
- x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
- x = x.mul(x_squeezed.sigmoid())
- x = self._bn2(x.conv2d(self._project_conv))
- if x.shape == inputs.shape:
- x = x.add(inputs)
- return x
- class EfficientNet:
- def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
- self.number = number
- global_params = [
- # width, depth
- (1.0, 1.0), # b0
- (1.0, 1.1), # b1
- (1.1, 1.2), # b2
- (1.2, 1.4), # b3
- (1.4, 1.8), # b4
- (1.6, 2.2), # b5
- (1.8, 2.6), # b6
- (2.0, 3.1), # b7
- (2.2, 3.6), # b8
- (4.3, 5.3), # l2
- ][max(number,0)]
- def round_filters(filters):
- multiplier = global_params[0]
- divisor = 8
- filters *= multiplier
- new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
- if new_filters < 0.9 * filters: # prevent rounding by more than 10%
- new_filters += divisor
- return int(new_filters)
- def round_repeats(repeats):
- return int(math.ceil(global_params[1] * repeats))
- out_channels = round_filters(32)
- self._conv_stem = Tensor.glorot_uniform(out_channels, input_channels, 3, 3)
- self._bn0 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
- blocks_args = [
- [1, 3, (1,1), 1, 32, 16, 0.25],
- [2, 3, (2,2), 6, 16, 24, 0.25],
- [2, 5, (2,2), 6, 24, 40, 0.25],
- [3, 3, (2,2), 6, 40, 80, 0.25],
- [3, 5, (1,1), 6, 80, 112, 0.25],
- [4, 5, (2,2), 6, 112, 192, 0.25],
- [1, 3, (1,1), 6, 192, 320, 0.25],
- ]
- if self.number == -1:
- blocks_args = [
- [1, 3, (2,2), 1, 32, 40, 0.25],
- [1, 3, (2,2), 1, 40, 80, 0.25],
- [1, 3, (2,2), 1, 80, 192, 0.25],
- [1, 3, (2,2), 1, 192, 320, 0.25],
- ]
- elif self.number == -2:
- blocks_args = [
- [1, 9, (8,8), 1, 32, 320, 0.25],
- ]
- self._blocks = []
- for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
- input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
- for n in range(round_repeats(num_repeats)):
- self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
- input_filters = output_filters
- strides = (1,1)
- in_channels = round_filters(320)
- out_channels = round_filters(1280)
- self._conv_head = Tensor.glorot_uniform(out_channels, in_channels, 1, 1)
- self._bn1 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
- if has_fc_output:
- self._fc = Tensor.glorot_uniform(out_channels, classes)
- self._fc_bias = Tensor.zeros(classes)
- else:
- self._fc = None
- def forward(self, x):
- x = self._bn0(x.conv2d(self._conv_stem, padding=(0,1,0,1), stride=2)).swish()
- x = x.sequential(self._blocks)
- x = self._bn1(x.conv2d(self._conv_head)).swish()
- x = x.avg_pool2d(kernel_size=x.shape[2:4])
- x = x.reshape(shape=(-1, x.shape[1]))
- return x.linear(self._fc, self._fc_bias) if self._fc is not None else x
- def load_from_pretrained(self):
- model_urls = {
- 0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
- 1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
- 2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
- 3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
- 4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
- 5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
- 6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
- 7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
- }
- b0 = torch_load(fetch(model_urls[self.number]))
- for k,v in b0.items():
- if k.endswith("num_batches_tracked"): continue
- for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
- if cat in k:
- k = k.replace('.bias', '_bias')
- k = k.replace('.weight', '')
- #print(k, v.shape)
- mv:Tensor = get_child(self, k)
- vnp = v #.astype(np.float32)
- vnp = vnp if k != '_fc' else vnp.clang().T
- #vnp = vnp if vnp.shape != () else np.array([vnp])
- if mv.shape == vnp.shape:
- mv.replace(vnp.to(mv.device))
- else:
- print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|