7 #define MIN(a, b) ((a) < (b) ? (a) : (b))
9 static void mm_naive(size_t m
, size_t n
, size_t p
,
10 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
12 for (size_t i
= 0; i
< m
; i
++)
14 for (size_t j
= 0; j
< n
; j
++)
16 t s0
= c
[i
*ldc
+j
], s1
= 0, s2
= 0, s3
= 0;
17 for (size_t k
= 0; k
< p
/4*4; k
+=4)
19 s0
= fma(a
[i
*lda
+k
+0], b
[(k
+0)*ldb
+j
], s0
);
20 s1
= fma(a
[i
*lda
+k
+1], b
[(k
+1)*ldb
+j
], s1
);
21 s2
= fma(a
[i
*lda
+k
+2], b
[(k
+2)*ldb
+j
], s2
);
22 s3
= fma(a
[i
*lda
+k
+3], b
[(k
+3)*ldb
+j
], s3
);
24 for (size_t k
= p
/4*4; k
< p
; k
++)
25 s0
= fma(a
[i
*lda
+k
], b
[k
*ldb
+j
], s0
);
26 c
[i
*ldc
+j
] = (s0
+ s1
) + (s2
+ s3
);
31 static inline void mm_rb(size_t m
, size_t n
, size_t p
,
32 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
34 size_t mb
= m
/RBM
*RBM
, nb
= n
/RBN
*RBN
;
35 for (size_t i
= 0; i
< mb
; i
+= RBM
)
37 for (size_t j
= 0; j
< nb
; j
+= RBN
)
38 kloop(p
, a
+i
*lda
, lda
, b
+j
, ldb
, c
+i
*ldc
+j
, ldc
);
39 mm_naive(RBM
, n
- nb
, p
, a
+i
*lda
, lda
, b
+nb
, ldb
, c
+i
*ldc
+nb
, ldc
);
41 mm_naive(m
- mb
, n
, p
, a
+mb
*lda
, lda
, b
, ldb
, c
+mb
*ldc
, ldc
);
44 static inline void repack(t
* a
, size_t lda
, const t
* a0
, size_t lda0
, size_t m
, size_t p
)
46 for (size_t i
= 0; i
< m
; i
++)
48 for (size_t j
= 0; j
< p
/8*8; j
+=8)
50 t t0
= a0
[i
*lda0
+j
+0];
51 t t1
= a0
[i
*lda0
+j
+1];
52 t t2
= a0
[i
*lda0
+j
+2];
53 t t3
= a0
[i
*lda0
+j
+3];
54 t t4
= a0
[i
*lda0
+j
+4];
55 t t5
= a0
[i
*lda0
+j
+5];
56 t t6
= a0
[i
*lda0
+j
+6];
57 t t7
= a0
[i
*lda0
+j
+7];
67 for (size_t j
= p
/8*8; j
< p
; j
++)
68 a
[i
*lda
+j
] = a0
[i
*lda0
+j
];
72 static void mm_cb(size_t m
, size_t n
, size_t p
,
73 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
75 size_t nmb
= m
/CBM
, nnb
= n
/CBN
, npb
= p
/CBK
;
76 size_t mb
= nmb
*CBM
, nb
= nnb
*CBN
, pb
= npb
*CBK
;
77 //t a1[mb*pb], b1[pb*nb], c1[mb*nb];
78 t
* a1
= (t
*)alloca_aligned(sizeof(t
)*mb
*pb
, 8192);
79 t
* b1
= (t
*)alloca_aligned(sizeof(t
)*pb
*nb
, 8192);
80 t
* c1
= (t
*)alloca_aligned(sizeof(t
)*mb
*nb
, 8192);
82 for (size_t i
= 0; i
< mb
; i
+= CBM
)
83 for (size_t j
= 0; j
< pb
; j
+= CBK
)
84 repack(a1
+ (npb
*(i
/CBM
) + j
/CBK
)*(CBM
*CBK
), CBK
, a
+ i
*lda
+ j
, lda
, CBM
, CBK
);
86 for (size_t i
= 0; i
< pb
; i
+= CBK
)
87 for (size_t j
= 0; j
< nb
; j
+= CBN
)
88 repack(b1
+ (nnb
*(i
/CBK
) + j
/CBN
)*(CBK
*CBN
), CBN
, b
+ i
*ldb
+ j
, ldb
, CBK
, CBN
);
90 for (size_t i
= 0; i
< mb
; i
+= CBM
)
91 for (size_t j
= 0; j
< nb
; j
+= CBN
)
92 repack(c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
, c
+ i
*ldc
+ j
, ldc
, CBM
, CBN
);
94 for (size_t i
= 0; i
< mb
; i
+= CBM
)
96 for (size_t j
= 0; j
< nb
; j
+= CBN
)
98 for (size_t k
= 0; k
< pb
; k
+= CBK
)
101 a1
+ (npb
*(i
/CBM
) + k
/CBK
)*(CBM
*CBK
), CBK
,
102 b1
+ (nnb
*(k
/CBK
) + j
/CBN
)*(CBK
*CBN
), CBN
,
103 c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
);
107 mm_rb(CBM
, CBN
, p
- pb
,
110 c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
);
115 for (size_t k
= 0; k
< p
; k
+= CBK
)
117 mm_rb(CBM
, n
- nb
, MIN(p
- k
, CBK
),
120 c
+ i
*ldc
+ nb
, ldc
);
126 for (size_t j
= 0; j
< n
; j
+= CBN
)
128 for (size_t k
= 0; k
< p
; k
+= CBK
)
130 mm_rb(m
- mb
, MIN(n
- j
, CBN
), MIN(p
- k
, CBK
),
133 c
+ mb
*ldc
+ j
, ldc
);
138 for (size_t i
= 0; i
< mb
; i
+= CBM
)
139 for (size_t j
= 0; j
< nb
; j
+= CBN
)
140 repack(c
+ i
*ldc
+ j
, ldc
, c1
+ (nnb
*(i
/CBM
) + j
/CBN
)*(CBM
*CBN
), CBN
, CBM
, CBN
);
143 void mm(size_t m
, size_t n
, size_t p
,
144 t
* a
, size_t lda
, t
* b
, size_t ldb
, t
* c
, size_t ldc
)
146 if (__builtin_expect(m
<= 2*CBM
&& n
<= 2*CBN
&& p
<= 2*CBK
, 1))
147 mm_rb(m
, n
, p
, a
, lda
, b
, ldb
, c
, ldc
);
149 mm_cb(m
, n
, p
, a
, lda
, b
, ldb
, c
, ldc
);