From ec92c8965960fa814c3663e987bc2a7eb80965e5 Mon Sep 17 00:00:00 2001 From: Clifford Wolf Date: Thu, 24 Sep 2015 22:16:37 +0200 Subject: [PATCH] Added pivoting to qwp solver --- passes/cmds/qwp.cc | 57 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/passes/cmds/qwp.cc b/passes/cmds/qwp.cc index eb4c10a73..f76de326a 100644 --- a/passes/cmds/qwp.cc +++ b/passes/cmds/qwp.cc @@ -255,33 +255,62 @@ struct QwpWorker // (least squares fit for "A*x = y") // // Using gaussian elimination to get M := [Id x] - // (no pivoting, so let's hope for the best..) - // eliminate to upper triangular matrix + vector pivot_cache; + vector queue; + + for (int i = 0; i < N; i++) + queue.push_back(i); + + // gaussian elimination for (int i = 0; i < N; i++) { + // find best row + int best_row = queue.front(); + int best_row_queue_idx = 0; + double best_row_absval = 0; + + for (int k = 0; k < GetSize(queue); k++) { + int row = queue[k]; + double absval = fabs(M[i + row*N1]); + if (absval > best_row_absval) { + best_row = row; + best_row_queue_idx = k; + best_row_absval = absval; + } + } + + int row = best_row; + pivot_cache.push_back(row); + + queue[best_row_queue_idx] = queue.back(); + queue.pop_back(); + // normalize row - for (int j = i+1; j < N+1; j++) - M[(N+1)*i + j] /= M[(N+1)*i + i]; - M[(N+1)*i + i] = 1.0; + for (int k = i+1; k < N1; k++) + M[k + row*N1] /= M[i + row*N1]; + M[i + row*N1] = 1.0; // elimination - for (int j = i+1; j < N; j++) { - double d = M[(N+1)*j + i]; - for (int k = 0; k < N+1; k++) - if (k > i) - M[(N+1)*j + k] -= d*M[(N+1)*i + k]; - else - M[(N+1)*j + k] = 0.0; + for (int other_row : queue) { + double d = M[i + other_row*N1]; + for (int k = i+1; k < N1; k++) + M[k + other_row*N1] -= d*M[k + row*N1]; + M[i + other_row*N1] = 0.0; } } + log_assert(queue.empty()); + log_assert(GetSize(pivot_cache) == N); + // back substitution for (int i = N-1; i >= 0; i--) for (int j = i+1; j < N; j++) { - M[(N+1)*i + N] -= M[(N+1)*i + j] * M[(N+1)*j + N]; - M[(N+1)*i + j] = 0.0; + int row = pivot_cache[i]; + int other_row = pivot_cache[j]; + M[N + row*N1] -= M[j + row*N1] * M[N + other_row*N1]; + M[j + row*N1] = 0.0; } #ifdef LOG_MATRICES -- 2.30.2