(proof-new) Add ONCE and FIXPOINT modes for term conversion proof generator (#4709)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 11 Jul 2020 02:19:01 +0000 (21:19 -0500)
committerGitHub <noreply@github.com>
Sat, 11 Jul 2020 02:19:01 +0000 (21:19 -0500)
We need multiple policies for generic proofs for term conversion, in particular, substitution has a "apply once" semantics which does not apply rewrite steps to a fixpoint like the rewriter.

This is required for eliminating SUBS steps in the final proof.

Also note that an internal method for getting proofs was generalized, which will be required for doing multiple variants of proofs in this utility in the future.

src/expr/term_conversion_proof_generator.cpp
src/expr/term_conversion_proof_generator.h

index 70b42951073229517f5b914eeb8f2d9bfe68d9d0..bad1633751193ac8bffe280119574ab73c91af2b 100644 (file)
@@ -18,9 +18,21 @@ using namespace CVC4::kind;
 
 namespace CVC4 {
 
+std::ostream& operator<<(std::ostream& out, TConvPolicy tcpol)
+{
+  switch (tcpol)
+  {
+    case TConvPolicy::FIXPOINT: out << "FIXPOINT"; break;
+    case TConvPolicy::ONCE: out << "ONCE"; break;
+    default: out << "TConvPolicy:unknown"; break;
+  }
+  return out;
+}
+
 TConvProofGenerator::TConvProofGenerator(ProofNodeManager* pnm,
-                                         context::Context* c)
-    : d_proof(pnm, nullptr, c), d_rewriteMap(c ? c : &d_context)
+                                         context::Context* c,
+                                         TConvPolicy tpol)
+    : d_proof(pnm, nullptr, c), d_rewriteMap(c ? c : &d_context), d_policy(tpol)
 {
 }
 
@@ -28,20 +40,20 @@ TConvProofGenerator::~TConvProofGenerator() {}
 
 void TConvProofGenerator::addRewriteStep(Node t, Node s, ProofGenerator* pg)
 {
-  // should not rewrite term more than once
-  Assert(!hasRewriteStep(t));
-  Node eq = t.eqNode(s);
-  d_proof.addLazyStep(eq, pg);
-  d_rewriteMap[t] = s;
+  Node eq = registerRewriteStep(t, s);
+  if (!eq.isNull())
+  {
+    d_proof.addLazyStep(eq, pg);
+  }
 }
 
 void TConvProofGenerator::addRewriteStep(Node t, Node s, ProofStep ps)
 {
-  // should not rewrite term more than once
-  Assert(!hasRewriteStep(t));
-  Node eq = t.eqNode(s);
-  d_proof.addStep(eq, ps);
-  d_rewriteMap[t] = s;
+  Node eq = registerRewriteStep(t, s);
+  if (!eq.isNull())
+  {
+    d_proof.addStep(eq, ps);
+  }
 }
 
 void TConvProofGenerator::addRewriteStep(Node t,
@@ -50,11 +62,11 @@ void TConvProofGenerator::addRewriteStep(Node t,
                                          const std::vector<Node>& children,
                                          const std::vector<Node>& args)
 {
-  // should not rewrite term more than once
-  Assert(!hasRewriteStep(t));
-  Node eq = t.eqNode(s);
-  d_proof.addStep(eq, id, children, args);
-  d_rewriteMap[t] = s;
+  Node eq = registerRewriteStep(t, s);
+  if (!eq.isNull())
+  {
+    d_proof.addStep(eq, id, children, args);
+  }
 }
 
 bool TConvProofGenerator::hasRewriteStep(Node t) const
@@ -62,6 +74,22 @@ bool TConvProofGenerator::hasRewriteStep(Node t) const
   return !getRewriteStep(t).isNull();
 }
 
+Node TConvProofGenerator::registerRewriteStep(Node t, Node s)
+{
+  if (t == s)
+  {
+    return Node::null();
+  }
+  // should not rewrite term to two different things
+  if (!getRewriteStep(t).isNull())
+  {
+    Assert(getRewriteStep(t) == s);
+    return Node::null();
+  }
+  d_rewriteMap[t] = s;
+  return t.eqNode(s);
+}
+
 std::shared_ptr<ProofNode> TConvProofGenerator::getProofFor(Node f)
 {
   Trace("tconv-pf-gen") << "TConvProofGenerator::getProofFor: " << f
@@ -72,29 +100,22 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofFor(Node f)
     Assert(false);
     return nullptr;
   }
-  std::shared_ptr<ProofNode> pf = getProofForRewriting(f[0]);
-  if (pf == nullptr)
-  {
-    // failed to generate proof
-    Trace("tconv-pf-gen") << "...failed to get proof" << std::endl;
-    Assert(false);
-    return pf;
-  }
-  if (pf->getResult() != f)
+  // we use the existing proofs
+  LazyCDProof lpf(d_proof.getManager(), &d_proof);
+  Node conc = getProofForRewriting(f[0], lpf);
+  if (conc != f)
   {
     Trace("tconv-pf-gen") << "...failed, mismatch: returned proof concludes "
-                          << pf->getResult() << std::endl;
+                          << conc << ", expected " << f << std::endl;
     Assert(false);
     return nullptr;
   }
   Trace("tconv-pf-gen") << "... success" << std::endl;
-  return pf;
+  return lpf.getProofFor(f);
 }
 
-std::shared_ptr<ProofNode> TConvProofGenerator::getProofForRewriting(Node t)
+Node TConvProofGenerator::getProofForRewriting(Node t, LazyCDProof& pf)
 {
-  // we use the existing proofs
-  LazyCDProof pf(d_proof.getManager(), &d_proof);
   NodeManager* nm = NodeManager::currentNM();
   // Invariant: if visited[t] = s or rewritten[t] = s and t,s are distinct,
   // then pf is able to generate a proof of t=s.
@@ -120,14 +141,23 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofForRewriting(Node t)
       Node rcur = getRewriteStep(cur);
       if (!rcur.isNull())
       {
-        // d_proof should have a proof of cur = rcur. Hence there is nothing
-        // to do here, as pf will reference prg to get the proof from d_proof.
-        // It may be the case that rcur also rewrites, thus we cannot assign
-        // the final rewritten form for cur yet. Instead we revisit cur after
-        // finishing visiting rcur.
-        rewritten[cur] = rcur;
-        visit.push_back(cur);
-        visit.push_back(rcur);
+        // d_proof has a proof of cur = rcur. Hence there is nothing
+        // to do here, as pf will reference d_proof to get its proof.
+        if (d_policy == TConvPolicy::FIXPOINT)
+        {
+          // It may be the case that rcur also rewrites, thus we cannot assign
+          // the final rewritten form for cur yet. Instead we revisit cur after
+          // finishing visiting rcur.
+          rewritten[cur] = rcur;
+          visit.push_back(cur);
+          visit.push_back(rcur);
+        }
+        else
+        {
+          Assert(d_policy == TConvPolicy::ONCE);
+          // not rewriting again, rcur is final
+          visited[cur] = rcur;
+        }
       }
       else
       {
@@ -140,6 +170,9 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofForRewriting(Node t)
       itr = rewritten.find(cur);
       if (itr != rewritten.end())
       {
+        // only can generate partially rewritten nodes when rewrite again is
+        // true.
+        Assert(d_policy != TConvPolicy::ONCE);
         // if it was rewritten, check the status of the rewritten node,
         // which should be finished now
         Node rcur = itr->second;
@@ -203,7 +236,12 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofForRewriting(Node t)
           pf.addStep(result, PfRule::CONG, pfChildren, pfArgs);
         }
         // did we rewrite ret (at post-rewrite)?
-        Node rret = getRewriteStep(ret);
+        Node rret;
+        // only if not ONCE policy, which only does pre-rewrite
+        if (d_policy != TConvPolicy::ONCE)
+        {
+          rret = getRewriteStep(ret);
+        }
         if (!rret.isNull())
         {
           if (cur != ret)
@@ -227,9 +265,8 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofForRewriting(Node t)
   } while (!visit.empty());
   Assert(visited.find(t) != visited.end());
   Assert(!visited.find(t)->second.isNull());
-  // make the overall proof
-  Node teq = t.eqNode(visited[t]);
-  return pf.getProofFor(teq);
+  // return the conclusion of the overall proof
+  return t.eqNode(visited[t]);
 }
 
 Node TConvProofGenerator::getRewriteStep(Node t) const
index 56b89f65cbfcd932fe12e63b4a58c3a9a3e73db7..d7ff6e8f620bc308fd1c482bac244297816a5a91 100644 (file)
 
 namespace CVC4 {
 
+/** A policy for how rewrite steps are applied in TConvProofGenerator */
+enum class TConvPolicy : uint32_t
+{
+  // steps are applied to fix-point, common use case is PfRule::REWRITE
+  FIXPOINT,
+  // steps are applied once at pre-rewrite, common use case is PfRule::SUBS
+  ONCE,
+};
+/** Writes a term conversion policy name to a stream. */
+std::ostream& operator<<(std::ostream& out, TConvPolicy tcpol);
+
 /**
  * The term conversion proof generator.
  *
@@ -69,8 +80,12 @@ class TConvProofGenerator : public ProofGenerator
    * @param pnm The proof node manager for constructing ProofNode objects.
    * @param c The context that this class depends on. If none is provided,
    * this class is context-independent.
+   * @param tpol The policy for applying rewrite steps of this class. For
+   * details, see d_policy.
    */
-  TConvProofGenerator(ProofNodeManager* pnm, context::Context* c = nullptr);
+  TConvProofGenerator(ProofNodeManager* pnm,
+                      context::Context* c = nullptr,
+                      TConvPolicy pol = TConvPolicy::FIXPOINT);
   ~TConvProofGenerator();
   /**
    * Add rewrite step t --> s based on proof generator.
@@ -96,7 +111,7 @@ class TConvProofGenerator : public ProofGenerator
    * t = t', where t' is the result of rewriting t based on the rewrite steps
    * registered to this class.
    *
-   * @param f The fact to get the proof for.
+   * @param f The equality fact to get the proof for.
    * @return The proof for f.
    */
   std::shared_ptr<ProofNode> getProofFor(Node f) override;
@@ -112,10 +127,26 @@ class TConvProofGenerator : public ProofGenerator
   /** map to rewritten forms */
   NodeNodeMap d_rewriteMap;
   /**
-   * Get the proof for term t. Returns a proof of t = t' where t' is the
-   * result of rewriting t based on the rewrite steps registered to this class.
+   * Policy for how rewrites are applied to terms. As a simple example, say we
+   * have registered the rewrite steps:
+   *   addRewriteStep( a, f(c), p1 )
+   *   addRewriteStep( c, d, p2 )
+   * Then getProofForRewriting(f(a,c),pf) returns a proof of:
+   *   f(a,c) = f(f(d),d) if d_policy is FIXPOINT,
+   *   f(a,c) = f(f(c),d) if d_policy is ONCE.
+   */
+  TConvPolicy d_policy;
+  /**
+   * Adds a proof of t = t' to the proof pf where t' is the result of rewriting
+   * t based on the rewrite steps registered to this class. This method then
+   * returns the proved equality t = t'.
+   */
+  Node getProofForRewriting(Node t, LazyCDProof& pf);
+  /**
+   * Register rewrite step, returns the equality t=s if t is distinct from s
+   * and a rewrite step has not already been registered for t.
    */
-  std::shared_ptr<ProofNode> getProofForRewriting(Node t);
+  Node registerRewriteStep(Node t, Node s);
 };
 
 }  // namespace CVC4