Add new substitution apply methods fixpoint, sequential, simultaneous (#6429)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 23 Apr 2021 18:31:37 +0000 (13:31 -0500)
committerGitHub <noreply@github.com>
Fri, 23 Apr 2021 18:31:37 +0000 (18:31 +0000)
This adds various methods for applying substitution as an options to MACRO_SR_* rules. It extends the proof checker and the proof post processor to eliminate based on these types.

It updates the trust substitutions utility used by non-clausal simplification to use fixed-point semantics for substitution, which is highly important for efficiency.

As a result of this PR, we are orders of magnitude faster for checking proofs for problems where non-clausal substitution infers many substitutions. It also makes our regressions noticeably faster:

src/expr/term_conversion_proof_generator.cpp
src/expr/term_conversion_proof_generator.h
src/smt/proof_post_processor.cpp
src/theory/builtin/proof_checker.cpp
src/theory/builtin/proof_checker.h
src/theory/strings/infer_proof_cons.cpp
src/theory/theory_proof_step_buffer.cpp
src/theory/theory_proof_step_buffer.h
src/theory/trust_substitutions.cpp

index 18057e1495b7c1f59bfae89aad82c8923e6eb8c3..3d68fd181935d090ed165273ccd496c1f49db9c4 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "expr/proof_checker.h"
 #include "expr/proof_node.h"
+#include "expr/proof_node_algorithm.h"
 #include "expr/term_context.h"
 #include "expr/term_context_stack.h"
 
@@ -236,6 +237,24 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofFor(Node f)
   return pfn;
 }
 
+std::shared_ptr<ProofNode> TConvProofGenerator::getProofForRewriting(Node n)
+{
+  LazyCDProof lpf(
+      d_proof.getManager(), &d_proof, nullptr, d_name + "::LazyCDProofRew");
+  Node conc = getProofForRewriting(n, lpf, d_tcontext);
+  if (conc[1] == n)
+  {
+    // assertion failure in debug
+    Assert(false) << "TConvProofGenerator::getProofForRewriting: " << identify()
+                  << ": don't ask for trivial proofs";
+    lpf.addStep(conc, PfRule::REFL, {}, {n});
+  }
+  std::shared_ptr<ProofNode> pfn = lpf.getProofFor(conc);
+  Assert(pfn != nullptr);
+  Trace("tconv-pf-gen-debug") << "... proof is " << *pfn << std::endl;
+  return pfn;
+}
+
 Node TConvProofGenerator::getProofForRewriting(Node t,
                                                LazyCDProof& pf,
                                                TermContext* tctx)
index bc09874788a07a0f742c6fd1419ccb3d6b290725..e546d23bdee963b43db06faf50e3b188e49436fd 100644 (file)
@@ -190,6 +190,14 @@ class TConvProofGenerator : public ProofGenerator
   std::shared_ptr<ProofNode> getProofFor(Node f) override;
   /** Identify this generator (for debugging, etc..) */
   std::string identify() const override;
+  /**
+   * Get the proof for how term n would rewrite. This is in contrast to the
+   * above method where the user provides an equality (= n n'). The motivation
+   * for this method is when it may be expensive to compute n', and hence it
+   * is preferred that the proof checker computes the rewritten form of
+   * n, instead of verifying that n has rewritten form n'.
+   */
+  std::shared_ptr<ProofNode> getProofForRewriting(Node n);
 
  protected:
   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
index 16b7f560bade8af5f253683444dddd27e914c0da..b36b00bd582a7126caa3d2659afcf252bd06c500 100644 (file)
@@ -427,19 +427,27 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     {
       std::vector<Node> sargs;
       sargs.push_back(t);
-      MethodId sid = MethodId::SB_DEFAULT;
+      MethodId ids = MethodId::SB_DEFAULT;
       if (args.size() >= 2)
       {
-        if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid))
+        if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids))
         {
           sargs.push_back(args[1]);
         }
       }
-      ts =
-          builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid);
+      MethodId ida = MethodId::SBA_SEQUENTIAL;
+      if (args.size() >= 3)
+      {
+        if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida))
+        {
+          sargs.push_back(args[2]);
+        }
+      }
+      ts = builtin::BuiltinProofRuleChecker::applySubstitution(
+          t, children, ids, ida);
       Trace("smt-proof-pp-debug")
           << "...eq intro subs equality is " << t << " == " << ts << ", from "
-          << sid << std::endl;
+          << ids << " " << ida << std::endl;
       if (ts != t)
       {
         Node eq = t.eqNode(ts);
@@ -459,21 +467,21 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     }
     std::vector<Node> rargs;
     rargs.push_back(ts);
-    MethodId rid = MethodId::RW_REWRITE;
-    if (args.size() >= 3)
+    MethodId idr = MethodId::RW_REWRITE;
+    if (args.size() >= 4)
     {
-      if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid))
+      if (builtin::BuiltinProofRuleChecker::getMethodId(args[3], idr))
       {
-        rargs.push_back(args[2]);
+        rargs.push_back(args[3]);
       }
     }
     builtin::BuiltinProofRuleChecker* builtinPfC =
         static_cast<builtin::BuiltinProofRuleChecker*>(
             d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO));
-    Node tr = builtinPfC->applyRewrite(ts, rid);
+    Node tr = builtinPfC->applyRewrite(ts, idr);
     Trace("smt-proof-pp-debug")
         << "...eq intro rewrite equality is " << ts << " == " << tr << ", from "
-        << rid << std::endl;
+        << idr << std::endl;
     if (ts != tr)
     {
       Node eq = ts.eqNode(tr);
@@ -797,6 +805,11 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     {
       builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids);
     }
+    MethodId ida = MethodId::SBA_SEQUENTIAL;
+    if (args.size() >= 3)
+    {
+      builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida);
+    }
     std::vector<std::shared_ptr<CDProof>> pfs;
     std::vector<TNode> vsList;
     std::vector<TNode> ssList;
@@ -834,7 +847,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
           << "...process " << var << " -> " << subs << " (" << childFrom << ", "
           << ids << ")" << std::endl;
       // apply the current substitution to the range
-      if (!vvec.empty())
+      if (!vvec.empty() && ida == MethodId::SBA_SEQUENTIAL)
       {
         Node ss =
             subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
@@ -889,43 +902,47 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
       svec.push_back(subs);
       pgs.push_back(cdp);
     }
-    Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
-    Node eq = t.eqNode(ts);
-    if (ts != t)
+    // should be implied by the substitution now
+    TConvPolicy tcpolicy = ida == MethodId::SBA_FIXPOINT ? TConvPolicy::FIXPOINT
+                                                         : TConvPolicy::ONCE;
+    TConvProofGenerator tcpg(d_pnm,
+                             nullptr,
+                             tcpolicy,
+                             TConvCachePolicy::NEVER,
+                             "SUBS_TConvProofGenerator",
+                             nullptr,
+                             true);
+    for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
     {
-      // should be implied by the substitution now
-      TConvProofGenerator tcpg(d_pnm,
-                               nullptr,
-                               TConvPolicy::ONCE,
-                               TConvCachePolicy::NEVER,
-                               "SUBS_TConvProofGenerator",
-                               nullptr,
-                               true);
-      for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
-      {
-        // substitutions are pre-rewrites
-        tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
-      }
-      // add the proof constructed by the term conversion utility
-      std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq);
-      // should give a proof, if not, then tcpg does not agree with the
-      // substitution.
-      Assert(pfn != nullptr);
-      if (pfn == nullptr)
-      {
-        cdp->addStep(eq, PfRule::TRUST_SUBS, {}, {eq});
-      }
-      else
-      {
-        cdp->addProof(pfn);
-      }
+      // substitutions are pre-rewrites
+      tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
+    }
+    // add the proof constructed by the term conversion utility
+    std::shared_ptr<ProofNode> pfn = tcpg.getProofForRewriting(t);
+    Node eq = pfn->getResult();
+    Node ts = builtin::BuiltinProofRuleChecker::applySubstitution(
+        t, children, ids, ida);
+    Node eqq = t.eqNode(ts);
+    if (eq != eqq)
+    {
+      pfn = nullptr;
+    }
+    // should give a proof, if not, then tcpg does not agree with the
+    // substitution.
+    Assert(pfn != nullptr);
+    if (pfn == nullptr)
+    {
+      AlwaysAssert(false) << "resort to TRUST_SUBS" << std::endl
+                          << eq << std::endl
+                          << eqq << std::endl
+                          << "from " << children << " applied to " << t;
+      cdp->addStep(eqq, PfRule::TRUST_SUBS, {}, {eqq});
     }
     else
     {
-      // should not be necessary typically
-      cdp->addStep(eq, PfRule::REFL, {}, {t});
+      cdp->addProof(pfn);
     }
-    return eq;
+    return eqq;
   }
   else if (id == PfRule::REWRITE)
   {
index 81cc759837c1b1645a5d47fe3c52d08c87b37e74..fdc952bdd0a7d6362833dd3a4bf1dcf3639e739f 100644 (file)
@@ -19,6 +19,7 @@
 #include "smt/term_formula_removal.h"
 #include "theory/evaluator.h"
 #include "theory/rewriter.h"
+#include "theory/substitutions.h"
 #include "theory/theory.h"
 
 using namespace cvc5::kind;
@@ -40,6 +41,9 @@ const char* toString(MethodId id)
     case MethodId::SB_DEFAULT: return "SB_DEFAULT";
     case MethodId::SB_LITERAL: return "SB_LITERAL";
     case MethodId::SB_FORMULA: return "SB_FORMULA";
+    case MethodId::SBA_SEQUENTIAL: return "SBA_SEQUENTIAL";
+    case MethodId::SBA_SIMUL: return "SBA_SIMUL";
+    case MethodId::SBA_FIXPOINT: return "SBA_FIXPOINT";
     default: return "MethodId::Unknown";
   };
 }
@@ -84,9 +88,13 @@ void BuiltinProofRuleChecker::registerTo(ProofChecker* pc)
 }
 
 Node BuiltinProofRuleChecker::applySubstitutionRewrite(
-    Node n, const std::vector<Node>& exp, MethodId ids, MethodId idr)
+    Node n,
+    const std::vector<Node>& exp,
+    MethodId ids,
+    MethodId ida,
+    MethodId idr)
 {
-  Node nks = applySubstitution(n, exp, ids);
+  Node nks = applySubstitution(n, exp, ids, ida);
   return applyRewrite(nks, idr);
 }
 
@@ -187,48 +195,59 @@ bool BuiltinProofRuleChecker::getSubstitutionFor(Node exp,
   return ret;
 }
 
-Node BuiltinProofRuleChecker::applySubstitution(Node n, Node exp, MethodId ids)
+Node BuiltinProofRuleChecker::applySubstitution(Node n,
+                                                Node exp,
+                                                MethodId ids,
+                                                MethodId ida)
 {
-  std::vector<TNode> vars;
-  std::vector<TNode> subs;
-  std::vector<TNode> from;
-  if (!getSubstitutionFor(exp, vars, subs, from, ids))
-  {
-    return Node::null();
-  }
-  Node ns = n;
-  // apply substitution one at a time, in reverse order
-  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
-  {
-    TNode v = vars[nvars - 1 - i];
-    TNode s = subs[nvars - 1 - i];
-    Trace("builtin-pfcheck-debug")
-        << "applySubstitution (" << ids << "): " << v << " -> " << s
-        << " (from " << exp << ")" << std::endl;
-    ns = ns.substitute(v, s);
-  }
-  return ns;
+  std::vector<Node> expv{exp};
+  return applySubstitution(n, expv, ids, ida);
 }
 
 Node BuiltinProofRuleChecker::applySubstitution(Node n,
                                                 const std::vector<Node>& exp,
-                                                MethodId ids)
+                                                MethodId ids,
+                                                MethodId ida)
 {
-  Node curr = n;
-  // apply substitution one at a time, in reverse order
+  std::vector<TNode> vars;
+  std::vector<TNode> subs;
+  std::vector<TNode> from;
   for (size_t i = 0, nexp = exp.size(); i < nexp; i++)
   {
-    if (exp[nexp - 1 - i].isNull())
+    if (exp[i].isNull())
     {
       return Node::null();
     }
-    curr = applySubstitution(curr, exp[nexp - 1 - i], ids);
-    if (curr.isNull())
+    if (!getSubstitutionFor(exp[i], vars, subs, from, ids))
     {
-      break;
+      return Node::null();
     }
   }
-  return curr;
+  if (ida == MethodId::SBA_SIMUL)
+  {
+    // simply apply the simultaneous substitution now
+    return n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+  }
+  else if (ida == MethodId::SBA_FIXPOINT)
+  {
+    SubstitutionMap sm;
+    for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
+    {
+      sm.addSubstitution(vars[i], subs[i]);
+    }
+    Node ns = sm.apply(n);
+    return ns;
+  }
+  Assert(ida == MethodId::SBA_SEQUENTIAL);
+  // we prefer n traversals of the term to n^2/2 traversals of range terms
+  Node ns = n;
+  for (size_t i = 0, nvars = vars.size(); i < nvars; i++)
+  {
+    TNode v = vars[(nvars - 1) - i];
+    TNode s = subs[(nvars - 1) - i];
+    ns = ns.substitute(v, s);
+  }
+  return ns;
 }
 
 bool BuiltinProofRuleChecker::getMethodId(TNode n, MethodId& i)
@@ -321,13 +340,13 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id,
   }
   else if (id == PfRule::MACRO_SR_EQ_INTRO)
   {
-    Assert(1 <= args.size() && args.size() <= 3);
-    MethodId ids, idr;
-    if (!getMethodIds(args, ids, idr, 1))
+    Assert(1 <= args.size() && args.size() <= 4);
+    MethodId ids, ida, idr;
+    if (!getMethodIds(args, ids, ida, idr, 1))
     {
       return Node::null();
     }
-    Node res = applySubstitutionRewrite(args[0], children, ids, idr);
+    Node res = applySubstitutionRewrite(args[0], children, ids, ida, idr);
     if (res.isNull())
     {
       return Node::null();
@@ -338,13 +357,13 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id,
   {
     Trace("builtin-pfcheck") << "Check " << id << " " << children.size() << " "
                              << args[0] << std::endl;
-    Assert(1 <= args.size() && args.size() <= 3);
-    MethodId ids, idr;
-    if (!getMethodIds(args, ids, idr, 1))
+    Assert(1 <= args.size() && args.size() <= 4);
+    MethodId ids, ida, idr;
+    if (!getMethodIds(args, ids, ida, idr, 1))
     {
       return Node::null();
     }
-    Node res = applySubstitutionRewrite(args[0], children, ids, idr);
+    Node res = applySubstitutionRewrite(args[0], children, ids, ida, idr);
     if (res.isNull())
     {
       return Node::null();
@@ -369,15 +388,15 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id,
     Trace("builtin-pfcheck") << "Check " << id << " " << children.size() << " "
                              << args.size() << std::endl;
     Assert(children.size() >= 1);
-    Assert(args.size() <= 2);
+    Assert(args.size() <= 3);
     std::vector<Node> exp;
     exp.insert(exp.end(), children.begin() + 1, children.end());
-    MethodId ids, idr;
-    if (!getMethodIds(args, ids, idr, 0))
+    MethodId ids, ida, idr;
+    if (!getMethodIds(args, ids, ida, idr, 0))
     {
       return Node::null();
     }
-    Node res1 = applySubstitutionRewrite(children[0], exp, ids, idr);
+    Node res1 = applySubstitutionRewrite(children[0], exp, ids, ida, idr);
     Trace("builtin-pfcheck") << "Returned " << res1 << std::endl;
     return res1;
   }
@@ -386,17 +405,17 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id,
     Trace("builtin-pfcheck") << "Check " << id << " " << children.size() << " "
                              << args.size() << std::endl;
     Assert(children.size() >= 1);
-    Assert(1 <= args.size() && args.size() <= 3);
+    Assert(1 <= args.size() && args.size() <= 4);
     Assert(args[0].getType().isBoolean());
-    MethodId ids, idr;
-    if (!getMethodIds(args, ids, idr, 1))
+    MethodId ids, ida, idr;
+    if (!getMethodIds(args, ids, ida, idr, 1))
     {
       return Node::null();
     }
     std::vector<Node> exp;
     exp.insert(exp.end(), children.begin() + 1, children.end());
-    Node res1 = applySubstitutionRewrite(children[0], exp, ids, idr);
-    Node res2 = applySubstitutionRewrite(args[0], exp, ids, idr);
+    Node res1 = applySubstitutionRewrite(children[0], exp, ids, ida, idr);
+    Node res2 = applySubstitutionRewrite(args[0], exp, ids, ida, idr);
     // if not already equal, do rewriting
     if (res1 != res2)
     {
@@ -438,27 +457,28 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id,
 
 bool BuiltinProofRuleChecker::getMethodIds(const std::vector<Node>& args,
                                            MethodId& ids,
+                                           MethodId& ida,
                                            MethodId& idr,
                                            size_t index)
 {
   ids = MethodId::SB_DEFAULT;
+  ida = MethodId::SBA_SEQUENTIAL;
   idr = MethodId::RW_REWRITE;
-  if (args.size() > index)
+  for (size_t offset = 0; offset <= 2; offset++)
   {
-    if (!getMethodId(args[index], ids))
+    if (args.size() > index + offset)
     {
-      Trace("builtin-pfcheck")
-          << "Failed to get id from " << args[index] << std::endl;
-      return false;
+      MethodId& id = offset == 0 ? ids : (offset == 1 ? ida : idr);
+      if (!getMethodId(args[index + offset], id))
+      {
+        Trace("builtin-pfcheck")
+            << "Failed to get id from " << args[index + offset] << std::endl;
+        return false;
+      }
     }
-  }
-  if (args.size() > index + 1)
-  {
-    if (!getMethodId(args[index + 1], idr))
+    else
     {
-      Trace("builtin-pfcheck")
-          << "Failed to get id from " << args[index + 1] << std::endl;
-      return false;
+      break;
     }
   }
   return true;
@@ -466,13 +486,19 @@ bool BuiltinProofRuleChecker::getMethodIds(const std::vector<Node>& args,
 
 void BuiltinProofRuleChecker::addMethodIds(std::vector<Node>& args,
                                            MethodId ids,
+                                           MethodId ida,
                                            MethodId idr)
 {
   bool ndefRewriter = (idr != MethodId::RW_REWRITE);
-  if (ids != MethodId::SB_DEFAULT || ndefRewriter)
+  bool ndefApply = (ida != MethodId::SBA_SEQUENTIAL);
+  if (ids != MethodId::SB_DEFAULT || ndefRewriter || ndefApply)
   {
     args.push_back(mkMethodId(ids));
   }
+  if (ndefApply || ndefRewriter)
+  {
+    args.push_back(mkMethodId(ida));
+  }
   if (ndefRewriter)
   {
     args.push_back(mkMethodId(idr));
index 38eea31c5c0d6fb6833858cb8dd89098b4ad74d5..81da0a9695dc90458a31e6d11fe525eafc9c897a 100644 (file)
@@ -64,6 +64,21 @@ enum class MethodId : uint32_t
   SB_LITERAL,
   // P is interpreted as P -> true using Node::substitute
   SB_FORMULA,
+  //---------------------------- Substitution applications
+  // multiple substitutions are applied sequentially
+  SBA_SEQUENTIAL,
+  // multiple substitutions are applied simultaneously
+  SBA_SIMUL,
+  // multiple substitutions are applied to fix point
+  SBA_FIXPOINT
+  // For example, for x -> u, y -> f(z), z -> g(x), applying this substituion to
+  // y gives:
+  // - f(g(x)) for SBA_SEQUENTIAL
+  // - f(z) for SBA_SIMUL
+  // - f(g(u)) for SBA_FIXPOINT
+  // Notice that SBA_FIXPOINT should provide a terminating rewrite system
+  // as a substitution, or else non-termination will occur during proof
+  // checking.
 };
 /** Converts a rewriter id to a string. */
 const char* toString(MethodId id);
@@ -126,10 +141,12 @@ class BuiltinProofRuleChecker : public ProofRuleChecker
    */
   static Node applySubstitution(Node n,
                                 Node exp,
-                                MethodId ids = MethodId::SB_DEFAULT);
+                                MethodId ids = MethodId::SB_DEFAULT,
+                                MethodId ida = MethodId::SBA_SEQUENTIAL);
   static Node applySubstitution(Node n,
                                 const std::vector<Node>& exp,
-                                MethodId ids = MethodId::SB_DEFAULT);
+                                MethodId ids = MethodId::SB_DEFAULT,
+                                MethodId ida = MethodId::SBA_SEQUENTIAL);
   /** Apply substitution + rewriting
    *
    * Combines the above two steps.
@@ -143,6 +160,7 @@ class BuiltinProofRuleChecker : public ProofRuleChecker
   Node applySubstitutionRewrite(Node n,
                                 const std::vector<Node>& exp,
                                 MethodId ids = MethodId::SB_DEFAULT,
+                                MethodId ida = MethodId::SBA_SEQUENTIAL,
                                 MethodId idr = MethodId::RW_REWRITE);
   /** get a method identifier from a node, return false if we fail */
   static bool getMethodId(TNode n, MethodId& i);
@@ -153,13 +171,17 @@ class BuiltinProofRuleChecker : public ProofRuleChecker
    */
   bool getMethodIds(const std::vector<Node>& args,
                     MethodId& ids,
+                    MethodId& ida,
                     MethodId& idr,
                     size_t index);
   /**
    * Add method identifiers ids and idr as nodes to args. This does not add ids
    * or idr if their values are the default ones.
    */
-  static void addMethodIds(std::vector<Node>& args, MethodId ids, MethodId idr);
+  static void addMethodIds(std::vector<Node>& args,
+                           MethodId ids,
+                           MethodId ida,
+                           MethodId idr);
 
   /** get a TheoryId from a node, return false if we fail */
   static bool getTheoryId(TNode n, TheoryId& tid);
index 2351e7bf38f984820b2f88477a56a102ba3bad3d..bc02dcbb69a76462df4a5e3c5a893096d09d1d11 100644 (file)
@@ -149,8 +149,11 @@ void InferProofCons::convert(InferenceId infer,
         break;
       }
       // may need the "extended equality rewrite"
-      Node mainEqSRew2 = psb.applyPredElim(
-          mainEqSRew, {}, MethodId::SB_DEFAULT, MethodId::RW_REWRITE_EQ_EXT);
+      Node mainEqSRew2 = psb.applyPredElim(mainEqSRew,
+                                           {},
+                                           MethodId::SB_DEFAULT,
+                                           MethodId::SBA_SEQUENTIAL,
+                                           MethodId::RW_REWRITE_EQ_EXT);
       if (mainEqSRew2 == conc)
       {
         useBuffer = true;
@@ -286,6 +289,7 @@ void InferProofCons::convert(InferenceId infer,
                                    conc,
                                    cexp,
                                    MethodId::SB_DEFAULT,
+                                   MethodId::SBA_SEQUENTIAL,
                                    MethodId::RW_REWRITE_EQ_EXT))
         {
           Trace("strings-ipc-core") << "Transformed to " << conc
index 06b4b7ad7d2cb9a540f1c1d6ddbfa10652a8d601..70f35afb831864f4eadaaeb4443cb7037d80a0fd 100644 (file)
@@ -31,11 +31,12 @@ bool TheoryProofStepBuffer::applyEqIntro(Node src,
                                          Node tgt,
                                          const std::vector<Node>& exp,
                                          MethodId ids,
+                                         MethodId ida,
                                          MethodId idr)
 {
   std::vector<Node> args;
   args.push_back(src);
-  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr);
+  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr);
   Node res = tryStep(PfRule::MACRO_SR_EQ_INTRO, exp, args);
   if (res.isNull())
   {
@@ -58,6 +59,7 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src,
                                                Node tgt,
                                                const std::vector<Node>& exp,
                                                MethodId ids,
+                                               MethodId ida,
                                                MethodId idr)
 {
   // symmetric equalities
@@ -71,7 +73,7 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src,
   // try to prove that tgt rewrites to src
   children.insert(children.end(), exp.begin(), exp.end());
   args.push_back(tgt);
-  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr);
+  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr);
   Node res = tryStep(PfRule::MACRO_SR_PRED_TRANSFORM, children, args);
   if (res.isNull())
   {
@@ -86,11 +88,12 @@ bool TheoryProofStepBuffer::applyPredTransform(Node src,
 bool TheoryProofStepBuffer::applyPredIntro(Node tgt,
                                            const std::vector<Node>& exp,
                                            MethodId ids,
+                                           MethodId ida,
                                            MethodId idr)
 {
   std::vector<Node> args;
   args.push_back(tgt);
-  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr);
+  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr);
   Node res = tryStep(PfRule::MACRO_SR_PRED_INTRO, exp, args);
   if (res.isNull())
   {
@@ -103,13 +106,14 @@ bool TheoryProofStepBuffer::applyPredIntro(Node tgt,
 Node TheoryProofStepBuffer::applyPredElim(Node src,
                                           const std::vector<Node>& exp,
                                           MethodId ids,
+                                          MethodId ida,
                                           MethodId idr)
 {
   std::vector<Node> children;
   children.push_back(src);
   children.insert(children.end(), exp.begin(), exp.end());
   std::vector<Node> args;
-  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, idr);
+  builtin::BuiltinProofRuleChecker::addMethodIds(args, ids, ida, idr);
   Node srcRew = tryStep(PfRule::MACRO_SR_PRED_ELIM, children, args);
   if (CDProof::isSame(src, srcRew))
   {
index 328266b896cb0591aff098f5d8a10defa80dbca9..1832ddb08ab9da127d3ebe1999a591878e599364 100644 (file)
@@ -47,6 +47,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer
                     Node tgt,
                     const std::vector<Node>& exp,
                     MethodId ids = MethodId::SB_DEFAULT,
+                    MethodId ida = MethodId::SBA_SEQUENTIAL,
                     MethodId idr = MethodId::RW_REWRITE);
   /**
    * Apply predicate transform. If this method returns true, it adds (at most
@@ -58,6 +59,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer
                           Node tgt,
                           const std::vector<Node>& exp,
                           MethodId ids = MethodId::SB_DEFAULT,
+                          MethodId ida = MethodId::SBA_SEQUENTIAL,
                           MethodId idr = MethodId::RW_REWRITE);
   /**
    * Apply predicate introduction. If this method returns true, it adds proof
@@ -68,6 +70,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer
   bool applyPredIntro(Node tgt,
                       const std::vector<Node>& exp,
                       MethodId ids = MethodId::SB_DEFAULT,
+                      MethodId ida = MethodId::SBA_SEQUENTIAL,
                       MethodId idr = MethodId::RW_REWRITE);
   /**
    * Apply predicate elimination. This method returns the result of applying
@@ -83,6 +86,7 @@ class TheoryProofStepBuffer : public ProofStepBuffer
   Node applyPredElim(Node src,
                      const std::vector<Node>& exp,
                      MethodId ids = MethodId::SB_DEFAULT,
+                     MethodId ida = MethodId::SBA_SEQUENTIAL,
                      MethodId idr = MethodId::RW_REWRITE);
   //---------------------------- end utilities builtin proof rules
 
index 3e5b04d867f9409c0190b794678d98b869588dbd..d76356f75f46fcd400589543170b0545439cc082 100644 (file)
@@ -204,7 +204,14 @@ std::shared_ptr<ProofNode> TrustSubstitutionMap::getProofFor(Node eq)
     }
   }
   Trace("trust-subs-pf") << "...apply eq intro" << std::endl;
-  if (!d_tspb->applyEqIntro(n, ns, pfChildren, d_ids))
+  // We use fixpoint as the substitution-apply identifier. Notice that it
+  // suffices to use SBA_SEQUENTIAL here, but SBA_FIXPOINT is typically
+  // more efficient. This is because for substitution of size n, sequential
+  // substitution can either be implemented as n traversals of the term to
+  // apply the substitution to, or a single traversal of the term, but n^2/2
+  // traversals of the range of the substitution to prepare a simultaneous
+  // substitution. Both of these options are inefficient.
+  if (!d_tspb->applyEqIntro(n, ns, pfChildren, d_ids, MethodId::SBA_FIXPOINT))
   {
     // if we fail for any reason, we must use a trusted step instead
     d_tspb->addStep(PfRule::TRUST_SUBS_MAP, pfChildren, {eq}, eq);