Remove Hwacha v3 tests
[riscv-tests.git] / mt / dr_matmul.c
1 #include "stdlib.h"
2
3 #include "util.h"
4
5 #include "dataset.h"
6 void __attribute__((noinline)) matmul(const int coreid, const int ncores, const int lda, const data_t A[], const data_t B[], data_t C[] )
7 {
8
9 // ***************************** //
10 // **** ADD YOUR CODE HERE ***** //
11 // ***************************** //
12 //
13 // feel free to make a separate function for MI and MSI versions.
14 int j2, i2, k2, j, i, k;
15 int tmpC00, tmpC01, tmpC02, tmpC03, tmpC04, tmpC05, tmpC06, tmpC07;
16 int tmpC10, tmpC11, tmpC12, tmpC13, tmpC14, tmpC15, tmpC16, tmpC17;
17 int jBLOCK = 32;
18 int iBLOCK = 16;
19 int kBLOCK = 32;
20 static __thread int tB[4096]; //__thread
21 int startInd = coreid*(lda/ncores);
22 int endInd = (coreid+1)*(lda/ncores);
23
24 //tranpose B (block?)
25 for (i = 0; i < lda; i += 2) {
26 for (j = startInd; j < endInd; j += 2) {
27 tB[j*lda + i] = B[i*lda + j];
28 tB[(j + 1)*lda + i] = B[i*lda + j + 1];
29 tB[j*lda + i + 1] = B[(i + 1)*lda + j];
30 tB[(j + 1)*lda + i + 1] = B[(i + 1)*lda + j + 1];
31 }
32 barrier(ncores);
33 }
34
35 // compute C[j*n + i] += A[j*n + k] + Btranspose[i*n + k]
36 for ( j2 = 0; j2 < lda; j2 += jBLOCK )
37 for ( i2 = startInd; i2 < endInd; i2 += iBLOCK )
38 for ( j = j2; j < j2 + jBLOCK; j += 2 )
39 for ( k2 = 0; k2 < lda; k2 += kBLOCK )
40 for ( i = i2; i < i2 + iBLOCK; i += 8) {
41 tmpC00 = C[j*lda + i + 0]; tmpC10 = C[(j + 1)*lda + i + 0];
42 tmpC01 = C[j*lda + i + 1]; tmpC11 = C[(j + 1)*lda + i + 1];
43 tmpC02 = C[j*lda + i + 2]; tmpC12 = C[(j + 1)*lda + i + 2];
44 tmpC03 = C[j*lda + i + 3]; tmpC13 = C[(j + 1)*lda + i + 3];
45 tmpC04 = C[j*lda + i + 4]; tmpC14 = C[(j + 1)*lda + i + 4];
46 tmpC05 = C[j*lda + i + 5]; tmpC15 = C[(j + 1)*lda + i + 5];
47 tmpC06 = C[j*lda + i + 6]; tmpC16 = C[(j + 1)*lda + i + 6];
48 tmpC07 = C[j*lda + i + 7]; tmpC17 = C[(j + 1)*lda + i + 7];
49 for ( k = k2; k < k2 + kBLOCK; k += 4) {
50 tmpC00 += A[j*lda + k] * tB[(i + 0)*lda + k];
51 tmpC01 += A[j*lda + k] * tB[(i + 1)*lda + k];
52 tmpC02 += A[j*lda + k] * tB[(i + 2)*lda + k];
53 tmpC03 += A[j*lda + k] * tB[(i + 3)*lda + k];
54 tmpC04 += A[j*lda + k] * tB[(i + 4)*lda + k];
55 tmpC05 += A[j*lda + k] * tB[(i + 5)*lda + k];
56 tmpC06 += A[j*lda + k] * tB[(i + 6)*lda + k];
57 tmpC07 += A[j*lda + k] * tB[(i + 7)*lda + k];
58 tmpC10 += A[(j + 1)*lda + k] * tB[(i + 0)*lda + k];
59 tmpC11 += A[(j + 1)*lda + k] * tB[(i + 1)*lda + k];
60 tmpC12 += A[(j + 1)*lda + k] * tB[(i + 2)*lda + k];
61 tmpC13 += A[(j + 1)*lda + k] * tB[(i + 3)*lda + k];
62 tmpC14 += A[(j + 1)*lda + k] * tB[(i + 4)*lda + k];
63 tmpC15 += A[(j + 1)*lda + k] * tB[(i + 5)*lda + k];
64 tmpC16 += A[(j + 1)*lda + k] * tB[(i + 6)*lda + k];
65 tmpC17 += A[(j + 1)*lda + k] * tB[(i + 7)*lda + k];
66
67 tmpC00 += A[j*lda + k + 1] * tB[(i + 0)*lda + k + 1];
68 tmpC01 += A[j*lda + k + 1] * tB[(i + 1)*lda + k + 1];
69 tmpC02 += A[j*lda + k + 1] * tB[(i + 2)*lda + k + 1];
70 tmpC03 += A[j*lda + k + 1] * tB[(i + 3)*lda + k + 1];
71 tmpC04 += A[j*lda + k + 1] * tB[(i + 4)*lda + k + 1];
72 tmpC05 += A[j*lda + k + 1] * tB[(i + 5)*lda + k + 1];
73 tmpC06 += A[j*lda + k + 1] * tB[(i + 6)*lda + k + 1];
74 tmpC07 += A[j*lda + k + 1] * tB[(i + 7)*lda + k + 1];
75 tmpC10 += A[(j + 1)*lda + k + 1] * tB[(i + 0)*lda + k + 1];
76 tmpC11 += A[(j + 1)*lda + k + 1] * tB[(i + 1)*lda + k + 1];
77 tmpC12 += A[(j + 1)*lda + k + 1] * tB[(i + 2)*lda + k + 1];
78 tmpC13 += A[(j + 1)*lda + k + 1] * tB[(i + 3)*lda + k + 1];
79 tmpC14 += A[(j + 1)*lda + k + 1] * tB[(i + 4)*lda + k + 1];
80 tmpC15 += A[(j + 1)*lda + k + 1] * tB[(i + 5)*lda + k + 1];
81 tmpC16 += A[(j + 1)*lda + k + 1] * tB[(i + 6)*lda + k + 1];
82 tmpC17 += A[(j + 1)*lda + k + 1] * tB[(i + 7)*lda + k + 1];
83
84 tmpC00 += A[j*lda + k + 2] * tB[(i + 0)*lda + k + 2];
85 tmpC01 += A[j*lda + k + 2] * tB[(i + 1)*lda + k + 2];
86 tmpC02 += A[j*lda + k + 2] * tB[(i + 2)*lda + k + 2];
87 tmpC03 += A[j*lda + k + 2] * tB[(i + 3)*lda + k + 2];
88 tmpC04 += A[j*lda + k + 2] * tB[(i + 4)*lda + k + 2];
89 tmpC05 += A[j*lda + k + 2] * tB[(i + 5)*lda + k + 2];
90 tmpC06 += A[j*lda + k + 2] * tB[(i + 6)*lda + k + 2];
91 tmpC07 += A[j*lda + k + 2] * tB[(i + 7)*lda + k + 2];
92 tmpC10 += A[(j + 1)*lda + k + 2] * tB[(i + 0)*lda + k + 2];
93 tmpC11 += A[(j + 1)*lda + k + 2] * tB[(i + 1)*lda + k + 2];
94 tmpC12 += A[(j + 1)*lda + k + 2] * tB[(i + 2)*lda + k + 2];
95 tmpC13 += A[(j + 1)*lda + k + 2] * tB[(i + 3)*lda + k + 2];
96 tmpC14 += A[(j + 1)*lda + k + 2] * tB[(i + 4)*lda + k + 2];
97 tmpC15 += A[(j + 1)*lda + k + 2] * tB[(i + 5)*lda + k + 2];
98 tmpC16 += A[(j + 1)*lda + k + 2] * tB[(i + 6)*lda + k + 2];
99 tmpC17 += A[(j + 1)*lda + k + 2] * tB[(i + 7)*lda + k + 2];
100
101 tmpC00 += A[j*lda + k + 3] * tB[(i + 0)*lda + k + 3];
102 tmpC01 += A[j*lda + k + 3] * tB[(i + 1)*lda + k + 3];
103 tmpC02 += A[j*lda + k + 3] * tB[(i + 2)*lda + k + 3];
104 tmpC03 += A[j*lda + k + 3] * tB[(i + 3)*lda + k + 3];
105 tmpC04 += A[j*lda + k + 3] * tB[(i + 4)*lda + k + 3];
106 tmpC05 += A[j*lda + k + 3] * tB[(i + 5)*lda + k + 3];
107 tmpC06 += A[j*lda + k + 3] * tB[(i + 6)*lda + k + 3];
108 tmpC07 += A[j*lda + k + 3] * tB[(i + 7)*lda + k + 3];
109 tmpC10 += A[(j + 1)*lda + k + 3] * tB[(i + 0)*lda + k + 3];
110 tmpC11 += A[(j + 1)*lda + k + 3] * tB[(i + 1)*lda + k + 3];
111 tmpC12 += A[(j + 1)*lda + k + 3] * tB[(i + 2)*lda + k + 3];
112 tmpC13 += A[(j + 1)*lda + k + 3] * tB[(i + 3)*lda + k + 3];
113 tmpC14 += A[(j + 1)*lda + k + 3] * tB[(i + 4)*lda + k + 3];
114 tmpC15 += A[(j + 1)*lda + k + 3] * tB[(i + 5)*lda + k + 3];
115 tmpC16 += A[(j + 1)*lda + k + 3] * tB[(i + 6)*lda + k + 3];
116 tmpC17 += A[(j + 1)*lda + k + 3] * tB[(i + 7)*lda + k + 3];
117 }
118 C[j*lda + i + 0] = tmpC00; C[(j + 1)*lda + i + 0] = tmpC10;
119 C[j*lda + i + 1] = tmpC01; C[(j + 1)*lda + i + 1] = tmpC11;
120 C[j*lda + i + 2] = tmpC02; C[(j + 1)*lda + i + 2] = tmpC12;
121 C[j*lda + i + 3] = tmpC03; C[(j + 1)*lda + i + 3] = tmpC13;
122 C[j*lda + i + 4] = tmpC04; C[(j + 1)*lda + i + 4] = tmpC14;
123 C[j*lda + i + 5] = tmpC05; C[(j + 1)*lda + i + 5] = tmpC15;
124 C[j*lda + i + 6] = tmpC06; C[(j + 1)*lda + i + 6] = tmpC16;
125 C[j*lda + i + 7] = tmpC07; C[(j + 1)*lda + i + 7] = tmpC17;
126 barrier(ncores);
127 }
128 }