Introduce tables in the rewriter (#3742)
authorAndres Noetzli <andres.noetzli@gmail.com>
Wed, 11 Mar 2020 13:54:50 +0000 (06:54 -0700)
committerGitHub <noreply@github.com>
Wed, 11 Mar 2020 13:54:50 +0000 (06:54 -0700)
This commit adds tables in the rewriter that store which function should
be used to rewrite which kind. We have separate tables for `EQUAL`
because every theory has its own equality rewriter.

src/theory/rewriter.cpp
src/theory/rewriter.h
src/theory/rewriter_tables_template.h
src/theory/theory_rewriter.h

index 765c2b4c8a0fbc9d0214809003183808c4a7c241..b3f1e23d7e28a0edd2103717de04d1bf45c651ab 100644 (file)
@@ -81,6 +81,11 @@ struct RewriteStackElement {
   NodeBuilder<> d_builder;
 };
 
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n)
+{
+  return RewriteResponse(REWRITE_DONE, n);
+}
+
 Node Rewriter::rewrite(TNode node) {
   if (node.getNumChildren() == 0)
   {
@@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) {
     // eagerly for the sake of efficiency here.
     return node;
   }
-  Rewriter& rewriter = getInstance();
-  return rewriter.rewriteTo(theoryOf(node), node);
+  return getInstance().rewriteTo(theoryOf(node), node);
+}
+
+void Rewriter::registerPreRewrite(
+    Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+  Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual.";
+  d_preRewriters[k] = fn;
+}
+
+void Rewriter::registerPostRewrite(
+    Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+  Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual.";
+  d_postRewriters[k] = fn;
+}
+
+void Rewriter::registerPreRewriteEqual(
+    theory::TheoryId tid,
+    std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+  d_preRewritersEqual[tid] = fn;
+}
+
+void Rewriter::registerPostRewriteEqual(
+    theory::TheoryId tid,
+    std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn)
+{
+  d_postRewritersEqual[tid] = fn;
 }
 
 Rewriter& Rewriter::getInstance()
@@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
         for(;;) {
           // Perform the pre-rewrite
           RewriteResponse response =
-              d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite(
-                  rewriteStackTop.d_node);
+              preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
           // Put the rewritten node to the top of the stack
           rewriteStackTop.d_node = response.d_node;
           TheoryId newTheory = theoryOf(rewriteStackTop.d_node);
@@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
       for(;;) {
         // Do the post-rewrite
         RewriteResponse response =
-            d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite(
-                rewriteStackTop.d_node);
+            postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node);
         // We continue with the response we got
         TheoryId newTheoryId = theoryOf(response.d_node);
         if (newTheoryId != rewriteStackTop.getTheoryId()
@@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) {
   Unreachable();
 }/* Rewriter::rewriteTo() */
 
+RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n)
+{
+  Kind k = n.getKind();
+  std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+      (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k];
+  if (fn == nullptr)
+  {
+    return d_theoryRewriters[theoryId]->preRewrite(n);
+  }
+  return fn(&d_re, n);
+}
+
+RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n)
+{
+  Kind k = n.getKind();
+  std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn =
+      (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k];
+  if (fn == nullptr)
+  {
+    return d_theoryRewriters[theoryId]->postRewrite(n);
+  }
+  return fn(&d_re, n);
+}
+
 void Rewriter::clearCaches() {
   Rewriter& rewriter = getInstance();
 
index e55ca5d1c244751d8a20507d0ec1fd1b0138b8a1..f7298e1fb00a99e3d9dfe7fb3daffefaa6e23e77 100644 (file)
@@ -27,6 +27,23 @@ namespace theory {
 
 class RewriterInitializer;
 
+/**
+ * The rewrite environment holds everything that the individual rewrites have
+ * access to.
+ */
+class RewriteEnvironment
+{
+};
+
+/**
+ * The identity rewrite just returns the original node.
+ *
+ * @param re The rewrite environment
+ * @param n The node to rewrite
+ * @return The original node
+ */
+RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n);
+
 /**
  * The main rewriter class.
  */
@@ -45,6 +62,44 @@ class Rewriter {
    */
   static void clearCaches();
 
+  /**
+   * Register a prerewrite for a given kind.
+   *
+   * @param k The kind to register a rewrite for.
+   * @param fn The function that performs the rewrite.
+   */
+  void registerPreRewrite(
+      Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+  /**
+   * Register a postrewrite for a given kind.
+   *
+   * @param k The kind to register a rewrite for.
+   * @param fn The function that performs the rewrite.
+   */
+  void registerPostRewrite(
+      Kind k, std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+  /**
+   * Register a prerewrite for equalities belonging to a given theory.
+   *
+   * @param tid The theory to register a rewrite for.
+   * @param fn The function that performs the rewrite.
+   */
+  void registerPreRewriteEqual(
+      theory::TheoryId tid,
+      std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
+  /**
+   * Register a postrewrite for equalities belonging to a given theory.
+   *
+   * @param tid The theory to register a rewrite for.
+   * @param fn The function that performs the rewrite.
+   */
+  void registerPostRewriteEqual(
+      theory::TheoryId tid,
+      std::function<RewriteResponse(RewriteEnvironment*, TNode)> fn);
+
  private:
   /**
    * Get the (singleton) instance of the rewriter.
@@ -71,10 +126,10 @@ class Rewriter {
   Node rewriteTo(theory::TheoryId theoryId, Node node);
 
   /** Calls the pre-rewriter for the given theory */
-  RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node);
+  RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n);
 
   /** Calls the post-rewriter for the given theory */
-  RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node);
+  RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n);
 
   /**
    * Calls the equality-rewriter for the given theory.
@@ -88,6 +143,27 @@ class Rewriter {
 
   unsigned long d_iterationCount = 0;
 
+  /** Rewriter table for prewrites. Maps kinds to rewriter function. */
+  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+      d_preRewriters[kind::LAST_KIND];
+  /** Rewriter table for postrewrites. Maps kinds to rewriter function. */
+  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+      d_postRewriters[kind::LAST_KIND];
+  /**
+   * Rewriter table for prerewrites of equalities. Maps theory to rewriter
+   * function.
+   */
+  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+      d_preRewritersEqual[theory::THEORY_LAST];
+  /**
+   * Rewriter table for postrewrites of equalities. Maps theory to rewriter
+   * function.
+   */
+  std::function<RewriteResponse(RewriteEnvironment*, TNode)>
+      d_postRewritersEqual[theory::THEORY_LAST];
+
+  RewriteEnvironment d_re;
+
 #ifdef CVC4_ASSERTIONS
   std::unique_ptr<std::unordered_set<Node, NodeHashFunction>> d_rewriteStack =
       nullptr;
index e1be6355b966c05a20981ce812740699c98565d3..1bb03e253a443027c70041822f19db5b5b6bfba2 100644 (file)
@@ -64,6 +64,19 @@ ${post_rewrite_set_cache}
 Rewriter::Rewriter()
 {
 ${rewrite_init}
+
+for (size_t i = 0; i < kind::LAST_KIND; ++i)
+{
+  d_preRewriters[i] = nullptr;
+  d_postRewriters[i] = nullptr;
+}
+
+for (size_t i = 0; i < theory::THEORY_LAST; ++i)
+{
+  d_preRewritersEqual[i] = nullptr;
+  d_postRewritersEqual[i] = nullptr;
+  d_theoryRewriters[i]->registerRewrites(this);
+}
 }
 
 void Rewriter::clearCachesInternal() {
index e7dc782bb39f130c025afe1eb385b73536ef5020..311ab90208428515183b2d2f6b2d34a0b16d5c74 100644 (file)
@@ -24,6 +24,8 @@
 namespace CVC4 {
 namespace theory {
 
+class Rewriter;
+
 /**
  * Theory rewriters signal whether more rewriting is needed (or not)
  * by using a member of this enumeration.  See RewriteResponse, below.
@@ -63,6 +65,13 @@ class TheoryRewriter
  public:
   virtual ~TheoryRewriter() = default;
 
+  /**
+   * Registers the rewrites of a given theory with the rewriter.
+   *
+   * @param rewriter The rewriter to register the rewrites with.
+   */
+  virtual void registerRewrites(Rewriter* rewriter) {}
+
   /**
    * Performs a pre-rewrite step.
    *