Improved qwp performance
authorClifford Wolf <clifford@clifford.at>
Thu, 24 Sep 2015 19:50:37 +0000 (21:50 +0200)
committerClifford Wolf <clifford@clifford.at>
Thu, 24 Sep 2015 19:50:37 +0000 (21:50 +0200)
passes/cmds/qwp.cc

index 2349c20ef77d85c202bf95a58fff2045603ab735..eb4c10a73a47c06aec3fd7b762f96b4501170b20 100644 (file)
@@ -203,42 +203,6 @@ struct QwpWorker
 
        void solve(bool alt_mode = false)
        {
-               int observation_matrix_m = GetSize(edges) + GetSize(nodes);
-               int observation_matrix_n = GetSize(nodes);
-
-               // Column-major order
-               vector<double> observation_matrix(observation_matrix_m * observation_matrix_n);
-               vector<double> observation_rhs_vector(observation_matrix_m);
-
-               int i = 0;
-               for (auto &edge : edges) {
-                       int idx1 = edge.first.first;
-                       int idx2 = edge.first.second;
-                       double weight = edge.second * (1.0 + xorshift32() * 1e-3);
-                       observation_matrix[i + observation_matrix_m*idx1] = weight;
-                       observation_matrix[i + observation_matrix_m*idx2] = -weight;
-                       i++;
-               }
-
-               int j = 0;
-               for (auto &node : nodes) {
-                       double weight = 1e-3;
-                       if (alt_mode ? node.alt_tied : node.tied) weight = 1e3;
-                       weight *= (1.0 + xorshift32() * 1e-3);
-                       observation_matrix[i + observation_matrix_m*j] = weight;
-                       observation_rhs_vector[i] = (alt_mode ? node.alt_pos : node.pos) * weight;
-                       i++, j++;
-               }
-
-#ifdef LOG_MATRICES
-               log("----\n");
-               for (int i = 0; i < observation_matrix_m; i++) {
-                       for (int j = 0; j < observation_matrix_n; j++)
-                               log(" %10.2e", observation_matrix[i + observation_matrix_m*j]);
-                       log(" |%9.2e\n", observation_rhs_vector[i]);
-               }
-#endif
-
                // A := observation_matrix
                // y := observation_rhs_vector
                //
@@ -248,22 +212,34 @@ struct QwpWorker
                // M := [AA Ay]
 
                // Row major order
-               vector<double> M(observation_matrix_n * (observation_matrix_n+1));
-               int N = observation_matrix_n;
+               int N = GetSize(nodes), N1 = N+1;
+               vector<double> M(N * N1);
 
-               for (int i = 0; i < N; i++)
-               for (int j = 0; j < N; j++) {
-                       double sum = 0;
-                       for (int k = 0; k < observation_matrix_m; k++)
-                               sum += observation_matrix[k + observation_matrix_m*i] * observation_matrix[k + observation_matrix_m*j];
-                       M[(N+1)*i + j] = sum;
+               for (auto &edge : edges)
+               {
+                       int idx1 = edge.first.first;
+                       int idx2 = edge.first.second;
+                       double weight = edge.second * (1.0 + xorshift32() * 1e-3);
+
+                       M[idx1 + idx1*N1] += weight * weight;
+                       M[idx2 + idx2*N1] += weight * weight;
+
+                       M[idx1 + idx2*N1] += -weight * weight;
+                       M[idx2 + idx1*N1] += -weight * weight;
                }
 
-               for (int i = 0; i < N; i++) {
-                       double sum = 0;
-                       for (int k = 0; k < observation_matrix_m; k++)
-                               sum += observation_matrix[k + observation_matrix_m*i] * observation_rhs_vector[k];
-                       M[(N+1)*i + N] = sum;
+               for (int idx = 0; idx < GetSize(nodes); idx++)
+               {
+                       auto &node = nodes[idx];
+                       double rhs = (alt_mode ? node.alt_pos : node.pos);
+
+                       double weight = 1e-3;
+                       if (alt_mode ? node.alt_tied : node.tied)
+                               weight = 1e3;
+                       weight *= (1.0 + xorshift32() * 1e-3);
+
+                       M[idx + idx*N1] += weight * weight;
+                       M[N + idx*N1] += rhs * weight * weight;
                }
 
 #ifdef LOG_MATRICES