1 //**************************************************************************
2 // Double-precision general matrix multiplication benchmark
3 //--------------------------------------------------------------------------
7 //--------------------------------------------------------------------------
8 // Input/Reference Data
12 //--------------------------------------------------------------------------
13 // square_dgemm function
15 void square_dgemm( long n0
, const double a0
[], const double b0
[], double c0
[] )
18 double a
[n
*n
], b
[n
*n
], c
[n
*n
];
20 for (long i
= 0; i
< n0
; i
++)
23 for (j
= 0; j
< n0
; j
++)
25 a
[i
*n
+j
] = a0
[i
*n0
+j
];
26 b
[i
*n
+j
] = b0
[j
*n0
+i
];
30 a
[i
*n
+j
] = b
[i
*n
+j
] = 0;
33 for (long i
= n0
; i
< n
; i
++)
34 for (long j
= 0; j
< n
; j
++)
35 a
[i
*n
+j
] = b
[i
*n
+j
] = 0;
38 for (i
= 0; i
< n
; i
+=3)
40 for (j
= 0; j
< n
; j
+=3)
42 double *a0
= a
+ (i
+0)*n
, *b0
= b
+ (j
+0)*n
;
43 double *a1
= a
+ (i
+1)*n
, *b1
= b
+ (j
+1)*n
;
44 double *a2
= a
+ (i
+2)*n
, *b2
= b
+ (j
+2)*n
;
46 double s00
= 0, s01
= 0, s02
= 0;
47 double s10
= 0, s11
= 0, s12
= 0;
48 double s20
= 0, s21
= 0, s22
= 0;
50 while (a0
< a
+ (i
+1)*n
)
52 double a00
= a0
[0], a01
= a0
[1], a02
= a0
[2];
53 double b00
= b0
[0], b01
= b0
[1], b02
= b0
[2];
54 double a10
= a1
[0], a11
= a1
[1], a12
= a1
[2];
55 double b10
= b1
[0], b11
= b1
[1], b12
= b1
[2];
56 asm ("" ::: "memory");
57 double a20
= a2
[0], a21
= a2
[1], a22
= a2
[2];
58 double b20
= b2
[0], b21
= b2
[1], b22
= b2
[2];
60 s00
= a00
*b00
+ (a01
*b01
+ (a02
*b02
+ s00
));
61 s01
= a00
*b10
+ (a01
*b11
+ (a02
*b12
+ s01
));
62 s02
= a00
*b20
+ (a01
*b21
+ (a02
*b22
+ s02
));
63 s10
= a10
*b00
+ (a11
*b01
+ (a12
*b02
+ s10
));
64 s11
= a10
*b10
+ (a11
*b11
+ (a12
*b12
+ s11
));
65 s12
= a10
*b20
+ (a11
*b21
+ (a12
*b22
+ s12
));
66 s20
= a20
*b00
+ (a21
*b01
+ (a22
*b02
+ s20
));
67 s21
= a20
*b10
+ (a21
*b11
+ (a22
*b12
+ s21
));
68 s22
= a20
*b20
+ (a21
*b21
+ (a22
*b22
+ s22
));
75 c
[(i
+0)*n
+j
+0] = s00
; c
[(i
+0)*n
+j
+1] = s01
; c
[(i
+0)*n
+j
+2] = s02
;
76 c
[(i
+1)*n
+j
+0] = s10
; c
[(i
+1)*n
+j
+1] = s11
; c
[(i
+1)*n
+j
+2] = s12
;
77 c
[(i
+2)*n
+j
+0] = s20
; c
[(i
+2)*n
+j
+1] = s21
; c
[(i
+2)*n
+j
+2] = s22
;
81 for (long i
= 0; i
< n0
; i
++)
84 for (j
= 0; j
< n0
- 2; j
+=3)
86 c0
[i
*n0
+j
+0] = c
[i
*n
+j
+0];
87 c0
[i
*n0
+j
+1] = c
[i
*n
+j
+1];
88 c0
[i
*n0
+j
+2] = c
[i
*n
+j
+2];
91 c0
[i
*n0
+j
] = c
[i
*n
+j
];
95 //--------------------------------------------------------------------------
98 int main( int argc
, char* argv
[] )
100 double results_data
[DATA_SIZE
*DATA_SIZE
];
102 // Output the input array
103 printDoubleArray( "input1", DATA_SIZE
*DATA_SIZE
, input1_data
);
104 printDoubleArray( "input2", DATA_SIZE
*DATA_SIZE
, input2_data
);
105 printDoubleArray( "verify", DATA_SIZE
*DATA_SIZE
, verify_data
);
108 // If needed we preallocate everything in the caches
109 square_dgemm( DATA_SIZE
, input1_data
, input2_data
, results_data
);
114 square_dgemm( DATA_SIZE
, input1_data
, input2_data
, results_data
);
117 // Print out the results
118 printDoubleArray( "results", DATA_SIZE
*DATA_SIZE
, results_data
);
121 return verifyDouble( DATA_SIZE
*DATA_SIZE
, results_data
, verify_data
);