joint_matrix_bfloat16.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. //==-------- joint_matrix_bfloat16.cpp - DPC++ joint_matrix----------- ----==//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. // REQUIRES: matrix
  9. // RUN: %clangxx -fsycl %s -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
  10. // RUN: %CPU_RUN_PLACEHOLDER %t.out
  11. // RUN: %GPU_RUN_PLACEHOLDER %t.out
  12. #include <iostream>
  13. #include <sycl/sycl.hpp>
  14. using namespace sycl;
  15. using namespace sycl::ext::oneapi::experimental::matrix;
  16. using bfloat16 = sycl::ext::oneapi::bfloat16;
  17. //#define SG_SZ 16
  18. #define SG_SZ 8
  19. #define TM 8
  20. #define TN SG_SZ
  21. //#define TK 16
  22. #define TK 16
  23. #define BF16_EPSILON 0.00781250
  24. template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
  25. private:
  26. T *mat;
  27. public:
  28. T *get_data() { return mat; }
  29. void set_data(T *data) { mat = data; }
  30. big_matrix(T *data) : mat(data) {}
  31. };
  32. template <typename T1, typename T2, size_t M, size_t N, size_t K>
  33. void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A, big_matrix<T2, K / 2, N * 2> &B) {
  34. size_t NDRangeM = M / TM;
  35. size_t NDRangeN = N / TN;
  36. buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
  37. buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
  38. buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
  39. auto program = [&](handler &cgh) {
  40. auto accC = bufC.get_access<access::mode::read_write>(cgh);
  41. auto accA = bufA.get_access<access::mode::read_write>(cgh);
  42. auto accB = bufB.get_access<access::mode::read_write>(cgh);
  43. cgh.parallel_for<class imatrix>(
  44. nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
  45. [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
  46. {
  47. // The submatrix API has to be accessed by all the workitems in a
  48. // subgroup these functions will be called once by the subgroup no
  49. // code divergence between the workitems
  50. const auto global_idx = spmd_item.get_global_id(0);
  51. const auto global_idy = spmd_item.get_global_id(1);
  52. const auto sg_startx = global_idx - spmd_item.get_local_id(0);
  53. const auto sg_starty = global_idy - spmd_item.get_local_id(1);
  54. sub_group sg = spmd_item.get_sub_group();
  55. joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major> sub_a;
  56. // For B, we assume B has been already VNNIed.
  57. joint_matrix<sub_group, bfloat16, use::b, TK, TN, ext::intel::experimental::matrix::layout::packed> sub_b;
  58. joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
  59. joint_matrix_load(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major);
  60. for (int k = 0; k < K / TK; k += 1) { //
  61. joint_matrix_load(sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, K);
  62. joint_matrix_load(sg, sub_b, accB.get_pointer() + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, N * 2);
  63. sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
  64. }
  65. joint_matrix_store(sg, sub_c, accC.get_pointer() + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N, layout::row_major);
  66. }); // parallel for
  67. };
  68. queue q;
  69. auto start = std::chrono::steady_clock::now();
  70. auto e = q.submit(program);
  71. auto submit = std::chrono::steady_clock::now();
  72. e.wait();
  73. auto end = std::chrono::steady_clock::now();
  74. std::cout << "submit: " << std::chrono::duration_cast<std::chrono::milliseconds>(submit - start).count() << " ms" << std::endl;
  75. std::cout << "compute: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - submit).count() << " ms" << std::endl;
  76. // ahh, freeing is slow
  77. }
  78. //#define SCALE 1024
  79. //#define SCALE 64
  80. #define SCALE 256
  81. static constexpr size_t MATRIX_M = TM * SCALE;
  82. static constexpr size_t MATRIX_N = TN * SCALE;
  83. static constexpr size_t MATRIX_K = TK * SCALE;
  84. bfloat16 A[MATRIX_M][MATRIX_K];
  85. bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
  86. float C[MATRIX_M][MATRIX_N];
  87. float D[MATRIX_M][MATRIX_N];
  88. float make_fp32(bfloat16 x) {
  89. unsigned int y = *((int *)&x);
  90. y = y << 16;
  91. float *res = reinterpret_cast<float *>(&y);
  92. return *res;
  93. }
  94. void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
  95. int K) {
  96. for (int m = 0; m < M; m++)
  97. for (int n = 0; n < N; n++) {
  98. for (int k = 0; k < K; k++) {
  99. // Because B was assumed VNNIed
  100. bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
  101. bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
  102. float acc = *((float *)(C_mem + m * N + n));
  103. for (int i = 0; i < 2; i++) {
  104. acc += (make_fp32(va[i]) * make_fp32(vb[i]));
  105. }
  106. *((float *)(C_mem + m * N + n)) = acc;
  107. }
  108. }
  109. }
  110. int main() {
  111. for (int i = 0; i < MATRIX_M; i++) {
  112. for (int j = 0; j < MATRIX_K; j++) {
  113. A[i][j] = bfloat16(1.0f * (i + j));
  114. }
  115. }
  116. for (int i = 0; i < MATRIX_K / 2; i++) {
  117. for (int j = 0; j < MATRIX_N * 2; j++) {
  118. B[i][j] = bfloat16(2.0f * i + 3.0f * j);
  119. }
  120. }
  121. for (int i = 0; i < MATRIX_M; i++) {
  122. for (int j = 0; j < MATRIX_N; j++) {
  123. C[i][j] = 1.0;
  124. D[i][j] = 1.0;
  125. }
  126. }
  127. std::cout << "M" << MATRIX_M << "N" << MATRIX_N << "K" << MATRIX_K << std::endl;
  128. big_matrix<float, MATRIX_M, MATRIX_N> MC((float *)&C);
  129. big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
  130. big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
  131. big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
  132. matrix_multiply(MC, MA, MB);
  133. /*start = std::chrono::steady_clock::now();
  134. matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, MATRIX_N, MATRIX_K / 2);
  135. end = std::chrono::steady_clock::now();
  136. std::cout << "Elapsed time in milliseconds (reference): " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << " ms" << std::endl;
  137. bool res = true;
  138. for (int i = 0; i < MATRIX_M; i++) {
  139. for (int j = 0; j < MATRIX_N; j++) {
  140. if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
  141. res = false;
  142. }
  143. }
  144. std::cout << (res ? "passed" : "failed") << std::endl;
  145. return !res;*/
  146. return 0;
  147. }