This commit merges the branch branches/arithmetic/quick-row-has into trunk. quick...
authorTim King <taking@cs.nyu.edu>
Thu, 17 Feb 2011 21:30:57 +0000 (21:30 +0000)
committerTim King <taking@cs.nyu.edu>
Thu, 17 Feb 2011 21:30:57 +0000 (21:30 +0000)
src/theory/arith/row_vector.cpp
src/theory/arith/row_vector.h

index 01131c4c98326a01d3715803de6ca8fdcef4fe62..6486077fbb63fc7908c4cdfdf54f147bdbdf3141 100644 (file)
@@ -48,6 +48,7 @@ void RowVector::zip(const std::vector< ArithVar >& variables,
   }
 }
 
+
 RowVector::RowVector(const std::vector< ArithVar >& variables,
                      const std::vector< Rational >& coefficients,
                      std::vector<uint32_t>& counts):
@@ -59,13 +60,28 @@ RowVector::RowVector(const std::vector< ArithVar >& variables,
 
   for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
     ++d_rowCount[getArithVar(*i)];
+    addArithVar(d_contains, getArithVar(*i));
   }
 
   Assert(isSorted(d_entries, true));
   Assert(noZeroCoefficients(d_entries));
 }
 
+void RowVector::addArithVar(ArithVarContainsSet& contains, ArithVar v){
+  if(v >= contains.size()){
+    contains.resize(v+1, false);
+  }
+  contains[v] = true;
+}
+
+void RowVector::removeArithVar(ArithVarContainsSet& contains, ArithVar v){
+  Assert(v < contains.size());
+  Assert(contains[v]);
+  contains[v] = false;
+}
+
 void RowVector::merge(VarCoeffArray& arr,
+                      ArithVarContainsSet& contains,
                       const VarCoeffArray& other,
                       const Rational& c,
                       std::vector<uint32_t>& counts){
@@ -85,6 +101,7 @@ void RowVector::merge(VarCoeffArray& arr,
     }else if(getArithVar(*curr1) > getArithVar(*curr2)){
       ++counts[getArithVar(*curr2)];
 
+      addArithVar(contains, getArithVar(*curr2));
       arr.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
       ++curr2;
     }else{
@@ -94,6 +111,7 @@ void RowVector::merge(VarCoeffArray& arr,
 
         arr.push_back(make_pair(getArithVar(*curr1), res));
       }else{
+        removeArithVar(contains, getArithVar(*curr2));
         --counts[getArithVar(*curr2)];
       }
       ++curr1;
@@ -107,6 +125,8 @@ void RowVector::merge(VarCoeffArray& arr,
   while(curr2 != end2){
     ++counts[getArithVar(*curr2)];
 
+    addArithVar(contains, getArithVar(*curr2));
+
     arr.push_back(make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
     ++curr2;
   }
@@ -123,7 +143,7 @@ void RowVector::multiply(const Rational& c){
 void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other){
   Assert(c != 0);
 
-  merge(d_entries, other.d_entries, c, d_rowCount);
+  merge(d_entries, d_contains, other.d_entries, c, d_rowCount);
 }
 
 void RowVector::printRow(){
@@ -144,7 +164,7 @@ ReducedRowVector::ReducedRowVector(ArithVar basic,
   VarCoeffArray justBasic;
   justBasic.push_back(make_pair(basic, Rational(-1)));
 
-  merge(d_entries,justBasic, Rational(1), d_rowCount);
+  merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount);
 
   Assert(wellFormed());
 }
index a967f8d68fd7f5a2a63d00fe4ce3b6ef5fb6e260..05ceeb98672a0fab3056aeb341d22dcbe1cc05a6 100644 (file)
@@ -31,6 +31,8 @@ public:
   typedef std::vector<VarCoeffPair> VarCoeffArray;
   typedef VarCoeffArray::const_iterator NonZeroIterator;
 
+  typedef std::vector<bool> ArithVarContainsSet;
+
   /**
    * Let c be -1 if strictlySorted is true and c be 0 otherwise.
    * isSorted(arr, strictlySorted) is then equivalent to
@@ -38,12 +40,6 @@ public:
    */
   static bool isSorted(const VarCoeffArray& arr, bool strictlySorted);
 
-  /**
-   * noZeroCoefficients(arr) is equivalent to
-   *  0 != getCoefficient(arr[i]) for all i.
-   */
-  static bool noZeroCoefficients(const VarCoeffArray& arr);
-
   /**
    * Zips together an array of variables and coefficients and appends
    * it to the end of an output vector.
@@ -52,10 +48,20 @@ public:
                   const std::vector< Rational >& coefficients,
                   VarCoeffArray& output);
 
-  static void merge(VarCoeffArray& arr, const VarCoeffArray& other, const Rational& c, std::vector<uint32_t>& count);
-
+  static void merge(VarCoeffArray& arr,
+                    ArithVarContainsSet& contains,
+                    const VarCoeffArray& other,
+                    const Rational& c,
+                    std::vector<uint32_t>& count);
 
 protected:
+  /**
+   * Debugging code.
+   * noZeroCoefficients(arr) is equivalent to
+   *  0 != getCoefficient(arr[i]) for all i.
+   */
+  static bool noZeroCoefficients(const VarCoeffArray& arr);
+
   /**
    * Invariants:
    * - isSorted(d_entries, true)
@@ -63,6 +69,12 @@ protected:
    */
   VarCoeffArray d_entries;
 
+  /**
+   * Invariants:
+   * - This set is the same as the set maintained in d_entries.
+   */
+  ArithVarContainsSet d_contains;
+
   std::vector<uint32_t>& d_rowCount;
 
   NonZeroIterator lower_bound(ArithVar x_j) const{
@@ -89,14 +101,26 @@ public:
 
   /** Returns true if the variable is in the row. */
   bool has(ArithVar x_j) const{
+    if(x_j >= d_contains.size()){
+      return false;
+    }else{
+      return d_contains[x_j];
+    }
+  }
+
+private:
+  /** Debugging code. */
+  bool hasInEntries(ArithVar x_j) const {
     return std::binary_search(d_entries.begin(), d_entries.end(), make_pair(x_j,0), cmp);
   }
+public:
 
   /**
    * Returns the coefficient of a variable in the row.
    */
   const Rational& lookup(ArithVar x_j) const{
     Assert(has(x_j));
+    Assert(hasInEntries(x_j));
     NonZeroIterator lb = lower_bound(x_j);
     return getCoefficient(*lb);
   }
@@ -113,6 +137,17 @@ public:
   void addRowTimesConstant(const Rational& c, const RowVector& other);
 
   void printRow();
+
+protected:
+  /**
+   * Adds v to d_contains.
+   * This may resize d_contains.
+   */
+  static void addArithVar(ArithVarContainsSet& contains, ArithVar v);
+
+  /** Removes v from d_contains. */
+  static void removeArithVar(ArithVarContainsSet& contains, ArithVar v);
+
 }; /* class RowVector */
 
 /**
@@ -148,6 +183,15 @@ public:
     return d_basic;
   }
 
+  /** Return true if x is in the row and is not the basic variable. */
+  bool hasNonBasic(ArithVar x) const {
+    if(x == basic()){
+      return false;
+    }else{
+      return has(x);
+    }
+  }
+
   void pivot(ArithVar x_j);
 
   void substitute(const ReducedRowVector& other);