Cache explanations in the equality engine (#2937)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 17 Apr 2019 21:35:51 +0000 (16:35 -0500)
committerGitHub <noreply@github.com>
Wed, 17 Apr 2019 21:35:51 +0000 (16:35 -0500)
src/theory/uf/equality_engine.cpp
src/theory/uf/equality_engine.h

index d1fc8341c8aea2bc8effefb5e5c22cf5c4f2183b..148a5e427ac4d6726d8ec786e3bb3b95fc764241 100644 (file)
@@ -929,9 +929,9 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const {
 void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
                                      std::vector<TNode>& equalities,
                                      EqProof* eqp) const {
-  Debug("equality") << d_name << "::eq::explainEquality(" << t1 << ", " << t2
-                    << ", " << (polarity ? "true" : "false") << ")"
-                    << ", proof = " << (eqp ? "ON" : "OFF") << std::endl;
+  Debug("pf::ee") << d_name << "::eq::explainEquality(" << t1 << ", " << t2
+                  << ", " << (polarity ? "true" : "false") << ")"
+                  << ", proof = " << (eqp ? "ON" : "OFF") << std::endl;
 
   // The terms must be there already
   Assert(hasTerm(t1) && hasTerm(t2));;
@@ -940,9 +940,10 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
   EqualityNodeId t1Id = getNodeId(t1);
   EqualityNodeId t2Id = getNodeId(t2);
 
+  std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*> cache;
   if (polarity) {
     // Get the explanation
-    getExplanation(t1Id, t2Id, equalities, eqp);
+    getExplanation(t1Id, t2Id, equalities, cache, eqp);
   } else {
     if (eqp) {
       eqp->d_id = eq::MERGED_THROUGH_TRANS;
@@ -964,12 +965,15 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
         eqpc = std::make_shared<EqProof>();
       }
 
-      getExplanation(toExplain.first, toExplain.second, equalities, eqpc.get());
+      getExplanation(
+          toExplain.first, toExplain.second, equalities, cache, eqpc.get());
 
       if (eqpc) {
-        Debug("pf::ee") << "Child proof is:" << std::endl;
-        eqpc->debug_print("pf::ee", 1);
-
+        if (Debug.isOn("pf::ee"))
+        {
+          Debug("pf::ee") << "Child proof is:" << std::endl;
+          eqpc->debug_print("pf::ee", 1);
+        }
         if (eqpc->d_id == eq::MERGED_THROUGH_TRANS) {
           std::vector<std::shared_ptr<EqProof>> orderedChildren;
           bool nullCongruenceFound = false;
@@ -987,8 +991,13 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
 
           if (nullCongruenceFound) {
             eqpc->d_children = orderedChildren;
-            Debug("pf::ee") << "Child proof's children have been reordered. It is now:" << std::endl;
-            eqpc->debug_print("pf::ee", 1);
+            if (Debug.isOn("pf::ee"))
+            {
+              Debug("pf::ee")
+                  << "Child proof's children have been reordered. It is now:"
+                  << std::endl;
+              eqpc->debug_print("pf::ee", 1);
+            }
           }
         }
 
@@ -1011,8 +1020,11 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
         *eqp = *temp;
       }
 
-      Debug("pf::ee") << "Disequality explanation final proof: " << std::endl;
-      eqp->debug_print("pf::ee", 1);
+      if (Debug.isOn("pf::ee"))
+      {
+        Debug("pf::ee") << "Disequality explanation final proof: " << std::endl;
+        eqp->debug_print("pf::ee", 1);
+      }
     }
   }
 }
@@ -1024,15 +1036,51 @@ void EqualityEngine::explainPredicate(TNode p, bool polarity,
                     << std::endl;
   // Must have the term
   Assert(hasTerm(p));
+  std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*> cache;
   // Get the explanation
-  getExplanation(getNodeId(p), polarity ? d_trueId : d_falseId, assertions,
-                 eqp);
+  getExplanation(
+      getNodeId(p), polarity ? d_trueId : d_falseId, assertions, cache, eqp);
 }
 
-void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
-                                    std::vector<TNode>& equalities,
-                                    EqProof* eqp) const {
-  Debug("equality") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl;
+void EqualityEngine::getExplanation(
+    EqualityNodeId t1Id,
+    EqualityNodeId t2Id,
+    std::vector<TNode>& equalities,
+    std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>& cache,
+    EqProof* eqp) const
+{
+  Trace("eq-exp") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << ","
+                  << d_nodes[t2Id] << ") size = " << cache.size() << std::endl;
+
+  // We order the ids, since explaining t1 = t2 is the same as explaining
+  // t2 = t1.
+  std::pair<EqualityNodeId, EqualityNodeId> cacheKey = std::minmax(t1Id, t2Id);
+  std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>::iterator it =
+      cache.find(cacheKey);
+  if (it != cache.end())
+  {
+    // copy one level
+    if (eqp)
+    {
+      if (it->second)
+      {
+        eqp->d_node = it->second->d_node;
+        eqp->d_id = it->second->d_id;
+        eqp->d_children.insert(eqp->d_children.end(),
+                               it->second->d_children.begin(),
+                               it->second->d_children.end());
+      }
+      else
+      {
+        // We may have cached null in its place, create the trivial proof now.
+        Assert(d_nodes[t1Id] == d_nodes[t2Id]);
+        Assert(eqp->d_id == MERGED_THROUGH_REFLEXIVITY);
+        eqp->d_node = d_nodes[t1Id];
+      }
+    }
+    return;
+  }
+  cache[cacheKey] = eqp;
 
   // We can only explain the nodes that got merged
 #ifdef CVC4_ASSERTIONS
@@ -1136,11 +1184,11 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
               Debug("equality") << "Explaining left hand side equalities" << std::endl;
               std::shared_ptr<EqProof> eqpc1 =
                   eqpc ? std::make_shared<EqProof>() : nullptr;
-              getExplanation(f1.a, f2.a, equalities, eqpc1.get());
+              getExplanation(f1.a, f2.a, equalities, cache, eqpc1.get());
               Debug("equality") << "Explaining right hand side equalities" << std::endl;
               std::shared_ptr<EqProof> eqpc2 =
                   eqpc ? std::make_shared<EqProof>() : nullptr;
-              getExplanation(f1.b, f2.b, equalities, eqpc2.get());
+              getExplanation(f1.b, f2.b, equalities, cache, eqpc2.get());
               if( eqpc ){
                 eqpc->d_children.push_back( eqpc1 );
                 eqpc->d_children.push_back( eqpc2 );
@@ -1185,7 +1233,7 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
               Debug("equality") << push;
               std::shared_ptr<EqProof> eqpc1 =
                   eqpc ? std::make_shared<EqProof>() : nullptr;
-              getExplanation(eq.a, eq.b, equalities, eqpc1.get());
+              getExplanation(eq.a, eq.b, equalities, cache, eqpc1.get());
               if( eqpc ){
                 eqpc->d_children.push_back( eqpc1 );
               }
@@ -1211,13 +1259,20 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
                 Assert(isConstant(childId));
                 std::shared_ptr<EqProof> eqpcc =
                     eqpc ? std::make_shared<EqProof>() : nullptr;
-                getExplanation(childId, getEqualityNode(childId).getFind(),
-                               equalities, eqpcc.get());
+                getExplanation(childId,
+                               getEqualityNode(childId).getFind(),
+                               equalities,
+                               cache,
+                               eqpcc.get());
                 if( eqpc ) {
                   eqpc->d_children.push_back( eqpcc );
-
-                  Debug("pf::ee") << "MERGED_THROUGH_CONSTANTS. Dumping the child proof" << std::endl;
-                  eqpc->debug_print("pf::ee", 1);
+                  if (Debug.isOn("pf::ee"))
+                  {
+                    Debug("pf::ee")
+                        << "MERGED_THROUGH_CONSTANTS. Dumping the child proof"
+                        << std::endl;
+                    eqpc->debug_print("pf::ee", 1);
+                  }
                 }
               }
 
@@ -1255,7 +1310,6 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
                 }
                 eqpc->d_id = reasonType;
               }
-
               equalities.push_back(reason);
               break;
             }
@@ -1288,8 +1342,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id,
               eqp->d_children.insert( eqp->d_children.end(), eqp_trans.begin(), eqp_trans.end() );
               eqp->d_node = NodeManager::currentNM()->mkNode(kind::EQUAL, d_nodes[t1Id], d_nodes[t2Id]);
             }
-
-            eqp->debug_print("pf::ee", 1);
+            if (Debug.isOn("pf::ee"))
+            {
+              eqp->debug_print("pf::ee", 1);
+            }
           }
 
           // Done
@@ -2236,27 +2292,48 @@ bool EqClassIterator::isFinished() const {
 }
 
 void EqProof::debug_print(const char* c, unsigned tb, PrettyPrinter* prettyPrinter) const {
-  for(unsigned i=0; i<tb; i++) { Debug( c ) << "  "; }
+  std::stringstream ss;
+  debug_print(ss, tb, prettyPrinter);
+  Debug(c) << ss.str();
+}
+void EqProof::debug_print(std::ostream& os,
+                          unsigned tb,
+                          PrettyPrinter* prettyPrinter) const
+{
+  for (unsigned i = 0; i < tb; i++)
+  {
+    os << "  ";
+  }
 
   if (prettyPrinter)
-    Debug( c ) << prettyPrinter->printTag(d_id);
+  {
+    os << prettyPrinter->printTag(d_id);
+  }
   else
-    Debug( c ) << d_id;
+  {
+    os << d_id;
+  }
 
-  Debug( c ) << "(";
+  os << "(";
   if( !d_children.empty() || !d_node.isNull() ){
     if( !d_node.isNull() ){
-      Debug( c ) << std::endl;
-      for( unsigned i=0; i<tb+1; i++ ) { Debug( c ) << "  "; }
-      Debug( c ) << d_node;
+      os << std::endl;
+      for (unsigned i = 0; i < tb + 1; i++)
+      {
+        os << "  ";
+      }
+      os << d_node;
     }
     for( unsigned i=0; i<d_children.size(); i++ ){
-      if( i>0 || !d_node.isNull() ) Debug( c ) << ",";
-      Debug( c ) << std::endl;
-      d_children[i]->debug_print( c, tb+1, prettyPrinter );
+      if (i > 0 || !d_node.isNull())
+      {
+        os << ",";
+      }
+      os << std::endl;
+      d_children[i]->debug_print(os, tb + 1, prettyPrinter);
     }
   }
-  Debug( c ) << ")" << std::endl;
+  os << ")" << std::endl;
 }
 
 } // Namespace uf
index b93ff6d6dd6264c235a86b4ecb9193d4fb3ca048..73d8bd4e9bc54a7556fa389f20f8ede8f4bce336 100644 (file)
@@ -516,11 +516,24 @@ private:
   bool d_inPropagate;
 
   /**
-   * Get an explanation of the equality t1 = t2. Returns the asserted equalities that
-   * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere
-   * else.
+   * Get an explanation of the equality t1 = t2. Returns the asserted equalities
+   * that imply t1 = t2. Returns TNodes as the assertion equalities should be
+   * hashed somewhere else.
+   *
+   * This call refers to terms t1 and t2 by their ids t1Id and t2Id.
+   *
+   * If eqp is non-null, then this method populates eqp's information and
+   * children such that it is a proof of t1 = t2.
+   *
+   * We cache results of this call in cache, where cache[t1Id][t2Id] stores
+   * a proof of t1 = t2.
    */
-  void getExplanation(EqualityEdgeId t1Id, EqualityNodeId t2Id, std::vector<TNode>& equalities, EqProof* eqp) const;
+  void getExplanation(
+      EqualityEdgeId t1Id,
+      EqualityNodeId t2Id,
+      std::vector<TNode>& equalities,
+      std::map<std::pair<EqualityNodeId, EqualityNodeId>, EqProof*>& cache,
+      EqProof* eqp) const;
 
   /**
    * Print the equality graph.
@@ -941,8 +954,19 @@ public:
   unsigned d_id;
   Node d_node;
   std::vector<std::shared_ptr<EqProof>> d_children;
+  /**
+   * Debug print this proof on debug trace c with tabulation tb and pretty
+   * printer prettyPrinter.
+   */
   void debug_print(const char* c, unsigned tb = 0,
                    PrettyPrinter* prettyPrinter = nullptr) const;
+  /**
+   * Debug print this proof on output stream os with tabulation tb and pretty
+   * printer prettyPrinter.
+   */
+  void debug_print(std::ostream& os,
+                   unsigned tb = 0,
+                   PrettyPrinter* prettyPrinter = nullptr) const;
 };/* class EqProof */
 
 } // Namespace eq