From: Tim King Date: Sun, 27 Feb 2011 18:29:38 +0000 (+0000) Subject: - Makes VarCoeffPair a class instead of a typedef of pair. This... X-Git-Tag: cvc5-1.0.0~8683 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=57fe149cf7915d721912e1d1866c31346f66e2f8;p=cvc5.git - Makes VarCoeffPair a class instead of a typedef of pair. This addresses a point Dejan brought up in the code review. --- diff --git a/src/theory/arith/row_vector.cpp b/src/theory/arith/row_vector.cpp index 78ec55c2a..090938f28 100644 --- a/src/theory/arith/row_vector.cpp +++ b/src/theory/arith/row_vector.cpp @@ -11,10 +11,10 @@ bool ReducedRowVector::isSorted(const VarCoeffArray& arr, bool strictlySorted) { if(arr.size() >= 2){ const_iterator curr = arr.begin(); const_iterator end = arr.end(); - ArithVar prev = getArithVar(*curr); + ArithVar prev = (*curr).getArithVar(); ++curr; for(;curr != end; ++curr){ - ArithVar v = getArithVar(*curr); + ArithVar v = (*curr).getArithVar(); if(strictlySorted && prev > v) return false; if(!strictlySorted && prev >= v) return false; prev = v; @@ -31,7 +31,7 @@ ReducedRowVector::~ReducedRowVector(){ const_iterator curr = begin(); const_iterator endEntries = end(); for(;curr != endEntries; ++curr){ - ArithVar v = getArithVar(*curr); + ArithVar v = (*curr).getArithVar(); Assert(d_rowCount[v] >= 1); d_columnMatrix[v].remove(basic()); --(d_rowCount[v]); @@ -43,7 +43,7 @@ ReducedRowVector::~ReducedRowVector(){ bool ReducedRowVector::matchingCounts() const{ for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){ - ArithVar v = getArithVar(*i); + ArithVar v = (*i).getArithVar(); if(d_columnMatrix[v].size() != d_rowCount[v]){ return false; } @@ -54,7 +54,7 @@ bool ReducedRowVector::matchingCounts() const{ bool ReducedRowVector::noZeroCoefficients(const VarCoeffArray& arr){ for(const_iterator curr = arr.begin(), endEntries = arr.end(); curr != endEntries; ++curr){ - const Rational& coeff = getCoefficient(*curr); + const Rational& coeff = (*curr).getCoefficient(); if(coeff == 0) return false; } return true; @@ -74,7 +74,7 @@ void ReducedRowVector::zip(const std::vector< ArithVar >& variables, const Rational& coeff = *coeffIter; ArithVar var_i = *varIter; - output.push_back(make_pair(var_i, coeff)); + output.push_back(VarCoeffPair(var_i, coeff)); } } @@ -95,13 +95,14 @@ void ReducedRowVector::multiply(const Rational& c){ Assert(c != 0); for(iterator i = d_entries.begin(), end = d_entries.end(); i != end; ++i){ - getCoefficient(*i) *= c; + (*i).getCoefficient() *= c; } } void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVector& other){ Assert(c != 0); Assert(d_buffer.empty()); + Assert(wellFormed()); d_buffer.reserve(other.d_entries.size()); @@ -112,32 +113,34 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe const_iterator end2 = other.d_entries.end(); while(curr1 != end1 && curr2 != end2){ - if(getArithVar(*curr1) < getArithVar(*curr2)){ + ArithVar var1 = (*curr1).getArithVar(); + ArithVar var2 = (*curr2).getArithVar(); + + if(var1 < var2){ d_buffer.push_back(*curr1); ++curr1; - }else if(getArithVar(*curr1) > getArithVar(*curr2)){ + }else if(var1 > var2){ - ++d_rowCount[getArithVar(*curr2)]; - if(d_basic != ARITHVAR_SENTINEL){ - d_columnMatrix[getArithVar(*curr2)].add(d_basic); - } + ++d_rowCount[var2]; + d_columnMatrix[var2].add(d_basic); - addArithVar(d_contains, getArithVar(*curr2)); - d_buffer.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2))); + addArithVar(d_contains, var2); + const Rational& coeff2 = (*curr2).getCoefficient(); + d_buffer.push_back( VarCoeffPair(var2, c * coeff2)); ++curr2; }else{ - Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2); + Assert(var1 == var2); + const Rational& coeff1 = (*curr1).getCoefficient(); + const Rational& coeff2 = (*curr2).getCoefficient(); + Rational res = coeff1 + (c * coeff2); if(res != 0){ //The variable is not new so the count stays the same - - d_buffer.push_back(make_pair(getArithVar(*curr1), res)); + d_buffer.push_back(VarCoeffPair(var1, res)); }else{ - removeArithVar(d_contains, getArithVar(*curr2)); + removeArithVar(d_contains, var1); - --d_rowCount[getArithVar(*curr2)]; - if(d_basic != ARITHVAR_SENTINEL){ - d_columnMatrix[getArithVar(*curr2)].remove(d_basic); - } + --d_rowCount[var1]; + d_columnMatrix[var1].remove(d_basic); } ++curr1; ++curr2; @@ -148,14 +151,14 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe ++curr1; } while(curr2 != end2){ - ++d_rowCount[getArithVar(*curr2)]; - if(d_basic != ARITHVAR_SENTINEL){ - d_columnMatrix[getArithVar(*curr2)].add(d_basic); - } + ArithVar var2 = (*curr2).getArithVar(); + const Rational& coeff2 = (*curr2).getCoefficient(); + ++d_rowCount[var2]; + d_columnMatrix[var2].add(d_basic); - addArithVar(d_contains, getArithVar(*curr2)); + addArithVar(d_contains, var2); - d_buffer.push_back(make_pair(getArithVar(*curr2), c * getCoefficient(*curr2))); + d_buffer.push_back(VarCoeffPair(var2, c * coeff2)); ++curr2; } @@ -167,8 +170,9 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe void ReducedRowVector::printRow(){ for(const_iterator i = begin(); i != end(); ++i){ - ArithVar nb = getArithVar(*i); - Debug("row::print") << "{" << nb << "," << getCoefficient(*i) << "}"; + ArithVar nb = (*i).getArithVar(); + const Rational& coeff = (*i).getCoefficient(); + Debug("row::print") << "{" << nb << "," << coeff << "}"; } Debug("row::print") << std::endl; } @@ -182,14 +186,15 @@ ReducedRowVector::ReducedRowVector(ArithVar basic, d_basic(basic), d_rowCount(counts), d_columnMatrix(cm) { zip(variables, coefficients, d_entries); - d_entries.push_back(make_pair(basic, Rational(-1))); + d_entries.push_back(VarCoeffPair(basic, Rational(-1))); - std::sort(d_entries.begin(), d_entries.end(), cmp); + std::sort(d_entries.begin(), d_entries.end()); for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){ - ++d_rowCount[getArithVar(*i)]; - addArithVar(d_contains, getArithVar(*i)); - d_columnMatrix[getArithVar(*i)].add(d_basic); + ArithVar var = (*i).getArithVar(); + ++d_rowCount[var]; + addArithVar(d_contains, var); + d_columnMatrix[var].add(d_basic); } Assert(isSorted(d_entries, true)); @@ -225,8 +230,9 @@ void ReducedRowVector::pivot(ArithVar x_j){ multiply(negInverseA_rs); for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){ - d_columnMatrix[getArithVar(*i)].remove(d_basic); - d_columnMatrix[getArithVar(*i)].add(x_j); + ArithVar var = (*i).getArithVar(); + d_columnMatrix[var].remove(d_basic); + d_columnMatrix[var].add(x_j); } d_basic = x_j; @@ -243,15 +249,40 @@ Node ReducedRowVector::asEquality(const ArithVarToNodeMap& map) const{ using namespace CVC4::kind; Assert(size() >= 2); + + vector nonBasicPairs; + for(const_iterator i = begin(); i != end(); ++i){ + ArithVar nb = (*i).getArithVar(); + if(nb == basic()) continue; + Node var = (map.find(nb))->second; + Node coeff = mkRationalNode((*i).getCoefficient()); + + Node mult = NodeBuilder<2>(MULT) << coeff << var; + nonBasicPairs.push_back(mult); + } + + Node sum = Node::null(); + if(nonBasicPairs.size() == 1 ){ + sum = nonBasicPairs.front(); + }else{ + Assert(nonBasicPairs.size() >= 2); + NodeBuilder<> sumBuilder(PLUS); + sumBuilder.append(nonBasicPairs); + sum = sumBuilder; + } + Node basicVar = (map.find(basic()))->second; + return NodeBuilder<2>(EQUAL) << basicVar << sum; + + /* Node sum = Node::null(); if(size() > 2){ NodeBuilder<> sumBuilder(PLUS); for(const_iterator i = begin(); i != end(); ++i){ - ArithVar nb = getArithVar(*i); + ArithVar nb = (*i).getArithVar(); if(nb == basic()) continue; Node var = (map.find(nb))->second; - Node coeff = mkRationalNode(getCoefficient(*i)); + Node coeff = mkRationalNode((*i).getCoefficient()); Node mult = NodeBuilder<2>(MULT) << coeff << var; sumBuilder << mult; @@ -260,15 +291,16 @@ Node ReducedRowVector::asEquality(const ArithVarToNodeMap& map) const{ }else{ Assert(size() == 2); const_iterator i = begin(); - if(getArithVar(*i) == basic()){ + if((*i).getArithVar() == basic()){ ++i; } - Assert(getArithVar(*i) != basic()); - Node var = (map.find(getArithVar(*i)))->second; - Node coeff = mkRationalNode(getCoefficient(*i)); + Assert((*i).getArithVar() != basic()); + Node var = (map.find((*i).getArithVar()))->second; + Node coeff = mkRationalNode((*i).getCoefficient()); sum = NodeBuilder<2>(MULT) << coeff << var; } Node basicVar = (map.find(basic()))->second; return NodeBuilder<2>(EQUAL) << basicVar << sum; +*/ } diff --git a/src/theory/arith/row_vector.h b/src/theory/arith/row_vector.h index 983e19a0a..0fdfd7f0c 100644 --- a/src/theory/arith/row_vector.h +++ b/src/theory/arith/row_vector.h @@ -14,15 +14,26 @@ namespace CVC4 { namespace theory { namespace arith { -typedef std::pair VarCoeffPair; +class VarCoeffPair { +private: + ArithVar d_variable; + Rational d_coeff; + +public: + VarCoeffPair(ArithVar v, const Rational& q): d_variable(v), d_coeff(q) {} -inline ArithVar getArithVar(const VarCoeffPair& v) { return v.first; } -inline Rational& getCoefficient(VarCoeffPair& v) { return v.second; } -inline const Rational& getCoefficient(const VarCoeffPair& v) { return v.second; } + ArithVar getArithVar() const { return d_variable; } + Rational& getCoefficient() { return d_coeff; } + const Rational& getCoefficient() const { return d_coeff; } -inline bool cmp(const VarCoeffPair& a, const VarCoeffPair& b){ - return getArithVar(a) < getArithVar(b); -} + bool operator<(const VarCoeffPair& other) const{ + return getArithVar() < other.getArithVar(); + } + + static bool variableLess(const VarCoeffPair& a, const VarCoeffPair& b){ + return a < b; + } +}; /** * ReducedRowVector is a sparse vector representation that represents the @@ -109,7 +120,7 @@ public: Assert(has(x_j)); Assert(hasInEntries(x_j)); const_iterator lb = lower_bound(x_j); - return getCoefficient(*lb); + return (*lb).getCoefficient(); } @@ -190,7 +201,7 @@ private: bool matchingCounts() const; const_iterator lower_bound(ArithVar x_j) const{ - return std::lower_bound(d_entries.begin(), d_entries.end(), std::make_pair(x_j,0), cmp); + return std::lower_bound(d_entries.begin(), d_entries.end(), VarCoeffPair(x_j, 0)); } /** Debugging code */ @@ -207,7 +218,7 @@ private: /** Debugging code. */ bool hasInEntries(ArithVar x_j) const { - return std::binary_search(d_entries.begin(), d_entries.end(), std::make_pair(x_j,0), cmp); + return std::binary_search(d_entries.begin(), d_entries.end(), VarCoeffPair(x_j,0)); } }; /* class ReducedRowVector */ diff --git a/src/theory/arith/simplex.cpp b/src/theory/arith/simplex.cpp index 02ce310ff..0809e0788 100644 --- a/src/theory/arith/simplex.cpp +++ b/src/theory/arith/simplex.cpp @@ -257,8 +257,8 @@ void SimplexDecisionProcedure::pivotAndUpdate(ArithVar x_i, ArithVar x_j, DeltaR varIter != row_k.end(); ++varIter){ - ArithVar var = varIter->first; - const Rational& coeff = varIter->second; + ArithVar var = (*varIter).getArithVar(); + const Rational& coeff = (*varIter).getCoefficient(); DeltaRational beta = d_partialModel.getAssignment(var); Debug("arith::pivotAndUpdate") << var << beta << coeff; if(d_partialModel.hasLowerBound(var)){ @@ -334,10 +334,10 @@ ArithVar SimplexDecisionProcedure::selectSlack(ArithVar x_i, bool first){ for(ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end(); nbi != end; ++nbi){ - ArithVar nonbasic = getArithVar(*nbi); + ArithVar nonbasic = (*nbi).getArithVar(); if(nonbasic == x_i) continue; - const Rational& a_ij = nbi->second; + const Rational& a_ij = (*nbi).getCoefficient(); int cmp = a_ij.cmp(d_constants.d_ZERO); if(above){ // beta(x_i) > u_i if( cmp < 0 && d_partialModel.strictlyBelowUpperBound(nonbasic)){ @@ -566,10 +566,10 @@ Node SimplexDecisionProcedure::generateConflictAbove(ArithVar conflictVar){ ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end(); for(; nbi != end; ++nbi){ - ArithVar nonbasic = getArithVar(*nbi); + ArithVar nonbasic = (*nbi).getArithVar(); if(nonbasic == conflictVar) continue; - const Rational& a_ij = nbi->second; + const Rational& a_ij = (*nbi).getCoefficient(); Assert(a_ij != d_constants.d_ZERO); @@ -606,10 +606,10 @@ Node SimplexDecisionProcedure::generateConflictBelow(ArithVar conflictVar){ ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end(); for(; nbi != end; ++nbi){ - ArithVar nonbasic = getArithVar(*nbi); + ArithVar nonbasic = (*nbi).getArithVar(); if(nonbasic == conflictVar) continue; - const Rational& a_ij = nbi->second; + const Rational& a_ij = (*nbi).getCoefficient(); Assert(a_ij != d_constants.d_ZERO); @@ -643,9 +643,9 @@ DeltaRational SimplexDecisionProcedure::computeRowValue(ArithVar x, bool useSafe ReducedRowVector& row = d_tableau.lookup(x); for(ReducedRowVector::const_iterator i = row.begin(), end = row.end(); i != end;++i){ - ArithVar nonbasic = getArithVar(*i); + ArithVar nonbasic = (*i).getArithVar(); if(nonbasic == row.basic()) continue; - const Rational& coeff = getCoefficient(*i); + const Rational& coeff = (*i).getCoefficient(); const DeltaRational& assignment = d_partialModel.getAssignment(nonbasic, useSafe); sum = sum + (assignment * coeff); @@ -671,10 +671,10 @@ void SimplexDecisionProcedure::checkTableau(){ for(ReducedRowVector::const_iterator nonbasicIter = row_k.begin(); nonbasicIter != row_k.end(); ++nonbasicIter){ - ArithVar nonbasic = nonbasicIter->first; + ArithVar nonbasic = (*nonbasicIter).getArithVar(); if(basic == nonbasic) continue; - const Rational& coeff = nonbasicIter->second; + const Rational& coeff = (*nonbasicIter).getCoefficient(); DeltaRational beta = d_partialModel.getAssignment(nonbasic); Debug("paranoid:check_tableau") << nonbasic << beta << coeff<