external_metal_compile_fail.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # AssertionError: Error Domain=AGXMetalG15X_B0 Code=3 "Compiler encountered an internal error"
  2. src = """
  3. #include <metal_stdlib>
  4. using namespace metal;
  5. kernel void r_64_32_8_16_4_6_6_4(device float* data0, const device float* data1,
  6. uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
  7. int gidx0 = gid.x; /* 64 */
  8. int lidx2 = lid.x; /* 8 */
  9. int gidx1 = gid.y; /* 32 */
  10. int lidx3 = lid.y; /* 16 */
  11. int alu0 = ((gidx0*4096)+(gidx1*16)+(lidx2*512)+lidx3);
  12. int alu1 = ((gidx0*147456)+(gidx1*576)+(lidx2*18432)+(lidx3*36));
  13. float acc0 = 0.0f;
  14. float acc1 = 0.0f;
  15. float acc2 = 0.0f;
  16. float acc3 = 0.0f;
  17. float acc4 = 0.0f;
  18. float acc5 = 0.0f;
  19. float acc6 = 0.0f;
  20. float acc7 = 0.0f;
  21. float acc8 = 0.0f;
  22. float acc9 = 0.0f;
  23. float acc10 = 0.0f;
  24. float acc11 = 0.0f;
  25. float acc12 = 0.0f;
  26. float acc13 = 0.0f;
  27. float acc14 = 0.0f;
  28. float acc15 = 0.0f;
  29. float acc16 = 0.0f;
  30. float acc17 = 0.0f;
  31. float acc18 = 0.0f;
  32. float acc19 = 0.0f;
  33. float acc20 = 0.0f;
  34. float acc21 = 0.0f;
  35. float acc22 = 0.0f;
  36. float acc23 = 0.0f;
  37. float acc24 = 0.0f;
  38. float acc25 = 0.0f;
  39. float acc26 = 0.0f;
  40. float acc27 = 0.0f;
  41. float acc28 = 0.0f;
  42. float acc29 = 0.0f;
  43. float acc30 = 0.0f;
  44. float acc31 = 0.0f;
  45. float acc32 = 0.0f;
  46. float acc33 = 0.0f;
  47. float acc34 = 0.0f;
  48. float acc35 = 0.0f;
  49. for (int ridx0 = 0; ridx0 < 4; ridx0++) {
  50. int alu2 = (ridx0*6);
  51. int alu3 = (alu2+1);
  52. int alu4 = (alu2+2);
  53. int alu5 = (alu2+3);
  54. int alu6 = (alu2+4);
  55. int alu7 = (alu2+5);
  56. int alu8 = (alu2%7);
  57. int alu9 = ((alu8+1)%7);
  58. int alu10 = ((alu8+2)%7);
  59. int alu11 = ((alu8+3)%7);
  60. int alu12 = ((alu8+4)%7);
  61. int alu13 = ((alu8+5)%7);
  62. int alu14 = ((((alu0+(alu3/21))%262144)*144)+(((alu3/7)%3)*3)+(alu9*36));
  63. int alu15 = ((((alu0+(alu4/21))%262144)*144)+(((alu4/7)%3)*3)+(alu10*36));
  64. int alu16 = ((((alu0+(alu5/21))%262144)*144)+(((alu5/7)%3)*3)+(alu11*36));
  65. int alu17 = ((((alu0+(alu6/21))%262144)*144)+(((alu6/7)%3)*3)+(alu12*36));
  66. int alu18 = ((((alu0+(alu7/21))%262144)*144)+(((alu7/7)%3)*3)+(alu13*36));
  67. int alu19 = (alu8%7);
  68. int alu20 = ((((alu0+(alu2/21))%262144)*144)+(((alu2/7)%3)*3)+(alu19*36));
  69. bool alu21 = ((alu2<16)&(alu13<4));
  70. bool alu22 = ((alu2<17)&(alu12<4));
  71. bool alu23 = ((alu2<18)&(alu11<4));
  72. bool alu24 = ((alu2<19)&(alu10<4));
  73. bool alu25 = ((alu2<20)&(alu9<4));
  74. bool alu26 = ((alu2<21)&(alu19<4));
  75. float val0 = (alu25?*(data1+alu14+1):0.0f);
  76. float val1 = (alu25?*(data1+alu14+2):0.0f);
  77. float val2 = (alu25?*(data1+alu14+9):0.0f);
  78. float val3 = (alu25?*(data1+alu14+10):0.0f);
  79. float val4 = (alu25?*(data1+alu14+11):0.0f);
  80. float val5 = (alu25?*(data1+alu14+18):0.0f);
  81. float val6 = (alu25?*(data1+alu14+19):0.0f);
  82. float val7 = (alu25?*(data1+alu14+20):0.0f);
  83. float val8 = (alu25?*(data1+alu14+27):0.0f);
  84. float val9 = (alu25?*(data1+alu14+28):0.0f);
  85. float val10 = (alu25?*(data1+alu14+29):0.0f);
  86. float val11 = (alu24?*(data1+alu15+1):0.0f);
  87. float val12 = (alu24?*(data1+alu15+2):0.0f);
  88. float val13 = (alu24?*(data1+alu15+9):0.0f);
  89. float val14 = (alu24?*(data1+alu15+10):0.0f);
  90. float val15 = (alu24?*(data1+alu15+11):0.0f);
  91. float val16 = (alu24?*(data1+alu15+18):0.0f);
  92. float val17 = (alu24?*(data1+alu15+19):0.0f);
  93. float val18 = (alu24?*(data1+alu15+20):0.0f);
  94. float val19 = (alu24?*(data1+alu15+27):0.0f);
  95. float val20 = (alu24?*(data1+alu15+28):0.0f);
  96. float val21 = (alu24?*(data1+alu15+29):0.0f);
  97. float val22 = (alu23?*(data1+alu16+1):0.0f);
  98. float val23 = (alu23?*(data1+alu16+2):0.0f);
  99. float val24 = (alu23?*(data1+alu16+9):0.0f);
  100. float val25 = (alu23?*(data1+alu16+10):0.0f);
  101. float val26 = (alu23?*(data1+alu16+11):0.0f);
  102. float val27 = (alu23?*(data1+alu16+18):0.0f);
  103. float val28 = (alu23?*(data1+alu16+19):0.0f);
  104. float val29 = (alu23?*(data1+alu16+20):0.0f);
  105. float val30 = (alu23?*(data1+alu16+27):0.0f);
  106. float val31 = (alu23?*(data1+alu16+28):0.0f);
  107. float val32 = (alu23?*(data1+alu16+29):0.0f);
  108. float val33 = (alu22?*(data1+alu17+1):0.0f);
  109. float val34 = (alu22?*(data1+alu17+2):0.0f);
  110. float val35 = (alu22?*(data1+alu17+9):0.0f);
  111. float val36 = (alu22?*(data1+alu17+10):0.0f);
  112. float val37 = (alu22?*(data1+alu17+11):0.0f);
  113. float val38 = (alu22?*(data1+alu17+18):0.0f);
  114. float val39 = (alu22?*(data1+alu17+19):0.0f);
  115. float val40 = (alu22?*(data1+alu17+20):0.0f);
  116. float val41 = (alu22?*(data1+alu17+27):0.0f);
  117. float val42 = (alu22?*(data1+alu17+28):0.0f);
  118. float val43 = (alu22?*(data1+alu17+29):0.0f);
  119. float val44 = (alu21?*(data1+alu18+1):0.0f);
  120. float val45 = (alu21?*(data1+alu18+2):0.0f);
  121. float val46 = (alu21?*(data1+alu18+9):0.0f);
  122. float val47 = (alu21?*(data1+alu18+10):0.0f);
  123. float val48 = (alu21?*(data1+alu18+11):0.0f);
  124. float val49 = (alu21?*(data1+alu18+18):0.0f);
  125. float val50 = (alu21?*(data1+alu18+19):0.0f);
  126. float val51 = (alu21?*(data1+alu18+20):0.0f);
  127. float val52 = (alu21?*(data1+alu18+27):0.0f);
  128. float val53 = (alu21?*(data1+alu18+28):0.0f);
  129. float val54 = (alu21?*(data1+alu18+29):0.0f);
  130. float val55 = (alu26?*(data1+alu20+1):0.0f);
  131. float val56 = (alu26?*(data1+alu20+2):0.0f);
  132. float val57 = (alu26?*(data1+alu20+9):0.0f);
  133. float val58 = (alu26?*(data1+alu20+10):0.0f);
  134. float val59 = (alu26?*(data1+alu20+11):0.0f);
  135. float val60 = (alu26?*(data1+alu20+18):0.0f);
  136. float val61 = (alu26?*(data1+alu20+19):0.0f);
  137. float val62 = (alu26?*(data1+alu20+20):0.0f);
  138. float val63 = (alu26?*(data1+alu20+27):0.0f);
  139. float val64 = (alu26?*(data1+alu20+28):0.0f);
  140. float val65 = (alu26?*(data1+alu20+29):0.0f);
  141. float val66 = (alu25?*(data1+alu14):0.0f);
  142. float val67 = (alu24?*(data1+alu15):0.0f);
  143. float val68 = (alu23?*(data1+alu16):0.0f);
  144. float val69 = (alu22?*(data1+alu17):0.0f);
  145. float val70 = (alu21?*(data1+alu18):0.0f);
  146. float val71 = (alu26?*(data1+alu20):0.0f);
  147. acc0 = (acc0+val71);
  148. acc1 = (acc1+val66);
  149. acc2 = (acc2+val67);
  150. acc3 = (acc3+val68);
  151. acc4 = (acc4+val69);
  152. acc5 = (acc5+val70);
  153. acc6 = (acc6+val57+val55);
  154. acc7 = (acc7+val2+val0);
  155. acc8 = (acc8+val13+val11);
  156. acc9 = (acc9+val24+val22);
  157. acc10 = (acc10+val35+val33);
  158. acc11 = (acc11+val46+val44);
  159. acc12 = (acc12+val60+val58+val56);
  160. acc13 = (acc13+val5+val3+val1);
  161. acc14 = (acc14+val16+val14+val12);
  162. acc15 = (acc15+val27+val25+val23);
  163. acc16 = (acc16+val38+val36+val34);
  164. acc17 = (acc17+val49+val47+val45);
  165. acc18 = (acc18+val63+val61+val59);
  166. acc19 = (acc19+val8+val6+val4);
  167. acc20 = (acc20+val19+val17+val15);
  168. acc21 = (acc21+val30+val28+val26);
  169. acc22 = (acc22+val41+val39+val37);
  170. acc23 = (acc23+val52+val50+val48);
  171. acc24 = (acc24+val64+val62);
  172. acc25 = (acc25+val9+val7);
  173. acc26 = (acc26+val20+val18);
  174. acc27 = (acc27+val31+val29);
  175. acc28 = (acc28+val42+val40);
  176. acc29 = (acc29+val53+val51);
  177. acc30 = (acc30+val65);
  178. acc31 = (acc31+val10);
  179. acc32 = (acc32+val21);
  180. acc33 = (acc33+val32);
  181. acc34 = (acc34+val43);
  182. acc35 = (acc35+val54);
  183. }
  184. *(data0+alu1+1) = acc6;
  185. *(data0+alu1+2) = acc12;
  186. *(data0+alu1+3) = acc18;
  187. *(data0+alu1+4) = acc24;
  188. *(data0+alu1+5) = acc30;
  189. *(data0+alu1+6) = acc1;
  190. *(data0+alu1+7) = acc7;
  191. *(data0+alu1+8) = acc13;
  192. *(data0+alu1+9) = acc19;
  193. *(data0+alu1+10) = acc25;
  194. *(data0+alu1+11) = acc31;
  195. *(data0+alu1+12) = acc2;
  196. *(data0+alu1+13) = acc8;
  197. *(data0+alu1+14) = acc14;
  198. *(data0+alu1+15) = acc20;
  199. *(data0+alu1+16) = acc26;
  200. *(data0+alu1+17) = acc32;
  201. *(data0+alu1+18) = acc3;
  202. *(data0+alu1+19) = acc9;
  203. *(data0+alu1+20) = acc15;
  204. *(data0+alu1+21) = acc21;
  205. *(data0+alu1+22) = acc27;
  206. *(data0+alu1+23) = acc33;
  207. *(data0+alu1+24) = acc4;
  208. *(data0+alu1+25) = acc10;
  209. *(data0+alu1+26) = acc16;
  210. *(data0+alu1+27) = acc22;
  211. *(data0+alu1+28) = acc28;
  212. *(data0+alu1+29) = acc34;
  213. *(data0+alu1+30) = acc5;
  214. *(data0+alu1+31) = acc11;
  215. *(data0+alu1+32) = acc17;
  216. *(data0+alu1+33) = acc23;
  217. *(data0+alu1+34) = acc29;
  218. *(data0+alu1+35) = acc35;
  219. *(data0+alu1) = acc0;
  220. }
  221. """
  222. from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
  223. if __name__ == "__main__":
  224. dev = MetalDevice("METAL")
  225. lib = MetalCompiler(dev).compile(src)
  226. prg = MetalProgram(dev, "r_64_32_8_16_4_6_6_4", lib)