From: Andres Noetzli Date: Fri, 3 Apr 2020 21:52:45 +0000 (-0700) Subject: Update theory rewriter ownership, add stats to strings (#4202) X-Git-Tag: cvc5-1.0.0~3408 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=aeede74491d1db9c5bac771e78b79934ca4ab552;p=cvc5.git Update theory rewriter ownership, add stats to strings (#4202) This commit adds statistics for string rewrites. This is work towards proof support in the string solver. At a high level, this commit adds a pointer to a `SequenceStatistics` in the rewriters and modifies `SequencesRewriter::returnRewrite()` to count the rewrites done. In practice, to make this work requires a couple of changes, some of them temporary: - We can't have a single `Rewriter` instance shared between different `SmtEngine` instances anymore. Thus the `Rewriter` is now owned by the `SmtEngine` and calling the rewriter retrieves the rewriter associated with the current `SmtEngine`. This is a temporary workaround before we get rid of singletons. - Methods in the `SequencesRewriter` and the `StringsRewriter` are made non-`static` because they need access to the statistics instance. - `StringsEntail` now has non-`static` methods because it needs a reference to the sequences rewriter that it can call. - The interaction between the `StringsRewriter` and the `SequencesRewriter` changed: the `StringsRewriter` is now a proper `TheoryRewriter` that inherits from `SequencesRewriter` and calls its `postRewrite()` before applying its own rewrites (this is essentially a reversal of roles from before: the `SequencesRewriter` used to call `static` methods in the `StringsRewriter`). - The theory rewriters are now owned by the individual theories. This design mirrors the `EqualityEngine`s owned by the individual theories. --- diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 2e1716543..03d9409be 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -802,6 +802,7 @@ SmtEngine::SmtEngine(ExprManager* em) d_theoryEngine(nullptr), d_propEngine(nullptr), d_proofManager(nullptr), + d_rewriter(new theory::Rewriter()), d_definedFunctions(nullptr), d_fmfRecFunctionsDefined(nullptr), d_assertionList(nullptr), diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 37b89cfb7..3f24c8bab 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -108,6 +108,7 @@ namespace smt { namespace theory { class TheoryModel; + class Rewriter; }/* CVC4::theory namespace */ // TODO: SAT layer (esp. CNF- versus non-clausal solvers under the @@ -134,6 +135,7 @@ class CVC4_PUBLIC SmtEngine friend class ::CVC4::LogicRequest; friend class ::CVC4::Model; // to access d_modelCommands friend class ::CVC4::theory::TheoryModel; + friend class ::CVC4::theory::Rewriter; /* ....................................................................... */ public: @@ -876,6 +878,9 @@ class CVC4_PUBLIC SmtEngine /** Get a pointer to the ProofManager owned by this SmtEngine. */ ProofManager* getProofManager() { return d_proofManager.get(); }; + /** Get a pointer to the Rewriter owned by this SmtEngine. */ + theory::Rewriter* getRewriter() { return d_rewriter.get(); } + /** Get a pointer to the StatisticsRegistry owned by this SmtEngine. */ StatisticsRegistry* getStatisticsRegistry() { @@ -1085,6 +1090,14 @@ class CVC4_PUBLIC SmtEngine /** The proof manager */ std::unique_ptr d_proofManager; + /** + * The rewriter associated with this SmtEngine. We have a different instance + * of the rewriter for each SmtEngine instance. This is because rewriters may + * hold references to objects that belong to theory solvers, which are + * specific to an SmtEngine/TheoryEngine instance. + */ + std::unique_ptr d_rewriter; + /** An index of our defined functions */ DefinedFunctionMap* d_definedFunctions; diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 2c748f188..e4aeca980 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -54,9 +54,9 @@ TheoryArith::~TheoryArith(){ delete d_internal; } -std::unique_ptr TheoryArith::mkTheoryRewriter() +TheoryRewriter* TheoryArith::getTheoryRewriter() { - return std::unique_ptr(new ArithRewriter()); + return d_internal->getTheoryRewriter(); } void TheoryArith::preRegisterTerm(TNode n){ diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 92892d2ae..a16f1ed5e 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -51,7 +51,7 @@ public: Valuation valuation, const LogicInfo& logicInfo); virtual ~TheoryArith(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override; /** * Does non-context dependent setup for a node connected to a theory. diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index f964c4e04..a4469f3b5 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -426,6 +426,8 @@ public: TheoryArithPrivate(TheoryArith& containing, context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); ~TheoryArithPrivate(); + TheoryRewriter* getTheoryRewriter() { return &d_rewriter; } + /** * Does non-context dependent setup for a node connected to a theory. */ @@ -882,6 +884,8 @@ private: NodeMap d_int_div_skolem; NodeMap d_nlin_inverse_skolem; + /** The theory rewriter for this theory. */ + ArithRewriter d_rewriter; };/* class TheoryArithPrivate */ }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 787ae84e2..e4b1e1c4c 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -179,11 +179,6 @@ TheoryArrays::~TheoryArrays() { smtStatisticsRegistry()->unregisterStat(&d_numSetModelValConflicts); } -std::unique_ptr TheoryArrays::mkTheoryRewriter() -{ - return std::unique_ptr(new TheoryArraysRewriter()); -} - void TheoryArrays::setMasterEqualityEngine(eq::EqualityEngine* eq) { d_equalityEngine.setMasterEqualityEngine(eq); } diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index d1f912d95..34cf6c424 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -27,6 +27,7 @@ #include "context/cdqueue.h" #include "theory/arrays/array_info.h" #include "theory/arrays/array_proof_reconstruction.h" +#include "theory/arrays/theory_arrays_rewriter.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" #include "util/statistics_registry.h" @@ -144,7 +145,7 @@ class TheoryArrays : public Theory { std::string name = ""); ~TheoryArrays(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } void setMasterEqualityEngine(eq::EqualityEngine* eq) override; @@ -177,6 +178,9 @@ class TheoryArrays : public Theory { bool ppDisequal(TNode a, TNode b); Node solveWrite(TNode term, bool solve1, bool solve2, bool ppCheck); + /** The theory rewriter for this theory. */ + TheoryArraysRewriter d_rewriter; + public: PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; Node ppRewrite(TNode atom) override; diff --git a/src/theory/booleans/theory_bool.cpp b/src/theory/booleans/theory_bool.cpp index e670121d1..29f5bb82d 100644 --- a/src/theory/booleans/theory_bool.cpp +++ b/src/theory/booleans/theory_bool.cpp @@ -33,9 +33,13 @@ namespace CVC4 { namespace theory { namespace booleans { -std::unique_ptr TheoryBool::mkTheoryRewriter() +TheoryBool::TheoryBool(context::Context* c, + context::UserContext* u, + OutputChannel& out, + Valuation valuation, + const LogicInfo& logicInfo) + : Theory(THEORY_BOOL, c, u, out, valuation, logicInfo) { - return std::unique_ptr(new TheoryBoolRewriter()); } Theory::PPAssertStatus TheoryBool::ppAssert(TNode in, SubstitutionMap& outSubstitutions) { diff --git a/src/theory/booleans/theory_bool.h b/src/theory/booleans/theory_bool.h index 75e375ee6..ae498165f 100644 --- a/src/theory/booleans/theory_bool.h +++ b/src/theory/booleans/theory_bool.h @@ -19,27 +19,33 @@ #ifndef CVC4__THEORY__BOOLEANS__THEORY_BOOL_H #define CVC4__THEORY__BOOLEANS__THEORY_BOOL_H -#include "theory/theory.h" #include "context/context.h" +#include "theory/booleans/theory_bool_rewriter.h" +#include "theory/theory.h" namespace CVC4 { namespace theory { namespace booleans { class TheoryBool : public Theory { -public: - TheoryBool(context::Context* c, context::UserContext* u, OutputChannel& out, - Valuation valuation, const LogicInfo& logicInfo) - : Theory(THEORY_BOOL, c, u, out, valuation, logicInfo) - {} + public: + TheoryBool(context::Context* c, + context::UserContext* u, + OutputChannel& out, + Valuation valuation, + const LogicInfo& logicInfo); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; //void check(Effort); std::string identify() const override { return std::string("TheoryBool"); } + + private: + /** The theory rewriter for this theory. */ + TheoryBoolRewriter d_rewriter; };/* class TheoryBool */ }/* CVC4::theory::booleans namespace */ diff --git a/src/theory/builtin/theory_builtin.cpp b/src/theory/builtin/theory_builtin.cpp index 8df5a8535..b9d05b833 100644 --- a/src/theory/builtin/theory_builtin.cpp +++ b/src/theory/builtin/theory_builtin.cpp @@ -36,11 +36,6 @@ TheoryBuiltin::TheoryBuiltin(context::Context* c, { } -std::unique_ptr TheoryBuiltin::mkTheoryRewriter() -{ - return std::unique_ptr(new TheoryBuiltinRewriter()); -} - std::string TheoryBuiltin::identify() const { return std::string("TheoryBuiltin"); diff --git a/src/theory/builtin/theory_builtin.h b/src/theory/builtin/theory_builtin.h index d240f4f63..bf99003ec 100644 --- a/src/theory/builtin/theory_builtin.h +++ b/src/theory/builtin/theory_builtin.h @@ -19,6 +19,7 @@ #ifndef CVC4__THEORY__BUILTIN__THEORY_BUILTIN_H #define CVC4__THEORY__BUILTIN__THEORY_BUILTIN_H +#include "theory/builtin/theory_builtin_rewriter.h" #include "theory/theory.h" namespace CVC4 { @@ -34,12 +35,16 @@ class TheoryBuiltin : public Theory Valuation valuation, const LogicInfo& logicInfo); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } std::string identify() const override; /** finish initialization */ void finishInit() override; + + private: + /** The theory rewriter for this theory. */ + TheoryBuiltinRewriter d_rewriter; }; /* class TheoryBuiltin */ } // namespace builtin diff --git a/src/theory/builtin/theory_builtin_rewriter.cpp b/src/theory/builtin/theory_builtin_rewriter.cpp index a39d4231b..dd6d434ca 100644 --- a/src/theory/builtin/theory_builtin_rewriter.cpp +++ b/src/theory/builtin/theory_builtin_rewriter.cpp @@ -84,8 +84,10 @@ RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) { Assert(retNode.getType() == node.getType()); Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode)); return RewriteResponse(REWRITE_DONE, retNode); - } - }else{ + } + } + else + { Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl; } return RewriteResponse(REWRITE_DONE, node); diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 27718b63f..94fc1e34c 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -112,11 +112,6 @@ TheoryBV::TheoryBV(context::Context* c, TheoryBV::~TheoryBV() {} -std::unique_ptr TheoryBV::mkTheoryRewriter() -{ - return std::unique_ptr(new TheoryBVRewriter()); -} - void TheoryBV::setMasterEqualityEngine(eq::EqualityEngine* eq) { if (options::bitblastMode() == options::BitblastMode::EAGER) { diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index ff1c9245a..bc54a09e7 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -26,6 +26,7 @@ #include "context/cdlist.h" #include "context/context.h" #include "theory/bv/bv_subtheory.h" +#include "theory/bv/theory_bv_rewriter.h" #include "theory/bv/theory_bv_utils.h" #include "theory/theory.h" #include "util/hash.h" @@ -72,7 +73,7 @@ public: ~TheoryBV(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } void setMasterEqualityEngine(eq::EqualityEngine* eq) override; @@ -259,6 +260,9 @@ public: void checkForLemma(TNode node); + /** The theory rewriter for this theory. */ + TheoryBVRewriter d_rewriter; + friend class LazyBitblaster; friend class TLazyBitblaster; friend class EagerBitblaster; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 15220b9dc..b3244fe91 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -86,11 +86,6 @@ TheoryDatatypes::~TheoryDatatypes() { } } -std::unique_ptr TheoryDatatypes::mkTheoryRewriter() -{ - return std::unique_ptr(new DatatypesRewriter()); -} - void TheoryDatatypes::setMasterEqualityEngine(eq::EqualityEngine* eq) { d_equalityEngine.setMasterEqualityEngine(eq); } diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index 7ccd04f39..cc54241d0 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -26,6 +26,7 @@ #include "expr/attribute.h" #include "expr/datatype.h" #include "expr/node_trie.h" +#include "theory/datatypes/datatypes_rewriter.h" #include "theory/datatypes/sygus_extension.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" @@ -271,7 +272,7 @@ private: const LogicInfo& logicInfo); ~TheoryDatatypes(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } void setMasterEqualityEngine(eq::EqualityEngine* eq) override; @@ -370,10 +371,13 @@ private: bool areDisequal( TNode a, TNode b ); bool areCareDisequal( TNode x, TNode y ); TNode getRepresentative( TNode a ); -private: - /** sygus symmetry breaking utility */ - std::unique_ptr d_sygusExtension; + private: + /** sygus symmetry breaking utility */ + std::unique_ptr d_sygusExtension; + + /** The theory rewriter for this theory. */ + DatatypesRewriter d_rewriter; };/* class TheoryDatatypes */ }/* CVC4::theory::datatypes namespace */ diff --git a/src/theory/fp/theory_fp.cpp b/src/theory/fp/theory_fp.cpp index 5ab285766..b028184cd 100644 --- a/src/theory/fp/theory_fp.cpp +++ b/src/theory/fp/theory_fp.cpp @@ -174,14 +174,8 @@ TheoryFp::TheoryFp(context::Context *c, d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_EXPONENT); d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND); d_equalityEngine.addFunctionKind(kind::ROUNDINGMODE_BITBLAST); - } /* TheoryFp::TheoryFp() */ -std::unique_ptr TheoryFp::mkTheoryRewriter() -{ - return std::unique_ptr(new TheoryFpRewriter()); -} - Node TheoryFp::minUF(Node node) { Assert(node.getKind() == kind::FLOATINGPOINT_MIN); TypeNode t(node.getType()); diff --git a/src/theory/fp/theory_fp.h b/src/theory/fp/theory_fp.h index 802a70435..ae4a2a1cb 100644 --- a/src/theory/fp/theory_fp.h +++ b/src/theory/fp/theory_fp.h @@ -25,6 +25,7 @@ #include "context/cdo.h" #include "theory/fp/fp_converter.h" +#include "theory/fp/theory_fp_rewriter.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" @@ -38,7 +39,7 @@ class TheoryFp : public Theory { TheoryFp(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } Node expandDefinition(LogicRequest& lr, Node node) override; @@ -144,6 +145,8 @@ class TheoryFp : public Theory { bool refineAbstraction(TheoryModel* m, TNode abstract, TNode concrete); + /** The theory rewriter for this theory. */ + TheoryFpRewriter d_rewriter; }; /* class TheoryFp */ } // namespace fp diff --git a/src/theory/quantifiers/candidate_rewrite_database.cpp b/src/theory/quantifiers/candidate_rewrite_database.cpp index e00bc957f..9070ed51d 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.cpp +++ b/src/theory/quantifiers/candidate_rewrite_database.cpp @@ -293,39 +293,6 @@ void CandidateRewriteDatabase::setExtendedRewriter(ExtendedRewriter* er) d_ext_rewrite = er; } -CandidateRewriteDatabaseGen::CandidateRewriteDatabaseGen( - std::vector& vars, unsigned nsamples) - : d_qe(nullptr), d_vars(vars.begin(), vars.end()), d_nsamples(nsamples) -{ -} - -bool CandidateRewriteDatabaseGen::addTerm(Node n, std::ostream& out) -{ - ExtendedRewriter* er = &d_ext_rewrite; - Node nr; - if (er == nullptr) - { - nr = Rewriter::rewrite(n); - } - else - { - nr = er->extendedRewrite(n); - } - TypeNode tn = nr.getType(); - std::map::iterator itc = d_cdbs.find(tn); - if (itc == d_cdbs.end()) - { - Trace("synth-rr-dbg") << "Initialize database for " << tn << std::endl; - // initialize with the extended rewriter owned by this class - d_cdbs[tn].initialize(d_vars, &d_sampler[tn]); - d_cdbs[tn].setExtendedRewriter(er); - itc = d_cdbs.find(tn); - Trace("synth-rr-dbg") << "...finish." << std::endl; - } - Trace("synth-rr-dbg") << "Add term " << nr << " for " << tn << std::endl; - return itc->second.addTerm(nr, false, out); -} - } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/candidate_rewrite_database.h b/src/theory/quantifiers/candidate_rewrite_database.h index b68b20998..a9ca659e1 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.h +++ b/src/theory/quantifiers/candidate_rewrite_database.h @@ -103,44 +103,6 @@ class CandidateRewriteDatabase : public ExprMiner bool d_silent; }; -/** - * This class generates and stores candidate rewrite databases for multiple - * types as needed. - */ -class CandidateRewriteDatabaseGen -{ - public: - /** constructor - * - * vars : the variables we are testing substitutions for, for all types, - * nsamples : number of sample points this class will test for all types. - */ - CandidateRewriteDatabaseGen(std::vector& vars, unsigned nsamples); - /** add term - * - * This registers term n with this class. We generate the candidate rewrite - * database of the appropriate type (if not allocated already), and register - * n with this database. This may result in "candidate-rewrite" being - * printed on the output stream out. We return true if the term sol is - * distinct (up to equivalence) with all previous terms added to this class. - */ - bool addTerm(Node n, std::ostream& out); - - private: - /** reference to quantifier engine */ - QuantifiersEngine* d_qe; - /** the variables */ - std::vector d_vars; - /** sygus sampler object for each type */ - std::map d_sampler; - /** the number of samples */ - unsigned d_nsamples; - /** candidate rewrite databases for each type */ - std::map d_cdbs; - /** an extended rewriter object */ - ExtendedRewriter d_ext_rewrite; -}; - } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 7920ecbeb..b0a474c56 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -1692,7 +1692,7 @@ Node ExtendedRewriter::extendedRewriteStrings(Node ret) if (ret.getKind() == EQUAL) { - new_ret = strings::SequencesRewriter::rewriteEqualityExt(ret); + new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret); } return new_ret; diff --git a/src/theory/quantifiers/theory_quantifiers.cpp b/src/theory/quantifiers/theory_quantifiers.cpp index e3e3c3824..5253638a7 100644 --- a/src/theory/quantifiers/theory_quantifiers.cpp +++ b/src/theory/quantifiers/theory_quantifiers.cpp @@ -54,11 +54,6 @@ TheoryQuantifiers::TheoryQuantifiers(Context* c, context::UserContext* u, Output TheoryQuantifiers::~TheoryQuantifiers() { } -std::unique_ptr TheoryQuantifiers::mkTheoryRewriter() -{ - return std::unique_ptr(new QuantifiersRewriter()); -} - void TheoryQuantifiers::finishInit() { // quantifiers are not evaluated in getModelValue diff --git a/src/theory/quantifiers/theory_quantifiers.h b/src/theory/quantifiers/theory_quantifiers.h index 7efe7419c..991a47ad0 100644 --- a/src/theory/quantifiers/theory_quantifiers.h +++ b/src/theory/quantifiers/theory_quantifiers.h @@ -23,6 +23,7 @@ #include "context/context.h" #include "expr/node.h" #include "theory/output_channel.h" +#include "theory/quantifiers/quantifiers_rewriter.h" #include "theory/theory.h" #include "theory/theory_engine.h" #include "theory/valuation.h" @@ -39,7 +40,7 @@ class TheoryQuantifiers : public Theory { const LogicInfo& logicInfo); ~TheoryQuantifiers(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } /** finish initialization */ void finishInit() override; @@ -58,6 +59,9 @@ class TheoryQuantifiers : public Theory { std::vector node_values, std::string str_value) override; + private: + /** The theory rewriter for this theory. */ + QuantifiersRewriter d_rewriter; };/* class TheoryQuantifiers */ }/* CVC4::theory::quantifiers namespace */ diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 54b9f319d..ca552de66 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -18,6 +18,7 @@ #include "theory/rewriter.h" #include "options/theory_options.h" +#include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" #include "theory/rewriter_tables.h" @@ -93,13 +94,13 @@ Node Rewriter::rewrite(TNode node) { // eagerly for the sake of efficiency here. return node; } - return getInstance().rewriteTo(theoryOf(node), node); + return getInstance()->rewriteTo(theoryOf(node), node); } void Rewriter::registerTheoryRewriter(theory::TheoryId tid, - std::unique_ptr trew) + TheoryRewriter* trew) { - getInstance().d_theoryRewriters[tid] = std::move(trew); + getInstance()->d_theoryRewriters[tid] = trew; } void Rewriter::registerPreRewrite( @@ -130,10 +131,9 @@ void Rewriter::registerPostRewriteEqual( d_postRewritersEqual[tid] = fn; } -Rewriter& Rewriter::getInstance() +Rewriter* Rewriter::getInstance() { - thread_local static Rewriter rewriter; - return rewriter; + return smt::currentSmtEngine()->getRewriter(); } Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { @@ -351,13 +351,13 @@ RewriteResponse Rewriter::postRewrite(theory::TheoryId theoryId, TNode n) } void Rewriter::clearCaches() { - Rewriter& rewriter = getInstance(); + Rewriter* rewriter = getInstance(); #ifdef CVC4_ASSERTIONS - rewriter.d_rewriteStack.reset(nullptr); + rewriter->d_rewriteStack.reset(nullptr); #endif - rewriter.clearCachesInternal(); + rewriter->clearCachesInternal(); } }/* CVC4::theory namespace */ diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index 8a641743b..32a8005d1 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -63,14 +63,14 @@ class Rewriter { static void clearCaches(); /** - * Registers a theory rewriter with this rewriter. This transfers the - * ownership of the theory rewriter to the rewriter. + * Registers a theory rewriter with this rewriter. The rewriter does not own + * the theory rewriters. * * @param tid The theory that the theory rewriter should be associated with. * @param trew The theory rewriter to register. */ static void registerTheoryRewriter(theory::TheoryId tid, - std::unique_ptr trew); + TheoryRewriter* trew); /** * Register a prerewrite for a given kind. @@ -112,11 +112,12 @@ class Rewriter { private: /** - * Get the (singleton) instance of the rewriter. + * Get the rewriter associated with the SmtEngine in scope. * - * TODO(#3468): Get rid of this singleton + * TODO(#3468): Get rid of this function (it relies on there being an + * singleton with the current SmtEngine in scope) */ - static Rewriter& getInstance(); + static Rewriter* getInstance(); /** Returns the appropriate cache for a node */ Node getPreRewriteCache(theory::TheoryId theoryId, TNode node); @@ -148,8 +149,8 @@ class Rewriter { void clearCachesInternal(); - /** Theory rewriters managed by this rewriter instance */ - std::unique_ptr d_theoryRewriters[theory::THEORY_LAST]; + /** Theory rewriters used by this rewriter instance */ + TheoryRewriter* d_theoryRewriters[theory::THEORY_LAST]; unsigned long d_iterationCount = 0; diff --git a/src/theory/sep/theory_sep.cpp b/src/theory/sep/theory_sep.cpp index 1d0ad904c..d90be94f9 100644 --- a/src/theory/sep/theory_sep.cpp +++ b/src/theory/sep/theory_sep.cpp @@ -65,11 +65,6 @@ TheorySep::~TheorySep() { } } -std::unique_ptr TheorySep::mkTheoryRewriter() -{ - return std::unique_ptr(new TheorySepRewriter()); -} - void TheorySep::setMasterEqualityEngine(eq::EqualityEngine* eq) { d_equalityEngine.setMasterEqualityEngine(eq); } diff --git a/src/theory/sep/theory_sep.h b/src/theory/sep/theory_sep.h index df3294882..ce1498f52 100644 --- a/src/theory/sep/theory_sep.h +++ b/src/theory/sep/theory_sep.h @@ -23,6 +23,7 @@ #include "context/cdhashset.h" #include "context/cdlist.h" #include "context/cdqueue.h" +#include "theory/sep/theory_sep_rewriter.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" #include "util/statistics_registry.h" @@ -56,6 +57,8 @@ class TheorySep : public Theory { //whether bounds have been initialized bool d_bounds_init; + TheorySepRewriter d_rewriter; + Node mkAnd( std::vector< TNode >& assumptions ); int processAssertion( Node n, std::map< int, std::map< Node, int > >& visited, @@ -66,7 +69,7 @@ class TheorySep : public Theory { TheorySep(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); ~TheorySep(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } void setMasterEqualityEngine(eq::EqualityEngine* eq) override; diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index 0b9e6d934..f4b265d98 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -46,9 +46,9 @@ TheorySets::~TheorySets() // Do not move me to the header. See explanation in the constructor. } -std::unique_ptr TheorySets::mkTheoryRewriter() +TheoryRewriter* TheorySets::getTheoryRewriter() { - return std::unique_ptr(new TheorySetsRewriter()); + return d_internal->getTheoryRewriter(); } void TheorySets::finishInit() diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index a55a22600..5fc1a61a3 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -42,7 +42,7 @@ class TheorySets : public Theory const LogicInfo& logicInfo); ~TheorySets() override; - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override; /** finish initialization */ void finishInit() override; diff --git a/src/theory/sets/theory_sets_private.h b/src/theory/sets/theory_sets_private.h index 5ef8c4825..ab9071793 100644 --- a/src/theory/sets/theory_sets_private.h +++ b/src/theory/sets/theory_sets_private.h @@ -26,6 +26,7 @@ #include "theory/sets/inference_manager.h" #include "theory/sets/solver_state.h" #include "theory/sets/theory_sets_rels.h" +#include "theory/sets/theory_sets_rewriter.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" @@ -161,6 +162,8 @@ class TheorySetsPrivate { ~TheorySetsPrivate(); + TheoryRewriter* getTheoryRewriter() { return &d_rewriter; } + void setMasterEqualityEngine(eq::EqualityEngine* eq); void addSharedTerm(TNode); @@ -279,6 +282,9 @@ class TheorySetsPrivate { * involving cardinality constraints is asserted to this theory. */ bool d_card_enabled; + + /** The theory rewriter for this theory. */ + TheorySetsRewriter d_rewriter; };/* class TheorySetsPrivate */ diff --git a/src/theory/strings/arith_entail.cpp b/src/theory/strings/arith_entail.cpp index 5933f2586..71680264d 100644 --- a/src/theory/strings/arith_entail.cpp +++ b/src/theory/strings/arith_entail.cpp @@ -66,11 +66,9 @@ bool ArithEntail::check(Node a, bool strict) return a.getConst().sgn() >= (strict ? 1 : 0); } - Node ar = - strict - ? NodeManager::currentNM()->mkNode( + Node ar = strict ? NodeManager::currentNM()->mkNode( kind::MINUS, a, NodeManager::currentNM()->mkConst(Rational(1))) - : a; + : a; ar = Rewriter::rewrite(ar); if (ar.getAttribute(StrCheckEntailArithComputedAttr())) diff --git a/src/theory/strings/extf_solver.cpp b/src/theory/strings/extf_solver.cpp index a1c04848a..0c46881e7 100644 --- a/src/theory/strings/extf_solver.cpp +++ b/src/theory/strings/extf_solver.cpp @@ -32,18 +32,20 @@ ExtfSolver::ExtfSolver(context::Context* c, SolverState& s, InferenceManager& im, SkolemCache& skc, + StringsRewriter& rewriter, BaseSolver& bs, CoreSolver& cs, ExtTheory* et, - SequencesStatistics& stats) + SequencesStatistics& statistics) : d_state(s), d_im(im), d_skCache(skc), + d_rewriter(rewriter), d_bsolver(bs), d_csolver(cs), d_extt(et), - d_statistics(stats), - d_preproc(&skc, u, stats), + d_statistics(statistics), + d_preproc(&skc, u, statistics), d_hasExtf(c, false), d_extfInferCache(c) { @@ -620,7 +622,7 @@ void ExtfSolver::checkExtfInference(Node n, if (inferEqr.getKind() == EQUAL) { // try to use the extended rewriter for equalities - inferEqrr = SequencesRewriter::rewriteEqualityExt(inferEqr); + inferEqrr = d_rewriter.rewriteEqualityExt(inferEqr); } if (inferEqrr != inferEqr) { diff --git a/src/theory/strings/extf_solver.h b/src/theory/strings/extf_solver.h index 9ca72fed2..e7e2512bd 100644 --- a/src/theory/strings/extf_solver.h +++ b/src/theory/strings/extf_solver.h @@ -17,8 +17,8 @@ #ifndef CVC4__THEORY__STRINGS__EXTF_SOLVER_H #define CVC4__THEORY__STRINGS__EXTF_SOLVER_H -#include #include +#include #include "context/cdo.h" #include "expr/node.h" @@ -29,6 +29,7 @@ #include "theory/strings/sequences_stats.h" #include "theory/strings/skolem_cache.h" #include "theory/strings/solver_state.h" +#include "theory/strings/strings_rewriter.h" #include "theory/strings/theory_strings_preprocess.h" namespace CVC4 { @@ -87,10 +88,11 @@ class ExtfSolver SolverState& s, InferenceManager& im, SkolemCache& skc, + StringsRewriter& rewriter, BaseSolver& bs, CoreSolver& cs, ExtTheory* et, - SequencesStatistics& stats); + SequencesStatistics& statistics); ~ExtfSolver(); /** check extended functions evaluation @@ -180,6 +182,8 @@ class ExtfSolver InferenceManager& d_im; /** cache of all skolems */ SkolemCache& d_skCache; + /** The theory rewriter for this theory. */ + StringsRewriter& d_rewriter; /** reference to the base solver, used for certain queries */ BaseSolver& d_bsolver; /** reference to the core solver, used for certain queries */ diff --git a/src/theory/strings/regexp_entail.cpp b/src/theory/strings/regexp_entail.cpp index d03893483..a43ec4430 100644 --- a/src/theory/strings/regexp_entail.cpp +++ b/src/theory/strings/regexp_entail.cpp @@ -438,7 +438,9 @@ bool RegExpEntail::testConstStringInRegExp(CVC4::String& s, return true; } } - case REGEXP_EMPTY: { return false; + case REGEXP_EMPTY: + { + return false; } case REGEXP_SIGMA: { diff --git a/src/theory/strings/sequences_rewriter.cpp b/src/theory/strings/sequences_rewriter.cpp index 9ccda55c2..152f71019 100644 --- a/src/theory/strings/sequences_rewriter.cpp +++ b/src/theory/strings/sequences_rewriter.cpp @@ -21,7 +21,6 @@ #include "theory/rewriter.h" #include "theory/strings/arith_entail.h" #include "theory/strings/regexp_entail.h" -#include "theory/strings/strings_entail.h" #include "theory/strings/strings_rewriter.h" #include "theory/strings/theory_strings_utils.h" #include "theory/strings/word.h" @@ -33,6 +32,11 @@ namespace CVC4 { namespace theory { namespace strings { +SequencesRewriter::SequencesRewriter(HistogramStat* statistics) + : d_statistics(statistics), d_stringsEntail(*this) +{ +} + Node SequencesRewriter::rewriteEquality(Node node) { Assert(node.getKind() == kind::EQUAL); @@ -53,7 +57,7 @@ Node SequencesRewriter::rewriteEquality(Node node) // must call rewrite contains directly to avoid infinite loop // we do a fix point since we may rewrite contains terms to simpler // contains terms. - Node ctn = StringsEntail::checkContains(node[r], node[1 - r], false); + Node ctn = d_stringsEntail.checkContains(node[r], node[1 - r], false); if (!ctn.isNull()) { if (!ctn.getConst()) @@ -810,7 +814,7 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node) // e.g. this ensures we rewrite (a)* ++ (_)* ---> (_)* while (!cvec.empty() && RegExpEntail::isConstRegExp(cvec.back()) && RegExpEntail::testConstStringInRegExp( - emptyStr, 0, cvec.back())) + emptyStr, 0, cvec.back())) { cvec.pop_back(); } @@ -1337,8 +1341,8 @@ Node SequencesRewriter::rewriteMembership(TNode node) RewriteResponse SequencesRewriter::postRewrite(TNode node) { - Trace("strings-postrewrite") - << "Strings::postRewrite start " << node << std::endl; + Trace("sequences-postrewrite") + << "Strings::SequencesRewriter::postRewrite start " << node << std::endl; Node retNode = node; Kind nk = node.getKind(); if (nk == kind::STRING_CONCAT) @@ -1365,14 +1369,6 @@ RewriteResponse SequencesRewriter::postRewrite(TNode node) { retNode = rewriteContains(node); } - else if (nk == kind::STRING_LT) - { - retNode = StringsRewriter::rewriteStringLt(node); - } - else if (nk == kind::STRING_LEQ) - { - retNode = StringsRewriter::rewriteStringLeq(node); - } else if (nk == kind::STRING_STRIDOF) { retNode = rewriteIndexof(node); @@ -1385,10 +1381,6 @@ RewriteResponse SequencesRewriter::postRewrite(TNode node) { retNode = rewriteReplaceAll(node); } - else if (nk == STRING_TOLOWER || nk == STRING_TOUPPER) - { - retNode = StringsRewriter::rewriteStrConvert(node); - } else if (nk == STRING_REV) { retNode = rewriteStrReverse(node); @@ -1397,30 +1389,10 @@ RewriteResponse SequencesRewriter::postRewrite(TNode node) { retNode = rewritePrefixSuffix(node); } - else if (nk == STRING_IS_DIGIT) - { - retNode = StringsRewriter::rewriteStringIsDigit(node); - } - else if (nk == kind::STRING_ITOS) - { - retNode = StringsRewriter::rewriteIntToStr(node); - } - else if (nk == kind::STRING_STOI) - { - retNode = StringsRewriter::rewriteStrToInt(node); - } else if (nk == kind::STRING_IN_REGEXP) { retNode = rewriteMembership(node); } - else if (nk == STRING_TO_CODE) - { - retNode = StringsRewriter::rewriteStringToCode(node); - } - else if (nk == STRING_FROM_CODE) - { - retNode = StringsRewriter::rewriteStringFromCode(node); - } else if (nk == REGEXP_CONCAT) { retNode = rewriteConcatRegExp(node); @@ -1458,12 +1430,13 @@ RewriteResponse SequencesRewriter::postRewrite(TNode node) retNode = rewriteRepeatRegExp(node); } - Trace("strings-postrewrite") - << "Strings::postRewrite returning " << retNode << std::endl; + Trace("sequences-postrewrite") + << "Strings::SequencesRewriter::postRewrite returning " << retNode + << std::endl; if (node != retNode) { - Trace("strings-rewrite-debug") - << "Strings: post-rewrite " << node << " to " << retNode << std::endl; + Trace("strings-rewrite-debug") << "Strings::SequencesRewriter::postRewrite " + << node << " to " << retNode << std::endl; return RewriteResponse(REWRITE_AGAIN_FULL, retNode); } return RewriteResponse(REWRITE_DONE, retNode); @@ -1866,7 +1839,7 @@ Node SequencesRewriter::rewriteContains(Node node) } else if (node[0].getKind() == STRING_STRREPL) { - Node rplDomain = StringsEntail::checkContains(node[0][1], node[1]); + Node rplDomain = d_stringsEntail.checkContains(node[0][1], node[1]); if (!rplDomain.isNull() && !rplDomain.getConst()) { Node d1 = nm->mkNode(STRING_STRCTN, node[0][0], node[1]); @@ -1892,7 +1865,7 @@ Node SequencesRewriter::rewriteContains(Node node) // component-wise containment std::vector nc1rb; std::vector nc1re; - if (StringsEntail::componentContains(nc1, nc2, nc1rb, nc1re) != -1) + if (d_stringsEntail.componentContains(nc1, nc2, nc1rb, nc1re) != -1) { Node ret = NodeManager::currentNM()->mkConst(true); return returnRewrite(node, ret, Rewrite::CTN_COMPONENT); @@ -1920,10 +1893,10 @@ Node SequencesRewriter::rewriteContains(Node node) // replacement does not change y. (str.contains x w) checks that if the // replacement changes anything in y, the w makes it impossible for it to // occur in x. - Node ctnConst = StringsEntail::checkContains(node[0], n[0]); + Node ctnConst = d_stringsEntail.checkContains(node[0], n[0]); if (!ctnConst.isNull() && !ctnConst.getConst()) { - Node ctnConst2 = StringsEntail::checkContains(node[0], n[2]); + Node ctnConst2 = d_stringsEntail.checkContains(node[0], n[2]); if (!ctnConst2.isNull() && !ctnConst2.getConst()) { Node res = nm->mkConst(false); @@ -2091,7 +2064,7 @@ Node SequencesRewriter::rewriteContains(Node node) // if (str.contains z w) ---> false and (str.len w) = 1 if (StringsEntail::checkLengthOne(node[1])) { - Node ctn = StringsEntail::checkContains(node[1], node[0][2]); + Node ctn = d_stringsEntail.checkContains(node[1], node[0][2]); if (!ctn.isNull() && !ctn.getConst()) { Node empty = nm->mkConst(String("")); @@ -2235,7 +2208,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) fstr = Rewriter::rewrite(fstr); } - Node cmp_conr = StringsEntail::checkContains(fstr, node[1]); + Node cmp_conr = d_stringsEntail.checkContains(fstr, node[1]); Trace("strings-rewrite-debug") << "For " << node << ", check contains(" << fstr << ", " << node[1] << ")" << std::endl; Trace("strings-rewrite-debug") << "...got " << cmp_conr << std::endl; @@ -2250,7 +2223,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) // past the first position in node[0] that contains node[1], we can drop std::vector nb; std::vector ne; - int cc = StringsEntail::componentContains( + int cc = d_stringsEntail.componentContains( children0, children1, nb, ne, true, 1); if (cc != -1 && !ne.empty()) { @@ -2445,14 +2418,14 @@ Node SequencesRewriter::rewriteReplace(Node node) // check if contains definitely does (or does not) hold Node cmp_con = nm->mkNode(kind::STRING_STRCTN, node[0], node[1]); Node cmp_conr = Rewriter::rewrite(cmp_con); - if (!StringsEntail::checkContains(node[0], node[1]).isNull()) + if (!d_stringsEntail.checkContains(node[0], node[1]).isNull()) { if (cmp_conr.getConst()) { // component-wise containment std::vector cb; std::vector ce; - int cc = StringsEntail::componentContains( + int cc = d_stringsEntail.componentContains( children0, children1, cb, ce, true, 1); if (cc != -1) { @@ -2673,7 +2646,7 @@ Node SequencesRewriter::rewriteReplace(Node node) return returnRewrite(node, node[0], Rewrite::REPL_REPL2_INV_ID); } bool dualReplIteSuccess = false; - Node cmp_con2 = StringsEntail::checkContains(node[1][0], node[1][2]); + Node cmp_con2 = d_stringsEntail.checkContains(node[1][0], node[1][2]); if (!cmp_con2.isNull() && !cmp_con2.getConst()) { // str.contains( x, z ) ---> false @@ -2688,10 +2661,10 @@ Node SequencesRewriter::rewriteReplace(Node node) // implies // str.replace( x, str.replace( x, y, z ), w ) ---> // ite( str.contains( x, y ), x, w ) - cmp_con2 = StringsEntail::checkContains(node[1][1], node[1][2]); + cmp_con2 = d_stringsEntail.checkContains(node[1][1], node[1][2]); if (!cmp_con2.isNull() && !cmp_con2.getConst()) { - cmp_con2 = StringsEntail::checkContains(node[1][2], node[1][1]); + cmp_con2 = d_stringsEntail.checkContains(node[1][2], node[1][1]); if (!cmp_con2.isNull() && !cmp_con2.getConst()) { dualReplIteSuccess = true; @@ -2721,7 +2694,7 @@ Node SequencesRewriter::rewriteReplace(Node node) // str.contains(y, z) ----> false and ( y == w or x == w ) implies // implies // str.replace(x, str.replace(y, x, z), w) ---> str.replace(x, y, w) - Node cmp_con2 = StringsEntail::checkContains(node[1][0], node[1][2]); + Node cmp_con2 = d_stringsEntail.checkContains(node[1][0], node[1][2]); invSuccess = !cmp_con2.isNull() && !cmp_con2.getConst(); } } @@ -2730,10 +2703,10 @@ Node SequencesRewriter::rewriteReplace(Node node) // str.contains(x, z) ----> false and str.contains(x, w) ----> false // implies // str.replace(x, str.replace(y, z, w), u) ---> str.replace(x, y, u) - Node cmp_con2 = StringsEntail::checkContains(node[0], node[1][1]); + Node cmp_con2 = d_stringsEntail.checkContains(node[0], node[1][1]); if (!cmp_con2.isNull() && !cmp_con2.getConst()) { - cmp_con2 = StringsEntail::checkContains(node[0], node[1][2]); + cmp_con2 = d_stringsEntail.checkContains(node[0], node[1][2]); invSuccess = !cmp_con2.isNull() && !cmp_con2.getConst(); } } @@ -2749,7 +2722,7 @@ Node SequencesRewriter::rewriteReplace(Node node) { // str.contains( z, w ) ----> false implies // str.replace( x, w, str.replace( z, x, y ) ) ---> str.replace( x, w, z ) - Node cmp_con2 = StringsEntail::checkContains(node[1], node[2][0]); + Node cmp_con2 = d_stringsEntail.checkContains(node[1], node[2][0]); if (!cmp_con2.isNull() && !cmp_con2.getConst()) { Node res = @@ -2769,7 +2742,7 @@ Node SequencesRewriter::rewriteReplace(Node node) { // str.contains( x, z ) ----> false implies // str.replace( x, y, str.replace( y, z, w ) ) ---> x - cmp_con = StringsEntail::checkContains(node[0], node[2][1]); + cmp_con = d_stringsEntail.checkContains(node[0], node[2][1]); success = !cmp_con.isNull() && !cmp_con.getConst(); } if (success) @@ -2795,7 +2768,7 @@ Node SequencesRewriter::rewriteReplace(Node node) checkLhs.end(), children0.begin(), children0.begin() + checkIndex); Node lhs = utils::mkConcat(checkLhs, stype); Node rhs = children0[checkIndex]; - Node ctn = StringsEntail::checkContains(lhs, rhs); + Node ctn = d_stringsEntail.checkContains(lhs, rhs); if (!ctn.isNull() && ctn.getConst()) { lastLhs = lhs; @@ -3102,6 +3075,11 @@ Node SequencesRewriter::returnRewrite(Node node, Node ret, Rewrite r) NodeManager* nm = NodeManager::currentNM(); + if (d_statistics != nullptr) + { + (*d_statistics) << r; + } + // standard post-processing // We rewrite (string) equalities immediately here. This allows us to forego // the standard invariant on equality rewrites (that s=t must rewrite to one diff --git a/src/theory/strings/sequences_rewriter.h b/src/theory/strings/sequences_rewriter.h index 7391a7bd0..56b74f536 100644 --- a/src/theory/strings/sequences_rewriter.h +++ b/src/theory/strings/sequences_rewriter.h @@ -22,6 +22,8 @@ #include "expr/node.h" #include "theory/strings/rewrites.h" +#include "theory/strings/sequences_stats.h" +#include "theory/strings/strings_entail.h" #include "theory/theory_rewriter.h" namespace CVC4 { @@ -30,82 +32,85 @@ namespace strings { class SequencesRewriter : public TheoryRewriter { + public: + SequencesRewriter(HistogramStat* statistics); + protected: /** rewrite regular expression concatenation * * This is the entry point for post-rewriting applications of re.++. * Returns the rewritten form of node. */ - static Node rewriteConcatRegExp(TNode node); + Node rewriteConcatRegExp(TNode node); /** rewrite regular expression star * * This is the entry point for post-rewriting applications of re.*. * Returns the rewritten form of node. */ - static Node rewriteStarRegExp(TNode node); + Node rewriteStarRegExp(TNode node); /** rewrite regular expression intersection/union * * This is the entry point for post-rewriting applications of re.inter and * re.union. Returns the rewritten form of node. */ - static Node rewriteAndOrRegExp(TNode node); + Node rewriteAndOrRegExp(TNode node); /** rewrite regular expression loop * * This is the entry point for post-rewriting applications of re.loop. * Returns the rewritten form of node. */ - static Node rewriteLoopRegExp(TNode node); + Node rewriteLoopRegExp(TNode node); /** rewrite regular expression repeat * * This is the entry point for post-rewriting applications of re.repeat. * Returns the rewritten form of node. */ - static Node rewriteRepeatRegExp(TNode node); + Node rewriteRepeatRegExp(TNode node); /** rewrite regular expression option * * This is the entry point for post-rewriting applications of re.opt. * Returns the rewritten form of node. */ - static Node rewriteOptionRegExp(TNode node); + Node rewriteOptionRegExp(TNode node); /** rewrite regular expression plus * * This is the entry point for post-rewriting applications of re.+. * Returns the rewritten form of node. */ - static Node rewritePlusRegExp(TNode node); + Node rewritePlusRegExp(TNode node); /** rewrite regular expression difference * * This is the entry point for post-rewriting applications of re.diff. * Returns the rewritten form of node. */ - static Node rewriteDifferenceRegExp(TNode node); + Node rewriteDifferenceRegExp(TNode node); /** rewrite regular expression range * * This is the entry point for post-rewriting applications of re.range. * Returns the rewritten form of node. */ - static Node rewriteRangeRegExp(TNode node); + Node rewriteRangeRegExp(TNode node); /** rewrite regular expression membership * * This is the entry point for post-rewriting applications of str.in.re * Returns the rewritten form of node. */ - static Node rewriteMembership(TNode node); + Node rewriteMembership(TNode node); /** rewrite string equality extended * * This method returns a formula that is equivalent to the equality between * two strings s = t, given by node. It is called by rewriteEqualityExt. */ - static Node rewriteStrEqualityExt(Node node); + Node rewriteStrEqualityExt(Node node); /** rewrite arithmetic equality extended * * This method returns a formula that is equivalent to the equality between * two arithmetic string terms s = t, given by node. t is called by * rewriteEqualityExt. */ - static Node rewriteArithEqualityExt(Node node); + Node rewriteArithEqualityExt(Node node); /** * Called when node rewrites to ret. * @@ -117,7 +122,7 @@ class SequencesRewriter : public TheoryRewriter * additional rewrites on ret, after which we return the result of this call. * Otherwise, this method simply returns ret. */ - static Node returnRewrite(Node node, Node ret, Rewrite r); + Node returnRewrite(Node node, Node ret, Rewrite r); public: RewriteResponse postRewrite(TNode node) override; @@ -129,7 +134,7 @@ class SequencesRewriter : public TheoryRewriter * two strings s = t, given by node. The result of rewrite is one of * { s = t, t = s, true, false }. */ - static Node rewriteEquality(Node node); + Node rewriteEquality(Node node); /** rewrite equality extended * * This method returns a formula that is equivalent to the equality between @@ -140,31 +145,31 @@ class SequencesRewriter : public TheoryRewriter * Specifically, this function performs rewrites whose conclusion is not * necessarily one of { s = t, t = s, true, false }. */ - static Node rewriteEqualityExt(Node node); + Node rewriteEqualityExt(Node node); /** rewrite string length * This is the entry point for post-rewriting terms node of the form * str.len( t ) * Returns the rewritten form of node. */ - static Node rewriteLength(Node node); + Node rewriteLength(Node node); /** rewrite concat * This is the entry point for post-rewriting terms node of the form * str.++( t1, .., tn ) * Returns the rewritten form of node. */ - static Node rewriteConcat(Node node); + Node rewriteConcat(Node node); /** rewrite character at * This is the entry point for post-rewriting terms node of the form * str.charat( s, i1 ) * Returns the rewritten form of node. */ - static Node rewriteCharAt(Node node); + Node rewriteCharAt(Node node); /** rewrite substr * This is the entry point for post-rewriting terms node of the form * str.substr( s, i1, i2 ) * Returns the rewritten form of node. */ - static Node rewriteSubstr(Node node); + Node rewriteSubstr(Node node); /** rewrite contains * This is the entry point for post-rewriting terms node of the form * str.contains( t, s ) @@ -174,51 +179,51 @@ class SequencesRewriter : public TheoryRewriter * 7 of Reynolds et al "Scaling Up DPLL(T) String Solvers Using * Context-Dependent Rewriting", CAV 2017. */ - static Node rewriteContains(Node node); + Node rewriteContains(Node node); /** rewrite indexof * This is the entry point for post-rewriting terms n of the form * str.indexof( s, t, n ) * Returns the rewritten form of node. */ - static Node rewriteIndexof(Node node); + Node rewriteIndexof(Node node); /** rewrite replace * This is the entry point for post-rewriting terms n of the form * str.replace( s, t, r ) * Returns the rewritten form of node. */ - static Node rewriteReplace(Node node); + Node rewriteReplace(Node node); /** rewrite replace all * This is the entry point for post-rewriting terms n of the form * str.replaceall( s, t, r ) * Returns the rewritten form of node. */ - static Node rewriteReplaceAll(Node node); + Node rewriteReplaceAll(Node node); /** rewrite replace internal * * This method implements rewrite rules that apply to both str.replace and * str.replaceall. If it returns a non-null ret, then node rewrites to ret. */ - static Node rewriteReplaceInternal(Node node); + Node rewriteReplaceInternal(Node node); /** rewrite string reverse * * This is the entry point for post-rewriting terms n of the form * str.rev( s ) * Returns the rewritten form of node. */ - static Node rewriteStrReverse(Node node); + Node rewriteStrReverse(Node node); /** rewrite prefix/suffix * This is the entry point for post-rewriting terms n of the form * str.prefixof( s, t ) / str.suffixof( s, t ) * Returns the rewritten form of node. */ - static Node rewritePrefixSuffix(Node node); + Node rewritePrefixSuffix(Node node); /** rewrite str.to_code * This is the entry point for post-rewriting terms n of the form * str.to_code( t ) * Returns the rewritten form of node. */ - static Node rewriteStringToCode(Node node); + Node rewriteStringToCode(Node node); /** length preserving rewrite * @@ -235,6 +240,12 @@ class SequencesRewriter : public TheoryRewriter * string exists. */ static Node canonicalStrForSymbolicLength(Node n, TypeNode stype); + + /** Reference to the rewriter statistics. */ + HistogramStat* d_statistics; + + /** Instance of the entailment checker for strings. */ + StringsEntail d_stringsEntail; }; /* class SequencesRewriter */ } // namespace strings diff --git a/src/theory/strings/sequences_stats.cpp b/src/theory/strings/sequences_stats.cpp index 5cd844290..afcfb1a60 100644 --- a/src/theory/strings/sequences_stats.cpp +++ b/src/theory/strings/sequences_stats.cpp @@ -27,6 +27,7 @@ SequencesStatistics::SequencesStatistics() d_reductions("theory::strings::reductions"), d_regexpUnfoldingsPos("theory::strings::regexpUnfoldingsPos"), d_regexpUnfoldingsNeg("theory::strings::regexpUnfoldingsNeg"), + d_rewrites("theory::strings::rewrites"), d_conflictsEqEngine("theory::strings::conflictsEqEngine", 0), d_conflictsEagerPrefix("theory::strings::conflictsEagerPrefix", 0), d_conflictsInfer("theory::strings::conflictsInfer", 0), @@ -42,6 +43,7 @@ SequencesStatistics::SequencesStatistics() smtStatisticsRegistry()->registerStat(&d_reductions); smtStatisticsRegistry()->registerStat(&d_regexpUnfoldingsPos); smtStatisticsRegistry()->registerStat(&d_regexpUnfoldingsNeg); + smtStatisticsRegistry()->registerStat(&d_rewrites); smtStatisticsRegistry()->registerStat(&d_conflictsEqEngine); smtStatisticsRegistry()->registerStat(&d_conflictsEagerPrefix); smtStatisticsRegistry()->registerStat(&d_conflictsInfer); @@ -59,6 +61,7 @@ SequencesStatistics::~SequencesStatistics() smtStatisticsRegistry()->unregisterStat(&d_reductions); smtStatisticsRegistry()->unregisterStat(&d_regexpUnfoldingsPos); smtStatisticsRegistry()->unregisterStat(&d_regexpUnfoldingsNeg); + smtStatisticsRegistry()->unregisterStat(&d_rewrites); smtStatisticsRegistry()->unregisterStat(&d_conflictsEqEngine); smtStatisticsRegistry()->unregisterStat(&d_conflictsEagerPrefix); smtStatisticsRegistry()->unregisterStat(&d_conflictsInfer); diff --git a/src/theory/strings/sequences_stats.h b/src/theory/strings/sequences_stats.h index 65f50dbbc..63d9f55eb 100644 --- a/src/theory/strings/sequences_stats.h +++ b/src/theory/strings/sequences_stats.h @@ -19,6 +19,7 @@ #include "expr/kind.h" #include "theory/strings/infer_info.h" +#include "theory/strings/rewrites.h" #include "util/statistics_registry.h" namespace CVC4 { @@ -77,6 +78,8 @@ class SequencesStatistics HistogramStat d_regexpUnfoldingsPos; HistogramStat d_regexpUnfoldingsNeg; //--------------- end of inferences + /** Counts the number of applications of each type of rewrite rule */ + HistogramStat d_rewrites; //--------------- conflicts, partition of calls to OutputChannel::conflict /** Number of equality engine conflicts */ IntStat d_conflictsEqEngine; diff --git a/src/theory/strings/strings_entail.cpp b/src/theory/strings/strings_entail.cpp index a1abfabe5..99219af82 100644 --- a/src/theory/strings/strings_entail.cpp +++ b/src/theory/strings/strings_entail.cpp @@ -27,6 +27,10 @@ namespace CVC4 { namespace theory { namespace strings { +StringsEntail::StringsEntail(SequencesRewriter& rewriter) : d_rewriter(rewriter) +{ +} + bool StringsEntail::canConstantContainConcat(Node c, Node n, int& firstc, @@ -468,10 +472,10 @@ bool StringsEntail::componentContainsBase( { // (str.contains (str.replace x y z) w) ---> true // if (str.contains x w) --> true and (str.contains z w) ---> true - Node xCtnW = StringsEntail::checkContains(n1[0], n2); + Node xCtnW = checkContains(n1[0], n2); if (!xCtnW.isNull() && xCtnW.getConst()) { - Node zCtnW = StringsEntail::checkContains(n1[2], n2); + Node zCtnW = checkContains(n1[2], n2); if (!zCtnW.isNull() && zCtnW.getConst()) { return true; @@ -680,7 +684,7 @@ Node StringsEntail::checkContains(Node a, Node b, bool fullRewriter) do { prev = ctn; - ctn = SequencesRewriter::rewriteContains(ctn); + ctn = d_rewriter.rewriteContains(ctn); } while (prev != ctn && ctn.getKind() == STRING_STRCTN); } diff --git a/src/theory/strings/strings_entail.h b/src/theory/strings/strings_entail.h index d4993faf4..379c09043 100644 --- a/src/theory/strings/strings_entail.h +++ b/src/theory/strings/strings_entail.h @@ -25,6 +25,8 @@ namespace CVC4 { namespace theory { namespace strings { +class SequencesRewriter; + /** * Entailment tests involving strings. * Some of these techniques are described in Reynolds et al, "High Level @@ -33,6 +35,8 @@ namespace strings { class StringsEntail { public: + StringsEntail(SequencesRewriter& rewriter); + /** can constant contain list * return true if constant c can contain the list l in order * firstc/lastc store which indices in l were used to determine the return @@ -153,12 +157,12 @@ class StringsEntail * n1 is updated to { "c", x, "def" }, * nb is updated to { y, "ab" } */ - static int componentContains(std::vector& n1, - std::vector& n2, - std::vector& nb, - std::vector& ne, - bool computeRemainder = false, - int remainderDir = 0); + int componentContains(std::vector& n1, + std::vector& n2, + std::vector& nb, + std::vector& ne, + bool computeRemainder = false, + int remainderDir = 0); /** strip constant endpoints * This function is used when rewriting str.contains( t1, t2 ), where * n1 is the vector form of t1 @@ -208,7 +212,7 @@ class StringsEntail * @return true node if it can be shown that `a` contains `b`, false node if * it can be shown that `a` does not contain `b`, null node otherwise */ - static Node checkContains(Node a, Node b, bool fullRewriter = true); + Node checkContains(Node a, Node b, bool fullRewriter = true); /** entail non-empty * @@ -346,7 +350,7 @@ class StringsEntail * Since we do not wish to introduce ITE terms in the rewriter, we instead * return false, indicating that we cannot compute the remainder. */ - static bool componentContainsBase( + bool componentContainsBase( Node n1, Node n2, Node& n1rb, Node& n1re, int dir, bool computeRemainder); /** * Simplifies a given node `a` s.t. the result is a concatenation of string @@ -362,6 +366,13 @@ class StringsEntail * @return A concatenation that can be interpreted as a multiset */ static Node getMultisetApproximation(Node a); + + private: + /** + * Reference to the sequences rewriter that owns this `StringsEntail` + * instance. + */ + SequencesRewriter& d_rewriter; }; } // namespace strings diff --git a/src/theory/strings/strings_rewriter.cpp b/src/theory/strings/strings_rewriter.cpp index 28ed14095..f27a19065 100644 --- a/src/theory/strings/strings_rewriter.cpp +++ b/src/theory/strings/strings_rewriter.cpp @@ -25,6 +25,67 @@ namespace CVC4 { namespace theory { namespace strings { +StringsRewriter::StringsRewriter(HistogramStat* statistics) + : SequencesRewriter(statistics) +{ +} + +RewriteResponse StringsRewriter::postRewrite(TNode node) +{ + Trace("strings-postrewrite") + << "Strings::StringsRewriter::postRewrite start " << node << std::endl; + + Node retNode = node; + Kind nk = node.getKind(); + if (nk == kind::STRING_LT) + { + retNode = rewriteStringLt(node); + } + else if (nk == kind::STRING_LEQ) + { + retNode = rewriteStringLeq(node); + } + else if (nk == STRING_TOLOWER || nk == STRING_TOUPPER) + { + retNode = rewriteStrConvert(node); + } + else if (nk == STRING_IS_DIGIT) + { + retNode = rewriteStringIsDigit(node); + } + else if (nk == kind::STRING_ITOS) + { + retNode = rewriteIntToStr(node); + } + else if (nk == kind::STRING_STOI) + { + retNode = rewriteStrToInt(node); + } + else if (nk == STRING_TO_CODE) + { + retNode = rewriteStringToCode(node); + } + else if (nk == STRING_FROM_CODE) + { + retNode = rewriteStringFromCode(node); + } + else + { + return SequencesRewriter::postRewrite(node); + } + + Trace("strings-postrewrite") + << "Strings::StringsRewriter::postRewrite returning " << retNode + << std::endl; + if (node != retNode) + { + Trace("strings-rewrite-debug") << "Strings::StringsRewriter::postRewrite " + << node << " to " << retNode << std::endl; + return RewriteResponse(REWRITE_AGAIN_FULL, retNode); + } + return RewriteResponse(REWRITE_DONE, retNode); +} + Node StringsRewriter::rewriteStrToInt(Node node) { Assert(node.getKind() == STRING_STOI); diff --git a/src/theory/strings/strings_rewriter.h b/src/theory/strings/strings_rewriter.h index 0c5b0b2f8..ce4be476d 100644 --- a/src/theory/strings/strings_rewriter.h +++ b/src/theory/strings/strings_rewriter.h @@ -32,13 +32,17 @@ namespace strings { class StringsRewriter : public SequencesRewriter { public: + StringsRewriter(HistogramStat* statistics); + + RewriteResponse postRewrite(TNode node) override; + /** rewrite string to integer * * This is the entry point for post-rewriting terms n of the form * str.to_int( s ) * Returns the rewritten form of n. */ - static Node rewriteStrToInt(Node n); + Node rewriteStrToInt(Node n); /** rewrite integer to string * @@ -46,7 +50,7 @@ class StringsRewriter : public SequencesRewriter * str.from_int( i ) * Returns the rewritten form of n. */ - static Node rewriteIntToStr(Node n); + Node rewriteIntToStr(Node n); /** rewrite string convert * @@ -54,7 +58,7 @@ class StringsRewriter : public SequencesRewriter * str.tolower( s ) and str.toupper( s ) * Returns the rewritten form of n. */ - static Node rewriteStrConvert(Node n); + Node rewriteStrConvert(Node n); /** rewrite string less than * @@ -62,7 +66,7 @@ class StringsRewriter : public SequencesRewriter * str.<( t, s ) * Returns the rewritten form of n. */ - static Node rewriteStringLt(Node n); + Node rewriteStringLt(Node n); /** rewrite string less than or equal * @@ -70,7 +74,7 @@ class StringsRewriter : public SequencesRewriter * str.<=( t, s ) * Returns the rewritten form of n. */ - static Node rewriteStringLeq(Node n); + Node rewriteStringLeq(Node n); /** rewrite str.from_code * @@ -78,7 +82,7 @@ class StringsRewriter : public SequencesRewriter * str.from_code( t ) * Returns the rewritten form of n. */ - static Node rewriteStringFromCode(Node n); + Node rewriteStringFromCode(Node n); /** rewrite str.to_code * @@ -86,7 +90,7 @@ class StringsRewriter : public SequencesRewriter * str.to_code( t ) * Returns the rewritten form of n. */ - static Node rewriteStringToCode(Node n); + Node rewriteStringToCode(Node n); /** rewrite is digit * @@ -94,7 +98,7 @@ class StringsRewriter : public SequencesRewriter * str.is_digit( t ) * Returns the rewritten form of n. */ - static Node rewriteStringIsDigit(Node n); + Node rewriteStringIsDigit(Node n); }; } // namespace strings diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index d5eb2dbbd..d74a0e9ca 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -26,7 +26,6 @@ #include "smt/smt_statistics_registry.h" #include "theory/ext_theory.h" #include "theory/rewriter.h" -#include "theory/strings/sequences_rewriter.h" #include "theory/strings/theory_strings_utils.h" #include "theory/strings/type_enumerator.h" #include "theory/strings/word.h" @@ -70,6 +69,7 @@ TheoryStrings::TheoryStrings(context::Context* c, const LogicInfo& logicInfo) : Theory(THEORY_STRINGS, c, u, out, valuation, logicInfo), d_notify(*this), + d_statistics(), d_equalityEngine(d_notify, c, "theory::strings::ee", true), d_state(c, d_equalityEngine, d_valuation), d_im(*this, c, u, d_state, d_sk_cache, out, d_statistics), @@ -77,6 +77,7 @@ TheoryStrings::TheoryStrings(context::Context* c, d_registered_terms_cache(u), d_functionsTerms(c), d_has_str_code(false), + d_rewriter(&d_statistics.d_rewrites), d_bsolver(c, u, d_state, d_im), d_csolver(c, u, d_state, d_im, d_sk_cache, d_bsolver), d_esolver(nullptr), @@ -91,6 +92,7 @@ TheoryStrings::TheoryStrings(context::Context* c, d_state, d_im, d_sk_cache, + d_rewriter, d_bsolver, d_csolver, extt, @@ -131,11 +133,6 @@ TheoryStrings::~TheoryStrings() { } -std::unique_ptr TheoryStrings::mkTheoryRewriter() -{ - return std::unique_ptr(new SequencesRewriter()); -} - bool TheoryStrings::areCareDisequal( TNode x, TNode y ) { Assert(d_equalityEngine.hasTerm(x)); Assert(d_equalityEngine.hasTerm(y)); diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 0e95628bc..5ae0ac7a9 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -39,6 +39,7 @@ #include "theory/strings/skolem_cache.h" #include "theory/strings/solver_state.h" #include "theory/strings/strings_fmf.h" +#include "theory/strings/strings_rewriter.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" @@ -109,7 +110,7 @@ class TheoryStrings : public Theory { const LogicInfo& logicInfo); ~TheoryStrings(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } void setMasterEqualityEngine(eq::EqualityEngine* eq) override; @@ -352,6 +353,9 @@ private: // Symbolic Regular Expression private: + /** The theory rewriter for this theory. */ + StringsRewriter d_rewriter; + /** * The base solver, responsible for reasoning about congruent terms and * inferring constants for equivalence classes. diff --git a/src/theory/strings/theory_strings_preprocess.cpp b/src/theory/strings/theory_strings_preprocess.cpp index 097cef235..5fc13f023 100644 --- a/src/theory/strings/theory_strings_preprocess.cpp +++ b/src/theory/strings/theory_strings_preprocess.cpp @@ -79,7 +79,7 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) { Node sk2 = ArithEntail::check(t12, lt0) ? emp : d_sc->mkSkolemCached( - s, t12, SkolemCache::SK_SUFFIX_REM, "sssufr"); + s, t12, SkolemCache::SK_SUFFIX_REM, "sssufr"); Node b11 = s.eqNode(nm->mkNode(STRING_CONCAT, sk1, skt, sk2)); //length of first skolem is second argument Node b12 = nm->mkNode(STRING_LENGTH, sk1).eqNode(n); diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index a159787f9..635a3216a 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -28,6 +28,7 @@ #include "theory/ext_theory.h" #include "theory/quantifiers_engine.h" #include "theory/substitutions.h" +#include "theory/theory_rewriter.h" using namespace std; diff --git a/src/theory/theory.h b/src/theory/theory.h index 63ca46b41..a6751e1ec 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -56,6 +56,7 @@ class QuantifiersEngine; class TheoryModel; class SubstitutionMap; class ExtTheory; +class TheoryRewriter; class EntailmentCheckParameters; class EntailmentCheckSideEffects; @@ -79,9 +80,7 @@ namespace eq { * all calls to them.) */ class Theory { - -private: - + private: friend class ::CVC4::TheoryEngine; // Disallow default construction, copy, assignment. @@ -140,7 +139,6 @@ private: protected: - // === STATISTICS === /** time spent in check calls */ TimerStat d_checkTime; @@ -318,9 +316,9 @@ public: virtual ~Theory(); /** - * Creates a new theory rewriter for the theory. + * @return The theory rewriter associated with this theory. */ - virtual std::unique_ptr mkTheoryRewriter() = 0; + virtual TheoryRewriter* getTheoryRewriter() = 0; /** * Subclasses of Theory may add additional efforts. DO NOT CHECK diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index dec654e76..809ef5139 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -495,7 +495,7 @@ class TheoryEngine { theory::Valuation(this), d_logicInfo); theory::Rewriter::registerTheoryRewriter( - theoryId, d_theoryTable[theoryId]->mkTheoryRewriter()); + theoryId, d_theoryTable[theoryId]->getTheoryRewriter()); } void setPropEngine(prop::PropEngine* propEngine) diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index 1ea5449b7..3b42fa6a1 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -66,11 +66,6 @@ TheoryUF::TheoryUF(context::Context* c, TheoryUF::~TheoryUF() { } -std::unique_ptr TheoryUF::mkTheoryRewriter() -{ - return std::unique_ptr(new TheoryUfRewriter()); -} - void TheoryUF::setMasterEqualityEngine(eq::EqualityEngine* eq) { d_equalityEngine.setMasterEqualityEngine(eq); } diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 623c5c64b..50b7a65cb 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -26,6 +26,7 @@ #include "theory/theory.h" #include "theory/uf/equality_engine.h" #include "theory/uf/symmetry_breaker.h" +#include "theory/uf/theory_uf_rewriter.h" namespace CVC4 { namespace theory { @@ -188,7 +189,7 @@ private: ~TheoryUF(); - std::unique_ptr mkTheoryRewriter() override; + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } void setMasterEqualityEngine(eq::EqualityEngine* eq) override; void finishInit() override; @@ -225,6 +226,8 @@ private: TNodeTrie* t2, unsigned arity, unsigned depth); + + TheoryUfRewriter d_rewriter; };/* class TheoryUF */ }/* CVC4::theory::uf namespace */ diff --git a/test/unit/theory/sequences_rewriter_white.h b/test/unit/theory/sequences_rewriter_white.h index 4cc679ca8..7e45296a9 100644 --- a/test/unit/theory/sequences_rewriter_white.h +++ b/test/unit/theory/sequences_rewriter_white.h @@ -14,6 +14,12 @@ ** Unit tests for the strings/sequences rewriter. **/ +#include + +#include +#include +#include + #include "expr/node.h" #include "expr/node_manager.h" #include "smt/smt_engine.h" @@ -23,11 +29,7 @@ #include "theory/strings/arith_entail.h" #include "theory/strings/sequences_rewriter.h" #include "theory/strings/strings_entail.h" - -#include -#include -#include -#include +#include "theory/strings/strings_rewriter.h" using namespace CVC4; using namespace CVC4::smt; @@ -246,7 +248,7 @@ class SequencesRewriterWhite : public CxxTest::TestSuite // (str.substr "A" x x) --> "" Node n = d_nm->mkNode(kind::STRING_SUBSTR, a, x, x); - Node res = SequencesRewriter::rewriteSubstr(n); + Node res = StringsRewriter(nullptr).rewriteSubstr(n); TS_ASSERT_EQUALS(res, empty); // (str.substr "A" (+ x 1) x) -> "" @@ -254,7 +256,7 @@ class SequencesRewriterWhite : public CxxTest::TestSuite a, d_nm->mkNode(kind::PLUS, x, d_nm->mkConst(Rational(1))), x); - res = SequencesRewriter::rewriteSubstr(n); + res = StringsRewriter(nullptr).rewriteSubstr(n); TS_ASSERT_EQUALS(res, empty); // (str.substr "A" (+ x (str.len s2)) x) -> "" @@ -263,24 +265,24 @@ class SequencesRewriterWhite : public CxxTest::TestSuite a, d_nm->mkNode(kind::PLUS, x, d_nm->mkNode(kind::STRING_LENGTH, s)), x); - res = SequencesRewriter::rewriteSubstr(n); + res = StringsRewriter(nullptr).rewriteSubstr(n); TS_ASSERT_EQUALS(res, empty); // (str.substr "A" x y) -> (str.substr "A" x y) n = d_nm->mkNode(kind::STRING_SUBSTR, a, x, y); - res = SequencesRewriter::rewriteSubstr(n); + res = StringsRewriter(nullptr).rewriteSubstr(n); TS_ASSERT_EQUALS(res, n); // (str.substr "ABCD" (+ x 3) x) -> "" n = d_nm->mkNode( kind::STRING_SUBSTR, abcd, d_nm->mkNode(kind::PLUS, x, three), x); - res = SequencesRewriter::rewriteSubstr(n); + res = StringsRewriter(nullptr).rewriteSubstr(n); TS_ASSERT_EQUALS(res, empty); // (str.substr "ABCD" (+ x 2) x) -> (str.substr "ABCD" (+ x 2) x) n = d_nm->mkNode( kind::STRING_SUBSTR, abcd, d_nm->mkNode(kind::PLUS, x, two), x); - res = SequencesRewriter::rewriteSubstr(n); + res = StringsRewriter(nullptr).rewriteSubstr(n); TS_ASSERT_EQUALS(res, n); // (str.substr (str.substr s x x) x x) -> "" diff --git a/test/unit/theory/theory_engine_white.h b/test/unit/theory/theory_engine_white.h index 4a019ac08..992251f16 100644 --- a/test/unit/theory/theory_engine_white.h +++ b/test/unit/theory/theory_engine_white.h @@ -113,18 +113,6 @@ class FakeTheoryRewriter : public TheoryRewriter template class FakeTheory : public Theory { - /** - * This fake theory class is equally useful for bool, uf, arith, etc. It - * keeps an identifier to identify itself. - */ - std::string d_id; - - /** - * The expected sequence of rewrite calls. Filled by FakeTheory::expect() and - * consumed by FakeTheory::preRewrite() and FakeTheory::postRewrite(). - */ - // static std::deque s_expected; - public: FakeTheory(context::Context* ctxt, context::UserContext* uctxt, @@ -135,10 +123,7 @@ class FakeTheory : public Theory { } - std::unique_ptr mkTheoryRewriter() override - { - return std::unique_ptr(new FakeTheoryRewriter()); - } + TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } /** Register an expected rewrite call */ static void expect(RewriteType type, @@ -176,6 +161,16 @@ class FakeTheory : public Theory return Node::null(); } Node getValue(TNode n) { return Node::null(); } + + private: + /** + * This fake theory class is equally useful for bool, uf, arith, etc. It + * keeps an identifier to identify itself. + */ + std::string d_id; + + /** The theory rewriter for this theory. */ + FakeTheoryRewriter d_rewriter; }; /* class FakeTheory */ /* definition of the s_expected static field in FakeTheory; see above */ diff --git a/test/unit/theory/theory_white.h b/test/unit/theory/theory_white.h index eb43e00cb..0d5238aa9 100644 --- a/test/unit/theory/theory_white.h +++ b/test/unit/theory/theory_white.h @@ -102,10 +102,7 @@ class DummyTheory : public Theory { : Theory(theory::THEORY_BUILTIN, ctxt, uctxt, out, valuation, logicInfo) {} - std::unique_ptr mkTheoryRewriter() - { - return std::unique_ptr(); - } + TheoryRewriter* getTheoryRewriter() { return nullptr; } void registerTerm(TNode n) { // check that we registerTerm() a term only once