Expand arith's farkas lemma rule as a macro (#6577)
[cvc5.git] / src / smt / proof_post_processor.cpp
index 40e61964c341a9c64b7af2667f3eb81116fb143e..f98d1d72773a7830f4d80946bb647b5ed87a464c 100644 (file)
@@ -1,20 +1,23 @@
-/*********************                                                        */
-/*! \file proof_post_processor.cpp
- ** \verbatim
- ** Top contributors (to current version):
- **   Andrew Reynolds
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory and their institutional affiliations.
- ** All rights reserved.  See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Implementation of module for processing proof nodes
- **/
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Haniel Barbosa, Aina Niemetz
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Implementation of module for processing proof nodes.
+ */
 
 #include "smt/proof_post_processor.h"
 
+#include "expr/proof_node_manager.h"
 #include "expr/skolem_manager.h"
+#include "options/proof_options.h"
 #include "options/smt_options.h"
 #include "preprocessing/assertion_pipeline.h"
 #include "smt/smt_engine.h"
 #include "theory/rewriter.h"
 #include "theory/theory.h"
 
-using namespace CVC4::kind;
-using namespace CVC4::theory;
+using namespace cvc5::kind;
+using namespace cvc5::theory;
 
-namespace CVC4 {
+namespace cvc5 {
 namespace smt {
 
 ProofPostprocessCallback::ProofPostprocessCallback(ProofNodeManager* pnm,
                                                    SmtEngine* smte,
-                                                   ProofGenerator* pppg)
-    : d_pnm(pnm), d_smte(smte), d_pppg(pppg), d_wfpm(pnm)
+                                                   ProofGenerator* pppg,
+                                                   bool updateScopedAssumptions)
+    : d_pnm(pnm),
+      d_smte(smte),
+      d_pppg(pppg),
+      d_wfpm(pnm),
+      d_updateScopedAssumptions(updateScopedAssumptions)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
-  // always check whether to update ASSUME
-  d_elimRules.insert(PfRule::ASSUME);
 }
 
 void ProofPostprocessCallback::initializeUpdate()
@@ -51,9 +57,26 @@ void ProofPostprocessCallback::setEliminateRule(PfRule rule)
 }
 
 bool ProofPostprocessCallback::shouldUpdate(std::shared_ptr<ProofNode> pn,
+                                            const std::vector<Node>& fa,
                                             bool& continueUpdate)
 {
-  return d_elimRules.find(pn->getRule()) != d_elimRules.end();
+  PfRule id = pn->getRule();
+  if (d_elimRules.find(id) != d_elimRules.end())
+  {
+    return true;
+  }
+  // other than elimination rules, we always update assumptions as long as
+  // d_updateScopedAssumptions is true or they are *not* in scope, i.e., not in
+  // fa
+  if (id != PfRule::ASSUME
+      || (!d_updateScopedAssumptions
+          && std::find(fa.begin(), fa.end(), pn->getResult()) != fa.end()))
+  {
+    Trace("smt-proof-pp-debug")
+        << "... not updating in-scope assumption " << pn->getResult() << "\n";
+    return false;
+  }
+  return true;
 }
 
 bool ProofPostprocessCallback::update(Node res,
@@ -130,6 +153,257 @@ bool ProofPostprocessCallback::updateInternal(Node res,
   return update(res, id, children, args, cdp, continueUpdate);
 }
 
+Node ProofPostprocessCallback::eliminateCrowdingLits(
+    const std::vector<Node>& clauseLits,
+    const std::vector<Node>& targetClauseLits,
+    const std::vector<Node>& children,
+    const std::vector<Node>& args,
+    CDProof* cdp)
+{
+  Trace("smt-proof-pp-debug2") << push;
+  NodeManager* nm = NodeManager::currentNM();
+  Node trueNode = nm->mkConst(true);
+  // get crowding lits and the position of the last clause that includes
+  // them. The factoring step must be added after the last inclusion and before
+  // its elimination.
+  std::unordered_set<TNode> crowding;
+  std::vector<std::pair<Node, size_t>> lastInclusion;
+  // positions of eliminators of crowding literals, which are the positions of
+  // the clauses that eliminate crowding literals *after* their last inclusion
+  std::vector<size_t> eliminators;
+  for (size_t i = 0, size = clauseLits.size(); i < size; ++i)
+  {
+    if (!crowding.count(clauseLits[i])
+        && std::find(
+               targetClauseLits.begin(), targetClauseLits.end(), clauseLits[i])
+               == targetClauseLits.end())
+    {
+      Node crowdLit = clauseLits[i];
+      crowding.insert(crowdLit);
+      Trace("smt-proof-pp-debug2") << "crowding lit " << crowdLit << "\n";
+      // found crowding lit, now get its last inclusion position, which is the
+      // position of the last resolution link that introduces the crowding
+      // literal. Note that this position has to be *before* the last link, as a
+      // link *after* the last inclusion must eliminate the crowding literal.
+      size_t j;
+      for (j = children.size() - 1; j > 0; --j)
+      {
+        // notice that only non-singleton clauses may be introducing the
+        // crowding literal, so we only care about non-singleton OR nodes. We
+        // check then against the kind and whether the whole OR node occurs as a
+        // pivot of the respective resolution
+        if (children[j - 1].getKind() != kind::OR)
+        {
+          continue;
+        }
+        uint64_t pivotIndex = 2 * (j - 1);
+        if (args[pivotIndex] == children[j - 1]
+            || args[pivotIndex].notNode() == children[j - 1])
+        {
+          continue;
+        }
+        if (std::find(children[j - 1].begin(), children[j - 1].end(), crowdLit)
+            != children[j - 1].end())
+        {
+          break;
+        }
+      }
+      Assert(j > 0);
+      lastInclusion.emplace_back(crowdLit, j - 1);
+
+      Trace("smt-proof-pp-debug2") << "last inc " << j - 1 << "\n";
+      // get elimination position, starting from the following link as the last
+      // inclusion one. The result is the last (in the chain, but first from
+      // this point on) resolution link that eliminates the crowding literal. A
+      // literal l is eliminated by a link if it contains a literal l' with
+      // opposite polarity to l.
+      for (; j < children.size(); ++j)
+      {
+        bool posFirst = args[(2 * j) - 1] == trueNode;
+        Node pivot = args[(2 * j)];
+        Trace("smt-proof-pp-debug2")
+            << "\tcheck w/ args " << posFirst << " / " << pivot << "\n";
+        // To eliminate the crowding literal (crowdLit), the clause must contain
+        // it with opposite polarity. There are three successful cases,
+        // according to the pivot and its sign
+        //
+        // - crowdLit is the same as the pivot and posFirst is true, which means
+        //   that the clause contains its negation and eliminates it
+        //
+        // - crowdLit is the negation of the pivot and posFirst is false, so the
+        //   clause contains the node whose negation is crowdLit. Note that this
+        //   case may either be crowdLit.notNode() == pivot or crowdLit ==
+        //   pivot.notNode().
+        if ((crowdLit == pivot && posFirst)
+            || (crowdLit.notNode() == pivot && !posFirst)
+            || (pivot.notNode() == crowdLit && !posFirst))
+        {
+          Trace("smt-proof-pp-debug2") << "\t\tfound it!\n";
+          eliminators.push_back(j);
+          break;
+        }
+      }
+      AlwaysAssert(j < children.size());
+    }
+  }
+  Assert(!lastInclusion.empty());
+  // order map so that we process crowding literals in the order of the clauses
+  // that last introduce them
+  auto cmp = [](std::pair<Node, size_t>& a, std::pair<Node, size_t>& b) {
+    return a.second < b.second;
+  };
+  std::sort(lastInclusion.begin(), lastInclusion.end(), cmp);
+  // order eliminators
+  std::sort(eliminators.begin(), eliminators.end());
+  if (Trace.isOn("smt-proof-pp-debug"))
+  {
+    Trace("smt-proof-pp-debug") << "crowding lits last inclusion:\n";
+    for (const auto& pair : lastInclusion)
+    {
+      Trace("smt-proof-pp-debug")
+          << "\t- [" << pair.second << "] : " << pair.first << "\n";
+    }
+    Trace("smt-proof-pp-debug") << "eliminators:";
+    for (size_t elim : eliminators)
+    {
+      Trace("smt-proof-pp-debug") << " " << elim;
+    }
+    Trace("smt-proof-pp-debug") << "\n";
+  }
+  // TODO (cvc4-wishues/issues/77): implement also simpler version and compare
+  //
+  // We now start to break the chain, one step at a time. Naively this breaking
+  // down would be one resolution/factoring to each crowding literal, but we can
+  // merge some of the cases. Effectively we do the following:
+  //
+  //
+  // lastClause   children[start] ... children[end]
+  // ---------------------------------------------- CHAIN_RES
+  //         C
+  //    ----------- FACTORING
+  //    lastClause'                children[start'] ... children[end']
+  //    -------------------------------------------------------------- CHAIN_RES
+  //                                    ...
+  //
+  // where
+  //   lastClause_0 = children[0]
+  //   start_0 = 1
+  //   end_0 = eliminators[0] - 1
+  //   start_i+1 = nextGuardedElimPos - 1
+  //
+  // The important point is how end_i+1 is computed. It is based on what we call
+  // the "nextGuardedElimPos", i.e., the next elimination position that requires
+  // removal of duplicates. The intuition is that a factoring step may eliminate
+  // the duplicates of crowding literals l1 and l2. If the last inclusion of l2
+  // is before the elimination of l1, then we can go ahead and also perform the
+  // elimination of l2 without another factoring. However if another literal l3
+  // has its last inclusion after the elimination of l2, then the elimination of
+  // l3 is the next guarded elimination.
+  //
+  // To do the above computation then we determine, after a resolution/factoring
+  // step, the first crowded literal to have its last inclusion after "end". The
+  // first elimination position to be bigger than the position of that crowded
+  // literal is the next guarded elimination position.
+  size_t lastElim = 0;
+  Node lastClause = children[0];
+  std::vector<Node> childrenRes;
+  std::vector<Node> childrenResArgs;
+  Node resPlaceHolder;
+  size_t nextGuardedElimPos = eliminators[0];
+  do
+  {
+    size_t start = lastElim + 1;
+    size_t end = nextGuardedElimPos - 1;
+    Trace("smt-proof-pp-debug2")
+        << "res with:\n\tlastClause: " << lastClause << "\n\tstart: " << start
+        << "\n\tend: " << end << "\n";
+    childrenRes.push_back(lastClause);
+    // note that the interval of insert is exclusive in the end, so we add 1
+    childrenRes.insert(childrenRes.end(),
+                       children.begin() + start,
+                       children.begin() + end + 1);
+    childrenResArgs.insert(childrenResArgs.end(),
+                           args.begin() + (2 * start) - 1,
+                           args.begin() + (2 * end) + 1);
+    Trace("smt-proof-pp-debug2") << "res children: " << childrenRes << "\n";
+    Trace("smt-proof-pp-debug2") << "res args: " << childrenResArgs << "\n";
+    resPlaceHolder = d_pnm->getChecker()->checkDebug(PfRule::CHAIN_RESOLUTION,
+                                                     childrenRes,
+                                                     childrenResArgs,
+                                                     Node::null(),
+                                                     "");
+    Trace("smt-proof-pp-debug2")
+        << "resPlaceHorder: " << resPlaceHolder << "\n";
+    cdp->addStep(
+        resPlaceHolder, PfRule::CHAIN_RESOLUTION, childrenRes, childrenResArgs);
+    // I need to add factoring if end < children.size(). Otherwise, this is
+    // to be handled by the caller
+    if (end < children.size() - 1)
+    {
+      lastClause = d_pnm->getChecker()->checkDebug(
+          PfRule::FACTORING, {resPlaceHolder}, {}, Node::null(), "");
+      if (!lastClause.isNull())
+      {
+        cdp->addStep(lastClause, PfRule::FACTORING, {resPlaceHolder}, {});
+      }
+      else
+      {
+        lastClause = resPlaceHolder;
+      }
+      Trace("smt-proof-pp-debug2") << "lastClause: " << lastClause << "\n";
+    }
+    else
+    {
+      lastClause = resPlaceHolder;
+      break;
+    }
+    // update for next round
+    childrenRes.clear();
+    childrenResArgs.clear();
+    lastElim = end;
+
+    // find the position of the last inclusion of the next crowded literal
+    size_t nextCrowdedInclusionPos = lastInclusion.size();
+    for (size_t i = 0, size = lastInclusion.size(); i < size; ++i)
+    {
+      if (lastInclusion[i].second > lastElim)
+      {
+        nextCrowdedInclusionPos = i;
+        break;
+      }
+    }
+    Trace("smt-proof-pp-debug2")
+        << "nextCrowdedInclusion/Pos: "
+        << lastInclusion[nextCrowdedInclusionPos].second << "/"
+        << nextCrowdedInclusionPos << "\n";
+    // if there is none, then the remaining literals will be used in the next
+    // round
+    if (nextCrowdedInclusionPos == lastInclusion.size())
+    {
+      nextGuardedElimPos = children.size();
+    }
+    else
+    {
+      nextGuardedElimPos = children.size();
+      for (size_t i = 0, size = eliminators.size(); i < size; ++i)
+      {
+        //  nextGuardedElimPos is the largest element of
+        // eliminators bigger the next crowded literal's last inclusion
+        if (eliminators[i] > lastInclusion[nextCrowdedInclusionPos].second)
+        {
+          nextGuardedElimPos = eliminators[i];
+          break;
+        }
+      }
+      Assert(nextGuardedElimPos < children.size());
+    }
+    Trace("smt-proof-pp-debug2")
+        << "nextGuardedElimPos: " << nextGuardedElimPos << "\n";
+  } while (true);
+  Trace("smt-proof-pp-debug2") << pop;
+  return lastClause;
+}
+
 Node ProofPostprocessCallback::expandMacros(PfRule id,
                                             const std::vector<Node>& children,
                                             const std::vector<Node>& args,
@@ -153,19 +427,27 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     {
       std::vector<Node> sargs;
       sargs.push_back(t);
-      MethodId sid = MethodId::SB_DEFAULT;
+      MethodId ids = MethodId::SB_DEFAULT;
       if (args.size() >= 2)
       {
-        if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid))
+        if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids))
         {
           sargs.push_back(args[1]);
         }
       }
-      ts =
-          builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid);
+      MethodId ida = MethodId::SBA_SEQUENTIAL;
+      if (args.size() >= 3)
+      {
+        if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida))
+        {
+          sargs.push_back(args[2]);
+        }
+      }
+      ts = builtin::BuiltinProofRuleChecker::applySubstitution(
+          t, children, ids, ida);
       Trace("smt-proof-pp-debug")
           << "...eq intro subs equality is " << t << " == " << ts << ", from "
-          << sid << std::endl;
+          << ids << " " << ida << std::endl;
       if (ts != t)
       {
         Node eq = t.eqNode(ts);
@@ -185,21 +467,21 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     }
     std::vector<Node> rargs;
     rargs.push_back(ts);
-    MethodId rid = MethodId::RW_REWRITE;
-    if (args.size() >= 3)
+    MethodId idr = MethodId::RW_REWRITE;
+    if (args.size() >= 4)
     {
-      if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid))
+      if (builtin::BuiltinProofRuleChecker::getMethodId(args[3], idr))
       {
-        rargs.push_back(args[2]);
+        rargs.push_back(args[3]);
       }
     }
     builtin::BuiltinProofRuleChecker* builtinPfC =
         static_cast<builtin::BuiltinProofRuleChecker*>(
             d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO));
-    Node tr = builtinPfC->applyRewrite(ts, rid);
+    Node tr = builtinPfC->applyRewrite(ts, idr);
     Trace("smt-proof-pp-debug")
         << "...eq intro rewrite equality is " << ts << " == " << tr << ", from "
-        << rid << std::endl;
+        << idr << std::endl;
     if (ts != tr)
     {
       Node eq = ts.eqNode(tr);
@@ -375,6 +657,128 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     cdp->addStep(args[0], PfRule::EQ_RESOLVE, {children[0], eq}, {});
     return args[0];
   }
+  else if (id == PfRule::MACRO_RESOLUTION
+           || id == PfRule::MACRO_RESOLUTION_TRUST)
+  {
+    // first generate the naive chain_resolution
+    std::vector<Node> chainResArgs{args.begin() + 1, args.end()};
+    Node chainConclusion = d_pnm->getChecker()->checkDebug(
+        PfRule::CHAIN_RESOLUTION, children, chainResArgs, Node::null(), "");
+    Trace("smt-proof-pp-debug") << "Original conclusion: " << args[0] << "\n";
+    Trace("smt-proof-pp-debug")
+        << "chainRes conclusion: " << chainConclusion << "\n";
+    // There are n cases:
+    // - if the conclusion is the same, just replace
+    // - if they have the same literals but in different quantity, add a
+    //   FACTORING step
+    // - if the order is not the same, add a REORDERING step
+    // - if there are literals in chainConclusion that are not in the original
+    //   conclusion, we need to transform the MACRO_RESOLUTION into a series of
+    //   CHAIN_RESOLUTION + FACTORING steps, so that we explicitly eliminate all
+    //   these "crowding" literals. We do this via FACTORING so we avoid adding
+    //   an exponential number of premises, which would happen if we just
+    //   repeated in the premises the clauses needed for eliminating crowding
+    //   literals, which could themselves add crowding literals.
+    if (chainConclusion == args[0])
+    {
+      cdp->addStep(
+          chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs);
+      return chainConclusion;
+    }
+    NodeManager* nm = NodeManager::currentNM();
+    // If we got here, then chainConclusion is NECESSARILY an OR node
+    Assert(chainConclusion.getKind() == kind::OR);
+    // get the literals in the chain conclusion
+    std::vector<Node> chainConclusionLits{chainConclusion.begin(),
+                                          chainConclusion.end()};
+    std::set<Node> chainConclusionLitsSet{chainConclusion.begin(),
+                                          chainConclusion.end()};
+    // is args[0] a singleton clause? If it's not an OR node, then yes.
+    // Otherwise, it's only a singleton if it occurs in chainConclusionLitsSet
+    std::vector<Node> conclusionLits;
+    // whether conclusion is singleton
+    if (chainConclusionLitsSet.count(args[0]))
+    {
+      conclusionLits.push_back(args[0]);
+    }
+    else
+    {
+      Assert(args[0].getKind() == kind::OR);
+      conclusionLits.insert(
+          conclusionLits.end(), args[0].begin(), args[0].end());
+    }
+    std::set<Node> conclusionLitsSet{conclusionLits.begin(),
+                                     conclusionLits.end()};
+    // If the sets are different, there are "crowding" literals, i.e. literals
+    // that were removed by implicit multi-usage of premises in the resolution
+    // chain.
+    if (chainConclusionLitsSet != conclusionLitsSet)
+    {
+      chainConclusion = eliminateCrowdingLits(
+          chainConclusionLits, conclusionLits, children, args, cdp);
+      // update vector of lits. Note that the set is no longer used, so we don't
+      // need to update it
+      //
+      // We need again to check whether chainConclusion is a singleton
+      // clause. As above, it's a singleton if it's in the original
+      // chainConclusionLitsSet.
+      chainConclusionLits.clear();
+      if (chainConclusionLitsSet.count(chainConclusion))
+      {
+        chainConclusionLits.push_back(chainConclusion);
+      }
+      else
+      {
+        Assert(chainConclusion.getKind() == kind::OR);
+        chainConclusionLits.insert(chainConclusionLits.end(),
+                                   chainConclusion.begin(),
+                                   chainConclusion.end());
+      }
+    }
+    else
+    {
+      cdp->addStep(
+          chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs);
+    }
+    Trace("smt-proof-pp-debug")
+        << "Conclusion after chain_res/elimCrowd: " << chainConclusion << "\n";
+    Trace("smt-proof-pp-debug")
+        << "Conclusion lits: " << chainConclusionLits << "\n";
+    // Placeholder for running conclusion
+    Node n = chainConclusion;
+    // factoring
+    if (chainConclusionLits.size() != conclusionLits.size())
+    {
+      // We build it rather than taking conclusionLits because the order may be
+      // different
+      std::vector<Node> factoredLits;
+      std::unordered_set<TNode> clauseSet;
+      for (size_t i = 0, size = chainConclusionLits.size(); i < size; ++i)
+      {
+        if (clauseSet.count(chainConclusionLits[i]))
+        {
+          continue;
+        }
+        factoredLits.push_back(n[i]);
+        clauseSet.insert(n[i]);
+      }
+      Node factored = factoredLits.empty()
+                          ? nm->mkConst(false)
+                          : factoredLits.size() == 1
+                                ? factoredLits[0]
+                                : nm->mkNode(kind::OR, factoredLits);
+      cdp->addStep(factored, PfRule::FACTORING, {n}, {});
+      n = factored;
+    }
+    // either same node or n as a clause
+    Assert(n == args[0] || n.getKind() == kind::OR);
+    // reordering
+    if (n != args[0])
+    {
+      cdp->addStep(args[0], PfRule::REORDERING, {n}, {args[0]});
+    }
+    return args[0];
+  }
   else if (id == PfRule::SUBS)
   {
     NodeManager* nm = NodeManager::currentNM();
@@ -401,6 +805,11 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     {
       builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids);
     }
+    MethodId ida = MethodId::SBA_SEQUENTIAL;
+    if (args.size() >= 3)
+    {
+      builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida);
+    }
     std::vector<std::shared_ptr<CDProof>> pfs;
     std::vector<TNode> vsList;
     std::vector<TNode> ssList;
@@ -438,7 +847,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
           << "...process " << var << " -> " << subs << " (" << childFrom << ", "
           << ids << ")" << std::endl;
       // apply the current substitution to the range
-      if (!vvec.empty())
+      if (!vvec.empty() && ida == MethodId::SBA_SEQUENTIAL)
       {
         Node ss =
             subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
@@ -461,7 +870,8 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
           // add previous rewrite steps
           for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
           {
-            tcg.addRewriteStep(vvec[j], svec[j], pgs[j]);
+            // substitutions are pre-rewrites
+            tcg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
           }
           // get the proof for the update to the current substitution
           Node seqss = subs.eqNode(ss);
@@ -492,42 +902,47 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
       svec.push_back(subs);
       pgs.push_back(cdp);
     }
-    Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
-    Node eq = t.eqNode(ts);
-    if (ts != t)
-    {
-      // should be implied by the substitution now
-      TConvProofGenerator tcpg(d_pnm,
-                               nullptr,
-                               TConvPolicy::ONCE,
-                               TConvCachePolicy::NEVER,
-                               "SUBS_TConvProofGenerator",
-                               nullptr,
-                               true);
-      for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
-      {
-        tcpg.addRewriteStep(vvec[j], svec[j], pgs[j]);
-      }
-      // add the proof constructed by the term conversion utility
-      std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq);
-      // should give a proof, if not, then tcpg does not agree with the
-      // substitution.
-      Assert(pfn != nullptr);
-      if (pfn == nullptr)
-      {
-        cdp->addStep(eq, PfRule::TRUST_SUBS, {}, {eq});
-      }
-      else
-      {
-        cdp->addProof(pfn);
-      }
+    // should be implied by the substitution now
+    TConvPolicy tcpolicy = ida == MethodId::SBA_FIXPOINT ? TConvPolicy::FIXPOINT
+                                                         : TConvPolicy::ONCE;
+    TConvProofGenerator tcpg(d_pnm,
+                             nullptr,
+                             tcpolicy,
+                             TConvCachePolicy::NEVER,
+                             "SUBS_TConvProofGenerator",
+                             nullptr,
+                             true);
+    for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
+    {
+      // substitutions are pre-rewrites
+      tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
+    }
+    // add the proof constructed by the term conversion utility
+    std::shared_ptr<ProofNode> pfn = tcpg.getProofForRewriting(t);
+    Node eq = pfn->getResult();
+    Node ts = builtin::BuiltinProofRuleChecker::applySubstitution(
+        t, children, ids, ida);
+    Node eqq = t.eqNode(ts);
+    if (eq != eqq)
+    {
+      pfn = nullptr;
+    }
+    // should give a proof, if not, then tcpg does not agree with the
+    // substitution.
+    Assert(pfn != nullptr);
+    if (pfn == nullptr)
+    {
+      AlwaysAssert(false) << "resort to TRUST_SUBS" << std::endl
+                          << eq << std::endl
+                          << eqq << std::endl
+                          << "from " << children << " applied to " << t;
+      cdp->addStep(eqq, PfRule::TRUST_SUBS, {}, {eqq});
     }
     else
     {
-      // should not be necessary typically
-      cdp->addStep(eq, PfRule::REFL, {}, {t});
+      cdp->addProof(pfn);
     }
-    return eq;
+    return eqq;
   }
   else if (id == PfRule::REWRITE)
   {
@@ -613,6 +1028,58 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     // otherwise no update
     Trace("final-pf-hole") << "hole: " << id << " : " << eq << std::endl;
   }
+  else if (id == PfRule::MACRO_ARITH_SCALE_SUM_UB)
+  {
+    Debug("macro::arith") << "Expand MACRO_ARITH_SCALE_SUM_UB" << std::endl;
+    if (Debug.isOn("macro::arith"))
+    {
+      for (const auto& child : children)
+      {
+        Debug("macro::arith") << "  child: " << child << std::endl;
+      }
+      Debug("macro::arith") << "   args: " << args << std::endl;
+    }
+    Assert(args.size() == children.size());
+    NodeManager* nm = NodeManager::currentNM();
+    ProofStepBuffer steps{d_pnm->getChecker()};
+
+    // Scale all children, accumulating
+    std::vector<Node> scaledRels;
+    for (size_t i = 0; i < children.size(); ++i)
+    {
+      TNode child = children[i];
+      TNode scalar = args[i];
+      bool isPos = scalar.getConst<Rational>() > 0;
+      Node scalarCmp =
+          nm->mkNode(isPos ? GT : LT, scalar, nm->mkConst(Rational(0)));
+      // (= scalarCmp true)
+      Node scalarCmpOrTrue = steps.tryStep(PfRule::EVALUATE, {}, {scalarCmp});
+      Assert(!scalarCmpOrTrue.isNull());
+      // scalarCmp
+      steps.addStep(PfRule::TRUE_ELIM, {scalarCmpOrTrue}, {}, scalarCmp);
+      // (and scalarCmp relation)
+      Node scalarCmpAndRel =
+          steps.tryStep(PfRule::AND_INTRO, {scalarCmp, child}, {});
+      Assert(!scalarCmpAndRel.isNull());
+      // (=> (and scalarCmp relation) scaled)
+      Node impl =
+          steps.tryStep(isPos ? PfRule::ARITH_MULT_POS : PfRule::ARITH_MULT_NEG,
+                        {},
+                        {scalar, child});
+      Assert(!impl.isNull());
+      // scaled
+      Node scaled =
+          steps.tryStep(PfRule::MODUS_PONENS, {scalarCmpAndRel, impl}, {});
+      Assert(!scaled.isNull());
+      scaledRels.emplace_back(scaled);
+    }
+
+    Node sumBounds = steps.tryStep(PfRule::ARITH_SUM_UB, scaledRels, {});
+    cdp->addSteps(steps);
+    Debug("macro::arith") << "Expansion done. Proved: " << sumBounds
+                          << std::endl;
+    return sumBounds;
+  }
 
   // TRUST, PREPROCESS, THEORY_LEMMA, THEORY_PREPROCESS?
 
@@ -621,7 +1088,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
 
 Node ProofPostprocessCallback::addProofForWitnessForm(Node t, CDProof* cdp)
 {
-  Node tw = SkolemManager::getWitnessForm(t);
+  Node tw = SkolemManager::getOriginalForm(t);
   Node eq = t.eqNode(tw);
   if (t == tw)
   {
@@ -703,25 +1170,18 @@ bool ProofPostprocessCallback::addToTransChildren(Node eq,
 
 ProofPostprocessFinalCallback::ProofPostprocessFinalCallback(
     ProofNodeManager* pnm)
-    : d_ruleCount("finalProof::ruleCount"),
-      d_totalRuleCount("finalProof::totalRuleCount", 0),
-      d_minPedanticLevel("finalProof::minPedanticLevel", 10),
-      d_numFinalProofs("finalProofs::numFinalProofs", 0),
+    : d_ruleCount(smtStatisticsRegistry().registerHistogram<PfRule>(
+        "finalProof::ruleCount")),
+      d_totalRuleCount(
+          smtStatisticsRegistry().registerInt("finalProof::totalRuleCount")),
+      d_minPedanticLevel(
+          smtStatisticsRegistry().registerInt("finalProof::minPedanticLevel")),
+      d_numFinalProofs(
+          smtStatisticsRegistry().registerInt("finalProofs::numFinalProofs")),
       d_pnm(pnm),
       d_pedanticFailure(false)
 {
-  smtStatisticsRegistry()->registerStat(&d_ruleCount);
-  smtStatisticsRegistry()->registerStat(&d_totalRuleCount);
-  smtStatisticsRegistry()->registerStat(&d_minPedanticLevel);
-  smtStatisticsRegistry()->registerStat(&d_numFinalProofs);
-}
-
-ProofPostprocessFinalCallback::~ProofPostprocessFinalCallback()
-{
-  smtStatisticsRegistry()->unregisterStat(&d_ruleCount);
-  smtStatisticsRegistry()->unregisterStat(&d_totalRuleCount);
-  smtStatisticsRegistry()->unregisterStat(&d_minPedanticLevel);
-  smtStatisticsRegistry()->unregisterStat(&d_numFinalProofs);
+  d_minPedanticLevel += 10;
 }
 
 void ProofPostprocessFinalCallback::initializeUpdate()
@@ -732,11 +1192,12 @@ void ProofPostprocessFinalCallback::initializeUpdate()
 }
 
 bool ProofPostprocessFinalCallback::shouldUpdate(std::shared_ptr<ProofNode> pn,
+                                                 const std::vector<Node>& fa,
                                                  bool& continueUpdate)
 {
   PfRule r = pn->getRule();
   // if not doing eager pedantic checking, fail if below threshold
-  if (!options::proofNewPedanticEager())
+  if (!options::proofEagerChecking())
   {
     if (!d_pedanticFailure)
     {
@@ -770,9 +1231,10 @@ bool ProofPostprocessFinalCallback::wasPedanticFailure(std::ostream& out) const
 
 ProofPostproccess::ProofPostproccess(ProofNodeManager* pnm,
                                      SmtEngine* smte,
-                                     ProofGenerator* pppg)
+                                     ProofGenerator* pppg,
+                                     bool updateScopedAssumptions)
     : d_pnm(pnm),
-      d_cb(pnm, smte, pppg),
+      d_cb(pnm, smte, pppg, updateScopedAssumptions),
       // the update merges subproofs
       d_updater(d_pnm, d_cb, true),
       d_finalCb(pnm),
@@ -809,4 +1271,4 @@ void ProofPostproccess::setEliminateRule(PfRule rule)
 }
 
 }  // namespace smt
-}  // namespace CVC4
+}  // namespace cvc5