Expand arith's farkas lemma rule as a macro (#6577)
[cvc5.git] / src / smt / proof_post_processor.cpp
index bc701ebc871ab94704900d04c814489bc79da6ff..f98d1d72773a7830f4d80946bb647b5ed87a464c 100644 (file)
@@ -1,16 +1,17 @@
-/*********************                                                        */
-/*! \file proof_post_processor.cpp
- ** \verbatim
- ** Top contributors (to current version):
- **   Andrew Reynolds, Haniel Barbosa
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
- ** in the top-level source directory and their institutional affiliations.
- ** All rights reserved.  See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief Implementation of module for processing proof nodes
- **/
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Haniel Barbosa, Aina Niemetz
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Implementation of module for processing proof nodes.
+ */
 
 #include "smt/proof_post_processor.h"
 
 #include "theory/rewriter.h"
 #include "theory/theory.h"
 
-using namespace CVC4::kind;
-using namespace CVC4::theory;
+using namespace cvc5::kind;
+using namespace cvc5::theory;
 
-namespace CVC4 {
+namespace cvc5 {
 namespace smt {
 
 ProofPostprocessCallback::ProofPostprocessCallback(ProofNodeManager* pnm,
                                                    SmtEngine* smte,
-                                                   ProofGenerator* pppg)
-    : d_pnm(pnm), d_smte(smte), d_pppg(pppg), d_wfpm(pnm)
+                                                   ProofGenerator* pppg,
+                                                   bool updateScopedAssumptions)
+    : d_pnm(pnm),
+      d_smte(smte),
+      d_pppg(pppg),
+      d_wfpm(pnm),
+      d_updateScopedAssumptions(updateScopedAssumptions)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
-  // always check whether to update ASSUME
-  d_elimRules.insert(PfRule::ASSUME);
 }
 
 void ProofPostprocessCallback::initializeUpdate()
@@ -53,9 +57,26 @@ void ProofPostprocessCallback::setEliminateRule(PfRule rule)
 }
 
 bool ProofPostprocessCallback::shouldUpdate(std::shared_ptr<ProofNode> pn,
+                                            const std::vector<Node>& fa,
                                             bool& continueUpdate)
 {
-  return d_elimRules.find(pn->getRule()) != d_elimRules.end();
+  PfRule id = pn->getRule();
+  if (d_elimRules.find(id) != d_elimRules.end())
+  {
+    return true;
+  }
+  // other than elimination rules, we always update assumptions as long as
+  // d_updateScopedAssumptions is true or they are *not* in scope, i.e., not in
+  // fa
+  if (id != PfRule::ASSUME
+      || (!d_updateScopedAssumptions
+          && std::find(fa.begin(), fa.end(), pn->getResult()) != fa.end()))
+  {
+    Trace("smt-proof-pp-debug")
+        << "... not updating in-scope assumption " << pn->getResult() << "\n";
+    return false;
+  }
+  return true;
 }
 
 bool ProofPostprocessCallback::update(Node res,
@@ -145,7 +166,7 @@ Node ProofPostprocessCallback::eliminateCrowdingLits(
   // get crowding lits and the position of the last clause that includes
   // them. The factoring step must be added after the last inclusion and before
   // its elimination.
-  std::unordered_set<TNode, TNodeHashFunction> crowding;
+  std::unordered_set<TNode> crowding;
   std::vector<std::pair<Node, size_t>> lastInclusion;
   // positions of eliminators of crowding literals, which are the positions of
   // the clauses that eliminate crowding literals *after* their last inclusion
@@ -406,19 +427,27 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     {
       std::vector<Node> sargs;
       sargs.push_back(t);
-      MethodId sid = MethodId::SB_DEFAULT;
+      MethodId ids = MethodId::SB_DEFAULT;
       if (args.size() >= 2)
       {
-        if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid))
+        if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids))
         {
           sargs.push_back(args[1]);
         }
       }
-      ts =
-          builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid);
+      MethodId ida = MethodId::SBA_SEQUENTIAL;
+      if (args.size() >= 3)
+      {
+        if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida))
+        {
+          sargs.push_back(args[2]);
+        }
+      }
+      ts = builtin::BuiltinProofRuleChecker::applySubstitution(
+          t, children, ids, ida);
       Trace("smt-proof-pp-debug")
           << "...eq intro subs equality is " << t << " == " << ts << ", from "
-          << sid << std::endl;
+          << ids << " " << ida << std::endl;
       if (ts != t)
       {
         Node eq = t.eqNode(ts);
@@ -438,21 +467,21 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     }
     std::vector<Node> rargs;
     rargs.push_back(ts);
-    MethodId rid = MethodId::RW_REWRITE;
-    if (args.size() >= 3)
+    MethodId idr = MethodId::RW_REWRITE;
+    if (args.size() >= 4)
     {
-      if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid))
+      if (builtin::BuiltinProofRuleChecker::getMethodId(args[3], idr))
       {
-        rargs.push_back(args[2]);
+        rargs.push_back(args[3]);
       }
     }
     builtin::BuiltinProofRuleChecker* builtinPfC =
         static_cast<builtin::BuiltinProofRuleChecker*>(
             d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO));
-    Node tr = builtinPfC->applyRewrite(ts, rid);
+    Node tr = builtinPfC->applyRewrite(ts, idr);
     Trace("smt-proof-pp-debug")
         << "...eq intro rewrite equality is " << ts << " == " << tr << ", from "
-        << rid << std::endl;
+        << idr << std::endl;
     if (ts != tr)
     {
       Node eq = ts.eqNode(tr);
@@ -628,7 +657,8 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     cdp->addStep(args[0], PfRule::EQ_RESOLVE, {children[0], eq}, {});
     return args[0];
   }
-  else if (id == PfRule::MACRO_RESOLUTION)
+  else if (id == PfRule::MACRO_RESOLUTION
+           || id == PfRule::MACRO_RESOLUTION_TRUST)
   {
     // first generate the naive chain_resolution
     std::vector<Node> chainResArgs{args.begin() + 1, args.end()};
@@ -722,7 +752,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
       // We build it rather than taking conclusionLits because the order may be
       // different
       std::vector<Node> factoredLits;
-      std::unordered_set<TNode, TNodeHashFunction> clauseSet;
+      std::unordered_set<TNode> clauseSet;
       for (size_t i = 0, size = chainConclusionLits.size(); i < size; ++i)
       {
         if (clauseSet.count(chainConclusionLits[i]))
@@ -775,6 +805,11 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     {
       builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids);
     }
+    MethodId ida = MethodId::SBA_SEQUENTIAL;
+    if (args.size() >= 3)
+    {
+      builtin::BuiltinProofRuleChecker::getMethodId(args[2], ida);
+    }
     std::vector<std::shared_ptr<CDProof>> pfs;
     std::vector<TNode> vsList;
     std::vector<TNode> ssList;
@@ -812,7 +847,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
           << "...process " << var << " -> " << subs << " (" << childFrom << ", "
           << ids << ")" << std::endl;
       // apply the current substitution to the range
-      if (!vvec.empty())
+      if (!vvec.empty() && ida == MethodId::SBA_SEQUENTIAL)
       {
         Node ss =
             subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
@@ -867,43 +902,47 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
       svec.push_back(subs);
       pgs.push_back(cdp);
     }
-    Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end());
-    Node eq = t.eqNode(ts);
-    if (ts != t)
+    // should be implied by the substitution now
+    TConvPolicy tcpolicy = ida == MethodId::SBA_FIXPOINT ? TConvPolicy::FIXPOINT
+                                                         : TConvPolicy::ONCE;
+    TConvProofGenerator tcpg(d_pnm,
+                             nullptr,
+                             tcpolicy,
+                             TConvCachePolicy::NEVER,
+                             "SUBS_TConvProofGenerator",
+                             nullptr,
+                             true);
+    for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
     {
-      // should be implied by the substitution now
-      TConvProofGenerator tcpg(d_pnm,
-                               nullptr,
-                               TConvPolicy::ONCE,
-                               TConvCachePolicy::NEVER,
-                               "SUBS_TConvProofGenerator",
-                               nullptr,
-                               true);
-      for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
-      {
-        // substitutions are pre-rewrites
-        tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
-      }
-      // add the proof constructed by the term conversion utility
-      std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq);
-      // should give a proof, if not, then tcpg does not agree with the
-      // substitution.
-      Assert(pfn != nullptr);
-      if (pfn == nullptr)
-      {
-        cdp->addStep(eq, PfRule::TRUST_SUBS, {}, {eq});
-      }
-      else
-      {
-        cdp->addProof(pfn);
-      }
+      // substitutions are pre-rewrites
+      tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
+    }
+    // add the proof constructed by the term conversion utility
+    std::shared_ptr<ProofNode> pfn = tcpg.getProofForRewriting(t);
+    Node eq = pfn->getResult();
+    Node ts = builtin::BuiltinProofRuleChecker::applySubstitution(
+        t, children, ids, ida);
+    Node eqq = t.eqNode(ts);
+    if (eq != eqq)
+    {
+      pfn = nullptr;
+    }
+    // should give a proof, if not, then tcpg does not agree with the
+    // substitution.
+    Assert(pfn != nullptr);
+    if (pfn == nullptr)
+    {
+      AlwaysAssert(false) << "resort to TRUST_SUBS" << std::endl
+                          << eq << std::endl
+                          << eqq << std::endl
+                          << "from " << children << " applied to " << t;
+      cdp->addStep(eqq, PfRule::TRUST_SUBS, {}, {eqq});
     }
     else
     {
-      // should not be necessary typically
-      cdp->addStep(eq, PfRule::REFL, {}, {t});
+      cdp->addProof(pfn);
     }
-    return eq;
+    return eqq;
   }
   else if (id == PfRule::REWRITE)
   {
@@ -989,6 +1028,58 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
     // otherwise no update
     Trace("final-pf-hole") << "hole: " << id << " : " << eq << std::endl;
   }
+  else if (id == PfRule::MACRO_ARITH_SCALE_SUM_UB)
+  {
+    Debug("macro::arith") << "Expand MACRO_ARITH_SCALE_SUM_UB" << std::endl;
+    if (Debug.isOn("macro::arith"))
+    {
+      for (const auto& child : children)
+      {
+        Debug("macro::arith") << "  child: " << child << std::endl;
+      }
+      Debug("macro::arith") << "   args: " << args << std::endl;
+    }
+    Assert(args.size() == children.size());
+    NodeManager* nm = NodeManager::currentNM();
+    ProofStepBuffer steps{d_pnm->getChecker()};
+
+    // Scale all children, accumulating
+    std::vector<Node> scaledRels;
+    for (size_t i = 0; i < children.size(); ++i)
+    {
+      TNode child = children[i];
+      TNode scalar = args[i];
+      bool isPos = scalar.getConst<Rational>() > 0;
+      Node scalarCmp =
+          nm->mkNode(isPos ? GT : LT, scalar, nm->mkConst(Rational(0)));
+      // (= scalarCmp true)
+      Node scalarCmpOrTrue = steps.tryStep(PfRule::EVALUATE, {}, {scalarCmp});
+      Assert(!scalarCmpOrTrue.isNull());
+      // scalarCmp
+      steps.addStep(PfRule::TRUE_ELIM, {scalarCmpOrTrue}, {}, scalarCmp);
+      // (and scalarCmp relation)
+      Node scalarCmpAndRel =
+          steps.tryStep(PfRule::AND_INTRO, {scalarCmp, child}, {});
+      Assert(!scalarCmpAndRel.isNull());
+      // (=> (and scalarCmp relation) scaled)
+      Node impl =
+          steps.tryStep(isPos ? PfRule::ARITH_MULT_POS : PfRule::ARITH_MULT_NEG,
+                        {},
+                        {scalar, child});
+      Assert(!impl.isNull());
+      // scaled
+      Node scaled =
+          steps.tryStep(PfRule::MODUS_PONENS, {scalarCmpAndRel, impl}, {});
+      Assert(!scaled.isNull());
+      scaledRels.emplace_back(scaled);
+    }
+
+    Node sumBounds = steps.tryStep(PfRule::ARITH_SUM_UB, scaledRels, {});
+    cdp->addSteps(steps);
+    Debug("macro::arith") << "Expansion done. Proved: " << sumBounds
+                          << std::endl;
+    return sumBounds;
+  }
 
   // TRUST, PREPROCESS, THEORY_LEMMA, THEORY_PREPROCESS?
 
@@ -997,7 +1088,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
 
 Node ProofPostprocessCallback::addProofForWitnessForm(Node t, CDProof* cdp)
 {
-  Node tw = SkolemManager::getWitnessForm(t);
+  Node tw = SkolemManager::getOriginalForm(t);
   Node eq = t.eqNode(tw);
   if (t == tw)
   {
@@ -1079,25 +1170,18 @@ bool ProofPostprocessCallback::addToTransChildren(Node eq,
 
 ProofPostprocessFinalCallback::ProofPostprocessFinalCallback(
     ProofNodeManager* pnm)
-    : d_ruleCount("finalProof::ruleCount"),
-      d_totalRuleCount("finalProof::totalRuleCount", 0),
-      d_minPedanticLevel("finalProof::minPedanticLevel", 10),
-      d_numFinalProofs("finalProofs::numFinalProofs", 0),
+    : d_ruleCount(smtStatisticsRegistry().registerHistogram<PfRule>(
+        "finalProof::ruleCount")),
+      d_totalRuleCount(
+          smtStatisticsRegistry().registerInt("finalProof::totalRuleCount")),
+      d_minPedanticLevel(
+          smtStatisticsRegistry().registerInt("finalProof::minPedanticLevel")),
+      d_numFinalProofs(
+          smtStatisticsRegistry().registerInt("finalProofs::numFinalProofs")),
       d_pnm(pnm),
       d_pedanticFailure(false)
 {
-  smtStatisticsRegistry()->registerStat(&d_ruleCount);
-  smtStatisticsRegistry()->registerStat(&d_totalRuleCount);
-  smtStatisticsRegistry()->registerStat(&d_minPedanticLevel);
-  smtStatisticsRegistry()->registerStat(&d_numFinalProofs);
-}
-
-ProofPostprocessFinalCallback::~ProofPostprocessFinalCallback()
-{
-  smtStatisticsRegistry()->unregisterStat(&d_ruleCount);
-  smtStatisticsRegistry()->unregisterStat(&d_totalRuleCount);
-  smtStatisticsRegistry()->unregisterStat(&d_minPedanticLevel);
-  smtStatisticsRegistry()->unregisterStat(&d_numFinalProofs);
+  d_minPedanticLevel += 10;
 }
 
 void ProofPostprocessFinalCallback::initializeUpdate()
@@ -1108,6 +1192,7 @@ void ProofPostprocessFinalCallback::initializeUpdate()
 }
 
 bool ProofPostprocessFinalCallback::shouldUpdate(std::shared_ptr<ProofNode> pn,
+                                                 const std::vector<Node>& fa,
                                                  bool& continueUpdate)
 {
   PfRule r = pn->getRule();
@@ -1146,9 +1231,10 @@ bool ProofPostprocessFinalCallback::wasPedanticFailure(std::ostream& out) const
 
 ProofPostproccess::ProofPostproccess(ProofNodeManager* pnm,
                                      SmtEngine* smte,
-                                     ProofGenerator* pppg)
+                                     ProofGenerator* pppg,
+                                     bool updateScopedAssumptions)
     : d_pnm(pnm),
-      d_cb(pnm, smte, pppg),
+      d_cb(pnm, smte, pppg, updateScopedAssumptions),
       // the update merges subproofs
       d_updater(d_pnm, d_cb, true),
       d_finalCb(pnm),
@@ -1185,4 +1271,4 @@ void ProofPostproccess::setEliminateRule(PfRule rule)
 }
 
 }  // namespace smt
-}  // namespace CVC4
+}  // namespace cvc5