From 01cde22b7d69c1b1037cf1d536ca62becc3bd865 Mon Sep 17 00:00:00 2001 From: Aina Niemetz Date: Tue, 7 Sep 2021 11:23:29 -0700 Subject: [PATCH] sygus: Eliminate calls to Rewriter::rewrite. (#7142) This derives sygus unification utility objects from EnvObj where necessary. There's one remaining occurrence of Rewriter::rewrite in sygus_unif_rl.cpp, which is a bit tricky to address and thus subject to a future PR. --- src/theory/quantifiers/sygus/cegis_unif.cpp | 4 ++- src/theory/quantifiers/sygus/sygus_pbe.cpp | 2 +- src/theory/quantifiers/sygus/sygus_unif.cpp | 9 +++++-- src/theory/quantifiers/sygus/sygus_unif.h | 4 +-- .../quantifiers/sygus/sygus_unif_io.cpp | 27 ++++++++++--------- src/theory/quantifiers/sygus/sygus_unif_io.h | 2 +- .../quantifiers/sygus/sygus_unif_rl.cpp | 27 ++++++++++--------- src/theory/quantifiers/sygus/sygus_unif_rl.h | 2 +- .../quantifiers/sygus/sygus_unif_strat.cpp | 3 +-- .../quantifiers/sygus/sygus_unif_strat.h | 6 +++-- 10 files changed, 49 insertions(+), 37 deletions(-) diff --git a/src/theory/quantifiers/sygus/cegis_unif.cpp b/src/theory/quantifiers/sygus/cegis_unif.cpp index c4d9cbd4a..797aecdab 100644 --- a/src/theory/quantifiers/sygus/cegis_unif.cpp +++ b/src/theory/quantifiers/sygus/cegis_unif.cpp @@ -34,7 +34,9 @@ CegisUnif::CegisUnif(QuantifiersState& qs, QuantifiersInferenceManager& qim, TermDbSygus* tds, SynthConjecture* p) - : Cegis(qs, qim, tds, p), d_sygus_unif(p), d_u_enum_manager(qs, qim, tds, p) + : Cegis(qs, qim, tds, p), + d_sygus_unif(qs.getEnv(), p), + d_u_enum_manager(qs, qim, tds, p) { } diff --git a/src/theory/quantifiers/sygus/sygus_pbe.cpp b/src/theory/quantifiers/sygus/sygus_pbe.cpp index 26621eb96..7601e2117 100644 --- a/src/theory/quantifiers/sygus/sygus_pbe.cpp +++ b/src/theory/quantifiers/sygus/sygus_pbe.cpp @@ -71,7 +71,7 @@ bool SygusPbe::initialize(Node conj, for (const Node& c : candidates) { Assert(ei->hasExamples(c)); - d_sygus_unif[c].reset(new SygusUnifIo(d_parent)); + d_sygus_unif[c].reset(new SygusUnifIo(d_env, d_parent)); Trace("sygus-pbe") << "Initialize unif utility for " << c << "..." << std::endl; std::map> strategy_lemmas; diff --git a/src/theory/quantifiers/sygus/sygus_unif.cpp b/src/theory/quantifiers/sygus/sygus_unif.cpp index 00370ffa2..0787b7913 100644 --- a/src/theory/quantifiers/sygus/sygus_unif.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif.cpp @@ -27,7 +27,11 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SygusUnif::SygusUnif() : d_tds(nullptr), d_enableMinimality(false) {} +SygusUnif::SygusUnif(Env& env) + : EnvObj(env), d_tds(nullptr), d_enableMinimality(false) +{ +} + SygusUnif::~SygusUnif() {} void SygusUnif::initializeCandidate( @@ -39,7 +43,8 @@ void SygusUnif::initializeCandidate( d_tds = tds; d_candidates.push_back(f); // initialize the strategy - d_strategy[f].initialize(tds, f, enums); + d_strategy.emplace(f, SygusUnifStrategy(d_env)); + d_strategy.at(f).initialize(tds, f, enums); } Node SygusUnif::getMinimalTerm(const std::vector& terms) diff --git a/src/theory/quantifiers/sygus/sygus_unif.h b/src/theory/quantifiers/sygus/sygus_unif.h index 2e9e34aa6..80368ea13 100644 --- a/src/theory/quantifiers/sygus/sygus_unif.h +++ b/src/theory/quantifiers/sygus/sygus_unif.h @@ -42,10 +42,10 @@ class TermDbSygus; * Based on the above, solutions can be constructed via calls to * constructSolution. */ -class SygusUnif +class SygusUnif : protected EnvObj { public: - SygusUnif(); + SygusUnif(Env& env); virtual ~SygusUnif(); /** initialize candidate diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 8207a07f2..9626f7af4 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -509,8 +509,9 @@ void SubsumeTrie::getLeaves(const std::vector& vals, getLeavesInternal(vals, pol, v, 0, -2); } -SygusUnifIo::SygusUnifIo(SynthConjecture* p) - : d_parent(p), +SygusUnifIo::SygusUnifIo(Env& env, SynthConjecture* p) + : SygusUnif(env), + d_parent(p), d_check_sol(false), d_cond_count(0), d_sol_term_size(0), @@ -549,7 +550,7 @@ void SygusUnifIo::initializeCandidate( d_ecache.clear(); SygusUnif::initializeCandidate(tds, f, enums, strategy_lemmas); // learn redundant operators based on the strategy - d_strategy[f].staticLearnRedundantOps(strategy_lemmas); + d_strategy.at(f).staticLearnRedundantOps(strategy_lemmas); } void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector& lemmas) @@ -560,7 +561,7 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector& lemmas) Assert(!d_examples.empty()); Assert(d_examples.size() == d_examples_out.size()); - EnumInfo& ei = d_strategy[c].getEnumInfo(e); + EnumInfo& ei = d_strategy.at(c).getEnumInfo(e); // The explanation for why the current value should be excluded in future // iterations. Node exp_exc; @@ -583,7 +584,7 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector& lemmas) for (const Node& xs : ei.d_enum_slave) { Assert(srmap.find(xs) == srmap.end()); - EnumInfo& eiv = d_strategy[c].getEnumInfo(xs); + EnumInfo& eiv = d_strategy.at(c).getEnumInfo(xs); Node templ = eiv.d_template; if (!templ.isNull()) { @@ -628,7 +629,7 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector& lemmas) for (unsigned s = 0; s < ei.d_enum_slave.size(); s++) { Node xs = ei.d_enum_slave[s]; - EnumInfo& eiv = d_strategy[c].getEnumInfo(xs); + EnumInfo& eiv = d_strategy.at(c).getEnumInfo(xs); EnumCache& ecv = d_ecache[xs]; Trace("sygus-sui-enum") << "Process " << xs << " from " << s << std::endl; // bool prevIsCover = false; @@ -829,7 +830,7 @@ Node SygusUnifIo::constructSolutionNode(std::vector& lemmas) initializeConstructSol(); initializeConstructSolFor(c); // call the virtual construct solution method - Node e = d_strategy[c].getRootEnumerator(); + Node e = d_strategy.at(c).getRootEnumerator(); Node vcc = constructSol(c, e, role_equal, 1, lemmas); // if we constructed the solution, and we either did not previously have // a solution, or the new solution is better (smaller). @@ -892,10 +893,10 @@ bool SygusUnifIo::useStrContainsEnumeratorExclude(Node e) << "Is " << e << " is str.contains exclusion?" << std::endl; d_use_str_contains_eexc[e] = true; Node c = d_candidate; - EnumInfo& ei = d_strategy[c].getEnumInfo(e); + EnumInfo& ei = d_strategy.at(c).getEnumInfo(e); for (const Node& sn : ei.d_enum_slave) { - EnumInfo& eis = d_strategy[c].getEnumInfo(sn); + EnumInfo& eis = d_strategy.at(c).getEnumInfo(sn); EnumRole er = eis.getRole(); if (er != enum_io && er != enum_concat_term) { @@ -945,7 +946,7 @@ bool SygusUnifIo::getExplanationForEnumeratorExclude( << "Check enumerator exclusion for " << e << " -> " << d_tds->sygusToBuiltin(v) << " based on str.contains." << std::endl; std::vector cmp_indices; - for (unsigned i = 0, size = results.size(); i < size; i++) + for (size_t i = 0, size = results.size(); i < size; i++) { // If the result is not constant, then it is worthless. It does not // impact whether the term is excluded. @@ -955,7 +956,7 @@ bool SygusUnifIo::getExplanationForEnumeratorExclude( Trace("sygus-sui-cterm-debug") << " " << results[i] << " <> " << d_examples_out[i]; Node cont = nm->mkNode(STRING_CONTAINS, d_examples_out[i], results[i]); - Node contr = Rewriter::rewrite(cont); + Node contr = rewrite(cont); if (contr == d_false) { cmp_indices.push_back(i); @@ -1039,10 +1040,10 @@ Node SygusUnifIo::constructSol( Trace("sygus-sui-dt-debug") << std::endl; } // enumerator type info - EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn); + EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn); // get the enumerator information - EnumInfo& einfo = d_strategy[f].getEnumInfo(e); + EnumInfo& einfo = d_strategy.at(f).getEnumInfo(e); EnumCache& ecache = d_ecache[e]; diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.h b/src/theory/quantifiers/sygus/sygus_unif_io.h index fd918c996..79db22ab3 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.h +++ b/src/theory/quantifiers/sygus/sygus_unif_io.h @@ -265,7 +265,7 @@ class SygusUnifIo : public SygusUnif friend class UnifContextIo; public: - SygusUnifIo(SynthConjecture* p); + SygusUnifIo(Env& env, SynthConjecture* p); ~SygusUnifIo(); /** initialize diff --git a/src/theory/quantifiers/sygus/sygus_unif_rl.cpp b/src/theory/quantifiers/sygus/sygus_unif_rl.cpp index 09060926b..5af838354 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_rl.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_rl.cpp @@ -32,8 +32,11 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SygusUnifRl::SygusUnifRl(SynthConjecture* p) - : d_parent(p), d_useCondPool(false), d_useCondPoolIGain(false) +SygusUnifRl::SygusUnifRl(Env& env, SynthConjecture* p) + : SygusUnif(env), + d_parent(p), + d_useCondPool(false), + d_useCondPoolIGain(false) { } SygusUnifRl::~SygusUnifRl() {} @@ -55,7 +58,7 @@ void SygusUnifRl::initializeCandidate( } // register the strategy registerStrategy(f, enums, restrictions.d_unused_strategies); - d_strategy[f].staticLearnRedundantOps(strategy_lemmas, restrictions); + d_strategy.at(f).staticLearnRedundantOps(strategy_lemmas, restrictions); // Copy candidates and check whether CegisUnif for any of them if (d_unif_candidates.find(f) != d_unif_candidates.end()) { @@ -118,7 +121,7 @@ Node SygusUnifRl::purifyLemma(Node n, TNode cand = n[0]; Node tmp = n.substitute(cand, it1->second); // should be concrete, can just use the rewriter - nv = Rewriter::rewrite(tmp); + nv = rewrite(tmp); Trace("sygus-unif-rl-purify") << "PurifyLemma : model value for " << tmp << " is " << nv << "\n"; } @@ -231,7 +234,7 @@ Node SygusUnifRl::purifyLemma(Node n, Trace("sygus-unif-rl-purify") << "PurifyLemma : adding model eq " << model_guards.back() << "\n"; } - nb = Rewriter::rewrite(nb); + nb = rewrite(nb); // every non-top level application of function-to-synthesize must be reduced // to a concrete constant Assert(!ensureConst || nb.isConst()); @@ -262,7 +265,7 @@ Node SygusUnifRl::addRefLemma(Node lemma, model_guards.push_back(plem); plem = NodeManager::currentNM()->mkNode(OR, model_guards); } - plem = Rewriter::rewrite(plem); + plem = rewrite(plem); Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n"; Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n"; @@ -316,7 +319,7 @@ bool SygusUnifRl::constructSolution(std::vector& sols, } initializeConstructSolFor(c); Node v = constructSol( - c, d_strategy[c].getRootEnumerator(), role_equal, 0, lemmas); + c, d_strategy.at(c).getRootEnumerator(), role_equal, 0, lemmas); if (v.isNull()) { // we continue trying to build solutions to accumulate potentitial @@ -337,7 +340,7 @@ Node SygusUnifRl::constructSol( Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl; // retrieve strategy information TypeNode etn = e.getType(); - EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn); + EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn); StrategyNode& snode = tinfo.getStrategyNode(nrole); if (nrole != role_equal) { @@ -412,10 +415,10 @@ void SygusUnifRl::registerStrategy( { Trace("sygus-unif-rl-strat") << "Strategy for " << f << " is : " << std::endl; - d_strategy[f].debugPrint("sygus-unif-rl-strat"); + d_strategy.at(f).debugPrint("sygus-unif-rl-strat"); } Trace("sygus-unif-rl-strat") << "Register..." << std::endl; - Node e = d_strategy[f].getRootEnumerator(); + Node e = d_strategy.at(f).getRootEnumerator(); std::map> visited; registerStrategyNode(f, e, role_equal, visited, enums, unused_strats); } @@ -435,7 +438,7 @@ void SygusUnifRl::registerStrategyNode( } visited[e][nrole] = true; TypeNode etn = e.getType(); - EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn); + EnumTypeInfo& tinfo = d_strategy.at(f).getEnumTypeInfo(etn); StrategyNode& snode = tinfo.getStrategyNode(nrole); for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++) { @@ -497,7 +500,7 @@ void SygusUnifRl::registerConditionalEnumerator(Node f, d_cenum_to_stratpt[cond].clear(); } // register that this strategy node has a decision tree construction - d_stratpt_to_dt[e].initialize(cond, this, &d_strategy[f], strategy_index); + d_stratpt_to_dt[e].initialize(cond, this, &d_strategy.at(f), strategy_index); // associate conditional enumerator with strategy node d_cenum_to_stratpt[cond].push_back(e); } diff --git a/src/theory/quantifiers/sygus/sygus_unif_rl.h b/src/theory/quantifiers/sygus/sygus_unif_rl.h index 28506d0bd..11a562d79 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_rl.h +++ b/src/theory/quantifiers/sygus/sygus_unif_rl.h @@ -48,7 +48,7 @@ class SynthConjecture; class SygusUnifRl : public SygusUnif { public: - SygusUnifRl(SynthConjecture* p); + SygusUnifRl(Env& env, SynthConjecture* p); ~SygusUnifRl(); /** initialize */ diff --git a/src/theory/quantifiers/sygus/sygus_unif_strat.cpp b/src/theory/quantifiers/sygus/sygus_unif_strat.cpp index 15b220c74..10db1ef9e 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_strat.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_strat.cpp @@ -23,7 +23,6 @@ #include "theory/quantifiers/sygus/sygus_unif.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" -#include "theory/rewriter.h" using namespace std; using namespace cvc5::kind; @@ -443,7 +442,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) teut = children.size() == 1 ? children[0] : nm->mkNode(eut.getKind(), children); - teut = Rewriter::rewrite(teut); + teut = rewrite(teut); } else { diff --git a/src/theory/quantifiers/sygus/sygus_unif_strat.h b/src/theory/quantifiers/sygus/sygus_unif_strat.h index 11950b0c2..fadf68e3f 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_strat.h +++ b/src/theory/quantifiers/sygus/sygus_unif_strat.h @@ -19,7 +19,9 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS_UNIF_STRAT_H #include + #include "expr/node.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -276,10 +278,10 @@ struct StrategyRestrictions * the grammar of the function to synthesize f. This tree is represented by * the above classes. */ -class SygusUnifStrategy +class SygusUnifStrategy : protected EnvObj { public: - SygusUnifStrategy() : d_tds(nullptr) {} + SygusUnifStrategy(Env& env) : EnvObj(env), d_tds(nullptr) {} /** initialize * * This initializes this class with function-to-synthesize f. We also call -- 2.30.2