From d9a103f371cd800615b37fa378ad9d8b7681ee1c Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 17 Apr 2019 16:35:51 -0500 Subject: [PATCH] Cache explanations in the equality engine (#2937) --- src/theory/uf/equality_engine.cpp | 157 ++++++++++++++++++++++-------- src/theory/uf/equality_engine.h | 32 +++++- 2 files changed, 145 insertions(+), 44 deletions(-) diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp index d1fc8341c..148a5e427 100644 --- a/src/theory/uf/equality_engine.cpp +++ b/src/theory/uf/equality_engine.cpp @@ -929,9 +929,9 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector& equalities, EqProof* eqp) const { - Debug("equality") << d_name << "::eq::explainEquality(" << t1 << ", " << t2 - << ", " << (polarity ? "true" : "false") << ")" - << ", proof = " << (eqp ? "ON" : "OFF") << std::endl; + Debug("pf::ee") << d_name << "::eq::explainEquality(" << t1 << ", " << t2 + << ", " << (polarity ? "true" : "false") << ")" + << ", proof = " << (eqp ? "ON" : "OFF") << std::endl; // The terms must be there already Assert(hasTerm(t1) && hasTerm(t2));; @@ -940,9 +940,10 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, EqualityNodeId t1Id = getNodeId(t1); EqualityNodeId t2Id = getNodeId(t2); + std::map, EqProof*> cache; if (polarity) { // Get the explanation - getExplanation(t1Id, t2Id, equalities, eqp); + getExplanation(t1Id, t2Id, equalities, cache, eqp); } else { if (eqp) { eqp->d_id = eq::MERGED_THROUGH_TRANS; @@ -964,12 +965,15 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, eqpc = std::make_shared(); } - getExplanation(toExplain.first, toExplain.second, equalities, eqpc.get()); + getExplanation( + toExplain.first, toExplain.second, equalities, cache, eqpc.get()); if (eqpc) { - Debug("pf::ee") << "Child proof is:" << std::endl; - eqpc->debug_print("pf::ee", 1); - + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") << "Child proof is:" << std::endl; + eqpc->debug_print("pf::ee", 1); + } if (eqpc->d_id == eq::MERGED_THROUGH_TRANS) { std::vector> orderedChildren; bool nullCongruenceFound = false; @@ -987,8 +991,13 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, if (nullCongruenceFound) { eqpc->d_children = orderedChildren; - Debug("pf::ee") << "Child proof's children have been reordered. It is now:" << std::endl; - eqpc->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") + << "Child proof's children have been reordered. It is now:" + << std::endl; + eqpc->debug_print("pf::ee", 1); + } } } @@ -1011,8 +1020,11 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, *eqp = *temp; } - Debug("pf::ee") << "Disequality explanation final proof: " << std::endl; - eqp->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") << "Disequality explanation final proof: " << std::endl; + eqp->debug_print("pf::ee", 1); + } } } } @@ -1024,15 +1036,51 @@ void EqualityEngine::explainPredicate(TNode p, bool polarity, << std::endl; // Must have the term Assert(hasTerm(p)); + std::map, EqProof*> cache; // Get the explanation - getExplanation(getNodeId(p), polarity ? d_trueId : d_falseId, assertions, - eqp); + getExplanation( + getNodeId(p), polarity ? d_trueId : d_falseId, assertions, cache, eqp); } -void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, - std::vector& equalities, - EqProof* eqp) const { - Debug("equality") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; +void EqualityEngine::getExplanation( + EqualityNodeId t1Id, + EqualityNodeId t2Id, + std::vector& equalities, + std::map, EqProof*>& cache, + EqProof* eqp) const +{ + Trace("eq-exp") << d_name << "::eq::getExplanation(" << d_nodes[t1Id] << "," + << d_nodes[t2Id] << ") size = " << cache.size() << std::endl; + + // We order the ids, since explaining t1 = t2 is the same as explaining + // t2 = t1. + std::pair cacheKey = std::minmax(t1Id, t2Id); + std::map, EqProof*>::iterator it = + cache.find(cacheKey); + if (it != cache.end()) + { + // copy one level + if (eqp) + { + if (it->second) + { + eqp->d_node = it->second->d_node; + eqp->d_id = it->second->d_id; + eqp->d_children.insert(eqp->d_children.end(), + it->second->d_children.begin(), + it->second->d_children.end()); + } + else + { + // We may have cached null in its place, create the trivial proof now. + Assert(d_nodes[t1Id] == d_nodes[t2Id]); + Assert(eqp->d_id == MERGED_THROUGH_REFLEXIVITY); + eqp->d_node = d_nodes[t1Id]; + } + } + return; + } + cache[cacheKey] = eqp; // We can only explain the nodes that got merged #ifdef CVC4_ASSERTIONS @@ -1136,11 +1184,11 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, Debug("equality") << "Explaining left hand side equalities" << std::endl; std::shared_ptr eqpc1 = eqpc ? std::make_shared() : nullptr; - getExplanation(f1.a, f2.a, equalities, eqpc1.get()); + getExplanation(f1.a, f2.a, equalities, cache, eqpc1.get()); Debug("equality") << "Explaining right hand side equalities" << std::endl; std::shared_ptr eqpc2 = eqpc ? std::make_shared() : nullptr; - getExplanation(f1.b, f2.b, equalities, eqpc2.get()); + getExplanation(f1.b, f2.b, equalities, cache, eqpc2.get()); if( eqpc ){ eqpc->d_children.push_back( eqpc1 ); eqpc->d_children.push_back( eqpc2 ); @@ -1185,7 +1233,7 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, Debug("equality") << push; std::shared_ptr eqpc1 = eqpc ? std::make_shared() : nullptr; - getExplanation(eq.a, eq.b, equalities, eqpc1.get()); + getExplanation(eq.a, eq.b, equalities, cache, eqpc1.get()); if( eqpc ){ eqpc->d_children.push_back( eqpc1 ); } @@ -1211,13 +1259,20 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, Assert(isConstant(childId)); std::shared_ptr eqpcc = eqpc ? std::make_shared() : nullptr; - getExplanation(childId, getEqualityNode(childId).getFind(), - equalities, eqpcc.get()); + getExplanation(childId, + getEqualityNode(childId).getFind(), + equalities, + cache, + eqpcc.get()); if( eqpc ) { eqpc->d_children.push_back( eqpcc ); - - Debug("pf::ee") << "MERGED_THROUGH_CONSTANTS. Dumping the child proof" << std::endl; - eqpc->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + Debug("pf::ee") + << "MERGED_THROUGH_CONSTANTS. Dumping the child proof" + << std::endl; + eqpc->debug_print("pf::ee", 1); + } } } @@ -1255,7 +1310,6 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, } eqpc->d_id = reasonType; } - equalities.push_back(reason); break; } @@ -1288,8 +1342,10 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, eqp->d_children.insert( eqp->d_children.end(), eqp_trans.begin(), eqp_trans.end() ); eqp->d_node = NodeManager::currentNM()->mkNode(kind::EQUAL, d_nodes[t1Id], d_nodes[t2Id]); } - - eqp->debug_print("pf::ee", 1); + if (Debug.isOn("pf::ee")) + { + eqp->debug_print("pf::ee", 1); + } } // Done @@ -2236,27 +2292,48 @@ bool EqClassIterator::isFinished() const { } void EqProof::debug_print(const char* c, unsigned tb, PrettyPrinter* prettyPrinter) const { - for(unsigned i=0; iprintTag(d_id); + { + os << prettyPrinter->printTag(d_id); + } else - Debug( c ) << d_id; + { + os << d_id; + } - Debug( c ) << "("; + os << "("; if( !d_children.empty() || !d_node.isNull() ){ if( !d_node.isNull() ){ - Debug( c ) << std::endl; - for( unsigned i=0; i0 || !d_node.isNull() ) Debug( c ) << ","; - Debug( c ) << std::endl; - d_children[i]->debug_print( c, tb+1, prettyPrinter ); + if (i > 0 || !d_node.isNull()) + { + os << ","; + } + os << std::endl; + d_children[i]->debug_print(os, tb + 1, prettyPrinter); } } - Debug( c ) << ")" << std::endl; + os << ")" << std::endl; } } // Namespace uf diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h index b93ff6d6d..73d8bd4e9 100644 --- a/src/theory/uf/equality_engine.h +++ b/src/theory/uf/equality_engine.h @@ -516,11 +516,24 @@ private: bool d_inPropagate; /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the equality t1 = t2. Returns the asserted equalities + * that imply t1 = t2. Returns TNodes as the assertion equalities should be + * hashed somewhere else. + * + * This call refers to terms t1 and t2 by their ids t1Id and t2Id. + * + * If eqp is non-null, then this method populates eqp's information and + * children such that it is a proof of t1 = t2. + * + * We cache results of this call in cache, where cache[t1Id][t2Id] stores + * a proof of t1 = t2. */ - void getExplanation(EqualityEdgeId t1Id, EqualityNodeId t2Id, std::vector& equalities, EqProof* eqp) const; + void getExplanation( + EqualityEdgeId t1Id, + EqualityNodeId t2Id, + std::vector& equalities, + std::map, EqProof*>& cache, + EqProof* eqp) const; /** * Print the equality graph. @@ -941,8 +954,19 @@ public: unsigned d_id; Node d_node; std::vector> d_children; + /** + * Debug print this proof on debug trace c with tabulation tb and pretty + * printer prettyPrinter. + */ void debug_print(const char* c, unsigned tb = 0, PrettyPrinter* prettyPrinter = nullptr) const; + /** + * Debug print this proof on output stream os with tabulation tb and pretty + * printer prettyPrinter. + */ + void debug_print(std::ostream& os, + unsigned tb = 0, + PrettyPrinter* prettyPrinter = nullptr) const; };/* class EqProof */ } // Namespace eq -- 2.30.2