updating the equality engine to be able to give explanations for terms that were...
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Tue, 6 Mar 2012 19:08:32 +0000 (19:08 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Tue, 6 Mar 2012 19:08:32 +0000 (19:08 +0000)
and it's a bit better
http://church.cims.nyu.edu/regress-results/compare_jobs.php?job_id=3738&category=&p=-1&reference_id=3731

src/theory/uf/equality_engine.h
src/theory/uf/equality_engine_impl.h

index 41b39af97e92a1b4687ed61bd13698f1360303cd..7314b6552e5083ab4770ecb7c3ef34da1954b777 100644 (file)
@@ -641,14 +641,14 @@ public:
    * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere
    * else. 
    */
-  void explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities) const;
+  void explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities);
 
   /**
    * Get an explanation of the equality t1 = t2. Returns the asserted equalities that
    * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere
    * else. 
    */
-  void explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities) const;
+  void explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities);
 
   /**
    * Add term to the trigger terms. The notify class will get notified when two 
index 02426c8499bad794db36ede42a9ba87447a1e2cb..925410561c35492eb231e8e35cf80f2a5a593c24 100644 (file)
@@ -27,6 +27,19 @@ namespace CVC4 {
 namespace theory {
 namespace uf {
 
+class ScopedBool {
+  bool& watch;
+  bool oldValue;
+public:
+  ScopedBool(bool& watch, bool newValue)
+  : watch(watch), oldValue(watch) {
+    watch = newValue;
+  }
+  ~ScopedBool() {
+    watch = oldValue;
+  }
+};
+
 template <typename NotifyClass>
 void EqualityEngine<NotifyClass>::enqueue(const MergeCandidate& candidate) {
     Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl;
@@ -195,8 +208,11 @@ void EqualityEngine<NotifyClass>::addDisequality(TNode t1, TNode t2, TNode reaso
 
   Debug("equality") << "EqualityEngine::addDisequality(" << t1 << "," << t2 << ")" << std::endl;
 
-  Node equality = t1.eqNode(t2);
-  addEquality(equality, d_false, reason);
+  Node equality1 = t1.eqNode(t2);
+  addEquality(equality1, d_false, reason);
+
+  Node equality2 = t2.eqNode(t1);
+  addEquality(equality2, d_false, reason);
 }
 
 
@@ -494,9 +510,19 @@ std::string EqualityEngine<NotifyClass>::edgesToString(EqualityEdgeId edgeId) co
 }
 
 template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities) const {
+void EqualityEngine<NotifyClass>::explainEquality(TNode t1, TNode t2, std::vector<TNode>& equalities) {
   Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl;
 
+  // Don't notify during this check
+  ScopedBool turnOfNotify(d_performNotify, false);
+
+  // Push the context, so that we can remove the terms later
+  d_context->push();
+
+  // Add the terms (they might not be there)
+  addTerm(t1);
+  addTerm(t2);
+
   Assert(getRepresentative(t1) == getRepresentative(t2),
          "Cannot explain an equality, because the two terms are not equal!\n"
          "The representative of %s\n"
@@ -510,13 +536,28 @@ void EqualityEngine<NotifyClass>::explainEquality(TNode t1, TNode t2, std::vecto
   EqualityNodeId t1Id = getNodeId(t1);
   EqualityNodeId t2Id = getNodeId(t2);
   getExplanation(t1Id, t2Id, equalities);
+
+  // Pop the possible extra information
+  d_context->pop();
 }
 
 template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities) const {
+void EqualityEngine<NotifyClass>::explainDisequality(TNode t1, TNode t2, std::vector<TNode>& equalities) {
   Debug("equality") << "EqualityEngine::explainDisequality(" << t1 << "," << t2 << ")" << std::endl;
 
+  // Don't notify during this check
+  ScopedBool turnOfNotify(d_performNotify, false);
+
+  // Push the context, so that we can remove the terms later
+  d_context->push();
+
+  // Add the terms
+  addTerm(t1);
+  addTerm(t2);
+
+  // Add the equality
   Node equality = t1.eqNode(t2);
+  addTerm(equality);
 
   Assert(getRepresentative(equality) == getRepresentative(d_false),
          "Cannot explain the dis-equality, because the two terms are not dis-equal!\n"
@@ -531,6 +572,9 @@ void EqualityEngine<NotifyClass>::explainDisequality(TNode t1, TNode t2, std::ve
   EqualityNodeId equalityId = getNodeId(equality);
   EqualityNodeId falseId = getNodeId(d_false);
   getExplanation(equalityId, falseId, equalities);
+
+  // Pop the possible extra information
+  d_context->pop();
 }
 
 
@@ -714,7 +758,7 @@ void EqualityEngine<NotifyClass>::propagate() {
     Assert(node1.getFind() == t1classId);
     Assert(node2.getFind() == t2classId);
 
-    // Add the actuall equality to the equality graph
+    // Add the actual equality to the equality graph
     addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason);
 
     // One more equality added
@@ -759,19 +803,6 @@ void EqualityEngine<NotifyClass>::debugPrintGraph() const {
   }
 }
 
-class ScopedBool {
-  bool& watch;
-  bool oldValue;
-public:
-  ScopedBool(bool& watch, bool newValue)
-  : watch(watch), oldValue(watch) {
-    watch = newValue;
-  }
-  ~ScopedBool() {
-    watch = oldValue;
-  }
-};
-
 template <typename NotifyClass>
 bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2)
 {
@@ -807,17 +838,9 @@ bool EqualityEngine<NotifyClass>::areDisequal(TNode t1, TNode t2)
   addTerm(t2);
 
   // Check (t1 = t2) = false
-  Node equality1 = t1.eqNode(t2);
-  addTerm(equality1);
-  if (getEqualityNode(equality1).getFind() == getEqualityNode(d_false).getFind()) {
-    d_context->pop();
-    return true;
-  }
-
-  // Check (t2 = t1) = false
-  Node equality2 = t2.eqNode(t1);
-  addTerm(equality2);
-  if (getEqualityNode(equality2).getFind() == getEqualityNode(d_false).getFind()) {
+  Node equality = t1.eqNode(t2);
+  addTerm(equality);
+  if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) {
     d_context->pop();
     return true;
   }