(proof-new) Distinguish pre vs post rewrites in term conversion proof generator ...
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 29 Jan 2021 19:27:44 +0000 (13:27 -0600)
committerGitHub <noreply@github.com>
Fri, 29 Jan 2021 19:27:44 +0000 (13:27 -0600)
This is work towards resolving two kinds of failures on proof-new:

(1) Functional issues with proofs from the rewriter, in particular when a term pre-rewrites and post-rewrites to different things,
(2) Conversion issues in theory-preprocessing, where all steps are assumed to be post-rewrites but some are in fact pre-rewrites. This leads to the term conversion proof generator proving something different than what is expected.

A followup PR will simplify and fix proofs for theory-preprocessing.

src/expr/term_conversion_proof_generator.cpp
src/expr/term_conversion_proof_generator.h
src/smt/expand_definitions.cpp
src/smt/proof_post_processor.cpp
src/smt/term_formula_removal.cpp
src/smt/witness_form.cpp
src/theory/rewriter.cpp
src/theory/theory_preprocessor.cpp

index d351d88b693be165da84ac9d66b963a52e36f425..d37ccf9147fbe808524f6fc06e2c58193f656f00 100644 (file)
@@ -51,7 +51,8 @@ TConvProofGenerator::TConvProofGenerator(ProofNodeManager* pnm,
                                          TermContext* tccb,
                                          bool rewriteOps)
     : d_proof(pnm, nullptr, c, name + "::LazyCDProof"),
-      d_rewriteMap(c ? c : &d_context),
+      d_preRewriteMap(c ? c : &d_context),
+      d_postRewriteMap(c ? c : &d_context),
       d_policy(pol),
       d_cpolicy(cpol),
       d_name(name),
@@ -65,23 +66,22 @@ TConvProofGenerator::~TConvProofGenerator() {}
 void TConvProofGenerator::addRewriteStep(Node t,
                                          Node s,
                                          ProofGenerator* pg,
+                                         bool isPre,
                                          PfRule trustId,
                                          bool isClosed,
                                          uint32_t tctx)
 {
-  Node eq = registerRewriteStep(t, s, tctx);
+  Node eq = registerRewriteStep(t, s, tctx, isPre);
   if (!eq.isNull())
   {
     d_proof.addLazyStep(eq, pg, trustId, isClosed);
   }
 }
 
-void TConvProofGenerator::addRewriteStep(Node t,
-                                         Node s,
-                                         ProofStep ps,
-                                         uint32_t tctx)
+void TConvProofGenerator::addRewriteStep(
+    Node t, Node s, ProofStep ps, bool isPre, uint32_t tctx)
 {
-  Node eq = registerRewriteStep(t, s, tctx);
+  Node eq = registerRewriteStep(t, s, tctx, isPre);
   if (!eq.isNull())
   {
     d_proof.addStep(eq, ps);
@@ -93,31 +93,39 @@ void TConvProofGenerator::addRewriteStep(Node t,
                                          PfRule id,
                                          const std::vector<Node>& children,
                                          const std::vector<Node>& args,
+                                         bool isPre,
                                          uint32_t tctx)
 {
-  Node eq = registerRewriteStep(t, s, tctx);
+  Node eq = registerRewriteStep(t, s, tctx, isPre);
   if (!eq.isNull())
   {
     d_proof.addStep(eq, id, children, args);
   }
 }
 
-bool TConvProofGenerator::hasRewriteStep(Node t, uint32_t tctx) const
+bool TConvProofGenerator::hasRewriteStep(Node t,
+                                         uint32_t tctx,
+                                         bool isPre) const
 {
-  return !getRewriteStep(t, tctx).isNull();
+  return !getRewriteStep(t, tctx, isPre).isNull();
 }
 
-Node TConvProofGenerator::getRewriteStep(Node t, uint32_t tctx) const
+Node TConvProofGenerator::getRewriteStep(Node t,
+                                         uint32_t tctx,
+                                         bool isPre) const
 {
   Node thash = t;
   if (d_tcontext != nullptr)
   {
     thash = TCtxNode::computeNodeHash(t, tctx);
   }
-  return getRewriteStepInternal(thash);
+  return getRewriteStepInternal(thash, isPre);
 }
 
-Node TConvProofGenerator::registerRewriteStep(Node t, Node s, uint32_t tctx)
+Node TConvProofGenerator::registerRewriteStep(Node t,
+                                              Node s,
+                                              uint32_t tctx,
+                                              bool isPre)
 {
   if (t == s)
   {
@@ -134,14 +142,15 @@ Node TConvProofGenerator::registerRewriteStep(Node t, Node s, uint32_t tctx)
     Assert(tctx == 0);
   }
   // should not rewrite term to two different things
-  if (!getRewriteStepInternal(thash).isNull())
+  if (!getRewriteStepInternal(thash, isPre).isNull())
   {
-    Assert(getRewriteStepInternal(thash) == s)
+    Assert(getRewriteStepInternal(thash, isPre) == s)
         << identify() << " rewriting " << t << " to both " << s << " and "
-        << getRewriteStepInternal(thash);
+        << getRewriteStepInternal(thash, isPre);
     return Node::null();
   }
-  d_rewriteMap[thash] = s;
+  NodeNodeMap& rm = isPre ? d_preRewriteMap : d_postRewriteMap;
+  rm[thash] = s;
   if (d_cpolicy == TConvCachePolicy::DYNAMIC)
   {
     // clear the cache
@@ -188,19 +197,23 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofFor(Node f)
         serr << " (see -t tconv-pf-gen-debug for details)";
       }
       serr << std::endl;
-      serr << "                  source: " << f[0] << std::endl;
-      serr << "expected after rewriting: " << f[1] << std::endl;
-      serr << "  actual after rewriting: " << conc[1] << std::endl;
+      serr << "                   source: " << f[0] << std::endl;
+      serr << "     requested conclusion: " << f[1] << std::endl;
+      serr << "conclusion from generator: " << conc[1] << std::endl;
 
       if (debugTraceEnabled)
       {
         Trace("tconv-pf-gen-debug") << "Printing rewrite steps..." << std::endl;
-        serr << "Rewrite steps: " << std::endl;
-        for (NodeNodeMap::const_iterator it = d_rewriteMap.begin();
-             it != d_rewriteMap.end();
-             ++it)
+        for (size_t r = 0; r < 2; r++)
         {
-          serr << (*it).first << " -> " << (*it).second << std::endl;
+          const NodeNodeMap& rm = r == 0 ? d_preRewriteMap : d_postRewriteMap;
+          serr << "Rewrite steps (" << (r == 0 ? "pre" : "post")
+               << "):" << std::endl;
+          for (NodeNodeMap::const_iterator it = rm.begin(); it != rm.end();
+               ++it)
+          {
+            serr << (*it).first << " -> " << (*it).second << std::endl;
+          }
         }
       }
       Unhandled() << serr.str();
@@ -288,8 +301,8 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
     {
       Trace("tconv-pf-gen-rewrite") << "- previsit" << std::endl;
       visited[curHash] = Node::null();
-      // did we rewrite the current node (possibly at pre-rewrite)?
-      Node rcur = getRewriteStepInternal(curHash);
+      // did we rewrite the current node (at pre-rewrite)?
+      Node rcur = getRewriteStepInternal(curHash, true);
       if (!rcur.isNull())
       {
         Trace("tconv-pf-gen-rewrite")
@@ -484,13 +497,8 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
           retHash = TCtxNode::computeNodeHash(cur, curCVal);
         }
         // did we rewrite ret (at post-rewrite)?
-        Node rret;
-        // only if not ONCE policy, which only does pre-rewrite
-        if (d_policy != TConvPolicy::ONCE)
-        {
-          rret = getRewriteStepInternal(retHash);
-        }
-        if (!rret.isNull())
+        Node rret = getRewriteStepInternal(retHash, false);
+        if (!rret.isNull() && d_policy == TConvPolicy::FIXPOINT)
         {
           Trace("tconv-pf-gen-rewrite")
               << "*** " << retHash << " postrewrites to " << rret << std::endl;
@@ -519,6 +527,8 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
         }
         else
         {
+          // take its rewrite if it rewrote and we have ONCE rewriting policy
+          ret = rret.isNull() ? ret : rret;
           Trace("tconv-pf-gen-rewrite")
               << "-> (postrewrite) " << curHash << " = " << ret << std::endl;
           // it is final
@@ -553,10 +563,11 @@ void TConvProofGenerator::doCache(Node curHash,
   }
 }
 
-Node TConvProofGenerator::getRewriteStepInternal(Node t) const
+Node TConvProofGenerator::getRewriteStepInternal(Node t, bool isPre) const
 {
-  NodeNodeMap::const_iterator it = d_rewriteMap.find(t);
-  if (it == d_rewriteMap.end())
+  const NodeNodeMap& rm = isPre ? d_preRewriteMap : d_postRewriteMap;
+  NodeNodeMap::const_iterator it = rm.find(t);
+  if (it == rm.end())
   {
     return Node::null();
   }
index 0a1f4e70a3b96f4d9f1419e0fd2193efe4d8232e..dcf8b3fa46770d398423b0412a99f3dcdcf8d71a 100644 (file)
@@ -144,6 +144,8 @@ class TConvProofGenerator : public ProofGenerator
   /**
    * Add rewrite step t --> s based on proof generator.
    *
+   * @param isPre Whether the rewrite is applied at prerewrite (pre-order
+   * traversal).
    * @param trustId If a null proof generator is provided, we add a step to
    * the proof that has trustId as the rule and expected as the sole argument.
    * @param isClosed whether to expect that pg can provide a closed proof for
@@ -156,25 +158,28 @@ class TConvProofGenerator : public ProofGenerator
   void addRewriteStep(Node t,
                       Node s,
                       ProofGenerator* pg,
+                      bool isPre = false,
                       PfRule trustId = PfRule::ASSUME,
                       bool isClosed = false,
                       uint32_t tctx = 0);
   /** Same as above, for a single step */
-  void addRewriteStep(Node t, Node s, ProofStep ps, uint32_t tctx = 0);
+  void addRewriteStep(
+      Node t, Node s, ProofStep ps, bool isPre = false, uint32_t tctx = 0);
   /** Same as above, with explicit arguments */
   void addRewriteStep(Node t,
                       Node s,
                       PfRule id,
                       const std::vector<Node>& children,
                       const std::vector<Node>& args,
+                      bool isPre = false,
                       uint32_t tctx = 0);
   /** Has rewrite step for term t */
-  bool hasRewriteStep(Node t, uint32_t tctx = 0) const;
+  bool hasRewriteStep(Node t, uint32_t tctx = 0, bool isPre = false) const;
   /** 
    * Get rewrite step for term t, returns the s provided in a call to
    * addRewriteStep if one exists, or null otherwise.
    */
-  Node getRewriteStep(Node t, uint32_t tctx = 0) const;
+  Node getRewriteStep(Node t, uint32_t tctx = 0, bool isPre = false) const;
   /**
    * Get the proof for formula f. It should be the case that f is of the form
    * t = t', where t' is the result of rewriting t based on the rewrite steps
@@ -194,7 +199,8 @@ class TConvProofGenerator : public ProofGenerator
   /** The (lazy) context dependent proof object. */
   LazyCDProof d_proof;
   /** map to rewritten forms */
-  NodeNodeMap d_rewriteMap;
+  NodeNodeMap d_preRewriteMap;
+  NodeNodeMap d_postRewriteMap;
   /**
    * Policy for how rewrites are applied to terms. As a simple example, say we
    * have registered the rewrite steps:
@@ -220,7 +226,7 @@ class TConvProofGenerator : public ProofGenerator
    */
   bool d_rewriteOps;
   /** Get rewrite step for (hash value of) term. */
-  Node getRewriteStepInternal(Node thash) const;
+  Node getRewriteStepInternal(Node thash, bool isPre) const;
   /**
    * Adds a proof of t = t' to the proof pf where t' is the result of rewriting
    * t based on the rewrite steps registered to this class. This method then
@@ -231,7 +237,7 @@ class TConvProofGenerator : public ProofGenerator
    * Register rewrite step, returns the equality t=s if t is distinct from s
    * and a rewrite step has not already been registered for t.
    */
-  Node registerRewriteStep(Node t, Node s, uint32_t tctx);
+  Node registerRewriteStep(Node t, Node s, uint32_t tctx, bool isPre);
   /** cache that r is the rewritten form of cur, pf can provide a proof */
   void doCache(Node curHash, Node cur, Node r, LazyCDProof& pf);
   /** get debug information on this generator */
index 4f49e1bbee3defffa08b925c3f29866a22def8cc..c2c4b6fd228ba7e3a2da77c0893ab9b697a06ebe 100644 (file)
@@ -99,7 +99,7 @@ TrustNode ExpandDefs::expandDefinitions(
             // ------- ASSUME
             // n = f
             Node conc = n.eqNode(f);
-            tpg->addRewriteStep(n, f, PfRule::ASSUME, {}, {conc});
+            tpg->addRewriteStep(n, f, PfRule::ASSUME, {}, {conc}, true);
           }
           // must recursively expand its definition
           TrustNode tfe = expandDefinitions(f, cache, expandOnly, tpg);
@@ -236,7 +236,8 @@ TrustNode ExpandDefs::expandDefinitions(
                                 instance,
                                 PfRule::MACRO_SR_PRED_INTRO,
                                 pfExpChildren,
-                                {conc});
+                                {conc},
+                                true);
           }
         }
         // now, call expand definitions again on the result
@@ -260,7 +261,7 @@ TrustNode ExpandDefs::expandDefinitions(
           if (tpg != nullptr)
           {
             tpg->addRewriteStep(
-                n, node, trn.getGenerator(), PfRule::THEORY_EXPAND_DEF);
+                n, node, trn.getGenerator(), true, PfRule::THEORY_EXPAND_DEF);
           }
         }
         else
index a620a4d22720d83968b1176691877dd2421da749..a57a6d920810521320e6bd97af127ad96905f8eb 100644 (file)
@@ -806,7 +806,8 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
           // add previous rewrite steps
           for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
           {
-            tcg.addRewriteStep(vvec[j], svec[j], pgs[j]);
+            // substitutions are pre-rewrites
+            tcg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
           }
           // get the proof for the update to the current substitution
           Node seqss = subs.eqNode(ss);
@@ -851,7 +852,8 @@ Node ProofPostprocessCallback::expandMacros(PfRule id,
                                true);
       for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++)
       {
-        tcpg.addRewriteStep(vvec[j], svec[j], pgs[j]);
+        // substitutions are pre-rewrites
+        tcpg.addRewriteStep(vvec[j], svec[j], pgs[j], true);
       }
       // add the proof constructed by the term conversion utility
       std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq);
index 9a856fc149c908443c65f2985616f2359b05eb3b..0df521c0b90cd3bd0e31a0e69f66d59304c91396 100644 (file)
@@ -462,6 +462,7 @@ Node RemoveTermFormulas::runCurrent(std::pair<Node, uint32_t>& curr,
                             PfRule::MACRO_SR_PRED_INTRO,
                             {},
                             {node.eqNode(skolem)},
+                            true,
                             cval);
     }
 
index 9c2c035a8078964632ade172594fe812e46ec035..cba3bd340f4ed5296a50b100ab6aa5d0c16cc661 100644 (file)
@@ -118,7 +118,8 @@ Node WitnessFormGenerator::convertToWitnessForm(Node t)
               true,
               "WitnessFormGenerator::convertToWitnessForm:witness_axiom");
           d_wintroPf.addStep(eq, PfRule::WITNESS_INTRO, {exists}, {});
-          d_tcpg.addRewriteStep(cur, curw, &d_wintroPf, PfRule::ASSUME, true);
+          d_tcpg.addRewriteStep(
+              cur, curw, &d_wintroPf, true, PfRule::ASSUME, true);
         }
         else
         {
index 725fa5cd897cac95b2b0d4aa1bcfff6d325f0226..be26510c61a0cc615ac8cb993a8df86c717fc15b 100644 (file)
@@ -382,11 +382,14 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId,
                 << "with proofs: " << rewriteStackTop.d_node << std::endl;
             Trace("rewriter-proof") << " w/o proofs: " << cached << std::endl;
             Node eq = rewriteStackTop.d_node.eqNode(cached);
+            // we make this a post-rewrite, since we are processing a node that
+            // has finished post-rewriting above
             tcpg->addRewriteStep(rewriteStackTop.d_node,
                                  cached,
                                  PfRule::TRUST_REWRITE,
                                  {},
-                                 {eq});
+                                 {eq},
+                                 false);
             // don't overwrite the cache, should be the same
             rewriteStackTop.d_node = cached;
           }
@@ -485,12 +488,13 @@ RewriteResponse Rewriter::processTrustRewriteResponse(
                            proven[1],
                            PfRule::THEORY_REWRITE,
                            {},
-                           {proven, tidn, rid});
+                           {proven, tidn, rid},
+                           isPre);
     }
     else
     {
       // store proven rewrite step
-      tcpg->addRewriteStep(proven[0], proven[1], pg);
+      tcpg->addRewriteStep(proven[0], proven[1], pg, isPre);
     }
   }
   return RewriteResponse(tresponse.d_status, trn.getNode());
index 7c01eda0fc80e242abd4e3077514c54aa6f85430..3b5e9cee40cf5e9c39b0972c81ba80db06070e1d 100644 (file)
@@ -421,6 +421,7 @@ Node TheoryPreprocessor::ppTheoryRewrite(TNode term)
 
 Node TheoryPreprocessor::rewriteWithProof(Node term)
 {
+  // FIXME (project #37): should properly distinguish pre vs post rewrite
   Node termr = Rewriter::rewrite(term);
   // store rewrite step if tracking proofs and it rewrites
   if (isProofEnabled())
@@ -431,7 +432,7 @@ Node TheoryPreprocessor::rewriteWithProof(Node term)
       Trace("tpp-debug") << "TheoryPreprocessor: addRewriteStep (rewriting) "
                          << term << " -> " << termr << std::endl;
       // always use term context hash 0 (default)
-      d_tpg->addRewriteStep(term, termr, PfRule::REWRITE, {}, {term});
+      d_tpg->addRewriteStep(term, termr, PfRule::REWRITE, {}, {term}, false);
     }
   }
   return termr;
@@ -470,6 +471,7 @@ Node TheoryPreprocessor::preprocessWithProof(Node term)
   }
   Node termr = trn.getNode();
   Assert(term != termr);
+  // FIXME (project #37): should properly distinguish pre vs post rewrite
   if (isProofEnabled())
   {
     if (trn.getGenerator() != nullptr)
@@ -480,15 +482,19 @@ Node TheoryPreprocessor::preprocessWithProof(Node term)
                            "TheoryPreprocessor::preprocessWithProof");
       // always use term context hash 0 (default)
       d_tpg->addRewriteStep(
-          term, termr, trn.getGenerator(), PfRule::ASSUME, true);
+          term, termr, trn.getGenerator(), false, PfRule::ASSUME, true);
     }
     else
     {
       Trace("tpp-debug") << "TheoryPreprocessor: addRewriteStep (trusted) "
                          << term << " -> " << termr << std::endl;
       // small step trust
-      d_tpg->addRewriteStep(
-          term, termr, PfRule::THEORY_PREPROCESS, {}, {term.eqNode(termr)});
+      d_tpg->addRewriteStep(term,
+                            termr,
+                            PfRule::THEORY_PREPROCESS,
+                            {},
+                            {term.eqNode(termr)},
+                            false);
     }
   }
   termr = rewriteWithProof(termr);