From 0ee1b1371e7cf50c14883316fdd6374114799a99 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 2 Sep 2020 13:01:10 -0500 Subject: [PATCH] (proof-new) Make term conversion proof generator optionally term-context sensitive (#4972) This will be used by TermFormulaRemoval. --- src/expr/CMakeLists.txt | 4 + src/expr/term_conversion_proof_generator.cpp | 351 ++++++++++++++----- src/expr/term_conversion_proof_generator.h | 70 +++- 3 files changed, 338 insertions(+), 87 deletions(-) diff --git a/src/expr/CMakeLists.txt b/src/expr/CMakeLists.txt index 8bc732314..aed7a866c 100644 --- a/src/expr/CMakeLists.txt +++ b/src/expr/CMakeLists.txt @@ -63,6 +63,10 @@ libcvc4_add_sources( term_canonize.h term_context.cpp term_context.h + term_context_node.cpp + term_context_node.h + term_context_stack.cpp + term_context_stack.h term_conversion_proof_generator.cpp term_conversion_proof_generator.h type.cpp diff --git a/src/expr/term_conversion_proof_generator.cpp b/src/expr/term_conversion_proof_generator.cpp index 1c4baeed7..8cd7561b4 100644 --- a/src/expr/term_conversion_proof_generator.cpp +++ b/src/expr/term_conversion_proof_generator.cpp @@ -14,6 +14,8 @@ #include "expr/term_conversion_proof_generator.h" +#include "expr/term_context_stack.h" + using namespace CVC4::kind; namespace CVC4 { @@ -45,31 +47,38 @@ TConvProofGenerator::TConvProofGenerator(ProofNodeManager* pnm, context::Context* c, TConvPolicy pol, TConvCachePolicy cpol, - std::string name) + std::string name, + TermContext* tccb) : d_proof(pnm, nullptr, c, name + "::LazyCDProof"), d_rewriteMap(c ? c : &d_context), d_policy(pol), d_cpolicy(cpol), - d_name(name) + d_name(name), + d_tcontext(tccb) { } TConvProofGenerator::~TConvProofGenerator() {} -void TConvProofGenerator::addRewriteStep(Node t, Node s, ProofGenerator* pg) +void TConvProofGenerator::addRewriteStep( + Node t, Node s, ProofGenerator* pg, bool isClosed, uint32_t tctx) { - Node eq = registerRewriteStep(t, s); + Node eq = registerRewriteStep(t, s, tctx); if (!eq.isNull()) { - d_proof.addLazyStep(eq, pg); + d_proof.addLazyStep(eq, pg, isClosed); } } -void TConvProofGenerator::addRewriteStep(Node t, Node s, ProofStep ps) +void TConvProofGenerator::addRewriteStep(Node t, + Node s, + ProofStep ps, + uint32_t tctx) { - Node eq = registerRewriteStep(t, s); + Node eq = registerRewriteStep(t, s, tctx); if (!eq.isNull()) { + AlwaysAssert(ps.d_rule != PfRule::ASSUME); d_proof.addStep(eq, ps); } } @@ -78,33 +87,55 @@ void TConvProofGenerator::addRewriteStep(Node t, Node s, PfRule id, const std::vector& children, - const std::vector& args) + const std::vector& args, + uint32_t tctx) { - Node eq = registerRewriteStep(t, s); + Node eq = registerRewriteStep(t, s, tctx); if (!eq.isNull()) { + AlwaysAssert(id != PfRule::ASSUME); d_proof.addStep(eq, id, children, args); } } -bool TConvProofGenerator::hasRewriteStep(Node t) const +bool TConvProofGenerator::hasRewriteStep(Node t, uint32_t tctx) const { - return !getRewriteStep(t).isNull(); + return !getRewriteStep(t, tctx).isNull(); } -Node TConvProofGenerator::registerRewriteStep(Node t, Node s) +Node TConvProofGenerator::getRewriteStep(Node t, uint32_t tctx) const +{ + Node thash = t; + if (d_tcontext != nullptr) + { + thash = TCtxNode::computeNodeHash(t, tctx); + } + return getRewriteStepInternal(thash); +} + +Node TConvProofGenerator::registerRewriteStep(Node t, Node s, uint32_t tctx) { if (t == s) { return Node::null(); } + Node thash = t; + if (d_tcontext != nullptr) + { + thash = TCtxNode::computeNodeHash(t, tctx); + } + else + { + // don't use term context ids if not using term context + Assert(tctx == 0); + } // should not rewrite term to two different things - if (!getRewriteStep(t).isNull()) + if (!getRewriteStepInternal(thash).isNull()) { - Assert(getRewriteStep(t) == s); + Assert(getRewriteStepInternal(thash) == s); return Node::null(); } - d_rewriteMap[t] = s; + d_rewriteMap[thash] = s; if (d_cpolicy == TConvCachePolicy::DYNAMIC) { // clear the cache @@ -115,66 +146,140 @@ Node TConvProofGenerator::registerRewriteStep(Node t, Node s) std::shared_ptr TConvProofGenerator::getProofFor(Node f) { - Trace("tconv-pf-gen") << "TConvProofGenerator::getProofFor: " << f - << std::endl; + Trace("tconv-pf-gen") << "TConvProofGenerator::getProofFor: " << identify() + << ": " << f << std::endl; if (f.getKind() != EQUAL) { - Trace("tconv-pf-gen") << "... fail, non-equality" << std::endl; - Assert(false); + std::stringstream serr; + serr << "TConvProofGenerator::getProofFor: " << identify() + << ": fail, non-equality " << f; + Unhandled() << serr.str(); + Trace("tconv-pf-gen") << serr.str() << std::endl; return nullptr; } // we use the existing proofs LazyCDProof lpf( d_proof.getManager(), &d_proof, nullptr, d_name + "::LazyCDProof"); - Node conc = getProofForRewriting(f[0], lpf); - if (conc != f) + if (f[0] == f[1]) { - Trace("tconv-pf-gen") << "...failed, mismatch: returned proof concludes " - << conc << ", expected " << f << std::endl; - Assert(false); - return nullptr; + // assertion failure in debug + Assert(false) << "TConvProofGenerator::getProofFor: " << identify() + << ": don't ask for trivial proofs"; + lpf.addStep(f, PfRule::REFL, {}, {f[0]}); + } + else + { + Node conc = getProofForRewriting(f[0], lpf, d_tcontext); + if (conc != f) + { + Assert(conc.getKind() == EQUAL && conc[0] == f[0]); + std::stringstream serr; + serr << "TConvProofGenerator::getProofFor: " << toStringDebug() + << ": failed, mismatch (see -t tconv-pf-gen-debug for details)" + << std::endl; + serr << " source: " << f[0] << std::endl; + serr << "expected after rewriting: " << f[1] << std::endl; + serr << " actual after rewriting: " << conc[1] << std::endl; + + if (Trace.isOn("tconv-pf-gen-debug")) + { + Trace("tconv-pf-gen-debug") << "Printing rewrite steps..." << std::endl; + serr << "Rewrite steps: " << std::endl; + for (NodeNodeMap::const_iterator it = d_rewriteMap.begin(); + it != d_rewriteMap.end(); + ++it) + { + serr << (*it).first << " -> " << (*it).second << std::endl; + } + } + Unhandled() << serr.str(); + return nullptr; + } } Trace("tconv-pf-gen") << "... success" << std::endl; return lpf.getProofFor(f); } -Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf) +Node TConvProofGenerator::getProofForRewriting(Node t, + LazyCDProof& pf, + TermContext* tctx) { NodeManager* nm = NodeManager::currentNM(); - // Invariant: if visited[t] = s or rewritten[t] = s and t,s are distinct, - // then pf is able to generate a proof of t=s. + // Invariant: if visited[hash(t)] = s or rewritten[hash(t)] = s and t,s are + // distinct, then pf is able to generate a proof of t=s. We must + // Node in the domains of the maps below due to hashing creating new (SEXPR) + // nodes. + // the final rewritten form of terms - std::unordered_map visited; + std::unordered_map visited; // the rewritten form of terms we have processed so far - std::unordered_map rewritten; - std::unordered_map::iterator it; - std::unordered_map::iterator itr; + std::unordered_map rewritten; + std::unordered_map::iterator it; + std::unordered_map::iterator itr; std::map >::iterator itc; + Trace("tconv-pf-gen-rewrite") + << "TConvProofGenerator::getProofForRewriting: " << toStringDebug() + << std::endl; + Trace("tconv-pf-gen-rewrite") << "Input: " << t << std::endl; + // if provided, we use term context for cache + std::shared_ptr visitctx; + // otherwise, visit is used if we don't have a term context std::vector visit; - TNode cur; - visit.push_back(t); + Node tinitialHash; + if (tctx != nullptr) + { + visitctx = std::make_shared(tctx); + visitctx->pushInitial(t); + tinitialHash = TCtxNode::computeNodeHash(t, tctx->initialValue()); + } + else + { + visit.push_back(t); + tinitialHash = t; + } + Node cur; + uint32_t curCVal = 0; + Node curHash; do { - cur = visit.back(); - visit.pop_back(); + // pop the top element + if (tctx != nullptr) + { + std::pair curPair = visitctx->getCurrent(); + cur = curPair.first; + curCVal = curPair.second; + curHash = TCtxNode::computeNodeHash(cur, curCVal); + visitctx->pop(); + } + else + { + cur = visit.back(); + curHash = cur; + visit.pop_back(); + } + Trace("tconv-pf-gen-rewrite") << "* visit : " << curHash << std::endl; // has the proof for cur been cached? - itc = d_cache.find(cur); + itc = d_cache.find(curHash); if (itc != d_cache.end()) { Node res = itc->second->getResult(); Assert(res.getKind() == EQUAL); - visited[cur] = res[1]; + Assert(!res[1].isNull()); + visited[curHash] = res[1]; pf.addProof(itc->second); continue; } - it = visited.find(cur); + it = visited.find(curHash); if (it == visited.end()) { - visited[cur] = Node::null(); + Trace("tconv-pf-gen-rewrite") << "- previsit" << std::endl; + visited[curHash] = Node::null(); // did we rewrite the current node (possibly at pre-rewrite)? - Node rcur = getRewriteStep(cur); + Node rcur = getRewriteStepInternal(curHash); if (!rcur.isNull()) { + Trace("tconv-pf-gen-rewrite") + << "*** " << curHash << " prerewrites to " << rcur << std::endl; // d_proof has a proof of cur = rcur. Hence there is nothing // to do here, as pf will reference d_proof to get its proof. if (d_policy == TConvPolicy::FIXPOINT) @@ -182,18 +287,34 @@ Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf) // It may be the case that rcur also rewrites, thus we cannot assign // the final rewritten form for cur yet. Instead we revisit cur after // finishing visiting rcur. - rewritten[cur] = rcur; - visit.push_back(cur); - visit.push_back(rcur); + rewritten[curHash] = rcur; + if (tctx != nullptr) + { + visitctx->push(cur, curCVal); + visitctx->push(rcur, curCVal); + } + else + { + visit.push_back(cur); + visit.push_back(rcur); + } } else { Assert(d_policy == TConvPolicy::ONCE); + Trace("tconv-pf-gen-rewrite") << "-> (once, prewrite) " << curHash + << " = " << rcur << std::endl; // not rewriting again, rcur is final - visited[cur] = rcur; - doCache(cur, rcur, pf); + Assert(!rcur.isNull()); + visited[curHash] = rcur; + doCache(curHash, cur, rcur, pf); } } + else if (tctx != nullptr) + { + visitctx->push(cur, curCVal); + visitctx->pushChildren(cur, curCVal); + } else { visit.push_back(cur); @@ -202,7 +323,7 @@ Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf) } else if (it->second.isNull()) { - itr = rewritten.find(cur); + itr = rewritten.find(curHash); if (itr != rewritten.end()) { // only can generate partially rewritten nodes when rewrite again is @@ -211,9 +332,17 @@ Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf) // if it was rewritten, check the status of the rewritten node, // which should be finished now Node rcur = itr->second; + Trace("tconv-pf-gen-rewrite") + << "- postvisit, previously rewritten to " << rcur << std::endl; + Node rcurHash = rcur; + if (tctx != nullptr) + { + rcurHash = TCtxNode::computeNodeHash(rcur, curCVal); + } Assert(cur != rcur); // the final rewritten form of cur is the final form of rcur - Node rcurFinal = visited[rcur]; + Node rcurFinal = visited[rcurHash]; + Assert(!rcurFinal.isNull()); if (rcurFinal != rcur) { // must connect via TRANS @@ -223,30 +352,54 @@ Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf) Node result = cur.eqNode(rcurFinal); pf.addStep(result, PfRule::TRANS, pfChildren, {}); } - visited[cur] = rcurFinal; - doCache(cur, rcurFinal, pf); + Trace("tconv-pf-gen-rewrite") + << "-> (rewritten postrewrite) " << curHash << " = " << rcurFinal + << std::endl; + visited[curHash] = rcurFinal; + doCache(curHash, cur, rcurFinal, pf); } else { + Trace("tconv-pf-gen-rewrite") << "- postvisit" << std::endl; Node ret = cur; + Node retHash = curHash; bool childChanged = false; std::vector children; if (cur.getMetaKind() == metakind::PARAMETERIZED) { children.push_back(cur.getOperator()); } - for (const Node& cn : cur) + // get the results of the children + if (tctx != nullptr) + { + for (size_t i = 0, nchild = cur.getNumChildren(); i < nchild; i++) + { + Node cn = cur[i]; + uint32_t cnval = tctx->computeValue(cur, curCVal, i); + Node cnHash = TCtxNode::computeNodeHash(cn, cnval); + it = visited.find(cnHash); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + childChanged = childChanged || cn != it->second; + children.push_back(it->second); + } + } + else { - it = visited.find(cn); - Assert(it != visited.end()); - Assert(!it->second.isNull()); - childChanged = childChanged || cn != it->second; - children.push_back(it->second); + // can use simple loop if not term-context-sensitive + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + childChanged = childChanged || cn != it->second; + children.push_back(it->second); + } } if (childChanged) { ret = nm->mkNode(cur.getKind(), children); - rewritten[cur] = ret; + rewritten[curHash] = ret; // congruence to show (cur = ret) std::vector pfChildren; for (size_t i = 0, size = cur.getNumChildren(); i < size; i++) @@ -260,62 +413,96 @@ Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf) } std::vector pfArgs; Kind k = cur.getKind(); + pfArgs.push_back(ProofRuleChecker::mkKindNode(k)); if (kind::metaKindOf(k) == kind::metakind::PARAMETERIZED) { pfArgs.push_back(cur.getOperator()); } - else - { - pfArgs.push_back(nm->operatorOf(k)); - } Node result = cur.eqNode(ret); pf.addStep(result, PfRule::CONG, pfChildren, pfArgs); + // must update the hash + retHash = ret; + if (tctx != nullptr) + { + retHash = TCtxNode::computeNodeHash(ret, curCVal); + } + } + else if (tctx != nullptr) + { + // now we need the hash + retHash = TCtxNode::computeNodeHash(cur, curCVal); } // did we rewrite ret (at post-rewrite)? Node rret; // only if not ONCE policy, which only does pre-rewrite if (d_policy != TConvPolicy::ONCE) { - rret = getRewriteStep(ret); + rret = getRewriteStepInternal(retHash); } if (!rret.isNull()) { - if (cur != ret) - { - visit.push_back(cur); - } + Trace("tconv-pf-gen-rewrite") + << "*** " << retHash << " postrewrites to " << rret << std::endl; // d_proof should have a proof of ret = rret, hence nothing to do // here, for the same reasons as above. It also may be the case that // rret rewrites, hence we must revisit ret. - rewritten[ret] = rret; - visit.push_back(ret); - visit.push_back(rret); + rewritten[retHash] = rret; + if (tctx != nullptr) + { + if (cur != ret) + { + visitctx->push(cur, curCVal); + } + visitctx->push(ret, curCVal); + visitctx->push(rret, curCVal); + } + else + { + if (cur != ret) + { + visit.push_back(cur); + } + visit.push_back(ret); + visit.push_back(rret); + } } else { + Trace("tconv-pf-gen-rewrite") + << "-> (postrewrite) " << curHash << " = " << ret << std::endl; // it is final - visited[cur] = ret; - doCache(cur, ret, pf); + Assert(!ret.isNull()); + visited[curHash] = ret; + doCache(curHash, cur, ret, pf); } } } - } while (!visit.empty()); - Assert(visited.find(t) != visited.end()); - Assert(!visited.find(t)->second.isNull()); + else + { + Trace("tconv-pf-gen-rewrite") << "- already visited" << std::endl; + } + } while (!(tctx != nullptr ? visitctx->empty() : visit.empty())); + Assert(visited.find(tinitialHash) != visited.end()); + Assert(!visited.find(tinitialHash)->second.isNull()); + Trace("tconv-pf-gen-rewrite") + << "...finished, return " << visited[tinitialHash] << std::endl; // return the conclusion of the overall proof - return t.eqNode(visited[t]); + return t.eqNode(visited[tinitialHash]); } -void TConvProofGenerator::doCache(Node cur, Node r, LazyCDProof& pf) +void TConvProofGenerator::doCache(Node curHash, + Node cur, + Node r, + LazyCDProof& pf) { if (d_cpolicy != TConvCachePolicy::NEVER) { Node eq = cur.eqNode(r); - d_cache[cur] = pf.getProofFor(eq); + d_cache[curHash] = pf.getProofFor(eq); } } -Node TConvProofGenerator::getRewriteStep(Node t) const +Node TConvProofGenerator::getRewriteStepInternal(Node t) const { NodeNodeMap::const_iterator it = d_rewriteMap.find(t); if (it == d_rewriteMap.end()) @@ -326,4 +513,12 @@ Node TConvProofGenerator::getRewriteStep(Node t) const } std::string TConvProofGenerator::identify() const { return d_name; } +std::string TConvProofGenerator::toStringDebug() const +{ + std::stringstream ss; + ss << identify() << " (policy=" << d_policy << ", cache policy=" << d_cpolicy + << (d_tcontext != nullptr ? ", term-context-sensitive" : "") << ")"; + return ss.str(); +} + } // namespace CVC4 diff --git a/src/expr/term_conversion_proof_generator.h b/src/expr/term_conversion_proof_generator.h index e634b8a83..faee2b9e3 100644 --- a/src/expr/term_conversion_proof_generator.h +++ b/src/expr/term_conversion_proof_generator.h @@ -21,6 +21,7 @@ #include "expr/lazy_proof.h" #include "expr/proof_generator.h" #include "expr/proof_node_manager.h" +#include "expr/term_context.h" namespace CVC4 { @@ -84,6 +85,35 @@ std::ostream& operator<<(std::ostream& out, TConvCachePolicy tcpol); * addRewriteStep. In particular, notice that in the above example, we realize * that f(a) --> c at pre-rewrite instead of post-rewriting a --> b and then * ending with f(a)=f(b). + * + * This class may additionally be used for term-context-sensitive rewrite + * systems. An example is the term formula removal pass which rewrites + * terms dependending on whether they occur in a "term position", for details + * see RtfTermContext in expr/term_context.h. To use this class in a way + * that takes into account term contexts, the user of the term conversion + * proof generator should: + * (1) Provide a term context callback to the constructor of this class (tccb), + * (2) Register rewrite steps that indicate the term context identifier of + * the rewrite, which is a uint32_t. + * + * For example, RtfTermContext uses hash value 2 to indicate we are in a "term + * position". Say the user of this class calls: + * addRewriteStep( (and A B), BOOLEAN_TERM_VARIABLE_1, pg, true, 2) + * This indicates that (and A B) should rewrite to BOOLEAN_TERM_VARIABLE_1 if + * (and A B) occurs in a term position, where pg is a proof generator that can + * provide a closed proof of: + * (= (and A B) BOOLEAN_TERM_VARIABLE_1) + * Subsequently, this class may respond to a call to getProofFor on: + * (= + * (or (and A B) (P (and A B))) + * (or (and A B) (P BOOLEAN_TERM_VARIABLE_1))) + * where P is a predicate Bool -> Bool. The proof returned by this class + * involves congruence and pg's proof of the equivalence above. In particular, + * assuming its proof of the equivalence is P1, this proof is: + * (CONG{=} (CONG{or} (REFL (and A B)) (CONG{P} P1))) + * Notice the callback provided to this class ensures that the rewrite is + * replayed in the expected way, e.g. the occurrence of (and A B) that is not + * in term position is not rewritten. */ class TConvProofGenerator : public ProofGenerator { @@ -99,32 +129,48 @@ class TConvProofGenerator : public ProofGenerator * details, see d_policy. * @param cpol The caching policy for this generator. * @param name The name of this generator (for debugging). + * @param tccb The term context callback that this class depends on. If this + * is non-null, then this class stores a term-context-sensitive rewrite + * system. The rewrite steps should be given term context identifiers. */ TConvProofGenerator(ProofNodeManager* pnm, context::Context* c = nullptr, TConvPolicy pol = TConvPolicy::FIXPOINT, TConvCachePolicy cpol = TConvCachePolicy::NEVER, - std::string name = "TConvProofGenerator"); + std::string name = "TConvProofGenerator", + TermContext* tccb = nullptr); ~TConvProofGenerator(); /** * Add rewrite step t --> s based on proof generator. + * + * @param isClosed whether to expect that pg can provide a closed proof for + * this fact. + * @param tctx The term context identifier for the rewrite step. This + * value should correspond to one generated by the term context callback + * class provided in the argument tccb provided to the constructor of this + * class. */ - void addRewriteStep(Node t, Node s, ProofGenerator* pg); + void addRewriteStep(Node t, + Node s, + ProofGenerator* pg, + bool isClosed = true, + uint32_t tctx = 0); /** Same as above, for a single step */ - void addRewriteStep(Node t, Node s, ProofStep ps); + void addRewriteStep(Node t, Node s, ProofStep ps, uint32_t tctx = 0); /** Same as above, with explicit arguments */ void addRewriteStep(Node t, Node s, PfRule id, const std::vector& children, - const std::vector& args); + const std::vector& args, + uint32_t tctx = 0); /** Has rewrite step for term t */ - bool hasRewriteStep(Node t) const; + bool hasRewriteStep(Node t, uint32_t tctx = 0) const; /** * Get rewrite step for term t, returns the s provided in a call to * addRewriteStep if one exists, or null otherwise. */ - Node getRewriteStep(Node t) const; + Node getRewriteStep(Node t, uint32_t tctx = 0) const; /** * Get the proof for formula f. It should be the case that f is of the form * t = t', where t' is the result of rewriting t based on the rewrite steps @@ -161,19 +207,25 @@ class TConvProofGenerator : public ProofGenerator std::string d_name; /** The cache for terms */ std::map > d_cache; + /** An (optional) term context object */ + TermContext* d_tcontext; + /** Get rewrite step for (hash value of) term. */ + Node getRewriteStepInternal(Node thash) const; /** * Adds a proof of t = t' to the proof pf where t' is the result of rewriting * t based on the rewrite steps registered to this class. This method then * returns the proved equality t = t'. */ - Node getProofForRewriting(Node t, LazyCDProof& pf); + Node getProofForRewriting(Node t, LazyCDProof& pf, TermContext* tc = nullptr); /** * Register rewrite step, returns the equality t=s if t is distinct from s * and a rewrite step has not already been registered for t. */ - Node registerRewriteStep(Node t, Node s); + Node registerRewriteStep(Node t, Node s, uint32_t tctx); /** cache that r is the rewritten form of cur, pf can provide a proof */ - void doCache(Node cur, Node r, LazyCDProof& pf); + void doCache(Node curHash, Node cur, Node r, LazyCDProof& pf); + /** get debug information on this generator */ + std::string toStringDebug() const; }; } // namespace CVC4 -- 2.30.2