getEqualityStatus now also queries the inequality solver
authorlianah <lianahady@gmail.com>
Mon, 25 Mar 2013 22:24:29 +0000 (18:24 -0400)
committerlianah <lianahady@gmail.com>
Mon, 25 Mar 2013 22:24:29 +0000 (18:24 -0400)
src/theory/bv/bv_inequality_graph.cpp
src/theory/bv/bv_inequality_graph.h
src/theory/bv/bv_subtheory.h
src/theory/bv/bv_subtheory_bitblast.cpp
src/theory/bv/bv_subtheory_bitblast.h
src/theory/bv/bv_subtheory_core.h
src/theory/bv/bv_subtheory_inequality.cpp
src/theory/bv/bv_subtheory_inequality.h
src/theory/bv/theory_bv.cpp

index 4bd31587295f1e85f24f7d882b4211e1712e4658..a1d2efbb5ce23a25fd5941baf0df37afe20a88b8 100644 (file)
@@ -29,55 +29,6 @@ const ReasonId CVC4::theory::bv::UndefinedReasonId = -1;
 const ReasonId CVC4::theory::bv::AxiomReasonId = -2;
 
 
-// BitVector InequalityGraph::maxValue(unsigned bitwidth) {
-//   if (d_signed) {
-//     return BitVector(1, 0u).concat(~BitVector(bitwidth - 1, 0u)); 
-//   }
-//   return ~BitVector(bitwidth, 0u);
-// }
-
-// BitVector InequalityGraph::minValue(unsigned bitwidth) {
-//   if (d_signed) {
-//     return ~BitVector(bitwidth, 0u); 
-//   } 
-//   return BitVector(bitwidth, 0u);
-// }
-
-// TermId InequalityGraph::getMaxValueId(unsigned bitwidth) {
-//   BitVector bv = maxValue(bitwidth); 
-//   Node max = utils::mkConst(bv); 
-  
-//   if (d_termNodeToIdMap.find(max) == d_termNodeToIdMap.end()) {
-//     TermId id = d_termNodes.size(); 
-//     d_termNodes.push_back(max);
-//     d_termNodeToIdMap[max] = id;
-//     InequalityNode node(id, bitwidth, true, bv);
-//     d_ineqNodes.push_back(node); 
-
-//     // although it will never have out edges we need this to keep the size of
-//     // d_termNodes and d_ineqEdges in sync
-//     d_ineqEdges.push_back(Edges());
-//     return id; 
-//   }
-//   return d_termNodeToIdMap[max]; 
-// }
-
-// TermId InequalityGraph::getMinValueId(unsigned bitwidth) {
-//   BitVector bv = minValue(bitwidth); 
-//   Node min = utils::mkConst(bv); 
-
-//   if (d_termNodeToIdMap.find(min) == d_termNodeToIdMap.end()) {
-//     TermId id = d_termNodes.size(); 
-//     d_termNodes.push_back(min);
-//     d_termNodeToIdMap[min] = id;
-//     d_ineqEdges.push_back(Edges());
-//     InequalityNode node = InequalityNode(id, bitwidth, true, bv);
-//     d_ineqNodes.push_back(node); 
-//     return id; 
-//   }
-//   return d_termNodeToIdMap[min]; 
-// }
-
 bool InequalityGraph::addInequality(TNode a, TNode b, bool strict, TNode reason) {
   Debug("bv-inequality") << "InequlityGraph::addInequality " << a << " " << b << " strict: " << strict << "\n"; 
 
@@ -121,24 +72,21 @@ bool InequalityGraph::addInequality(TNode a, TNode b, bool strict, TNode reason)
   
   // add the inequality edge
   addEdge(id_a, id_b, strict, id_reason);
-  BFSQueue queue;
-  ModelValue mv = hasModelValue(id_a) ? getModelValue(id_a) : ModelValue();
-  queue.push(PQueueElement(id_a,  getValue(id_a), mv));
-  TermIdSet seen; 
-  return computeValuesBFS(queue, id_a, seen); 
+  BFSQueue queue(&d_modelValues);
+  Assert (hasModelValue(id_a)); 
+  queue.push(id_a);
+  return processQueue(queue, id_a); 
 }
 
-bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const TermIdSet& seen, bool& changed) {
-  TermId id = el.id;
-  const BitVector& lower_bound = el.lower_bound; 
-  InequalityNode& ineqNode = getInequalityNode(id);
-                               
-  if (ineqNode.isConstant()) {
+bool InequalityGraph::updateValue(TermId id, ModelValue new_mv, TermId start, bool& changed) {
+  BitVector lower_bound = new_mv.value;
+  
+  if (isConst(id)) {
     if (getValue(id) < lower_bound) {
       Debug("bv-inequality") << "Conflict: constant " << getValue(id) << "\n"; 
       std::vector<ReasonId> conflict;
-      TermId parent = el.model_value.parent; 
-      ReasonId reason = el.model_value.reason; 
+      TermId parent = new_mv.parent; 
+      ReasonId reason = new_mv.reason; 
       conflict.push_back(reason); 
       computeExplanation(UndefinedTermId, parent, conflict);
       Debug("bv-inequality") << "InequalityGraph::addInequality conflict: constant\n"; 
@@ -146,12 +94,12 @@ bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const T
       return false; 
     }
   } else {
-    // if not constant we can update the value
+    // if not constant we can try to update the value
     if (getValue(id) < lower_bound) {
       // if we are updating the term we started with we must be in a cycle
-      if (seen.count(id) && id == start) {
-        TermId parent = el.model_value.parent;
-        ReasonId reason = el.model_value.reason;
+      if (id == start) {
+        TermId parent = new_mv.parent;
+        ReasonId reason = new_mv.reason;
         std::vector<TermId> conflict;
         conflict.push_back(reason);
         computeExplanation(id, parent, conflict);
@@ -163,68 +111,66 @@ bool InequalityGraph::updateValue(const PQueueElement& el, TermId start, const T
                                       << "  from " << getValue(id) << "\n"
                                       << "  to " << lower_bound << "\n";
       changed = true;
-      ModelValue mv = el.model_value;
-      mv.value = lower_bound; 
-      setModelValue(id, mv); 
+      setModelValue(id, new_mv); 
     }
   }
   return true; 
 }
 
-bool InequalityGraph::computeValuesBFS(BFSQueue& queue, TermId start, TermIdSet& seen) {
-  if (queue.empty())
-    return true;
-
-  const PQueueElement current = queue.top();
-  queue.pop();
-  Debug("bv-inequality-internal") << "InequalityGraph::computeValuesBFS proceessing " << getTermNode(current.id) << " " << current.toString() << "\n";
-  bool updated_current = false; 
-  if (!updateValue(current, start, seen, updated_current)) {
-    return false; 
-  }
-  if (seen.count(current.id) && current.id == start) {
-    // we know what we didn't update start or we would have had a conflict 
-    Debug("bv-inequality-internal") << "InequalityGraph::computeValuesBFS equal cycle."; 
-    // this means we are in a cycle where all the values are forced to be equal
-    // TODO: make sure we collapse this cycle into one big node. 
-    return computeValuesBFS(queue, start, seen); 
-  }
-  
-  if (!updated_current && !(seen.count(current.id) == 0 && current.id == start)) {
-    // if we didn't update current we don't need to readd to the queue it's children 
-    seen.insert(current.id);
-    Debug("bv-inequality-internal") << "  unchanged " << getTermNode(current.id) << "\n";  
-    return computeValuesBFS(queue, start, seen); 
-  }
+bool InequalityGraph::processQueue(BFSQueue& queue, TermId start) {
+  while (!queue.empty()) {
+    TermId current = queue.top();
+    queue.pop();
+    Debug("bv-inequality-internal") << "InequalityGraph::processQueue proceessing " << getTermNode(current) << "\n";
   
-  seen.insert(current.id);
+    BitVector current_value = getValue(current);
   
-  const BitVector& current_value = getValue(current.id);
+    unsigned size = getBitwidth(current);
+    const BitVector zero(size, 0u); 
+    const BitVector one(size, 1u); 
   
-  unsigned size = getBitwidth(current.id);
-  const BitVector zero(size, 0u); 
-  const BitVector one(size, 1u); 
-  
-  const Edges& edges = getEdges(current.id);
-  for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) {
-    TermId next = it->next;
-    const BitVector increment = it->strict ? one : zero; 
-    const BitVector& next_lower_bound = current_value + increment;
-    if (next_lower_bound < current_value) {
-      // it means we have an overflow and hence a conflict
+    const Edges& edges = getEdges(current);
+    for (Edges::const_iterator it = edges.begin(); it!= edges.end(); ++it) {
+      TermId next = it->next;
+      ReasonId reason = it->reason;
+
+      const BitVector increment = it->strict ? one : zero; 
+      const BitVector next_lower_bound = current_value + increment;
+
+      if (next_lower_bound < current_value) {
+        // it means we have an overflow and hence a conflict
         std::vector<TermId> conflict;
         conflict.push_back(it->reason);
-        computeExplanation(start, current.id, conflict);
+        computeExplanation(start, current, conflict);
         Debug("bv-inequality") << "InequalityGraph::addInequality conflict: cycle \n"; 
         setConflict(conflict); 
         return false; 
+      }
+      
+      ModelValue new_mv(next_lower_bound, current, reason);       
+      bool updated = false; 
+      if (!updateValue(next, new_mv, start, updated)) {
+        return false; 
+      }
+      
+      if (next == start) {
+        // we know what we didn't update start or we would have had a conflict 
+        // this means we are in a cycle where all the values are forced to be equal
+        Debug("bv-inequality-internal") << "InequalityGraph::processQueue equal cycle."; 
+        continue; 
+      }
+      
+      if (!updated) {
+        // if we didn't update current we don't need to add to the queue it's children 
+        Debug("bv-inequality-internal") << "  unchanged " << getTermNode(next) << "\n";  
+        continue; 
+      }
+
+      queue.push(next);
+      Debug("bv-inequality-internal") << "   enqueue " << getTermNode(next) << "\n"; 
     }
-    const BitVector& value = getValue(next); 
-    PQueueElement el = PQueueElement(next, next_lower_bound, ModelValue(value, current.id, it->reason)); 
-    queue.push(el);
-    Debug("bv-inequality-internal") << "   enqueue " << getTermNode(el.id) << " " << el.toString() << "\n"; 
   }
-  return computeValuesBFS(queue, start, seen)
+  return true
 }
 
 void InequalityGraph::computeExplanation(TermId from, TermId to, std::vector<ReasonId>& explanation) {
@@ -371,8 +317,7 @@ bool InequalityGraph::hasModelValue(TermId id) const {
 
 BitVector InequalityGraph::getValue(TermId id) const {
   Assert (hasModelValue(id)); 
-  BitVector res = (*(d_modelValues.find(id))).second.value;
-  return res; 
+  return (*(d_modelValues.find(id))).second.value;
 }
 
 bool InequalityGraph::hasReason(TermId id) const {
@@ -396,12 +341,21 @@ bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) {
   if (!hasModelValue(id_b)) {
     initializeModelValue(b); 
   }
-  const BitVector& val_a = getValue(id_a);
-  const BitVector& val_b = getValue(id_b);
+  const BitVector val_a = getValue(id_a);
+  const BitVector val_b = getValue(id_b);
   if (val_a == val_b) {
     if (a.getKind() == kind::CONST_BITVECTOR) {
       // then we know b cannot be smaller  than the assigned value so we try to make it larger
-      return addInequality(a, b, true, reason);
+      std::vector<ReasonId> explanation_ids; 
+      computeExplanation(UndefinedTermId, id_b, explanation_ids); 
+      std::vector<TNode> explanation_nodes;
+      explanation_nodes.push_back(reason);
+      for (unsigned i = 0; i < explanation_ids.size(); ++i) {
+        explanation_nodes.push_back(getReasonNode(explanation_ids[i])); 
+      }
+      Node explanation = utils::mkAnd(explanation_nodes);
+      d_reasonSet.insert(explanation); 
+      return addInequality(a, b, true, explanation);
     }
     if (b.getKind() == kind::CONST_BITVECTOR) {
       return addInequality(b, a, true, reason);
@@ -418,32 +372,26 @@ bool InequalityGraph::addDisequality(TNode a, TNode b, TNode reason) {
 void InequalityGraph::splitDisequality(TNode diseq) {
   Debug("bv-inequality-internal")<<"InequalityGraph::splitDisequality " << diseq <<"\n"; 
   Assert (diseq.getKind() == kind::NOT && diseq[0].getKind() == kind::EQUAL);
-  TNode a = diseq[0][0];
-  TNode b = diseq[0][1];
-  Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b);
-  Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a);
-  Node split = utils::mkNode(kind::OR, a_lt_b, b_lt_a);
-  Node lemma = utils::mkNode(kind::IMPLIES, diseq, split);
-  if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) { 
-    d_lemmaQueue.push_back(lemma);
+  if (d_disequalitiesAlreadySplit.find(diseq) == d_disequalitiesAlreadySplit.end()) {
+    d_disequalitiesToSplit.push_back(diseq); 
   }
 }
 
-void InequalityGraph::getNewLemmas(std::vector<TNode>& new_lemmas) {
-  for (unsigned i = d_lemmaIndex; i < d_lemmaQueue.size(); ++i)  {
-    TNode lemma = d_lemmaQueue[i];
-    if (d_lemmasAdded.find(lemma) == d_lemmasAdded.end()) {
-      new_lemmas.push_back(lemma);
-      d_lemmasAdded.insert(lemma); 
+void InequalityGraph::getNewLemmas(std::vector<Node>& new_lemmas) {
+  for (unsigned i = d_diseqToSplitIndex; i < d_disequalitiesToSplit.size(); ++i)  {
+    TNode diseq = d_disequalitiesToSplit[i];
+    if (d_disequalitiesAlreadySplit.find(diseq) == d_disequalitiesAlreadySplit.end()) {
+        TNode a = diseq[0][0];
+        TNode b = diseq[0][1];
+        Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b);
+        Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a);
+        Node eq = diseq[0]; 
+        Node lemma = utils::mkNode(kind::OR, a_lt_b, b_lt_a, eq);
+        new_lemmas.push_back(lemma);
+        d_disequalitiesAlreadySplit.insert(diseq); 
     }
-    d_lemmaIndex = d_lemmaIndex + 1; 
-  } 
-}
-
-std::string InequalityGraph::PQueueElement::toString() const {
-  ostringstream os;
-  os << "(id: " << id << ", lower_bound: " << lower_bound.toString(10) <<", old_value: " << model_value.value.toString(10) << ")"; 
-  return os.str(); 
+    d_diseqToSplitIndex = d_diseqToSplitIndex + 1; 
+  }
 }
 
 void InequalityGraph::backtrack() {
@@ -467,3 +415,37 @@ void InequalityGraph::backtrack() {
     edges.pop_back(); 
   }
 }
+
+void InequalityGraph::checkDisequalities() {
+  for (CDQueue<TNode>::const_iterator it = d_disequalities.begin(); it != d_disequalities.end(); ++it) {
+    if (d_disequalitiesAlreadySplit.find(*it) == d_disequalitiesAlreadySplit.end()) {
+      // if we haven't already split on this disequality
+      TNode diseq = *it;
+      TermId a_id = registerTerm(diseq[0][0]);
+      TermId b_id = registerTerm(diseq[0][1]);
+      if (getValue(a_id) == getValue(b_id)) {
+        // if the disequality is not satisified by the model 
+        d_disequalitiesToSplit.push_back(diseq); 
+      }
+    }
+  }
+}
+
+bool InequalityGraph::isLessThan(TNode a, TNode b) {
+  Assert (isRegistered(a) && isRegistered(b));
+  Unimplemented(); 
+}
+
+bool InequalityGraph::hasValueInModel(TNode node) const {
+  if (isRegistered(node)) {
+    TermId id = getTermId(node);
+    return hasModelValue(id); 
+  }
+  return false; 
+}
+
+BitVector InequalityGraph::getValueInModel(TNode node) const {
+  TermId id = getTermId(node); 
+  Assert (hasModelValue(id));
+  return getValue(id); 
+}
index 1335eff93c337c251517e435096e8bd9556f7d91..b23ea77047522789c902ac552877de2ab5c1e8ae 100644 (file)
@@ -87,34 +87,37 @@ class InequalityGraph : public context::ContextNotifyObj{
         value(val)
     {}
   };
+  
+  typedef context::CDHashMap<TermId, ModelValue> Model;
 
-  struct PQueueElement {
-    TermId id;
-    BitVector lower_bound;
-    ModelValue model_value; 
-    PQueueElement(TermId id, const BitVector& lb, const ModelValue& mv)
-      : id(id),
-        lower_bound(lb),
-        model_value(mv)
+  struct QueueComparator {
+    const Model* d_model;
+    QueueComparator(const Model* model)
+      : d_model(model)
     {}
-    
-    bool operator< (const PQueueElement& other) const {
-      return model_value.value > other.model_value.value;
+    bool operator() (TermId left, TermId right) const {
+      Assert (d_model->find(left) != d_model->end() &&
+              d_model->find(right) != d_model->end());
+      
+      return (*(d_model->find(left))).second.value < (*(d_model->find(right))).second.value; 
     }
-    std::string toString() const; 
-  };
-  
+  }; 
+
   typedef __gnu_cxx::hash_map<TNode, ReasonId, TNodeHashFunction> ReasonToIdMap;
   typedef __gnu_cxx::hash_map<TNode, TermId, TNodeHashFunction> TermNodeToIdMap;
 
   typedef std::vector<InequalityEdge> Edges; 
   typedef __gnu_cxx::hash_set<TermId> TermIdSet;
 
-  typedef std::priority_queue<PQueueElement> BFSQueue; 
-  typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet; 
+  typedef std::priority_queue<TermId, std::vector<TermId>, QueueComparator> BFSQueue; 
+  typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
+  typedef __gnu_cxx::hash_set<Node, NodeHashFunction> NodeSet;
+
   std::vector<InequalityNode> d_ineqNodes;
   std::vector< Edges > d_ineqEdges;
-  
+
+  // to keep the explanation nodes alive
+  NodeSet d_reasonSet; 
   std::vector<TNode> d_reasonNodes;
   ReasonToIdMap d_reasonToIdMap;
   
@@ -125,7 +128,7 @@ class InequalityGraph : public context::ContextNotifyObj{
   std::vector<TNode> d_conflict;
   bool d_signed; 
 
-  context::CDHashMap<TermId, ModelValue>  d_modelValues;
+  Model  d_modelValues;
   void initializeModelValue(TNode node); 
   void setModelValue(TermId term, const ModelValue& mv);
   ModelValue getModelValue(TermId term) const;
@@ -163,23 +166,21 @@ class InequalityGraph : public context::ContextNotifyObj{
   /** 
    * If necessary update the value in the model of the current queue element. 
    * 
-   * @param el current queue element we are updating
+   * @param id current queue element we are updating
    * @param start node we started with, to detect cycles
-   * @param seen 
    * 
    * @return 
    */
-  bool updateValue(const PQueueElement& el, TermId start, const TermIdSet& seen, bool& changed);
+  bool updateValue(TermId id, ModelValue new_mv, TermId start, bool& changed);
   /** 
    * Update the current model starting with the start term. 
    * 
    * @param queue 
    * @param start 
-   * @param seen 
    * 
    * @return 
    */
-  bool computeValuesBFS(BFSQueue& queue, TermId start, TermIdSet& seen);
+  bool processQueue(BFSQueue& queue, TermId start);
   /** 
    * Return the reasons why from <= to. If from is undefined we just
    * explain the current value of to. 
@@ -197,9 +198,10 @@ class InequalityGraph : public context::ContextNotifyObj{
   
   /*** The currently asserted disequalities */
   context::CDQueue<TNode> d_disequalities;
-  context::CDQueue<Node> d_lemmaQueue; 
-  context::CDO<unsigned> d_lemmaIndex; 
-  TNodeSet d_lemmasAdded; 
+  context::CDQueue<TNode> d_disequalitiesToSplit; 
+  context::CDO<unsigned> d_diseqToSplitIndex; 
+  TNodeSet d_lemmasAdded;
+  TNodeSet d_disequalitiesAlreadySplit; 
   
   /** Backtracking mechanisms **/
   std::vector<std::pair<TermId, InequalityEdge> > d_undoStack;
@@ -223,28 +225,72 @@ public:
       d_signed(s),
       d_modelValues(c),
       d_disequalities(c),
-      d_lemmaQueue(c),
-      d_lemmaIndex(c, 0),
-      d_lemmasAdded(),
+      d_disequalitiesToSplit(c),
+      d_diseqToSplitIndex(c, 0),
+      d_disequalitiesAlreadySplit(),
       d_undoStack(),
       d_undoStackIndex(c)
   {}
   /** 
-   * 
+   * Add a new inequality to the graph 
    * 
    * @param a 
    * @param b 
-   * @param diff 
+   * @param strict 
    * @param reason 
    * 
    * @return 
    */
   bool addInequality(TNode a, TNode b, bool strict, TNode reason);
+  /** 
+   * Add a new disequality to the graph. This may lead in a lemma. 
+   * 
+   * @param a 
+   * @param b 
+   * @param reason 
+   * 
+   * @return 
+   */
   bool addDisequality(TNode a, TNode b, TNode reason); 
-  bool areLessThan(TNode a, TNode b);
   void getConflict(std::vector<TNode>& conflict);
   virtual ~InequalityGraph() throw(AssertionException) {}
-  void getNewLemmas(std::vector<TNode>& new_lemmas);
+  /** 
+   * Get any new lemmas (resulting from disequalities splits) that need
+   * to be added. 
+   * 
+   * @param new_lemmas 
+   */
+  void getNewLemmas(std::vector<Node>& new_lemmas);
+  /** 
+   * Check that the currently asserted disequalities that have not been split on
+   * are still true in the current model. 
+   */
+  void checkDisequalities();
+  /** 
+   * Return true if a < b is entailed by the current set of assertions. 
+   * 
+   * @param a 
+   * @param b 
+   * 
+   * @return 
+   */
+  bool isLessThan(TNode a, TNode b);
+  /** 
+   * Returns true if the term has a value in the model (i.e. if we have seen it)
+   * 
+   * @param a 
+   * 
+   * @return 
+   */
+  bool hasValueInModel(TNode a) const;
+  /** 
+   * Return the value of a in the current model. 
+   * 
+   * @param a 
+   * 
+   * @return 
+   */
+  BitVector getValueInModel(TNode a) const; 
 }; 
 
 }
index c442fa6dd2ba4cb9f14e339c932d1607a023130f..00b3526c05b56254b862b333b6106abbbdebfee9 100644 (file)
@@ -91,6 +91,9 @@ public:
   virtual void preRegister(TNode node) {}
   virtual void propagate(Theory::Effort e) {}
   virtual void collectModelInfo(TheoryModel* m) = 0;
+  virtual bool isComplete() = 0;
+  virtual EqualityStatus getEqualityStatus(TNode a, TNode b) = 0;
+  
   bool done() { return d_assertionQueue.size() == d_assertionIndex; }
   TNode get() {
     Assert (!done()); 
@@ -98,8 +101,7 @@ public:
     d_assertionIndex = d_assertionIndex + 1;
     return res; 
   }
-  void assertFact(TNode fact) { d_assertionQueue.push_back(fact); }
-
+  virtual void assertFact(TNode fact) { d_assertionQueue.push_back(fact); }
 }; 
 
 }
index 2f76e32d381e48bb6275c187fb722a3acf7fdcfc..20da2511cedc72985516e46e02a9144c6ec59ea2 100644 (file)
@@ -74,7 +74,7 @@ bool BitblastSolver::check(Theory::Effort e) {
     d_bitblastQueue.pop();
   }
 
-  // Processingssertions  
+  // Processing assertions  
   while (!done()) {
     TNode fact = get(); 
     if (!d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_BITBLAST)) {
index 318fdd230c6e24cfff92194179740e5533be6a60..47bed07dd920071395b32966fdeeacc8826f71fa 100644 (file)
@@ -46,6 +46,7 @@ public:
   void  explain(TNode literal, std::vector<TNode>& assumptions);
   EqualityStatus getEqualityStatus(TNode a, TNode b);
   void collectModelInfo(TheoryModel* m); 
+  bool isComplete() { return true; }
 };
 
 }
index 868f3754f370583e3a1d084f3a4bf773977a2b78..5eb37b50a757cbd2e9bc081a7f2fb035787f20cf 100644 (file)
@@ -71,7 +71,7 @@ class CoreSolver : public SubtheorySolver {
 public: 
   CoreSolver(context::Context* c, TheoryBV* bv);
   ~CoreSolver();
-  bool  isCoreTheory() { return d_isCoreTheory; }
+  bool  isComplete() { return d_isCoreTheory; }
   void  setMasterEqualityEngine(eq::EqualityEngine* eq);
   void  preRegister(TNode node);
   bool  check(Theory::Effort e);
index 6b9842e8fa561a7784aa0266e16e9f353655ad08..33802668145a853f691afba001c0ef619a26561e 100644 (file)
@@ -61,14 +61,20 @@ bool InequalitySolver::check(Theory::Effort e) {
       ok = d_inequalityGraph.addInequality(a, b, false, fact);
     }
   }
+  
   if (!ok) {
     std::vector<TNode> conflict;
     d_inequalityGraph.getConflict(conflict); 
-    d_bv->setConflict(utils::mkConjunction(conflict));
+    d_bv->setConflict(utils::flattenAnd(conflict));
     return false; 
   }
+
+  // make sure all the disequalities we didn't split on are still satisifed
+  // and split on the ones that are not
+  d_inequalityGraph.checkDisequalities();
+
   // send out any lemmas
-  std::vector<TNode> lemmas;
+  std::vector<Node> lemmas;
   d_inequalityGraph.getNewLemmas(lemmas); 
   for(unsigned i = 0; i < lemmas.size(); ++i) {
     d_bv->lemma(lemmas[i]); 
@@ -76,6 +82,38 @@ bool InequalitySolver::check(Theory::Effort e) {
   return true; 
 }
 
+EqualityStatus InequalitySolver::getEqualityStatus(TNode a, TNode b) {
+  Node a_lt_b = utils::mkNode(kind::BITVECTOR_ULT, a, b);
+  Node b_lt_a = utils::mkNode(kind::BITVECTOR_ULT, b, a);
+
+  // if an inequality containing the terms has been asserted then we know
+  // the equality is false
+  if (d_assertionSet.contains(a_lt_b) || d_assertionSet.contains(b_lt_a)) {
+    return EQUALITY_FALSE; 
+  }
+  
+  if (!d_inequalityGraph.hasValueInModel(a) ||
+      !d_inequalityGraph.hasValueInModel(b)) {
+    return EQUALITY_UNKNOWN; 
+  }
+
+  // TODO: check if this disequality is entailed by inequalities via transitivity
+  
+  BitVector a_val = d_inequalityGraph.getValueInModel(a);
+  BitVector b_val = d_inequalityGraph.getValueInModel(b);
+  
+  if (a_val == b_val) {
+    return EQUALITY_TRUE_IN_MODEL; 
+  } else {
+    return EQUALITY_FALSE_IN_MODEL; 
+  }
+}
+
+void InequalitySolver::assertFact(TNode fact) {
+  d_assertionQueue.push_back(fact);
+  d_assertionSet.insert(fact); 
+}
+
 void InequalitySolver::explain(TNode literal, std::vector<TNode>& assumptions) {
   Assert (false); 
 }
index 07c561c848481ee1cedfdca8e5b855a6a4655eca..6d1d77c7efb0ebb8b9d1c3ed22e2bd2afe017395 100644 (file)
 
 #include "theory/bv/bv_subtheory.h"
 #include "theory/bv/bv_inequality_graph.h"
+#include "context/cdhashset.h"
 
 namespace CVC4 {
 namespace theory {
 namespace bv {
 
 class InequalitySolver: public SubtheorySolver {
+  context::CDHashSet<Node, NodeHashFunction> d_assertionSet; 
   InequalityGraph d_inequalityGraph;
 public:
   
   InequalitySolver(context::Context* c, TheoryBV* bv)
     : SubtheorySolver(c, bv),
+      d_assertionSet(c),
       d_inequalityGraph(c)
   {}
   
   bool check(Theory::Effort e);
   void propagate(Theory::Effort e); 
   void explain(TNode literal, std::vector<TNode>& assumptions);
-  bool isInequalityTheory() { return true; }
-  virtual void collectModelInfo(TheoryModel* m) {}
+  bool isComplete() { return true; }
+  void collectModelInfo(TheoryModel* m) {}
+  EqualityStatus getEqualityStatus(TNode a, TNode b);
+  void assertFact(TNode fact); 
 }; 
 
 }
index bc8e39e67b5562e72e1a66ef1c55dc7ead4a1857..bdf93eadccfe4da22c1ca612a60c05093e3d38ed 100644 (file)
@@ -32,9 +32,6 @@ using namespace CVC4::context;
 using namespace std;
 using namespace CVC4::theory::bv::utils;
 
-
-
-
 TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo, QuantifiersEngine* qe)
   : Theory(THEORY_BV, c, u, out, valuation, logicInfo, qe),
     d_context(c),
@@ -122,11 +119,11 @@ void TheoryBV::check(Effort e)
   }
   Assert (!ok == inConflict()); 
 
-  if (!inConflict() && !d_coreSolver.isCoreTheory()) {
+  if (!inConflict() && !d_coreSolver.isComplete()) {
     ok = d_inequalitySolver.check(e); 
   }
 
-  Assert (!ok == inConflict());
+  // Assert (!ok == inConflict());
   // if (!inConflict() && !d_coreSolver.isCoreTheory()) {
   // if (!inConflict() && !d_inequalitySolver.isInequalityTheory()) {
   //   ok = d_bitblastSolver.check(e); 
@@ -303,6 +300,9 @@ EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b)
   }
 
   EqualityStatus status = d_coreSolver.getEqualityStatus(a, b);
+  if (status == EQUALITY_UNKNOWN) {
+    status = d_inequalitySolver.getEqualityStatus(a, b); 
+  }
   if (status == EQUALITY_UNKNOWN) {
     status = d_bitblastSolver.getEqualityStatus(a, b);
   }