From d6b3329e3f2b6e29e5f4af6cf09fd32e26c47e15 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 8 Sep 2021 12:30:31 -0500 Subject: [PATCH] Towards standard usage of ExtendedRewriter (#7145) This PR: Adds extendedRewrite to EnvObj and Rewriter. Eliminates static calls to Rewriter::rewrite from within the extended rewriter. Instead, the use of extended rewriter is always through Rewriter, which passes itself to the ExtendedRewriter. Make most uses of extended rewriter non-static. I've added a placeholder method Rewriter::callExtendedRewrite for places in the code that call the extended rewriter are currently difficult to eliminate. --- .../passes/extended_rewriter_pass.cpp | 6 +-- src/smt/env_obj.cpp | 5 +++ src/smt/env_obj.h | 5 +++ src/smt/preprocess_proof_generator.cpp | 3 +- src/smt/quant_elim_solver.cpp | 11 ++--- src/smt/quant_elim_solver.h | 7 +--- src/theory/arith/nl/cad/cdcac.cpp | 9 ++--- src/theory/arith/nl/cad/cdcac.h | 6 +-- src/theory/builtin/proof_checker.h | 4 -- src/theory/datatypes/sygus_extension.cpp | 8 ++-- src/theory/datatypes/sygus_extension.h | 6 ++- src/theory/datatypes/theory_datatypes.cpp | 2 +- .../candidate_rewrite_database.cpp | 16 ++++---- .../quantifiers/candidate_rewrite_database.h | 8 ++-- src/theory/quantifiers/expr_miner_manager.cpp | 2 +- src/theory/quantifiers/expr_miner_manager.h | 6 --- src/theory/quantifiers/extended_rewrite.cpp | 21 +++++----- src/theory/quantifiers/extended_rewrite.h | 7 +++- .../quantifiers/quantifiers_modules.cpp | 5 ++- src/theory/quantifiers/quantifiers_modules.h | 3 +- .../quantifiers/quantifiers_rewriter.cpp | 3 +- .../sygus/ce_guided_single_inv.cpp | 4 +- src/theory/quantifiers/sygus/cegis.cpp | 2 +- .../sygus/enum_stream_substitution.cpp | 18 ++++----- .../quantifiers/sygus/enum_value_manager.cpp | 6 ++- .../quantifiers/sygus/enum_value_manager.h | 6 ++- .../sygus/sygus_enumerator_basic.cpp | 3 +- .../sygus/sygus_enumerator_callback.cpp | 3 +- .../sygus/sygus_enumerator_callback.h | 2 - .../quantifiers/sygus/sygus_grammar_red.cpp | 3 +- .../quantifiers/sygus/sygus_invariance.cpp | 6 +-- src/theory/quantifiers/sygus/sygus_pbe.cpp | 2 +- .../quantifiers/sygus/sygus_unif_io.cpp | 2 +- .../quantifiers/sygus/synth_conjecture.cpp | 19 ++++----- .../quantifiers/sygus/synth_conjecture.h | 6 ++- src/theory/quantifiers/sygus/synth_engine.cpp | 9 +++-- src/theory/quantifiers/sygus/synth_engine.h | 3 +- .../quantifiers/sygus/term_database_sygus.cpp | 8 ++-- .../quantifiers/sygus/term_database_sygus.h | 11 ++--- src/theory/quantifiers/term_registry.cpp | 6 ++- src/theory/quantifiers/term_registry.h | 3 +- src/theory/quantifiers/theory_quantifiers.cpp | 4 +- src/theory/quantifiers_engine.cpp | 6 ++- src/theory/quantifiers_engine.h | 13 ++++-- src/theory/rewriter.cpp | 18 ++++++--- src/theory/rewriter.h | 40 ++++++++----------- test/unit/theory/sequences_rewriter_white.cpp | 18 ++++----- test/unit/theory/theory_arith_cad_white.cpp | 5 +-- 48 files changed, 189 insertions(+), 180 deletions(-) diff --git a/src/preprocessing/passes/extended_rewriter_pass.cpp b/src/preprocessing/passes/extended_rewriter_pass.cpp index a36388c26..7242be54c 100644 --- a/src/preprocessing/passes/extended_rewriter_pass.cpp +++ b/src/preprocessing/passes/extended_rewriter_pass.cpp @@ -20,7 +20,6 @@ #include "options/smt_options.h" #include "preprocessing/assertion_pipeline.h" #include "preprocessing/preprocessing_pass_context.h" -#include "theory/quantifiers/extended_rewrite.h" namespace cvc5 { namespace preprocessing { @@ -32,11 +31,12 @@ ExtRewPre::ExtRewPre(PreprocessingPassContext* preprocContext) PreprocessingPassResult ExtRewPre::applyInternal( AssertionPipeline* assertionsToPreprocess) { - theory::quantifiers::ExtendedRewriter extr(options::extRewPrepAgg()); for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; ++i) { assertionsToPreprocess->replace( - i, extr.extendedRewrite((*assertionsToPreprocess)[i])); + i, + extendedRewrite((*assertionsToPreprocess)[i], + options::extRewPrepAgg())); } return PreprocessingPassResult::NO_CONFLICT; } diff --git a/src/smt/env_obj.cpp b/src/smt/env_obj.cpp index acc3ab038..fc50a359b 100644 --- a/src/smt/env_obj.cpp +++ b/src/smt/env_obj.cpp @@ -26,6 +26,11 @@ EnvObj::EnvObj(Env& env) : d_env(env) {} Node EnvObj::rewrite(TNode node) { return d_env.getRewriter()->rewrite(node); } +Node EnvObj::extendedRewrite(TNode node, bool aggr) +{ + return d_env.getRewriter()->extendedRewrite(node, aggr); +} + const LogicInfo& EnvObj::logicInfo() const { return d_env.getLogicInfo(); } const Options& EnvObj::options() const { return d_env.getOptions(); } diff --git a/src/smt/env_obj.h b/src/smt/env_obj.h index d1c882b96..ebe304dcf 100644 --- a/src/smt/env_obj.h +++ b/src/smt/env_obj.h @@ -49,6 +49,11 @@ class EnvObj * This is a wrapper around theory::Rewriter::rewrite via Env. */ Node rewrite(TNode node); + /** + * Extended rewrite a node. + * This is a wrapper around theory::Rewriter::extendedRewrite via Env. + */ + Node extendedRewrite(TNode node, bool aggr = true); /** Get the current logic information. */ const LogicInfo& logicInfo() const; diff --git a/src/smt/preprocess_proof_generator.cpp b/src/smt/preprocess_proof_generator.cpp index 1e322ccd3..e2730151e 100644 --- a/src/smt/preprocess_proof_generator.cpp +++ b/src/smt/preprocess_proof_generator.cpp @@ -180,8 +180,7 @@ std::shared_ptr PreprocessProofGenerator::getProofFor(Node f) if (!proofStepProcessed) { // maybe its just an (extended) rewrite? - theory::quantifiers::ExtendedRewriter extr(true); - Node pr = extr.extendedRewrite(proven[0]); + Node pr = theory::Rewriter::callExtendedRewrite(proven[0]); if (proven[1] == pr) { Node idr = mkMethodId(MethodId::RW_EXT_REWRITE); diff --git a/src/smt/quant_elim_solver.cpp b/src/smt/quant_elim_solver.cpp index 087aa0e06..08ba5b416 100644 --- a/src/smt/quant_elim_solver.cpp +++ b/src/smt/quant_elim_solver.cpp @@ -20,9 +20,7 @@ #include "expr/subs.h" #include "smt/smt_solver.h" #include "theory/quantifiers/cegqi/nested_qe.h" -#include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers_engine.h" -#include "theory/rewriter.h" #include "theory/theory_engine.h" #include "util/string.h" @@ -33,7 +31,7 @@ namespace cvc5 { namespace smt { QuantElimSolver::QuantElimSolver(Env& env, SmtSolver& sms) - : d_env(env), d_smtSolver(sms) + : EnvObj(env), d_smtSolver(sms) { } @@ -52,7 +50,7 @@ Node QuantElimSolver::getQuantifierElimination(Assertions& as, } NodeManager* nm = NodeManager::currentNM(); // ensure the body is rewritten - q = nm->mkNode(q.getKind(), q[0], Rewriter::rewrite(q[1])); + q = nm->mkNode(q.getKind(), q[0], rewrite(q[1])); // do nested quantifier elimination if necessary q = quantifiers::NestedQe::doNestedQe(d_env, q, true); Trace("smt-qe") << "QuantElimSolver: after nested quantifier elimination : " @@ -110,7 +108,7 @@ Node QuantElimSolver::getQuantifierElimination(Assertions& as, Trace("smt-qe") << "QuantElimSolver returned : " << ret << std::endl; if (q.getKind() == EXISTS) { - ret = Rewriter::rewrite(ret.negate()); + ret = rewrite(ret.negate()); } } else @@ -118,8 +116,7 @@ Node QuantElimSolver::getQuantifierElimination(Assertions& as, ret = nm->mkConst(q.getKind() != EXISTS); } // do extended rewrite to minimize the size of the formula aggressively - theory::quantifiers::ExtendedRewriter extr(true); - ret = extr.extendedRewrite(ret); + ret = extendedRewrite(ret); // if we are not an internal subsolver, convert to witness form, since // internally generated skolems should not escape if (!isInternalSubsolver) diff --git a/src/smt/quant_elim_solver.h b/src/smt/quant_elim_solver.h index f890deba0..a0b43d09d 100644 --- a/src/smt/quant_elim_solver.h +++ b/src/smt/quant_elim_solver.h @@ -20,10 +20,9 @@ #include "expr/node.h" #include "smt/assertions.h" +#include "smt/env_obj.h" namespace cvc5 { -class Env; - namespace smt { class SmtSolver; @@ -36,7 +35,7 @@ class SmtSolver; * quantifier instantiations used for unsat which are in turn used for * constructing the solution for the quantifier elimination query. */ -class QuantElimSolver +class QuantElimSolver : protected EnvObj { public: QuantElimSolver(Env& env, SmtSolver& sms); @@ -97,8 +96,6 @@ class QuantElimSolver bool isInternalSubsolver); private: - /** Reference to the env */ - Env& d_env; /** The SMT solver, which is used during doQuantifierElimination. */ SmtSolver& d_smtSolver; }; diff --git a/src/theory/arith/nl/cad/cdcac.cpp b/src/theory/arith/nl/cad/cdcac.cpp index 9b37a135f..9b7678388 100644 --- a/src/theory/arith/nl/cad/cdcac.cpp +++ b/src/theory/arith/nl/cad/cdcac.cpp @@ -23,7 +23,7 @@ #include "theory/arith/nl/cad/projections.h" #include "theory/arith/nl/cad/variable_ordering.h" #include "theory/arith/nl/nl_model.h" -#include "theory/quantifiers/extended_rewrite.h" +#include "theory/rewriter.h" namespace std { /** Generic streaming operator for std::vector. */ @@ -42,7 +42,7 @@ namespace nl { namespace cad { CDCAC::CDCAC(Env& env, const std::vector& ordering) - : d_env(env), d_variableOrdering(ordering) + : EnvObj(env), d_variableOrdering(ordering) { if (d_env.isTheoryProofProducing()) { @@ -276,9 +276,8 @@ PolyVector requiredCoefficientsLazardModified( Kind::EQUAL, nl::as_cvc_polynomial(coeff, vm), zero)); } // if phi is false (i.e. p can not vanish) - quantifiers::ExtendedRewriter rew; - Node rewritten = - rew.extendedRewrite(NodeManager::currentNM()->mkAnd(conditions)); + Node rewritten = Rewriter::callExtendedRewrite( + NodeManager::currentNM()->mkAnd(conditions)); if (rewritten.isConst()) { Assert(rewritten.getKind() == Kind::CONST_BOOLEAN); diff --git a/src/theory/arith/nl/cad/cdcac.h b/src/theory/arith/nl/cad/cdcac.h index b504998d8..be72e4063 100644 --- a/src/theory/arith/nl/cad/cdcac.h +++ b/src/theory/arith/nl/cad/cdcac.h @@ -26,6 +26,7 @@ #include #include "smt/env.h" +#include "smt/env_obj.h" #include "theory/arith/nl/cad/cdcac_utils.h" #include "theory/arith/nl/cad/constraints.h" #include "theory/arith/nl/cad/proof_generator.h" @@ -44,7 +45,7 @@ namespace cad { * This class implements Cylindrical Algebraic Coverings as presented in * https://arxiv.org/pdf/2003.05633.pdf */ -class CDCAC +class CDCAC : protected EnvObj { public: /** Initialize this method with the given variable ordering. */ @@ -184,9 +185,6 @@ class CDCAC */ void pruneRedundantIntervals(std::vector& intervals); - /** A reference to the environment */ - Env& d_env; - /** * The current assignment. When the method terminates with SAT, it contains a * model for the input constraints. diff --git a/src/theory/builtin/proof_checker.h b/src/theory/builtin/proof_checker.h index 8b3988f27..59a84c86e 100644 --- a/src/theory/builtin/proof_checker.h +++ b/src/theory/builtin/proof_checker.h @@ -22,7 +22,6 @@ #include "proof/method_id.h" #include "proof/proof_checker.h" #include "proof/proof_node.h" -#include "theory/quantifiers/extended_rewrite.h" namespace cvc5 { @@ -114,9 +113,6 @@ class BuiltinProofRuleChecker : public ProofRuleChecker const std::vector& children, const std::vector& args) override; - /** extended rewriter object */ - quantifiers::ExtendedRewriter d_ext_rewriter; - private: /** Reference to the environment. */ Env& d_env; diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index e66b70934..8e3c22fb8 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -42,10 +42,12 @@ using namespace cvc5::context; using namespace cvc5::theory; using namespace cvc5::theory::datatypes; -SygusExtension::SygusExtension(TheoryState& s, +SygusExtension::SygusExtension(Env& env, + TheoryState& s, InferenceManager& im, quantifiers::TermDbSygus* tds) - : d_state(s), + : EnvObj(env), + d_state(s), d_im(im), d_tds(tds), d_ssb(tds), @@ -1037,7 +1039,7 @@ Node SygusExtension::registerSearchValue(Node a, << ", type=" << tn << std::endl; Node bv = d_tds->sygusToBuiltin(cnv, tn); Trace("sygus-sb-debug") << " ......builtin is " << bv << std::endl; - Node bvr = d_tds->getExtRewriter()->extendedRewrite(bv); + Node bvr = extendedRewrite(bv); Trace("sygus-sb-debug") << " ......search value rewrites to " << bvr << std::endl; Trace("dt-sygus") << " * DT builtin : " << n << " -> " << bvr << std::endl; unsigned sz = utils::getSygusTermSize(nv); diff --git a/src/theory/datatypes/sygus_extension.h b/src/theory/datatypes/sygus_extension.h index 3c7607eaf..5860dca99 100644 --- a/src/theory/datatypes/sygus_extension.h +++ b/src/theory/datatypes/sygus_extension.h @@ -25,6 +25,7 @@ #include "context/cdhashset.h" #include "context/context.h" #include "expr/node.h" +#include "smt/env_obj.h" #include "theory/datatypes/sygus_simple_sym.h" #include "theory/decision_manager.h" #include "theory/quantifiers/sygus_sampler.h" @@ -62,7 +63,7 @@ class InferenceManager; * We prioritize decisions of form (1) before (2). Both kinds of decision are * critical for solution completeness, which is enforced by DecisionManager. */ -class SygusExtension +class SygusExtension : protected EnvObj { typedef context::CDHashMap IntMap; typedef context::CDHashMap NodeMap; @@ -70,7 +71,8 @@ class SygusExtension typedef context::CDHashSet NodeSet; public: - SygusExtension(TheoryState& s, + SygusExtension(Env& env, + TheoryState& s, InferenceManager& im, quantifiers::TermDbSygus* tds); ~SygusExtension(); diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 58d0dbaab..4a8976876 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -110,7 +110,7 @@ void TheoryDatatypes::finishInit() { quantifiers::TermDbSygus* tds = getQuantifiersEngine()->getTermDatabaseSygus(); - d_sygusExtension.reset(new SygusExtension(d_state, d_im, tds)); + d_sygusExtension.reset(new SygusExtension(d_env, d_state, d_im, tds)); // do congruence on evaluation functions d_equalityEngine->addFunctionKind(kind::DT_SYGUS_EVAL); } diff --git a/src/theory/quantifiers/candidate_rewrite_database.cpp b/src/theory/quantifiers/candidate_rewrite_database.cpp index 0fd0eebd6..475df0b43 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.cpp +++ b/src/theory/quantifiers/candidate_rewrite_database.cpp @@ -37,7 +37,7 @@ CandidateRewriteDatabase::CandidateRewriteDatabase( Env& env, bool doCheck, bool rewAccel, bool silent, bool filterPairs) : ExprMiner(env), d_tds(nullptr), - d_ext_rewrite(nullptr), + d_useExtRewriter(false), d_doCheck(doCheck), d_rewAccel(rewAccel), d_silent(silent), @@ -52,7 +52,7 @@ void CandidateRewriteDatabase::initialize(const std::vector& vars, d_candidate = Node::null(); d_using_sygus = false; d_tds = nullptr; - d_ext_rewrite = nullptr; + d_useExtRewriter = false; if (d_filterPairs) { d_crewrite_filter.initialize(ss, nullptr, false); @@ -69,7 +69,7 @@ void CandidateRewriteDatabase::initializeSygus(const std::vector& vars, d_candidate = f; d_using_sygus = true; d_tds = tds; - d_ext_rewrite = nullptr; + d_useExtRewriter = false; if (d_filterPairs) { d_crewrite_filter.initialize(ss, d_tds, d_using_sygus); @@ -121,10 +121,10 @@ Node CandidateRewriteDatabase::addTerm(Node sol, // get the rewritten form Node solbr; Node eq_solr; - if (d_ext_rewrite != nullptr) + if (d_useExtRewriter) { - solbr = d_ext_rewrite->extendedRewrite(solb); - eq_solr = d_ext_rewrite->extendedRewrite(eq_solb); + solbr = extendedRewrite(solb); + eq_solr = extendedRewrite(eq_solb); } else { @@ -289,9 +289,9 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) void CandidateRewriteDatabase::setSilent(bool flag) { d_silent = flag; } -void CandidateRewriteDatabase::setExtendedRewriter(ExtendedRewriter* er) +void CandidateRewriteDatabase::enableExtendedRewriter() { - d_ext_rewrite = er; + d_useExtRewriter = true; } } // namespace quantifiers diff --git a/src/theory/quantifiers/candidate_rewrite_database.h b/src/theory/quantifiers/candidate_rewrite_database.h index 71ae5649f..c0e783fc1 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.h +++ b/src/theory/quantifiers/candidate_rewrite_database.h @@ -100,14 +100,14 @@ class CandidateRewriteDatabase : public ExprMiner bool addTerm(Node sol, std::ostream& out) override; /** sets whether this class should output candidate rewrites it finds */ void setSilent(bool flag); - /** set the (extended) rewriter used by this class */ - void setExtendedRewriter(ExtendedRewriter* er); + /** Enable the (extended) rewriter for this class */ + void enableExtendedRewriter(); private: /** (required) pointer to the sygus term database of d_qe */ TermDbSygus* d_tds; - /** an extended rewriter object */ - ExtendedRewriter* d_ext_rewrite; + /** Whether we use the extended rewriter */ + bool d_useExtRewriter; /** the function-to-synthesize we are testing (if sygus) */ Node d_candidate; /** whether we are checking equivalence using subsolver */ diff --git a/src/theory/quantifiers/expr_miner_manager.cpp b/src/theory/quantifiers/expr_miner_manager.cpp index ae20d4909..8af456ea8 100644 --- a/src/theory/quantifiers/expr_miner_manager.cpp +++ b/src/theory/quantifiers/expr_miner_manager.cpp @@ -87,7 +87,7 @@ void ExpressionMinerManager::enableRewriteRuleSynth() { d_crd.initialize(vars, &d_sampler); } - d_crd.setExtendedRewriter(&d_ext_rew); + d_crd.enableExtendedRewriter(); d_crd.setSilent(false); } diff --git a/src/theory/quantifiers/expr_miner_manager.h b/src/theory/quantifiers/expr_miner_manager.h index 92450b3ba..43a615c97 100644 --- a/src/theory/quantifiers/expr_miner_manager.h +++ b/src/theory/quantifiers/expr_miner_manager.h @@ -21,15 +21,11 @@ #include "expr/node.h" #include "smt/env_obj.h" #include "theory/quantifiers/candidate_rewrite_database.h" -#include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/query_generator.h" #include "theory/quantifiers/solution_filter.h" #include "theory/quantifiers/sygus_sampler.h" namespace cvc5 { - -class Env; - namespace theory { namespace quantifiers { @@ -114,8 +110,6 @@ class ExpressionMinerManager : protected EnvObj SolutionFilterStrength d_sols; /** sygus sampler object */ SygusSampler d_sampler; - /** extended rewriter object */ - ExtendedRewriter d_ext_rew; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 58a78b4aa..40e28eb78 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -42,7 +42,8 @@ struct ExtRewriteAggAttributeId }; typedef expr::Attribute ExtRewriteAggAttribute; -ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) +ExtendedRewriter::ExtendedRewriter(Rewriter& rew, bool aggr) + : d_rew(rew), d_aggr(aggr) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); @@ -97,7 +98,7 @@ bool ExtendedRewriter::addToChildren(Node nc, Node ExtendedRewriter::extendedRewrite(Node n) const { - n = Rewriter::rewrite(n); + n = d_rew.rewrite(n); // has it already been computed? Node ncache = getCache(n); @@ -204,7 +205,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) const } } } - ret = Rewriter::rewrite(ret); + ret = d_rew.rewrite(ret); //--------------------end rewrite children // now, do extended rewrite @@ -496,7 +497,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); if (nn != t2) { - nn = Rewriter::rewrite(nn); + nn = d_rew.rewrite(nn); if (nn == t1) { new_ret = t2; @@ -508,7 +509,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const // must use partial substitute here, to avoid substitution into witness std::map rkinds; nn = partialSubstitute(t1, vars, subs, rkinds); - nn = Rewriter::rewrite(nn); + nn = d_rew.rewrite(nn); if (nn != t1) { // If full=false, then we've duplicated a term u in the children of n. @@ -537,7 +538,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full) const Node nn = partialSubstitute(t2, assign, rkinds); if (nn != t2) { - nn = Rewriter::rewrite(nn); + nn = d_rew.rewrite(nn); if (nn == t1) { new_ret = nn; @@ -625,7 +626,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const { children[ii] = n[i][j + 1]; Node pull = nm->mkNode(n.getKind(), children); - Node pullr = Rewriter::rewrite(pull); + Node pullr = d_rew.rewrite(pull); children[ii] = n[i]; ite_c[i][j] = pullr; } @@ -688,7 +689,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n) const Assert(nite.getKind() == itek); // now, simply pull the ITE and try ITE rewrites Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]); - pull_ite = Rewriter::rewrite(pull_ite); + pull_ite = d_rew.rewrite(pull_ite); if (pull_ite.getKind() == ITE) { Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false); @@ -887,7 +888,7 @@ Node ExtendedRewriter::extendedRewriteBcp(Kind andk, ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs); Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs << std::endl; - ccs = Rewriter::rewrite(ccs); + ccs = d_rew.rewrite(ccs); Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl; to_process.push_back(ccs); // store this as a node that propagation touched. This marks c so that @@ -1522,7 +1523,7 @@ Node ExtendedRewriter::extendedRewriteEqChain( index--; new_ret = nm->mkNode(eqk, children[index], new_ret); } - new_ret = Rewriter::rewrite(new_ret); + new_ret = d_rew.rewrite(new_ret); if (new_ret != ret) { return new_ret; diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index b1b08657d..b4dcab041 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -24,6 +24,9 @@ namespace cvc5 { namespace theory { + +class Rewriter; + namespace quantifiers { /** Extended rewriter @@ -48,12 +51,14 @@ namespace quantifiers { class ExtendedRewriter { public: - ExtendedRewriter(bool aggr = true); + ExtendedRewriter(Rewriter& rew, bool aggr = true); ~ExtendedRewriter() {} /** return the extended rewritten form of n */ Node extendedRewrite(Node n) const; private: + /** The underlying rewriter that we are extending */ + Rewriter& d_rew; /** cache that the extended rewritten form of n is ret */ void setCache(Node n, Node ret) const; /** get the cache for n */ diff --git a/src/theory/quantifiers/quantifiers_modules.cpp b/src/theory/quantifiers/quantifiers_modules.cpp index 27ec187a9..6cfc48fb9 100644 --- a/src/theory/quantifiers/quantifiers_modules.cpp +++ b/src/theory/quantifiers/quantifiers_modules.cpp @@ -41,7 +41,8 @@ QuantifiersModules::QuantifiersModules() { } QuantifiersModules::~QuantifiersModules() {} -void QuantifiersModules::initialize(QuantifiersState& qs, +void QuantifiersModules::initialize(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, @@ -72,7 +73,7 @@ void QuantifiersModules::initialize(QuantifiersState& qs, } if (options::sygus()) { - d_synth_e.reset(new SynthEngine(qs, qim, qr, tr)); + d_synth_e.reset(new SynthEngine(env, qs, qim, qr, tr)); modules.push_back(d_synth_e.get()); } // bounded integer instantiation is used when the user requests it via diff --git a/src/theory/quantifiers/quantifiers_modules.h b/src/theory/quantifiers/quantifiers_modules.h index f41e81f34..9878e79ae 100644 --- a/src/theory/quantifiers/quantifiers_modules.h +++ b/src/theory/quantifiers/quantifiers_modules.h @@ -57,7 +57,8 @@ class QuantifiersModules * This constructs the above modules based on the current options. It adds * a pointer to each module it constructs to modules. */ - void initialize(QuantifiersState& qs, + void initialize(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 6d8570287..e5662cdc6 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -548,8 +548,7 @@ Node QuantifiersRewriter::computeExtendedRewrite(Node q) { Node body = q[1]; // apply extended rewriter - ExtendedRewriter er; - Node bodyr = er.extendedRewrite(body); + Node bodyr = Rewriter::callExtendedRewrite(body); if (body != bodyr) { std::vector children; diff --git a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp index d2c616238..80f4af984 100644 --- a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp @@ -400,7 +400,7 @@ Node CegSingleInv::getSolutionFromInst(size_t index) } //simplify the solution using the extended rewriter Trace("csi-sol") << "Solution (pre-simplification): " << s << std::endl; - s = d_treg.getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(s); + s = extendedRewrite(s); Trace("csi-sol") << "Solution (post-simplification): " << s << std::endl; // wrap into lambda, as needed return SygusUtils::wrapSolutionForSynthFun(prog, s); @@ -467,7 +467,7 @@ Node CegSingleInv::reconstructToSyntax(Node s, { Trace("csi-sol") << "Post-process solution..." << std::endl; Node prev = sol; - sol = d_treg.getTermDatabaseSygus()->getExtRewriter()->extendedRewrite(sol); + sol = extendedRewrite(sol); if (prev != sol) { Trace("csi-sol") << "Solution (after post process) : " << sol diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index 57b763044..8d1bfd9b6 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -345,7 +345,7 @@ void Cegis::addRefinementLemma(Node lem) d_rl_vals.end()); } // rewrite with extended rewriter - slem = d_tds->getExtRewriter()->extendedRewrite(slem); + slem = extendedRewrite(slem); // collect all variables in slem expr::getSymbols(slem, d_refinement_lemma_vars); std::vector waiting; diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp index f853ac8e8..a5be4ebd6 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp @@ -16,15 +16,16 @@ #include "theory/quantifiers/sygus/enum_stream_substitution.h" +#include // for std::iota +#include + #include "expr/dtype_cons.h" #include "options/base_options.h" #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "printer/printer.h" #include "theory/quantifiers/sygus/term_database_sygus.h" - -#include // for std::iota -#include +#include "theory/rewriter.h" using namespace cvc5::kind; @@ -32,7 +33,7 @@ namespace cvc5 { namespace theory { namespace quantifiers { -EnumStreamPermutation::EnumStreamPermutation(quantifiers::TermDbSygus* tds) +EnumStreamPermutation::EnumStreamPermutation(TermDbSygus* tds) : d_tds(tds), d_first(true), d_curr_ind(0) { } @@ -124,8 +125,7 @@ Node EnumStreamPermutation::getNext() { d_first = false; Node bultin_value = d_tds->sygusToBuiltin(d_value, d_value.getType()); - d_perm_values.insert( - d_tds->getExtRewriter()->extendedRewrite(bultin_value)); + d_perm_values.insert(Rewriter::callExtendedRewrite(bultin_value)); return d_value; } unsigned n_classes = d_perm_state_class.size(); @@ -194,8 +194,7 @@ Node EnumStreamPermutation::getNext() << " ......perm builtin is " << bultin_perm_value; if (options::sygusSymBreakDynamic()) { - bultin_perm_value = - d_tds->getExtRewriter()->extendedRewrite(bultin_perm_value); + bultin_perm_value = Rewriter::callExtendedRewrite(bultin_perm_value); Trace("synth-stream-concrete-debug") << " and rewrites to " << bultin_perm_value; } @@ -515,8 +514,7 @@ Node EnumStreamSubstitution::getNext() d_tds->sygusToBuiltin(comb_value, comb_value.getType()); if (options::sygusSymBreakDynamic()) { - builtin_comb_value = - d_tds->getExtRewriter()->extendedRewrite(builtin_comb_value); + builtin_comb_value = Rewriter::callExtendedRewrite(builtin_comb_value); } if (Trace.isOn("synth-stream-concrete")) { diff --git a/src/theory/quantifiers/sygus/enum_value_manager.cpp b/src/theory/quantifiers/sygus/enum_value_manager.cpp index 8a2d70bfa..1d0ba5bee 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.cpp +++ b/src/theory/quantifiers/sygus/enum_value_manager.cpp @@ -33,13 +33,15 @@ namespace cvc5 { namespace theory { namespace quantifiers { -EnumValueManager::EnumValueManager(Node e, +EnumValueManager::EnumValueManager(Env& env, QuantifiersState& qs, QuantifiersInferenceManager& qim, TermRegistry& tr, SygusStatistics& s, + Node e, bool hasExamples) - : d_enum(e), + : EnvObj(env), + d_enum(e), d_qstate(qs), d_qim(qim), d_treg(tr), diff --git a/src/theory/quantifiers/sygus/enum_value_manager.h b/src/theory/quantifiers/sygus/enum_value_manager.h index c786bb6f1..23fdc7391 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.h +++ b/src/theory/quantifiers/sygus/enum_value_manager.h @@ -19,6 +19,7 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_VALUE_MANAGER_H #include "expr/node.h" +#include "smt/env_obj.h" #include "theory/quantifiers/sygus/enum_val_generator.h" #include "theory/quantifiers/sygus/example_eval_cache.h" #include "theory/quantifiers/sygus/sygus_enumerator_callback.h" @@ -38,14 +39,15 @@ class SygusStatistics; * not actively generated, or may be determined by the (fast) enumerator * when it is actively generated. */ -class EnumValueManager +class EnumValueManager : protected EnvObj { public: - EnumValueManager(Node e, + EnumValueManager(Env& env, QuantifiersState& qs, QuantifiersInferenceManager& qim, TermRegistry& tr, SygusStatistics& s, + Node e, bool hasExamples); ~EnumValueManager(); /** diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp index f45b976ec..743f67cec 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/sygus/sygus_enumerator_basic.h" #include "options/datatypes_options.h" +#include "theory/rewriter.h" using namespace cvc5::kind; using namespace std; @@ -40,7 +41,7 @@ bool EnumValGeneratorBasic::increment() if (options::sygusSymBreakDynamic()) { Node nextb = d_tds->sygusToBuiltin(d_currTerm); - nextb = d_tds->getExtRewriter()->extendedRewrite(nextb); + nextb = Rewriter::callExtendedRewrite(nextb); if (d_cache.find(nextb) == d_cache.end()) { d_cache.insert(nextb); diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp index 3b536695f..1b5b3f5af 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp @@ -19,6 +19,7 @@ #include "theory/quantifiers/sygus/example_eval_cache.h" #include "theory/quantifiers/sygus/sygus_stats.h" #include "theory/quantifiers/sygus_sampler.h" +#include "theory/rewriter.h" namespace cvc5 { namespace theory { @@ -33,7 +34,7 @@ SygusEnumeratorCallback::SygusEnumeratorCallback(Node e, SygusStatistics* s) bool SygusEnumeratorCallback::addTerm(Node n, std::unordered_set& bterms) { Node bn = datatypes::utils::sygusToBuiltin(n); - Node bnr = d_extr.extendedRewrite(bn); + Node bnr = Rewriter::callExtendedRewrite(bn); if (d_stats != nullptr) { ++(d_stats->d_enumTermsRewrite); diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h index 5ed28b309..8689d876f 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h @@ -74,8 +74,6 @@ class SygusEnumeratorCallback Node d_enum; /** The type of enum */ TypeNode d_tn; - /** extended rewriter */ - ExtendedRewriter d_extr; /** pointer to the statistics */ SygusStatistics* d_stats; }; diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp index a51fcce25..fd84f0c0a 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp @@ -21,6 +21,7 @@ #include "options/quantifiers_options.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" +#include "theory/rewriter.h" using namespace std; using namespace cvc5::kind; @@ -147,7 +148,7 @@ void SygusRedundantCons::getGenericList(TermDbSygus* tds, if (index == dt[c].getNumArgs()) { Node gt = tds->mkGeneric(dt, c, pre); - gt = tds->getExtRewriter()->extendedRewrite(gt); + gt = Rewriter::callExtendedRewrite(gt); terms.push_back(gt); return; } diff --git a/src/theory/quantifiers/sygus/sygus_invariance.cpp b/src/theory/quantifiers/sygus/sygus_invariance.cpp index cb7e2b84e..29557fe5c 100644 --- a/src/theory/quantifiers/sygus/sygus_invariance.cpp +++ b/src/theory/quantifiers/sygus/sygus_invariance.cpp @@ -111,7 +111,7 @@ bool EquivSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Node nbvr = Rewriter::callExtendedRewrite(nbv); Trace("sygus-sb-mexp-debug") << " min-exp check : " << nbv << " -> " << nbvr << std::endl; bool exc_arg = false; @@ -181,7 +181,7 @@ bool DivByZeroSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Node nbvr = Rewriter::callExtendedRewrite(nbv); if (tds->involvesDivByZero(nbvr)) { Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) @@ -212,7 +212,7 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Node nbvr = Rewriter::callExtendedRewrite(nbv); // if for any of the examples, it is not contained, then we can exclude for (unsigned i = 0; i < d_neg_con_indices.size(); i++) { diff --git a/src/theory/quantifiers/sygus/sygus_pbe.cpp b/src/theory/quantifiers/sygus/sygus_pbe.cpp index 7601e2117..52bca1586 100644 --- a/src/theory/quantifiers/sygus/sygus_pbe.cpp +++ b/src/theory/quantifiers/sygus/sygus_pbe.cpp @@ -131,7 +131,7 @@ bool SygusPbe::initialize(Node conj, // Apply extended rewriting on the lemma. This helps utilities like // SygusEnumerator more easily recognize the shape of this lemma, e.g. // ( ~is-ite(x) or ( ~is-ite(x) ^ P ) ) --> ~is-ite(x). - lem = d_tds->getExtRewriter()->extendedRewrite(lem); + lem = extendedRewrite(lem); Trace("sygus-pbe") << " static redundant op lemma : " << lem << std::endl; // Register as a symmetry breaking lemma with the term database. diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 9626f7af4..3fb80f917 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -569,7 +569,7 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector& lemmas) std::vector base_results; TypeNode xtn = e.getType(); Node bv = d_tds->sygusToBuiltin(v, xtn); - bv = d_tds->getExtRewriter()->extendedRewrite(bv); + bv = extendedRewrite(bv); Trace("sygus-sui-enum") << "PBE Compute Examples for " << bv << std::endl; // compte the results (should be cached) ExampleEvalCache* eec = d_parent->getExampleEvalCache(e); diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index 3e7095c12..e87857c3b 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -45,12 +45,14 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SynthConjecture::SynthConjecture(QuantifiersState& qs, +SynthConjecture::SynthConjecture(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, SygusStatistics& s) - : d_qstate(qs), + : EnvObj(env), + d_qstate(qs), d_qim(qim), d_qreg(qr), d_treg(tr), @@ -58,11 +60,11 @@ SynthConjecture::SynthConjecture(QuantifiersState& qs, d_tds(tr.getTermDatabaseSygus()), d_verify(qs.options(), qs.getLogicInfo(), d_tds), d_hasSolution(false), - d_ceg_si(new CegSingleInv(qs.getEnv(), tr, s)), + d_ceg_si(new CegSingleInv(env, tr, s)), d_templInfer(new SygusTemplateInfer), d_ceg_proc(new SynthConjectureProcess), d_ceg_gc(new CegGrammarConstructor(d_tds, this)), - d_sygus_rconst(new SygusRepairConst(qs.getEnv(), d_tds)), + d_sygus_rconst(new SygusRepairConst(env, d_tds)), d_exampleInfer(new ExampleInfer(d_tds)), d_ceg_pbe(new SygusPbe(qs, qim, d_tds, this)), d_ceg_cegis(new Cegis(qs, qim, d_tds, this)), @@ -609,8 +611,7 @@ bool SynthConjecture::checkSideCondition(const std::vector& cvals) const } Trace("sygus-engine") << "Check side condition..." << std::endl; Trace("cegqi-debug") << "Check side condition : " << sc << std::endl; - Env& env = d_qstate.getEnv(); - Result r = checkWithSubsolver(sc, env.getOptions(), env.getLogicInfo()); + Result r = checkWithSubsolver(sc, options(), logicInfo()); Trace("cegqi-debug") << "...got side condition : " << r << std::endl; if (r == Result::UNSAT) { @@ -763,8 +764,8 @@ EnumValueManager* SynthConjecture::getEnumValueManagerFor(Node e) Node f = d_tds->getSynthFunForEnumerator(e); bool hasExamples = (d_exampleInfer->hasExamples(f) && d_exampleInfer->getNumExamples(f) != 0); - d_enumManager[e].reset( - new EnumValueManager(e, d_qstate, d_qim, d_treg, d_stats, hasExamples)); + d_enumManager[e].reset(new EnumValueManager( + d_env, d_qstate, d_qim, d_treg, d_stats, e, hasExamples)); EnumValueManager* eman = d_enumManager[e].get(); // set up the examples if (hasExamples) @@ -885,7 +886,7 @@ void SynthConjecture::printSynthSolutionInternal(std::ostream& out) d_exprm.find(prog); if (its == d_exprm.end()) { - d_exprm[prog].reset(new ExpressionMinerManager(d_qstate.getEnv())); + d_exprm[prog].reset(new ExpressionMinerManager(d_env)); ExpressionMinerManager* emm = d_exprm[prog].get(); emm->initializeSygus( d_tds, d_candidates[i], options::sygusSamples(), true); diff --git a/src/theory/quantifiers/sygus/synth_conjecture.h b/src/theory/quantifiers/sygus/synth_conjecture.h index 9cc488fd2..d7635c816 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.h +++ b/src/theory/quantifiers/sygus/synth_conjecture.h @@ -21,6 +21,7 @@ #include +#include "smt/env_obj.h" #include "theory/quantifiers/expr_miner_manager.h" #include "theory/quantifiers/sygus/ce_guided_single_inv.h" #include "theory/quantifiers/sygus/cegis.h" @@ -51,10 +52,11 @@ class EnumValueManager; * determines which approach and optimizations are applicable to the * conjecture, and has interfaces for implementing them. */ -class SynthConjecture +class SynthConjecture : protected EnvObj { public: - SynthConjecture(QuantifiersState& qs, + SynthConjecture(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr, diff --git a/src/theory/quantifiers/sygus/synth_engine.cpp b/src/theory/quantifiers/sygus/synth_engine.cpp index cdcbeb85d..64227793d 100644 --- a/src/theory/quantifiers/sygus/synth_engine.cpp +++ b/src/theory/quantifiers/sygus/synth_engine.cpp @@ -26,14 +26,15 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SynthEngine::SynthEngine(QuantifiersState& qs, +SynthEngine::SynthEngine(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr) : QuantifiersModule(qs, qim, qr, tr), d_conj(nullptr), d_sqp(qs.getEnv()) { d_conjs.push_back(std::unique_ptr( - new SynthConjecture(qs, qim, qr, tr, d_statistics))); + new SynthConjecture(env, qs, qim, qr, tr, d_statistics))); d_conj = d_conjs.back().get(); } @@ -153,8 +154,8 @@ void SynthEngine::assignConjecture(Node q) // allocate a new synthesis conjecture if not assigned if (d_conjs.back()->isAssigned()) { - d_conjs.push_back(std::unique_ptr( - new SynthConjecture(d_qstate, d_qim, d_qreg, d_treg, d_statistics))); + d_conjs.push_back(std::unique_ptr(new SynthConjecture( + d_env, d_qstate, d_qim, d_qreg, d_treg, d_statistics))); } d_conjs.back()->assign(q); } diff --git a/src/theory/quantifiers/sygus/synth_engine.h b/src/theory/quantifiers/sygus/synth_engine.h index d37df4e28..c623d9c0f 100644 --- a/src/theory/quantifiers/sygus/synth_engine.h +++ b/src/theory/quantifiers/sygus/synth_engine.h @@ -34,7 +34,8 @@ class SynthEngine : public QuantifiersModule typedef context::CDHashMap NodeBoolMap; public: - SynthEngine(QuantifiersState& qs, + SynthEngine(Env& env, + QuantifiersState& qs, QuantifiersInferenceManager& qim, QuantifiersRegistry& qr, TermRegistry& tr); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 3b0ea3312..9c9a90255 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -51,10 +51,10 @@ std::ostream& operator<<(std::ostream& os, EnumeratorRole r) return os; } -TermDbSygus::TermDbSygus(QuantifiersState& qs) - : d_qstate(qs), +TermDbSygus::TermDbSygus(Env& env, QuantifiersState& qs) + : EnvObj(env), + d_qstate(qs), d_syexp(new SygusExplain(this)), - d_ext_rw(new ExtendedRewriter(true)), d_eval(new Evaluator), d_funDefEval(new FunDefEvaluator), d_eval_unfold(new SygusEvalUnfold(this)) @@ -1036,7 +1036,7 @@ Node TermDbSygus::evaluateWithUnfolding(Node n, } if (options::sygusExtRew()) { - ret = getExtRewriter()->extendedRewrite(ret); + ret = extendedRewrite(ret); } // use rewriting, possibly involving recursive functions ret = rewriteNode(ret); diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 80411b258..a44ebd297 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -21,6 +21,7 @@ #include #include "expr/dtype.h" +#include "smt/env_obj.h" #include "theory/evaluator.h" #include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/fun_def_evaluator.h" @@ -53,9 +54,10 @@ enum EnumeratorRole std::ostream& operator<<(std::ostream& os, EnumeratorRole r); // TODO :issue #1235 split and document this class -class TermDbSygus { +class TermDbSygus : protected EnvObj +{ public: - TermDbSygus(QuantifiersState& qs); + TermDbSygus(Env& env, QuantifiersState& qs); ~TermDbSygus() {} /** Finish init, which sets the inference manager */ void finishInit(QuantifiersInferenceManager* qim); @@ -78,8 +80,6 @@ class TermDbSygus { //------------------------------utilities /** get the explanation utility */ SygusExplain* getExplain() { return d_syexp.get(); } - /** get the extended rewrite utility */ - ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } /** get the evaluator */ Evaluator* getEvaluator() { return d_eval.get(); } /** (recursive) function evaluator utility */ @@ -324,8 +324,6 @@ class TermDbSygus { //------------------------------utilities /** sygus explanation */ std::unique_ptr d_syexp; - /** extended rewriter */ - std::unique_ptr d_ext_rw; /** evaluator */ std::unique_ptr d_eval; /** (recursive) function evaluator utility */ @@ -461,7 +459,6 @@ class TermDbSygus { /** get anchor */ static Node getAnchor( Node n ); static unsigned getAnchorDepth( Node n ); - }; } // namespace quantifiers diff --git a/src/theory/quantifiers/term_registry.cpp b/src/theory/quantifiers/term_registry.cpp index 324217798..36dc8865c 100644 --- a/src/theory/quantifiers/term_registry.cpp +++ b/src/theory/quantifiers/term_registry.cpp @@ -29,7 +29,9 @@ namespace cvc5 { namespace theory { namespace quantifiers { -TermRegistry::TermRegistry(QuantifiersState& qs, QuantifiersRegistry& qr) +TermRegistry::TermRegistry(Env& env, + QuantifiersState& qs, + QuantifiersRegistry& qr) : d_presolve(qs.getUserContext(), true), d_presolveCache(qs.getUserContext()), d_termEnum(new TermEnumeration), @@ -42,7 +44,7 @@ TermRegistry::TermRegistry(QuantifiersState& qs, QuantifiersRegistry& qr) if (options::sygus() || options::sygusInst()) { // must be constructed here since it is required for datatypes finistInit - d_sygusTdb.reset(new TermDbSygus(qs)); + d_sygusTdb.reset(new TermDbSygus(env, qs)); } Trace("quant-engine-debug") << "Initialize quantifiers engine." << std::endl; Trace("quant-engine-debug") diff --git a/src/theory/quantifiers/term_registry.h b/src/theory/quantifiers/term_registry.h index c3e4fcf4c..e0ce73286 100644 --- a/src/theory/quantifiers/term_registry.h +++ b/src/theory/quantifiers/term_registry.h @@ -42,8 +42,7 @@ class TermRegistry using NodeSet = context::CDHashSet; public: - TermRegistry(QuantifiersState& qs, - QuantifiersRegistry& qr); + TermRegistry(Env& env, QuantifiersState& qs, QuantifiersRegistry& qr); /** Finish init, which sets the inference manager on modules of this class */ void finishInit(FirstOrderModel* fm, QuantifiersInferenceManager* qim); /** Presolve */ diff --git a/src/theory/quantifiers/theory_quantifiers.cpp b/src/theory/quantifiers/theory_quantifiers.cpp index dff0ac979..137e25c89 100644 --- a/src/theory/quantifiers/theory_quantifiers.cpp +++ b/src/theory/quantifiers/theory_quantifiers.cpp @@ -36,13 +36,13 @@ TheoryQuantifiers::TheoryQuantifiers(Env& env, : Theory(THEORY_QUANTIFIERS, env, out, valuation), d_qstate(env, valuation, logicInfo()), d_qreg(), - d_treg(d_qstate, d_qreg), + d_treg(env, d_qstate, d_qreg), d_qim(env, *this, d_qstate, d_qreg, d_treg, d_pnm), d_qengine(nullptr) { // construct the quantifiers engine d_qengine.reset( - new QuantifiersEngine(d_qstate, d_qreg, d_treg, d_qim, d_pnm)); + new QuantifiersEngine(env, d_qstate, d_qreg, d_treg, d_qim, d_pnm)); // indicate we are using the quantifiers theory state object d_theoryState = &d_qstate; diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index 213b6c55e..40923ad0d 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -45,12 +45,14 @@ namespace cvc5 { namespace theory { QuantifiersEngine::QuantifiersEngine( + Env& env, quantifiers::QuantifiersState& qs, quantifiers::QuantifiersRegistry& qr, quantifiers::TermRegistry& tr, quantifiers::QuantifiersInferenceManager& qim, ProofNodeManager* pnm) - : d_qstate(qs), + : EnvObj(env), + d_qstate(qs), d_qim(qim), d_te(nullptr), d_pnm(pnm), @@ -113,7 +115,7 @@ void QuantifiersEngine::finishInit(TheoryEngine* te) // Initialize the modules and the utilities here. d_qmodules.reset(new quantifiers::QuantifiersModules); d_qmodules->initialize( - d_qstate, d_qim, d_qreg, d_treg, d_builder.get(), d_modules); + d_env, d_qstate, d_qim, d_qreg, d_treg, d_builder.get(), d_modules); if (d_qmodules->d_rel_dom.get()) { d_util.push_back(d_qmodules->d_rel_dom.get()); diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index 23b6d9708..e8c385fcd 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -24,6 +24,7 @@ #include "context/cdhashmap.h" #include "context/cdhashset.h" #include "context/cdlist.h" +#include "smt/env_obj.h" #include "theory/quantifiers/quant_util.h" namespace cvc5 { @@ -51,14 +52,18 @@ class TermEnumeration; class TermRegistry; } -// TODO: organize this more/review this, github issue #1163 -class QuantifiersEngine { +/** + * The main class that manages techniques for quantified formulas. + */ +class QuantifiersEngine : protected EnvObj +{ friend class ::cvc5::TheoryEngine; typedef context::CDHashMap BoolMap; typedef context::CDHashSet NodeSet; public: - QuantifiersEngine(quantifiers::QuantifiersState& qstate, + QuantifiersEngine(Env& env, + quantifiers::QuantifiersState& qstate, quantifiers::QuantifiersRegistry& qr, quantifiers::TermRegistry& tr, quantifiers::QuantifiersInferenceManager& qim, @@ -208,7 +213,7 @@ public: std::map d_quants_red_lem; /** Number of rounds we have instantiated */ uint32_t d_numInstRoundsLemma; -};/* class QuantifiersEngine */ +}; /* class QuantifiersEngine */ } // namespace theory } // namespace cvc5 diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 5c4cc5536..460813084 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -92,10 +92,6 @@ struct RewriteStackElement { NodeBuilder d_builder; }; -RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n) -{ - return RewriteResponse(REWRITE_DONE, n); -} Node Rewriter::rewrite(TNode node) { if (node.getNumChildren() == 0) @@ -107,6 +103,17 @@ Node Rewriter::rewrite(TNode node) { return getInstance()->rewriteTo(theoryOf(node), node); } +Node Rewriter::callExtendedRewrite(TNode node, bool aggr) +{ + return getInstance()->extendedRewrite(node, aggr); +} + +Node Rewriter::extendedRewrite(TNode node, bool aggr) +{ + quantifiers::ExtendedRewriter er(*this, aggr); + return er.extendedRewrite(node); +} + TrustNode Rewriter::rewriteWithProof(TNode node, bool isExtEq) { @@ -480,8 +487,7 @@ Node Rewriter::rewriteViaMethod(TNode n, MethodId idr) } if (idr == MethodId::RW_EXT_REWRITE) { - quantifiers::ExtendedRewriter er; - return er.extendedRewrite(n); + return extendedRewrite(n); } if (idr == MethodId::RW_REWRITE_EQ_EXT) { diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index 63628b0af..d87043a67 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -29,41 +29,25 @@ class TrustNode; namespace theory { -namespace builtin { -class BuiltinProofRuleChecker; -} - -/** - * The rewrite environment holds everything that the individual rewrites have - * access to. - */ -class RewriteEnvironment -{ -}; - -/** - * The identity rewrite just returns the original node. - * - * @param re The rewrite environment - * @param n The node to rewrite - * @return The original node - */ -RewriteResponse identityRewrite(RewriteEnvironment* re, TNode n); - /** * The main rewriter class. */ class Rewriter { - friend builtin::BuiltinProofRuleChecker; public: Rewriter(); /** + * !!! Temporary until static access to rewriter is eliminated. + * * Rewrites the node using theoryOf() to determine which rewriter to * use on the node. */ static Node rewrite(TNode node); + /** + * !!! Temporary until static access to rewriter is eliminated. + */ + static Node callExtendedRewrite(TNode node, bool aggr = true); /** * Rewrites the equality node using theoryOf() to determine which rewriter to @@ -77,6 +61,16 @@ class Rewriter { */ Node rewriteEqualityExt(TNode node); + /** + * Extended rewrite of the given node. This method is implemented by a + * custom ExtendRewriter class that wraps this class to perform custom + * rewrites (usually those that are not useful for solving, but e.g. useful + * for SyGuS symmetry breaking). + * @param node The node to rewrite + * @param aggr Whether to perform aggressive rewrites. + */ + Node extendedRewrite(TNode node, bool aggr = true); + /** * Rewrite with proof production, which is managed by the term conversion * proof generator managed by this class (d_tpg). This method requires a call @@ -174,8 +168,6 @@ class Rewriter { /** Theory rewriters used by this rewriter instance */ TheoryRewriter* d_theoryRewriters[theory::THEORY_LAST]; - RewriteEnvironment d_re; - /** The proof generator */ std::unique_ptr d_tpg; #ifdef CVC5_ASSERTIONS diff --git a/test/unit/theory/sequences_rewriter_white.cpp b/test/unit/theory/sequences_rewriter_white.cpp index 77671dc34..b7339942e 100644 --- a/test/unit/theory/sequences_rewriter_white.cpp +++ b/test/unit/theory/sequences_rewriter_white.cpp @@ -20,7 +20,6 @@ #include "expr/node.h" #include "expr/node_manager.h" #include "test_smt.h" -#include "theory/quantifiers/extended_rewrite.h" #include "theory/rewriter.h" #include "theory/strings/arith_entail.h" #include "theory/strings/sequences_rewriter.h" @@ -32,7 +31,6 @@ namespace cvc5 { using namespace theory; -using namespace theory::quantifiers; using namespace theory::strings; namespace test { @@ -44,10 +42,10 @@ class TestTheoryWhiteSequencesRewriter : public TestSmt { TestSmt::SetUp(); Options opts; - d_rewriter.reset(new ExtendedRewriter(true)); + d_rewriter = d_smtEngine->getRewriter(); } - std::unique_ptr d_rewriter; + Rewriter* d_rewriter; void inNormalForm(Node t) { @@ -155,7 +153,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) Node slen_y = d_nodeManager->mkNode(kind::STRING_LENGTH, y); Node x_plus_slen_y = d_nodeManager->mkNode(kind::PLUS, x, slen_y); - Node x_plus_slen_y_eq_zero = Rewriter::rewrite( + Node x_plus_slen_y_eq_zero = d_rewriter->rewrite( d_nodeManager->mkNode(kind::EQUAL, x_plus_slen_y, zero)); // x + (str.len y) = 0 |= 0 >= x --> true @@ -166,7 +164,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) ASSERT_FALSE( ArithEntail::checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, true)); - Node x_plus_slen_y_plus_z_eq_zero = Rewriter::rewrite(d_nodeManager->mkNode( + Node x_plus_slen_y_plus_z_eq_zero = d_rewriter->rewrite(d_nodeManager->mkNode( kind::EQUAL, d_nodeManager->mkNode(kind::PLUS, x_plus_slen_y, z), zero)); // x + (str.len y) + z = 0 |= 0 > x --> false @@ -174,7 +172,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) x_plus_slen_y_plus_z_eq_zero, zero, x, true)); Node x_plus_slen_y_plus_slen_y_eq_zero = - Rewriter::rewrite(d_nodeManager->mkNode( + d_rewriter->rewrite(d_nodeManager->mkNode( kind::EQUAL, d_nodeManager->mkNode(kind::PLUS, x_plus_slen_y, slen_y), zero)); @@ -187,7 +185,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) Node six = d_nodeManager->mkConst(Rational(6)); Node x_plus_five = d_nodeManager->mkNode(kind::PLUS, x, five); Node x_plus_five_lt_six = - Rewriter::rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, six)); + d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, six)); // x + 5 < 6 |= 0 >= x --> true ASSERT_TRUE( @@ -199,7 +197,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) Node neg_x = d_nodeManager->mkNode(kind::UMINUS, x); Node x_plus_five_lt_five = - Rewriter::rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, five)); + d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, five)); // x + 5 < 5 |= -x >= 0 --> true ASSERT_TRUE(ArithEntail::checkWithAssumption( @@ -210,7 +208,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) ArithEntail::checkWithAssumption(x_plus_five_lt_five, zero, x, false)); // 0 < x |= x >= (str.len (int.to.str x)) - Node assm = Rewriter::rewrite(d_nodeManager->mkNode(kind::LT, zero, x)); + Node assm = d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, zero, x)); ASSERT_TRUE(ArithEntail::checkWithAssumption( assm, x, diff --git a/test/unit/theory/theory_arith_cad_white.cpp b/test/unit/theory/theory_arith_cad_white.cpp index 4c9b104c7..3d5f7a8c5 100644 --- a/test/unit/theory/theory_arith_cad_white.cpp +++ b/test/unit/theory/theory_arith_cad_white.cpp @@ -29,7 +29,7 @@ #include "theory/arith/nl/nl_lemma_utils.h" #include "theory/arith/nl/poly_conversion.h" #include "theory/arith/theory_arith.h" -#include "theory/quantifiers/extended_rewrite.h" +#include "theory/rewriter.h" #include "theory/theory.h" #include "theory/theory_engine.h" #include "util/poly_util.h" @@ -193,8 +193,7 @@ TEST_F(TestTheoryWhiteArithCAD, lazard_simp) EXPECT_NE(rewritten, d_nodeManager->mkConst(false)); } { - quantifiers::ExtendedRewriter extrew; - Node rewritten = extrew.extendedRewrite(orig); + Node rewritten = Rewriter::callExtendedRewrite(orig); EXPECT_EQ(rewritten, d_nodeManager->mkConst(false)); } } -- 2.30.2