From c40d5678a4bbd73bde711149004206e37176661b Mon Sep 17 00:00:00 2001 From: Tim King Date: Tue, 22 Feb 2011 01:13:56 +0000 Subject: [PATCH] - Adds column based iterators. --- src/theory/arith/arithvar_set.h | 8 +++- src/theory/arith/row_vector.cpp | 71 ++++++++++++++++++++++++++++----- src/theory/arith/row_vector.h | 17 ++++++-- src/theory/arith/simplex.cpp | 67 ++++++++++++++++++++++++------- src/theory/arith/tableau.cpp | 17 +++++--- src/theory/arith/tableau.h | 33 ++++++++++++++- 6 files changed, 176 insertions(+), 37 deletions(-) diff --git a/src/theory/arith/arithvar_set.h b/src/theory/arith/arithvar_set.h index de215696e..ff75b373a 100644 --- a/src/theory/arith/arithvar_set.h +++ b/src/theory/arith/arithvar_set.h @@ -37,8 +37,9 @@ namespace arith { */ class ArithVarSet { -private: +public: typedef std::vector VarList; +private: //List of the ArithVars in the set. VarList d_list; @@ -49,7 +50,7 @@ private: public: typedef VarList::const_iterator iterator; - ArithVarSet() : d_list(), d_posVector() {} + ArithVarSet() : d_list(), d_posVector() {} size_t size() const { return d_list.size(); @@ -95,6 +96,9 @@ public: iterator begin() const{ return d_list.begin(); } iterator end() const{ return d_list.end(); } + const VarList& getList() const{ + return d_list; + } /** Invalidates iterators */ void remove(ArithVar x){ diff --git a/src/theory/arith/row_vector.cpp b/src/theory/arith/row_vector.cpp index 2af03bf08..2463adf47 100644 --- a/src/theory/arith/row_vector.cpp +++ b/src/theory/arith/row_vector.cpp @@ -29,6 +29,18 @@ RowVector::~RowVector(){ Assert(d_rowCount[v] >= 1); --(d_rowCount[v]); } + + Assert(matchingCounts()); +} + +bool RowVector::matchingCounts() const{ + for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){ + ArithVar v = getArithVar(*i); + if(d_columnMatrix[v].size() != d_rowCount[v]){ + return false; + } + } + return true; } bool RowVector::noZeroCoefficients(const VarCoeffArray& arr){ @@ -61,8 +73,9 @@ void RowVector::zip(const std::vector< ArithVar >& variables, RowVector::RowVector(const std::vector< ArithVar >& variables, const std::vector< Rational >& coefficients, - std::vector& counts): - d_rowCount(counts) + std::vector& counts, + std::vector& cm): + d_rowCount(counts), d_columnMatrix(cm) { zip(variables, coefficients, d_entries); @@ -94,7 +107,9 @@ void RowVector::merge(VarCoeffArray& arr, ArithVarContainsSet& contains, const VarCoeffArray& other, const Rational& c, - std::vector& counts){ + std::vector& counts, + std::vector& columnMatrix, + ArithVar basic){ VarCoeffArray copy = arr; arr.clear(); @@ -109,7 +124,11 @@ void RowVector::merge(VarCoeffArray& arr, arr.push_back(*curr1); ++curr1; }else if(getArithVar(*curr1) > getArithVar(*curr2)){ + ++counts[getArithVar(*curr2)]; + if(basic != ARITHVAR_SENTINEL){ + columnMatrix[getArithVar(*curr2)].add(basic); + } addArithVar(contains, getArithVar(*curr2)); arr.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2))); @@ -118,12 +137,15 @@ void RowVector::merge(VarCoeffArray& arr, Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2); if(res != 0){ //The variable is not new so the count stays the same - //bug: ++counts[getArithVar(*curr2)]; arr.push_back(make_pair(getArithVar(*curr1), res)); }else{ removeArithVar(contains, getArithVar(*curr2)); + --counts[getArithVar(*curr2)]; + if(basic != ARITHVAR_SENTINEL){ + columnMatrix[getArithVar(*curr2)].remove(basic); + } } ++curr1; ++curr2; @@ -135,6 +157,9 @@ void RowVector::merge(VarCoeffArray& arr, } while(curr2 != end2){ ++counts[getArithVar(*curr2)]; + if(basic != ARITHVAR_SENTINEL){ + columnMatrix[getArithVar(*curr2)].add(basic); + } addArithVar(contains, getArithVar(*curr2)); @@ -151,10 +176,10 @@ void RowVector::multiply(const Rational& c){ } } -void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other){ +void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other, ArithVar basic){ Assert(c != 0); - merge(d_entries, d_contains, other.d_entries, c, d_rowCount); + merge(d_entries, d_contains, other.d_entries, c, d_rowCount, d_columnMatrix, basic); } void RowVector::printRow(){ @@ -165,18 +190,27 @@ void RowVector::printRow(){ Debug("row::print") << std::endl; } + ReducedRowVector::ReducedRowVector(ArithVar basic, const std::vector& variables, const std::vector& coefficients, - std::vector& count): - RowVector(variables, coefficients, count), d_basic(basic){ + std::vector& count, + std::vector& columnMatrix): + RowVector(variables, coefficients, count, columnMatrix), d_basic(basic){ + for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){ + //basic is not yet in d_entries + Assert(getArithVar(*i) != d_basic); + d_columnMatrix[getArithVar(*i)].add(d_basic); + } + VarCoeffArray justBasic; justBasic.push_back(make_pair(basic, Rational(-1))); - merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount); + merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount, d_columnMatrix, d_basic); + Assert(matchingCounts()); Assert(wellFormed()); Assert(d_rowCount[d_basic] == 1); } @@ -190,10 +224,12 @@ void ReducedRowVector::substitute(const ReducedRowVector& row_s){ Rational a_rs = lookup(x_s); Assert(a_rs != 0); - addRowTimesConstant(a_rs, row_s); + addRowTimesConstant(a_rs, row_s, basic()); + Assert(!has(x_s)); Assert(wellFormed()); + Assert(matchingCounts()); Assert(d_rowCount[basic()] == 1); } @@ -202,8 +238,15 @@ void ReducedRowVector::pivot(ArithVar x_j){ Assert(basic() != x_j); Rational negInverseA_rs = -(lookup(x_j).inverse()); multiply(negInverseA_rs); + + for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){ + d_columnMatrix[getArithVar(*i)].remove(d_basic); + d_columnMatrix[getArithVar(*i)].add(x_j); + } + d_basic = x_j; + Assert(matchingCounts()); Assert(wellFormed()); //The invariant Assert(d_rowCount[basic()] == 1); does not hold. //This is because the pivot is within the row first then @@ -249,4 +292,12 @@ ReducedRowVector::~ReducedRowVector(){ //This executes before the super classes destructor RowVector, // which will set this to 0. Assert(d_rowCount[basic()] == 1); + + NonZeroIterator curr = beginNonZero(); + NonZeroIterator end = endNonZero(); + for(;curr != end; ++curr){ + ArithVar v = getArithVar(*curr); + Assert(d_rowCount[v] >= 1); + d_columnMatrix[v].remove(basic()); + } } diff --git a/src/theory/arith/row_vector.h b/src/theory/arith/row_vector.h index 85a188063..29b79ddd5 100644 --- a/src/theory/arith/row_vector.h +++ b/src/theory/arith/row_vector.h @@ -6,6 +6,7 @@ #define __CVC4__THEORY__ARITH__ROW_VECTOR_H #include "theory/arith/arith_utilities.h" +#include "theory/arith/arithvar_set.h" #include "util/rational.h" #include @@ -52,7 +53,9 @@ public: ArithVarContainsSet& contains, const VarCoeffArray& other, const Rational& c, - std::vector& count); + std::vector& count, + std::vector& columnMatrix, + ArithVar basic); protected: /** @@ -62,6 +65,9 @@ protected: */ static bool noZeroCoefficients(const VarCoeffArray& arr); + /** Debugging code.*/ + bool matchingCounts() const; + /** * Invariants: * - isSorted(d_entries, true) @@ -76,6 +82,7 @@ protected: ArithVarContainsSet d_contains; std::vector& d_rowCount; + std::vector& d_columnMatrix; NonZeroIterator lower_bound(ArithVar x_j) const{ return std::lower_bound(d_entries.begin(), d_entries.end(), make_pair(x_j,0), cmp); @@ -87,7 +94,8 @@ public: RowVector(const std::vector< ArithVar >& variables, const std::vector< Rational >& coefficients, - std::vector& counts); + std::vector& counts, + std::vector& columnMatrix); ~RowVector(); @@ -135,7 +143,7 @@ public: * Updates the current row to be the sum of itself and * another vector times c (c != 0). */ - void addRowTimesConstant(const Rational& c, const RowVector& other); + void addRowTimesConstant(const Rational& c, const RowVector& other, ArithVar basic); void printRow(); @@ -176,7 +184,8 @@ public: ReducedRowVector(ArithVar basic, const std::vector< ArithVar >& variables, const std::vector< Rational >& coefficients, - std::vector& count); + std::vector& count, + std::vector& columnMatrix); ~ReducedRowVector(); diff --git a/src/theory/arith/simplex.cpp b/src/theory/arith/simplex.cpp index d837d7ac0..2785222e3 100644 --- a/src/theory/arith/simplex.cpp +++ b/src/theory/arith/simplex.cpp @@ -168,6 +168,37 @@ bool SimplexDecisionProcedure::AssertEquality(ArithVar x_i, const DeltaRational& return false; } +set tableauAndHasSet(Tableau& tab, ArithVar v){ + set has; + for(ArithVarSet::iterator basicIter = tab.begin(); + basicIter != tab.end(); + ++basicIter){ + ArithVar basic = *basicIter; + ReducedRowVector& row = tab.lookup(basic); + + if(row.has(v)){ + has.insert(basic); + } + } + return has; +} + +set columnIteratorSet(Tableau& tab,ArithVar v){ + set has; + ArithVarSet::iterator basicIter = tab.beginColumn(v); + ArithVarSet::iterator endIter = tab.endColumn(v); + for(; basicIter != endIter; ++basicIter){ + ArithVar basic = *basicIter; + has.insert(basic); + } + return has; +} + + +bool matchingSets(Tableau& tab, ArithVar v){ + return tableauAndHasSet(tab, v) == columnIteratorSet(tab, v); +} + void SimplexDecisionProcedure::update(ArithVar x_i, const DeltaRational& v){ Assert(!d_tableau.isBasic(x_i)); DeltaRational assignment_x_i = d_partialModel.getAssignment(x_i); @@ -177,22 +208,21 @@ void SimplexDecisionProcedure::update(ArithVar x_i, const DeltaRational& v){ << assignment_x_i << "|-> " << v << endl; DeltaRational diff = v - assignment_x_i; - for(ArithVarSet::iterator basicIter = d_tableau.begin(); - basicIter != d_tableau.end(); - ++basicIter){ + Assert(matchingSets(d_tableau, x_i)); + ArithVarSet::iterator basicIter = d_tableau.beginColumn(x_i); + ArithVarSet::iterator endIter = d_tableau.endColumn(x_i); + for(; basicIter != endIter; ++basicIter){ ArithVar x_j = *basicIter; ReducedRowVector& row_j = d_tableau.lookup(x_j); - if(row_j.has(x_i)){ - const Rational& a_ji = row_j.lookup(x_i); + Assert(row_j.has(x_i)); + const Rational& a_ji = row_j.lookup(x_i); - const DeltaRational& assignment = d_partialModel.getAssignment(x_j); - DeltaRational nAssignment = assignment+(diff * a_ji); - d_partialModel.setAssignment(x_j, nAssignment); + const DeltaRational& assignment = d_partialModel.getAssignment(x_j); + DeltaRational nAssignment = assignment+(diff * a_ji); + d_partialModel.setAssignment(x_j, nAssignment); - d_queue.enqueueIfInconsistent(x_j); - //checkBasicVariable(x_j); - } + d_queue.enqueueIfInconsistent(x_j); } d_partialModel.setAssignment(x_i, v); @@ -250,12 +280,21 @@ void SimplexDecisionProcedure::pivotAndUpdate(ArithVar x_i, ArithVar x_j, DeltaR DeltaRational tmp = d_partialModel.getAssignment(x_j) + theta; d_partialModel.setAssignment(x_j, tmp); - ArithVarSet::iterator basicIter = d_tableau.begin(), end = d_tableau.end(); - for(; basicIter != end; ++basicIter){ + + Assert(matchingSets(d_tableau, x_j)); + ArithVarSet::iterator basicIter = d_tableau.beginColumn(x_j); + ArithVarSet::iterator endIter = d_tableau.endColumn(x_j); + for(; basicIter != endIter; ++basicIter){ + + //ArithVarSet::iterator basicIter = d_tableau.begin(), end = d_tableau.end(); + //for(; basicIter != end; ++basicIter){ ArithVar x_k = *basicIter; ReducedRowVector& row_k = d_tableau.lookup(x_k); - if(x_k != x_i && row_k.has(x_j)){ + Assert(row_k.has(x_j)); + + //if(x_k != x_i && row_k.has(x_j)){ + if(x_k != x_i ){ const Rational& a_kj = row_k.lookup(x_j); DeltaRational nextAssignment = d_partialModel.getAssignment(x_k) + (theta * a_kj); d_partialModel.setAssignment(x_k, nextAssignment); diff --git a/src/theory/arith/tableau.cpp b/src/theory/arith/tableau.cpp index d318a70e6..ebf7dbee8 100644 --- a/src/theory/arith/tableau.cpp +++ b/src/theory/arith/tableau.cpp @@ -41,7 +41,7 @@ void Tableau::addRow(ArithVar basicVar, //The new basic variable cannot already be a basic variable Assert(!d_basicVariables.isMember(basicVar)); d_basicVariables.add(basicVar); - ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount); + ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount, d_columnMatrix); d_rowsTable[basicVar] = row_current; //A variable in the row may have been made non-basic already. @@ -90,17 +90,22 @@ void Tableau::pivot(ArithVar x_r, ArithVar x_s){ row_s->pivot(x_s); - for(ArithVarSet::iterator basicIter = begin(), endIter = end(); - basicIter != endIter; ++basicIter){ + ArithVarSet::VarList copy(getColumn(x_s).getList()); + vector::iterator basicIter = copy.begin(), endIter = copy.end(); + + for(; basicIter != endIter; ++basicIter){ ArithVar basic = *basicIter; if(basic == x_s) continue; ReducedRowVector& row_k = lookup(basic); - if(row_k.has(x_s)){ - row_k.substitute(*row_s); - } + Assert(row_k.has(x_s)); + + row_k.substitute(*row_s); } + Assert(getColumn(x_s).size() == 1); + Assert(getRowCount(x_s) == 1); } + void Tableau::printTableau(){ Debug("tableau") << "Tableau::d_activeRows" << endl; diff --git a/src/theory/arith/tableau.h b/src/theory/arith/tableau.h index 36d61ba25..27aa1305c 100644 --- a/src/theory/arith/tableau.h +++ b/src/theory/arith/tableau.h @@ -37,6 +37,10 @@ namespace CVC4 { namespace theory { namespace arith { +typedef ArithVarSet Column; + +typedef std::vector ColumnMatrix; + class Tableau { private: @@ -47,6 +51,7 @@ private: ArithVarSet d_basicVariables; std::vector d_rowCount; + ColumnMatrix d_columnMatrix; public: /** @@ -55,7 +60,8 @@ public: Tableau() : d_rowsTable(), d_basicVariables(), - d_rowCount() + d_rowCount(), + d_columnMatrix() {} ~Tableau(); @@ -67,6 +73,16 @@ public: d_basicVariables.increaseSize(); d_rowsTable.push_back(NULL); d_rowCount.push_back(0); + + d_columnMatrix.push_back(ArithVarSet()); + + //TODO replace with version of ArithVarSet that handles misses as non-entries + // not as buffer overflows + ColumnMatrix::iterator i = d_columnMatrix.begin(), end = d_columnMatrix.end(); + for(; i != end; ++i){ + Column& col = *i; + col.increaseSize(d_columnMatrix.size()); + } } bool isBasic(ArithVar v) const { @@ -81,6 +97,19 @@ public: return d_basicVariables.end(); } + const Column& getColumn(ArithVar v){ + Assert(v < d_columnMatrix.size()); + return d_columnMatrix[v]; + } + + Column::iterator beginColumn(ArithVar v){ + return getColumn(v).begin(); + } + Column::iterator endColumn(ArithVar v){ + return getColumn(v).end(); + } + + ReducedRowVector& lookup(ArithVar var){ Assert(d_basicVariables.isMember(var)); Assert(d_rowsTable[var] != NULL); @@ -90,6 +119,8 @@ public: public: uint32_t getRowCount(ArithVar x){ Assert(x < d_rowCount.size()); + AlwaysAssert(d_rowCount[x] == getColumn(x).size()); + return d_rowCount[x]; } -- 2.30.2