From: Andrew Reynolds Date: Fri, 18 Sep 2020 02:06:44 +0000 (-0500) Subject: (proof-new) Rewrites involving operators in term conversion proof generator (#5072) X-Git-Tag: cvc5-1.0.0~2841 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=cb438c1aca9e205359313f2e661fef910e1132b6;p=cvc5.git (proof-new) Rewrites involving operators in term conversion proof generator (#5072) In some cases, e.g. witness form conversion, we require traversing to operators in term conversion proofs. This updates the term conversion generator to support (term-context-sensitive) rewrites involving operators using HO_CONG. This requires updating the term context utilities with support for operators. --- diff --git a/src/expr/term_context.cpp b/src/expr/term_context.cpp index 6dcdc25ee..fcb8edfdc 100644 --- a/src/expr/term_context.cpp +++ b/src/expr/term_context.cpp @@ -16,6 +16,12 @@ namespace CVC4 { +uint32_t TermContext::computeValueOp(TNode t, uint32_t tval) const +{ + // default is no change + return tval; +} + uint32_t RtfTermContext::initialValue() const { // by default, not in a term context or a quantifier diff --git a/src/expr/term_context.h b/src/expr/term_context.h index 87f91f2df..0c51dd493 100644 --- a/src/expr/term_context.h +++ b/src/expr/term_context.h @@ -48,6 +48,11 @@ class TermContext * is the term context identifier of t. */ virtual uint32_t computeValue(TNode t, uint32_t tval, size_t index) const = 0; + /** + * Returns the term context identifier of the operator of t, where tval + * is the term context identifier of t. + */ + virtual uint32_t computeValueOp(TNode t, uint32_t tval) const; }; /** diff --git a/src/expr/term_context_stack.cpp b/src/expr/term_context_stack.cpp index 4c2e32033..4deb0627c 100644 --- a/src/expr/term_context_stack.cpp +++ b/src/expr/term_context_stack.cpp @@ -43,6 +43,13 @@ void TCtxStack::pushChild(Node t, uint32_t tval, size_t index) d_stack.push_back(std::pair(t[index], tcval)); } +void TCtxStack::pushOp(Node t, uint32_t tval) +{ + Assert(t.hasOperator()); + uint32_t toval = d_tctx->computeValueOp(t, tval); + d_stack.push_back(std::pair(t.getOperator(), toval)); +} + void TCtxStack::push(Node t, uint32_t tval) { d_stack.push_back(std::pair(t, tval)); diff --git a/src/expr/term_context_stack.h b/src/expr/term_context_stack.h index 9aeea04c2..6f5dece81 100644 --- a/src/expr/term_context_stack.h +++ b/src/expr/term_context_stack.h @@ -41,6 +41,11 @@ class TCtxStack * the term context hash of t. */ void pushChild(Node t, uint32_t tval, size_t index); + /** + * Push the operator of t to the stack, where tval is the term context has + * of t. + */ + void pushOp(Node t, uint32_t tval); /** Push t to the stack with term context hash tval. */ void push(Node t, uint32_t tval); /** Pop a term from the context */ diff --git a/src/expr/term_conversion_proof_generator.cpp b/src/expr/term_conversion_proof_generator.cpp index 8cd7561b4..215793318 100644 --- a/src/expr/term_conversion_proof_generator.cpp +++ b/src/expr/term_conversion_proof_generator.cpp @@ -48,13 +48,15 @@ TConvProofGenerator::TConvProofGenerator(ProofNodeManager* pnm, TConvPolicy pol, TConvCachePolicy cpol, std::string name, - TermContext* tccb) + TermContext* tccb, + bool rewriteOps) : d_proof(pnm, nullptr, c, name + "::LazyCDProof"), d_rewriteMap(c ? c : &d_context), d_policy(pol), d_cpolicy(cpol), d_name(name), - d_tcontext(tccb) + d_tcontext(tccb), + d_rewriteOps(rewriteOps) { } @@ -196,8 +198,11 @@ std::shared_ptr TConvProofGenerator::getProofFor(Node f) return nullptr; } } + std::shared_ptr pfn = lpf.getProofFor(f); Trace("tconv-pf-gen") << "... success" << std::endl; - return lpf.getProofFor(f); + Assert (pfn!=nullptr); + Trace("tconv-pf-gen-debug") << "... proof is " << *pfn << std::endl; + return pfn; } Node TConvProofGenerator::getProofForRewriting(Node t, @@ -313,11 +318,21 @@ Node TConvProofGenerator::getProofForRewriting(Node t, else if (tctx != nullptr) { visitctx->push(cur, curCVal); + // visit operator if apply uf + if (d_rewriteOps && cur.getKind() == APPLY_UF) + { + visitctx->pushOp(cur, curCVal); + } visitctx->pushChildren(cur, curCVal); } else { visit.push_back(cur); + // visit operator if apply uf + if (d_rewriteOps && cur.getKind() == APPLY_UF) + { + visit.push_back(cur.getOperator()); + } visit.insert(visit.end(), cur.begin(), cur.end()); } } @@ -365,8 +380,29 @@ Node TConvProofGenerator::getProofForRewriting(Node t, Node retHash = curHash; bool childChanged = false; std::vector children; - if (cur.getMetaKind() == metakind::PARAMETERIZED) + Kind ck = cur.getKind(); + if (d_rewriteOps && ck == APPLY_UF) + { + // the operator of APPLY_UF is visited + Node cop = cur.getOperator(); + if (tctx != nullptr) + { + uint32_t coval = tctx->computeValueOp(cur, curCVal); + Node coHash = TCtxNode::computeNodeHash(cop, coval); + it = visited.find(coHash); + } + else + { + it = visited.find(cop); + } + Assert(it != visited.end()); + Assert(!it->second.isNull()); + childChanged = childChanged || cop != it->second; + children.push_back(it->second); + } + else if (cur.getMetaKind() == metakind::PARAMETERIZED) { + // all other parametrized operators are unchanged children.push_back(cur.getOperator()); } // get the results of the children @@ -398,10 +434,23 @@ Node TConvProofGenerator::getProofForRewriting(Node t, } if (childChanged) { - ret = nm->mkNode(cur.getKind(), children); + ret = nm->mkNode(ck, children); rewritten[curHash] = ret; // congruence to show (cur = ret) + PfRule congRule = PfRule::CONG; std::vector pfChildren; + std::vector pfArgs; + pfArgs.push_back(ProofRuleChecker::mkKindNode(ck)); + if (ck == APPLY_UF && children[0] != cur.getOperator()) + { + // use HO_CONG if the operator changed + congRule = PfRule::HO_CONG; + pfChildren.push_back(cur.getOperator().eqNode(children[0])); + } + else if (kind::metaKindOf(ck) == kind::metakind::PARAMETERIZED) + { + pfArgs.push_back(cur.getOperator()); + } for (size_t i = 0, size = cur.getNumChildren(); i < size; i++) { if (cur[i] == ret[i]) @@ -411,15 +460,8 @@ Node TConvProofGenerator::getProofForRewriting(Node t, } pfChildren.push_back(cur[i].eqNode(ret[i])); } - std::vector pfArgs; - Kind k = cur.getKind(); - pfArgs.push_back(ProofRuleChecker::mkKindNode(k)); - if (kind::metaKindOf(k) == kind::metakind::PARAMETERIZED) - { - pfArgs.push_back(cur.getOperator()); - } Node result = cur.eqNode(ret); - pf.addStep(result, PfRule::CONG, pfChildren, pfArgs); + pf.addStep(result, congRule, pfChildren, pfArgs); // must update the hash retHash = ret; if (tctx != nullptr) diff --git a/src/expr/term_conversion_proof_generator.h b/src/expr/term_conversion_proof_generator.h index faee2b9e3..1153995ad 100644 --- a/src/expr/term_conversion_proof_generator.h +++ b/src/expr/term_conversion_proof_generator.h @@ -138,7 +138,8 @@ class TConvProofGenerator : public ProofGenerator TConvPolicy pol = TConvPolicy::FIXPOINT, TConvCachePolicy cpol = TConvCachePolicy::NEVER, std::string name = "TConvProofGenerator", - TermContext* tccb = nullptr); + TermContext* tccb = nullptr, + bool rewriteOps = false); ~TConvProofGenerator(); /** * Add rewrite step t --> s based on proof generator. @@ -209,6 +210,12 @@ class TConvProofGenerator : public ProofGenerator std::map > d_cache; /** An (optional) term context object */ TermContext* d_tcontext; + /** + * Whether we rewrite operators. If this flag is true, then the main + * traversal algorithm of this proof generator traverses operators of + * APPLY_UF and uses HO_CONG to justify rewriting of subterms when necessary. + */ + bool d_rewriteOps; /** Get rewrite step for (hash value of) term. */ Node getRewriteStepInternal(Node thash) const; /** diff --git a/src/smt/witness_form.cpp b/src/smt/witness_form.cpp index 19795119d..48fb2a1f7 100644 --- a/src/smt/witness_form.cpp +++ b/src/smt/witness_form.cpp @@ -25,7 +25,9 @@ WitnessFormGenerator::WitnessFormGenerator(ProofNodeManager* pnm) nullptr, TConvPolicy::FIXPOINT, TConvCachePolicy::NEVER, - "WfGenerator::TConvProofGenerator"), + "WfGenerator::TConvProofGenerator", + nullptr, + true), d_wintroPf(pnm, nullptr, nullptr, "WfGenerator::LazyCDProof") { } @@ -110,6 +112,10 @@ Node WitnessFormGenerator::convertToWitnessForm(Node t) // It should be the case that cur has children, since the witness // form of constants are themselves. Assert(cur.getNumChildren() > 0); + if (cur.hasOperator()) + { + visit.push_back(cur.getOperator()); + } visit.insert(visit.end(), cur.begin(), cur.end()); } }