some uf cleanup
authorDejan Jovanović <dejan.jovanovic@gmail.com>
Tue, 23 Aug 2011 23:43:01 +0000 (23:43 +0000)
committerDejan Jovanović <dejan.jovanovic@gmail.com>
Tue, 23 Aug 2011 23:43:01 +0000 (23:43 +0000)
src/theory/uf/equality_engine.h
src/theory/uf/equality_engine_impl.h

index ac02fe4db90955e9c12688f43258751a47c63a0a..bb5c4694550e5379f324f5cd0ce13ff663c32fb2 100644 (file)
@@ -201,7 +201,7 @@ public:
 };
 
 template <typename NotifyClass>
-class EqualityEngine {
+class EqualityEngine : public context::ContextNotifyObj {
 
 public:
 
@@ -213,26 +213,21 @@ public:
     IntStat termsCount;
     /** Number of function terms managed by the system */
     IntStat functionTermsCount;
-    /** Number of times we performed a backtrack */
-    IntStat backtracksCount;
 
     Statistics(std::string name)
     : mergesCount(name + "::mergesCount", 0),
       termsCount(name + "::termsCount", 0),
-      functionTermsCount(name + "::functionTermsCount", 0),
-      backtracksCount(name + "::backtracksCount", 0)
+      functionTermsCount(name + "::functionTermsCount", 0)
     {
       StatisticsRegistry::registerStat(&mergesCount);
       StatisticsRegistry::registerStat(&termsCount);
       StatisticsRegistry::registerStat(&functionTermsCount);
-      StatisticsRegistry::registerStat(&backtracksCount);
     }
 
     ~Statistics() {
       StatisticsRegistry::unregisterStat(&mergesCount);
       StatisticsRegistry::unregisterStat(&termsCount);
       StatisticsRegistry::unregisterStat(&functionTermsCount);
-      StatisticsRegistry::unregisterStat(&backtracksCount);
     }
   };
 
@@ -374,9 +369,15 @@ private:
   /** Returns the equality node of the given node */
   EqualityNode& getEqualityNode(TNode node);
 
+  /** Returns the equality node of the given node */
+  const EqualityNode& getEqualityNode(TNode node) const;
+
   /** Returns the equality node of the given node */
   EqualityNode& getEqualityNode(EqualityNodeId nodeId);
 
+  /** Returns the equality node of the given node */
+  const EqualityNode& getEqualityNode(EqualityNodeId nodeId) const;
+
   /** Returns the id of the node */
   EqualityNodeId getNodeId(TNode node) const;
 
@@ -470,8 +471,8 @@ private:
   /** Enqueue to the propagation queue */
   void enqueue(const MergeCandidate& candidate);
 
-  /** Do the propagation (if check is on, congruences are checked again) */
-  void propagate(bool check);
+  /** Do the propagation */
+  void propagate();
 
   /**
    * Get an explanation of the equality t1 = t2. Returns the asserted equalities that
@@ -483,7 +484,7 @@ private:
   /**
    * Print the equality graph.
    */
-  void debugPrintGraph();
+  void debugPrintGraph() const;
 
 public:
 
@@ -492,12 +493,24 @@ public:
    * the owner information.
    */
   EqualityEngine(NotifyClass& notify, context::Context* context, std::string name)
-  : d_notify(notify), d_assertedEqualitiesCount(context, 0), d_stats(name) {
+  : ContextNotifyObj(context), d_notify(notify), d_assertedEqualitiesCount(context, 0), d_stats(name) {
     Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl;
     Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl;
     Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl;
   }
 
+  /**
+   * Just a destructor.
+   */
+  virtual ~EqualityEngine() throw(AssertionException) {}
+
+  /**
+   * This method gets called on backtracks from the context manager.
+   */
+  void notify() {
+    backtrack();
+  }
+
   /**
    * Adds a term to the term database.
    */
index a19ec8d66736b6762bb9f6685bb3780452940acd..cc73e1917a4008418429c8461c2328a2a8ef6d23 100644 (file)
@@ -59,7 +59,7 @@ EqualityNodeId EqualityEngine<NotifyClass>::newApplicationNode(TNode original, E
     // If it's there, we need to merge these two
     Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): lookup exists, adding to queue" << std::endl;
     enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null()));
-    propagate(false);
+    propagate();
   }
 
   // Add to the use lists
@@ -155,21 +155,29 @@ EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(EqualityNodeId nodeId
   return d_equalityNodes[nodeId];
 }
 
+template <typename NotifyClass>
+const EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(TNode t) const {
+  return getEqualityNode(getNodeId(t));
+}
+
+template <typename NotifyClass>
+const EqualityNode& EqualityEngine<NotifyClass>::getEqualityNode(EqualityNodeId nodeId) const {
+  Assert(nodeId < d_equalityNodes.size());
+  return d_equalityNodes[nodeId];
+}
+
 template <typename NotifyClass>
 void EqualityEngine<NotifyClass>::addEquality(TNode t1, TNode t2, TNode reason) {
 
   Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl;
 
-  // Backtrack if necessary
-  backtrack();
-
   // Add the terms if they are not already in the database
   EqualityNodeId t1Id = getNodeId(t1);
   EqualityNodeId t2Id = getNodeId(t2);
 
   // Add to the queue and propagate
   enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason));
-  propagate(false);
+  propagate();
 }
 
 template <typename NotifyClass>
@@ -180,8 +188,7 @@ TNode EqualityEngine<NotifyClass>::getRepresentative(TNode t) const {
   Assert(hasTerm(t));
 
   // Both following commands are semantically const
-  const_cast<EqualityEngine*>(this)->backtrack();
-  EqualityNodeId representativeId = const_cast<EqualityEngine*>(this)->getEqualityNode(t).getFind();
+  EqualityNodeId representativeId = getEqualityNode(t).getFind();
 
   Debug("equality") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl;
 
@@ -196,9 +203,8 @@ bool EqualityEngine<NotifyClass>::areEqual(TNode t1, TNode t2) const {
   Assert(hasTerm(t2));
 
   // Both following commands are semantically const
-  const_cast<EqualityEngine*>(this)->backtrack();
-  EqualityNodeId rep1 = const_cast<EqualityEngine*>(this)->getEqualityNode(t1).getFind();
-  EqualityNodeId rep2 = const_cast<EqualityEngine*>(this)->getEqualityNode(t2).getFind();
+  EqualityNodeId rep1 = getEqualityNode(t1).getFind();
+  EqualityNodeId rep2 = getEqualityNode(t2).getFind();
 
   Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl;
 
@@ -383,12 +389,6 @@ void EqualityEngine<NotifyClass>::undoMerge(EqualityNode& class1, EqualityNode&
       // If the id doesn't exist, we'll set it
       if (find == d_applicationLookup.end()) {
         d_applicationLookup[funNormalized] = funId;
-      } else {
-        // Otherwise, we might be congruent agaain
-        if (getEqualityNode(funId).getFind() != getEqualityNode(find->second).getFind()) {
-          // Damn, we might be merging again, but we'll check this later
-          enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null()));
-        }
       }
       // Go to the next one in the use list
       currentUseId = useNode.getNext();
@@ -411,8 +411,6 @@ void EqualityEngine<NotifyClass>::backtrack() {
       d_propagationQueue.pop();
     }
 
-    ++ d_stats.backtracksCount;
-
     Debug("equality") << "EqualityEngine::backtrack(): nodes" << std::endl;
 
     for (int i = (int)d_assertedEqualities.size() - 1, i_end = (int)d_assertedEqualitiesCount; i >= i_end; --i) {
@@ -434,9 +432,6 @@ void EqualityEngine<NotifyClass>::backtrack() {
     }
 
     d_equalityEdges.resize(2 * d_assertedEqualitiesCount);
-
-    // Now repropagate if something got reenqueued
-    propagate(true);
   }
 }
 
@@ -450,7 +445,7 @@ void EqualityEngine<NotifyClass>::addGraphEdge(EqualityNodeId t1, EqualityNodeId
   d_equalityGraph[t2] = edge | 1;
 
   if (Debug.isOn("equality::internal")) {
-    const_cast<EqualityEngine*>(this)->debugPrintGraph();
+    debugPrintGraph();
   }
 }
 
@@ -478,9 +473,6 @@ void EqualityEngine<NotifyClass>::getExplanation(TNode t1, TNode t2, std::vector
 
   Assert(getRepresentative(t1) == getRepresentative(t2));
 
-  // Backtrack if necessary
-  const_cast<EqualityEngine*>(this)->backtrack();
-
   // Get the explanation
   EqualityNodeId t1Id = getNodeId(t1);
   EqualityNodeId t2Id = getNodeId(t2);
@@ -623,7 +615,7 @@ void EqualityEngine<NotifyClass>::addTriggerEquality(TNode t1, TNode t2, TNode t
 }
 
 template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::propagate(bool check) {
+void EqualityEngine<NotifyClass>::propagate() {
 
   Debug("equality") << "EqualityEngine::propagate()" << std::endl;
 
@@ -648,14 +640,6 @@ void EqualityEngine<NotifyClass>::propagate(bool check) {
       continue;
     }
 
-    // If check is on, and a congruence, check the arguments (it might be from a backtrack)
-    if (check && current.type == MERGED_THROUGH_CONGRUENCE) {
-      const FunctionApplication& f1 = d_applications[current.t1Id];
-      const FunctionApplication& f2 = d_applications[current.t2Id];
-      if (getEqualityNode(f1.a).getFind() != getEqualityNode(f2.a).getFind()) continue;
-      if (getEqualityNode(f1.b).getFind() != getEqualityNode(f2.b).getFind()) continue;
-    }
-
     // Get the nodes of the representatives
     EqualityNode& node1 = getEqualityNode(t1classId);
     EqualityNode& node2 = getEqualityNode(t2classId);
@@ -693,7 +677,7 @@ void EqualityEngine<NotifyClass>::propagate(bool check) {
 }
 
 template <typename NotifyClass>
-void EqualityEngine<NotifyClass>::debugPrintGraph() {
+void EqualityEngine<NotifyClass>::debugPrintGraph() const {
   for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) {
 
     Debug("equality::internal") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):";