Rework benchmarks to run in M-mode
[riscv-tests.git] / benchmarks / mm / mm.c
1 // See LICENSE for license details.
2
3 #include "common.h"
4 #include <assert.h>
5 #include <math.h>
6 #include <stdint.h>
7 #include <alloca.h>
8
9 #define MIN(a, b) ((a) < (b) ? (a) : (b))
10
11 static void mm_naive(size_t m, size_t n, size_t p,
12 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
13 {
14 for (size_t i = 0; i < m; i++)
15 {
16 for (size_t j = 0; j < n; j++)
17 {
18 t s0 = c[i*ldc+j], s1 = 0, s2 = 0, s3 = 0;
19 for (size_t k = 0; k < p/4*4; k+=4)
20 {
21 s0 = fma(a[i*lda+k+0], b[(k+0)*ldb+j], s0);
22 s1 = fma(a[i*lda+k+1], b[(k+1)*ldb+j], s1);
23 s2 = fma(a[i*lda+k+2], b[(k+2)*ldb+j], s2);
24 s3 = fma(a[i*lda+k+3], b[(k+3)*ldb+j], s3);
25 }
26 for (size_t k = p/4*4; k < p; k++)
27 s0 = fma(a[i*lda+k], b[k*ldb+j], s0);
28 c[i*ldc+j] = (s0 + s1) + (s2 + s3);
29 }
30 }
31 }
32
33 static inline void mm_rb(size_t m, size_t n, size_t p,
34 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
35 {
36 size_t mb = m/RBM*RBM, nb = n/RBN*RBN;
37 for (size_t i = 0; i < mb; i += RBM)
38 {
39 for (size_t j = 0; j < nb; j += RBN)
40 kloop(p, a+i*lda, lda, b+j, ldb, c+i*ldc+j, ldc);
41 mm_naive(RBM, n - nb, p, a+i*lda, lda, b+nb, ldb, c+i*ldc+nb, ldc);
42 }
43 mm_naive(m - mb, n, p, a+mb*lda, lda, b, ldb, c+mb*ldc, ldc);
44 }
45
46 static inline void repack(t* a, size_t lda, const t* a0, size_t lda0, size_t m, size_t p)
47 {
48 for (size_t i = 0; i < m; i++)
49 {
50 for (size_t j = 0; j < p/8*8; j+=8)
51 {
52 t t0 = a0[i*lda0+j+0];
53 t t1 = a0[i*lda0+j+1];
54 t t2 = a0[i*lda0+j+2];
55 t t3 = a0[i*lda0+j+3];
56 t t4 = a0[i*lda0+j+4];
57 t t5 = a0[i*lda0+j+5];
58 t t6 = a0[i*lda0+j+6];
59 t t7 = a0[i*lda0+j+7];
60 a[i*lda+j+0] = t0;
61 a[i*lda+j+1] = t1;
62 a[i*lda+j+2] = t2;
63 a[i*lda+j+3] = t3;
64 a[i*lda+j+4] = t4;
65 a[i*lda+j+5] = t5;
66 a[i*lda+j+6] = t6;
67 a[i*lda+j+7] = t7;
68 }
69 for (size_t j = p/8*8; j < p; j++)
70 a[i*lda+j] = a0[i*lda0+j];
71 }
72 }
73
74 static void mm_cb(size_t m, size_t n, size_t p,
75 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
76 {
77 size_t nmb = m/CBM, nnb = n/CBN, npb = p/CBK;
78 size_t mb = nmb*CBM, nb = nnb*CBN, pb = npb*CBK;
79 //t a1[mb*pb], b1[pb*nb], c1[mb*nb];
80 t* a1 = (t*)alloca_aligned(sizeof(t)*mb*pb, 8192);
81 t* b1 = (t*)alloca_aligned(sizeof(t)*pb*nb, 8192);
82 t* c1 = (t*)alloca_aligned(sizeof(t)*mb*nb, 8192);
83
84 for (size_t i = 0; i < mb; i += CBM)
85 for (size_t j = 0; j < pb; j += CBK)
86 repack(a1 + (npb*(i/CBM) + j/CBK)*(CBM*CBK), CBK, a + i*lda + j, lda, CBM, CBK);
87
88 for (size_t i = 0; i < pb; i += CBK)
89 for (size_t j = 0; j < nb; j += CBN)
90 repack(b1 + (nnb*(i/CBK) + j/CBN)*(CBK*CBN), CBN, b + i*ldb + j, ldb, CBK, CBN);
91
92 for (size_t i = 0; i < mb; i += CBM)
93 for (size_t j = 0; j < nb; j += CBN)
94 repack(c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN, c + i*ldc + j, ldc, CBM, CBN);
95
96 for (size_t i = 0; i < mb; i += CBM)
97 {
98 for (size_t j = 0; j < nb; j += CBN)
99 {
100 for (size_t k = 0; k < pb; k += CBK)
101 {
102 mm_rb(CBM, CBN, CBK,
103 a1 + (npb*(i/CBM) + k/CBK)*(CBM*CBK), CBK,
104 b1 + (nnb*(k/CBK) + j/CBN)*(CBK*CBN), CBN,
105 c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN);
106 }
107 if (pb < p)
108 {
109 mm_rb(CBM, CBN, p - pb,
110 a + i*lda + pb, lda,
111 b + pb*ldb + j, ldb,
112 c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN);
113 }
114 }
115 if (nb < n)
116 {
117 for (size_t k = 0; k < p; k += CBK)
118 {
119 mm_rb(CBM, n - nb, MIN(p - k, CBK),
120 a + i*lda + k, lda,
121 b + k*ldb + nb, ldb,
122 c + i*ldc + nb, ldc);
123 }
124 }
125 }
126 if (mb < m)
127 {
128 for (size_t j = 0; j < n; j += CBN)
129 {
130 for (size_t k = 0; k < p; k += CBK)
131 {
132 mm_rb(m - mb, MIN(n - j, CBN), MIN(p - k, CBK),
133 a + mb*lda + k, lda,
134 b + k*ldb + j, ldb,
135 c + mb*ldc + j, ldc);
136 }
137 }
138 }
139
140 for (size_t i = 0; i < mb; i += CBM)
141 for (size_t j = 0; j < nb; j += CBN)
142 repack(c + i*ldc + j, ldc, c1 + (nnb*(i/CBM) + j/CBN)*(CBM*CBN), CBN, CBM, CBN);
143 }
144
145 void mm(size_t m, size_t n, size_t p,
146 t* a, size_t lda, t* b, size_t ldb, t* c, size_t ldc)
147 {
148 if (__builtin_expect(m <= 2*CBM && n <= 2*CBN && p <= 2*CBK, 1))
149 mm_rb(m, n, p, a, lda, b, ldb, c, ldc);
150 else
151 mm_cb(m, n, p, a, lda, b, ldb, c, ldc);
152 }