matmul.c 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. // clang -Ofast -Wno-unused-result -march=native matmul.c
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <time.h>
  5. float b52[786432];
  6. float b49[196608];
  7. float h_0_mlp_c_fc_weight[2359296];
  8. float h_0_mlp_c_fc_bias[3072];
  9. void matmul_forward(float* out,
  10. float* inp, float* weight, float* bias,
  11. int B, int T, int C, int OC) {
  12. // most of the running time is spent here and in matmul_backward
  13. // OC is short for "output channels"
  14. // inp is (B,T,C), weight is (OC, C), bias is (OC)
  15. // out will be (B,T,OC)
  16. #pragma omp parallel for collapse(2)
  17. for (int b = 0; b < B; b++) {
  18. for (int t = 0; t < T; t++) {
  19. float* out_bt = out + b * T * OC + t * OC;
  20. float* inp_bt = inp + b * T * C + t * C;
  21. for (int o = 0; o < OC; o++) {
  22. float val = (bias != NULL) ? bias[o] : 0.0f;
  23. float* wrow = weight + o*C;
  24. for (int i = 0; i < C; i++) {
  25. val += inp_bt[i] * wrow[i];
  26. }
  27. out_bt[o] = val;
  28. }
  29. }
  30. }
  31. }
  32. void r_256_3072_768(float* restrict data0, const float* restrict data1, const float* restrict data2, const float* restrict data3) {
  33. for (int ridx0 = 0; ridx0 < 256; ridx0++) {
  34. for (int ridx1 = 0; ridx1 < 3072; ridx1++) {
  35. float acc0 = 0.0f;
  36. float val0 = data3[ridx1];
  37. for (int ridx2 = 0; ridx2 < 768; ridx2++) {
  38. float val1 = data1[(ridx0*768)+ridx2];
  39. float val2 = data2[(ridx1*768)+ridx2];
  40. acc0 = ((val1*val2)+acc0);
  41. }
  42. data0[(ridx0*3072)+ridx1] = (acc0+val0);
  43. }
  44. }
  45. }
  46. int main() {
  47. for (int i = 0; i < 5; i++) {
  48. struct timespec t1, t2, t3;
  49. clock_gettime(CLOCK_MONOTONIC, &t1);
  50. r_256_3072_768(b52, b49, h_0_mlp_c_fc_weight, h_0_mlp_c_fc_bias);
  51. clock_gettime(CLOCK_MONOTONIC, &t2);
  52. matmul_forward(b52, b49, h_0_mlp_c_fc_weight, h_0_mlp_c_fc_bias, 4, 64, 768, 3072);
  53. clock_gettime(CLOCK_MONOTONIC, &t3);
  54. double time_gen = (t2.tv_sec - t1.tv_sec) + (t2.tv_nsec - t1.tv_nsec) / 1e9;
  55. double time_real = (t3.tv_sec - t2.tv_sec) + (t3.tv_nsec - t2.tv_nsec) / 1e9;
  56. printf("%.2f ms gen vs %.2f ms reference\n", time_gen*1e3, time_real*1e3);
  57. }
  58. }