Updated README to recursively initialize repos
[riscv-tests.git] / benchmarks / dgemm / dgemm_main.c
1 //**************************************************************************
2 // Double-precision general matrix multiplication benchmark
3 //--------------------------------------------------------------------------
4
5 int ncores = 1;
6 #include "util.h"
7
8 //--------------------------------------------------------------------------
9 // Macros
10
11 // Set HOST_DEBUG to 1 if you are going to compile this for a host
12 // machine (ie Athena/Linux) for debug purposes and set HOST_DEBUG
13 // to 0 if you are compiling with the smips-gcc toolchain.
14
15 #ifndef HOST_DEBUG
16 #define HOST_DEBUG 0
17 #endif
18
19 // Set PREALLOCATE to 1 if you want to preallocate the benchmark
20 // function before starting stats. If you have instruction/data
21 // caches and you don't want to count the overhead of misses, then
22 // you will need to use preallocation.
23
24 #ifndef PREALLOCATE
25 #define PREALLOCATE 0
26 #endif
27
28 // Set SET_STATS to 1 if you want to carve out the piece that actually
29 // does the computation.
30
31 #ifndef SET_STATS
32 #define SET_STATS 0
33 #endif
34
35 //--------------------------------------------------------------------------
36 // Input/Reference Data
37
38 #include "dataset1.h"
39
40 //--------------------------------------------------------------------------
41 // Helper functions
42
43 int verify( long n, const double test[], const double correct[] )
44 {
45 int i;
46 for ( i = 0; i < n; i++ ) {
47 if ( test[i] != correct[i] ) {
48 return 2;
49 }
50 }
51 return 1;
52 }
53
54 #if HOST_DEBUG
55 #include <stdio.h>
56 #include <stdlib.h>
57 void printArray( char name[], long n, const double arr[] )
58 {
59 int i;
60 printf( " %10s :", name );
61 for ( i = 0; i < n; i++ )
62 printf( " %8.1f ", arr[i] );
63 printf( "\n" );
64 }
65 #endif
66
67 void setStats( int enable )
68 {
69 #if ( !HOST_DEBUG && SET_STATS )
70 asm( "mtpcr %0, cr10" : : "r" (enable) );
71 #endif
72 }
73
74 //--------------------------------------------------------------------------
75 // square_dgemm function
76
77 void square_dgemm( long n0, const double a0[], const double b0[], double c0[] )
78 {
79 long n = (n0+2)/3*3;
80 double a[n*n], b[n*n], c[n*n];
81
82 for (long i = 0; i < n0; i++)
83 {
84 long j;
85 for (j = 0; j < n0; j++)
86 {
87 a[i*n+j] = a0[i*n0+j];
88 b[i*n+j] = b0[j*n0+i];
89 }
90 for ( ; j < n; j++)
91 {
92 a[i*n+j] = b[i*n+j] = 0;
93 }
94 }
95 for (long i = n0; i < n; i++)
96 for (long j = 0; j < n; j++)
97 a[i*n+j] = b[i*n+j] = 0;
98
99 long i, j, k;
100 for (i = 0; i < n; i+=3)
101 {
102 for (j = 0; j < n; j+=3)
103 {
104 double *a0 = a + (i+0)*n, *b0 = b + (j+0)*n;
105 double *a1 = a + (i+1)*n, *b1 = b + (j+1)*n;
106 double *a2 = a + (i+2)*n, *b2 = b + (j+2)*n;
107
108 double s00 = 0, s01 = 0, s02 = 0;
109 double s10 = 0, s11 = 0, s12 = 0;
110 double s20 = 0, s21 = 0, s22 = 0;
111
112 while (a0 < a + (i+1)*n)
113 {
114 double a00 = a0[0], a01 = a0[1], a02 = a0[2];
115 double b00 = b0[0], b01 = b0[1], b02 = b0[2];
116 double a10 = a1[0], a11 = a1[1], a12 = a1[2];
117 double b10 = b1[0], b11 = b1[1], b12 = b1[2];
118 asm ("" ::: "memory");
119 double a20 = a2[0], a21 = a2[1], a22 = a2[2];
120 double b20 = b2[0], b21 = b2[1], b22 = b2[2];
121
122 s00 = a00*b00 + (a01*b01 + (a02*b02 + s00));
123 s01 = a00*b10 + (a01*b11 + (a02*b12 + s01));
124 s02 = a00*b20 + (a01*b21 + (a02*b22 + s02));
125 s10 = a10*b00 + (a11*b01 + (a12*b02 + s10));
126 s11 = a10*b10 + (a11*b11 + (a12*b12 + s11));
127 s12 = a10*b20 + (a11*b21 + (a12*b22 + s12));
128 s20 = a20*b00 + (a21*b01 + (a22*b02 + s20));
129 s21 = a20*b10 + (a21*b11 + (a22*b12 + s21));
130 s22 = a20*b20 + (a21*b21 + (a22*b22 + s22));
131
132 a0 += 3; b0 += 3;
133 a1 += 3; b1 += 3;
134 a2 += 3; b2 += 3;
135 }
136
137 c[(i+0)*n+j+0] = s00; c[(i+0)*n+j+1] = s01; c[(i+0)*n+j+2] = s02;
138 c[(i+1)*n+j+0] = s10; c[(i+1)*n+j+1] = s11; c[(i+1)*n+j+2] = s12;
139 c[(i+2)*n+j+0] = s20; c[(i+2)*n+j+1] = s21; c[(i+2)*n+j+2] = s22;
140 }
141 }
142
143 for (long i = 0; i < n0; i++)
144 {
145 long j;
146 for (j = 0; j < n0 - 2; j+=3)
147 {
148 c0[i*n0+j+0] = c[i*n+j+0];
149 c0[i*n0+j+1] = c[i*n+j+1];
150 c0[i*n0+j+2] = c[i*n+j+2];
151 }
152 for ( ; j < n0; j++)
153 c0[i*n0+j] = c[i*n+j];
154 }
155 }
156
157 //--------------------------------------------------------------------------
158 // Main
159
160 int main( int argc, char* argv[] )
161 {
162 double results_data[DATA_SIZE*DATA_SIZE];
163
164 // Output the input array
165
166 #if HOST_DEBUG
167 printArray( "input1", DATA_SIZE*DATA_SIZE, input1_data );
168 printArray( "input2", DATA_SIZE*DATA_SIZE, input2_data );
169 printArray( "verify", DATA_SIZE*DATA_SIZE, verify_data );
170 #endif
171
172 // If needed we preallocate everything in the caches
173
174 #if PREALLOCATE
175 square_dgemm( DATA_SIZE, input1_data, input2_data, results_data );
176 #endif
177
178 // Do the dgemm
179
180 setStats(1);
181 square_dgemm( DATA_SIZE, input1_data, input2_data, results_data );
182 setStats(0);
183
184 // Print out the results
185
186 #if HOST_DEBUG
187 printArray( "results", DATA_SIZE*DATA_SIZE, results_data );
188 #endif
189
190 // Check the results
191
192 finishTest(verify( DATA_SIZE*DATA_SIZE, results_data, verify_data ));
193
194 }