Standardize Rewriter::rewriteViaMethod call (#7119)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 3 Sep 2021 18:42:18 +0000 (13:42 -0500)
committerGitHub <noreply@github.com>
Fri, 3 Sep 2021 18:42:18 +0000 (18:42 +0000)
This moves the standard method for rewrites in proofs from TheoryBuiltinProofRuleChecker to Rewriter.  The motivation for this change is to make various kinds of rewrite methods (standard rewrite, extended rewrite, extended equality rewrite, evaluate) accessible throughout the code in a standard way.  After this PR, it is possible to know variants of the REWRITE proof rule application by having access to the rewriter, instead of having to get the builtin proof rule checker.  Note that TheoryBuiltinProofRuleChecker::applyRewrite *cannot* be static since access to the rewriter is not longer permitted to be static.

It also removes some unused infrastructure from Rewriter.

Followup PRs will remove applyRewrite for TheoryBuiltinProofRuleChecker in favor of calling the rewriter directly.

src/theory/builtin/proof_checker.cpp
src/theory/builtin/proof_checker.h
src/theory/rewriter.cpp
src/theory/rewriter.h
src/theory/rewriter_tables_template.h

index bb0f9a413e68ae7038e71eff15e4c35abecefb00..e51db4ce3726b4488fec1f0ee4a0eb421abba896 100644 (file)
@@ -70,36 +70,9 @@ Node BuiltinProofRuleChecker::applySubstitutionRewrite(
   return applyRewrite(nks, idr);
 }
 
-Node BuiltinProofRuleChecker::applyRewrite(Node n, MethodId idr)
+Node BuiltinProofRuleChecker::applyRewrite(TNode n, MethodId idr)
 {
-  Trace("builtin-pfcheck-debug")
-      << "applyRewrite (" << idr << "): " << n << std::endl;
-  if (idr == MethodId::RW_REWRITE)
-  {
-    return Rewriter::rewrite(n);
-  }
-  if (idr == MethodId::RW_EXT_REWRITE)
-  {
-    return d_ext_rewriter.extendedRewrite(n);
-  }
-  if (idr == MethodId::RW_REWRITE_EQ_EXT)
-  {
-    return d_env.getRewriter()->rewriteEqualityExt(n);
-  }
-  if (idr == MethodId::RW_EVALUATE)
-  {
-    Evaluator eval;
-    return eval.eval(n, {}, {}, false);
-  }
-  if (idr == MethodId::RW_IDENTITY)
-  {
-    // does nothing
-    return n;
-  }
-  // unknown rewriter
-  Assert(false) << "BuiltinProofRuleChecker::applyRewrite: no rewriter for "
-                << idr << std::endl;
-  return n;
+  return d_env.getRewriter()->rewriteViaMethod(n, idr);
 }
 
 bool BuiltinProofRuleChecker::getSubstitutionForLit(Node exp,
index d7edd2c537b762eca2f0604b040792558a566a56..bb746e467f1d82bff4de8c523769f02fab881ba2 100644 (file)
@@ -48,7 +48,7 @@ class BuiltinProofRuleChecker : public ProofRuleChecker
    * specifying a call to Rewriter::rewrite.
    * @return The rewritten form of n.
    */
-  Node applyRewrite(Node n, MethodId idr = MethodId::RW_REWRITE);
+  Node applyRewrite(TNode n, MethodId idr = MethodId::RW_REWRITE);
   /**
    * Get substitution for literal exp. Updates vars/subs to the substitution
    * specified by exp for the substitution method ids.
index bcd0952652ba140d330e468f0d93f31e4193b5c6..5c4cc553613878efa46b89eefc942f9e2739fe84 100644 (file)
@@ -21,6 +21,8 @@
 #include "smt/smt_engine_scope.h"
 #include "smt/smt_statistics_registry.h"
 #include "theory/builtin/proof_checker.h"
+#include "theory/evaluator.h"
+#include "theory/quantifiers/extended_rewrite.h"
 #include "theory/rewriter_tables.h"
 #include "theory/theory.h"
 #include "util/resource_manager.h"
@@ -150,34 +152,6 @@ void Rewriter::registerTheoryRewriter(theory::TheoryId tid,
   d_theoryRewriters[tid] = trew;
 }
 
-void Rewriter::registerPreRewrite(
-    Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
-{
-  Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual.";
-  d_preRewriters[k] = fn;
-}
-
-void Rewriter::registerPostRewrite(
-    Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
-{
-  Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual.";
-  d_postRewriters[k] = fn;
-}
-
-void Rewriter::registerPreRewriteEqual(
-    theory::TheoryId tid,
-    std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
-{
-  d_preRewritersEqual[tid] = fn;
-}
-
-void Rewriter::registerPostRewriteEqual(
-    theory::TheoryId tid,
-    std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
-{
-  d_postRewritersEqual[tid] = fn;
-}
-
 TheoryRewriter* Rewriter::getTheoryRewriter(theory::TheoryId theoryId)
 {
   return d_theoryRewriters[theoryId];
@@ -428,44 +402,30 @@ RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId,
                                      TNode n,
                                      TConvProofGenerator* tcpg)
 {
-  Kind k = n.getKind();
-  std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
-      (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
-  if (fn == nullptr)
+  if (tcpg != nullptr)
   {
-    if (tcpg != nullptr)
-    {
-      // call the trust rewrite response interface
-      TrustRewriteResponse tresponse =
-          d_theoryRewriters[theoryId]->preRewriteWithProof(n);
-      // process the trust rewrite response: store the proof step into
-      // tcpg if necessary and then convert to rewrite response.
-      return processTrustRewriteResponse(theoryId, tresponse, true, tcpg);
-    }
-    return d_theoryRewriters[theoryId]->preRewrite(n);
+    // call the trust rewrite response interface
+    TrustRewriteResponse tresponse =
+        d_theoryRewriters[theoryId]->preRewriteWithProof(n);
+    // process the trust rewrite response: store the proof step into
+    // tcpg if necessary and then convert to rewrite response.
+    return processTrustRewriteResponse(theoryId, tresponse, true, tcpg);
   }
-  return fn(&d_re, n);
+  return d_theoryRewriters[theoryId]->preRewrite(n);
 }
 
 RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId,
                                       TNode n,
                                       TConvProofGenerator* tcpg)
 {
-  Kind k = n.getKind();
-  std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
-      (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
-  if (fn == nullptr)
+  if (tcpg != nullptr)
   {
-    if (tcpg != nullptr)
-    {
-      // same as above, for post-rewrite
-      TrustRewriteResponse tresponse =
-          d_theoryRewriters[theoryId]->postRewriteWithProof(n);
-      return processTrustRewriteResponse(theoryId, tresponse, false, tcpg);
-    }
-    return d_theoryRewriters[theoryId]->postRewrite(n);
+    // same as above, for post-rewrite
+    TrustRewriteResponse tresponse =
+        d_theoryRewriters[theoryId]->postRewriteWithProof(n);
+    return processTrustRewriteResponse(theoryId, tresponse, false, tcpg);
   }
-  return fn(&d_re, n);
+  return d_theoryRewriters[theoryId]->postRewrite(n);
 }
 
 RewriteResponse Rewriter::processTrustRewriteResponse(
@@ -512,5 +472,36 @@ void Rewriter::clearCaches()
   clearCachesInternal();
 }
 
+Node Rewriter::rewriteViaMethod(TNode n, MethodId idr)
+{
+  if (idr == MethodId::RW_REWRITE)
+  {
+    return rewrite(n);
+  }
+  if (idr == MethodId::RW_EXT_REWRITE)
+  {
+    quantifiers::ExtendedRewriter er;
+    return er.extendedRewrite(n);
+  }
+  if (idr == MethodId::RW_REWRITE_EQ_EXT)
+  {
+    return rewriteEqualityExt(n);
+  }
+  if (idr == MethodId::RW_EVALUATE)
+  {
+    Evaluator eval;
+    return eval.eval(n, {}, {}, false);
+  }
+  if (idr == MethodId::RW_IDENTITY)
+  {
+    // does nothing
+    return n;
+  }
+  // unknown rewriter
+  Unhandled() << "Rewriter::rewriteViaMethod: no rewriter for " << idr
+              << std::endl;
+  return n;
+}
+
 }  // namespace theory
 }  // namespace cvc5
index 23a9914bd3711776f6070056b9f96508a3f314ec..63628b0afe62c77c9eaefc2f4978fc0fcccea196 100644 (file)
@@ -18,6 +18,7 @@
 #pragma once
 
 #include "expr/node.h"
+#include "proof/method_id.h"
 #include "theory/theory_rewriter.h"
 
 namespace cvc5 {
@@ -105,46 +106,19 @@ class Rewriter {
    */
   void registerTheoryRewriter(theory::TheoryId tid, TheoryRewriter* trew);
 
-  /**
-   * Register a prerewrite for a given kind.
-   *
-   * @param k The kind to register a rewrite for.
-   * @param fn The function that performs the rewrite.
-   */
-  void registerPreRewrite(
-      Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
-
-  /**
-   * Register a postrewrite for a given kind.
-   *
-   * @param k The kind to register a rewrite for.
-   * @param fn The function that performs the rewrite.
-   */
-  void registerPostRewrite(
-      Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
-
-  /**
-   * Register a prerewrite for equalities belonging to a given theory.
-   *
-   * @param tid The theory to register a rewrite for.
-   * @param fn The function that performs the rewrite.
-   */
-  void registerPreRewriteEqual(
-      theory::TheoryId tid,
-      std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+  /** Get the theory rewriter for the given id */
+  TheoryRewriter* getTheoryRewriter(theory::TheoryId theoryId);
 
   /**
-   * Register a postrewrite for equalities belonging to a given theory.
+   * Apply rewrite on n via the rewrite method identifier idr (see method_id.h).
+   * This encapsulates the exact behavior of a REWRITE step in a proof.
    *
-   * @param tid The theory to register a rewrite for.
-   * @param fn The function that performs the rewrite.
+   * @param n The node to rewrite,
+   * @param idr The method identifier of the rewriter, by default RW_REWRITE
+   * specifying a call to rewrite.
+   * @return The rewritten form of n.
    */
-  void registerPostRewriteEqual(
-      theory::TheoryId tid,
-      std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
-
-  /** Get the theory rewriter for the given id */
-  TheoryRewriter* getTheoryRewriter(theory::TheoryId theoryId);
+  Node rewriteViaMethod(TNode n, MethodId idr = MethodId::RW_REWRITE);
 
  private:
   /**
@@ -200,25 +174,6 @@ class Rewriter {
   /** Theory rewriters used by this rewriter instance */
   TheoryRewriter* d_theoryRewriters[theory::THEORY_LAST];
 
-  /** Rewriter table for prewrites. Maps kinds to rewriter function. */
-  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
-      d_preRewriters[kind::LAST_KIND];
-  /** Rewriter table for postrewrites. Maps kinds to rewriter function. */
-  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
-      d_postRewriters[kind::LAST_KIND];
-  /**
-   * Rewriter table for prerewrites of equalities. Maps theory to rewriter
-   * function.
-   */
-  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
-      d_preRewritersEqual[theory::THEORY_LAST];
-  /**
-   * Rewriter table for postrewrites of equalities. Maps theory to rewriter
-   * function.
-   */
-  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
-      d_postRewritersEqual[theory::THEORY_LAST];
-
   RewriteEnvironment d_re;
 
   /** The proof generator */
index c549f8cfbc525b4ffa4b008763d534bb5af1c061..36d320fb7ec1320a031d37f2f12af92b8e7b771a 100644 (file)
@@ -82,17 +82,7 @@ ${post_rewrite_set_cache}
 
 Rewriter::Rewriter() : d_tpg(nullptr)
 {
-  for (size_t i = 0; i < kind::LAST_KIND; ++i)
-  {
-    d_preRewriters[i] = nullptr;
-    d_postRewriters[i] = nullptr;
-  }
 
-  for (size_t i = 0; i < theory::THEORY_LAST; ++i)
-  {
-    d_preRewritersEqual[i] = nullptr;
-    d_postRewritersEqual[i] = nullptr;
-  }
 }
 
 void Rewriter::clearCachesInternal()