(proof-new) Make nl-ext factoring lemmas proof producing. (#5698)
authorGereon Kremer <gkremer@stanford.edu>
Mon, 21 Dec 2020 16:20:23 +0000 (17:20 +0100)
committerGitHub <noreply@github.com>
Mon, 21 Dec 2020 16:20:23 +0000 (17:20 +0100)
This PR adds proofs for the lemmas from the nonlinear factoring check.

src/theory/arith/nl/ext/factoring_check.cpp
src/theory/arith/nl/ext/factoring_check.h
src/theory/arith/nl/nonlinear_extension.cpp

index 71c584e91983a3d80eb6c78d5f28cb406854729e..098f4500d619f1fa5c952c68e5fe0c838d736f9f 100644 (file)
@@ -15,6 +15,7 @@
 #include "theory/arith/nl/ext/factoring_check.h"
 
 #include "expr/node.h"
+#include "expr/skolem_manager.h"
 #include "theory/arith/arith_msum.h"
 #include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/nl_model.h"
@@ -24,8 +25,7 @@ namespace theory {
 namespace arith {
 namespace nl {
 
-FactoringCheck::FactoringCheck(InferenceManager& im, NlModel& model)
-    : d_im(im), d_model(model)
+FactoringCheck::FactoringCheck(ExtState* data) : d_data(data)
 {
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
   d_one = NodeManager::currentNM()->mkConst(Rational(1));
@@ -40,7 +40,7 @@ void FactoringCheck::check(const std::vector<Node>& asserts,
   {
     bool polarity = lit.getKind() != Kind::NOT;
     Node atom = lit.getKind() == Kind::NOT ? lit[0] : lit;
-    Node litv = d_model.computeConcreteModelValue(lit);
+    Node litv = d_data->d_model.computeConcreteModelValue(lit);
     bool considerLit = false;
     // Only consider literals that are in false_asserts.
     considerLit = std::find(false_asserts.begin(), false_asserts.end(), lit)
@@ -120,7 +120,13 @@ void FactoringCheck::check(const std::vector<Node>& asserts,
           sum = Rewriter::rewrite(sum);
           Trace("nl-ext-factor")
               << "* Factored sum for " << x << " : " << sum << std::endl;
-          Node kf = getFactorSkolem(sum);
+
+          CDProof* proof = nullptr;
+          if (d_data->isProofEnabled())
+          {
+            proof = d_data->getProof();
+          }
+          Node kf = getFactorSkolem(sum, proof);
           std::vector<Node> poly;
           poly.push_back(nm->mkNode(Kind::MULT, x, kf));
           std::map<Node, std::vector<Node> >::iterator itfo =
@@ -149,26 +155,41 @@ void FactoringCheck::check(const std::vector<Node>& asserts,
           }
 
           std::vector<Node> lemma_disj;
-          lemma_disj.push_back(lit.negate());
           lemma_disj.push_back(conc_lit);
+          lemma_disj.push_back(lit.negate());
           Node flem = nm->mkNode(Kind::OR, lemma_disj);
           Trace("nl-ext-factor") << "...lemma is " << flem << std::endl;
-          d_im.addPendingArithLemma(flem, InferenceId::NL_FACTOR);
+          if (d_data->isProofEnabled())
+          {
+            Node k_eq = kf.eqNode(sum);
+            Node split = nm->mkNode(Kind::OR, lit, lit.notNode());
+            proof->addStep(split, PfRule::SPLIT, {}, {lit});
+            proof->addStep(
+                flem, PfRule::MACRO_SR_PRED_TRANSFORM, {split, k_eq}, {flem});
+          }
+          d_data->d_im.addPendingArithLemma(
+              flem, InferenceId::NL_FACTOR, proof);
         }
       }
     }
   }
 }
 
-Node FactoringCheck::getFactorSkolem(Node n)
+Node FactoringCheck::getFactorSkolem(Node n, CDProof* proof)
 {
   std::map<Node, Node>::iterator itf = d_factor_skolem.find(n);
   if (itf == d_factor_skolem.end())
   {
     NodeManager* nm = NodeManager::currentNM();
-    Node k = nm->mkSkolem("kf", n.getType());
-    Node k_eq = Rewriter::rewrite(k.eqNode(n));
-    d_im.addPendingArithLemma(k_eq, InferenceId::NL_FACTOR);
+    Node k = nm->getSkolemManager()->mkPurifySkolem(n, "kf");
+    Node k_eq = k.eqNode(n);
+    Trace("nl-ext-factor") << "...adding factor skolem " << k << " == " << n
+                           << std::endl;
+    if (d_data->isProofEnabled())
+    {
+      proof->addStep(k_eq, PfRule::MACRO_SR_PRED_INTRO, {}, {k_eq});
+    }
+    d_data->d_im.addPendingArithLemma(k_eq, InferenceId::NL_FACTOR, proof);
     d_factor_skolem[n] = k;
     return k;
   }
@@ -178,4 +199,4 @@ Node FactoringCheck::getFactorSkolem(Node n)
 }  // namespace nl
 }  // namespace arith
 }  // namespace theory
-}  // namespace CVC4
\ No newline at end of file
+}  // namespace CVC4
index 9f879aa39cdefb7dc30656f300dd2a9796ea7bc8..fa0f8239ab46e42e311041300fc0c94cf0af045d 100644 (file)
@@ -18,8 +18,7 @@
 #include <vector>
 
 #include "expr/node.h"
-#include "theory/arith/inference_manager.h"
-#include "theory/arith/nl/nl_model.h"
+#include "theory/arith/nl/ext/ext_state.h"
 
 namespace CVC4 {
 namespace theory {
@@ -29,12 +28,12 @@ namespace nl {
 class FactoringCheck
 {
  public:
-  FactoringCheck(InferenceManager& im, NlModel& model);
+  FactoringCheck(ExtState* data);
 
   /** check factoring
    *
    * Returns a set of valid theory lemmas, based on a
-   * lemma schema that states a relationship betwen monomials
+   * lemma schema that states a relationship between monomials
    * with common factors that occur in the same constraint.
    *
    * Examples:
@@ -47,17 +46,20 @@ class FactoringCheck
              const std::vector<Node>& false_asserts);
 
  private:
-  /** The inference manager that we push conflicts and lemmas to. */
-  InferenceManager& d_im;
-  /** Reference to the non-linear model object */
-  NlModel& d_model;
+  /** Basic data that is shared with other checks */
+  ExtState* d_data;
+
   /** maps nodes to their factor skolems */
   std::map<Node, Node> d_factor_skolem;
 
   Node d_zero;
   Node d_one;
 
-  Node getFactorSkolem(Node n);
+  /**
+   * Introduces a new purification skolem k for n and adds k=n as lemma.
+   * If proof is not nullptr, it proves this lemma via MACRO_SR_PRED_INTRO.
+   */
+  Node getFactorSkolem(Node n, CDProof* proof);
 };
 
 }  // namespace nl
index b97a53f953489591a39114bad7c01d228d6e15aa..c6787140da925dde8db2242f98713511d8487e54 100644 (file)
@@ -49,7 +49,7 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing,
       d_model(containing.getSatContext()),
       d_trSlv(d_im, d_model, pnm, containing.getUserContext()),
       d_extState(d_im, d_model, pnm, containing.getUserContext()),
-      d_factoringSlv(d_im, d_model),
+      d_factoringSlv(&d_extState),
       d_monomialBoundsSlv(&d_extState),
       d_monomialSlv(&d_extState),
       d_splitZeroSlv(&d_extState),