7c8ce7c9ae92450b9326c80fb2a69c231da02e1a
1 //**************************************************************************
2 // Double-precision general matrix multiplication benchmark
3 //--------------------------------------------------------------------------
5 //--------------------------------------------------------------------------
8 // Set HOST_DEBUG to 1 if you are going to compile this for a host
9 // machine (ie Athena/Linux) for debug purposes and set HOST_DEBUG
10 // to 0 if you are compiling with the smips-gcc toolchain.
16 // Set PREALLOCATE to 1 if you want to preallocate the benchmark
17 // function before starting stats. If you have instruction/data
18 // caches and you don't want to count the overhead of misses, then
19 // you will need to use preallocation.
25 // Set SET_STATS to 1 if you want to carve out the piece that actually
26 // does the computation.
32 //--------------------------------------------------------------------------
33 // Input/Reference Data
37 //--------------------------------------------------------------------------
40 int verify( long n
, const double test
[], const double correct
[] )
43 for ( i
= 0; i
< n
; i
++ ) {
44 if ( test
[i
] != correct
[i
] ) {
54 void printArray( char name
[], long n
, const double arr
[] )
57 printf( " %10s :", name
);
58 for ( i
= 0; i
< n
; i
++ )
59 printf( " %8.1f ", arr
[i
] );
64 void finishTest( int toHostValue
)
67 if ( toHostValue
== 1 )
68 printf( "*** PASSED ***\n" );
70 printf( "*** FAILED *** (tohost = %d)\n", toHostValue
);
73 asm( "mtpcr %0, cr30" : : "r" (toHostValue
) );
78 void setStats( int enable
)
80 #if ( !HOST_DEBUG && SET_STATS )
81 asm( "mtpcr %0, cr10" : : "r" (enable
) );
85 //--------------------------------------------------------------------------
86 // square_dgemm function
88 void square_dgemm( long n0
, const double a0
[], const double b0
[], double c0
[] )
91 double a
[n
*n
], b
[n
*n
], c
[n
*n
];
93 for (long i
= 0; i
< n0
; i
++)
96 for (j
= 0; j
< n0
; j
++)
98 a
[i
*n
+j
] = a0
[i
*n0
+j
];
99 b
[i
*n
+j
] = b0
[j
*n0
+i
];
103 a
[i
*n
+j
] = b
[i
*n
+j
] = 0;
106 for (long i
= n0
; i
< n
; i
++)
107 for (long j
= 0; j
< n
; j
++)
108 a
[i
*n
+j
] = b
[i
*n
+j
] = 0;
111 for (i
= 0; i
< n
; i
+=3)
113 for (j
= 0; j
< n
; j
+=3)
115 double *a0
= a
+ (i
+0)*n
, *b0
= b
+ (j
+0)*n
;
116 double *a1
= a
+ (i
+1)*n
, *b1
= b
+ (j
+1)*n
;
117 double *a2
= a
+ (i
+2)*n
, *b2
= b
+ (j
+2)*n
;
119 double s00
= 0, s01
= 0, s02
= 0;
120 double s10
= 0, s11
= 0, s12
= 0;
121 double s20
= 0, s21
= 0, s22
= 0;
123 while (a0
< a
+ (i
+1)*n
)
125 double a00
= a0
[0], a01
= a0
[1], a02
= a0
[2];
126 double b00
= b0
[0], b01
= b0
[1], b02
= b0
[2];
127 double a10
= a1
[0], a11
= a1
[1], a12
= a1
[2];
128 double b10
= b1
[0], b11
= b1
[1], b12
= b1
[2];
129 asm ("" ::: "memory");
130 double a20
= a2
[0], a21
= a2
[1], a22
= a2
[2];
131 double b20
= b2
[0], b21
= b2
[1], b22
= b2
[2];
133 s00
= a00
*b00
+ (a01
*b01
+ (a02
*b02
+ s00
));
134 s01
= a00
*b10
+ (a01
*b11
+ (a02
*b12
+ s01
));
135 s02
= a00
*b20
+ (a01
*b21
+ (a02
*b22
+ s02
));
136 s10
= a10
*b00
+ (a11
*b01
+ (a12
*b02
+ s10
));
137 s11
= a10
*b10
+ (a11
*b11
+ (a12
*b12
+ s11
));
138 s12
= a10
*b20
+ (a11
*b21
+ (a12
*b22
+ s12
));
139 s20
= a20
*b00
+ (a21
*b01
+ (a22
*b02
+ s20
));
140 s21
= a20
*b10
+ (a21
*b11
+ (a22
*b12
+ s21
));
141 s22
= a20
*b20
+ (a21
*b21
+ (a22
*b22
+ s22
));
148 c
[(i
+0)*n
+j
+0] = s00
; c
[(i
+0)*n
+j
+1] = s01
; c
[(i
+0)*n
+j
+2] = s02
;
149 c
[(i
+1)*n
+j
+0] = s10
; c
[(i
+1)*n
+j
+1] = s11
; c
[(i
+1)*n
+j
+2] = s12
;
150 c
[(i
+2)*n
+j
+0] = s20
; c
[(i
+2)*n
+j
+1] = s21
; c
[(i
+2)*n
+j
+2] = s22
;
154 for (long i
= 0; i
< n0
; i
++)
157 for (j
= 0; j
< n0
- 2; j
+=3)
159 c0
[i
*n0
+j
+0] = c
[i
*n
+j
+0];
160 c0
[i
*n0
+j
+1] = c
[i
*n
+j
+1];
161 c0
[i
*n0
+j
+2] = c
[i
*n
+j
+2];
164 c0
[i
*n0
+j
] = c
[i
*n
+j
];
168 //--------------------------------------------------------------------------
171 int main( int argc
, char* argv
[] )
173 double results_data
[DATA_SIZE
*DATA_SIZE
];
175 // Output the input array
178 printArray( "input1", DATA_SIZE
*DATA_SIZE
, input1_data
);
179 printArray( "input2", DATA_SIZE
*DATA_SIZE
, input2_data
);
180 printArray( "verify", DATA_SIZE
*DATA_SIZE
, verify_data
);
183 // If needed we preallocate everything in the caches
186 square_dgemm( DATA_SIZE
, input1_data
, input2_data
, results_data
);
192 square_dgemm( DATA_SIZE
, input1_data
, input2_data
, results_data
);
195 // Print out the results
198 printArray( "results", DATA_SIZE
*DATA_SIZE
, results_data
);
203 finishTest(verify( DATA_SIZE
*DATA_SIZE
, results_data
, verify_data
));