updates for bitvectors
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Mon, 2 May 2011 19:09:11 +0000 (19:09 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Mon, 2 May 2011 19:09:11 +0000 (19:09 +0000)
src/theory/bv/cd_set_collection.h
src/theory/bv/equality_engine.h
src/theory/bv/slice_manager.h
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h
src/theory/bv/theory_bv_utils.h

index 217ebadcde15fe4abc5d4ae1c062c70fec33668f..30e4e47ec1c1ab7a824ad71db7b4245af63141ba 100644 (file)
@@ -252,7 +252,7 @@ public:
     backtrack();
     Assert(isValid(set));
 
-    const_value_reference candidate_value;
+    const_value_reference candidate_value = value_type();
     bool candidate_found = false;
 
     // Find the biggest node smaleer than value (it must exist)
@@ -279,7 +279,7 @@ public:
     backtrack();
     Assert(isValid(set));
 
-    const_value_reference candidate_value;
+    const_value_reference candidate_value = value_type();
     bool candidate_found = false;
 
     // Find the smallest node bigger than value (it must exist)
index 0450c4535128b8c1d597fa6484e755a173159344..31a4bfd27e3e5be79c8e3b9d333f6eb89012dc2d 100644 (file)
@@ -812,7 +812,7 @@ Node EqualityEngine<OwnerClass, NotifyClass, UnionFindPreferences>::normalizeWit
   if (node != result) {
     std::vector<TNode> equalities;
     getExplanation(result, node, equalities);
-    assumptions.insert(equalities.begin(), equalities.end());
+    utils::getConjuncts(equalities, assumptions);
   }
 
   // If asked, substitute the children with their representatives
index 78ed4f26527d2a786998180d74360d42673e477c..4fb11f1055781eab291f60bfd9b32a432df1caf3 100644 (file)
@@ -301,18 +301,20 @@ bool SliceManager<TheoryBitvector>::sliceAndSolve(std::vector<Node>& lhs, std::v
 
     // We slice constants immediately
     if (sizeDifference > 0 && lhsTerm.getKind() == kind::CONST_BITVECTOR) {
-      BitVector low  = lhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, 0);
-      BitVector high = lhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, utils::getSize(rhsTerm));
-      lhs.push_back(utils::mkConst(low));
-      lhs.push_back(utils::mkConst(high));
+      Node low  = utils::mkConst(lhsTerm.getConst<BitVector>().extract(sizeDifference - 1, 0));
+      Node high = utils::mkConst(lhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, sizeDifference));
+      d_equalityEngine.addTerm(low); d_equalityEngine.addTerm(high);
+      lhs.push_back(low);
+      lhs.push_back(high);
       rhs.push_back(rhsTerm);
       continue;
     }
     if (sizeDifference < 0 && rhsTerm.getKind() == kind::CONST_BITVECTOR) {
-      BitVector low  = rhsTerm.getConst<BitVector>().extract(utils::getSize(lhsTerm) - 1, 0);
-      BitVector high = rhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, utils::getSize(lhsTerm));
-      rhs.push_back(utils::mkConst(low));
-      rhs.push_back(utils::mkConst(high));
+      Node low  = utils::mkConst(rhsTerm.getConst<BitVector>().extract(-sizeDifference - 1, 0));
+      Node high = utils::mkConst(rhsTerm.getConst<BitVector>().extract(utils::getSize(rhsTerm) - 1, -sizeDifference));
+      d_equalityEngine.addTerm(low); d_equalityEngine.addTerm(high);
+      rhs.push_back(low);
+      rhs.push_back(high);
       lhs.push_back(lhsTerm);
       continue;
     }
@@ -418,7 +420,7 @@ bool SliceManager<TheoryBitvector>::sliceAndSolve(std::vector<Node>& lhs, std::v
         std::vector<TNode> explanation;
         d_equalityEngine.getExplanation(lhsTerm, lhsTermRepresentative, explanation);
         std::set<TNode> additionalAssumptions(assumptions);
-        additionalAssumptions.insert(explanation.begin(), explanation.end());
+        utils::getConjuncts(explanation, additionalAssumptions);
         bool ok = solveEquality(lhsTermRepresentative, concat, additionalAssumptions);
         if (!ok) return false;
       } else {
@@ -438,7 +440,7 @@ bool SliceManager<TheoryBitvector>::sliceAndSolve(std::vector<Node>& lhs, std::v
         std::vector<TNode> explanation;
         d_equalityEngine.getExplanation(rhsTerm, rhsTermRepresentative, explanation);
         std::set<TNode> additionalAssumptions(assumptions);
-        additionalAssumptions.insert(explanation.begin(), explanation.end());
+        utils::getConjuncts(explanation, additionalAssumptions);
         bool ok = solveEquality(rhsTermRepresentative, concat, additionalAssumptions);
         if (!ok) return false;
       } else {
@@ -509,6 +511,8 @@ bool SliceManager<TheoryBitvector>::addSlice(Node node, unsigned slicePoint) {
 
   TNode nodeBase = baseTerm(node);
 
+  Assert(nodeBase.getKind() != kind::CONST_BITVECTOR);
+
   set_reference sliceSet;
   slicing_map::iterator find = d_nodeSlicing.find(nodeBase);
   if (find == d_nodeSlicing.end()) {
@@ -550,7 +554,7 @@ bool SliceManager<TheoryBitvector>::addSlice(Node node, unsigned slicePoint) {
     std::set<TNode> assumptions;
     std::vector<TNode> equalities;
     d_equalityEngine.getExplanation(nodeSlice, nodeSliceRepresentative, equalities);
-    assumptions.insert(equalities.begin(), equalities.end());
+    utils::getConjuncts(equalities, assumptions);
     ok = solveEquality(nodeSliceRepresentative, concat, assumptions);
   }
 
@@ -592,23 +596,32 @@ inline bool SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>&
   Assert(d_setCollection.size(nodeSliceSet) >= 2);
 
   BVDebug("slicing") << "SliceManager::slice(" << node << "): current: " << d_setCollection.toString(nodeSliceSet) << std::endl;
-  std::vector<size_t> slicePoints;
-  if (low + 1 < high) {
-    d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints);
-  }
+  
   // Go through all the points i_0 <= low < i_1 < ... < i_{n-1} < high <= i_n from the slice set
   // and generate the slices [i_0:low-1][low:i_1-1] [i_1:i2] ... [i_{n-1}:high-1][high:i_n-1]. They are in reverse order,
   // as they should be
-  size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1);
-  BVDebug("slicing") << "SliceManager::slice(" << node << "): i_0: " << i_0 << std::endl;
+  
+  // The high bound already in the slicing
   size_t i_n = high == utils::getSize(nodeBase) ? high: d_setCollection.next(nodeSliceSet, high - 1);
-  BVDebug("slicing") << "SliceManager::slice(" << node << "): i_n: " << i_n << std::endl;
-
-  // Add the new points to the slice set (they might be there already)
+  BVDebug("slicing") << "SliceManager::slice(" << node << "): i_n: " << i_n << std::endl;  
+  // Add the new point to the slice set (they might be there already)
   if (high < i_n) {
     if (!addSlice(nodeBase, high)) return false;
   }
+  // The low bound already in the slicing (slicing might have changed after adding high)
+  size_t i_0 = low == 0 ? 0 : d_setCollection.prev(nodeSliceSet, low + 1);
+  BVDebug("slicing") << "SliceManager::slice(" << node << "): i_0: " << i_0 << std::endl;
+  // Add the new points to the slice set (they might be there already)
+  if (i_0 < low) {
+    if (!addSlice(nodeBase, low)) return false;
+  }
+
+  // Get the slice points
+  std::vector<size_t> slicePoints;
+  if (low + 1 < high) {
+    d_setCollection.getElements(nodeSliceSet, low + 1, high - 1, slicePoints);
+  }
+
   // Construct the actuall slicing
   if (slicePoints.size() > 0) {
     BVDebug("slicing") << "SliceManager::slice(" << node << "): adding" << utils::mkExtract(nodeBase, slicePoints[0] - 1, low) << std::endl;
@@ -622,11 +635,7 @@ inline bool SliceManager<TheoryBitvector>::slice(TNode node, std::vector<Node>&
   } else {
     sliced.push_back(utils::mkExtract(nodeBase, high - 1, low));
   }
-  // Add the new points to the slice set (they might be there already)
-  if (i_0 < low) {
-    if (!addSlice(nodeBase, low)) return false;
-  }
-
+  
   return true;
 }
 
index 314e6b71487dd6c980c74a6f199ef97e12bb669c..7a8ebb85c1ecdd1f9c258669214fce3c3c98baf0 100644 (file)
@@ -46,9 +46,8 @@ void TheoryBV::preRegisterTerm(TNode node) {
         d_eqEngine.addTerm(node[1][i]);
       }
     }
-    size_t triggerId = d_eqEngine.addTrigger(node[0], node[1]);
-    Assert(triggerId == d_triggers.size());
-    d_triggers.push_back(node);
+
+    d_normalization[node] = new Normalization(d_context, node);
   }
 }
 
@@ -56,6 +55,10 @@ void TheoryBV::check(Effort e) {
 
   BVDebug("bitvector") << "TheoryBV::check(" << e << ")" << std::endl;
 
+  // Normalization iterators
+  NormalizationMap::iterator it = d_normalization.begin();
+  NormalizationMap::iterator it_end = d_normalization.end();
+
   // Get all the assertions
   std::vector<TNode> assertionsList;
   while (!done()) {
@@ -65,6 +68,8 @@ void TheoryBV::check(Effort e) {
     assertionsList.push_back(assertion);
   }
 
+  bool normalizeEqualities = false;
+
   for (unsigned i = 0; i < assertionsList.size(); ++ i) {
 
     TNode assertion = assertionsList[i];
@@ -77,27 +82,23 @@ void TheoryBV::check(Effort e) {
       // Slice and solve the equality
       bool ok = d_sliceManager.solveEquality(assertion[0], assertion[1]);
       if (!ok) return;
+      // Normalize all equalities
+      normalizeEqualities = true;
+      it = d_normalization.begin();
+      it = d_normalization.end();
       break;
     }
     case kind::NOT: {
-      // We need to check this as the equality trigger might have been true when we made it
-      TNode equality = assertion[0];
-
-      // Assumptions
-      std::set<TNode> assumptions;
-      Node lhsNormalized = d_eqEngine.normalize(equality[0], assumptions);
-      Node rhsNormalized = d_eqEngine.normalize(equality[1], assumptions);
-
-      BVDebug("bitvector") << "TheoryBV::check(" << e << "): normalizes to " << lhsNormalized << " = " << rhsNormalized << std::endl;
-      
-      // No need to slice the equality, the whole thing *should* be deduced
-      if (lhsNormalized == rhsNormalized) {
-        BVDebug("bitvector") << "TheoryBV::check(" << e << "): conflict with " << utils::setToString(assumptions) << std::endl;
-        assumptions.insert(assertion);
-        d_out->conflict(mkConjunction(assumptions));
-        return;
-      } else {
-        d_disequalities.push_back(assertion);
+      if (!normalizeEqualities) {
+        // We still need to check this dis-equality, as it might have been pre-registered just now
+        // so we didn't have a chance to propagate
+        it = d_normalization.find(assertion[0]);
+        if (it->second->assumptions.size() == 1) {
+          // Just normalize this equality
+          normalizeEqualities = true;
+          it_end = it;
+          it_end ++;
+        }
       }
       break;
     }
@@ -106,28 +107,70 @@ void TheoryBV::check(Effort e) {
     }
   }
 
-  if (fullEffort(e)) {
+  if (normalizeEqualities) {
 
-    BVDebug("bitvector") << "TheoryBV::check(" << e << "): checking dis-equalities" << std::endl;
+    BVDebug("bitvector") << "Checking for propagations" << std::endl;
+  
+    NormalizationMap::iterator it = d_normalization.begin();
+    NormalizationMap::iterator it_end = d_normalization.end();
+    for(; it != it_end; ++ it) {
 
-    for (unsigned i = 0, i_end = d_disequalities.size(); i < i_end; ++ i) {
-
-      BVDebug("bitvector") << "TheoryBV::check(" << e << "): checking " << d_disequalities[i] << std::endl;
+      TNode equality = it->first;
+      BVDebug("bitvector") << "Checking " << equality << std::endl;
+      Normalization& normalization = *it->second;
+      
+      // If asserted, we don't care
+      if (d_assertions.find(equality) != d_assertions.end()) continue; 
 
-      TNode equality = d_disequalities[i][0];
       // Assumptions
-      std::set<TNode> assumptions;
-      Node lhsNormalized = d_eqEngine.normalize(equality[0], assumptions);
-      Node rhsNormalized = d_eqEngine.normalize(equality[1], assumptions);
+      std::set<TNode> assumptions; 
+      TNode lhs = normalization.equalities.back()[0];
+      TNode rhs = normalization.equalities.back()[1];
+      // If already satisfied, do nothing
+      if (lhs == rhs) continue;
+
+      Node lhsNormalized = d_eqEngine.normalize(lhs, assumptions);
+      Node rhsNormalized = d_eqEngine.normalize(rhs, assumptions);
+
+      if (lhsNormalized == lhs && rhsNormalized == rhs) continue;
+
+      normalization.equalities.push_back(lhsNormalized.eqNode(rhsNormalized));
+      normalization.assumptions.push_back(assumptions);
+
+      BVDebug("bitvector") << "Adding normalization " << lhsNormalized.eqNode(rhsNormalized) << std::endl;
+      BVDebug("bitvector") << "       assumptions   " << setToString(assumptions) << std::endl;
+
 
       BVDebug("bitvector") << "TheoryBV::check(" << e << "): normalizes to " << lhsNormalized << " = " << rhsNormalized << std::endl;
 
-      // No need to slice the equality, the whole thing *should* be deduced
-      if (lhsNormalized == rhsNormalized) {
-        BVDebug("bitvector") << "TheoryBV::check(" << e << "): conflict with " << utils::setToString(assumptions) << std::endl;
-        assumptions.insert(d_disequalities[i]);
-        d_out->conflict(mkConjunction(assumptions));
-        return;
+      // If both are equal we can propagate
+      bool propagate = lhsNormalized == rhsNormalized;
+      // otherwise if both are constants, we propagate negation (if not already there)
+      bool propagateNegation = !propagate &&
+          lhsNormalized.getKind() == kind::CONST_BITVECTOR && rhsNormalized.getKind() == kind::CONST_BITVECTOR
+          && d_assertions.find(equality.notNode()) == d_assertions.end();
+          ;
+      if (propagate || propagateNegation) {
+        Node implied = propagate        ? Node(equality)     : equality.notNode() ;
+        Node impliedNegated = propagate ? equality.notNode() : Node(equality)     ;
+        // If the negation of what's implied has been asserted, we are in conflict
+        if (d_assertions.find(impliedNegated) != d_assertions.end()) {
+          BVDebug("bitvector") << "TheoryBV::check(" << e << "): conflict with " << utils::setToString(assumptions) << std::endl;
+          // Construct the assumptions
+          for (unsigned i = 0; i < normalization.assumptions.size(); ++ i) {
+            assumptions.insert(normalization.assumptions[i].begin(), normalization.assumptions[i].end());
+          }
+          // Make the conflict
+          assumptions.insert(impliedNegated);
+          d_out->conflict(mkConjunction(assumptions));
+          return;
+        }
+        // Otherwise we propagate the implication
+        else {
+          BVDebug("bitvector") << "TheoryBV::check(" << e << "): propagating " << implied << std::endl;
+          d_out->propagate(implied);
+          d_assertions.insert(implied);
+        }
       }
     }
   }
@@ -138,6 +181,8 @@ bool TheoryBV::triggerEquality(size_t triggerId) {
   Assert(triggerId < d_triggers.size());
   BVDebug("bitvector") << "TheoryBV::triggerEquality(" << triggerId << "): " << d_triggers[triggerId] << std::endl;
 
+  return true;
+
   TNode equality = d_triggers[triggerId];
 
   // If we have just asserted this equality ignore it
@@ -181,13 +226,17 @@ Node TheoryBV::getValue(TNode n) {
 
 void TheoryBV::explain(TNode node) {
   BVDebug("bitvector") << "TheoryBV::explain(" << node << ")" << std::endl;
-  if(node.getKind() == kind::EQUAL) {
-    std::vector<TNode> reasons;
-    d_eqEngine.getExplanation(node[0], node[1], reasons);
-    std::set<TNode> simpleReasons;
-    utils::getConjuncts(reasons, simpleReasons);
-    d_out->explanation(utils::mkConjunction(simpleReasons));
-    return;
+
+  TNode equality = node.getKind() == kind::NOT ? node[0] : node;
+  Assert(equality.getKind() == kind::EQUAL);
+
+  context::CDList< set<TNode> >& vec = d_normalization[equality]->assumptions;
+  std::set<TNode> assumptions;
+  for (unsigned i = 0; i < vec.size(); ++ i) {
+    BVDebug("bitvector") << "Adding normalization " << d_normalization[equality]->equalities[i] << std::endl;
+    BVDebug("bitvector") << "       assumptions   " << setToString(d_normalization[equality]->assumptions[i]) << std::endl;
+    assumptions.insert(vec[i].begin(), vec[i].end());
   }
-  Unreachable();
+  d_out->explanation(utils::mkConjunction(assumptions));
+  return;
 }
index 748352321c6623e8d1ea57eb28042bff1e3078f8..8c2b59efac34ce970bd1466e95b9a97b1a01ff0b 100644 (file)
@@ -24,6 +24,7 @@
 #include "theory/theory.h"
 #include "context/context.h"
 #include "context/cdset.h"
+#include "context/cdlist.h"
 #include "theory/bv/equality_engine.h"
 #include "theory/bv/slice_manager.h"
 
@@ -82,7 +83,6 @@ public:
 
 private:
 
-
   /** Equality reasoning engine */
   BvEqualityEngine d_eqEngine;
 
@@ -91,6 +91,9 @@ private:
 
   /** Equality triggers indexed by ids from the equality manager */
   std::vector<Node> d_triggers;
+  
+  /** The context we are using */
+  context::Context* d_context;
 
   /** The asserted stuff */
   context::CDSet<TNode, TNodeHashFunction> d_assertions;
@@ -98,13 +101,36 @@ private:
   /** Asserted dis-equalities */
   context::CDList<TNode> d_disequalities;
 
+  struct Normalization {
+    context::CDList<Node> equalities;
+    context::CDList< std::set<TNode> > assumptions;
+    Normalization(context::Context* c, TNode eq)
+    : equalities(c), assumptions(c) {
+      equalities.push_back(eq);
+      assumptions.push_back(std::set<TNode>());
+    }
+  };
+
+  /** Map from equalities to their noramlization information */
+  typedef __gnu_cxx::hash_map<TNode, Normalization*, TNodeHashFunction> NormalizationMap;
+  NormalizationMap d_normalization;
+
   /** Called by the equality managere on triggers */
   bool triggerEquality(size_t triggerId);
 
+  Node d_true;
+
 public:
 
-  TheoryBV(context::Context* c, OutputChannel& out, Valuation valuation) :
-    Theory(THEORY_BV, c, out, valuation), d_eqEngine(*this, c, "theory::bv::EqualityEngine"), d_sliceManager(*this, c), d_assertions(c), d_disequalities(c) {
+  TheoryBV(context::Context* c, OutputChannel& out, Valuation valuation)
+  : Theory(THEORY_BV, c, out, valuation), 
+    d_eqEngine(*this, c, "theory::bv::EqualityEngine"), 
+    d_sliceManager(*this, c), 
+    d_context(c),
+    d_assertions(c), 
+    d_disequalities(c)
+  {
+    d_true = utils::mkTrue();
   }
 
   BvEqualityEngine& getEqualityEngine() {
index 80751fe4c772ccc2604f291493ed174b525ef5f0..a3135f077cec366096f44ae120fcd05015aa3478 100644 (file)
@@ -106,7 +106,6 @@ inline void getConjuncts(std::vector<TNode>& nodes, std::set<TNode>& conjuncts)
   }
 }
 
-
 inline Node mkConjunction(const std::set<TNode> nodes) {
   std::set<TNode> expandedNodes;
 
@@ -115,12 +114,8 @@ inline Node mkConjunction(const std::set<TNode> nodes) {
   while (it != it_end) {
     TNode current = *it;
     if (current != mkTrue()) {
-      Assert(current != mkFalse());
-      if (current.getKind() == kind::AND) {
-        getConjuncts(current, expandedNodes);
-      } else {
-        expandedNodes.insert(current);
-      }
+      Assert(current.getKind() == kind::EQUAL || (current.getKind() == kind::NOT && current[0].getKind() == kind::EQUAL));
+      expandedNodes.insert(current);
     }
     ++ it;
   }