From b6ce0f23ce0aaa0552767e8067fe58dbceee11cb Mon Sep 17 00:00:00 2001 From: Andres Noetzli Date: Mon, 9 Dec 2019 11:19:10 -0800 Subject: [PATCH] Make theory rewriters non-static (#3547) This commit changes theory rewriters to be non-static. This refactoring is needed as a stepping stone to making our rewriter configurable: If we have multiple solver objects with different rewrite configurations, we cannot use `static` variables for the rewriter table in the BV rewriter for example. It is also in line with our goal of getting rid of singletons in general. Note that the `Rewriter` class is still a singleton, which will be changed in a future commit. --- src/CMakeLists.txt | 1 + src/theory/arith/arith_rewriter.h | 21 +- src/theory/arrays/theory_arrays_rewriter.h | 18 +- src/theory/booleans/theory_bool_rewriter.h | 18 +- .../builtin/theory_builtin_rewriter.cpp | 3 +- src/theory/builtin/theory_builtin_rewriter.h | 21 +- src/theory/bv/theory_bv_rewriter.cpp | 18 +- src/theory/bv/theory_bv_rewriter.h | 43 +- src/theory/datatypes/datatypes_rewriter.h | 237 ++++---- src/theory/fp/theory_fp_rewriter.cpp | 549 +++++++++--------- src/theory/fp/theory_fp_rewriter.h | 35 +- src/theory/mkrewriter | 15 +- src/theory/quantifiers/quantifiers_rewriter.h | 16 +- src/theory/rewriter.cpp | 125 ++-- src/theory/rewriter.h | 99 ++-- src/theory/rewriter_tables_template.h | 23 +- src/theory/sep/theory_sep_rewriter.h | 30 +- src/theory/sets/theory_sets_rewriter.h | 28 +- src/theory/strings/regexp_elim.cpp | 1 + src/theory/strings/theory_strings_rewriter.h | 17 +- src/theory/theory_rewriter.h | 79 +++ src/theory/uf/theory_uf_rewriter.h | 22 +- 22 files changed, 711 insertions(+), 708 deletions(-) create mode 100644 src/theory/theory_rewriter.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5b6c3aeca..e62507d8e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -629,6 +629,7 @@ libcvc4_add_sources( theory/rewriter.h theory/rewriter_attributes.h theory/sep/theory_sep.cpp + theory/theory_rewriter.h theory/sep/theory_sep.h theory/sep/theory_sep_rewriter.cpp theory/sep/theory_sep_rewriter.h diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index de6b3f228..600f38d4f 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -21,24 +21,19 @@ #define CVC4__THEORY__ARITH__ARITH_REWRITER_H #include "theory/theory.h" -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" namespace CVC4 { namespace theory { namespace arith { -class ArithRewriter { -public: - - static RewriteResponse preRewrite(TNode n); - static RewriteResponse postRewrite(TNode n); - - static void init() { } - - static void shutdown() { } - -private: +class ArithRewriter : public TheoryRewriter +{ + public: + RewriteResponse preRewrite(TNode n) override; + RewriteResponse postRewrite(TNode n) override; + private: static Node makeSubtractionNode(TNode l, TNode r); static Node makeUnaryMinusNode(TNode n); @@ -70,7 +65,7 @@ private: return !isAtom(n); } -};/* class ArithRewriter */ +}; /* class ArithRewriter */ }/* CVC4::theory::arith namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/arrays/theory_arrays_rewriter.h b/src/theory/arrays/theory_arrays_rewriter.h index 5f3fcf48d..851d2ca5d 100644 --- a/src/theory/arrays/theory_arrays_rewriter.h +++ b/src/theory/arrays/theory_arrays_rewriter.h @@ -24,6 +24,7 @@ #include #include "theory/rewriter.h" +#include "theory/theory_rewriter.h" #include "theory/type_enumerator.h" namespace CVC4 { @@ -39,11 +40,13 @@ static inline Node mkEqNode(Node a, Node b) { return a.eqNode(b); } -class TheoryArraysRewriter { +class TheoryArraysRewriter : public TheoryRewriter +{ static Node normalizeConstant(TNode node) { return normalizeConstant(node, node[1].getType().getCardinality()); } -public: + + public: //this function is called by printers when using the option "--model-u-dt-enum" static Node normalizeConstant(TNode node, Cardinality indexCard) { TNode store = node[0]; @@ -229,9 +232,9 @@ public: return n; } -public: - - static RewriteResponse postRewrite(TNode node) { + public: + RewriteResponse postRewrite(TNode node) override + { Trace("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl; switch (node.getKind()) { case kind::SELECT: { @@ -403,7 +406,8 @@ public: return RewriteResponse(REWRITE_DONE, node); } - static inline RewriteResponse preRewrite(TNode node) { + RewriteResponse preRewrite(TNode node) override + { Trace("arrays-prerewrite") << "Arrays::preRewrite start " << node << std::endl; switch (node.getKind()) { case kind::SELECT: { @@ -503,7 +507,7 @@ public: static inline void init() {} static inline void shutdown() {} -};/* class TheoryArraysRewriter */ +}; /* class TheoryArraysRewriter */ }/* CVC4::theory::arrays namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/booleans/theory_bool_rewriter.h b/src/theory/booleans/theory_bool_rewriter.h index 250f36efa..8fc65932e 100644 --- a/src/theory/booleans/theory_bool_rewriter.h +++ b/src/theory/booleans/theory_bool_rewriter.h @@ -20,23 +20,19 @@ #ifndef CVC4__THEORY__BOOLEANS__THEORY_BOOL_REWRITER_H #define CVC4__THEORY__BOOLEANS__THEORY_BOOL_REWRITER_H -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" namespace CVC4 { namespace theory { namespace booleans { -class TheoryBoolRewriter { +class TheoryBoolRewriter : public TheoryRewriter +{ + public: + RewriteResponse preRewrite(TNode node) override; + RewriteResponse postRewrite(TNode node) override; -public: - - static RewriteResponse preRewrite(TNode node); - static RewriteResponse postRewrite(TNode node); - - static void init() {} - static void shutdown() {} - -};/* class TheoryBoolRewriter */ +}; /* class TheoryBoolRewriter */ }/* CVC4::theory::booleans namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/builtin/theory_builtin_rewriter.cpp b/src/theory/builtin/theory_builtin_rewriter.cpp index d483bf994..d9fe2fecc 100644 --- a/src/theory/builtin/theory_builtin_rewriter.cpp +++ b/src/theory/builtin/theory_builtin_rewriter.cpp @@ -15,11 +15,12 @@ ** \todo document this file **/ -#include "expr/attribute.h" #include "theory/builtin/theory_builtin_rewriter.h" +#include "expr/attribute.h" #include "expr/chain.h" #include "expr/node_algorithm.h" +#include "theory/rewriter.h" using namespace std; diff --git a/src/theory/builtin/theory_builtin_rewriter.h b/src/theory/builtin/theory_builtin_rewriter.h index 5f703fa00..05b1b643c 100644 --- a/src/theory/builtin/theory_builtin_rewriter.h +++ b/src/theory/builtin/theory_builtin_rewriter.h @@ -20,15 +20,15 @@ #ifndef CVC4__THEORY__BUILTIN__THEORY_BUILTIN_REWRITER_H #define CVC4__THEORY__BUILTIN__THEORY_BUILTIN_REWRITER_H -#include "theory/rewriter.h" #include "theory/theory.h" +#include "theory/theory_rewriter.h" namespace CVC4 { namespace theory { namespace builtin { -class TheoryBuiltinRewriter { - +class TheoryBuiltinRewriter : public TheoryRewriter +{ static Node blastDistinct(TNode node); static Node blastChain(TNode node); @@ -45,17 +45,12 @@ public: } } - static RewriteResponse postRewrite(TNode node); - - static inline RewriteResponse preRewrite(TNode node) { - return doRewrite(node); - } + RewriteResponse postRewrite(TNode node) override; - static inline void init() {} - static inline void shutdown() {} + RewriteResponse preRewrite(TNode node) override { return doRewrite(node); } -// conversions between lambdas and arrays -private: + // conversions between lambdas and arrays + private: /** recursive helper for getLambdaForArrayRepresentation */ static Node getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, std::unordered_map< TNode, Node, TNodeHashFunction >& visited ); @@ -124,7 +119,7 @@ private: * to n, this method returns null. */ static Node getArrayRepresentationForLambda(TNode n); -};/* class TheoryBuiltinRewriter */ +}; /* class TheoryBuiltinRewriter */ }/* CVC4::theory::builtin namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 6b160ea67..282ae2e27 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -29,23 +29,7 @@ using namespace CVC4; using namespace CVC4::theory; using namespace CVC4::theory::bv; - -// thread_local AllRewriteRules* TheoryBVRewriter::s_allRules = NULL; -// thread_local TimerStat* TheoryBVRewriter::d_rewriteTimer = NULL; -RewriteFunction TheoryBVRewriter::d_rewriteTable[kind::LAST_KIND]; -void TheoryBVRewriter::init() { - // s_allRules = new AllRewriteRules; - // d_rewriteTimer = new TimerStat("theory::bv::rewriteTimer"); - // smtStatisticsRegistry()->registerStat(d_rewriteTimer); - initializeRewrites(); - -} - -void TheoryBVRewriter::shutdown() { - // delete s_allRules; - // smtStatisticsRegistry()->unregisterStat(d_rewriteTimer); - //delete d_rewriteTimer; -} +TheoryBVRewriter::TheoryBVRewriter() { initializeRewrites(); } RewriteResponse TheoryBVRewriter::preRewrite(TNode node) { RewriteResponse res = d_rewriteTable[node.getKind()](node, true); diff --git a/src/theory/bv/theory_bv_rewriter.h b/src/theory/bv/theory_bv_rewriter.h index bf707c268..8c8b7846c 100644 --- a/src/theory/bv/theory_bv_rewriter.h +++ b/src/theory/bv/theory_bv_rewriter.h @@ -20,7 +20,7 @@ #ifndef CVC4__THEORY__BV__THEORY_BV_REWRITER_H #define CVC4__THEORY__BV__THEORY_BV_REWRITER_H -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" #include "util/statistics_registry.h" namespace CVC4 { @@ -30,14 +30,27 @@ namespace bv { struct AllRewriteRules; typedef RewriteResponse (*RewriteFunction) (TNode, bool); -class TheoryBVRewriter { +class TheoryBVRewriter : public TheoryRewriter +{ + public: + /** + * Temporary hack for devision-by-zero until we refactor theory code from + * smt engine. + * + * @param node + * + * @return + */ + static Node eliminateBVSDiv(TNode node); - static RewriteFunction d_rewriteTable[kind::LAST_KIND]; + TheoryBVRewriter(); + RewriteResponse postRewrite(TNode node) override; + RewriteResponse preRewrite(TNode node) override; + + private: static RewriteResponse IdentityRewrite(TNode node, bool prerewrite = false); static RewriteResponse UndefinedRewrite(TNode node, bool prerewrite = false); - - static void initializeRewrites(); static RewriteResponse RewriteEqual(TNode node, bool prerewrite = false); static RewriteResponse RewriteUlt(TNode node, bool prerewrite = false); @@ -86,24 +99,10 @@ class TheoryBVRewriter { static RewriteResponse RewriteBVToNat(TNode node, bool prerewrite = false); static RewriteResponse RewriteIntToBV(TNode node, bool prerewrite = false); -public: + void initializeRewrites(); - static RewriteResponse postRewrite(TNode node); - - static RewriteResponse preRewrite(TNode node); - - static void init(); - static void shutdown(); - /** - * Temporary hack for devision-by-zero until we refactor theory code from - * smt engine. - * - * @param node - * - * @return - */ - static Node eliminateBVSDiv(TNode node); -};/* class TheoryBVRewriter */ + RewriteFunction d_rewriteTable[kind::LAST_KIND]; +}; /* class TheoryBVRewriter */ }/* CVC4::theory::bv namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index d2fdd8f4d..8f24a5aef 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -20,138 +20,129 @@ #define CVC4__THEORY__DATATYPES__DATATYPES_REWRITER_H #include "expr/node_manager_attributes.h" -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" #include "theory/type_enumerator.h" namespace CVC4 { namespace theory { namespace datatypes { -class DatatypesRewriter { -public: - static RewriteResponse postRewrite(TNode in); +class DatatypesRewriter : public TheoryRewriter +{ + public: + RewriteResponse postRewrite(TNode in) override; + RewriteResponse preRewrite(TNode in) override; - static RewriteResponse preRewrite(TNode in); + /** normalize codatatype constant + * + * This returns the normal form of the codatatype constant n. This runs a + * DFA minimization algorithm based on the private functions below. + * + * In particular, we first call collectRefs to setup initial information + * about what terms occur in n. Then, we run a DFA minimization algorithm to + * partition these subterms in equivalence classes. Finally, we call + * normalizeCodatatypeConstantEqc to construct the normalized codatatype + * constant that is equivalent to n. + */ + static Node normalizeCodatatypeConstant(Node n); + /** normalize constant + * + * This method returns the normal form of n, which calls the above function + * on all top-level codatatype subterms of n. + */ + static Node normalizeConstant(Node n); - static inline void init() {} - static inline void shutdown() {} + private: + /** rewrite constructor term in */ + static RewriteResponse rewriteConstructor(TNode in); + /** rewrite selector term in */ + static RewriteResponse rewriteSelector(TNode in); + /** rewrite tester term in */ + static RewriteResponse rewriteTester(TNode in); - /** normalize codatatype constant - * - * This returns the normal form of the codatatype constant n. This runs a - * DFA minimization algorithm based on the private functions below. - * - * In particular, we first call collectRefs to setup initial information - * about what terms occur in n. Then, we run a DFA minimization algorithm to - * partition these subterms in equivalence classes. Finally, we call - * normalizeCodatatypeConstantEqc to construct the normalized codatatype - * constant that is equivalent to n. - */ - static Node normalizeCodatatypeConstant(Node n); - /** normalize constant - * - * This method returns the normal form of n, which calls the above function - * on all top-level codatatype subterms of n. - */ - static Node normalizeConstant(Node n); - -private: - /** rewrite constructor term in */ - static RewriteResponse rewriteConstructor(TNode in); - /** rewrite selector term in */ - static RewriteResponse rewriteSelector(TNode in); - /** rewrite tester term in */ - static RewriteResponse rewriteTester(TNode in); - - /** collect references - * - * This function, given as input a codatatype term n, collects the necessary - * information for constructing a (canonical) codatatype constant that is - * equivalent to n if one exists, or null otherwise. - * - * In particular it returns a term ret such that all non-codatatype datatype - * subterms of n are replaced by a constant that is equal to them via a - * (mutually) recursive call to normalizeConstant above. Additionally, this - * function replaces references to mu-binders with fresh variables. - * In detail, mu-terms are represented by uninterpreted constants of datatype - * type that carry their Debruijn index. - * - * Consider the example of a codatatype representing a stream of integers: - * Stream := cons( head : Int, tail : Stream ) - * The stream 1,0,1,0,1,0... when written in mu-notation is the term: - * mu x. cons( 1, mu y. cons( 0, x ) ) - * This is represented in CVC4 by the Node: - * cons( 1, cons( 0, c[1] ) ) - * where c[1] is a uninterpreted constant datatype with Debruijn index 1, - * indicating that c[1] is nested underneath 1 level on the path to the - * term which it binds. On the other hand, the stream 1,0,0,0,0,... is - * represented by the codatatype term: - * cons( 1, cons( 0, c[0] ) ) - * - * Subterms that are references to mu-binders in n are replaced by a new - * variable. If n contains any subterm that is a reference to a mu-binder not - * bound in n, then we return null. For example we return null when n is: - * cons( 1, cons( 0, c[2] ) ) - * since c[2] is not bound by this codatatype term. - * - * All valid references to mu-binders are replaced by a variable that is unique - * for the term it references. For example, for the infinite tree codatatype: - * Tree : node( data : Int, left : Tree, right : Tree ) - * If n is the term: - * node( 0, c[0], node( 1, c[0], c[1] ) ) - * then the return value ret of this function is: - * node( 0, x, node( 1, y, x ) ) - * where x refers to the root of the term and y refers to the right tree of the - * root. - * - * The argument sk stores the current set of node that we are traversing - * beneath. The argument rf_pending stores, for each node that we are - * traversing beneath either null or the free variable that we are using to - * refer to its mu-binder. The remaining arguments store information that is - * relevant when performing normalization of n using the value of ret: - * - * rf : maps subterms of n to the corresponding term in ret for all subterms - * where the corresponding term in ret is different. - * terms : stores all subterms of ret. - * cdts : for each term t in terms, stores whether t is a codatatype. - */ - static Node collectRef(Node n, - std::vector& sk, - std::map& rf, - std::vector& rf_pending, - std::vector& terms, - std::map& cdts); - /** normalize codatatype constant eqc - * - * This recursive function returns a codatatype constant that is equivalent to - * n based on a pre-computed partition of the subterms of n into equivalence - * classes, as stored in the mapping eqc, which maps the subterms of n to - * equivalence class ids. The arguments eqc_stack and depth store information - * about the traversal in a term we have recursed, where - * - * eqc_stack : maps the depth of each term we have traversed to its equivalence - * class id. - * depth : the number of levels which we have traversed. - */ - static Node normalizeCodatatypeConstantEqc(Node n, - std::map& eqc_stack, - std::map& eqc, - int depth); - /** replace debruijn - * - * This function, given codatatype term n, returns a node - * where all subterms of n that have Debruijn indices that refer to a - * term of input depth are replaced by orig. For example, for the infinite Tree - * datatype, - * replaceDebruijn( node( 0, c[0], node( 1, c[0], c[1] ) ), t, Tree, 0 ) - * returns - * node( 0, t, node( 1, c[0], t ) ). - */ - static Node replaceDebruijn(Node n, - Node orig, - TypeNode orig_tn, - unsigned depth); -};/* class DatatypesRewriter */ + /** collect references + * + * This function, given as input a codatatype term n, collects the necessary + * information for constructing a (canonical) codatatype constant that is + * equivalent to n if one exists, or null otherwise. + * + * In particular it returns a term ret such that all non-codatatype datatype + * subterms of n are replaced by a constant that is equal to them via a + * (mutually) recursive call to normalizeConstant above. Additionally, this + * function replaces references to mu-binders with fresh variables. + * In detail, mu-terms are represented by uninterpreted constants of datatype + * type that carry their Debruijn index. + * + * Consider the example of a codatatype representing a stream of integers: + * Stream := cons( head : Int, tail : Stream ) + * The stream 1,0,1,0,1,0... when written in mu-notation is the term: + * mu x. cons( 1, mu y. cons( 0, x ) ) + * This is represented in CVC4 by the Node: + * cons( 1, cons( 0, c[1] ) ) + * where c[1] is a uninterpreted constant datatype with Debruijn index 1, + * indicating that c[1] is nested underneath 1 level on the path to the + * term which it binds. On the other hand, the stream 1,0,0,0,0,... is + * represented by the codatatype term: + * cons( 1, cons( 0, c[0] ) ) + * + * Subterms that are references to mu-binders in n are replaced by a new + * variable. If n contains any subterm that is a reference to a mu-binder not + * bound in n, then we return null. For example we return null when n is: + * cons( 1, cons( 0, c[2] ) ) + * since c[2] is not bound by this codatatype term. + * + * All valid references to mu-binders are replaced by a variable that is + * unique for the term it references. For example, for the infinite tree + * codatatype: Tree : node( data : Int, left : Tree, right : Tree ) If n is + * the term: node( 0, c[0], node( 1, c[0], c[1] ) ) then the return value ret + * of this function is: node( 0, x, node( 1, y, x ) ) where x refers to the + * root of the term and y refers to the right tree of the root. + * + * The argument sk stores the current set of node that we are traversing + * beneath. The argument rf_pending stores, for each node that we are + * traversing beneath either null or the free variable that we are using to + * refer to its mu-binder. The remaining arguments store information that is + * relevant when performing normalization of n using the value of ret: + * + * rf : maps subterms of n to the corresponding term in ret for all subterms + * where the corresponding term in ret is different. + * terms : stores all subterms of ret. + * cdts : for each term t in terms, stores whether t is a codatatype. + */ + static Node collectRef(Node n, + std::vector& sk, + std::map& rf, + std::vector& rf_pending, + std::vector& terms, + std::map& cdts); + /** normalize codatatype constant eqc + * + * This recursive function returns a codatatype constant that is equivalent to + * n based on a pre-computed partition of the subterms of n into equivalence + * classes, as stored in the mapping eqc, which maps the subterms of n to + * equivalence class ids. The arguments eqc_stack and depth store information + * about the traversal in a term we have recursed, where + * + * eqc_stack : maps the depth of each term we have traversed to its + * equivalence class id. depth : the number of levels which we have traversed. + */ + static Node normalizeCodatatypeConstantEqc(Node n, + std::map& eqc_stack, + std::map& eqc, + int depth); + /** replace debruijn + * + * This function, given codatatype term n, returns a node + * where all subterms of n that have Debruijn indices that refer to a + * term of input depth are replaced by orig. For example, for the infinite + * Tree datatype, replaceDebruijn( node( 0, c[0], node( 1, c[0], c[1] ) ), t, + * Tree, 0 ) returns node( 0, t, node( 1, c[0], t ) ). + */ + static Node replaceDebruijn(Node n, + Node orig, + TypeNode orig_tn, + unsigned depth); +}; /* class DatatypesRewriter */ }/* CVC4::theory::datatypes namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp index 7f5ce4afb..5d1e3d4da 100644 --- a/src/theory/fp/theory_fp_rewriter.cpp +++ b/src/theory/fp/theory_fp_rewriter.cpp @@ -978,269 +978,287 @@ namespace constantFold { }; /* CVC4::theory::fp::constantFold */ -RewriteFunction TheoryFpRewriter::preRewriteTable[kind::LAST_KIND]; -RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND]; -RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; - - /** * Initialize the rewriter. */ - void TheoryFpRewriter::init() { - - /* Set up the pre-rewrite dispatch table */ - for (unsigned i = 0; i < kind::LAST_KIND; ++i) { - preRewriteTable[i] = rewrite::notFP; - } - - /******** Constants ********/ - /* No rewriting possible for constants */ - preRewriteTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; - preRewriteTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; - - /******** Sorts(?) ********/ - /* These kinds should only appear in types */ - //preRewriteTable[kind::ROUNDINGMODE_TYPE] = rewrite::type; - preRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; - - /******** Operations ********/ - preRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs; - preRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation; - preRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_SUB] = rewrite::convertSubtractionToAddition; - preRewriteTable[kind::FLOATINGPOINT_MULT] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax; - preRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax; - preRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax; - preRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax; - - /******** Comparisons ********/ - preRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::then; - preRewriteTable[kind::FLOATINGPOINT_LEQ] = - rewrite::then; - preRewriteTable[kind::FLOATINGPOINT_LT] = - rewrite::then; - preRewriteTable[kind::FLOATINGPOINT_GEQ] = rewrite::then; - preRewriteTable[kind::FLOATINGPOINT_GT] = rewrite::then; - - /******** Classifications ********/ - preRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ISSN] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ISZ] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ISINF] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ISNAN] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ISNEG] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity; - - /******** Conversions ********/ - preRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; - preRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity; - - /******** Variables ********/ - preRewriteTable[kind::VARIABLE] = rewrite::variable; - preRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable; - preRewriteTable[kind::SKOLEM] = rewrite::variable; - preRewriteTable[kind::INST_CONSTANT] = rewrite::variable; - - preRewriteTable[kind::EQUAL] = rewrite::equal; - - - /******** Components for bit-blasting ********/ - preRewriteTable[kind::FLOATINGPOINT_COMPONENT_NAN] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_COMPONENT_INF] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = rewrite::identity; - preRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] = rewrite::identity; - preRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity; - - - - - /* Set up the post-rewrite dispatch table */ - for (unsigned i = 0; i < kind::LAST_KIND; ++i) { - postRewriteTable[i] = rewrite::notFP; - } - - /******** Constants ********/ - /* No rewriting possible for constants */ - postRewriteTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; - postRewriteTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; - - /******** Sorts(?) ********/ - /* These kinds should only appear in types */ - //postRewriteTable[kind::ROUNDINGMODE_TYPE] = rewrite::type; - postRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; - - /******** Operations ********/ - postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs; - postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation; - postRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::reorderBinaryOperation; - postRewriteTable[kind::FLOATINGPOINT_SUB] = rewrite::removed; - postRewriteTable[kind::FLOATINGPOINT_MULT] = rewrite::reorderBinaryOperation; - postRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::reorderFMA; - postRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::compactRemainder; - postRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax; - postRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax; - postRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax; - postRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax; - - /******** Comparisons ********/ - postRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; - postRewriteTable[kind::FLOATINGPOINT_LEQ] = rewrite::leqId; - postRewriteTable[kind::FLOATINGPOINT_LT] = rewrite::ltId; - postRewriteTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed; - postRewriteTable[kind::FLOATINGPOINT_GT] = rewrite::removed; - - /******** Classifications ********/ - postRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::removeSignOperations; - postRewriteTable[kind::FLOATINGPOINT_ISSN] = rewrite::removeSignOperations; - postRewriteTable[kind::FLOATINGPOINT_ISZ] = rewrite::removeSignOperations; - postRewriteTable[kind::FLOATINGPOINT_ISINF] = rewrite::removeSignOperations; - postRewriteTable[kind::FLOATINGPOINT_ISNAN] = rewrite::removeSignOperations; - postRewriteTable[kind::FLOATINGPOINT_ISNEG] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity; - - /******** Conversions ********/ - postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; - postRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity; - - /******** Variables ********/ - postRewriteTable[kind::VARIABLE] = rewrite::variable; - postRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable; - postRewriteTable[kind::SKOLEM] = rewrite::variable; - postRewriteTable[kind::INST_CONSTANT] = rewrite::variable; - - postRewriteTable[kind::EQUAL] = rewrite::equal; - - - /******** Components for bit-blasting ********/ - postRewriteTable[kind::FLOATINGPOINT_COMPONENT_NAN] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_COMPONENT_INF] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = rewrite::identity; - postRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] = rewrite::identity; - postRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity; - - - - /* Set up the post-rewrite constant fold table */ - for (unsigned i = 0; i < kind::LAST_KIND; ++i) { - // Note that this is identity, not notFP - // Constant folding is called after post-rewrite - // So may have to deal with cases of things being - // re-written to non-floating-point sorts (i.e. true). - constantFoldTable[i] = rewrite::identity; - } - - /******** Constants ********/ - /* Already folded! */ - constantFoldTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; - constantFoldTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; - - /******** Sorts(?) ********/ - /* These kinds should only appear in types */ - constantFoldTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; - - /******** Operations ********/ - constantFoldTable[kind::FLOATINGPOINT_FP] = constantFold::fpLiteral; - constantFoldTable[kind::FLOATINGPOINT_ABS] = constantFold::abs; - constantFoldTable[kind::FLOATINGPOINT_NEG] = constantFold::neg; - constantFoldTable[kind::FLOATINGPOINT_PLUS] = constantFold::plus; - constantFoldTable[kind::FLOATINGPOINT_SUB] = rewrite::removed; - constantFoldTable[kind::FLOATINGPOINT_MULT] = constantFold::mult; - constantFoldTable[kind::FLOATINGPOINT_DIV] = constantFold::div; - constantFoldTable[kind::FLOATINGPOINT_FMA] = constantFold::fma; - constantFoldTable[kind::FLOATINGPOINT_SQRT] = constantFold::sqrt; - constantFoldTable[kind::FLOATINGPOINT_REM] = constantFold::rem; - constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti; - constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min; - constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max; - constantFoldTable[kind::FLOATINGPOINT_MIN_TOTAL] = constantFold::minTotal; - constantFoldTable[kind::FLOATINGPOINT_MAX_TOTAL] = constantFold::maxTotal; - - /******** Comparisons ********/ - constantFoldTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; - constantFoldTable[kind::FLOATINGPOINT_LEQ] = constantFold::leq; - constantFoldTable[kind::FLOATINGPOINT_LT] = constantFold::lt; - constantFoldTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed; - constantFoldTable[kind::FLOATINGPOINT_GT] = rewrite::removed; - - /******** Classifications ********/ - constantFoldTable[kind::FLOATINGPOINT_ISN] = constantFold::isNormal; - constantFoldTable[kind::FLOATINGPOINT_ISSN] = constantFold::isSubnormal; - constantFoldTable[kind::FLOATINGPOINT_ISZ] = constantFold::isZero; - constantFoldTable[kind::FLOATINGPOINT_ISINF] = constantFold::isInfinite; - constantFoldTable[kind::FLOATINGPOINT_ISNAN] = constantFold::isNaN; - constantFoldTable[kind::FLOATINGPOINT_ISNEG] = constantFold::isNegative; - constantFoldTable[kind::FLOATINGPOINT_ISPOS] = constantFold::isPositive; - - /******** Conversions ********/ - constantFoldTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = constantFold::convertFromIEEEBitVectorLiteral; - constantFoldTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = constantFold::constantConvert; - constantFoldTable[kind::FLOATINGPOINT_TO_FP_REAL] = constantFold::convertFromRealLiteral; - constantFoldTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = constantFold::convertFromSBV; - constantFoldTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = constantFold::convertFromUBV; - constantFoldTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; - constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV; - constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV; - constantFoldTable[kind::FLOATINGPOINT_TO_REAL] = constantFold::convertToReal; - constantFoldTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = constantFold::convertToUBVTotal; - constantFoldTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = constantFold::convertToSBVTotal; - constantFoldTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = constantFold::convertToRealTotal; - - /******** Variables ********/ - constantFoldTable[kind::VARIABLE] = rewrite::variable; - constantFoldTable[kind::BOUND_VARIABLE] = rewrite::variable; - - constantFoldTable[kind::EQUAL] = constantFold::equal; - - - /******** Components for bit-blasting ********/ - constantFoldTable[kind::FLOATINGPOINT_COMPONENT_NAN] = constantFold::componentFlag; - constantFoldTable[kind::FLOATINGPOINT_COMPONENT_INF] = constantFold::componentFlag; - constantFoldTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = constantFold::componentFlag; - constantFoldTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = constantFold::componentFlag; - constantFoldTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = constantFold::componentExponent; - constantFoldTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] = constantFold::componentSignificand; - constantFoldTable[kind::ROUNDINGMODE_BITBLAST] = constantFold::roundingModeBitBlast; - - - } - - - +TheoryFpRewriter::TheoryFpRewriter() +{ + /* Set up the pre-rewrite dispatch table */ + for (unsigned i = 0; i < kind::LAST_KIND; ++i) + { + d_preRewriteTable[i] = rewrite::notFP; + } + + /******** Constants ********/ + /* No rewriting possible for constants */ + d_preRewriteTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; + d_preRewriteTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; + + /******** Sorts(?) ********/ + /* These kinds should only appear in types */ + // d_preRewriteTable[kind::ROUNDINGMODE_TYPE] = rewrite::type; + d_preRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; + + /******** Operations ********/ + d_preRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs; + d_preRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation; + d_preRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_SUB] = + rewrite::convertSubtractionToAddition; + d_preRewriteTable[kind::FLOATINGPOINT_MULT] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax; + d_preRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax; + d_preRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax; + d_preRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax; + + /******** Comparisons ********/ + d_preRewriteTable[kind::FLOATINGPOINT_EQ] = + rewrite::then; + d_preRewriteTable[kind::FLOATINGPOINT_LEQ] = + rewrite::then; + d_preRewriteTable[kind::FLOATINGPOINT_LT] = + rewrite::then; + d_preRewriteTable[kind::FLOATINGPOINT_GEQ] = + rewrite::then; + d_preRewriteTable[kind::FLOATINGPOINT_GT] = + rewrite::then; + + /******** Classifications ********/ + d_preRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ISSN] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ISZ] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ISINF] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ISNAN] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ISNEG] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity; + + /******** Conversions ********/ + d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = + rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = + rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = + rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = + rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; + d_preRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity; + + /******** Variables ********/ + d_preRewriteTable[kind::VARIABLE] = rewrite::variable; + d_preRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable; + d_preRewriteTable[kind::SKOLEM] = rewrite::variable; + d_preRewriteTable[kind::INST_CONSTANT] = rewrite::variable; + + d_preRewriteTable[kind::EQUAL] = rewrite::equal; + + /******** Components for bit-blasting ********/ + d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_NAN] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_INF] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = rewrite::identity; + d_preRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] = + rewrite::identity; + d_preRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity; + + /* Set up the post-rewrite dispatch table */ + for (unsigned i = 0; i < kind::LAST_KIND; ++i) + { + d_postRewriteTable[i] = rewrite::notFP; + } + + /******** Constants ********/ + /* No rewriting possible for constants */ + d_postRewriteTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; + d_postRewriteTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; + + /******** Sorts(?) ********/ + /* These kinds should only appear in types */ + // d_postRewriteTable[kind::ROUNDINGMODE_TYPE] = rewrite::type; + d_postRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; + + /******** Operations ********/ + d_postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::compactAbs; + d_postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation; + d_postRewriteTable[kind::FLOATINGPOINT_PLUS] = + rewrite::reorderBinaryOperation; + d_postRewriteTable[kind::FLOATINGPOINT_SUB] = rewrite::removed; + d_postRewriteTable[kind::FLOATINGPOINT_MULT] = + rewrite::reorderBinaryOperation; + d_postRewriteTable[kind::FLOATINGPOINT_DIV] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_FMA] = rewrite::reorderFMA; + d_postRewriteTable[kind::FLOATINGPOINT_SQRT] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_REM] = rewrite::compactRemainder; + d_postRewriteTable[kind::FLOATINGPOINT_RTI] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_MIN] = rewrite::compactMinMax; + d_postRewriteTable[kind::FLOATINGPOINT_MAX] = rewrite::compactMinMax; + d_postRewriteTable[kind::FLOATINGPOINT_MIN_TOTAL] = rewrite::compactMinMax; + d_postRewriteTable[kind::FLOATINGPOINT_MAX_TOTAL] = rewrite::compactMinMax; + + /******** Comparisons ********/ + d_postRewriteTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; + d_postRewriteTable[kind::FLOATINGPOINT_LEQ] = rewrite::leqId; + d_postRewriteTable[kind::FLOATINGPOINT_LT] = rewrite::ltId; + d_postRewriteTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed; + d_postRewriteTable[kind::FLOATINGPOINT_GT] = rewrite::removed; + + /******** Classifications ********/ + d_postRewriteTable[kind::FLOATINGPOINT_ISN] = rewrite::removeSignOperations; + d_postRewriteTable[kind::FLOATINGPOINT_ISSN] = rewrite::removeSignOperations; + d_postRewriteTable[kind::FLOATINGPOINT_ISZ] = rewrite::removeSignOperations; + d_postRewriteTable[kind::FLOATINGPOINT_ISINF] = rewrite::removeSignOperations; + d_postRewriteTable[kind::FLOATINGPOINT_ISNAN] = rewrite::removeSignOperations; + d_postRewriteTable[kind::FLOATINGPOINT_ISNEG] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity; + + /******** Conversions ********/ + d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = + rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = + rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = + rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = + rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; + d_postRewriteTable[kind::FLOATINGPOINT_TO_UBV] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_SBV] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_REAL] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = rewrite::identity; + + /******** Variables ********/ + d_postRewriteTable[kind::VARIABLE] = rewrite::variable; + d_postRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable; + d_postRewriteTable[kind::SKOLEM] = rewrite::variable; + d_postRewriteTable[kind::INST_CONSTANT] = rewrite::variable; + + d_postRewriteTable[kind::EQUAL] = rewrite::equal; + + /******** Components for bit-blasting ********/ + d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_NAN] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_INF] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = + rewrite::identity; + d_postRewriteTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] = + rewrite::identity; + d_postRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity; + + /* Set up the post-rewrite constant fold table */ + for (unsigned i = 0; i < kind::LAST_KIND; ++i) + { + // Note that this is identity, not notFP + // Constant folding is called after post-rewrite + // So may have to deal with cases of things being + // re-written to non-floating-point sorts (i.e. true). + d_constantFoldTable[i] = rewrite::identity; + } + + /******** Constants ********/ + /* Already folded! */ + d_constantFoldTable[kind::CONST_FLOATINGPOINT] = rewrite::identity; + d_constantFoldTable[kind::CONST_ROUNDINGMODE] = rewrite::identity; + + /******** Sorts(?) ********/ + /* These kinds should only appear in types */ + d_constantFoldTable[kind::FLOATINGPOINT_TYPE] = rewrite::type; + + /******** Operations ********/ + d_constantFoldTable[kind::FLOATINGPOINT_FP] = constantFold::fpLiteral; + d_constantFoldTable[kind::FLOATINGPOINT_ABS] = constantFold::abs; + d_constantFoldTable[kind::FLOATINGPOINT_NEG] = constantFold::neg; + d_constantFoldTable[kind::FLOATINGPOINT_PLUS] = constantFold::plus; + d_constantFoldTable[kind::FLOATINGPOINT_SUB] = rewrite::removed; + d_constantFoldTable[kind::FLOATINGPOINT_MULT] = constantFold::mult; + d_constantFoldTable[kind::FLOATINGPOINT_DIV] = constantFold::div; + d_constantFoldTable[kind::FLOATINGPOINT_FMA] = constantFold::fma; + d_constantFoldTable[kind::FLOATINGPOINT_SQRT] = constantFold::sqrt; + d_constantFoldTable[kind::FLOATINGPOINT_REM] = constantFold::rem; + d_constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti; + d_constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min; + d_constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max; + d_constantFoldTable[kind::FLOATINGPOINT_MIN_TOTAL] = constantFold::minTotal; + d_constantFoldTable[kind::FLOATINGPOINT_MAX_TOTAL] = constantFold::maxTotal; + + /******** Comparisons ********/ + d_constantFoldTable[kind::FLOATINGPOINT_EQ] = rewrite::removed; + d_constantFoldTable[kind::FLOATINGPOINT_LEQ] = constantFold::leq; + d_constantFoldTable[kind::FLOATINGPOINT_LT] = constantFold::lt; + d_constantFoldTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed; + d_constantFoldTable[kind::FLOATINGPOINT_GT] = rewrite::removed; + + /******** Classifications ********/ + d_constantFoldTable[kind::FLOATINGPOINT_ISN] = constantFold::isNormal; + d_constantFoldTable[kind::FLOATINGPOINT_ISSN] = constantFold::isSubnormal; + d_constantFoldTable[kind::FLOATINGPOINT_ISZ] = constantFold::isZero; + d_constantFoldTable[kind::FLOATINGPOINT_ISINF] = constantFold::isInfinite; + d_constantFoldTable[kind::FLOATINGPOINT_ISNAN] = constantFold::isNaN; + d_constantFoldTable[kind::FLOATINGPOINT_ISNEG] = constantFold::isNegative; + d_constantFoldTable[kind::FLOATINGPOINT_ISPOS] = constantFold::isPositive; + + /******** Conversions ********/ + d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = + constantFold::convertFromIEEEBitVectorLiteral; + d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = + constantFold::constantConvert; + d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_REAL] = + constantFold::convertFromRealLiteral; + d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = + constantFold::convertFromSBV; + d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = + constantFold::convertFromUBV; + d_constantFoldTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed; + d_constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV; + d_constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV; + d_constantFoldTable[kind::FLOATINGPOINT_TO_REAL] = + constantFold::convertToReal; + d_constantFoldTable[kind::FLOATINGPOINT_TO_UBV_TOTAL] = + constantFold::convertToUBVTotal; + d_constantFoldTable[kind::FLOATINGPOINT_TO_SBV_TOTAL] = + constantFold::convertToSBVTotal; + d_constantFoldTable[kind::FLOATINGPOINT_TO_REAL_TOTAL] = + constantFold::convertToRealTotal; + + /******** Variables ********/ + d_constantFoldTable[kind::VARIABLE] = rewrite::variable; + d_constantFoldTable[kind::BOUND_VARIABLE] = rewrite::variable; + + d_constantFoldTable[kind::EQUAL] = constantFold::equal; + + /******** Components for bit-blasting ********/ + d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_NAN] = + constantFold::componentFlag; + d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_INF] = + constantFold::componentFlag; + d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_ZERO] = + constantFold::componentFlag; + d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_SIGN] = + constantFold::componentFlag; + d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_EXPONENT] = + constantFold::componentExponent; + d_constantFoldTable[kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND] = + constantFold::componentSignificand; + d_constantFoldTable[kind::ROUNDINGMODE_BITBLAST] = + constantFold::roundingModeBitBlast; +} /** * Rewrite a node into the normal form for the theory of fp @@ -1254,7 +1272,7 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; RewriteResponse TheoryFpRewriter::preRewrite(TNode node) { Trace("fp-rewrite") << "TheoryFpRewriter::preRewrite(): " << node << std::endl; - RewriteResponse res = preRewriteTable [node.getKind()] (node, true); + RewriteResponse res = d_preRewriteTable[node.getKind()](node, true); if (res.node != node) { Debug("fp-rewrite") << "TheoryFpRewriter::preRewrite(): before " << node << std::endl; Debug("fp-rewrite") << "TheoryFpRewriter::preRewrite(): after " << res.node << std::endl; @@ -1287,7 +1305,7 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; RewriteResponse TheoryFpRewriter::postRewrite(TNode node) { Trace("fp-rewrite") << "TheoryFpRewriter::postRewrite(): " << node << std::endl; - RewriteResponse res = postRewriteTable [node.getKind()] (node, false); + RewriteResponse res = d_postRewriteTable[node.getKind()](node, false); if (res.node != node) { Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before " << node << std::endl; Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after " << res.node << std::endl; @@ -1366,9 +1384,10 @@ RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND]; wRTP)))); } } else { - RewriteResponse tmp = constantFoldTable [res.node.getKind()] (res.node, false); - rs = tmp.status; - rn = tmp.node; + RewriteResponse tmp = + d_constantFoldTable[res.node.getKind()](res.node, false); + rs = tmp.status; + rn = tmp.node; } RewriteResponse constRes(rs,rn); diff --git a/src/theory/fp/theory_fp_rewriter.h b/src/theory/fp/theory_fp_rewriter.h index fef97afea..790e9d83d 100644 --- a/src/theory/fp/theory_fp_rewriter.h +++ b/src/theory/fp/theory_fp_rewriter.h @@ -20,7 +20,7 @@ #ifndef CVC4__THEORY__FP__THEORY_FP_REWRITER_H #define CVC4__THEORY__FP__THEORY_FP_REWRITER_H -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" namespace CVC4 { namespace theory { @@ -28,37 +28,28 @@ namespace fp { typedef RewriteResponse (*RewriteFunction) (TNode, bool); -class TheoryFpRewriter { - protected : - static RewriteFunction preRewriteTable[kind::LAST_KIND]; - static RewriteFunction postRewriteTable[kind::LAST_KIND]; - static RewriteFunction constantFoldTable[kind::LAST_KIND]; - - +class TheoryFpRewriter : public TheoryRewriter +{ public: + TheoryFpRewriter(); - static RewriteResponse preRewrite(TNode node); - static RewriteResponse postRewrite(TNode node); - + RewriteResponse preRewrite(TNode node) override; + RewriteResponse postRewrite(TNode node) override; /** * Rewrite an equality, in case special handling is required. */ - static Node rewriteEquality(TNode equality) { + Node rewriteEquality(TNode equality) + { // often this will suffice return postRewrite(equality).node; } - static void init(); - - /** - * Shut down the rewriter. - */ - static inline void shutdown() { - // nothing to do - } - -};/* class TheoryFpRewriter */ + protected: + RewriteFunction d_preRewriteTable[kind::LAST_KIND]; + RewriteFunction d_postRewriteTable[kind::LAST_KIND]; + RewriteFunction d_constantFoldTable[kind::LAST_KIND]; +}; /* class TheoryFpRewriter */ }/* CVC4::theory::fp namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/mkrewriter b/src/theory/mkrewriter index ea7deb270..dd5abd219 100755 --- a/src/theory/mkrewriter +++ b/src/theory/mkrewriter @@ -38,13 +38,10 @@ template=$1; shift rewriter_includes= rewrite_init= -rewrite_shutdown= -pre_rewrite_calls= pre_rewrite_get_cache= pre_rewrite_set_cache= -post_rewrite_calls= post_rewrite_get_cache= post_rewrite_set_cache= @@ -143,24 +140,18 @@ function rewriter { rewriter_includes="${rewriter_includes}#include \"$header\" " - rewrite_init="${rewrite_init} ${class}::init(); -" - rewrite_shutdown="${rewrite_shutdown} ${class}::shutdown(); + rewrite_init="${rewrite_init} d_theoryRewriters[${theory_id}].reset(new ${class}); " pre_rewrite_attribute_ids="${pre_rewrite_attribute_ids} preids.push_back(expr::attr::AttributeManager::getAttributeId(RewriteAttibute<${theory_id}>::pre_rewrite())); " post_rewrite_attribute_ids="${post_rewrite_attribute_ids} postids.push_back(expr::attr::AttributeManager::getAttributeId(RewriteAttibute<${theory_id}>::post_rewrite())); " - pre_rewrite_calls="${pre_rewrite_calls} case ${theory_id}: return ${class}::preRewrite(node); -" pre_rewrite_get_cache="${pre_rewrite_get_cache} case ${theory_id}: return RewriteAttibute<${theory_id}>::getPreRewriteCache(node); " pre_rewrite_set_cache="${pre_rewrite_set_cache} case ${theory_id}: return RewriteAttibute<${theory_id}>::setPreRewriteCache(node, cache); " - post_rewrite_calls="${post_rewrite_calls} case ${theory_id}: return ${class}::postRewrite(node); -" post_rewrite_get_cache="${post_rewrite_get_cache} case ${theory_id}: return RewriteAttibute<${theory_id}>::getPostRewriteCache(node); " post_rewrite_set_cache="${post_rewrite_set_cache} case ${theory_id}: return RewriteAttibute<${theory_id}>::setPostRewriteCache(node, cache); @@ -262,13 +253,11 @@ nl -ba -s' ' "$template" | grep '^ *[0-9][0-9]* # *line' | text=$(cat "$template") for var in \ rewriter_includes \ - pre_rewrite_calls \ - post_rewrite_calls \ pre_rewrite_get_cache \ post_rewrite_get_cache \ pre_rewrite_set_cache \ post_rewrite_set_cache \ - rewrite_init rewrite_shutdown \ + rewrite_init \ pre_rewrite_attribute_ids \ post_rewrite_attribute_ids \ template \ diff --git a/src/theory/quantifiers/quantifiers_rewriter.h b/src/theory/quantifiers/quantifiers_rewriter.h index 5d5e23c75..56eac761e 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.h +++ b/src/theory/quantifiers/quantifiers_rewriter.h @@ -19,7 +19,7 @@ #ifndef CVC4__THEORY__QUANTIFIERS__QUANTIFIERS_REWRITER_H #define CVC4__THEORY__QUANTIFIERS__QUANTIFIERS_REWRITER_H -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" namespace CVC4 { namespace theory { @@ -27,8 +27,9 @@ namespace quantifiers { struct QAttributes; -class QuantifiersRewriter { -public: +class QuantifiersRewriter : public TheoryRewriter +{ + public: static bool isLiteral( Node n ); //-------------------------------------variable elimination utilities /** is variable elimination @@ -194,10 +195,9 @@ private: }; static Node computeOperation( Node f, int computeOption, QAttributes& qa ); public: - static RewriteResponse preRewrite(TNode in); - static RewriteResponse postRewrite(TNode in); - static inline void init() {} - static inline void shutdown() {} + RewriteResponse preRewrite(TNode in) override; + RewriteResponse postRewrite(TNode in) override; + private: /** options */ static bool doOperation( Node f, int computeOption, QAttributes& qa ); @@ -223,7 +223,7 @@ public: static Node mkForAll( std::vector< Node >& args, Node body, QAttributes& qa ); static Node mkForall( std::vector< Node >& args, Node body, bool marked = false ); static Node mkForall( std::vector< Node >& args, Node body, std::vector< Node >& iplc, bool marked = false ); -};/* class QuantifiersRewriter */ +}; /* class QuantifiersRewriter */ }/* CVC4::theory::quantifiers namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 866883681..3380694e7 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -28,37 +28,34 @@ using namespace std; namespace CVC4 { namespace theory { -unsigned long Rewriter::d_iterationCount = 0; - static TheoryId theoryOf(TNode node) { return Theory::theoryOf(THEORY_OF_TYPE_BASED, node); } -#ifdef CVC4_ASSERTIONS -static thread_local std::unique_ptr> - s_rewriteStack = nullptr; -#endif /* CVC4_ASSERTIONS */ - -class RewriterInitializer { - static RewriterInitializer s_rewriterInitializer; - RewriterInitializer() { - Rewriter::init(); - } - ~RewriterInitializer() { Rewriter::shutdown(); } -};/* class RewriterInitializer */ - -/** - * This causes initialization of the rewriter before first use, - * and tear-down at exit time. - */ -RewriterInitializer RewriterInitializer::s_rewriterInitializer; - /** * TheoryEngine::rewrite() keeps a stack of things that are being pre- * and post-rewritten. Each element of the stack is a * RewriteStackElement. */ struct RewriteStackElement { + /** + * Construct a fresh stack element. + */ + RewriteStackElement(TNode node, TheoryId theoryId) + : node(node), + original(node), + theoryId(theoryId), + originalTheoryId(theoryId), + nextChild(0) + { + } + + TheoryId getTheoryId() { return static_cast(theoryId); } + + TheoryId getOriginalTheoryId() + { + return static_cast(originalTheoryId); + } /** The node we're currently rewriting */ Node node; @@ -72,21 +69,17 @@ struct RewriteStackElement { unsigned nextChild : 32; /** Builder for this node */ NodeBuilder<> builder; - - /** - * Construct a fresh stack element. - */ - RewriteStackElement(TNode node, TheoryId theoryId) : - node(node), - original(node), - theoryId(theoryId), - originalTheoryId(theoryId), - nextChild(0) { - } }; Node Rewriter::rewrite(TNode node) { - return rewriteTo(theoryOf(node), node); + Rewriter& rewriter = getInstance(); + return rewriter.rewriteTo(theoryOf(node), node); +} + +Rewriter& Rewriter::getInstance() +{ + thread_local static Rewriter rewriter; + return rewriter; } Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { @@ -94,9 +87,9 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { #ifdef CVC4_ASSERTIONS bool isEquality = node.getKind() == kind::EQUAL && (!node[0].getType().isBoolean()); - if (s_rewriteStack == nullptr) + if (d_rewriteStack == nullptr) { - s_rewriteStack.reset(new std::unordered_set()); + d_rewriteStack.reset(new std::unordered_set()); } #endif @@ -129,29 +122,38 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { // Get the top of the recursion stack RewriteStackElement& rewriteStackTop = rewriteStack.back(); - Trace("rewriter") << "Rewriter::rewriting: " << (TheoryId) rewriteStackTop.theoryId << "," << rewriteStackTop.node << std::endl; + Trace("rewriter") << "Rewriter::rewriting: " + << rewriteStackTop.getTheoryId() << "," + << rewriteStackTop.node << std::endl; // Before rewriting children we need to do a pre-rewrite of the node if (rewriteStackTop.nextChild == 0) { // Check if the pre-rewrite has already been done (it's in the cache) - Node cached = Rewriter::getPreRewriteCache((TheoryId) rewriteStackTop.theoryId, rewriteStackTop.node); + Node cached = getPreRewriteCache(rewriteStackTop.getTheoryId(), + rewriteStackTop.node); if (cached.isNull()) { // Rewrite until fix-point is reached for(;;) { // Perform the pre-rewrite - RewriteResponse response = Rewriter::callPreRewrite((TheoryId) rewriteStackTop.theoryId, rewriteStackTop.node); + RewriteResponse response = + d_theoryRewriters[rewriteStackTop.getTheoryId()]->preRewrite( + rewriteStackTop.node); // Put the rewritten node to the top of the stack rewriteStackTop.node = response.node; TheoryId newTheory = theoryOf(rewriteStackTop.node); // In the pre-rewrite, if changing theories, we just call the other theories pre-rewrite - if (newTheory == (TheoryId) rewriteStackTop.theoryId && response.status == REWRITE_DONE) { + if (newTheory == rewriteStackTop.getTheoryId() + && response.status == REWRITE_DONE) + { break; } rewriteStackTop.theoryId = newTheory; } // Cache the rewrite - Rewriter::setPreRewriteCache((TheoryId) rewriteStackTop.originalTheoryId, rewriteStackTop.original, rewriteStackTop.node); + setPreRewriteCache(rewriteStackTop.getOriginalTheoryId(), + rewriteStackTop.original, + rewriteStackTop.node); } // Otherwise we're have already been pre-rewritten (in pre-rewrite cache) else { @@ -163,7 +165,8 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { rewriteStackTop.original =rewriteStackTop.node; // Now it's time to rewrite the children, check if this has already been done - Node cached = Rewriter::getPostRewriteCache((TheoryId) rewriteStackTop.theoryId, rewriteStackTop.node); + Node cached = getPostRewriteCache(rewriteStackTop.getTheoryId(), + rewriteStackTop.node); // If not, go through the children if(cached.isNull()) { @@ -202,26 +205,33 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { // Done with all pre-rewriting, so let's do the post rewrite for(;;) { // Do the post-rewrite - RewriteResponse response = Rewriter::callPostRewrite((TheoryId) rewriteStackTop.theoryId, rewriteStackTop.node); + RewriteResponse response = + d_theoryRewriters[rewriteStackTop.getTheoryId()]->postRewrite( + rewriteStackTop.node); // We continue with the response we got TheoryId newTheoryId = theoryOf(response.node); - if (newTheoryId != (TheoryId) rewriteStackTop.theoryId || response.status == REWRITE_AGAIN_FULL) { + if (newTheoryId != rewriteStackTop.getTheoryId() + || response.status == REWRITE_AGAIN_FULL) + { // In the post rewrite if we've changed theories, we must do a full rewrite Assert(response.node != rewriteStackTop.node); //TODO: this is not thread-safe - should make this assertion dependent on sequential build #ifdef CVC4_ASSERTIONS - Assert(s_rewriteStack->find(response.node) == s_rewriteStack->end()); - s_rewriteStack->insert(response.node); + Assert(d_rewriteStack->find(response.node) == d_rewriteStack->end()); + d_rewriteStack->insert(response.node); #endif Node rewritten = rewriteTo(newTheoryId, response.node); rewriteStackTop.node = rewritten; #ifdef CVC4_ASSERTIONS - s_rewriteStack->erase(response.node); + d_rewriteStack->erase(response.node); #endif break; - } else if (response.status == REWRITE_DONE) { + } + else if (response.status == REWRITE_DONE) + { #ifdef CVC4_ASSERTIONS - RewriteResponse r2 = Rewriter::callPostRewrite(newTheoryId, response.node); + RewriteResponse r2 = + d_theoryRewriters[newTheoryId]->postRewrite(response.node); Assert(r2.node == response.node); #endif rewriteStackTop.node = response.node; @@ -229,15 +239,16 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { } // Check for trivial rewrite loops of size 1 or 2 Assert(response.node != rewriteStackTop.node); - Assert(Rewriter::callPostRewrite((TheoryId)rewriteStackTop.theoryId, - response.node) + Assert(d_theoryRewriters[rewriteStackTop.getTheoryId()] + ->postRewrite(response.node) .node != rewriteStackTop.node); rewriteStackTop.node = response.node; } // We're done with the post rewrite, so we add to the cache - Rewriter::setPostRewriteCache((TheoryId) rewriteStackTop.originalTheoryId, rewriteStackTop.original, rewriteStackTop.node); - + setPostRewriteCache(rewriteStackTop.getOriginalTheoryId(), + rewriteStackTop.original, + rewriteStackTop.node); } else { // We were already in cache, so just remember it rewriteStackTop.node = cached; @@ -260,13 +271,13 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { }/* Rewriter::rewriteTo() */ void Rewriter::clearCaches() { + Rewriter& rewriter = getInstance(); + #ifdef CVC4_ASSERTIONS - if (s_rewriteStack != nullptr) - { - s_rewriteStack.reset(nullptr); - } + rewriter.d_rewriteStack.reset(nullptr); #endif - Rewriter::clearCachesInternal(); + + rewriter.clearCachesInternal(); } }/* CVC4::theory namespace */ diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index 5a15d15fb..e55ca5d1c 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -19,100 +19,79 @@ #pragma once #include "expr/node.h" +#include "theory/theory_rewriter.h" #include "util/unsafe_interrupt_exception.h" namespace CVC4 { namespace theory { -/** - * Theory rewriters signal whether more rewriting is needed (or not) - * by using a member of this enumeration. See RewriteResponse, below. - */ -enum RewriteStatus { - REWRITE_DONE, - REWRITE_AGAIN, - REWRITE_AGAIN_FULL -};/* enum RewriteStatus */ - -/** - * Instances of this class serve as response codes from - * Theory::preRewrite() and Theory::postRewrite(). Instances of - * derived classes RewriteComplete(n), RewriteAgain(n), and - * FullRewriteNeeded(n) should be used, giving self-documenting - * rewrite behavior. - */ -struct RewriteResponse { - const RewriteStatus status; - const Node node; - RewriteResponse(RewriteStatus status, Node node) : - status(status), node(node) {} -};/* struct RewriteResponse */ - class RewriterInitializer; /** - * The main rewriter class. All functionality is static. + * The main rewriter class. */ class Rewriter { + public: + Rewriter(); + + /** + * Rewrites the node using theoryOf() to determine which rewriter to + * use on the node. + */ + static Node rewrite(TNode node); + + /** + * Garbage collects the rewrite caches. + */ + static void clearCaches(); + + private: + /** + * Get the (singleton) instance of the rewriter. + * + * TODO(#3468): Get rid of this singleton + */ + static Rewriter& getInstance(); - friend class RewriterInitializer; - static unsigned long d_iterationCount; /** Returns the appropriate cache for a node */ - static Node getPreRewriteCache(theory::TheoryId theoryId, TNode node); + Node getPreRewriteCache(theory::TheoryId theoryId, TNode node); /** Returns the appropriate cache for a node */ - static Node getPostRewriteCache(theory::TheoryId theoryId, TNode node); + Node getPostRewriteCache(theory::TheoryId theoryId, TNode node); /** Sets the appropriate cache for a node */ - static void setPreRewriteCache(theory::TheoryId theoryId, - TNode node, TNode cache); + void setPreRewriteCache(theory::TheoryId theoryId, TNode node, TNode cache); /** Sets the appropriate cache for a node */ - static void setPostRewriteCache(theory::TheoryId theoryId, - TNode node, TNode cache); - - // disable construction of rewriters; all functionality is static - Rewriter() = delete; - Rewriter(const Rewriter&) = delete; + void setPostRewriteCache(theory::TheoryId theoryId, TNode node, TNode cache); /** * Rewrites the node using the given theory rewriter. */ - static Node rewriteTo(theory::TheoryId theoryId, Node node); + Node rewriteTo(theory::TheoryId theoryId, Node node); /** Calls the pre-rewriter for the given theory */ - static RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node); /** Calls the post-rewriter for the given theory */ - static RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node); + RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node); /** * Calls the equality-rewriter for the given theory. */ - static Node callRewriteEquality(theory::TheoryId theoryId, TNode equality); + Node callRewriteEquality(theory::TheoryId theoryId, TNode equality); - /** - * Should be called before the rewriter gets used for the first time. - */ - static void init(); + void clearCachesInternal(); - /** - * Should be called to clean up any state. - */ - static void shutdown(); - static void clearCachesInternal(); -public: + /** Theory rewriters managed by this rewriter instance */ + std::unique_ptr d_theoryRewriters[theory::THEORY_LAST]; - /** - * Rewrites the node using theoryOf() to determine which rewriter to - * use on the node. - */ - static Node rewrite(TNode node); + unsigned long d_iterationCount = 0; - /** - * Garbage collects the rewrite caches. - */ - static void clearCaches(); +#ifdef CVC4_ASSERTIONS + std::unique_ptr> d_rewriteStack = + nullptr; +#endif /* CVC4_ASSERTIONS */ };/* class Rewriter */ }/* CVC4::theory namespace */ diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index 531b62a49..e1be6355b 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -29,22 +29,6 @@ ${rewriter_includes} namespace CVC4 { namespace theory { -RewriteResponse Rewriter::callPreRewrite(theory::TheoryId theoryId, TNode node) { - switch(theoryId) { -${pre_rewrite_calls} - default: - Unreachable(); - } -} - -RewriteResponse Rewriter::callPostRewrite(theory::TheoryId theoryId, TNode node) { - switch(theoryId) { -${post_rewrite_calls} - default: - Unreachable(); - } -} - Node Rewriter::getPreRewriteCache(theory::TheoryId theoryId, TNode node) { switch(theoryId) { ${pre_rewrite_get_cache} @@ -77,14 +61,11 @@ ${post_rewrite_set_cache} } } -void Rewriter::init() { +Rewriter::Rewriter() +{ ${rewrite_init} } -void Rewriter::shutdown() { -${rewrite_shutdown} -} - void Rewriter::clearCachesInternal() { typedef CVC4::expr::attr::AttributeUniqueId AttributeUniqueId; std::vector preids; diff --git a/src/theory/sep/theory_sep_rewriter.h b/src/theory/sep/theory_sep_rewriter.h index f4639efa3..0d1de3463 100644 --- a/src/theory/sep/theory_sep_rewriter.h +++ b/src/theory/sep/theory_sep_rewriter.h @@ -20,30 +20,32 @@ #ifndef CVC4__THEORY__SEP__THEORY_SEP_REWRITER_H #define CVC4__THEORY__SEP__THEORY_SEP_REWRITER_H -#include "theory/rewriter.h" +#include "theory/theory_rewriter.h" #include "theory/type_enumerator.h" namespace CVC4 { namespace theory { namespace sep { - -class TheorySepRewriter { -private: - static void getStarChildren( Node n, std::vector< Node >& s_children, std::vector< Node >& ns_children ); - static void getAndChildren( Node n, std::vector< Node >& s_children, std::vector< Node >& ns_children ); - static bool isSpatial( Node n, std::map< Node, bool >& visited ); -public: - - static RewriteResponse postRewrite(TNode node); - static inline RewriteResponse preRewrite(TNode node) { +class TheorySepRewriter : public TheoryRewriter +{ + public: + RewriteResponse postRewrite(TNode node) override; + RewriteResponse preRewrite(TNode node) override + { Trace("sep-prerewrite") << "Sep::preRewrite returning " << node << std::endl; return RewriteResponse(REWRITE_DONE, node); } - static inline void init() {} - static inline void shutdown() {} -};/* class TheorySepRewriter */ + private: + static void getStarChildren(Node n, + std::vector& s_children, + std::vector& ns_children); + static void getAndChildren(Node n, + std::vector& s_children, + std::vector& ns_children); + static bool isSpatial(Node n, std::map& visited); +}; /* class TheorySepRewriter */ }/* CVC4::theory::sep namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h index 463d02d6d..b6c14e30a 100644 --- a/src/theory/sets/theory_sets_rewriter.h +++ b/src/theory/sets/theory_sets_rewriter.h @@ -25,9 +25,9 @@ namespace CVC4 { namespace theory { namespace sets { -class TheorySetsRewriter { -public: - +class TheorySetsRewriter : public TheoryRewriter +{ + public: /** * Rewrite a node into the normal form for the theory of sets. * Called in post-order (really reverse-topological order) when @@ -49,7 +49,7 @@ public: * expression belongs to a different theory, it will be fully * rewritten by that theory's rewriter. */ - static RewriteResponse postRewrite(TNode node); + RewriteResponse postRewrite(TNode node) override; /** * Rewrite a node into the normal form for the theory of sets @@ -60,30 +60,18 @@ public: * nasty expression). Since it's only an optimization, the * implementation here can do nothing. */ - static RewriteResponse preRewrite(TNode node); + RewriteResponse preRewrite(TNode node) override; /** * Rewrite an equality, in case special handling is required. */ - static Node rewriteEquality(TNode equality) { + Node rewriteEquality(TNode equality) + { // often this will suffice return postRewrite(equality).node; } - /** - * Initialize the rewriter. - */ - static inline void init() { - // nothing to do - } - - /** - * Shut down the rewriter. - */ - static inline void shutdown() { - // nothing to do - } -};/* class TheorySetsRewriter */ +}; /* class TheorySetsRewriter */ }/* CVC4::theory::sets namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/strings/regexp_elim.cpp b/src/theory/strings/regexp_elim.cpp index 42d679692..1942938c3 100644 --- a/src/theory/strings/regexp_elim.cpp +++ b/src/theory/strings/regexp_elim.cpp @@ -16,6 +16,7 @@ #include "theory/strings/regexp_elim.h" #include "options/strings_options.h" +#include "theory/rewriter.h" #include "theory/strings/theory_strings_rewriter.h" #include "theory/strings/theory_strings_utils.h" diff --git a/src/theory/strings/theory_strings_rewriter.h b/src/theory/strings/theory_strings_rewriter.h index b19d49e67..35805e1c2 100644 --- a/src/theory/strings/theory_strings_rewriter.h +++ b/src/theory/strings/theory_strings_rewriter.h @@ -18,19 +18,20 @@ #ifndef CVC4__THEORY__STRINGS__THEORY_STRINGS_REWRITER_H #define CVC4__THEORY__STRINGS__THEORY_STRINGS_REWRITER_H +#include #include #include -#include "theory/rewriter.h" -#include "theory/type_enumerator.h" #include "expr/attribute.h" -#include +#include "theory/theory_rewriter.h" +#include "theory/type_enumerator.h" namespace CVC4 { namespace theory { namespace strings { -class TheoryStringsRewriter { +class TheoryStringsRewriter : public TheoryRewriter +{ private: /** simple regular expression consume * @@ -155,11 +156,9 @@ class TheoryStringsRewriter { static Node returnRewrite(Node node, Node ret, const char* c); public: - static RewriteResponse postRewrite(TNode node); - static RewriteResponse preRewrite(TNode node); + RewriteResponse postRewrite(TNode node) override; + RewriteResponse preRewrite(TNode node) override; - static inline void init() {} - static inline void shutdown() {} /** get the cardinality of the alphabet used, based on the options */ static unsigned getAlphabetCardinality(); /** rewrite equality @@ -769,7 +768,7 @@ class TheoryStringsRewriter { * and the list of nodes that are compared to the empty string */ static std::pair > collectEmptyEqs(Node x); -};/* class TheoryStringsRewriter */ +}; /* class TheoryStringsRewriter */ }/* CVC4::theory::strings namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h new file mode 100644 index 000000000..b0171813b --- /dev/null +++ b/src/theory/theory_rewriter.h @@ -0,0 +1,79 @@ +/********************* */ +/*! \file theory_rewriter.h + ** \verbatim + ** Top contributors (to current version): + ** Andres Noetzli + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The TheoryRewriter class + ** + ** The TheoryRewriter class is the interface that theory rewriters implement. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__THEORY_REWRITER_H +#define CVC4__THEORY__THEORY_REWRITER_H + +#include "expr/node.h" + +namespace CVC4 { +namespace theory { + +/** + * Theory rewriters signal whether more rewriting is needed (or not) + * by using a member of this enumeration. See RewriteResponse, below. + */ +enum RewriteStatus +{ + /** The node is fully rewritten (no more rewrites apply) */ + REWRITE_DONE, + /** The node may be rewritten further */ + REWRITE_AGAIN, + /** Subnodes of the node may be rewritten further */ + REWRITE_AGAIN_FULL +}; /* enum RewriteStatus */ + +/** + * Instances of this class serve as response codes from + * TheoryRewriter::preRewrite() and TheoryRewriter::postRewrite(). The response + * consists of the rewritten node as well as a status that indicates whether + * more rewriting is needed or not. + */ +struct RewriteResponse +{ + const RewriteStatus status; + const Node node; + RewriteResponse(RewriteStatus status, Node node) : status(status), node(node) + { + } +}; /* struct RewriteResponse */ + +class TheoryRewriter +{ + public: + virtual ~TheoryRewriter() = default; + + /** + * Performs a pre-rewrite step. + * + * @param node The node to rewrite + */ + virtual RewriteResponse postRewrite(TNode node) = 0; + + /** + * Performs a post-rewrite step. + * + * @param node The node to rewrite + */ + virtual RewriteResponse preRewrite(TNode node) = 0; +}; + +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__THEORY_REWRITER_H */ diff --git a/src/theory/uf/theory_uf_rewriter.h b/src/theory/uf/theory_uf_rewriter.h index 7f4c1c164..1cc8ab876 100644 --- a/src/theory/uf/theory_uf_rewriter.h +++ b/src/theory/uf/theory_uf_rewriter.h @@ -21,19 +21,19 @@ #define CVC4__THEORY__UF__THEORY_UF_REWRITER_H #include "expr/node_algorithm.h" -#include "theory/rewriter.h" -#include "theory/substitutions.h" #include "options/uf_options.h" +#include "theory/substitutions.h" +#include "theory/theory_rewriter.h" namespace CVC4 { namespace theory { namespace uf { -class TheoryUfRewriter { - -public: - - static RewriteResponse postRewrite(TNode node) { +class TheoryUfRewriter : public TheoryRewriter +{ + public: + RewriteResponse postRewrite(TNode node) override + { if(node.getKind() == kind::EQUAL) { if(node[0] == node[1]) { return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); @@ -142,7 +142,8 @@ public: return RewriteResponse(REWRITE_DONE, node); } - static RewriteResponse preRewrite(TNode node) { + RewriteResponse preRewrite(TNode node) override + { if(node.getKind() == kind::EQUAL) { if(node[0] == node[1]) { return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); @@ -154,9 +155,6 @@ public: return RewriteResponse(REWRITE_DONE, node); } - static inline void init() {} - static inline void shutdown() {} - public: //conversion between HO_APPLY AND APPLY_UF // converts an APPLY_UF to a curried HO_APPLY e.g. (f a b) becomes (@ (@ f a) b) static Node getHoApplyForApplyUf(TNode n) { @@ -207,7 +205,7 @@ public: //conversion between HO_APPLY AND APPLY_UF static inline bool canUseAsApplyUfOperator(TNode n){ return n.isVar(); } -};/* class TheoryUfRewriter */ +}; /* class TheoryUfRewriter */ }/* CVC4::theory::uf namespace */ }/* CVC4::theory namespace */ -- 2.30.2