de2740b90793ba803e5123e043eb02b463ab654f
[riscv-tests.git] / benchmarks / dgemm / dgemm_gendata.scala
1 #!/usr/bin/env scala
2 !#
3
4 val size = args(0).toInt
5
6 def print_matrix(name: String, rows: Int, cols: Int, data: Array[Double]) = {
7 println("const double " + name + "[DATA_SIZE*DATA_SIZE] = {")
8 for (i <- 0 until rows) {
9 println(data.slice(cols*i, cols*(i+1)).mkString(", ") + (if (i < rows-1) ", " else ""))
10 }
11 println("};")
12 }
13 def rand_matrix(rows: Int, cols: Int) = {
14 var r = new scala.util.Random
15 var m = new Array[Double](rows*cols)
16 for (i <- 0 until rows*cols)
17 m(i) = r.nextInt(1000)
18 m
19 }
20 def matmul(a: Array[Double], b: Array[Double], m: Int, n: Int, k: Int) = {
21 var c = new Array[Double](m*n)
22 for (i <- 0 until m)
23 for (j <- 0 until n)
24 for (l <- 0 until k)
25 c(i*n+j) += a(i*n+l)*b(l*k+j)
26 c
27 }
28
29 println("#define DATA_SIZE " + size)
30
31 val a = rand_matrix(size, size)
32 val b = rand_matrix(size, size)
33 val c = matmul(a, b, size, size, size)
34
35 print_matrix("input1_data", size, size, a)
36 print_matrix("input2_data", size, size, b)
37 print_matrix("verify_data", size, size, c)