Added pivoting to qwp solver
authorClifford Wolf <clifford@clifford.at>
Thu, 24 Sep 2015 20:16:37 +0000 (22:16 +0200)
committerClifford Wolf <clifford@clifford.at>
Thu, 24 Sep 2015 20:16:37 +0000 (22:16 +0200)
passes/cmds/qwp.cc

index eb4c10a73a47c06aec3fd7b762f96b4501170b20..f76de326a29bcefce5359416ad23a93037827c81 100644 (file)
@@ -255,33 +255,62 @@ struct QwpWorker
                // (least squares fit for "A*x = y")
                //
                // Using gaussian elimination to get M := [Id x]
-               // (no pivoting, so let's hope for the best..)
 
-               // eliminate to upper triangular matrix
+               vector<int> pivot_cache;
+               vector<int> queue;
+
+               for (int i = 0; i < N; i++)
+                       queue.push_back(i);
+
+               // gaussian elimination
                for (int i = 0; i < N; i++)
                {
+                       // find best row
+                       int best_row = queue.front();
+                       int best_row_queue_idx = 0;
+                       double best_row_absval = 0;
+
+                       for (int k = 0; k < GetSize(queue); k++) {
+                               int row = queue[k];
+                               double absval = fabs(M[i + row*N1]);
+                               if (absval > best_row_absval) {
+                                       best_row = row;
+                                       best_row_queue_idx = k;
+                                       best_row_absval = absval;
+                               }
+                       }
+
+                       int row = best_row;
+                       pivot_cache.push_back(row);
+
+                       queue[best_row_queue_idx] = queue.back();
+                       queue.pop_back();
+
                        // normalize row
-                       for (int j = i+1; j < N+1; j++)
-                               M[(N+1)*i + j] /= M[(N+1)*i + i];
-                       M[(N+1)*i + i] = 1.0;
+                       for (int k = i+1; k < N1; k++)
+                               M[k + row*N1] /= M[i + row*N1];
+                       M[i + row*N1] = 1.0;
 
                        // elimination
-                       for (int j = i+1; j < N; j++) {
-                               double d = M[(N+1)*j + i];
-                               for (int k = 0; k < N+1; k++)
-                                       if (k > i)
-                                               M[(N+1)*j + k] -= d*M[(N+1)*i + k];
-                                       else
-                                               M[(N+1)*j + k] = 0.0;
+                       for (int other_row : queue) {
+                               double d = M[i + other_row*N1];
+                               for (int k = i+1; k < N1; k++)
+                                       M[k + other_row*N1] -= d*M[k + row*N1];
+                               M[i + other_row*N1] = 0.0;
                        }
                }
 
+               log_assert(queue.empty());
+               log_assert(GetSize(pivot_cache) == N);
+
                // back substitution
                for (int i = N-1; i >= 0; i--)
                for (int j = i+1; j < N; j++)
                {
-                       M[(N+1)*i + N] -= M[(N+1)*i + j] * M[(N+1)*j + N];
-                       M[(N+1)*i + j] = 0.0;
+                       int row = pivot_cache[i];
+                       int other_row = pivot_cache[j];
+                       M[N + row*N1] -= M[j + row*N1] * M[N + other_row*N1];
+                       M[j + row*N1] = 0.0;
                }
 
 #ifdef LOG_MATRICES