1 //**************************************************************************
2 // Double-precision general matrix multiplication benchmark
3 //--------------------------------------------------------------------------
8 //--------------------------------------------------------------------------
11 // Set HOST_DEBUG to 1 if you are going to compile this for a host
12 // machine (ie Athena/Linux) for debug purposes and set HOST_DEBUG
13 // to 0 if you are compiling with the smips-gcc toolchain.
19 // Set PREALLOCATE to 1 if you want to preallocate the benchmark
20 // function before starting stats. If you have instruction/data
21 // caches and you don't want to count the overhead of misses, then
22 // you will need to use preallocation.
28 // Set SET_STATS to 1 if you want to carve out the piece that actually
29 // does the computation.
35 //--------------------------------------------------------------------------
36 // Input/Reference Data
40 //--------------------------------------------------------------------------
43 int verify( long n
, const double test
[], const double correct
[] )
46 for ( i
= 0; i
< n
; i
++ ) {
47 if ( test
[i
] != correct
[i
] ) {
57 void printArray( char name
[], long n
, const double arr
[] )
60 printf( " %10s :", name
);
61 for ( i
= 0; i
< n
; i
++ )
62 printf( " %8.1f ", arr
[i
] );
67 void setStats( int enable
)
69 #if ( !HOST_DEBUG && SET_STATS )
70 asm( "mtpcr %0, cr10" : : "r" (enable
) );
74 //--------------------------------------------------------------------------
75 // square_dgemm function
77 void square_dgemm( long n0
, const double a0
[], const double b0
[], double c0
[] )
80 double a
[n
*n
], b
[n
*n
], c
[n
*n
];
82 for (long i
= 0; i
< n0
; i
++)
85 for (j
= 0; j
< n0
; j
++)
87 a
[i
*n
+j
] = a0
[i
*n0
+j
];
88 b
[i
*n
+j
] = b0
[j
*n0
+i
];
92 a
[i
*n
+j
] = b
[i
*n
+j
] = 0;
95 for (long i
= n0
; i
< n
; i
++)
96 for (long j
= 0; j
< n
; j
++)
97 a
[i
*n
+j
] = b
[i
*n
+j
] = 0;
100 for (i
= 0; i
< n
; i
+=3)
102 for (j
= 0; j
< n
; j
+=3)
104 double *a0
= a
+ (i
+0)*n
, *b0
= b
+ (j
+0)*n
;
105 double *a1
= a
+ (i
+1)*n
, *b1
= b
+ (j
+1)*n
;
106 double *a2
= a
+ (i
+2)*n
, *b2
= b
+ (j
+2)*n
;
108 double s00
= 0, s01
= 0, s02
= 0;
109 double s10
= 0, s11
= 0, s12
= 0;
110 double s20
= 0, s21
= 0, s22
= 0;
112 while (a0
< a
+ (i
+1)*n
)
114 double a00
= a0
[0], a01
= a0
[1], a02
= a0
[2];
115 double b00
= b0
[0], b01
= b0
[1], b02
= b0
[2];
116 double a10
= a1
[0], a11
= a1
[1], a12
= a1
[2];
117 double b10
= b1
[0], b11
= b1
[1], b12
= b1
[2];
118 asm ("" ::: "memory");
119 double a20
= a2
[0], a21
= a2
[1], a22
= a2
[2];
120 double b20
= b2
[0], b21
= b2
[1], b22
= b2
[2];
122 s00
= a00
*b00
+ (a01
*b01
+ (a02
*b02
+ s00
));
123 s01
= a00
*b10
+ (a01
*b11
+ (a02
*b12
+ s01
));
124 s02
= a00
*b20
+ (a01
*b21
+ (a02
*b22
+ s02
));
125 s10
= a10
*b00
+ (a11
*b01
+ (a12
*b02
+ s10
));
126 s11
= a10
*b10
+ (a11
*b11
+ (a12
*b12
+ s11
));
127 s12
= a10
*b20
+ (a11
*b21
+ (a12
*b22
+ s12
));
128 s20
= a20
*b00
+ (a21
*b01
+ (a22
*b02
+ s20
));
129 s21
= a20
*b10
+ (a21
*b11
+ (a22
*b12
+ s21
));
130 s22
= a20
*b20
+ (a21
*b21
+ (a22
*b22
+ s22
));
137 c
[(i
+0)*n
+j
+0] = s00
; c
[(i
+0)*n
+j
+1] = s01
; c
[(i
+0)*n
+j
+2] = s02
;
138 c
[(i
+1)*n
+j
+0] = s10
; c
[(i
+1)*n
+j
+1] = s11
; c
[(i
+1)*n
+j
+2] = s12
;
139 c
[(i
+2)*n
+j
+0] = s20
; c
[(i
+2)*n
+j
+1] = s21
; c
[(i
+2)*n
+j
+2] = s22
;
143 for (long i
= 0; i
< n0
; i
++)
146 for (j
= 0; j
< n0
- 2; j
+=3)
148 c0
[i
*n0
+j
+0] = c
[i
*n
+j
+0];
149 c0
[i
*n0
+j
+1] = c
[i
*n
+j
+1];
150 c0
[i
*n0
+j
+2] = c
[i
*n
+j
+2];
153 c0
[i
*n0
+j
] = c
[i
*n
+j
];
157 //--------------------------------------------------------------------------
160 int main( int argc
, char* argv
[] )
162 double results_data
[DATA_SIZE
*DATA_SIZE
];
164 // Output the input array
167 printArray( "input1", DATA_SIZE
*DATA_SIZE
, input1_data
);
168 printArray( "input2", DATA_SIZE
*DATA_SIZE
, input2_data
);
169 printArray( "verify", DATA_SIZE
*DATA_SIZE
, verify_data
);
172 // If needed we preallocate everything in the caches
175 square_dgemm( DATA_SIZE
, input1_data
, input2_data
, results_data
);
181 square_dgemm( DATA_SIZE
, input1_data
, input2_data
, results_data
);
184 // Print out the results
187 printArray( "results", DATA_SIZE
*DATA_SIZE
, results_data
);
192 finishTest(verify( DATA_SIZE
*DATA_SIZE
, results_data
, verify_data
));