- Adds column based iterators.
authorTim King <taking@cs.nyu.edu>
Tue, 22 Feb 2011 01:13:56 +0000 (01:13 +0000)
committerTim King <taking@cs.nyu.edu>
Tue, 22 Feb 2011 01:13:56 +0000 (01:13 +0000)
src/theory/arith/arithvar_set.h
src/theory/arith/row_vector.cpp
src/theory/arith/row_vector.h
src/theory/arith/simplex.cpp
src/theory/arith/tableau.cpp
src/theory/arith/tableau.h

index de215696eea70337891c3a63cb87a5adbac5ecd8..ff75b373a0c58306cdd2bb49b6b49c34f77748a5 100644 (file)
@@ -37,8 +37,9 @@ namespace arith {
  */
 
 class ArithVarSet {
-private:
+public:
   typedef std::vector<ArithVar> VarList;
+private:
   //List of the ArithVars in the set.
   VarList d_list;
 
@@ -49,7 +50,7 @@ private:
 public:
   typedef VarList::const_iterator iterator;
 
- ArithVarSet() :  d_list(), d_posVector() {}
 ArithVarSet() :  d_list(), d_posVector() {}
 
   size_t size() const {
     return d_list.size();
@@ -95,6 +96,9 @@ public:
   iterator begin() const{ return d_list.begin(); }
   iterator end() const{ return d_list.end(); }
 
+  const VarList& getList() const{
+    return d_list;
+  }
 
   /** Invalidates iterators */
   void remove(ArithVar x){
index 2af03bf085d78924057bd80746fe88c9a6cf3f9e..2463adf474b815cedd339cfeed8397594009a674 100644 (file)
@@ -29,6 +29,18 @@ RowVector::~RowVector(){
     Assert(d_rowCount[v] >= 1);
     --(d_rowCount[v]);
   }
+
+  Assert(matchingCounts());
+}
+
+bool RowVector::matchingCounts() const{
+  for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
+    ArithVar v = getArithVar(*i);
+    if(d_columnMatrix[v].size() != d_rowCount[v]){
+      return false;
+    }
+  }
+  return true;
 }
 
 bool RowVector::noZeroCoefficients(const VarCoeffArray& arr){
@@ -61,8 +73,9 @@ void RowVector::zip(const std::vector< ArithVar >& variables,
 
 RowVector::RowVector(const std::vector< ArithVar >& variables,
                      const std::vector< Rational >& coefficients,
-                     std::vector<uint32_t>& counts):
-  d_rowCount(counts)
+                     std::vector<uint32_t>& counts,
+                     std::vector<ArithVarSet>& cm):
+  d_rowCount(counts), d_columnMatrix(cm)
 {
   zip(variables, coefficients, d_entries);
 
@@ -94,7 +107,9 @@ void RowVector::merge(VarCoeffArray& arr,
                       ArithVarContainsSet& contains,
                       const VarCoeffArray& other,
                       const Rational& c,
-                      std::vector<uint32_t>& counts){
+                      std::vector<uint32_t>& counts,
+                      std::vector<ArithVarSet>& columnMatrix,
+                      ArithVar basic){
   VarCoeffArray copy = arr;
   arr.clear();
 
@@ -109,7 +124,11 @@ void RowVector::merge(VarCoeffArray& arr,
       arr.push_back(*curr1);
       ++curr1;
     }else if(getArithVar(*curr1) > getArithVar(*curr2)){
+
       ++counts[getArithVar(*curr2)];
+      if(basic != ARITHVAR_SENTINEL){
+        columnMatrix[getArithVar(*curr2)].add(basic);
+      }
 
       addArithVar(contains, getArithVar(*curr2));
       arr.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
@@ -118,12 +137,15 @@ void RowVector::merge(VarCoeffArray& arr,
       Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2);
       if(res != 0){
         //The variable is not new so the count stays the same
-        //bug: ++counts[getArithVar(*curr2)];
 
         arr.push_back(make_pair(getArithVar(*curr1), res));
       }else{
         removeArithVar(contains, getArithVar(*curr2));
+
         --counts[getArithVar(*curr2)];
+        if(basic != ARITHVAR_SENTINEL){
+          columnMatrix[getArithVar(*curr2)].remove(basic);
+        }
       }
       ++curr1;
       ++curr2;
@@ -135,6 +157,9 @@ void RowVector::merge(VarCoeffArray& arr,
   }
   while(curr2 != end2){
     ++counts[getArithVar(*curr2)];
+    if(basic != ARITHVAR_SENTINEL){
+      columnMatrix[getArithVar(*curr2)].add(basic);
+    }
 
     addArithVar(contains, getArithVar(*curr2));
 
@@ -151,10 +176,10 @@ void RowVector::multiply(const Rational& c){
   }
 }
 
-void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other){
+void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other, ArithVar basic){
   Assert(c != 0);
 
-  merge(d_entries, d_contains, other.d_entries, c, d_rowCount);
+  merge(d_entries, d_contains, other.d_entries, c, d_rowCount, d_columnMatrix, basic);
 }
 
 void RowVector::printRow(){
@@ -165,18 +190,27 @@ void RowVector::printRow(){
   Debug("row::print") << std::endl;
 }
 
+
 ReducedRowVector::ReducedRowVector(ArithVar basic,
                                    const std::vector<ArithVar>& variables,
                                    const std::vector<Rational>& coefficients,
-                                   std::vector<uint32_t>& count):
-  RowVector(variables, coefficients, count), d_basic(basic){
+                                   std::vector<uint32_t>& count,
+                                   std::vector<ArithVarSet>& columnMatrix):
+  RowVector(variables, coefficients, count, columnMatrix), d_basic(basic){
 
 
+  for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
+    //basic is not yet in d_entries
+    Assert(getArithVar(*i) != d_basic);
+    d_columnMatrix[getArithVar(*i)].add(d_basic);
+  }
+
   VarCoeffArray justBasic;
   justBasic.push_back(make_pair(basic, Rational(-1)));
 
-  merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount);
+  merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount, d_columnMatrix, d_basic);
 
+  Assert(matchingCounts());
   Assert(wellFormed());
   Assert(d_rowCount[d_basic] == 1);
 }
@@ -190,10 +224,12 @@ void ReducedRowVector::substitute(const ReducedRowVector& row_s){
   Rational a_rs = lookup(x_s);
   Assert(a_rs != 0);
 
-  addRowTimesConstant(a_rs, row_s);
+  addRowTimesConstant(a_rs, row_s, basic());
+
 
   Assert(!has(x_s));
   Assert(wellFormed());
+  Assert(matchingCounts());
   Assert(d_rowCount[basic()] == 1);
 }
 
@@ -202,8 +238,15 @@ void ReducedRowVector::pivot(ArithVar x_j){
   Assert(basic() != x_j);
   Rational negInverseA_rs = -(lookup(x_j).inverse());
   multiply(negInverseA_rs);
+
+  for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
+    d_columnMatrix[getArithVar(*i)].remove(d_basic);
+    d_columnMatrix[getArithVar(*i)].add(x_j);
+  }
+
   d_basic = x_j;
 
+  Assert(matchingCounts());
   Assert(wellFormed());
   //The invariant Assert(d_rowCount[basic()] == 1); does not hold.
   //This is because the pivot is within the row first then
@@ -249,4 +292,12 @@ ReducedRowVector::~ReducedRowVector(){
   //This executes before the super classes destructor RowVector,
   // which will set this to 0.
   Assert(d_rowCount[basic()] == 1);
+
+  NonZeroIterator curr = beginNonZero();
+  NonZeroIterator end = endNonZero();
+  for(;curr != end; ++curr){
+    ArithVar v = getArithVar(*curr);
+    Assert(d_rowCount[v] >= 1);
+    d_columnMatrix[v].remove(basic());
+  }
 }
index 85a1880636c858c345fac03864632e506f3758a3..29b79ddd51695c0597454ba371daa09a6379fd55 100644 (file)
@@ -6,6 +6,7 @@
 #define __CVC4__THEORY__ARITH__ROW_VECTOR_H
 
 #include "theory/arith/arith_utilities.h"
+#include "theory/arith/arithvar_set.h"
 #include "util/rational.h"
 #include <vector>
 
@@ -52,7 +53,9 @@ public:
                     ArithVarContainsSet& contains,
                     const VarCoeffArray& other,
                     const Rational& c,
-                    std::vector<uint32_t>& count);
+                    std::vector<uint32_t>& count,
+                    std::vector<ArithVarSet>& columnMatrix,
+                    ArithVar basic);
 
 protected:
   /**
@@ -62,6 +65,9 @@ protected:
    */
   static bool noZeroCoefficients(const VarCoeffArray& arr);
 
+  /** Debugging code.*/
+  bool matchingCounts() const;
+
   /**
    * Invariants:
    * - isSorted(d_entries, true)
@@ -76,6 +82,7 @@ protected:
   ArithVarContainsSet d_contains;
 
   std::vector<uint32_t>& d_rowCount;
+  std::vector<ArithVarSet>& d_columnMatrix;
 
   NonZeroIterator lower_bound(ArithVar x_j) const{
     return std::lower_bound(d_entries.begin(), d_entries.end(), make_pair(x_j,0), cmp);
@@ -87,7 +94,8 @@ public:
 
   RowVector(const std::vector< ArithVar >& variables,
             const std::vector< Rational >& coefficients,
-            std::vector<uint32_t>& counts);
+            std::vector<uint32_t>& counts,
+            std::vector<ArithVarSet>& columnMatrix);
 
   ~RowVector();
 
@@ -135,7 +143,7 @@ public:
    * Updates the current row to be the sum of itself and
    * another vector times c (c != 0).
    */
-  void addRowTimesConstant(const Rational& c, const RowVector& other);
+  void addRowTimesConstant(const Rational& c, const RowVector& other, ArithVar basic);
 
   void printRow();
 
@@ -176,7 +184,8 @@ public:
   ReducedRowVector(ArithVar basic,
                    const std::vector< ArithVar >& variables,
                    const std::vector< Rational >& coefficients,
-                   std::vector<uint32_t>& count);
+                   std::vector<uint32_t>& count,
+                   std::vector<ArithVarSet>& columnMatrix);
 
   ~ReducedRowVector();
 
index d837d7ac03b094d98592352d9685cb90fb243341..2785222e32c8314fdf0a9dc1e877a4898f56d724 100644 (file)
@@ -168,6 +168,37 @@ bool SimplexDecisionProcedure::AssertEquality(ArithVar x_i, const DeltaRational&
   return false;
 }
 
+set<ArithVar> tableauAndHasSet(Tableau& tab, ArithVar v){
+  set<ArithVar> has;
+  for(ArithVarSet::iterator basicIter = tab.begin();
+      basicIter != tab.end();
+      ++basicIter){
+    ArithVar basic = *basicIter;
+    ReducedRowVector& row = tab.lookup(basic);
+
+    if(row.has(v)){
+      has.insert(basic);
+    }
+  }
+  return has;
+}
+
+set<ArithVar> columnIteratorSet(Tableau& tab,ArithVar v){
+  set<ArithVar> has;
+  ArithVarSet::iterator basicIter = tab.beginColumn(v);
+  ArithVarSet::iterator endIter = tab.endColumn(v);
+  for(; basicIter != endIter; ++basicIter){
+    ArithVar basic = *basicIter;
+    has.insert(basic);
+  }
+  return has;
+}
+
+
+bool matchingSets(Tableau& tab, ArithVar v){
+  return tableauAndHasSet(tab, v) == columnIteratorSet(tab, v);
+}
+
 void SimplexDecisionProcedure::update(ArithVar x_i, const DeltaRational& v){
   Assert(!d_tableau.isBasic(x_i));
   DeltaRational assignment_x_i = d_partialModel.getAssignment(x_i);
@@ -177,22 +208,21 @@ void SimplexDecisionProcedure::update(ArithVar x_i, const DeltaRational& v){
                  << assignment_x_i << "|-> " << v << endl;
   DeltaRational diff = v - assignment_x_i;
 
-  for(ArithVarSet::iterator basicIter = d_tableau.begin();
-      basicIter != d_tableau.end();
-      ++basicIter){
+  Assert(matchingSets(d_tableau, x_i));
+  ArithVarSet::iterator basicIter = d_tableau.beginColumn(x_i);
+  ArithVarSet::iterator endIter   = d_tableau.endColumn(x_i);
+  for(; basicIter != endIter; ++basicIter){
     ArithVar x_j = *basicIter;
     ReducedRowVector& row_j = d_tableau.lookup(x_j);
 
-    if(row_j.has(x_i)){
-      const Rational& a_ji = row_j.lookup(x_i);
+    Assert(row_j.has(x_i));
+    const Rational& a_ji = row_j.lookup(x_i);
 
-      const DeltaRational& assignment = d_partialModel.getAssignment(x_j);
-      DeltaRational  nAssignment = assignment+(diff * a_ji);
-      d_partialModel.setAssignment(x_j, nAssignment);
+    const DeltaRational& assignment = d_partialModel.getAssignment(x_j);
+    DeltaRational  nAssignment = assignment+(diff * a_ji);
+    d_partialModel.setAssignment(x_j, nAssignment);
 
-      d_queue.enqueueIfInconsistent(x_j);
-      //checkBasicVariable(x_j);
-    }
+    d_queue.enqueueIfInconsistent(x_j);
   }
 
   d_partialModel.setAssignment(x_i, v);
@@ -250,12 +280,21 @@ void SimplexDecisionProcedure::pivotAndUpdate(ArithVar x_i, ArithVar x_j, DeltaR
   DeltaRational tmp = d_partialModel.getAssignment(x_j) + theta;
   d_partialModel.setAssignment(x_j, tmp);
 
-  ArithVarSet::iterator basicIter = d_tableau.begin(), end = d_tableau.end();
-  for(; basicIter != end; ++basicIter){
+
+  Assert(matchingSets(d_tableau, x_j));
+  ArithVarSet::iterator basicIter = d_tableau.beginColumn(x_j);
+  ArithVarSet::iterator endIter   = d_tableau.endColumn(x_j);
+  for(; basicIter != endIter; ++basicIter){
+
+  //ArithVarSet::iterator basicIter = d_tableau.begin(), end = d_tableau.end();
+  //for(; basicIter != end; ++basicIter){
     ArithVar x_k = *basicIter;
     ReducedRowVector& row_k = d_tableau.lookup(x_k);
 
-    if(x_k != x_i && row_k.has(x_j)){
+    Assert(row_k.has(x_j));
+
+    //if(x_k != x_i && row_k.has(x_j)){
+    if(x_k != x_i ){
       const Rational& a_kj = row_k.lookup(x_j);
       DeltaRational nextAssignment = d_partialModel.getAssignment(x_k) + (theta * a_kj);
       d_partialModel.setAssignment(x_k, nextAssignment);
index d318a70e6b6974cb0d46b1a8b4d3db6ab026c91a..ebf7dbee8b901ad2764646090c47f5fbd05f8f19 100644 (file)
@@ -41,7 +41,7 @@ void Tableau::addRow(ArithVar basicVar,
   //The new basic variable cannot already be a basic variable
   Assert(!d_basicVariables.isMember(basicVar));
   d_basicVariables.add(basicVar);
-  ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount);
+  ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount, d_columnMatrix);
   d_rowsTable[basicVar] = row_current;
 
   //A variable in the row may have been made non-basic already.
@@ -90,17 +90,22 @@ void Tableau::pivot(ArithVar x_r, ArithVar x_s){
 
   row_s->pivot(x_s);
 
-  for(ArithVarSet::iterator basicIter = begin(), endIter = end();
-      basicIter != endIter; ++basicIter){
+  ArithVarSet::VarList copy(getColumn(x_s).getList());
+  vector<ArithVar>::iterator basicIter = copy.begin(), endIter = copy.end();
+
+  for(; basicIter != endIter; ++basicIter){
     ArithVar basic = *basicIter;
     if(basic == x_s) continue;
 
     ReducedRowVector& row_k = lookup(basic);
-    if(row_k.has(x_s)){
-      row_k.substitute(*row_s);
-    }
+    Assert(row_k.has(x_s));
+
+    row_k.substitute(*row_s);
   }
+  Assert(getColumn(x_s).size() == 1);
+  Assert(getRowCount(x_s) == 1);
 }
+
 void Tableau::printTableau(){
   Debug("tableau") << "Tableau::d_activeRows"  << endl;
 
index 36d61ba25137f065ffa2099817649241e1a8d05f..27aa1305ca49fb8aabfcff773dfa841571cb8ebb 100644 (file)
@@ -37,6 +37,10 @@ namespace CVC4 {
 namespace theory {
 namespace arith {
 
+typedef ArithVarSet Column;
+
+typedef std::vector<Column> ColumnMatrix;
+
 class Tableau {
 private:
 
@@ -47,6 +51,7 @@ private:
   ArithVarSet d_basicVariables;
 
   std::vector<uint32_t> d_rowCount;
+  ColumnMatrix d_columnMatrix;
 
 public:
   /**
@@ -55,7 +60,8 @@ public:
   Tableau() :
     d_rowsTable(),
     d_basicVariables(),
-    d_rowCount()
+    d_rowCount(),
+    d_columnMatrix()
   {}
   ~Tableau();
 
@@ -67,6 +73,16 @@ public:
     d_basicVariables.increaseSize();
     d_rowsTable.push_back(NULL);
     d_rowCount.push_back(0);
+
+    d_columnMatrix.push_back(ArithVarSet());
+
+    //TODO replace with version of ArithVarSet that handles misses as non-entries
+    // not as buffer overflows
+    ColumnMatrix::iterator i = d_columnMatrix.begin(), end = d_columnMatrix.end();
+    for(; i != end; ++i){
+      Column& col = *i;
+      col.increaseSize(d_columnMatrix.size());
+    }
   }
 
   bool isBasic(ArithVar v) const {
@@ -81,6 +97,19 @@ public:
     return d_basicVariables.end();
   }
 
+  const Column& getColumn(ArithVar v){
+    Assert(v < d_columnMatrix.size());
+    return d_columnMatrix[v];
+  }
+
+  Column::iterator beginColumn(ArithVar v){
+    return getColumn(v).begin();
+  }
+  Column::iterator endColumn(ArithVar v){
+    return getColumn(v).end();
+  }
+
+
   ReducedRowVector& lookup(ArithVar var){
     Assert(d_basicVariables.isMember(var));
     Assert(d_rowsTable[var] != NULL);
@@ -90,6 +119,8 @@ public:
 public:
   uint32_t getRowCount(ArithVar x){
     Assert(x < d_rowCount.size());
+    AlwaysAssert(d_rowCount[x] == getColumn(x).size());
+
     return d_rowCount[x];
   }