(proof-new) Rewrites involving operators in term conversion proof generator (#5072)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 18 Sep 2020 02:06:44 +0000 (21:06 -0500)
committerGitHub <noreply@github.com>
Fri, 18 Sep 2020 02:06:44 +0000 (21:06 -0500)
In some cases, e.g. witness form conversion, we require traversing to operators in term conversion proofs. This updates the term conversion generator to support (term-context-sensitive) rewrites involving operators using HO_CONG. This requires updating the term context utilities with support for operators.

src/expr/term_context.cpp
src/expr/term_context.h
src/expr/term_context_stack.cpp
src/expr/term_context_stack.h
src/expr/term_conversion_proof_generator.cpp
src/expr/term_conversion_proof_generator.h
src/smt/witness_form.cpp

index 6dcdc25ee54cbfd61e08cc1579fa8352eff42422..fcb8edfdce67aea2c3fb5fc4d401ed2e2bc42d94 100644 (file)
 
 namespace CVC4 {
 
+uint32_t TermContext::computeValueOp(TNode t, uint32_t tval) const
+{
+  // default is no change
+  return tval;
+}
+
 uint32_t RtfTermContext::initialValue() const
 {
   // by default, not in a term context or a quantifier
index 87f91f2df52a5c3f78b66fb61c7e2892c73bd72b..0c51dd49306e142d21fb291f8411abdcb4d629b8 100644 (file)
@@ -48,6 +48,11 @@ class TermContext
    * is the term context identifier of t.
    */
   virtual uint32_t computeValue(TNode t, uint32_t tval, size_t index) const = 0;
+  /**
+   * Returns the term context identifier of the operator of t, where tval
+   * is the term context identifier of t.
+   */
+  virtual uint32_t computeValueOp(TNode t, uint32_t tval) const;
 };
 
 /**
index 4c2e32033af508ed8f4878e5ff3b0e15c9cf2ff9..4deb0627c228a788f9841c15a0b2dff986f0b958 100644 (file)
@@ -43,6 +43,13 @@ void TCtxStack::pushChild(Node t, uint32_t tval, size_t index)
   d_stack.push_back(std::pair<Node, uint32_t>(t[index], tcval));
 }
 
+void TCtxStack::pushOp(Node t, uint32_t tval)
+{
+  Assert(t.hasOperator());
+  uint32_t toval = d_tctx->computeValueOp(t, tval);
+  d_stack.push_back(std::pair<Node, uint32_t>(t.getOperator(), toval));
+}
+
 void TCtxStack::push(Node t, uint32_t tval)
 {
   d_stack.push_back(std::pair<Node, uint32_t>(t, tval));
index 9aeea04c2aa40482866c90b54491c0df3dd2888b..6f5dece81d14ed81dd4c3ed0f63ff91d887b7a96 100644 (file)
@@ -41,6 +41,11 @@ class TCtxStack
    * the term context hash of t.
    */
   void pushChild(Node t, uint32_t tval, size_t index);
+  /**
+   * Push the operator of t to the stack, where tval is the term context has
+   * of t.
+   */
+  void pushOp(Node t, uint32_t tval);
   /** Push t to the stack with term context hash tval. */
   void push(Node t, uint32_t tval);
   /** Pop a term from the context */
index 8cd7561b4e068c956bf9e8760118593995dc3f84..215793318befdf8796d17b377330de589db3773e 100644 (file)
@@ -48,13 +48,15 @@ TConvProofGenerator::TConvProofGenerator(ProofNodeManager* pnm,
                                          TConvPolicy pol,
                                          TConvCachePolicy cpol,
                                          std::string name,
-                                         TermContext* tccb)
+                                         TermContext* tccb,
+                                         bool rewriteOps)
     : d_proof(pnm, nullptr, c, name + "::LazyCDProof"),
       d_rewriteMap(c ? c : &d_context),
       d_policy(pol),
       d_cpolicy(cpol),
       d_name(name),
-      d_tcontext(tccb)
+      d_tcontext(tccb),
+      d_rewriteOps(rewriteOps)
 {
 }
 
@@ -196,8 +198,11 @@ std::shared_ptr<ProofNode> TConvProofGenerator::getProofFor(Node f)
       return nullptr;
     }
   }
+  std::shared_ptr<ProofNode> pfn = lpf.getProofFor(f);
   Trace("tconv-pf-gen") << "... success" << std::endl;
-  return lpf.getProofFor(f);
+  Assert (pfn!=nullptr);
+  Trace("tconv-pf-gen-debug") << "... proof is " << *pfn << std::endl;
+  return pfn;
 }
 
 Node TConvProofGenerator::getProofForRewriting(Node t,
@@ -313,11 +318,21 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
       else if (tctx != nullptr)
       {
         visitctx->push(cur, curCVal);
+        // visit operator if apply uf
+        if (d_rewriteOps && cur.getKind() == APPLY_UF)
+        {
+          visitctx->pushOp(cur, curCVal);
+        }
         visitctx->pushChildren(cur, curCVal);
       }
       else
       {
         visit.push_back(cur);
+        // visit operator if apply uf
+        if (d_rewriteOps && cur.getKind() == APPLY_UF)
+        {
+          visit.push_back(cur.getOperator());
+        }
         visit.insert(visit.end(), cur.begin(), cur.end());
       }
     }
@@ -365,8 +380,29 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
         Node retHash = curHash;
         bool childChanged = false;
         std::vector<Node> children;
-        if (cur.getMetaKind() == metakind::PARAMETERIZED)
+        Kind ck = cur.getKind();
+        if (d_rewriteOps && ck == APPLY_UF)
+        {
+          // the operator of APPLY_UF is visited
+          Node cop = cur.getOperator();
+          if (tctx != nullptr)
+          {
+            uint32_t coval = tctx->computeValueOp(cur, curCVal);
+            Node coHash = TCtxNode::computeNodeHash(cop, coval);
+            it = visited.find(coHash);
+          }
+          else
+          {
+            it = visited.find(cop);
+          }
+          Assert(it != visited.end());
+          Assert(!it->second.isNull());
+          childChanged = childChanged || cop != it->second;
+          children.push_back(it->second);
+        }
+        else if (cur.getMetaKind() == metakind::PARAMETERIZED)
         {
+          // all other parametrized operators are unchanged
           children.push_back(cur.getOperator());
         }
         // get the results of the children
@@ -398,10 +434,23 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
         }
         if (childChanged)
         {
-          ret = nm->mkNode(cur.getKind(), children);
+          ret = nm->mkNode(ck, children);
           rewritten[curHash] = ret;
           // congruence to show (cur = ret)
+          PfRule congRule = PfRule::CONG;
           std::vector<Node> pfChildren;
+          std::vector<Node> pfArgs;
+          pfArgs.push_back(ProofRuleChecker::mkKindNode(ck));
+          if (ck == APPLY_UF && children[0] != cur.getOperator())
+          {
+            // use HO_CONG if the operator changed
+            congRule = PfRule::HO_CONG;
+            pfChildren.push_back(cur.getOperator().eqNode(children[0]));
+          }
+          else if (kind::metaKindOf(ck) == kind::metakind::PARAMETERIZED)
+          {
+            pfArgs.push_back(cur.getOperator());
+          }
           for (size_t i = 0, size = cur.getNumChildren(); i < size; i++)
           {
             if (cur[i] == ret[i])
@@ -411,15 +460,8 @@ Node TConvProofGenerator::getProofForRewriting(Node t,
             }
             pfChildren.push_back(cur[i].eqNode(ret[i]));
           }
-          std::vector<Node> pfArgs;
-          Kind k = cur.getKind();
-          pfArgs.push_back(ProofRuleChecker::mkKindNode(k));
-          if (kind::metaKindOf(k) == kind::metakind::PARAMETERIZED)
-          {
-            pfArgs.push_back(cur.getOperator());
-          }
           Node result = cur.eqNode(ret);
-          pf.addStep(result, PfRule::CONG, pfChildren, pfArgs);
+          pf.addStep(result, congRule, pfChildren, pfArgs);
           // must update the hash
           retHash = ret;
           if (tctx != nullptr)
index faee2b9e36f7b7ca2c1a602564dd5629082ce12b..1153995adc6265383f7242cd42313f7880e808c1 100644 (file)
@@ -138,7 +138,8 @@ class TConvProofGenerator : public ProofGenerator
                       TConvPolicy pol = TConvPolicy::FIXPOINT,
                       TConvCachePolicy cpol = TConvCachePolicy::NEVER,
                       std::string name = "TConvProofGenerator",
-                      TermContext* tccb = nullptr);
+                      TermContext* tccb = nullptr,
+                      bool rewriteOps = false);
   ~TConvProofGenerator();
   /**
    * Add rewrite step t --> s based on proof generator.
@@ -209,6 +210,12 @@ class TConvProofGenerator : public ProofGenerator
   std::map<Node, std::shared_ptr<ProofNode> > d_cache;
   /** An (optional) term context object */
   TermContext* d_tcontext;
+  /**
+   * Whether we rewrite operators. If this flag is true, then the main
+   * traversal algorithm of this proof generator traverses operators of
+   * APPLY_UF and uses HO_CONG to justify rewriting of subterms when necessary.
+   */
+  bool d_rewriteOps;
   /** Get rewrite step for (hash value of) term. */
   Node getRewriteStepInternal(Node thash) const;
   /**
index 19795119d82f7f4fbc49cf81378c2d84ac1368a3..48fb2a1f7710d5aee82801f77327383372115da6 100644 (file)
@@ -25,7 +25,9 @@ WitnessFormGenerator::WitnessFormGenerator(ProofNodeManager* pnm)
              nullptr,
              TConvPolicy::FIXPOINT,
              TConvCachePolicy::NEVER,
-             "WfGenerator::TConvProofGenerator"),
+             "WfGenerator::TConvProofGenerator",
+             nullptr,
+             true),
       d_wintroPf(pnm, nullptr, nullptr, "WfGenerator::LazyCDProof")
 {
 }
@@ -110,6 +112,10 @@ Node WitnessFormGenerator::convertToWitnessForm(Node t)
           // It should be the case that cur has children, since the witness
           // form of constants are themselves.
           Assert(cur.getNumChildren() > 0);
+          if (cur.hasOperator())
+          {
+            visit.push_back(cur.getOperator());
+          }
           visit.insert(visit.end(), cur.begin(), cur.end());
         }
       }