8abe8e6515aebd88d01a60c96d6b0dfdf4674724
[riscv-tests.git] / benchmarks / mm / mm.c
1 #include "common.h"
2 #include <assert.h>
3 #include <math.h>
4 #include <stdint.h>
5 #include <alloca.h>
6
7 #define MIN(a, b) ((a) < (b) ? (a) : (b))
8
9 static void mm_naive(size_t m, size_t n, size_t p,
10 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
11 {
12 for (size_t i = 0; i < m; i++)
13 {
14 for (size_t j = 0; j < n; j++)
15 {
16 t s0 = c[i*ldc+j], s1 = 0, s2 = 0, s3 = 0;
17 for (size_t k = 0; k < p/4*4; k+=4)
18 {
19 s0 = fma(a[i*lda+k+0], b[(k+0)*ldb+j], s0);
20 s1 = fma(a[i*lda+k+1], b[(k+1)*ldb+j], s1);
21 s2 = fma(a[i*lda+k+2], b[(k+2)*ldb+j], s2);
22 s3 = fma(a[i*lda+k+3], b[(k+3)*ldb+j], s3);
23 }
24 for (size_t k = p/4*4; k < p; k++)
25 s0 = fma(a[i*lda+k], b[k*ldb+j], s0);
26 c[i*ldc+j] = (s0 + s1) + (s2 + s3);
27 }
28 }
29 }
30
31 static inline void mm_rb(size_t m, size_t n, size_t p,
32 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
33 {
34 size_t mb = m/RBM*RBM, nb = n/RBN*RBN;
35 for (size_t i = 0; i < mb; i += RBM)
36 {
37 for (size_t j = 0; j < nb; j += RBN)
38 kloop(p, a+i*lda, lda, b+j, ldb, c+i*ldc+j, ldc);
39 mm_naive(RBM, n - nb, p, a+i*lda, lda, b+nb, ldb, c+i*ldc+nb, ldc);
40 }
41 mm_naive(m - mb, n, p, a+mb*lda, lda, b, ldb, c+mb*ldc, ldc);
42 }
43
44 static inline void repack(t* a, size_t lda, const t* a0, size_t lda0, size_t m, size_t p)
45 {
46 for (size_t i = 0; i < m; i++)
47 {
48 for (size_t j = 0; j < p/8*8; j+=8)
49 {
50 t t0 = a0[i*lda0+j+0];
51 t t1 = a0[i*lda0+j+1];
52 t t2 = a0[i*lda0+j+2];
53 t t3 = a0[i*lda0+j+3];
54 t t4 = a0[i*lda0+j+4];
55 t t5 = a0[i*lda0+j+5];
56 t t6 = a0[i*lda0+j+6];
57 t t7 = a0[i*lda0+j+7];
58 a[i*lda+j+0] = t0;
59 a[i*lda+j+1] = t1;
60 a[i*lda+j+2] = t2;
61 a[i*lda+j+3] = t3;
62 a[i*lda+j+4] = t4;
63 a[i*lda+j+5] = t5;
64 a[i*lda+j+6] = t6;
65 a[i*lda+j+7] = t7;
66 }
67 for (size_t j = p/8*8; j < p; j++)
68 a[i*lda+j] = a0[i*lda0+j];
69 }
70 }
71
72 static void mm_cb(size_t m, size_t n, size_t p,
73 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
74 {
75 size_t nmb = m/CBM, nnb = n/CBN, npb = p/CBK;
76 size_t mb = nmb*CBM, nb = nnb*CBN, pb = npb*CBK;
77 //t a1[mb*pb], b1[pb*nb], c1[mb*nb];
78 t* a1 = (t*)alloca_aligned(sizeof(t)*mb*pb, 8192);
79 t* b1 = (t*)alloca_aligned(sizeof(t)*pb*nb, 8192);
80 t* c1 = (t*)alloca_aligned(sizeof(t)*mb*nb, 8192);
81
82 for (size_t i = 0; i < mb; i += CBM)
83 for (size_t j = 0; j < pb; j += CBK)
84 repack(a1 + (npb*(i/CBM) + j/CBK)*(CBM*CBK), CBK, a + i*lda + j, lda, CBM, CBK);
85
86 for (size_t i = 0; i < pb; i += CBK)
87 for (size_t j = 0; j < nb; j += CBN)
88 repack(b1 + (nnb*(i/CBK) + j/CBN)*(CBK*CBN), CBN, b + i*ldb + j, ldb, CBK, CBN);
89
90 for (size_t i = 0; i < mb; i += CBM)
91 for (size_t j = 0; j < nb; j += CBN)
92 repack(c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN, c + i*ldc + j, ldc, CBM, CBN);
93
94 for (size_t i = 0; i < mb; i += CBM)
95 {
96 for (size_t j = 0; j < nb; j += CBN)
97 {
98 for (size_t k = 0; k < pb; k += CBK)
99 {
100 mm_rb(CBM, CBN, CBK,
101 a1 + (npb*(i/CBM) + k/CBK)*(CBM*CBK), CBK,
102 b1 + (nnb*(k/CBK) + j/CBN)*(CBK*CBN), CBN,
103 c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN);
104 }
105 if (pb < p)
106 {
107 mm_rb(CBM, CBN, p - pb,
108 a + i*lda + pb, lda,
109 b + pb*ldb + j, ldb,
110 c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN);
111 }
112 }
113 if (nb < n)
114 {
115 for (size_t k = 0; k < p; k += CBK)
116 {
117 mm_rb(CBM, n - nb, MIN(p - k, CBK),
118 a + i*lda + k, lda,
119 b + k*ldb + nb, ldb,
120 c + i*ldc + nb, ldc);
121 }
122 }
123 }
124 if (mb < m)
125 {
126 for (size_t j = 0; j < n; j += CBN)
127 {
128 for (size_t k = 0; k < p; k += CBK)
129 {
130 mm_rb(m - mb, MIN(n - j, CBN), MIN(p - k, CBK),
131 a + mb*lda + k, lda,
132 b + k*ldb + j, ldb,
133 c + mb*ldc + j, ldc);
134 }
135 }
136 }
137
138 for (size_t i = 0; i < mb; i += CBM)
139 for (size_t j = 0; j < nb; j += CBN)
140 repack(c + i*ldc + j, ldc, c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN, CBM, CBN);
141 }
142
143 void mm(size_t m, size_t n, size_t p,
144 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
145 {
146 if (__builtin_expect(m <= 2*CBM && n <= 2*CBN && p <= 2*CBK, 1))
147 mm_rb(m, n, p, a, lda, b, ldb, c, ldc);
148 else
149 mm_cb(m, n, p, a, lda, b, ldb, c, ldc);
150 }