added support for disequalities in the inequality solver
authorLiana Hadarean <lianahady@gmail.com>
Mon, 25 Mar 2013 03:38:33 +0000 (23:38 -0400)
committerLiana Hadarean <lianahady@gmail.com>
Mon, 25 Mar 2013 03:38:33 +0000 (23:38 -0400)
src/theory/bv/bv_inequality_graph.cpp
src/theory/bv/bv_inequality_graph.h
src/theory/bv/bv_subtheory_inequality.cpp
src/theory/bv/theory_bv.h

index 704f99039d09948b5459fff5c7fd024c14cfcf62..4bd31587295f1e85f24f7d882b4211e1712e4658 100644 (file)
@@ -261,9 +261,27 @@ void InequalityGraph::addEdge(TermId a, TermId b, bool strict, TermId reason) {
   d_undoStackIndex = d_undoStackIndex + 1; 
 }
 
+void InequalityGraph::initializeModelValue(TNode node) {
+  TermId id = getTermId(node); 
+  Assert (!hasModelValue(id));
+  bool isConst = node.getKind() == kind::CONST_BITVECTOR;
+  unsigned size = utils::getSize(node); 
+  BitVector value = isConst? node.getConst<BitVector>() : BitVector(size, 0u); 
+  setModelValue(id, ModelValue(value, UndefinedTermId, UndefinedReasonId));
+}
+
+bool InequalityGraph::isRegistered(TNode term) const {
+  return d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end(); 
+}
+
 TermId InequalityGraph::registerTerm(TNode term) {
   if (d_termNodeToIdMap.find(term) != d_termNodeToIdMap.end()) {
-    return d_termNodeToIdMap[term]; 
+    TermId id = d_termNodeToIdMap[term];
+    if (!hasModelValue(id)) {
+      // we could have backtracked and
+      initializeModelValue(term); 
+    }
+    return id; 
   }
 
   // store in node mapping
@@ -275,21 +293,17 @@ TermId InequalityGraph::registerTerm(TNode term) {
   
   // create InequalityNode
   unsigned size = utils::getSize(term);
+
   bool isConst = term.getKind() == kind::CONST_BITVECTOR;
-  BitVector value = isConst? term.getConst<BitVector>() : BitVector(size, 0u); 
-  
   InequalityNode ineq = InequalityNode(id, size, isConst);
-  setModelValue(id, ModelValue(value, UndefinedTermId, UndefinedReasonId));
-  
+
   Assert (d_ineqNodes.size() == id); 
   d_ineqNodes.push_back(ineq);
   
   Assert (d_ineqEdges.size() == id); 
   d_ineqEdges.push_back(Edges());
 
-  // add the default edges min <= term <= max
-  //  addEdge(getMinValueId(size), id, false, AxiomReasonId);
-  // addEdge(id, getMaxValueId(size), false, AxiomReasonId); 
+  initializeModelValue(term); 
   
   return id; 
 }
@@ -314,6 +328,11 @@ TNode InequalityGraph::getTermNode(TermId id) const {
   return d_termNodes[id]; 
 }
 
+TermId InequalityGraph::getTermId(TNode node) const {
+  Assert (d_termNodeToIdMap.find(node) != d_termNodeToIdMap.end());
+  return d_termNodeToIdMap.find(node)->second; 
+}
+
 void InequalityGraph::setConflict(const std::vector<ReasonId>& conflict) {
   Assert (!d_inConflict); 
   d_inConflict = true;
@@ -351,8 +370,8 @@ bool InequalityGraph::hasModelValue(TermId id) const {
 }
 
 BitVector InequalityGraph::getValue(TermId id) const {
-  Assert (hasModelValue(id));
-  BitVector res = (*(d_modelValues.find(id))).second.value; 
+  Assert (hasModelValue(id)); 
+  BitVector res = (*(d_modelValues.find(id))).second.value;
   return res; 
 }
 
@@ -361,6 +380,66 @@ bool InequalityGraph::hasReason(TermId id) const {
   return mv.reason != UndefinedReasonId; 
 }
 
+bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) {
+  Debug("bv-inequality") << "InequalityGraph::addDisequality " << reason << "\n"; 
+  d_disequalities.push_back(reason);
+
+  if (!isRegistered(a) || !isRegistered(b)) {
+    splitDisequality(reason);
+    return true; 
+  }
+  TermId id_a = getTermId(a);
+  TermId id_b = getTermId(b);
+  if (!hasModelValue(id_a)) {
+    initializeModelValue(a); 
+  }
+  if (!hasModelValue(id_b)) {
+    initializeModelValue(b); 
+  }
+  const BitVector& val_a = getValue(id_a);
+  const BitVector& val_b = getValue(id_b);
+  if (val_a == val_b) {
+    if (a.getKind() == kind::CONST_BITVECTOR) {
+      // then we know b cannot be smaller  than the assigned value so we try to make it larger
+      return addInequality(a, b, true, reason);
+    }
+    if (b.getKind() == kind::CONST_BITVECTOR) {
+      return addInequality(b, a, true, reason);
+    }
+    // if none of the terms are constants just add the lemma 
+    splitDisequality(reason);
+  } else {
+    Debug("bv-inequality-internal") << "Disequal: " << a << " => " << val_a.toString(10) << "\n"
+                                    << "          " << b << " => " << val_b.toString(10) << "\n"; 
+  }
+  return true; 
+}
+
+void InequalityGraph::splitDisequality(TNode diseq) {
+  Debug("bv-inequality-internal")<<"InequalityGraph::splitDisequality " << diseq <<"\n"; 
+  Assert (diseq.getKind() == kind::NOT && diseq[0].getKind() == kind::EQUAL);
+  TNode a = diseq[0][0];
+  TNode b = diseq[0][1];
+  Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b);
+  Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a);
+  Node split = utils::mkNode(kind::OR, a_lt_b, b_lt_a);
+  Node lemma = utils::mkNode(kind::IMPLIES, diseq, split);
+  if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) { 
+    d_lemmaQueue.push_back(lemma);
+  }
+}
+
+void InequalityGraph::getNewLemmas(std::vector<TNode>& new_lemmas) {
+  for (unsigned i = d_lemmaIndex; i < d_lemmaQueue.size(); ++i)  {
+    TNode lemma = d_lemmaQueue[i];
+    if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) {
+      new_lemmas.push_back(lemma);
+      d_lemmasAdded.insert(lemma); 
+    }
+    d_lemmaIndex = d_lemmaIndex + 1; 
+  } 
+}
+
 std::string InequalityGraph::PQueueElement::toString() const {
   ostringstream os;
   os << "(id: " << id << ", lower_bound: " << lower_bound.toString(10) <<", old_value: " << model_value.value.toString(10) << ")"; 
index 57e59f6f5d0481d0ee1f4fa86f5fe396ab551297..1335eff93c337c251517e435096e8bd9556f7d91 100644 (file)
@@ -111,7 +111,7 @@ class InequalityGraph : public context::ContextNotifyObj{
   typedef __gnu_cxx::hash_set<TermId> TermIdSet;
 
   typedef std::priority_queue<PQueueElement> BFSQueue; 
-  
+  typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet; 
   std::vector<InequalityNode> d_ineqNodes;
   std::vector< Edges > d_ineqEdges;
   
@@ -125,7 +125,8 @@ class InequalityGraph : public context::ContextNotifyObj{
   std::vector<TNode> d_conflict;
   bool d_signed; 
 
-  context::CDHashMap<TermId, ModelValue>  d_modelValues; 
+  context::CDHashMap<TermId, ModelValue>  d_modelValues;
+  void initializeModelValue(TNode node); 
   void setModelValue(TermId term, const ModelValue& mv);
   ModelValue getModelValue(TermId term) const;
   bool hasModelValue(TermId id) const; 
@@ -142,7 +143,8 @@ class InequalityGraph : public context::ContextNotifyObj{
   TermId registerTerm(TNode term);
   TNode getTermNode(TermId id) const; 
   TermId getTermId(TNode node) const;
-
+  bool isRegistered(TNode term) const; 
+  
   ReasonId registerReason(TNode reason);
   TNode getReasonNode(ReasonId id) const;
   
@@ -152,10 +154,6 @@ class InequalityGraph : public context::ContextNotifyObj{
   const InequalityNode& getInequalityNode(TermId id) const { Assert (id < d_ineqNodes.size()); return d_ineqNodes[id]; }
   unsigned getBitwidth(TermId id) const { return getInequalityNode(id).getBitwidth(); }
   bool isConst(TermId id) const { return getInequalityNode(id).isConstant(); }
-  // BitVector maxValue(unsigned bitwidth);
-  // BitVector minValue(unsigned bitwidth);
-  // TermId getMaxValueId(unsigned bitwidth);
-  // TermId getMinValueId(unsigned bitwidth);
   
   BitVector getValue(TermId id) const; 
     
@@ -191,7 +189,18 @@ class InequalityGraph : public context::ContextNotifyObj{
    * @param explanation 
    */
   void computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation); 
+  void splitDisequality(TNode diseq); 
 
+  /**
+     Disequality reasoning
+   */
+  
+  /*** The currently asserted disequalities */
+  context::CDQueue<TNode> d_disequalities;
+  context::CDQueue<Node> d_lemmaQueue; 
+  context::CDO<unsigned> d_lemmaIndex; 
+  TNodeSet d_lemmasAdded; 
+  
   /** Backtracking mechanisms **/
   std::vector<std::pair<TermId, InequalityEdge> > d_undoStack;
   context::CDO<unsigned> d_undoStackIndex; 
@@ -213,6 +222,10 @@ public:
       d_conflict(),
       d_signed(s),
       d_modelValues(c),
+      d_disequalities(c),
+      d_lemmaQueue(c),
+      d_lemmaIndex(c, 0),
+      d_lemmasAdded(),
       d_undoStack(),
       d_undoStackIndex(c)
   {}
@@ -227,9 +240,11 @@ public:
    * @return 
    */
   bool addInequality(TNode a, TNode b, bool strict, TNode reason);
+  bool addDisequality(TNode a, TNode b, TNode reason); 
   bool areLessThan(TNode a, TNode b);
   void getConflict(std::vector<TNode>& conflict);
   virtual ~InequalityGraph() throw(AssertionException) {}
+  void getNewLemmas(std::vector<TNode>& new_lemmas);
 }; 
 
 }
index f856c94106dbe0eb5e520250c5665ba51a26cfce..6b9842e8fa561a7784aa0266e16e9f353655ad08 100644 (file)
@@ -27,15 +27,21 @@ using namespace CVC4::theory::bv;
 using namespace CVC4::theory::bv::utils;
 
 bool InequalitySolver::check(Theory::Effort e) {
+  Debug("bv-subtheory-inequality") << "InequalitySolveR::check("<< e <<")\n"; 
   bool ok = true; 
   while (!done() && ok) {
     TNode fact = get();
+    Debug("bv-subtheory-inequality") << "  "<< fact <<"\n"; 
     if (fact.getKind() == kind::EQUAL) {
       TNode a = fact[0];
       TNode b = fact[1];
       ok = d_inequalityGraph.addInequality(a, b, false, fact);
       if (ok)
         ok = d_inequalityGraph.addInequality(b, a, false, fact); 
+    } else if (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL) {
+      TNode a = fact[0][0];
+      TNode b = fact[0][1];
+      ok = d_inequalityGraph.addDisequality(a, b, fact);
     }
     if (fact.getKind() == kind::NOT && fact[0].getKind() == kind::BITVECTOR_ULE) {
       TNode a = fact[0][1];
@@ -61,6 +67,12 @@ bool InequalitySolver::check(Theory::Effort e) {
     d_bv->setConflict(utils::mkConjunction(conflict));
     return false; 
   }
+  // send out any lemmas
+  std::vector<TNode> lemmas;
+  d_inequalityGraph.getNewLemmas(lemmas); 
+  for(unsigned i = 0; i < lemmas.size(); ++i) {
+    d_bv->lemma(lemmas[i]); 
+  }
   return true; 
 }
 
index 13a475d3d28415e255df42f62c216b9a63dfae8d..54260deb98d0ce377fab656a936131d49af86335 100644 (file)
@@ -137,6 +137,8 @@ private:
 
   void sendConflict();
 
+  void lemma(TNode node) { d_out->lemma(node); }
+  
   friend class Bitblaster;
   friend class BitblastSolver;
   friend class EqualitySolver;