From: Andres Noetzli Date: Wed, 11 Mar 2020 13:54:50 +0000 (-0700) Subject: Introduce tables in the rewriter (#3742) X-Git-Tag: cvc5-1.0.0~3517 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2b355305ef635ddfaad7fe75c29221cb2f744a62;p=cvc5.git Introduce tables in the rewriter (#3742) This commit adds tables in the rewriter that store which function should be used to rewrite which kind. We have separate tables for `EQUAL` because every theory has its own equality rewriter. --- diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 765c2b4c8..b3f1e23d7 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -81,6 +81,11 @@ struct RewriteStackElement { NodeBuilder<> d_builder; }; +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n) +{ + return RewriteResponse(REWRITE_DONE, n); +} + Node Rewriter::rewrite(TNode node) { if (node.getNumChildren() == 0) { @@ -88,8 +93,35 @@ Node Rewriter::rewrite(TNode node) { // eagerly for the sake of efficiency here. return node; } - Rewriter& rewriter = getInstance(); - return rewriter.rewriteTo(theoryOf(node), node); + return getInstance().rewriteTo(theoryOf(node), node); +} + +void Rewriter::registerPreRewrite( + Kind k, std::function fn) +{ + Assert(k != kind::EQUAL) << "Register pre-rewrites for EQUAL with registerPreRewriteEqual."; + d_preRewriters[k] = fn; +} + +void Rewriter::registerPostRewrite( + Kind k, std::function fn) +{ + Assert(k != kind::EQUAL) << "Register post-rewrites for EQUAL with registerPostRewriteEqual."; + d_postRewriters[k] = fn; +} + +void Rewriter::registerPreRewriteEqual( + theory::TheoryId tid, + std::function fn) +{ + d_preRewritersEqual[tid] = fn; +} + +void Rewriter::registerPostRewriteEqual( + theory::TheoryId tid, + std::function fn) +{ + d_postRewritersEqual[tid] = fn; } Rewriter& Rewriter::getInstance() @@ -153,8 +185,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Perform the pre-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite( - rewriteStackTop.d_node); + preRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node); // Put the rewritten node to the top of the stack rewriteStackTop.d_node = response.d_node; TheoryId newTheory = theoryOf(rewriteStackTop.d_node); @@ -225,8 +256,7 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { for(;;) { // Do the post-rewrite RewriteResponse response = - d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite( - rewriteStackTop.d_node); + postRewrite(rewriteStackTop.getTheoryId(), rewriteStackTop.d_node); // We continue with the response we got TheoryId newTheoryId = theoryOf(response.d_node); if (newTheoryId != rewriteStackTop.getTheoryId() @@ -290,6 +320,30 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { Unreachable(); }/* Rewriter::rewriteTo() */ +RewriteResponse Rewriter::preRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function fn = + (k == kind::EQUAL) ? d_preRewritersEqual[theoryId] : d_preRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->preRewrite(n); + } + return fn(&d_re, n); +} + +RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n) +{ + Kind k = n.getKind(); + std::function fn = + (k == kind::EQUAL) ? d_postRewritersEqual[theoryId] : d_postRewriters[k]; + if (fn == nullptr) + { + return d_theoryRewriters[theoryId]->postRewrite(n); + } + return fn(&d_re, n); +} + void Rewriter::clearCaches() { Rewriter& rewriter = getInstance(); diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index e55ca5d1c..f7298e1fb 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -27,6 +27,23 @@ namespace theory { class RewriterInitializer; +/** + * The rewrite environment holds everything that the individual rewrites have + * access to. + */ +class RewriteEnvironment +{ +}; + +/** + * The identity rewrite just returns the original node. + * + * @param re The rewrite environment + * @param n The node to rewrite + * @return The original node + */ +RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n); + /** * The main rewriter class. */ @@ -45,6 +62,44 @@ class Rewriter { */ static void clearCaches(); + /** + * Register a prerewrite for a given kind. + * + * @param k The kind to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPreRewrite( + Kind k, std::function fn); + + /** + * Register a postrewrite for a given kind. + * + * @param k The kind to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPostRewrite( + Kind k, std::function fn); + + /** + * Register a prerewrite for equalities belonging to a given theory. + * + * @param tid The theory to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPreRewriteEqual( + theory::TheoryId tid, + std::function fn); + + /** + * Register a postrewrite for equalities belonging to a given theory. + * + * @param tid The theory to register a rewrite for. + * @param fn The function that performs the rewrite. + */ + void registerPostRewriteEqual( + theory::TheoryId tid, + std::function fn); + private: /** * Get the (singleton) instance of the rewriter. @@ -71,10 +126,10 @@ class Rewriter { Node rewriteTo(theory::TheoryId theoryId, Node node); /** Calls the pre-rewriter for the given theory */ - RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse preRewrite(theory::TheoryId theoryId, TNode n); /** Calls the post-rewriter for the given theory */ - RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse postRewrite(theory::TheoryId theoryId, TNode n); /** * Calls the equality-rewriter for the given theory. @@ -88,6 +143,27 @@ class Rewriter { unsigned long d_iterationCount = 0; + /** Rewriter table for prewrites. Maps kinds to rewriter function. */ + std::function + d_preRewriters[kind::LAST_KIND]; + /** Rewriter table for postrewrites. Maps kinds to rewriter function. */ + std::function + d_postRewriters[kind::LAST_KIND]; + /** + * Rewriter table for prerewrites of equalities. Maps theory to rewriter + * function. + */ + std::function + d_preRewritersEqual[theory::THEORY_LAST]; + /** + * Rewriter table for postrewrites of equalities. Maps theory to rewriter + * function. + */ + std::function + d_postRewritersEqual[theory::THEORY_LAST]; + + RewriteEnvironment d_re; + #ifdef CVC4_ASSERTIONS std::unique_ptr> d_rewriteStack = nullptr; diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index e1be6355b..1bb03e253 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -64,6 +64,19 @@ ${post_rewrite_set_cache} Rewriter::Rewriter() { ${rewrite_init} + +for (size_t i = 0; i < kind::LAST_KIND; ++i) +{ + d_preRewriters[i] = nullptr; + d_postRewriters[i] = nullptr; +} + +for (size_t i = 0; i < theory::THEORY_LAST; ++i) +{ + d_preRewritersEqual[i] = nullptr; + d_postRewritersEqual[i] = nullptr; + d_theoryRewriters[i]->registerRewrites(this); +} } void Rewriter::clearCachesInternal() { diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h index e7dc782bb..311ab9020 100644 --- a/src/theory/theory_rewriter.h +++ b/src/theory/theory_rewriter.h @@ -24,6 +24,8 @@ namespace CVC4 { namespace theory { +class Rewriter; + /** * Theory rewriters signal whether more rewriting is needed (or not) * by using a member of this enumeration. See RewriteResponse, below. @@ -63,6 +65,13 @@ class TheoryRewriter public: virtual ~TheoryRewriter() = default; + /** + * Registers the rewrites of a given theory with the rewriter. + * + * @param rewriter The rewriter to register the rewrites with. + */ + virtual void registerRewrites(Rewriter* rewriter) {} + /** * Performs a pre-rewrite step. *