simple_conv_bn.py 516 B

1234567891011121314151617
  1. from tinygrad.tensor import Tensor
  2. from tinygrad.nn import Conv2d, BatchNorm2d
  3. from tinygrad.nn.state import get_parameters
  4. if __name__ == "__main__":
  5. with Tensor.train():
  6. BS, C1, H, W = 4, 16, 224, 224
  7. C2, K, S, P = 64, 7, 2, 1
  8. x = Tensor.uniform(BS, C1, H, W)
  9. conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
  10. bn = BatchNorm2d(C2, track_running_stats=False)
  11. for t in get_parameters([x, conv, bn]): t.realize()
  12. print("running network")
  13. x.sequential([conv, bn]).numpy()