Clean up benchmarks; support uarch-specific counters
[riscv-tests.git] / benchmarks / dgemm / dgemm_main.c
1 //**************************************************************************
2 // Double-precision general matrix multiplication benchmark
3 //--------------------------------------------------------------------------
4
5 #include "util.h"
6
7 //--------------------------------------------------------------------------
8 // Input/Reference Data
9
10 #include "dataset1.h"
11
12 //--------------------------------------------------------------------------
13 // square_dgemm function
14
15 void square_dgemm( long n0, const double a0[], const double b0[], double c0[] )
16 {
17 long n = (n0+2)/3*3;
18 double a[n*n], b[n*n], c[n*n];
19
20 for (long i = 0; i < n0; i++)
21 {
22 long j;
23 for (j = 0; j < n0; j++)
24 {
25 a[i*n+j] = a0[i*n0+j];
26 b[i*n+j] = b0[j*n0+i];
27 }
28 for ( ; j < n; j++)
29 {
30 a[i*n+j] = b[i*n+j] = 0;
31 }
32 }
33 for (long i = n0; i < n; i++)
34 for (long j = 0; j < n; j++)
35 a[i*n+j] = b[i*n+j] = 0;
36
37 long i, j, k;
38 for (i = 0; i < n; i+=3)
39 {
40 for (j = 0; j < n; j+=3)
41 {
42 double *a0 = a + (i+0)*n, *b0 = b + (j+0)*n;
43 double *a1 = a + (i+1)*n, *b1 = b + (j+1)*n;
44 double *a2 = a + (i+2)*n, *b2 = b + (j+2)*n;
45
46 double s00 = 0, s01 = 0, s02 = 0;
47 double s10 = 0, s11 = 0, s12 = 0;
48 double s20 = 0, s21 = 0, s22 = 0;
49
50 while (a0 < a + (i+1)*n)
51 {
52 double a00 = a0[0], a01 = a0[1], a02 = a0[2];
53 double b00 = b0[0], b01 = b0[1], b02 = b0[2];
54 double a10 = a1[0], a11 = a1[1], a12 = a1[2];
55 double b10 = b1[0], b11 = b1[1], b12 = b1[2];
56 asm ("" ::: "memory");
57 double a20 = a2[0], a21 = a2[1], a22 = a2[2];
58 double b20 = b2[0], b21 = b2[1], b22 = b2[2];
59
60 s00 = a00*b00 + (a01*b01 + (a02*b02 + s00));
61 s01 = a00*b10 + (a01*b11 + (a02*b12 + s01));
62 s02 = a00*b20 + (a01*b21 + (a02*b22 + s02));
63 s10 = a10*b00 + (a11*b01 + (a12*b02 + s10));
64 s11 = a10*b10 + (a11*b11 + (a12*b12 + s11));
65 s12 = a10*b20 + (a11*b21 + (a12*b22 + s12));
66 s20 = a20*b00 + (a21*b01 + (a22*b02 + s20));
67 s21 = a20*b10 + (a21*b11 + (a22*b12 + s21));
68 s22 = a20*b20 + (a21*b21 + (a22*b22 + s22));
69
70 a0 += 3; b0 += 3;
71 a1 += 3; b1 += 3;
72 a2 += 3; b2 += 3;
73 }
74
75 c[(i+0)*n+j+0] = s00; c[(i+0)*n+j+1] = s01; c[(i+0)*n+j+2] = s02;
76 c[(i+1)*n+j+0] = s10; c[(i+1)*n+j+1] = s11; c[(i+1)*n+j+2] = s12;
77 c[(i+2)*n+j+0] = s20; c[(i+2)*n+j+1] = s21; c[(i+2)*n+j+2] = s22;
78 }
79 }
80
81 for (long i = 0; i < n0; i++)
82 {
83 long j;
84 for (j = 0; j < n0 - 2; j+=3)
85 {
86 c0[i*n0+j+0] = c[i*n+j+0];
87 c0[i*n0+j+1] = c[i*n+j+1];
88 c0[i*n0+j+2] = c[i*n+j+2];
89 }
90 for ( ; j < n0; j++)
91 c0[i*n0+j] = c[i*n+j];
92 }
93 }
94
95 //--------------------------------------------------------------------------
96 // Main
97
98 int main( int argc, char* argv[] )
99 {
100 double results_data[DATA_SIZE*DATA_SIZE];
101
102 // Output the input array
103 printDoubleArray( "input1", DATA_SIZE*DATA_SIZE, input1_data );
104 printDoubleArray( "input2", DATA_SIZE*DATA_SIZE, input2_data );
105 printDoubleArray( "verify", DATA_SIZE*DATA_SIZE, verify_data );
106
107 #if PREALLOCATE
108 // If needed we preallocate everything in the caches
109 square_dgemm( DATA_SIZE, input1_data, input2_data, results_data );
110 #endif
111
112 // Do the dgemm
113 setStats(1);
114 square_dgemm( DATA_SIZE, input1_data, input2_data, results_data );
115 setStats(0);
116
117 // Print out the results
118 printDoubleArray( "results", DATA_SIZE*DATA_SIZE, results_data );
119
120 // Check the results
121 return verifyDouble( DATA_SIZE*DATA_SIZE, results_data, verify_data );
122 }