Committing the work on equality engine, I need to see how it does on the regressions...
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Sun, 27 May 2012 05:44:13 +0000 (05:44 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Sun, 27 May 2012 05:44:13 +0000 (05:44 +0000)
* areDisequal(x, y) -> areDisequal(x, y, needProof): when asking for a disequality you must say needProof if you will ask for an explanation later.
* propagation of shared dis-equalities (not yet complete, once case missing)
* changes to the theories that use it, authors should check up on the changes

src/theory/arrays/theory_arrays.cpp
src/theory/arrays/theory_arrays.h
src/theory/bv/bv_subtheory_eq.h
src/theory/shared_terms_database.cpp
src/theory/term_registration_visitor.cpp
src/theory/uf/equality_engine.cpp
src/theory/uf/equality_engine.h
src/theory/uf/equality_engine_types.h
src/theory/uf/theory_uf.cpp

index 44e362f906d4090a733665d0659aa83325084e9b..15636fc72fd65d06614391c7b80edb3773bea3a6 100644 (file)
@@ -86,12 +86,19 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC
   d_true = NodeManager::currentNM()->mkConst<bool>(true);
   d_false = NodeManager::currentNM()->mkConst<bool>(false);
 
+  // The preprocessing congruence kinds
+  d_ppEqualityEngine.addFunctionKind(kind::SELECT);
+  d_ppEqualityEngine.addFunctionKind(kind::STORE);
+
+  // The mayequal congruence kinds
+  d_mayEqualEqualityEngine.addFunctionKind(kind::SELECT);
+  d_mayEqualEqualityEngine.addFunctionKind(kind::STORE);
+
   // The kinds we are treating as function application in congruence
   d_equalityEngine.addFunctionKind(kind::SELECT);
   if (d_ccStore) {
     d_equalityEngine.addFunctionKind(kind::STORE);
   }
-  d_equalityEngine.addFunctionKind(kind::EQUAL);
   if (d_useArrTable) {
     d_equalityEngine.addFunctionKind(kind::ARR_TABLE_FUN);
   }
@@ -118,12 +125,13 @@ TheoryArrays::~TheoryArrays() {
 
 Node TheoryArrays::ppRewrite(TNode term) {
   if (!d_preprocess) return term;
+  d_ppEqualityEngine.addTerm(term);
   switch (term.getKind()) {
     case kind::SELECT: {
       // select(store(a,i,v),j) = select(a,j)
       //    IF i != j
       if (term[0].getKind() == kind::STORE &&
-          (d_ppEqualityEngine.areDisequal(term[0][1], term[1]) ||
+          (d_ppEqualityEngine.areDisequal(term[0][1], term[1], false) ||
            (term[0][1].isConst() && term[1].isConst() && term[0][1] != term[1]))) {
         return NodeBuilder<2>(kind::SELECT) << term[0][0] << term[1];
       }
@@ -134,7 +142,7 @@ Node TheoryArrays::ppRewrite(TNode term) {
       //    IF i != j and j comes before i in the ordering
       if (term[0].getKind() == kind::STORE &&
           (term[1] < term[0][1]) &&
-          (d_ppEqualityEngine.areDisequal(term[1], term[0][1]) ||
+          (d_ppEqualityEngine.areDisequal(term[1], term[0][1], false) ||
            (term[0][1].isConst() && term[1].isConst() && term[0][1] != term[1]))) {
         Node inner = NodeBuilder<3>(kind::STORE) << term[0][0] << term[1] << term[2];
         Node outer = NodeBuilder<3>(kind::STORE) << inner << term[0][1] << term[0][2];
@@ -198,7 +206,7 @@ Node TheoryArrays::ppRewrite(TNode term) {
                 NodeBuilder<> hyp(kind::AND);
                 for (j = leftWrites - 1; j > i; --j) {
                   index_j = write_j[1];
-                  if (d_ppEqualityEngine.areDisequal(index_i, index_j) ||
+                  if (d_ppEqualityEngine.areDisequal(index_i, index_j, false) ||
                       (index_i.isConst() && index_j.isConst() && index_i != index_j)) {
                     continue;
                   }
@@ -374,6 +382,7 @@ void TheoryArrays::preRegisterTerm(TNode node)
   switch (node.getKind()) {
   case kind::EQUAL:
     // Add the trigger for equality
+    // NOTE: note that if the equality is true or false already, it might not be added
     d_equalityEngine.addTriggerEquality(node);
     break;
   case kind::SELECT: {
@@ -385,6 +394,9 @@ void TheoryArrays::preRegisterTerm(TNode node)
     // Reads
     TNode store = d_equalityEngine.getRepresentative(node[0]);
 
+    // The may equal needs the store
+    d_mayEqualEqualityEngine.addTerm(store);
+
     // Apply RIntro1 rule to any stores equal to store if not done already
     const CTNodeList* stores = d_infoMap.getStores(store);
     CTNodeList::const_iterator it = stores->begin();
@@ -460,7 +472,8 @@ void TheoryArrays::preRegisterTerm(TNode node)
     break;
   }
   // Invariant: preregistered terms are exactly the terms in the equality engine
-  Assert(d_equalityEngine.hasTerm(node));
+  // Disabled, see comment above for kind::EQUAL
+  // Assert(d_equalityEngine.hasTerm(node) || !d_equalityEngine.consistent());
 }
 
 
@@ -525,7 +538,7 @@ EqualityStatus TheoryArrays::getEqualityStatus(TNode a, TNode b) {
     // The terms are implied to be equal
     return EQUALITY_TRUE;
   }
-  if (d_equalityEngine.areDisequal(a, b)) {
+  if (d_equalityEngine.areDisequal(a, b, false)) {
     // The terms are implied to be dis-equal
     return EQUALITY_FALSE;
   }
@@ -576,10 +589,10 @@ void TheoryArrays::computeCareGraph()
 
         if (r1[0] != r2[0]) {
           // If arrays are known to be disequal, or cannot become equal, we can continue
-          Assert(d_equalityEngine.hasTerm(r1[0]) && d_equalityEngine.hasTerm(r2[0]));
+          Assert(d_mayEqualEqualityEngine.hasTerm(r1[0]) && d_mayEqualEqualityEngine.hasTerm(r2[0]));
           if (r1[0].getType() != r2[0].getType() ||
               (!d_mayEqualEqualityEngine.areEqual(r1[0], r2[0])) ||
-              d_equalityEngine.areDisequal(r1[0], r2[0])) {
+              d_equalityEngine.areDisequal(r1[0], r2[0], false)) {
             Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): arrays can't be equal, skipping" << std::endl;
             continue;
           }
@@ -704,7 +717,7 @@ void TheoryArrays::check(Effort e) {
       case kind::NOT:
         if (fact[0].getKind() == kind::SELECT) {
           d_equalityEngine.assertPredicate(fact[0], false, fact);
-        } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1])) {
+        } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1], false)) {
           // Assert the dis-equality
           d_equalityEngine.assertEquality(fact[0], false, fact);
 
@@ -1141,7 +1154,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem)
 
   // If propagating, check propagations
   if (d_propagateLemmas) {
-    if (d_equalityEngine.areDisequal(i,j)) {
+    if (d_equalityEngine.areDisequal(i,j,true)) {
       Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating aj = bj ("<<aj<<", "<<bj<<")\n";
       Node aj_eq_bj = aj.eqNode(bj);
       Node i_eq_j = i.eqNode(j);
@@ -1157,7 +1170,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem)
       ++d_numProp;
       return;
     }
-    if (bothExist && d_equalityEngine.areDisequal(aj,bj)) {
+    if (bothExist && d_equalityEngine.areDisequal(aj,bj,true)) {
       Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<<i<<", "<<j<<")\n";
       Node aj_eq_bj = aj.eqNode(bj);
       Node i_eq_j = i.eqNode(j);
@@ -1178,7 +1191,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem)
   }
 
   // Prefer equality between indexes so as not to introduce new read terms
-  if (d_eagerIndexSplitting && !bothExist && !d_equalityEngine.areDisequal(i,j)) {
+  if (d_eagerIndexSplitting && !bothExist && !d_equalityEngine.areDisequal(i,j, false)) {
     Node split = d_valuation.ensureLiteral(i.eqNode(j));
     d_out->propagateAsDecision(split);
   }
index 639b03df804346d25cc498d8e3492673d3f0ef0f..03d7e7d8daae4822766ce59720f3752dbcf5dc9d 100644 (file)
@@ -228,7 +228,7 @@ class TheoryArrays : public Theory {
     NotifyClass(TheoryArrays& arrays): d_arrays(arrays) {}
 
     bool eqNotifyTriggerEquality(TNode equality, bool value) {
-      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl;
+      Debug("arrays::propagate") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl;
       // Just forward to arrays
       if (value) {
         return d_arrays.propagate(equality);
@@ -242,7 +242,7 @@ class TheoryArrays : public Theory {
     }
 
     bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
-      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl;
+      Debug("arrays::propagate") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl;
       if (value) {
         if (t1.getType().isArray()) {
           d_arrays.mergeArrays(t1, t2);
@@ -259,7 +259,7 @@ class TheoryArrays : public Theory {
     }
 
     bool eqNotifyConstantTermMerge(TNode t1, TNode t2) {
-      Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
+      Debug("arrays::propagate") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl;
       if (Theory::theoryOf(t1) == THEORY_BOOL) {
         return d_arrays.propagate(t1.iffNode(t2));
       } else {
index 241dd591967aa611cdb1e0ce6844b38cd6e7d5aa..356d12a0655382165a195a47c86eca91f9565c49 100644 (file)
@@ -64,7 +64,7 @@ public:
       // The terms are implied to be equal
       return EQUALITY_TRUE;
     }
-    if (d_equalityEngine.areDisequal(a, b)) {
+    if (d_equalityEngine.areDisequal(a, b, false)) {
       // The terms are implied to be dis-equal
       return EQUALITY_FALSE;
     }
index 90037b90bfce386aec517b9d0db0ff5c4667b74c..0c893482a33feaaffbe2f012dcf34901fac602cd 100644 (file)
@@ -297,7 +297,7 @@ bool SharedTermsDatabase::areEqual(TNode a, TNode b) {
 
 
 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) {
-  return d_equalityEngine.areDisequal(a,b);
+  return d_equalityEngine.areDisequal(a,b,false);
 }
 
 void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason)
@@ -305,7 +305,7 @@ void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason)
   bool negated = literal.getKind() == kind::NOT;
   TNode atom = negated ? literal[0] : literal;
   if (negated) {
-    Assert(!d_equalityEngine.areDisequal(atom[0], atom[1]));
+    Assert(!d_equalityEngine.areDisequal(atom[0], atom[1],false));
     d_equalityEngine.assertEquality(atom, false, reason);
     //    !!! need to send this out
   }
index 099871ceb20fa0855f570db488a96d41f606f902..22b87c32f5c42e9dc2f1480073fd75c5fc9718a7 100644 (file)
@@ -32,7 +32,6 @@ bool PreRegisterVisitor::alreadyVisited(TNode current, TNode parent) {
 
   Debug("register::internal") << "PreRegisterVisitor::alreadyVisited(" << current << "," << parent << ")" << std::endl;
 
-
   TheoryId currentTheoryId = Theory::theoryOf(current);
   TheoryId parentTheoryId = Theory::theoryOf(parent);
 
index e60d52c7ac59fab56e476d6aa44e1d3d07213b6d..72966d0b333fd2f69fe84f6320ab90af64eef98d 100644 (file)
@@ -90,6 +90,7 @@ EqualityEngine::EqualityEngine(context::Context* context, std::string name)
 , d_stats(name)
 , d_triggerDatabaseSize(context, 0)
 , d_triggerTermSetUpdatesSize(context, 0)
+, d_deducedDisequalitiesSize(context, 0)
 {
   init();
 }
@@ -107,6 +108,7 @@ EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* c
 , d_stats(name)
 , d_triggerDatabaseSize(context, 0)
 , d_triggerTermSetUpdatesSize(context, 0)
+, d_deducedDisequalitiesSize(context, 0)
 {
   init();
 }
@@ -152,8 +154,8 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId
   }
 
   // Add to the use lists
-  d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes);
-  d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes);
+  d_equalityNodes[t1ClassId].usedIn<USE_LIST_FUNCTIONS>(funId, d_useListNodes);
+  d_equalityNodes[t2ClassId].usedIn<USE_LIST_FUNCTIONS>(funId, d_useListNodes);
 
   // Return the new id
   Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl;
@@ -205,6 +207,10 @@ void EqualityEngine::addTerm(TNode t) {
     return;
   }
 
+  if (d_done) {
+    return;
+  }
+
   EqualityNodeId result;
 
   if (t.getKind() == kind::EQUAL) {
@@ -270,6 +276,10 @@ void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) {
 
   Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl;
 
+  if (d_done) {
+    return;
+  }
+
   // Add the terms if they are not already in the database
   addTerm(t1);
   addTerm(t2);
@@ -278,27 +288,85 @@ void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) {
   EqualityNodeId t1Id = getNodeId(t1);
   EqualityNodeId t2Id = getNodeId(t2);
   enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason));
-
-  propagate();
 }
 
 void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason) {
   Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl;
   Assert(t.getKind() != kind::EQUAL, "Use assertEquality instead");
   assertEqualityInternal(t, polarity ? d_true : d_false, reason);
+  propagate();
 }
 
 void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) {
   Debug("equality") << "EqualityEngine::addEquality(" << eq << "," << (polarity ? "true" : "false") << std::endl;
   if (polarity) {
+    // If two terms are already equal, don't assert anything
+    if (hasTerm(eq[0]) && hasTerm(eq[1]) && areEqual(eq[0], eq[1])) {
+      return;
+    }
     // Add equality between terms
     assertEqualityInternal(eq[0], eq[1], reason);
+    propagate();
     // Add eq = true for dis-equality propagation
     assertEqualityInternal(eq, d_true, reason);
+    propagate();    
   } else {
+    // If two terms are already dis-equal, don't assert anything
+    if (hasTerm(eq[0]) && hasTerm(eq[1]) && areDisequal(eq[0], eq[1], false)) {
+      return;
+    }
+    
     assertEqualityInternal(eq, d_false, reason);
-    Node eqSymm = eq[1].eqNode(eq[0]);
-    assertEqualityInternal(eqSymm, d_false, reason);
+    propagate();    
+    assertEqualityInternal(eq[1].eqNode(eq[0]), d_false, reason);
+    propagate();
+  
+    if (d_done) {
+      return;
+    }
+  
+    // If we are adding a disequality, notify of the shared term representatives
+    EqualityNodeId a = getNodeId(eq[0]);
+    EqualityNodeId b = getNodeId(eq[1]);
+    EqualityNodeId eqId = getNodeId(eq);
+    TriggerTermSetRef aTriggerRef = d_nodeIndividualTrigger[a];
+    TriggerTermSetRef bTriggerRef = d_nodeIndividualTrigger[b];
+    if (aTriggerRef != +null_set_id && bTriggerRef != +null_set_id) {
+      // The sets of trigger terms
+      TriggerTermSet& aTriggerTerms = getTriggerTermSet(aTriggerRef);
+      TriggerTermSet& bTriggerTerms = getTriggerTermSet(bTriggerRef);
+      // Go through and notify the shared dis-equalities 
+      Theory::Set aTags = aTriggerTerms.tags;           
+      Theory::Set bTags = bTriggerTerms.tags;           
+      TheoryId aTag = Theory::setPop(aTags);
+      TheoryId bTag = Theory::setPop(bTags);
+      int a_i = 0, b_i = 0;
+      while (aTag != THEORY_LAST && bTag != THEORY_LAST) {
+        if (aTag < bTag) {
+          aTag = Theory::setPop(aTags);
+          ++ a_i;                  
+        } else if (aTag > bTag) {
+          bTag = Theory::setPop(bTags);
+          ++ b_i;
+        } else {
+          // Same tags, notify
+          EqualityNodeId aSharedId = aTriggerTerms.triggers[a_i++];
+          EqualityNodeId bSharedId = bTriggerTerms.triggers[b_i++];
+          d_deducedDisequalityReasons.push_back(EqualityPair(aSharedId, a));
+          d_deducedDisequalityReasons.push_back(EqualityPair(bSharedId, b));
+          d_deducedDisequalityReasons.push_back(EqualityPair(eqId, d_falseId));
+          storePropagatedDisequality(d_nodes[aSharedId], d_nodes[bSharedId], 3);
+          // We notify even if the it's already been sent (they are not 
+          // disequal at assertion, and we need to notify for each tag) 
+          if (!d_notify.eqNotifyTriggerTermEquality(aTag, d_nodes[aSharedId], d_nodes[bSharedId], false)) {
+            break;
+          }
+          // Pop the next tags
+          aTag = Theory::setPop(aTags);
+          bTag = Theory::setPop(bTags);
+        }
+      }
+    }
   }
 }
 
@@ -310,26 +378,11 @@ TNode EqualityEngine::getRepresentative(TNode t) const {
   return d_nodes[representativeId];
 }
 
-bool EqualityEngine::areEqual(TNode t1, TNode t2) const {
-  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl;
-
-  Assert(hasTerm(t1));
-  Assert(hasTerm(t2));
-
-  // Both following commands are semantically const
-  EqualityNodeId rep1 = getEqualityNode(t1).getFind();
-  EqualityNodeId rep2 = getEqualityNode(t2).getFind();
-
-  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl;
-
-  return rep1 == rep2;
-}
-
-bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggers) {
+bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector<TriggerId>& triggersFired) {
 
   Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl;
 
-  Assert(triggers.empty());
+  Assert(triggersFired.empty());
 
   ++ d_stats.mergesCount;
 
@@ -338,15 +391,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
 
   // Check for constant merges
   bool isConstant = d_isConstant[class1Id];
-  if (isConstant && d_isConstant[class2Id]) {
-    if (d_performNotify) {
-      if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) {
-        // Now merge the so that backtracking is OK
-        class1.merge<true>(class2);
-        return false;
-      } 
-    }
-  } 
+
   // Update class2 representative information
   Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl;
   EqualityNodeId currentId = class2Id;
@@ -369,8 +414,19 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
         trigger.classId = class1Id;
         // If they became the same, call the trigger
         if (otherTrigger.classId == class1Id) {
+          const TriggerInfo& triggerInfo = d_equalityTriggersOriginal[currentTrigger];
+          if (triggerInfo.trigger.getKind() == kind::EQUAL && !triggerInfo.polarity) {
+            TNode equality = triggerInfo.trigger;
+            EqualityNodeId original = getNodeId(equality);
+            d_deducedDisequalityReasons.push_back(EqualityPair(original, d_falseId));
+            if (!storePropagatedDisequality(equality[0], equality[1], 1)) {
+              // Go to the next trigger
+              currentTrigger = trigger.nextTrigger;
+              continue;
+            }
+          }
           // Id of the real trigger is half the internal one
-          triggers.push_back(currentTrigger);
+          triggersFired.push_back(currentTrigger);
         }
       }
 
@@ -393,7 +449,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
       Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl;
  
       // Go through the uselist and check for congruences
-      UseListNodeId currentUseId = currentNode.getUseList();
+      UseListNodeId currentUseId = currentNode.getUseList<USE_LIST_FUNCTIONS>();
       while (currentUseId != null_uselist_id) {
         // Get the node of the use list
         UseListNode& useNode = d_useListNodes[currentUseId];
@@ -422,7 +478,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
                 enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null()));
               }
             }
-          }
+          }          
         }
    
         // Go to the next one in the use list
@@ -502,6 +558,15 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect
     }  
   }
 
+  // Notify of the constants merge
+  if (isConstant && d_isConstant[class2Id]) {
+    if (d_performNotify) {
+      if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) {
+        return false;
+      }
+    }
+  }
+
   // Everything fine
   return true;
 }
@@ -610,9 +675,9 @@ void EqualityEngine::backtrack() {
       const FunctionApplication& app = d_applications[i].normalized;
       if (app.isApplication()) {
         // Remove b from use-list
-        getEqualityNode(app.b).removeTopFromUseList(d_useListNodes);
+        getEqualityNode(app.b).removeTopFromUseList<USE_LIST_FUNCTIONS>(d_useListNodes);
         // Remove a from use-list
-        getEqualityNode(app.a).removeTopFromUseList(d_useListNodes);
+        getEqualityNode(app.a).removeTopFromUseList<USE_LIST_FUNCTIONS>(d_useListNodes);
       }
     }
 
@@ -626,6 +691,20 @@ void EqualityEngine::backtrack() {
     d_equalityGraph.resize(d_nodesCount);
     d_equalityNodes.resize(d_nodesCount);
   }
+
+  if (d_deducedDisequalities.size() > d_deducedDisequalitiesSize) {
+    for(int i = d_deducedDisequalities.size() - 1, i_end = (int)d_deducedDisequalitiesSize; i >= i_end; -- i) {
+      EqualityPair pair = d_deducedDisequalities[i];
+      DisequalityReasonRef reason = d_disequalityReasonsMap[pair];
+      // Remove from the map
+      d_disequalityReasonsMap.erase(pair);
+      std::swap(pair.first, pair.second);
+      d_disequalityReasonsMap.erase(pair);
+      // Resize the reasons vector
+      d_deducedDisequalityReasons.resize(reason.mergesStart);
+    }
+    d_deducedDisequalities.resize(d_deducedDisequalitiesSize);
+  }
 }
 
 void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) {
@@ -658,42 +737,35 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const {
   return out.str();
 }
 
-void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& equalities) {
-  Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl;
+void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& equalities) const {
+  Debug("equality") << "EqualityEngine::explainEquality(" << t1 << ", " << t2 << ", " << (polarity ? "true" : "false") << ")" << std::endl;
 
-  // Don't notify during this check
-  ScopedBool turnOffNotify(d_performNotify, false);
+  // The terms must be there already
+  Assert(hasTerm(t1) && hasTerm(t2));;
 
-  // Add the terms (they might not be there)
-  addTerm(t1);
-  addTerm(t2);
+  // Get the ids
+  EqualityNodeId t1Id = getNodeId(t1);
+  EqualityNodeId t2Id = getNodeId(t2);
 
   if (polarity) {
     // Get the explanation
-    EqualityNodeId t1Id = getNodeId(t1);
-    EqualityNodeId t2Id = getNodeId(t2);
     getExplanation(t1Id, t2Id, equalities);
   } else {
-    // Add the equality
-    Node equality = t1.eqNode(t2);
-    addTerm(equality);
-
-    // Get the explanation
-    EqualityNodeId equalityId = getNodeId(equality);
-    EqualityNodeId falseId = getNodeId(d_false);
-    getExplanation(equalityId, falseId, equalities);
+    // Get the reason for this disequality
+    EqualityPair pair(t1Id, t2Id);
+    Assert(d_disequalityReasonsMap.find(pair) != d_disequalityReasonsMap.end(), "Don't ask for stuff I didn't notify you about");
+    DisequalityReasonRef reasonRef = d_disequalityReasonsMap.find(pair)->second;
+    for (unsigned i = reasonRef.mergesStart; i < reasonRef.mergesEnd; ++ i) {
+      EqualityPair toExplain = d_deducedDisequalityReasons[i];
+      getExplanation(toExplain.first, toExplain.second, equalities);
+    }
   }
 }
 
-void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions) {
-  Debug("equality") << "EqualityEngine::explainEquality(" << p << ")" << std::endl;
-
-  // Don't notify during this check
-  ScopedBool turnOffNotify(d_performNotify, false);
-
-  // Add the terms
-  addTerm(p);
-
+void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions) const {
+  Debug("equality") << "EqualityEngine::explainPredicate(" << p << ")" << std::endl;
+  // Must have the term
+  Assert(hasTerm(p));
   // Get the explanation
   getExplanation(getNodeId(p), polarity ? d_trueId : d_falseId, assertions);
 }
@@ -702,11 +774,16 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st
 
   Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl;
 
-  // We can only explain the nodes that got merged (or between 
-  // constants since they didn't get merged but we stil added the 
-  // edge in the graph equality 
-  Assert(getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind() ||
-         (d_isConstant[getEqualityNode(t1Id).getFind()] && d_isConstant[getEqualityNode(t2Id).getFind()]));
+  // We can only explain the nodes that got merged
+#ifdef CVC4_ASSERTIONS
+  bool canExplain = getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind();
+  if (!canExplain) {
+    Warning() << "Can't explain equality:" << std::endl;
+    Warning() << d_nodes[t1Id] << " with find " << d_nodes[getEqualityNode(t1Id).getFind()] << std::endl;
+    Warning() << d_nodes[t2Id] << " with find " << d_nodes[getEqualityNode(t2Id).getFind()] << std::endl;    
+  }
+  Assert(canExplain);
+#endif
 
   // If the nodes are the same, we're done
   if (t1Id == t2Id) return;
@@ -778,7 +855,7 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st
               equalities.push_back(d_equalityEdges[currentEdge].getReason());
               break;
             case MERGED_THROUGH_CONSTANTS: {
-              // (a = b) == false bacause a and b are different constants
+              // (a = b) == false because a and b are different constants
               Debug("equality") << "EqualityEngine::getExplanation(): due to constants, going deeper" << std::endl;
               EqualityNodeId eqId = currentNode == d_falseId ? edgeNode : currentNode;
               const FunctionApplication& eq = d_applications[eqId].original;
@@ -822,8 +899,34 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st
 
 void EqualityEngine::addTriggerEquality(TNode eq) {
   Assert(eq.getKind() == kind::EQUAL);
+
+  if (d_done) {
+    return;
+  }
+
   // Add the terms
+  addTerm(eq[0]);
+  addTerm(eq[1]);
+
+  bool skipTrigger = false;
+
+  // If they are equal or disequal already, no need for the trigger
+  if (areEqual(eq[0], eq[1])) {
+    d_notify.eqNotifyTriggerEquality(eq, true);
+    skipTrigger = true;
+  }
+  if (areDisequal(eq[0], eq[1], true)) {
+    d_notify.eqNotifyTriggerEquality(eq, false);
+    skipTrigger = true;
+  }
+
+  if (skipTrigger) {
+    return;
+  }
+
+  // Add the equality
   addTerm(eq);
+
   // Positive trigger
   addTriggerEqualityInternal(eq[0], eq[1], eq, true);
   // Negative trigger
@@ -833,8 +936,30 @@ void EqualityEngine::addTriggerEquality(TNode eq) {
 void EqualityEngine::addTriggerPredicate(TNode predicate) {
   Assert(predicate.getKind() != kind::NOT && predicate.getKind() != kind::EQUAL);
   Assert(d_congruenceKinds.tst(predicate.getKind()), "No point in adding non-congruence predicates");
+
+  if (d_done) {
+    return;
+  }
+
   // Add the term
   addTerm(predicate);
+
+  bool skipTrigger = false;
+
+  // If it's know already, no need for the trigger
+  if (areEqual(predicate, d_true)) {
+    d_notify.eqNotifyTriggerPredicate(predicate, true);
+    skipTrigger = true;
+  }
+  if (areEqual(predicate, d_false)) {
+    d_notify.eqNotifyTriggerPredicate(predicate, false);
+    skipTrigger = true;
+  }
+
+  if (skipTrigger) {
+    return;
+  }
+
   // Positive trigger
   addTriggerEqualityInternal(predicate, d_true, predicate, true);
   // Negative trigger
@@ -848,41 +973,41 @@ void EqualityEngine::addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigge
   Assert(hasTerm(t1));
   Assert(hasTerm(t2));
 
+  if (d_done) {
+    return;
+  }
+
   // Get the information about t1
   EqualityNodeId t1Id = getNodeId(t1);
   EqualityNodeId t1classId = getEqualityNode(t1Id).getFind();
+  // We will attach it to the class representative, since then we know how to backtrack it
   TriggerId t1TriggerId = d_nodeTriggers[t1classId];
 
   // Get the information about t2
   EqualityNodeId t2Id = getNodeId(t2);
   EqualityNodeId t2classId = getEqualityNode(t2Id).getFind();
+  // We will attach it to the class representative, since then we know how to backtrack it
   TriggerId t2TriggerId = d_nodeTriggers[t2classId];
 
   Debug("equality") << "EqualityEngine::addTrigger(" << trigger << "): " << t1Id << " (" << t1classId << ") = " << t2Id << " (" << t2classId << ")" << std::endl;
 
   // Create the triggers
   TriggerId t1NewTriggerId = d_equalityTriggers.size();
-  TriggerId t2NewTriggerId = t1NewTriggerId | 1;
   d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId));
   d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity));
+  TriggerId t2NewTriggerId = d_equalityTriggers.size();
   d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId));
   d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity));
 
   // Update the counters
-  d_equalityTriggersCount = d_equalityTriggersCount + 2;
+  d_equalityTriggersCount = d_equalityTriggers.size();
+  Assert(d_equalityTriggers.size() == d_equalityTriggersOriginal.size());
+  Assert(d_equalityTriggers.size() % 2 == 0);
 
   // Add the trigger to the trigger graph
   d_nodeTriggers[t1classId] = t1NewTriggerId;
   d_nodeTriggers[t2classId] = t2NewTriggerId;
 
-  // If the trigger is already on, we propagate
-  if (t1classId == t2classId) {
-    Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl;
-    if (d_performNotify) {
-      d_notify.eqNotifyTriggerEquality(trigger, polarity); // Don't care about the return value
-    }
-  }
-
   if (Debug.isOn("equality::internal")) {
     debugPrintGraph();
   }
@@ -978,39 +1103,105 @@ void EqualityEngine::debugPrintGraph() const {
   }
 }
 
-bool EqualityEngine::areEqual(TNode t1, TNode t2)
-{
-  // Don't notify during this check
-  ScopedBool turnOffNotify(d_performNotify, false);
+bool EqualityEngine::areEqual(TNode t1, TNode t2) const {
+  Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl;
 
-  // Add the terms
-  addTerm(t1);
-  addTerm(t2);
-  bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind();
+  Assert(hasTerm(t1));
+  Assert(hasTerm(t2));
 
-  // Return whether the two terms are equal
-  return equal;
+  return getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind();
 }
 
-bool EqualityEngine::areDisequal(TNode t1, TNode t2)
+bool EqualityEngine::areDisequal(TNode t1, TNode t2, bool ensureProof) const
 {
-  // Don't notify during this check
-  ScopedBool turnOffNotify(d_performNotify, false);
+  Debug("equality") << "EqualityEngine::areDisequal(" << t1 << "," << t2 << ")" << std::endl;
 
   // Add the terms
-  addTerm(t1);
-  addTerm(t2);
+  Assert(hasTerm(t1));
+  Assert(hasTerm(t2));
+
+  // Get ids
+  EqualityNodeId t1Id = getNodeId(t1);
+  EqualityNodeId t2Id = getNodeId(t2);
 
-  // Check (t1 = t2) = false
-  // No need to check the symmetric version: we can only deduce a disequality from an existing
-  // diseqality, and each of those is asserted in the symmetric version also
-  Node equality = t1.eqNode(t2);
-  addTerm(equality);
-  if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) {
+  // Get equivalence classes
+  EqualityNodeId t1ClassId = getEqualityNode(t1Id).getFind();
+  EqualityNodeId t2ClassId = getEqualityNode(t2Id).getFind();
+
+  // We are semantically const, for remembering stuff
+  EqualityEngine* nonConst = const_cast<EqualityEngine*>(this);
+
+  // Check for constants
+  if (d_isConstant[t1ClassId] && d_isConstant[t2ClassId] && t1ClassId != t2ClassId) {
+    if (ensureProof) {
+      nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t1Id, t1ClassId));
+      nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t2Id, t2ClassId));
+      storePropagatedDisequality(t1, t2, 2);
+    }
     return true;
   }
 
-  // Return whether the terms are disequal
+  // Check the equality itself if it exists
+  Node eq = t1.eqNode(t2);
+  if (hasTerm(eq)) {
+    if (getEqualityNode(eq).getFind() == getEqualityNode(d_falseId).getFind()) {
+      if (ensureProof) {
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(getNodeId(eq), d_falseId));
+        storePropagatedDisequality(t1, t2, 1);
+      }
+      return true;
+    }
+  }
+  // Check the other equality itself if it exists
+  eq = t2.eqNode(t1);
+  if (hasTerm(eq)) {
+    if (getEqualityNode(eq).getFind() == getEqualityNode(d_false).getFind()) {
+      if (ensureProof) {
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(getNodeId(eq), d_falseId));
+        storePropagatedDisequality(t1, t2, 1);
+      }
+      return true;
+    }
+  }
+  
+  // Create the equality
+  FunctionApplication eqNormalized(true, t1ClassId, t2ClassId);
+  ApplicationIdsMap::const_iterator find = d_applicationLookup.find(eqNormalized);
+  if (find != d_applicationLookup.end()) {
+    if (getEqualityNode(find->second).getFind() == getEqualityNode(d_falseId).getFind()) {
+      if (ensureProof) {
+        const FunctionApplication original = d_applications[find->second].original;
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t1Id, original.a));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t2Id, original.b));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.a, t1ClassId));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.b, t2ClassId));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(find->second, d_falseId));
+        storePropagatedDisequality(t1, t2, 5);
+      }
+      return true;
+    }
+  }
+  
+  // Check the symmetric disequality
+  std::swap(eqNormalized.a, eqNormalized.b);
+  find = d_applicationLookup.find(eqNormalized);
+  if (find != d_applicationLookup.end()) {
+    if (getEqualityNode(find->second).getFind() == getEqualityNode(d_falseId).getFind()) {
+      if (ensureProof) {
+        const FunctionApplication original = d_applications[find->second].original;
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t2Id, original.a));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t1Id, original.b));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.a, t2ClassId));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.b, t1ClassId));
+        nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(find->second, d_falseId));
+        storePropagatedDisequality(t1, t2, 5);
+      }
+      return true;
+    }
+  }
+    
+  // Couldn't deduce dis-equalityReturn whether the terms are disequal
   return false;
 }
 
@@ -1026,6 +1217,10 @@ void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag)
   Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ", " << tag << ")" << std::endl;
   Assert(tag != THEORY_LAST);
 
+  if (d_done) {
+    return;
+  }
+
   // Add the term if it's not already there
   addTerm(t);
 
@@ -1040,7 +1235,9 @@ void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag)
     // If the term already is in the equivalence class that a tagged representative, just notify
     if (d_performNotify) {
       EqualityNodeId triggerId = getTriggerTermSet(triggerSetRef).getTrigger(tag);
-      d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true);
+      if (!d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true)) {
+        d_done = true;
+      }
     }
   } else {
 
@@ -1117,7 +1314,7 @@ void EqualityEngine::getUseListTerms(TNode t, std::set<TNode>& output) {
       // Get the current node
       EqualityNode& currentNode = getEqualityNode(currentId);
       // Go through the use-list
-      UseListNodeId currentUseId = currentNode.getUseList();
+      UseListNodeId currentUseId = currentNode.getUseList<USE_LIST_FUNCTIONS>();
       while (currentUseId != null_uselist_id) {
         // Get the node of the use list
         UseListNode& useNode = d_useListNodes[currentUseId];
@@ -1157,6 +1354,41 @@ EqualityEngine::TriggerTermSetRef EqualityEngine::newTriggerTermSet() {
   return newTriggerSetRef;
 }
 
+bool EqualityEngine::storePropagatedDisequality(TNode lhs, TNode rhs, unsigned reasonsCount) const {
+
+  Assert(reasonsCount > 0);
+
+
+  EqualityNodeId lhsId = getNodeId(lhs);
+  EqualityNodeId rhsId = getNodeId(rhs);
+
+  // We are semantically const, just remembering stuff for later
+  EqualityEngine* nonConst = const_cast<EqualityEngine*>(this);
+
+  Assert(d_deducedDisequalityReasons.size() >= reasonsCount);
+  DisequalityReasonRef ref(d_deducedDisequalityReasons.size() - reasonsCount, d_deducedDisequalityReasons.size());
+
+#ifdef CVC4_ASSERTIONS
+  for (unsigned i = ref.mergesStart; i < ref.mergesEnd; ++ i) {
+    Assert(getEqualityNode(d_deducedDisequalityReasons[i].first).getFind() == getEqualityNode(d_deducedDisequalityReasons[i].second).getFind());
+  }
+#endif
+
+  EqualityPair pair(lhsId, rhsId);
+  DisequalityReasonsMap::const_iterator find = d_disequalityReasonsMap.find(pair);
+  if (find == d_disequalityReasonsMap.end()) {
+    nonConst->d_disequalityReasonsMap[pair] = ref;
+    nonConst->d_deducedDisequalities.push_back(pair);
+    nonConst->d_deducedDisequalitiesSize = d_deducedDisequalities.size();
+    std::swap(pair.first, pair.second);
+    nonConst->d_disequalityReasonsMap[pair] = ref;
+    return true;
+  } else {
+    nonConst->d_deducedDisequalities.resize(d_deducedDisequalitiesSize);
+    return false;
+  }
+}
+
 
 } // Namespace uf
 } // Namespace theory
index 8fc57eb489d39da199e08e9ea5906d4ad4f941cb..5ff8ee4dc3e7adcad546d4f77ce673302a1f3b81 100644 (file)
@@ -187,20 +187,6 @@ private:
   /** A context-dependents count of nodes */
   context::CDO<DefaultSizeType> d_nodesCount;
 
-  /**
-   * At time of addition a function application can already normalize to something, so
-   * we keep both the original, and the normalized version.
-   */
-  struct FunctionApplicationPair {
-    FunctionApplication original;
-    FunctionApplication normalized;
-    FunctionApplicationPair() {}
-    FunctionApplicationPair(const FunctionApplication& original, const FunctionApplication& normalized)
-    : original(original), normalized(normalized) {}
-    bool isNull() const {
-      return !original.isApplication();
-    }
-  };
   /** Map from ids to the applications */
   std::vector<FunctionApplicationPair> d_applications;
 
@@ -336,16 +322,6 @@ private:
    */
   std::vector<Trigger> d_equalityTriggers;
 
-  struct TriggerInfo {
-    /** The trigger itself */
-    Node trigger;
-    /** Polarity of the trigger */
-    bool polarity;
-    TriggerInfo() {}
-    TriggerInfo(Node trigger, bool polarity)
-    : trigger(trigger), polarity(polarity) {}
-  };
-
   /**
    * Vector of original equalities of the triggers.
    */
@@ -515,6 +491,31 @@ private:
    */
   std::vector<TriggerTermSetRef> d_nodeIndividualTrigger;
 
+  typedef std::hash_map<EqualityPair, DisequalityReasonRef, EqualityPairHashFunction> DisequalityReasonsMap;
+
+  /**
+   * A map from pairs of disequal terms, to the reason why we deduced they are disequal.
+   */
+  DisequalityReasonsMap d_disequalityReasonsMap;
+
+  /**
+   * A list of all the disequalities we deduced.
+   */
+  std::vector<EqualityPair> d_deducedDisequalities;
+
+  /**
+   * Context dependent size of the deduced disequalities
+   */
+  context::CDO<size_t> d_deducedDisequalitiesSize;
+
+  /**
+   * For each disequality deduced, we add the pairs of equivalences needed to explain it.
+   */
+  std::vector<EqualityPair> d_deducedDisequalityReasons;
+
+
+  bool storePropagatedDisequality(TNode lhs, TNode rhs, unsigned reasonsCount) const;
+
 public:
 
   /**
@@ -593,24 +594,19 @@ public:
    */
   void getUseListTerms(TNode t, std::set<TNode>& output);
 
-  /**
-   * Returns true if the two nodes are in the same equivalence class.
-   */
-  bool areEqual(TNode t1, TNode t2) const;
-
   /**
    * Get an explanation of the equality t1 = t2 begin true of false. 
    * Returns the reasons (added when asserting) that imply it
    * in the assertions vector.
    */
-  void explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& assertions);
+  void explainEquality(TNode t1, TNode t2, bool polarity, std::vector<TNode>& assertions) const;
 
   /**
    * Get an explanation of the predicate being true or false. 
    * Returns the reasons (added when asserting) that imply imply it
    * in the assertions vector.
    */
-  void explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions);
+  void explainPredicate(TNode p, bool polarity, std::vector<TNode>& assertions) const;
 
   /**
    * Add term to the set of trigger terms with a corresponding tag. The notify class will get
@@ -652,14 +648,14 @@ public:
   void addTriggerPredicate(TNode predicate);
 
   /**
-   * Check whether the two terms are equal.
+   * Returns true if the two are currently in the database and equal.
    */
-  bool areEqual(TNode t1, TNode t2);
+  bool areEqual(TNode t1, TNode t2) const;
 
   /**
    * Check whether the two term are dis-equal.
    */
-  bool areDisequal(TNode t1, TNode t2);
+  bool areDisequal(TNode t1, TNode t2, bool ensureProof) const;
 
   /**
    * Return the number of nodes in the equivalence class containing t
@@ -667,6 +663,11 @@ public:
    */
   size_t getSize(TNode t);
 
+  /**
+   * Returns true if the engine is in a consistents state.
+   */
+  bool consistent() const { return !d_done; }
+
 };
 
 } // Namespace uf
index a0d84a1edea219d872d7223993504b247c7301d4..0baf70fcf29c56aed59d52e49410961ac129d2d2 100644 (file)
@@ -98,6 +98,16 @@ struct MergeCandidate {
     t1Id(x), t2Id(y), type(type), reason(reason) {}
 };
 
+/**
+ * Just an index into the reasons array, and the number of merges to consume.
+ */
+struct DisequalityReasonRef {
+  DefaultSizeType mergesStart;
+  DefaultSizeType mergesEnd;
+  DisequalityReasonRef(DefaultSizeType mergesStart = 0, DefaultSizeType mergesEnd = 0)
+  : mergesStart(mergesStart), mergesEnd(mergesEnd) {}
+};
+
 /** 
  * We mantaint uselist where a node appears in, and this is the node
  * of such a list. 
@@ -135,6 +145,15 @@ public:
   }
 };
 
+/** Main types of uselists */
+enum UseListType {
+  /** Use list of functions where the term appears in */
+  USE_LIST_FUNCTIONS,
+  /** Use list of asserted disequalities */
+  USE_LIST_DISEQUALITIES
+};
+
+
 /**
  * Main class for representing nodes in the equivalence class. The 
  * nodes are a circular list, with the representative carrying the
@@ -159,23 +178,28 @@ private:
   /** The use list of this node */
   UseListNodeId d_useList;
 
+  /** The list of asserted disequalities that this node appears in */
+  UseListNodeId d_diseqList;
+
 public:
 
   /**
    * Creates a new node, which is in a list of it's own.
    */
   EqualityNode(EqualityNodeId nodeId = null_id)
-  : d_size(1), 
-    d_findId(nodeId), 
-    d_nextId(nodeId), 
-    d_useList(null_uselist_id)
+  : d_size(1)
+  , d_findId(nodeId) 
+  , d_nextId(nodeId)
+  , d_useList(null_uselist_id)
+  , d_diseqList(null_uselist_id)
   {}
 
   /**
-   * Returns the function uselist.
+   * Returns the requested uselist.
    */
+  template<UseListType type>
   UseListNodeId getUseList() const {
-    return d_useList;
+    return type == USE_LIST_FUNCTIONS ? d_useList : d_diseqList;
   }
 
   /**
@@ -220,24 +244,38 @@ public:
    * Note that this node is used in a function application funId, or
    * a negatively asserted equality (dis-equality) with funId. 
    */
-  template<typename memory_class>
+  template<UseListType type, typename memory_class>
   void usedIn(EqualityNodeId funId, memory_class& memory) {
+    UseListNodeId& useList = type == USE_LIST_FUNCTIONS ? d_useList : d_diseqList;
     UseListNodeId newUseId = memory.size();
-    memory.push_back(UseListNode(funId, d_useList));
-    d_useList = newUseId;
+    memory.push_back(UseListNode(funId, useList));
+    useList = newUseId;
   }
 
   /**
    * For backtracking: remove the first element from the uselist and pop the memory.
    */
-  template<typename memory_class>
+  template<UseListType type, typename memory_class>
   void removeTopFromUseList(memory_class& memory) {
-    Assert ((int)d_useList == (int)memory.size() - 1);
-    d_useList = memory.back().getNext();
+    UseListNodeId& useList = type == USE_LIST_FUNCTIONS ? d_useList : d_diseqList;
+    Assert ((int) useList == (int)memory.size() - 1);
+    useList = memory.back().getNext();
     memory.pop_back();
   }
 };
 
+/** A pair of ids */
+typedef std::pair<EqualityNodeId, EqualityNodeId> EqualityPair;
+
+struct EqualityPairHashFunction {
+  size_t operator () (const EqualityPair& pair) const {
+    size_t hash = 0;
+    hash = 0x9e3779b9 + pair.first;
+    hash ^= 0x9e3779b9 + pair.second + (hash << 6) + (hash >> 2);
+    return hash;
+  }
+};
+
 /**
  * Represents the function APPLY a b. If isEquality is true then it
  * represents the predicate (a = b). Note that since one can not 
@@ -266,6 +304,35 @@ struct FunctionApplicationHashFunction {
   }
 };
 
+/**
+ * At time of addition a function application can already normalize to something, so
+ * we keep both the original, and the normalized version.
+ */
+struct FunctionApplicationPair {
+  FunctionApplication original;
+  FunctionApplication normalized;
+  FunctionApplicationPair() {}
+  FunctionApplicationPair(const FunctionApplication& original, const FunctionApplication& normalized)
+  : original(original), normalized(normalized) {}
+  bool isNull() const {
+    return !original.isApplication();
+  }
+};
+
+/**
+ * Information about the added triggers.
+ */
+struct TriggerInfo {
+  /** The trigger itself */
+  Node trigger;
+  /** Polarity of the trigger */
+  bool polarity;
+  TriggerInfo() {}
+  TriggerInfo(Node trigger, bool polarity)
+  : trigger(trigger), polarity(polarity) {}
+  };
+
+
 } // namespace eq
 } // namespace theory
 } // namespace CVC4
index 9c1229f80b105abda050fabacaac095ea2d84531..ae8bdc8dafa82c16d37d3c2a525c810d22884e08 100644 (file)
@@ -332,7 +332,7 @@ EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) {
   }
 
   // Check for disequality
-  if (d_equalityEngine.areDisequal(a, b)) {
+  if (d_equalityEngine.areDisequal(a, b, false)) {
     // The terms are implied to be dis-equal
     return EQUALITY_FALSE;
   }
@@ -388,16 +388,14 @@ void TheoryUF::computeCareGraph() {
 
           Debug("uf::sharing") << "TheoryUf::computeCareGraph(): checking arguments " << x << " and " << y << std::endl;
 
-          EqualityStatus eqStatusUf = getEqualityStatus(x, y);
-
-          if (eqStatusUf == EQUALITY_FALSE) {
+          if (d_equalityEngine.areDisequal(x, y, false)) {
             // Mark that there is a dis-equal pair and break
             somePairIsDisequal = true;
             Debug("uf::sharing") << "TheoryUf::computeCareGraph(): disequal, disregarding all" << std::endl;
             break;
           }
 
-          if (eqStatusUf == EQUALITY_TRUE) {
+          if (d_equalityEngine.areEqual(x, y)) {
             // We don't need this one
             Debug("uf::sharing") << "TheoryUf::computeCareGraph(): equal" << std::endl;
             continue;