Merge pull request #8 from riscv/sqrt-171
[riscv-tests.git] / benchmarks / spmv / spmv_gendata.scala
1 #!/usr/bin/env scala
2 !#
3
4 val m = args(0).toInt
5 val n = args(1).toInt
6 val approx_nnz = args(2).toInt
7
8 val pnnz = approx_nnz.toDouble/(m*n)
9 val idx = collection.mutable.ArrayBuffer[Int]()
10 val p = collection.mutable.ArrayBuffer(0)
11
12 for (i <- 0 until m) {
13 for (j <- 0 until n) {
14 if (util.Random.nextDouble < pnnz)
15 idx += j
16 }
17 p += idx.length
18 }
19
20 val nnz = idx.length
21 val v = Array.tabulate(n)(i => util.Random.nextInt(1000))
22 val d = Array.tabulate(nnz)(i => util.Random.nextInt(1000))
23
24 def printVec(t: String, name: String, data: Seq[Int]) = {
25 println("const " + t + " " + name + "[" + data.length + "] = {")
26 println(" "+data.map(_.toString).reduceLeft(_+",\n "+_))
27 println("};")
28 }
29
30 def spmv(p: Seq[Int], d: Seq[Int], idx: Seq[Int], v: Seq[Int]) = {
31 val y = collection.mutable.ArrayBuffer[Int]()
32 for (i <- 0 until p.length-1) {
33 var yi = 0
34 for (k <- p(i) until p(i+1))
35 yi = yi + d(k)*v(idx(k))
36 y += yi
37 }
38 y
39 }
40
41 println("#define R " + m)
42 println("#define C " + n)
43 println("#define NNZ " + nnz)
44 printVec("double", "val", d)
45 printVec("int", "idx", idx)
46 printVec("double", "x", v)
47 printVec("int", "ptr", p)
48 printVec("double", "verify_data", spmv(p, d, idx, v))