1 // See LICENSE for license details.
9 #define MIN(a, b) ((a) < (b) ? (a) : (b))
11 static void mm_naive(size_t m
, size_t n
, size_t p
,
12 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
14 for (size_t i
= 0; i
< m
; i
++)
16 for (size_t j
= 0; j
< n
; j
++)
18 t s0
= c
[i
*ldc
+j
], s1
= 0, s2
= 0, s3
= 0;
19 for (size_t k
= 0; k
< p
/4*4; k
+=4)
21 s0
= fma(a
[i
*lda
+k
+0], b
[(k
+0)*ldb
+j
], s0
);
22 s1
= fma(a
[i
*lda
+k
+1], b
[(k
+1)*ldb
+j
], s1
);
23 s2
= fma(a
[i
*lda
+k
+2], b
[(k
+2)*ldb
+j
], s2
);
24 s3
= fma(a
[i
*lda
+k
+3], b
[(k
+3)*ldb
+j
], s3
);
26 for (size_t k
= p
/4*4; k
< p
; k
++)
27 s0
= fma(a
[i
*lda
+k
], b
[k
*ldb
+j
], s0
);
28 c
[i
*ldc
+j
] = (s0
+ s1
) + (s2
+ s3
);
33 static inline void mm_rb(size_t m
, size_t n
, size_t p
,
34 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
36 size_t mb
= m
/RBM
*RBM
, nb
= n
/RBN
*RBN
;
37 for (size_t i
= 0; i
< mb
; i
+= RBM
)
39 for (size_t j
= 0; j
< nb
; j
+= RBN
)
40 kloop(p
, a
+i
*lda
, lda
, b
+j
, ldb
, c
+i
*ldc
+j
, ldc
);
41 mm_naive(RBM
, n
- nb
, p
, a
+i
*lda
, lda
, b
+nb
, ldb
, c
+i
*ldc
+nb
, ldc
);
43 mm_naive(m
- mb
, n
, p
, a
+mb
*lda
, lda
, b
, ldb
, c
+mb
*ldc
, ldc
);
46 static inline void repack(t
* a
, size_t lda
, const t
* a0
, size_t lda0
, size_t m
, size_t p
)
48 for (size_t i
= 0; i
< m
; i
++)
50 for (size_t j
= 0; j
< p
/8*8; j
+=8)
52 t t0
= a0
[i
*lda0
+j
+0];
53 t t1
= a0
[i
*lda0
+j
+1];
54 t t2
= a0
[i
*lda0
+j
+2];
55 t t3
= a0
[i
*lda0
+j
+3];
56 t t4
= a0
[i
*lda0
+j
+4];
57 t t5
= a0
[i
*lda0
+j
+5];
58 t t6
= a0
[i
*lda0
+j
+6];
59 t t7
= a0
[i
*lda0
+j
+7];
69 for (size_t j
= p
/8*8; j
< p
; j
++)
70 a
[i
*lda
+j
] = a0
[i
*lda0
+j
];
74 static void mm_cb(size_t m
, size_t n
, size_t p
,
75 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
77 size_t nmb
= m
/CBM
, nnb
= n
/CBN
, npb
= p
/CBK
;
78 size_t mb
= nmb
*CBM
, nb
= nnb
*CBN
, pb
= npb
*CBK
;
79 //t a1[mb*pb], b1[pb*nb], c1[mb*nb];
80 t
* a1
= (t
*)alloca_aligned(sizeof(t
)*mb
*pb
, 8192);
81 t
* b1
= (t
*)alloca_aligned(sizeof(t
)*pb
*nb
, 8192);
82 t
* c1
= (t
*)alloca_aligned(sizeof(t
)*mb
*nb
, 8192);
84 for (size_t i
= 0; i
< mb
; i
+= CBM
)
85 for (size_t j
= 0; j
< pb
; j
+= CBK
)
86 repack(a1
+ (npb
*(i
/CBM
) + j
/CBK
)*(CBM
*CBK
), CBK
, a
+ i
*lda
+ j
, lda
, CBM
, CBK
);
88 for (size_t i
= 0; i
< pb
; i
+= CBK
)
89 for (size_t j
= 0; j
< nb
; j
+= CBN
)
90 repack(b1
+ (nnb
*(i
/CBK
) + j
/CBN
)*(CBK
*CBN
), CBN
, b
+ i
*ldb
+ j
, ldb
, CBK
, CBN
);
92 for (size_t i
= 0; i
< mb
; i
+= CBM
)
93 for (size_t j
= 0; j
< nb
; j
+= CBN
)
94 repack(c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
, c
+ i
*ldc
+ j
, ldc
, CBM
, CBN
);
96 for (size_t i
= 0; i
< mb
; i
+= CBM
)
98 for (size_t j
= 0; j
< nb
; j
+= CBN
)
100 for (size_t k
= 0; k
< pb
; k
+= CBK
)
103 a1
+ (npb
*(i
/CBM
) + k
/CBK
)*(CBM
*CBK
), CBK
,
104 b1
+ (nnb
*(k
/CBK
) + j
/CBN
)*(CBK
*CBN
), CBN
,
105 c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
);
109 mm_rb(CBM
, CBN
, p
- pb
,
112 c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
);
117 for (size_t k
= 0; k
< p
; k
+= CBK
)
119 mm_rb(CBM
, n
- nb
, MIN(p
- k
, CBK
),
122 c
+ i
*ldc
+ nb
, ldc
);
128 for (size_t j
= 0; j
< n
; j
+= CBN
)
130 for (size_t k
= 0; k
< p
; k
+= CBK
)
132 mm_rb(m
- mb
, MIN(n
- j
, CBN
), MIN(p
- k
, CBK
),
135 c
+ mb
*ldc
+ j
, ldc
);
140 for (size_t i
= 0; i
< mb
; i
+= CBM
)
141 for (size_t j
= 0; j
< nb
; j
+= CBN
)
142 repack(c
+ i
*ldc
+ j
, ldc
, c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
, CBM
, CBN
);
145 void mm(size_t m
, size_t n
, size_t p
,
146 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
148 if (__builtin_expect(m
<= 2*CBM
&& n
<= 2*CBN
&& p
<= 2*CBK
, 1))
149 mm_rb(m
, n
, p
, a
, lda
, b
, ldb
, c
, ldc
);
151 mm_cb(m
, n
, p
, a
, lda
, b
, ldb
, c
, ldc
);