- Makes VarCoeffPair a class instead of a typedef of pair<ArithVar, Rational>. This...
authorTim King <taking@cs.nyu.edu>
Sun, 27 Feb 2011 18:29:38 +0000 (18:29 +0000)
committerTim King <taking@cs.nyu.edu>
Sun, 27 Feb 2011 18:29:38 +0000 (18:29 +0000)
src/theory/arith/row_vector.cpp
src/theory/arith/row_vector.h
src/theory/arith/simplex.cpp

index 78ec55c2a3775a9614f6b8b270f0a1ea74486ac3..090938f283200be2e2b9937a49fce1cbd1d3d96a 100644 (file)
@@ -11,10 +11,10 @@ bool ReducedRowVector::isSorted(const VarCoeffArray& arr, bool strictlySorted) {
   if(arr.size() >= 2){
     const_iterator curr = arr.begin();
     const_iterator end = arr.end();
-    ArithVar prev = getArithVar(*curr);
+    ArithVar prev = (*curr).getArithVar();
     ++curr;
     for(;curr != end; ++curr){
-      ArithVar v = getArithVar(*curr);
+      ArithVar v = (*curr).getArithVar();
       if(strictlySorted && prev > v) return false;
       if(!strictlySorted && prev >= v) return false;
       prev = v;
@@ -31,7 +31,7 @@ ReducedRowVector::~ReducedRowVector(){
   const_iterator curr = begin();
   const_iterator endEntries = end();
   for(;curr != endEntries; ++curr){
-    ArithVar v = getArithVar(*curr);
+    ArithVar v = (*curr).getArithVar();
     Assert(d_rowCount[v] >= 1);
     d_columnMatrix[v].remove(basic());
     --(d_rowCount[v]);
@@ -43,7 +43,7 @@ ReducedRowVector::~ReducedRowVector(){
 
 bool ReducedRowVector::matchingCounts() const{
   for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
-    ArithVar v = getArithVar(*i);
+    ArithVar v = (*i).getArithVar();
     if(d_columnMatrix[v].size() != d_rowCount[v]){
       return false;
     }
@@ -54,7 +54,7 @@ bool ReducedRowVector::matchingCounts() const{
 bool ReducedRowVector::noZeroCoefficients(const VarCoeffArray& arr){
   for(const_iterator curr = arr.begin(), endEntries = arr.end();
       curr != endEntries; ++curr){
-    const Rational& coeff = getCoefficient(*curr);
+    const Rational& coeff = (*curr).getCoefficient();
     if(coeff == 0) return false;
   }
   return true;
@@ -74,7 +74,7 @@ void ReducedRowVector::zip(const std::vector< ArithVar >& variables,
     const Rational& coeff = *coeffIter;
     ArithVar var_i = *varIter;
 
-    output.push_back(make_pair(var_i, coeff));
+    output.push_back(VarCoeffPair(var_i, coeff));
   }
 }
 
@@ -95,13 +95,14 @@ void ReducedRowVector::multiply(const Rational& c){
   Assert(c != 0);
 
   for(iterator i = d_entries.begin(), end = d_entries.end(); i != end; ++i){
-    getCoefficient(*i) *= c;
+    (*i).getCoefficient() *= c;
   }
 }
 
 void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVector& other){
   Assert(c != 0);
   Assert(d_buffer.empty());
+  Assert(wellFormed());
 
   d_buffer.reserve(other.d_entries.size());
 
@@ -112,32 +113,34 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe
   const_iterator end2 = other.d_entries.end();
 
   while(curr1 != end1 && curr2 != end2){
-    if(getArithVar(*curr1) < getArithVar(*curr2)){
+    ArithVar var1 = (*curr1).getArithVar();
+    ArithVar var2 = (*curr2).getArithVar();
+
+    if(var1 < var2){
       d_buffer.push_back(*curr1);
       ++curr1;
-    }else if(getArithVar(*curr1) > getArithVar(*curr2)){
+    }else if(var1 > var2){
 
-      ++d_rowCount[getArithVar(*curr2)];
-      if(d_basic != ARITHVAR_SENTINEL){
-        d_columnMatrix[getArithVar(*curr2)].add(d_basic);
-      }
+      ++d_rowCount[var2];
+      d_columnMatrix[var2].add(d_basic);
 
-      addArithVar(d_contains, getArithVar(*curr2));
-      d_buffer.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
+      addArithVar(d_contains, var2);
+      const Rational& coeff2 = (*curr2).getCoefficient();
+      d_buffer.push_back( VarCoeffPair(var2, c * coeff2));
       ++curr2;
     }else{
-      Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2);
+      Assert(var1 == var2);
+      const Rational& coeff1 = (*curr1).getCoefficient();
+      const Rational& coeff2 = (*curr2).getCoefficient();
+      Rational res = coeff1 + (c * coeff2);
       if(res != 0){
         //The variable is not new so the count stays the same
-
-        d_buffer.push_back(make_pair(getArithVar(*curr1), res));
+        d_buffer.push_back(VarCoeffPair(var1, res));
       }else{
-        removeArithVar(d_contains, getArithVar(*curr2));
+        removeArithVar(d_contains, var1);
 
-        --d_rowCount[getArithVar(*curr2)];
-        if(d_basic != ARITHVAR_SENTINEL){
-          d_columnMatrix[getArithVar(*curr2)].remove(d_basic);
-        }
+        --d_rowCount[var1];
+        d_columnMatrix[var1].remove(d_basic);
       }
       ++curr1;
       ++curr2;
@@ -148,14 +151,14 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe
     ++curr1;
   }
   while(curr2 != end2){
-    ++d_rowCount[getArithVar(*curr2)];
-    if(d_basic != ARITHVAR_SENTINEL){
-      d_columnMatrix[getArithVar(*curr2)].add(d_basic);
-    }
+    ArithVar var2 = (*curr2).getArithVar();
+    const Rational& coeff2 = (*curr2).getCoefficient();
+    ++d_rowCount[var2];
+    d_columnMatrix[var2].add(d_basic);
 
-    addArithVar(d_contains, getArithVar(*curr2));
+    addArithVar(d_contains, var2);
 
-    d_buffer.push_back(make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
+    d_buffer.push_back(VarCoeffPair(var2, c * coeff2));
     ++curr2;
   }
 
@@ -167,8 +170,9 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe
 
 void ReducedRowVector::printRow(){
   for(const_iterator i = begin(); i != end(); ++i){
-    ArithVar nb = getArithVar(*i);
-    Debug("row::print") << "{" << nb << "," << getCoefficient(*i) << "}";
+    ArithVar nb = (*i).getArithVar();
+    const Rational& coeff = (*i).getCoefficient();
+    Debug("row::print") << "{" << nb << "," << coeff << "}";
   }
   Debug("row::print") << std::endl;
 }
@@ -182,14 +186,15 @@ ReducedRowVector::ReducedRowVector(ArithVar basic,
   d_basic(basic), d_rowCount(counts),  d_columnMatrix(cm)
 {
   zip(variables, coefficients, d_entries);
-  d_entries.push_back(make_pair(basic, Rational(-1)));
+  d_entries.push_back(VarCoeffPair(basic, Rational(-1)));
 
-  std::sort(d_entries.begin(), d_entries.end(), cmp);
+  std::sort(d_entries.begin(), d_entries.end());
 
   for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
-    ++d_rowCount[getArithVar(*i)];
-    addArithVar(d_contains, getArithVar(*i));
-    d_columnMatrix[getArithVar(*i)].add(d_basic);
+    ArithVar var = (*i).getArithVar();
+    ++d_rowCount[var];
+    addArithVar(d_contains, var);
+    d_columnMatrix[var].add(d_basic);
   }
 
   Assert(isSorted(d_entries, true));
@@ -225,8 +230,9 @@ void ReducedRowVector::pivot(ArithVar x_j){
   multiply(negInverseA_rs);
 
   for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
-    d_columnMatrix[getArithVar(*i)].remove(d_basic);
-    d_columnMatrix[getArithVar(*i)].add(x_j);
+    ArithVar var = (*i).getArithVar();
+    d_columnMatrix[var].remove(d_basic);
+    d_columnMatrix[var].add(x_j);
   }
 
   d_basic = x_j;
@@ -243,15 +249,40 @@ Node ReducedRowVector::asEquality(const ArithVarToNodeMap& map) const{
   using namespace CVC4::kind;
 
   Assert(size() >= 2);
+
+  vector<Node> nonBasicPairs;
+  for(const_iterator i = begin(); i != end(); ++i){
+    ArithVar nb = (*i).getArithVar();
+    if(nb == basic()) continue;
+    Node var = (map.find(nb))->second;
+    Node coeff = mkRationalNode((*i).getCoefficient());
+
+    Node mult = NodeBuilder<2>(MULT) << coeff << var;
+    nonBasicPairs.push_back(mult);
+  }
+
+  Node sum = Node::null();
+  if(nonBasicPairs.size() == 1 ){
+    sum = nonBasicPairs.front();
+  }else{
+    Assert(nonBasicPairs.size() >= 2);
+    NodeBuilder<> sumBuilder(PLUS);
+    sumBuilder.append(nonBasicPairs);
+    sum = sumBuilder;
+  }
+  Node basicVar = (map.find(basic()))->second;
+  return NodeBuilder<2>(EQUAL) << basicVar << sum;
+
+  /*
   Node sum = Node::null();
   if(size() > 2){
     NodeBuilder<> sumBuilder(PLUS);
 
     for(const_iterator i = begin(); i != end(); ++i){
-      ArithVar nb = getArithVar(*i);
+      ArithVar nb = (*i).getArithVar();
       if(nb == basic()) continue;
       Node var = (map.find(nb))->second;
-      Node coeff = mkRationalNode(getCoefficient(*i));
+      Node coeff = mkRationalNode((*i).getCoefficient());
 
       Node mult = NodeBuilder<2>(MULT) << coeff << var;
       sumBuilder << mult;
@@ -260,15 +291,16 @@ Node ReducedRowVector::asEquality(const ArithVarToNodeMap& map) const{
   }else{
     Assert(size() == 2);
     const_iterator i = begin();
-    if(getArithVar(*i) == basic()){
+    if((*i).getArithVar() == basic()){
       ++i;
     }
-    Assert(getArithVar(*i) != basic());
-    Node var = (map.find(getArithVar(*i)))->second;
-    Node coeff = mkRationalNode(getCoefficient(*i));
+    Assert((*i).getArithVar() != basic());
+    Node var = (map.find((*i).getArithVar()))->second;
+    Node coeff = mkRationalNode((*i).getCoefficient());
     sum = NodeBuilder<2>(MULT) << coeff << var;
   }
   Node basicVar = (map.find(basic()))->second;
   return NodeBuilder<2>(EQUAL) << basicVar << sum;
+*/
 }
 
index 983e19a0a61976e4c05f6d7095f821a878d1bd05..0fdfd7f0c3d88795fbbb1764416ddd8dceadac24 100644 (file)
@@ -14,15 +14,26 @@ namespace CVC4 {
 namespace theory {
 namespace arith {
 
-typedef std::pair<ArithVar, Rational> VarCoeffPair;
+class VarCoeffPair {
+private:
+  ArithVar d_variable;
+  Rational d_coeff;
+
+public:
+  VarCoeffPair(ArithVar v, const Rational& q): d_variable(v), d_coeff(q) {}
 
-inline ArithVar getArithVar(const VarCoeffPair& v) { return v.first; }
-inline Rational& getCoefficient(VarCoeffPair& v) { return v.second; }
-inline const Rational& getCoefficient(const VarCoeffPair& v) { return v.second; }
+  ArithVar getArithVar() const { return d_variable; }
+  Rational& getCoefficient() { return d_coeff; }
+  const Rational& getCoefficient() const { return d_coeff; }
 
-inline bool cmp(const VarCoeffPair& a, const VarCoeffPair& b){
-  return getArithVar(a) < getArithVar(b);
-}
+  bool operator<(const VarCoeffPair& other) const{
+    return getArithVar() < other.getArithVar();
+  }
+
+  static bool variableLess(const VarCoeffPair& a, const VarCoeffPair& b){
+    return a < b;
+  }
+};
 
 /**
  * ReducedRowVector is a sparse vector representation that represents the
@@ -109,7 +120,7 @@ public:
     Assert(has(x_j));
     Assert(hasInEntries(x_j));
     const_iterator lb = lower_bound(x_j);
-    return getCoefficient(*lb);
+    return (*lb).getCoefficient();
   }
 
 
@@ -190,7 +201,7 @@ private:
   bool matchingCounts() const;
 
   const_iterator lower_bound(ArithVar x_j) const{
-    return std::lower_bound(d_entries.begin(), d_entries.end(), std::make_pair(x_j,0), cmp);
+    return std::lower_bound(d_entries.begin(), d_entries.end(), VarCoeffPair(x_j, 0));
   }
 
   /** Debugging code */
@@ -207,7 +218,7 @@ private:
 
   /** Debugging code. */
   bool hasInEntries(ArithVar x_j) const {
-    return std::binary_search(d_entries.begin(), d_entries.end(), std::make_pair(x_j,0), cmp);
+    return std::binary_search(d_entries.begin(), d_entries.end(), VarCoeffPair(x_j,0));
   }
 
 }; /* class ReducedRowVector */
index 02ce310ff58059bf94150592bcc5191136c81801..0809e07882959623631498129688ee945030f691 100644 (file)
@@ -257,8 +257,8 @@ void SimplexDecisionProcedure::pivotAndUpdate(ArithVar x_i, ArithVar x_j, DeltaR
         varIter != row_k.end();
         ++varIter){
 
-      ArithVar var = varIter->first;
-      const Rational& coeff = varIter->second;
+      ArithVar var = (*varIter).getArithVar();
+      const Rational& coeff = (*varIter).getCoefficient();
       DeltaRational beta = d_partialModel.getAssignment(var);
       Debug("arith::pivotAndUpdate") << var << beta << coeff;
       if(d_partialModel.hasLowerBound(var)){
@@ -334,10 +334,10 @@ ArithVar SimplexDecisionProcedure::selectSlack(ArithVar x_i, bool first){
 
   for(ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end();
       nbi != end; ++nbi){
-    ArithVar nonbasic = getArithVar(*nbi);
+    ArithVar nonbasic = (*nbi).getArithVar();
     if(nonbasic == x_i) continue;
 
-    const Rational& a_ij = nbi->second;
+    const Rational& a_ij = (*nbi).getCoefficient();
     int cmp = a_ij.cmp(d_constants.d_ZERO);
     if(above){ // beta(x_i) > u_i
       if( cmp < 0 && d_partialModel.strictlyBelowUpperBound(nonbasic)){
@@ -566,10 +566,10 @@ Node SimplexDecisionProcedure::generateConflictAbove(ArithVar conflictVar){
   ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end();
 
   for(; nbi != end; ++nbi){
-    ArithVar nonbasic = getArithVar(*nbi);
+    ArithVar nonbasic = (*nbi).getArithVar();
     if(nonbasic == conflictVar) continue;
 
-    const Rational& a_ij = nbi->second;
+    const Rational& a_ij = (*nbi).getCoefficient();
 
     Assert(a_ij != d_constants.d_ZERO);
 
@@ -606,10 +606,10 @@ Node SimplexDecisionProcedure::generateConflictBelow(ArithVar conflictVar){
 
   ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end();
   for(; nbi != end; ++nbi){
-    ArithVar nonbasic = getArithVar(*nbi);
+    ArithVar nonbasic = (*nbi).getArithVar();
     if(nonbasic == conflictVar) continue;
 
-    const Rational& a_ij = nbi->second;
+    const Rational& a_ij = (*nbi).getCoefficient();
 
     Assert(a_ij != d_constants.d_ZERO);
 
@@ -643,9 +643,9 @@ DeltaRational SimplexDecisionProcedure::computeRowValue(ArithVar x, bool useSafe
   ReducedRowVector& row = d_tableau.lookup(x);
   for(ReducedRowVector::const_iterator i = row.begin(), end = row.end();
       i != end;++i){
-    ArithVar nonbasic = getArithVar(*i);
+    ArithVar nonbasic = (*i).getArithVar();
     if(nonbasic == row.basic()) continue;
-    const Rational& coeff = getCoefficient(*i);
+    const Rational& coeff = (*i).getCoefficient();
 
     const DeltaRational& assignment = d_partialModel.getAssignment(nonbasic, useSafe);
     sum = sum + (assignment * coeff);
@@ -671,10 +671,10 @@ void SimplexDecisionProcedure::checkTableau(){
     for(ReducedRowVector::const_iterator nonbasicIter = row_k.begin();
         nonbasicIter != row_k.end();
         ++nonbasicIter){
-      ArithVar nonbasic = nonbasicIter->first;
+      ArithVar nonbasic = (*nonbasicIter).getArithVar();
       if(basic == nonbasic) continue;
 
-      const Rational& coeff = nonbasicIter->second;
+      const Rational& coeff = (*nonbasicIter).getCoefficient();
       DeltaRational beta = d_partialModel.getAssignment(nonbasic);
       Debug("paranoid:check_tableau") << nonbasic << beta << coeff<<endl;
       sum = sum + (beta*coeff);