From 86569ce68ed002aeb31d102511d3c9bd8384a7ec Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 23 Apr 2021 13:31:37 -0500 Subject: [PATCH] Add new substitution apply methods fixpoint, sequential, simultaneous (#6429) This adds various methods for applying substitution as an options to MACRO_SR_* rules. It extends the proof checker and the proof post processor to eliminate based on these types. It updates the trust substitutions utility used by non-clausal simplification to use fixed-point semantics for substitution, which is highly important for efficiency. As a result of this PR, we are orders of magnitude faster for checking proofs for problems where non-clausal substitution infers many substitutions. It also makes our regressions noticeably faster: --- src/expr/term_conversion_proof_generator.cpp | 19 +++ src/expr/term_conversion_proof_generator.h | 8 + src/smt/proof_post_processor.cpp | 105 +++++++------ src/theory/builtin/proof_checker.cpp | 146 +++++++++++-------- src/theory/builtin/proof_checker.h | 28 +++- src/theory/strings/infer_proof_cons.cpp | 8 +- src/theory/theory_proof_step_buffer.cpp | 12 +- src/theory/theory_proof_step_buffer.h | 4 + src/theory/trust_substitutions.cpp | 9 +- 9 files changed, 225 insertions(+), 114 deletions(-) diff --git a/src/expr/term_conversion_proof_generator.cpp b/src/expr/term_conversion_proof_generator.cpp index 18057e149..3d68fd181 100644 --- a/src/expr/term_conversion_proof_generator.cpp +++ b/src/expr/term_conversion_proof_generator.cpp @@ -19,6 +19,7 @@ #include "expr/proof_checker.h" #include "expr/proof_node.h" +#include "expr/proof_node_algorithm.h" #include "expr/term_context.h" #include "expr/term_context_stack.h" @@ -236,6 +237,24 @@ std::shared_ptr TConvProofGenerator::getProofFor(Node f) return pfn; } +std::shared_ptr TConvProofGenerator::getProofForRewriting(Node n) +{ + LazyCDProof lpf( + d_proof.getManager(), &d_proof, nullptr, d_name + "::LazyCDProofRew"); + Node conc = getProofForRewriting(n, lpf, d_tcontext); + if (conc[1] == n) + { + // assertion failure in debug + Assert(false) << "TConvProofGenerator::getProofForRewriting: " << identify() + << ": don't ask for trivial proofs"; + lpf.addStep(conc, PfRule::REFL, {}, {n}); + } + std::shared_ptr pfn = lpf.getProofFor(conc); + Assert(pfn != nullptr); + Trace("tconv-pf-gen-debug") << "... proof is " << *pfn << std::endl; + return pfn; +} + Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf, TermContext* tctx) diff --git a/src/expr/term_conversion_proof_generator.h b/src/expr/term_conversion_proof_generator.h index bc0987478..e546d23bd 100644 --- a/src/expr/term_conversion_proof_generator.h +++ b/src/expr/term_conversion_proof_generator.h @@ -190,6 +190,14 @@ class TConvProofGenerator : public ProofGenerator std::shared_ptr getProofFor(Node f) override; /** Identify this generator (for debugging, etc..) */ std::string identify() const override; + /** + * Get the proof for how term n would rewrite. This is in contrast to the + * above method where the user provides an equality (= n n'). The motivation + * for this method is when it may be expensive to compute n', and hence it + * is preferred that the proof checker computes the rewritten form of + * n, instead of verifying that n has rewritten form n'. + */ + std::shared_ptr getProofForRewriting(Node n); protected: typedef context::CDHashMap NodeNodeMap; diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp index 16b7f560b..b36b00bd5 100644 --- a/src/smt/proof_post_processor.cpp +++ b/src/smt/proof_post_processor.cpp @@ -427,19 +427,27 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, { std::vector sargs; sargs.push_back(t); - MethodId sid = MethodId::SB_DEFAULT; + MethodId ids = MethodId::SB_DEFAULT; if (args.size() >= 2) { - if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid)) + if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids)) { sargs.push_back(args[1]); } } - ts = - builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid); + MethodId ida = MethodId::SBA_SEQUENTIAL; + if (args.size() >= 3) + { + if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida)) + { + sargs.push_back(args[2]); + } + } + ts = builtin::BuiltinProofRuleChecker::applySubstitution( + t, children, ids, ida); Trace("smt-proof-pp-debug") << "...eq intro subs equality is " << t << " == " << ts << ", from " - << sid << std::endl; + << ids << " " << ida << std::endl; if (ts != t) { Node eq = t.eqNode(ts); @@ -459,21 +467,21 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, } std::vector rargs; rargs.push_back(ts); - MethodId rid = MethodId::RW_REWRITE; - if (args.size() >= 3) + MethodId idr = MethodId::RW_REWRITE; + if (args.size() >= 4) { - if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid)) + if (builtin::BuiltinProofRuleChecker::getMethodId(args[3], idr)) { - rargs.push_back(args[2]); + rargs.push_back(args[3]); } } builtin::BuiltinProofRuleChecker* builtinPfC = static_cast( d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO)); - Node tr = builtinPfC->applyRewrite(ts, rid); + Node tr = builtinPfC->applyRewrite(ts, idr); Trace("smt-proof-pp-debug") << "...eq intro rewrite equality is " << ts << " == " << tr << ", from " - << rid << std::endl; + << idr << std::endl; if (ts != tr) { Node eq = ts.eqNode(tr); @@ -797,6 +805,11 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, { builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids); } + MethodId ida = MethodId::SBA_SEQUENTIAL; + if (args.size() >= 3) + { + builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida); + } std::vector> pfs; std::vector vsList; std::vector ssList; @@ -834,7 +847,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, << "...process " << var << " -> " << subs << " (" << childFrom << ", " << ids << ")" << std::endl; // apply the current substitution to the range - if (!vvec.empty()) + if (!vvec.empty() && ida == MethodId::SBA_SEQUENTIAL) { Node ss = subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end()); @@ -889,43 +902,47 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, svec.push_back(subs); pgs.push_back(cdp); } - Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end()); - Node eq = t.eqNode(ts); - if (ts != t) + // should be implied by the substitution now + TConvPolicy tcpolicy = ida == MethodId::SBA_FIXPOINT ? TConvPolicy::FIXPOINT + : TConvPolicy::ONCE; + TConvProofGenerator tcpg(d_pnm, + nullptr, + tcpolicy, + TConvCachePolicy::NEVER, + "SUBS_TConvProofGenerator", + nullptr, + true); + for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++) { - // should be implied by the substitution now - TConvProofGenerator tcpg(d_pnm, - nullptr, - TConvPolicy::ONCE, - TConvCachePolicy::NEVER, - "SUBS_TConvProofGenerator", - nullptr, - true); - for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++) - { - // substitutions are pre-rewrites - tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true); - } - // add the proof constructed by the term conversion utility - std::shared_ptr pfn = tcpg.getProofFor(eq); - // should give a proof, if not, then tcpg does not agree with the - // substitution. - Assert(pfn != nullptr); - if (pfn == nullptr) - { - cdp->addStep(eq, PfRule::TRUST_SUBS, {}, {eq}); - } - else - { - cdp->addProof(pfn); - } + // substitutions are pre-rewrites + tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true); + } + // add the proof constructed by the term conversion utility + std::shared_ptr pfn = tcpg.getProofForRewriting(t); + Node eq = pfn->getResult(); + Node ts = builtin::BuiltinProofRuleChecker::applySubstitution( + t, children, ids, ida); + Node eqq = t.eqNode(ts); + if (eq != eqq) + { + pfn = nullptr; + } + // should give a proof, if not, then tcpg does not agree with the + // substitution. + Assert(pfn != nullptr); + if (pfn == nullptr) + { + AlwaysAssert(false) << "resort to TRUST_SUBS" << std::endl + << eq << std::endl + << eqq << std::endl + << "from " << children << " applied to " << t; + cdp->addStep(eqq, PfRule::TRUST_SUBS, {}, {eqq}); } else { - // should not be necessary typically - cdp->addStep(eq, PfRule::REFL, {}, {t}); + cdp->addProof(pfn); } - return eq; + return eqq; } else if (id == PfRule::REWRITE) { diff --git a/src/theory/builtin/proof_checker.cpp b/src/theory/builtin/proof_checker.cpp index 81cc75983..fdc952bdd 100644 --- a/src/theory/builtin/proof_checker.cpp +++ b/src/theory/builtin/proof_checker.cpp @@ -19,6 +19,7 @@ #include "smt/term_formula_removal.h" #include "theory/evaluator.h" #include "theory/rewriter.h" +#include "theory/substitutions.h" #include "theory/theory.h" using namespace cvc5::kind; @@ -40,6 +41,9 @@ const char* toString(MethodId id) case MethodId::SB_DEFAULT: return "SB_DEFAULT"; case MethodId::SB_LITERAL: return "SB_LITERAL"; case MethodId::SB_FORMULA: return "SB_FORMULA"; + case MethodId::SBA_SEQUENTIAL: return "SBA_SEQUENTIAL"; + case MethodId::SBA_SIMUL: return "SBA_SIMUL"; + case MethodId::SBA_FIXPOINT: return "SBA_FIXPOINT"; default: return "MethodId::Unknown"; }; } @@ -84,9 +88,13 @@ void BuiltinProofRuleChecker::registerTo(ProofChecker* pc) } Node BuiltinProofRuleChecker::applySubstitutionRewrite( - Node n, const std::vector& exp, MethodId ids, MethodId idr) + Node n, + const std::vector& exp, + MethodId ids, + MethodId ida, + MethodId idr) { - Node nks = applySubstitution(n, exp, ids); + Node nks = applySubstitution(n, exp, ids, ida); return applyRewrite(nks, idr); } @@ -187,48 +195,59 @@ bool BuiltinProofRuleChecker::getSubstitutionFor(Node exp, return ret; } -Node BuiltinProofRuleChecker::applySubstitution(Node n, Node exp, MethodId ids) +Node BuiltinProofRuleChecker::applySubstitution(Node n, + Node exp, + MethodId ids, + MethodId ida) { - std::vector vars; - std::vector subs; - std::vector from; - if (!getSubstitutionFor(exp, vars, subs, from, ids)) - { - return Node::null(); - } - Node ns = n; - // apply substitution one at a time, in reverse order - for (size_t i = 0, nvars = vars.size(); i < nvars; i++) - { - TNode v = vars[nvars - 1 - i]; - TNode s = subs[nvars - 1 - i]; - Trace("builtin-pfcheck-debug") - << "applySubstitution (" << ids << "): " << v << " -> " << s - << " (from " << exp << ")" << std::endl; - ns = ns.substitute(v, s); - } - return ns; + std::vector expv{exp}; + return applySubstitution(n, expv, ids, ida); } Node BuiltinProofRuleChecker::applySubstitution(Node n, const std::vector& exp, - MethodId ids) + MethodId ids, + MethodId ida) { - Node curr = n; - // apply substitution one at a time, in reverse order + std::vector vars; + std::vector subs; + std::vector from; for (size_t i = 0, nexp = exp.size(); i < nexp; i++) { - if (exp[nexp - 1 - i].isNull()) + if (exp[i].isNull()) { return Node::null(); } - curr = applySubstitution(curr, exp[nexp - 1 - i], ids); - if (curr.isNull()) + if (!getSubstitutionFor(exp[i], vars, subs, from, ids)) { - break; + return Node::null(); } } - return curr; + if (ida == MethodId::SBA_SIMUL) + { + // simply apply the simultaneous substitution now + return n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + } + else if (ida == MethodId::SBA_FIXPOINT) + { + SubstitutionMap sm; + for (size_t i = 0, nvars = vars.size(); i < nvars; i++) + { + sm.addSubstitution(vars[i], subs[i]); + } + Node ns = sm.apply(n); + return ns; + } + Assert(ida == MethodId::SBA_SEQUENTIAL); + // we prefer n traversals of the term to n^2/2 traversals of range terms + Node ns = n; + for (size_t i = 0, nvars = vars.size(); i < nvars; i++) + { + TNode v = vars[(nvars - 1) - i]; + TNode s = subs[(nvars - 1) - i]; + ns = ns.substitute(v, s); + } + return ns; } bool BuiltinProofRuleChecker::getMethodId(TNode n, MethodId& i) @@ -321,13 +340,13 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, } else if (id == PfRule::MACRO_SR_EQ_INTRO) { - Assert(1 <= args.size() && args.size() <= 3); - MethodId ids, idr; - if (!getMethodIds(args, ids, idr, 1)) + Assert(1 <= args.size() && args.size() <= 4); + MethodId ids, ida, idr; + if (!getMethodIds(args, ids, ida, idr, 1)) { return Node::null(); } - Node res = applySubstitutionRewrite(args[0], children, ids, idr); + Node res = applySubstitutionRewrite(args[0], children, ids, ida, idr); if (res.isNull()) { return Node::null(); @@ -338,13 +357,13 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, { Trace("builtin-pfcheck") << "Check " << id << " " << children.size() << " " << args[0] << std::endl; - Assert(1 <= args.size() && args.size() <= 3); - MethodId ids, idr; - if (!getMethodIds(args, ids, idr, 1)) + Assert(1 <= args.size() && args.size() <= 4); + MethodId ids, ida, idr; + if (!getMethodIds(args, ids, ida, idr, 1)) { return Node::null(); } - Node res = applySubstitutionRewrite(args[0], children, ids, idr); + Node res = applySubstitutionRewrite(args[0], children, ids, ida, idr); if (res.isNull()) { return Node::null(); @@ -369,15 +388,15 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, Trace("builtin-pfcheck") << "Check " << id << " " << children.size() << " " << args.size() << std::endl; Assert(children.size() >= 1); - Assert(args.size() <= 2); + Assert(args.size() <= 3); std::vector exp; exp.insert(exp.end(), children.begin() + 1, children.end()); - MethodId ids, idr; - if (!getMethodIds(args, ids, idr, 0)) + MethodId ids, ida, idr; + if (!getMethodIds(args, ids, ida, idr, 0)) { return Node::null(); } - Node res1 = applySubstitutionRewrite(children[0], exp, ids, idr); + Node res1 = applySubstitutionRewrite(children[0], exp, ids, ida, idr); Trace("builtin-pfcheck") << "Returned " << res1 << std::endl; return res1; } @@ -386,17 +405,17 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, Trace("builtin-pfcheck") << "Check " << id << " " << children.size() << " " << args.size() << std::endl; Assert(children.size() >= 1); - Assert(1 <= args.size() && args.size() <= 3); + Assert(1 <= args.size() && args.size() <= 4); Assert(args[0].getType().isBoolean()); - MethodId ids, idr; - if (!getMethodIds(args, ids, idr, 1)) + MethodId ids, ida, idr; + if (!getMethodIds(args, ids, ida, idr, 1)) { return Node::null(); } std::vector exp; exp.insert(exp.end(), children.begin() + 1, children.end()); - Node res1 = applySubstitutionRewrite(children[0], exp, ids, idr); - Node res2 = applySubstitutionRewrite(args[0], exp, ids, idr); + Node res1 = applySubstitutionRewrite(children[0], exp, ids, ida, idr); + Node res2 = applySubstitutionRewrite(args[0], exp, ids, ida, idr); // if not already equal, do rewriting if (res1 != res2) { @@ -438,27 +457,28 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, bool BuiltinProofRuleChecker::getMethodIds(const std::vector& args, MethodId& ids, + MethodId& ida, MethodId& idr, size_t index) { ids = MethodId::SB_DEFAULT; + ida = MethodId::SBA_SEQUENTIAL; idr = MethodId::RW_REWRITE; - if (args.size() > index) + for (size_t offset = 0; offset <= 2; offset++) { - if (!getMethodId(args[index], ids)) + if (args.size() > index + offset) { - Trace("builtin-pfcheck") - << "Failed to get id from " << args[index] << std::endl; - return false; + MethodId& id = offset == 0 ? ids : (offset == 1 ? ida : idr); + if (!getMethodId(args[index + offset], id)) + { + Trace("builtin-pfcheck") + << "Failed to get id from " << args[index + offset] << std::endl; + return false; + } } - } - if (args.size() > index + 1) - { - if (!getMethodId(args[index + 1], idr)) + else { - Trace("builtin-pfcheck") - << "Failed to get id from " << args[index + 1] << std::endl; - return false; + break; } } return true; @@ -466,13 +486,19 @@ bool BuiltinProofRuleChecker::getMethodIds(const std::vector& args, void BuiltinProofRuleChecker::addMethodIds(std::vector& args, MethodId ids, + MethodId ida, MethodId idr) { bool ndefRewriter = (idr != MethodId::RW_REWRITE); - if (ids != MethodId::SB_DEFAULT || ndefRewriter) + bool ndefApply = (ida != MethodId::SBA_SEQUENTIAL); + if (ids != MethodId::SB_DEFAULT || ndefRewriter || ndefApply) { args.push_back(mkMethodId(ids)); } + if (ndefApply || ndefRewriter) + { + args.push_back(mkMethodId(ida)); + } if (ndefRewriter) { args.push_back(mkMethodId(idr)); diff --git a/src/theory/builtin/proof_checker.h b/src/theory/builtin/proof_checker.h index 38eea31c5..81da0a969 100644 --- a/src/theory/builtin/proof_checker.h +++ b/src/theory/builtin/proof_checker.h @@ -64,6 +64,21 @@ enum class MethodId : uint32_t SB_LITERAL, // P is interpreted as P -> true using Node::substitute SB_FORMULA, + //---------------------------- Substitution applications + // multiple substitutions are applied sequentially + SBA_SEQUENTIAL, + // multiple substitutions are applied simultaneously + SBA_SIMUL, + // multiple substitutions are applied to fix point + SBA_FIXPOINT + // For example, for x -> u, y -> f(z), z -> g(x), applying this substituion to + // y gives: + // - f(g(x)) for SBA_SEQUENTIAL + // - f(z) for SBA_SIMUL + // - f(g(u)) for SBA_FIXPOINT + // Notice that SBA_FIXPOINT should provide a terminating rewrite system + // as a substitution, or else non-termination will occur during proof + // checking. }; /** Converts a rewriter id to a string. */ const char* toString(MethodId id); @@ -126,10 +141,12 @@ class BuiltinProofRuleChecker : public ProofRuleChecker */ static Node applySubstitution(Node n, Node exp, - MethodId ids = MethodId::SB_DEFAULT); + MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL); static Node applySubstitution(Node n, const std::vector& exp, - MethodId ids = MethodId::SB_DEFAULT); + MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL); /** Apply substitution + rewriting * * Combines the above two steps. @@ -143,6 +160,7 @@ class BuiltinProofRuleChecker : public ProofRuleChecker Node applySubstitutionRewrite(Node n, const std::vector& exp, MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL, MethodId idr = MethodId::RW_REWRITE); /** get a method identifier from a node, return false if we fail */ static bool getMethodId(TNode n, MethodId& i); @@ -153,13 +171,17 @@ class BuiltinProofRuleChecker : public ProofRuleChecker */ bool getMethodIds(const std::vector& args, MethodId& ids, + MethodId& ida, MethodId& idr, size_t index); /** * Add method identifiers ids and idr as nodes to args. This does not add ids * or idr if their values are the default ones. */ - static void addMethodIds(std::vector& args, MethodId ids, MethodId idr); + static void addMethodIds(std::vector& args, + MethodId ids, + MethodId ida, + MethodId idr); /** get a TheoryId from a node, return false if we fail */ static bool getTheoryId(TNode n, TheoryId& tid); diff --git a/src/theory/strings/infer_proof_cons.cpp b/src/theory/strings/infer_proof_cons.cpp index 2351e7bf3..bc02dcbb6 100644 --- a/src/theory/strings/infer_proof_cons.cpp +++ b/src/theory/strings/infer_proof_cons.cpp @@ -149,8 +149,11 @@ void InferProofCons::convert(InferenceId infer, break; } // may need the "extended equality rewrite" - Node mainEqSRew2 = psb.applyPredElim( - mainEqSRew, {}, MethodId::SB_DEFAULT, MethodId::RW_REWRITE_EQ_EXT); + Node mainEqSRew2 = psb.applyPredElim(mainEqSRew, + {}, + MethodId::SB_DEFAULT, + MethodId::SBA_SEQUENTIAL, + MethodId::RW_REWRITE_EQ_EXT); if (mainEqSRew2 == conc) { useBuffer = true; @@ -286,6 +289,7 @@ void InferProofCons::convert(InferenceId infer, conc, cexp, MethodId::SB_DEFAULT, + MethodId::SBA_SEQUENTIAL, MethodId::RW_REWRITE_EQ_EXT)) { Trace("strings-ipc-core") << "Transformed to " << conc diff --git a/src/theory/theory_proof_step_buffer.cpp b/src/theory/theory_proof_step_buffer.cpp index 06b4b7ad7..70f35afb8 100644 --- a/src/theory/theory_proof_step_buffer.cpp +++ b/src/theory/theory_proof_step_buffer.cpp @@ -31,11 +31,12 @@ bool TheoryProofStepBuffer::applyEqIntro(Node src, Node tgt, const std::vector& exp, MethodId ids, + MethodId ida, MethodId idr) { std::vector args; args.push_back(src); - builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr); + builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr); Node res = tryStep(PfRule::MACRO_SR_EQ_INTRO, exp, args); if (res.isNull()) { @@ -58,6 +59,7 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src, Node tgt, const std::vector& exp, MethodId ids, + MethodId ida, MethodId idr) { // symmetric equalities @@ -71,7 +73,7 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src, // try to prove that tgt rewrites to src children.insert(children.end(), exp.begin(), exp.end()); args.push_back(tgt); - builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr); + builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr); Node res = tryStep(PfRule::MACRO_SR_PRED_TRANSFORM, children, args); if (res.isNull()) { @@ -86,11 +88,12 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src, bool TheoryProofStepBuffer::applyPredIntro(Node tgt, const std::vector& exp, MethodId ids, + MethodId ida, MethodId idr) { std::vector args; args.push_back(tgt); - builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr); + builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr); Node res = tryStep(PfRule::MACRO_SR_PRED_INTRO, exp, args); if (res.isNull()) { @@ -103,13 +106,14 @@ bool TheoryProofStepBuffer::applyPredIntro(Node tgt, Node TheoryProofStepBuffer::applyPredElim(Node src, const std::vector& exp, MethodId ids, + MethodId ida, MethodId idr) { std::vector children; children.push_back(src); children.insert(children.end(), exp.begin(), exp.end()); std::vector args; - builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr); + builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr); Node srcRew = tryStep(PfRule::MACRO_SR_PRED_ELIM, children, args); if (CDProof::isSame(src, srcRew)) { diff --git a/src/theory/theory_proof_step_buffer.h b/src/theory/theory_proof_step_buffer.h index 328266b89..1832ddb08 100644 --- a/src/theory/theory_proof_step_buffer.h +++ b/src/theory/theory_proof_step_buffer.h @@ -47,6 +47,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer Node tgt, const std::vector& exp, MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL, MethodId idr = MethodId::RW_REWRITE); /** * Apply predicate transform. If this method returns true, it adds (at most @@ -58,6 +59,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer Node tgt, const std::vector& exp, MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL, MethodId idr = MethodId::RW_REWRITE); /** * Apply predicate introduction. If this method returns true, it adds proof @@ -68,6 +70,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer bool applyPredIntro(Node tgt, const std::vector& exp, MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL, MethodId idr = MethodId::RW_REWRITE); /** * Apply predicate elimination. This method returns the result of applying @@ -83,6 +86,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer Node applyPredElim(Node src, const std::vector& exp, MethodId ids = MethodId::SB_DEFAULT, + MethodId ida = MethodId::SBA_SEQUENTIAL, MethodId idr = MethodId::RW_REWRITE); //---------------------------- end utilities builtin proof rules diff --git a/src/theory/trust_substitutions.cpp b/src/theory/trust_substitutions.cpp index 3e5b04d86..d76356f75 100644 --- a/src/theory/trust_substitutions.cpp +++ b/src/theory/trust_substitutions.cpp @@ -204,7 +204,14 @@ std::shared_ptr TrustSubstitutionMap::getProofFor(Node eq) } } Trace("trust-subs-pf") << "...apply eq intro" << std::endl; - if (!d_tspb->applyEqIntro(n, ns, pfChildren, d_ids)) + // We use fixpoint as the substitution-apply identifier. Notice that it + // suffices to use SBA_SEQUENTIAL here, but SBA_FIXPOINT is typically + // more efficient. This is because for substitution of size n, sequential + // substitution can either be implemented as n traversals of the term to + // apply the substitution to, or a single traversal of the term, but n^2/2 + // traversals of the range of the substitution to prepare a simultaneous + // substitution. Both of these options are inefficient. + if (!d_tspb->applyEqIntro(n, ns, pfChildren, d_ids, MethodId::SBA_FIXPOINT)) { // if we fail for any reason, we must use a trusted step instead d_tspb->addStep(PfRule::TRUST_SUBS_MAP, pfChildren, {eq}, eq); -- 2.30.2