(proof-new) Proofs for non-clausal simplification (#5409)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 14 Nov 2020 14:17:15 +0000 (08:17 -0600)
committerGitHub <noreply@github.com>
Sat, 14 Nov 2020 14:17:15 +0000 (08:17 -0600)
Adds proof support in non-clausal simplification, connecting the proofs from circuit propagator.

src/preprocessing/passes/non_clausal_simp.cpp
src/preprocessing/passes/non_clausal_simp.h

index 2b788098a38c2d759953dab7e9601a43f417688e..cedee6d3cca3a98ba381c790dc6334e31878453e 100644 (file)
@@ -47,7 +47,19 @@ NonClausalSimp::Statistics::~Statistics()
 /* -------------------------------------------------------------------------- */
 
 NonClausalSimp::NonClausalSimp(PreprocessingPassContext* preprocContext)
-    : PreprocessingPass(preprocContext, "non-clausal-simp")
+    : PreprocessingPass(preprocContext, "non-clausal-simp"),
+      d_pnm(preprocContext->getProofNodeManager()),
+      d_llpg(d_pnm ? new smt::PreprocessProofGenerator(
+                         d_pnm,
+                         preprocContext->getUserContext(),
+                         "NonClausalSimp::llpg")
+                   : nullptr),
+      d_llra(d_pnm ? new LazyCDProof(d_pnm,
+                                     nullptr,
+                                     preprocContext->getUserContext(),
+                                     "NonClausalSimp::llra")
+                   : nullptr),
+      d_tsubsList(preprocContext->getUserContext())
 {
 }
 
@@ -101,11 +113,10 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
         << "conflict in non-clausal propagation" << std::endl;
     Assert(!options::unsatCores());
     assertionsToPreprocess->clear();
-    Node n = NodeManager::currentNM()->mkConst<bool>(false);
-    assertionsToPreprocess->push_back(n);
+    assertionsToPreprocess->pushBackTrusted(conf);
     if (options::unsatCores())
     {
-      ProofManager::currentPM()->addDependence(n, Node::null());
+      ProofManager::currentPM()->addDependence(conf.getNode(), Node::null());
     }
     propagator->setNeedsFinish(true);
     return PreprocessingPassResult::CONFLICT;
@@ -115,41 +126,41 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
       << "Iterate through " << propagator->getLearnedLiterals().size()
       << " learned literals." << std::endl;
   // No conflict, go through the literals and solve them
+  context::Context* u = d_preprocContext->getUserContext();
   TrustSubstitutionMap& ttls = d_preprocContext->getTopLevelSubstitutions();
   CVC4_UNUSED SubstitutionMap& top_level_substs = ttls.get();
-  SubstitutionMap constantPropagations(d_preprocContext->getUserContext());
-  TrustSubstitutionMap tnewSubstituions(d_preprocContext->getUserContext(),
-                                        nullptr);
-  SubstitutionMap& newSubstitutions = tnewSubstituions.get();
-  SubstitutionMap::iterator pos;
+  // constant propagations
+  std::shared_ptr<TrustSubstitutionMap> constantPropagations =
+      std::make_shared<TrustSubstitutionMap>(
+          u, d_pnm, "NonClausalSimp::cprop", PfRule::PREPROCESS_LEMMA);
+  SubstitutionMap& cps = constantPropagations->get();
+  // new substitutions
+  std::shared_ptr<TrustSubstitutionMap> newSubstitutions =
+      std::make_shared<TrustSubstitutionMap>(
+          u, d_pnm, "NonClausalSimp::newSubs", PfRule::PREPROCESS_LEMMA);
+  SubstitutionMap& nss = newSubstitutions->get();
+
   size_t j = 0;
   std::vector<TrustNode>& learned_literals = propagator->getLearnedLiterals();
+  // if proofs are enabled, we will need to track the proofs of learned literals
+  if (isProofEnabled())
+  {
+    d_tsubsList.push_back(constantPropagations);
+    d_tsubsList.push_back(newSubstitutions);
+    for (const TrustNode& tll : learned_literals)
+    {
+      d_llpg->notifyNewTrustedAssert(tll);
+    }
+  }
   for (size_t i = 0, size = learned_literals.size(); i < size; ++i)
   {
     // Simplify the literal we learned wrt previous substitutions
     Node learnedLiteral = learned_literals[i].getNode();
     Assert(Rewriter::rewrite(learnedLiteral) == learnedLiteral);
     Assert(top_level_substs.apply(learnedLiteral) == learnedLiteral);
-    Trace("non-clausal-simplify")
-        << "Process learnedLiteral : " << learnedLiteral << std::endl;
-    Node learnedLiteralNew = newSubstitutions.apply(learnedLiteral);
-    if (learnedLiteral != learnedLiteralNew)
-    {
-      learnedLiteral = Rewriter::rewrite(learnedLiteralNew);
-    }
-    Trace("non-clausal-simplify")
-        << "Process learnedLiteral, after newSubs : " << learnedLiteral
-        << std::endl;
-    for (;;)
-    {
-      learnedLiteralNew = constantPropagations.apply(learnedLiteral);
-      if (learnedLiteralNew == learnedLiteral)
-      {
-        break;
-      }
-      d_statistics.d_numConstantProps += 1;
-      learnedLiteral = Rewriter::rewrite(learnedLiteralNew);
-    }
+    // process the learned literal with substitutions and const propagations
+    learnedLiteral = processLearnedLit(
+        learnedLiteral, newSubstitutions.get(), constantPropagations.get());
     Trace("non-clausal-simplify")
         << "Process learnedLiteral, after constProp : " << learnedLiteral
         << std::endl;
@@ -169,7 +180,7 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
         Assert(!options::unsatCores());
         assertionsToPreprocess->clear();
         Node n = NodeManager::currentNM()->mkConst<bool>(false);
-        assertionsToPreprocess->push_back(n);
+        assertionsToPreprocess->push_back(n, false, false, d_llpg.get());
         if (options::unsatCores())
         {
           ProofManager::currentPM()->addDependence(n, Node::null());
@@ -182,11 +193,12 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
     // Solve it with the corresponding theory, possibly adding new
     // substitutions to newSubstitutions
     Trace("non-clausal-simplify") << "solving " << learnedLiteral << std::endl;
+
     TrustNode tlearnedLiteral =
-        TrustNode::mkTrustLemma(learnedLiteral, nullptr);
+        TrustNode::mkTrustLemma(learnedLiteral, d_llpg.get());
     Theory::PPAssertStatus solveStatus =
         d_preprocContext->getTheoryEngine()->solve(tlearnedLiteral,
-                                                   tnewSubstituions);
+                                                   *newSubstitutions.get());
 
     switch (solveStatus)
     {
@@ -195,16 +207,7 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
         // The literal should rewrite to true
         Trace("non-clausal-simplify")
             << "solved " << learnedLiteral << std::endl;
-        Assert(Rewriter::rewrite(newSubstitutions.apply(learnedLiteral))
-                   .isConst());
-        //        vector<pair<Node, Node> > equations;
-        //        constantPropagations.simplifyLHS(top_level_substs, equations,
-        //        true); if (equations.empty()) {
-        //          break;
-        //        }
-        //        Assert(equations[0].first.isConst() &&
-        //        equations[0].second.isConst() && equations[0].first !=
-        //        equations[0].second);
+        Assert(Rewriter::rewrite(nss.apply(learnedLiteral)).isConst());
         // else fall through
         break;
       }
@@ -242,21 +245,19 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
             c = learnedLiteral[1];
           }
           Assert(!t.isConst());
-          Assert(constantPropagations.apply(t) == t);
+          Assert(cps.apply(t) == t);
           Assert(top_level_substs.apply(t) == t);
-          Assert(newSubstitutions.apply(t) == t);
-          constantPropagations.addSubstitution(t, c);
-          // vector<pair<Node,Node> > equations;
-          // constantPropagations.simplifyLHS(t, c, equations, true);
-          // if (!equations.empty()) {
-          //   Assert(equations[0].first.isConst() &&
-          //   equations[0].second.isConst() && equations[0].first !=
-          //   equations[0].second); assertionsToPreprocess->clear();
-          //   Node n = NodeManager::currentNM()->mkConst<bool>(false);
-          //   assertionsToPreprocess->push_back(n);
-          //   false); return;
-          // }
-          // top_level_substs.simplifyRHS(constantPropagations);
+          Assert(nss.apply(t) == t);
+          // also add to learned literal
+          ProofGenerator* cpg = constantPropagations->addSubstitutionSolved(
+              t, c, tlearnedLiteral);
+          // We need to justify (= t c) as a literal, since it is reasserted
+          // to the assertion pipeline below. We do this with the proof
+          // generator returned by the above call.
+          if (isProofEnabled())
+          {
+            d_llpg->notifyNewAssert(t.eqNode(c), cpg);
+          }
         }
         else
         {
@@ -280,34 +281,19 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
   // r' another constant propagation, then l'[l/r] -> r' should be a
   //    constant propagation too
   // 4. each lhs of constantPropagations is different from each rhs
-  for (pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos)
+  for (SubstitutionMap::iterator pos = nss.begin(); pos != nss.end(); ++pos)
   {
     Assert((*pos).first.isVar());
     Assert(top_level_substs.apply((*pos).first) == (*pos).first);
     Assert(top_level_substs.apply((*pos).second) == (*pos).second);
-    Assert(newSubstitutions.apply(newSubstitutions.apply((*pos).second))
-           == newSubstitutions.apply((*pos).second));
+    Node app = nss.apply((*pos).second);
+    Assert(nss.apply(app) == app);
   }
-  for (pos = constantPropagations.begin(); pos != constantPropagations.end();
-       ++pos)
+  for (SubstitutionMap::iterator pos = cps.begin(); pos != cps.end(); ++pos)
   {
     Assert((*pos).second.isConst());
     Assert(Rewriter::rewrite((*pos).first) == (*pos).first);
-    // Node newLeft = top_level_substs.apply((*pos).first);
-    // if (newLeft != (*pos).first) {
-    //   newLeft = Rewriter::rewrite(newLeft);
-    //   Assert(newLeft == (*pos).second ||
-    //          (constantPropagations.hasSubstitution(newLeft) &&
-    //          constantPropagations.apply(newLeft) == (*pos).second));
-    // }
-    // newLeft = constantPropagations.apply((*pos).first);
-    // if (newLeft != (*pos).first) {
-    //   newLeft = Rewriter::rewrite(newLeft);
-    //   Assert(newLeft == (*pos).second ||
-    //          (constantPropagations.hasSubstitution(newLeft) &&
-    //          constantPropagations.apply(newLeft) == (*pos).second));
-    // }
-    Assert(constantPropagations.apply((*pos).second) == (*pos).second);
+    Assert(cps.apply((*pos).second) == (*pos).second);
   }
 #endif /* CVC4_ASSERTIONS */
 
@@ -320,33 +306,31 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
   for (size_t i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
   {
     Node assertion = (*assertionsToPreprocess)[i];
-    Node assertionNew = newSubstitutions.apply(assertion);
+    TrustNode assertionNew = newSubstitutions->apply(assertion);
     Trace("non-clausal-simplify") << "assertion = " << assertion << std::endl;
-    Trace("non-clausal-simplify")
-        << "assertionNew = " << assertionNew << std::endl;
-    if (assertion != assertionNew)
+    if (!assertionNew.isNull())
     {
-      assertion = Rewriter::rewrite(assertionNew);
       Trace("non-clausal-simplify")
-          << "rewrite(assertion) = " << assertion << std::endl;
+          << "assertionNew = " << assertionNew.getNode() << std::endl;
+      assertionsToPreprocess->replaceTrusted(i, assertionNew);
+      assertion = assertionNew.getNode();
+      Assert(Rewriter::rewrite(assertion) == assertion);
     }
-    Assert(Rewriter::rewrite(assertion) == assertion);
     for (;;)
     {
-      assertionNew = constantPropagations.apply(assertion);
-      if (assertionNew == assertion)
+      assertionNew = constantPropagations->apply(assertion);
+      if (assertionNew.isNull())
       {
         break;
       }
+      Assert(assertionNew.getNode() != assertion);
+      assertionsToPreprocess->replaceTrusted(i, assertionNew);
+      assertion = assertionNew.getNode();
       d_statistics.d_numConstantProps += 1;
       Trace("non-clausal-simplify")
-          << "assertionNew = " << assertionNew << std::endl;
-      assertion = Rewriter::rewrite(assertionNew);
-      Trace("non-clausal-simplify")
-          << "assertionNew = " << assertionNew << std::endl;
+          << "assertionNew = " << assertion << std::endl;
     }
     s.insert(assertion);
-    assertionsToPreprocess->replace(i, assertion);
     Trace("non-clausal-simplify")
         << "non-clausal preprocessed: " << assertion << std::endl;
   }
@@ -355,10 +339,11 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
   TheoryModel* m = d_preprocContext->getTheoryEngine()->getModel();
   Assert(m != nullptr);
   NodeManager* nm = NodeManager::currentNM();
-  for (pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos)
+  for (SubstitutionMap::iterator pos = nss.begin(); pos != nss.end(); ++pos)
   {
     Node lhs = (*pos).first;
-    Node rhs = newSubstitutions.apply((*pos).second);
+    TrustNode trhs = newSubstitutions->apply((*pos).second);
+    Node rhs = trhs.isNull() ? (*pos).second : trhs.getNode();
     // If using incremental, we must check whether this variable has occurred
     // before now. If it hasn't we can add this as a substitution.
     if (!assertionsToPreprocess->storeSubstsInAsserts()
@@ -376,34 +361,26 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
       Trace("non-clausal-simplify")
           << "substitute: will notify SAT layer of substitution: " << eq
           << std::endl;
-      assertionsToPreprocess->addSubstitutionNode(eq);
+       trhs = newSubstitutions->apply((*pos).first);
+       Assert(!trhs.isNull());
+       assertionsToPreprocess->addSubstitutionNode(trhs.getProven(),
+       trhs.getGenerator());
     }
   }
 
   Assert(assertionsToPreprocess->getRealAssertionsEnd()
          <= assertionsToPreprocess->size());
+  // Learned literals to conjoin. If proofs are enabled, all these are
+  // justified by d_llpg.
   std::vector<Node> learnedLitsToConjoin;
 
   for (size_t i = 0; i < learned_literals.size(); ++i)
   {
     Node learned = learned_literals[i].getNode();
     Assert(top_level_substs.apply(learned) == learned);
-    Node learnedNew = newSubstitutions.apply(learned);
-    if (learned != learnedNew)
-    {
-      learned = Rewriter::rewrite(learnedNew);
-    }
-    Assert(Rewriter::rewrite(learned) == learned);
-    for (;;)
-    {
-      learnedNew = constantPropagations.apply(learned);
-      if (learnedNew == learned)
-      {
-        break;
-      }
-      d_statistics.d_numConstantProps += 1;
-      learned = Rewriter::rewrite(learnedNew);
-    }
+    // process learned literal
+    learned = processLearnedLit(
+        learned, newSubstitutions.get(), constantPropagations.get());
     if (s.find(learned) != s.end())
     {
       continue;
@@ -415,17 +392,12 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
   }
   learned_literals.clear();
 
-  for (pos = constantPropagations.begin(); pos != constantPropagations.end();
-       ++pos)
+  for (SubstitutionMap::iterator pos = cps.begin(); pos != cps.end(); ++pos)
   {
     Node cProp = (*pos).first.eqNode((*pos).second);
     Assert(top_level_substs.apply(cProp) == cProp);
-    Node cPropNew = newSubstitutions.apply(cProp);
-    if (cProp != cPropNew)
-    {
-      cProp = Rewriter::rewrite(cPropNew);
-      Assert(Rewriter::rewrite(cProp) == cProp);
-    }
+    // process learned literal (substitutions only)
+    cProp = processLearnedLit(cProp, newSubstitutions.get(), nullptr);
     if (s.find(cProp) != s.end())
     {
       continue;
@@ -440,20 +412,89 @@ PreprocessingPassResult NonClausalSimp::applyInternal(
   // Note that we don't have to keep rhs's in full solved form
   // because SubstitutionMap::apply does a fixed-point iteration when
   // substituting
-  top_level_substs.addSubstitutions(newSubstitutions);
+  ttls.addSubstitutions(*newSubstitutions.get());
 
   if (!learnedLitsToConjoin.empty())
   {
     size_t replIndex = assertionsToPreprocess->getRealAssertionsEnd() - 1;
     Node newConj = NodeManager::currentNM()->mkAnd(learnedLitsToConjoin);
-    assertionsToPreprocess->conjoin(replIndex, newConj);
+    Trace("non-clausal-simplify")
+        << "non-clausal simplification, reassert: " << newConj << std::endl;
+    ProofGenerator* pg = nullptr;
+    if (isProofEnabled())
+    {
+      // justify in d_llra
+      for (const Node& lit : learnedLitsToConjoin)
+      {
+        d_llra->addLazyStep(lit, d_llpg.get());
+      }
+      if (learnedLitsToConjoin.size() > 1)
+      {
+        d_llra->addStep(newConj, PfRule::AND_INTRO, learnedLitsToConjoin, {});
+        pg = d_llra.get();
+      }
+      else
+      {
+        // otherwise we ask the learned literal proof generator directly
+        pg = d_llpg.get();
+      }
+    }
+    // ------- from d_llpg    --------- from d_llpg
+    //  conj[0]       ....    d_conj[n]
+    // -------------------------------- AND_INTRO
+    //  newConj
+    // where newConj is conjoined at the given index
+    assertionsToPreprocess->conjoin(replIndex, newConj, pg);
   }
 
   propagator->setNeedsFinish(true);
   return PreprocessingPassResult::NO_CONFLICT;
-}  // namespace passes
+}
 
-/* -------------------------------------------------------------------------- */
+bool NonClausalSimp::isProofEnabled() const { return d_pnm != nullptr; }
+
+Node NonClausalSimp::processLearnedLit(Node lit,
+                                       theory::TrustSubstitutionMap* subs,
+                                       theory::TrustSubstitutionMap* cp)
+{
+  TrustNode tlit;
+  if (subs != nullptr)
+  {
+    tlit = subs->apply(lit);
+    if (!tlit.isNull())
+    {
+      lit = processRewrittenLearnedLit(tlit);
+    }
+    Trace("non-clausal-simplify")
+        << "Process learnedLiteral, after newSubs : " << lit << std::endl;
+  }
+  // apply to fixed point
+  if (cp != nullptr)
+  {
+    for (;;)
+    {
+      tlit = cp->apply(lit);
+      if (tlit.isNull())
+      {
+        break;
+      }
+      Assert(lit != tlit.getNode());
+      lit = processRewrittenLearnedLit(tlit);
+      d_statistics.d_numConstantProps += 1;
+    }
+  }
+  return lit;
+}
+
+Node NonClausalSimp::processRewrittenLearnedLit(theory::TrustNode trn)
+{
+  if (isProofEnabled())
+  {
+    d_llpg->notifyTrustedPreprocessed(trn);
+  }
+  // return the node
+  return trn.getNode();
+}
 
 }  // namespace passes
 }  // namespace preprocessing
index defb3cc82dfcc541da60ee13979741c75e4c206b..d3a9e0b1aca5c32a76a135a3b443e8687babc4d3 100644 (file)
 
 #include <vector>
 
+#include "expr/lazy_proof.h"
 #include "expr/node.h"
 #include "preprocessing/preprocessing_pass.h"
 #include "preprocessing/preprocessing_pass_context.h"
+#include "smt/preprocess_proof_generator.h"
+#include "theory/trust_node.h"
 
 namespace CVC4 {
 namespace preprocessing {
@@ -45,9 +48,44 @@ class NonClausalSimp : public PreprocessingPass
   };
 
   Statistics d_statistics;
-
-  /** Learned literals */
-  std::vector<Node> d_nonClausalLearnedLiterals;
+  /**
+   * Transform learned literal lit. We apply substitutions in subs once and then
+   * apply constant propagations cp to fixed point. Return the rewritten
+   * form of lit.
+   *
+   * If proofs are enabled, then we require that the learned literal preprocess
+   * proof generator (d_llpg) has a proof of lit when this method is called,
+   * and ensure that the return literal also has a proof in d_llpg.
+   */
+  Node processLearnedLit(Node lit,
+                         theory::TrustSubstitutionMap* subs,
+                         theory::TrustSubstitutionMap* cp);
+  /**
+   * Process rewritten learned literal. This is called when we have a
+   * learned literal lit that is rewritten to litr based on the proof generator
+   * contained in trn (if it exists). The trust node trn should be of kind
+   * REWRITE and proving (= lit litr).
+   *
+   * This tracks the proof in the learned literal preprocess proof generator
+   * d_llpg below and returns the rewritten learned literal.
+   */
+  Node processRewrittenLearnedLit(theory::TrustNode trn);
+  /** Is proof enabled? */
+  bool isProofEnabled() const;
+  /** The proof node manager */
+  ProofNodeManager* d_pnm;
+  /** the learned literal preprocess proof generator */
+  std::unique_ptr<smt::PreprocessProofGenerator> d_llpg;
+  /**
+   * An lazy proof for learned literals that are reasserted into the assertions
+   * pipeline by this class.
+   */
+  std::unique_ptr<LazyCDProof> d_llra;
+  /**
+   * A context-dependent list of trust substitution maps, which are required
+   * for storing proofs.
+   */
+  context::CDList<std::shared_ptr<theory::TrustSubstitutionMap> > d_tsubsList;
 };
 
 }  // namespace passes