From: Andrew Reynolds Date: Thu, 16 Dec 2021 22:16:03 +0000 (-0600) Subject: Eliminate most static calls to rewrite in quantifiers (#7823) X-Git-Tag: cvc5-1.0.0~650 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=03cb06e30a13bdb9e6a3c6c3d54bfe7411f27ec8;p=cvc5.git Eliminate most static calls to rewrite in quantifiers (#7823) --- diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index 90511112c..0318f7da9 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -1052,7 +1052,7 @@ Node SygusExtension::registerSearchValue(Node a, Trace("dt-sygus") << " * DT builtin : " << n << " -> " << bvr << std::endl; unsigned sz = utils::getSygusTermSize(nv); if( d_tds->involvesDivByZero( bvr ) ){ - quantifiers::DivByZeroSygusInvarianceTest dbzet; + quantifiers::DivByZeroSygusInvarianceTest dbzet(d_env.getRewriter()); Trace("sygus-sb-mexp-debug") << "Minimize explanation for div-by-zero in " << bv << std::endl; registerSymBreakLemmaForValue(a, nv, dbzet, Node::null(), var_count); @@ -1161,7 +1161,7 @@ Node SygusExtension::registerSearchValue(Node a, // generalize the explanation for why the analog of bad_val // is equivalent to bvr - quantifiers::EquivSygusInvarianceTest eset; + quantifiers::EquivSygusInvarianceTest eset(d_env.getRewriter()); eset.init(d_tds, tn, aconj, a, bvr); Trace("sygus-sb-mexp-debug") << "Minimize explanation for eval[" << d_tds->sygusToBuiltin( bad_val ) << "] = " << bvr << std::endl; diff --git a/src/theory/quantifiers/bv_inverter.cpp b/src/theory/quantifiers/bv_inverter.cpp index cfcc5f5a1..75db29207 100644 --- a/src/theory/quantifiers/bv_inverter.cpp +++ b/src/theory/quantifiers/bv_inverter.cpp @@ -31,6 +31,8 @@ namespace cvc5 { namespace theory { namespace quantifiers { +BvInverter::BvInverter(Rewriter* r) : d_rewriter(r) {} + /*---------------------------------------------------------------------------*/ Node BvInverter::getSolveVariable(TypeNode tn) @@ -53,12 +55,16 @@ Node BvInverter::getInversionNode(Node cond, TypeNode tn, BvInverterQuery* m) TNode solve_var = getSolveVariable(tn); // condition should be rewritten - Node new_cond = Rewriter::rewrite(cond); - if (new_cond != cond) + Node new_cond = cond; + if (d_rewriter != nullptr) { - Trace("cegqi-bv-skvinv-debug") - << "Condition " << cond << " was rewritten to " << new_cond - << std::endl; + new_cond = d_rewriter->rewrite(cond); + if (new_cond != cond) + { + Trace("cegqi-bv-skvinv-debug") + << "Condition " << cond << " was rewritten to " << new_cond + << std::endl; + } } // optimization : if condition is ( x = solve_var ) should just return // solve_var and not introduce a Skolem this can happen when we ask for diff --git a/src/theory/quantifiers/bv_inverter.h b/src/theory/quantifiers/bv_inverter.h index e840b53de..835637a30 100644 --- a/src/theory/quantifiers/bv_inverter.h +++ b/src/theory/quantifiers/bv_inverter.h @@ -27,6 +27,9 @@ namespace cvc5 { namespace theory { + +class Rewriter; + namespace quantifiers { /** BvInverterQuery @@ -50,7 +53,7 @@ class BvInverterQuery class BvInverter { public: - BvInverter() {} + BvInverter(Rewriter* r = nullptr); ~BvInverter() {} /** get dummy fresh variable of type tn, used as argument for sv */ Node getSolveVariable(TypeNode tn); @@ -96,9 +99,6 @@ class BvInverter BvInverterQuery* m); private: - /** Dummy variables for each type */ - std::map d_solve_var; - /** Helper function for getPathToPv */ Node getPathToPv(Node lit, Node pv, @@ -125,6 +125,10 @@ class BvInverter * to this call is null. */ Node getInversionNode(Node cond, TypeNode tn, BvInverterQuery* m); + /** (Optional) rewriter used as helper in getInversionNode */ + Rewriter* d_rewriter; + /** Dummy variables for each type */ + std::map d_solve_var; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp index 2555ff637..bbc853ee4 100644 --- a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp +++ b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp @@ -66,7 +66,7 @@ InstStrategyCegqi::InstStrategyCegqi(Env& env, if (options().quantifiers.cegqiBv) { // if doing instantiation for BV, need the inverter class - d_bv_invert.reset(new BvInverter); + d_bv_invert.reset(new BvInverter(env.getRewriter())); } if (options().quantifiers.cegqiNestedQE) { diff --git a/src/theory/quantifiers/quant_bound_inference.cpp b/src/theory/quantifiers/quant_bound_inference.cpp index dfffe64cf..4094a6638 100644 --- a/src/theory/quantifiers/quant_bound_inference.cpp +++ b/src/theory/quantifiers/quant_bound_inference.cpp @@ -16,7 +16,6 @@ #include "theory/quantifiers/quant_bound_inference.h" #include "theory/quantifiers/fmf/bounded_integers.h" -#include "theory/rewriter.h" #include "util/rational.h" using namespace cvc5::kind; @@ -60,13 +59,8 @@ bool QuantifiersBoundInference::mayComplete(TypeNode tn, unsigned maxCard) Cardinality c = tn.getCardinality(); if (!c.isLargeFinite()) { - NodeManager* nm = NodeManager::currentNM(); - Node card = nm->mkConstInt(Rational(c.getFiniteCardinality())); // check if less than fixed upper bound - Node oth = nm->mkConstInt(Rational(maxCard)); - Node eq = nm->mkNode(LEQ, card, oth); - eq = Rewriter::rewrite(eq); - mc = eq.isConst() && eq.getConst(); + mc = (c.getFiniteCardinality() < Integer(maxCard)); } } return mc; diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index c53809d6e..10c0a315b 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -511,9 +511,9 @@ Node QuantifiersRewriter::computeProcessTerms2( { // check if it rewrites to a constant Node nn = nm->mkNode(EQUAL, no, ret[i][j]); - nn = Rewriter::rewrite(nn); childrenIte.push_back(nn); - if (nn.isConst()) + // check if it will rewrite to a constant + if (no == ret[i][j] || (no.isConst() && ret[i][j].isConst())) { doRewrite = true; } diff --git a/src/theory/quantifiers/single_inv_partition.cpp b/src/theory/quantifiers/single_inv_partition.cpp index 2725e1826..8463321dc 100644 --- a/src/theory/quantifiers/single_inv_partition.cpp +++ b/src/theory/quantifiers/single_inv_partition.cpp @@ -28,6 +28,11 @@ namespace cvc5 { namespace theory { namespace quantifiers { +SingleInvocationPartition::SingleInvocationPartition(Env& env) + : EnvObj(env), d_has_input_funcs(false) +{ +} + bool SingleInvocationPartition::init(Node n) { // first, get types of arguments for functions @@ -220,7 +225,7 @@ bool SingleInvocationPartition::init(std::vector& funcs, d_input_funcs.end(), d_input_func_sks.begin(), d_input_func_sks.end()); - cr = TermUtil::getQuantSimplify(cr); + cr = getQuantSimplify(cr); cr = cr.substitute(d_input_func_sks.begin(), d_input_func_sks.end(), d_input_funcs.begin(), @@ -614,6 +619,21 @@ void SingleInvocationPartition::debugPrint(const char* c) Trace(c) << std::endl; } +Node SingleInvocationPartition::getQuantSimplify(TNode n) const +{ + std::unordered_set fvs; + expr::getFreeVariables(n, fvs); + if (fvs.empty()) + { + return rewrite(n); + } + std::vector bvs(fvs.begin(), fvs.end()); + NodeManager* nm = NodeManager::currentNM(); + Node q = nm->mkNode(FORALL, nm->mkNode(BOUND_VAR_LIST, bvs), n); + q = rewrite(q); + return TermUtil::getRemoveQuantifiers(q); +} + } // namespace quantifiers } // namespace theory } // namespace cvc5 diff --git a/src/theory/quantifiers/single_inv_partition.h b/src/theory/quantifiers/single_inv_partition.h index 1b4ea62b0..144db9346 100644 --- a/src/theory/quantifiers/single_inv_partition.h +++ b/src/theory/quantifiers/single_inv_partition.h @@ -24,6 +24,7 @@ #include "expr/node.h" #include "expr/subs.h" #include "expr/type_node.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -56,10 +57,10 @@ namespace quantifiers { * see Example 5 of Reynolds et al. SYNT 2017. * */ -class SingleInvocationPartition +class SingleInvocationPartition : protected EnvObj { public: - SingleInvocationPartition() : d_has_input_funcs(false) {} + SingleInvocationPartition(Env& env); ~SingleInvocationPartition() {} /** initialize this partition for formula n, with input functions funcs * @@ -289,6 +290,9 @@ class SingleInvocationPartition /** get the and node corresponding to d_conjuncts[index] */ Node getConjunct(int index); + /** Quantified simplify (treat free variables in n as quantified and run + * rewriter) */ + Node getQuantSimplify(TNode n) const; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp index 833abdd22..bcd6ea561 100644 --- a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp @@ -39,7 +39,7 @@ namespace quantifiers { CegSingleInv::CegSingleInv(Env& env, TermRegistry& tr, SygusStatistics& s) : EnvObj(env), d_isSolved(false), - d_sip(new SingleInvocationPartition), + d_sip(new SingleInvocationPartition(env)), d_srcons(new SygusReconstruct(env, tr.getTermDatabaseSygus(), s)), d_single_invocation(false), d_treg(tr) diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index fdc0b28e0..6bb94f41e 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -516,7 +516,7 @@ bool Cegis::getRefinementEvalLemmas(const std::vector& vs, Assert(!lem.isNull()); std::map visited; std::map > exp; - EvalSygusInvarianceTest vsit; + EvalSygusInvarianceTest vsit(d_env.getRewriter()); Trace("sygus-cref-eval") << "Check refinement lemma conjunct " << lem << " against current model." << std::endl; Trace("sygus-cref-eval2") << "Check refinement lemma conjunct " << lem diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp index a5be4ebd6..233d7f17b 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp @@ -33,8 +33,8 @@ namespace cvc5 { namespace theory { namespace quantifiers { -EnumStreamPermutation::EnumStreamPermutation(TermDbSygus* tds) - : d_tds(tds), d_first(true), d_curr_ind(0) +EnumStreamPermutation::EnumStreamPermutation(Env& env, TermDbSygus* tds) + : EnvObj(env), d_tds(tds), d_first(true), d_curr_ind(0) { } @@ -125,7 +125,7 @@ Node EnumStreamPermutation::getNext() { d_first = false; Node bultin_value = d_tds->sygusToBuiltin(d_value, d_value.getType()); - d_perm_values.insert(Rewriter::callExtendedRewrite(bultin_value)); + d_perm_values.insert(extendedRewrite(bultin_value)); return d_value; } unsigned n_classes = d_perm_state_class.size(); @@ -192,9 +192,9 @@ Node EnumStreamPermutation::getNext() bultin_perm_value = d_tds->sygusToBuiltin(perm_value, perm_value.getType()); Trace("synth-stream-concrete-debug") << " ......perm builtin is " << bultin_perm_value; - if (options::sygusSymBreakDynamic()) + if (options().datatypes.sygusSymBreakDynamic) { - bultin_perm_value = Rewriter::callExtendedRewrite(bultin_perm_value); + bultin_perm_value = extendedRewrite(bultin_perm_value); Trace("synth-stream-concrete-debug") << " and rewrites to " << bultin_perm_value; } @@ -327,8 +327,8 @@ bool EnumStreamPermutation::PermutationState::getNextPermutation() return true; } -EnumStreamSubstitution::EnumStreamSubstitution(quantifiers::TermDbSygus* tds) - : d_tds(tds), d_stream_permutations(tds), d_curr_ind(0) +EnumStreamSubstitution::EnumStreamSubstitution(Env& env, TermDbSygus* tds) + : EnvObj(env), d_tds(tds), d_stream_permutations(env, tds), d_curr_ind(0) { } @@ -512,9 +512,9 @@ Node EnumStreamSubstitution::getNext() // construction (unless it's equiv to a constant, e.g. true / false) Node builtin_comb_value = d_tds->sygusToBuiltin(comb_value, comb_value.getType()); - if (options::sygusSymBreakDynamic()) + if (options().datatypes.sygusSymBreakDynamic) { - builtin_comb_value = Rewriter::callExtendedRewrite(builtin_comb_value); + builtin_comb_value = extendedRewrite(builtin_comb_value); } if (Trace.isOn("synth-stream-concrete")) { @@ -606,6 +606,11 @@ bool EnumStreamSubstitution::CombinationState::getNextCombination() return new_comb; } +EnumStreamConcrete::EnumStreamConcrete(Env& env, TermDbSygus* tds) + : EnumValGenerator(env), d_ess(env, tds) +{ +} + void EnumStreamConcrete::initialize(Node e) { d_ess.initialize(e.getType()); } void EnumStreamConcrete::addValue(Node v) { diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.h b/src/theory/quantifiers/sygus/enum_stream_substitution.h index 05c693ace..d9161b56f 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.h +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.h @@ -19,6 +19,7 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_STREAM_SUBSTITUTION_H #include "expr/node.h" +#include "smt/env_obj.h" #include "theory/quantifiers/sygus/enum_val_generator.h" namespace cvc5 { @@ -32,10 +33,10 @@ class TermDbSygus; * Generates a new value (modulo rewriting) when queried in which its variables * are permuted (see EnumStreamSubstitution for more details). */ -class EnumStreamPermutation +class EnumStreamPermutation : protected EnvObj { public: - EnumStreamPermutation(TermDbSygus* tds); + EnumStreamPermutation(Env& env, TermDbSygus* tds); ~EnumStreamPermutation() {} /** resets utility * @@ -164,10 +165,10 @@ class EnumStreamPermutation * Therefore when streaming concrete values, permutations and combinations are * generated by the product of the permutations and combinations of each class. */ -class EnumStreamSubstitution +class EnumStreamSubstitution : protected EnvObj { public: - EnumStreamSubstitution(TermDbSygus* tds); + EnumStreamSubstitution(Env& env, TermDbSygus* tds); ~EnumStreamSubstitution() {} /** initializes utility * @@ -283,7 +284,7 @@ class EnumStreamSubstitution class EnumStreamConcrete : public EnumValGenerator { public: - EnumStreamConcrete(TermDbSygus* tds) : d_ess(tds) {} + EnumStreamConcrete(Env& env, TermDbSygus* tds); /** initialize this class with enumerator e */ void initialize(Node e) override; /** get that value v was enumerated */ diff --git a/src/theory/quantifiers/sygus/enum_val_generator.h b/src/theory/quantifiers/sygus/enum_val_generator.h index 64c069087..ace7cc552 100644 --- a/src/theory/quantifiers/sygus/enum_val_generator.h +++ b/src/theory/quantifiers/sygus/enum_val_generator.h @@ -19,6 +19,7 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS__ENUM_VAL_GENERATOR_H #include "expr/node.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -30,9 +31,10 @@ namespace quantifiers { * values" a1, ..., an, ..., and generate a (possibly larger) stream of * "concrete values" c11, ..., c1{m_1}, ..., cn1, ... cn{m_n}, .... */ -class EnumValGenerator +class EnumValGenerator : protected EnvObj { public: + EnumValGenerator(Env& env) : EnvObj(env) {} virtual ~EnumValGenerator() {} /** initialize this class with enumerator e */ virtual void initialize(Node e) = 0; diff --git a/src/theory/quantifiers/sygus/enum_value_manager.cpp b/src/theory/quantifiers/sygus/enum_value_manager.cpp index 7fbe1c3cd..b289e984e 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.cpp +++ b/src/theory/quantifiers/sygus/enum_value_manager.cpp @@ -82,7 +82,7 @@ Node EnumValueManager::getEnumeratedValue(bool& activeIncomplete) { if (d_tds->isVariableAgnosticEnumerator(e)) { - d_evg.reset(new EnumStreamConcrete(d_tds)); + d_evg = std::make_unique(d_env, d_tds); } else { @@ -93,12 +93,13 @@ Node EnumValueManager::getEnumeratedValue(bool& activeIncomplete) if (options().quantifiers.sygusActiveGenMode == options::SygusActiveGenMode::ENUM_BASIC) { - d_evg.reset(new EnumValGeneratorBasic(d_tds, e.getType())); + d_evg = + std::make_unique(d_env, d_tds, e.getType()); } else if (options().quantifiers.sygusActiveGenMode == options::SygusActiveGenMode::RANDOM) { - d_evg.reset(new SygusRandomEnumerator(d_tds)); + d_evg = std::make_unique(d_env, d_tds); } else { @@ -118,17 +119,18 @@ Node EnumValueManager::getEnumeratedValue(bool& activeIncomplete) // use the default output for the output of sygusRewVerify out = options().base.out; } - d_secd.reset(new SygusEnumeratorCallbackDefault( - e, &d_stats, d_eec.get(), d_samplerRrV.get(), out)); + d_secd = std::make_unique( + d_env, e, &d_stats, d_eec.get(), d_samplerRrV.get(), out); } // if sygus repair const is enabled, we enumerate terms with free // variables as arguments to any-constant constructors - d_evg.reset( - new SygusEnumerator(d_tds, - d_secd.get(), - &d_stats, - false, - options().quantifiers.sygusRepairConst)); + d_evg = std::make_unique( + d_env, + d_tds, + d_secd.get(), + &d_stats, + false, + options().quantifiers.sygusRepairConst); } } Trace("sygus-active-gen") diff --git a/src/theory/quantifiers/sygus/rcons_type_info.cpp b/src/theory/quantifiers/sygus/rcons_type_info.cpp index 72a8e6a56..20232552a 100644 --- a/src/theory/quantifiers/sygus/rcons_type_info.cpp +++ b/src/theory/quantifiers/sygus/rcons_type_info.cpp @@ -34,7 +34,7 @@ void RConsTypeInfo::initialize(Env& env, NodeManager* nm = NodeManager::currentNM(); SkolemManager* sm = nm->getSkolemManager(); - d_enumerator.reset(new SygusEnumerator(tds, nullptr, &s, true)); + d_enumerator = std::make_unique(env, tds, nullptr, &s, true); d_enumerator->initialize(sm->mkDummySkolem("sygus_rcons", stn)); d_crd.reset(new CandidateRewriteDatabase(env, true, false, true, false)); // since initial samples are not always useful for equivalence checks, set diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index 711d390f8..b6ee1ca89 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -33,12 +33,14 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SygusEnumerator::SygusEnumerator(TermDbSygus* tds, +SygusEnumerator::SygusEnumerator(Env& env, + TermDbSygus* tds, SygusEnumeratorCallback* sec, SygusStatistics* s, bool enumShapes, bool enumAnyConstHoles) - : d_tds(tds), + : EnumValGenerator(env), + d_tds(tds), d_sec(sec), d_stats(s), d_enumShapes(enumShapes), @@ -55,7 +57,8 @@ void SygusEnumerator::initialize(Node e) // allocate the default callback if (d_sec == nullptr && options::sygusSymBreakDynamic()) { - d_secd.reset(new SygusEnumeratorCallbackDefault(e, d_stats)); + d_secd = + std::make_unique(d_env, e, d_stats); d_sec = d_secd.get(); } d_etype = d_enum.getType(); @@ -88,7 +91,7 @@ void SygusEnumerator::initialize(Node e) { // substitute its active guard by true and rewrite Node slem = lem.substitute(agt, truent); - slem = Rewriter::rewrite(slem); + slem = rewrite(slem); // break into conjuncts std::vector sblc; if (slem.getKind() == AND) diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.h b/src/theory/quantifiers/sygus/sygus_enumerator.h index 612a753af..594cde97f 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator.h @@ -59,6 +59,7 @@ class SygusEnumerator : public EnumValGenerator { public: /** + * @param env Reference to the environment * @param tds Pointer to the term database, required if enumShapes or * enumAnyConstHoles is true, or if we want to include symmetry breaking from * lemmas stored in the sygus term database, @@ -70,7 +71,8 @@ class SygusEnumerator : public EnumValGenerator * @param enumAnyConstHoles If true, this enumerator will generate terms where * free variables are the arguments to any-constant constructors. */ - SygusEnumerator(TermDbSygus* tds = nullptr, + SygusEnumerator(Env& env, + TermDbSygus* tds = nullptr, SygusEnumeratorCallback* sec = nullptr, SygusStatistics* s = nullptr, bool enumShapes = false, diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp index 743f67cec..b9578a66a 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator_basic.cpp @@ -24,8 +24,10 @@ namespace cvc5 { namespace theory { namespace quantifiers { -EnumValGeneratorBasic::EnumValGeneratorBasic(TermDbSygus* tds, TypeNode tn) - : d_tds(tds), d_te(tn) +EnumValGeneratorBasic::EnumValGeneratorBasic(Env& env, + TermDbSygus* tds, + TypeNode tn) + : EnumValGenerator(env), d_tds(tds), d_te(tn) { } @@ -38,10 +40,10 @@ bool EnumValGeneratorBasic::increment() return false; } d_currTerm = *d_te; - if (options::sygusSymBreakDynamic()) + if (options().datatypes.sygusSymBreakDynamic) { Node nextb = d_tds->sygusToBuiltin(d_currTerm); - nextb = Rewriter::callExtendedRewrite(nextb); + nextb = extendedRewrite(nextb); if (d_cache.find(nextb) == d_cache.end()) { d_cache.insert(nextb); diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_basic.h b/src/theory/quantifiers/sygus/sygus_enumerator_basic.h index 42bce471d..543598a90 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_basic.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator_basic.h @@ -39,7 +39,7 @@ namespace quantifiers { class EnumValGeneratorBasic : public EnumValGenerator { public: - EnumValGeneratorBasic(TermDbSygus* tds, TypeNode tn); + EnumValGeneratorBasic(Env& env, TermDbSygus* tds, TypeNode tn); ~EnumValGeneratorBasic() {} /** initialize (do nothing) */ void initialize(Node e) override {} diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp index 1170eee82..bde1fdd67 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.cpp @@ -25,8 +25,10 @@ namespace cvc5 { namespace theory { namespace quantifiers { -SygusEnumeratorCallback::SygusEnumeratorCallback(Node e, SygusStatistics* s) - : d_enum(e), d_stats(s) +SygusEnumeratorCallback::SygusEnumeratorCallback(Env& env, + Node e, + SygusStatistics* s) + : EnvObj(env), d_enum(e), d_stats(s) { d_tn = e.getType(); } @@ -34,7 +36,7 @@ SygusEnumeratorCallback::SygusEnumeratorCallback(Node e, SygusStatistics* s) bool SygusEnumeratorCallback::addTerm(Node n, std::unordered_set& bterms) { Node bn = datatypes::utils::sygusToBuiltin(n); - Node bnr = Rewriter::callExtendedRewrite(bn); + Node bnr = extendedRewrite(bn); if (d_stats != nullptr) { ++(d_stats->d_enumTermsRewrite); @@ -62,12 +64,16 @@ bool SygusEnumeratorCallback::addTerm(Node n, std::unordered_set& bterms) } SygusEnumeratorCallbackDefault::SygusEnumeratorCallbackDefault( + Env& env, Node e, SygusStatistics* s, ExampleEvalCache* eec, SygusSampler* ssrv, std::ostream* out) - : SygusEnumeratorCallback(e, s), d_eec(eec), d_samplerRrV(ssrv), d_out(out) + : SygusEnumeratorCallback(env, e, s), + d_eec(eec), + d_samplerRrV(ssrv), + d_out(out) { } void SygusEnumeratorCallbackDefault::notifyTermInternal(Node n, diff --git a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h index 8689d876f..9b7c3fd98 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator_callback.h +++ b/src/theory/quantifiers/sygus/sygus_enumerator_callback.h @@ -21,6 +21,7 @@ #include #include "expr/node.h" +#include "smt/env_obj.h" #include "theory/quantifiers/extended_rewrite.h" namespace cvc5 { @@ -36,10 +37,10 @@ class SygusSampler; * provide custom criteria for whether or not enumerated values should be * considered. */ -class SygusEnumeratorCallback +class SygusEnumeratorCallback : protected EnvObj { public: - SygusEnumeratorCallback(Node e, SygusStatistics* s = nullptr); + SygusEnumeratorCallback(Env& env, Node e, SygusStatistics* s = nullptr); virtual ~SygusEnumeratorCallback() {} /** * Add term, return true if the term should be considered in the enumeration. @@ -81,7 +82,8 @@ class SygusEnumeratorCallback class SygusEnumeratorCallbackDefault : public SygusEnumeratorCallback { public: - SygusEnumeratorCallbackDefault(Node e, + SygusEnumeratorCallbackDefault(Env& env, + Node e, SygusStatistics* s = nullptr, ExampleEvalCache* eec = nullptr, SygusSampler* ssrv = nullptr, diff --git a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp index 388a4d31f..d24ad25b2 100644 --- a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp +++ b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp @@ -177,7 +177,7 @@ void SygusEvalUnfold::registerModelValue(Node a, } else { - EvalSygusInvarianceTest esit; + EvalSygusInvarianceTest esit(d_env.getRewriter()); eval_children.insert( eval_children.end(), it->second[i].begin(), it->second[i].end()); Node conj = nm->mkNode(DT_SYGUS_EVAL, eval_children); diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 7227b7184..a783ce2ca 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -153,7 +153,7 @@ Node CegGrammarConstructor::process(Node q, sfvl = preGrammarType.getDType().getSygusVarList(); tn = preGrammarType; // normalize type, if user-provided - SygusGrammarNorm sygus_norm(d_tds); + SygusGrammarNorm sygus_norm(d_env, d_tds); tn = sygus_norm.normalizeSygusType(tn, sfvl); }else{ sfvl = SygusUtils::getSygusArgumentListForSynthFun(sf); @@ -1232,7 +1232,6 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( // Do beta reduction on the operator so that its arguments match the // fresh variables of the lambda (op) we are constructing below. sop = datatypes::utils::mkSygusTerm(sop, opLArgs); - sop = Rewriter::rewrite(sop); } opCArgs.push_back(unresAnyConst); Node coeff = nm->mkBoundVar(types[i]); diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp index 209d10297..cf7b71104 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp @@ -69,7 +69,10 @@ bool OpPosTrie::getOrMakeType(TypeNode tn, return d_children[op_pos[ind]].getOrMakeType(tn, unres_tn, op_pos, ind + 1); } -SygusGrammarNorm::SygusGrammarNorm(TermDbSygus* tds) : d_tds(tds) {} +SygusGrammarNorm::SygusGrammarNorm(Env& env, TermDbSygus* tds) + : EnvObj(env), d_tds(tds) +{ +} SygusGrammarNorm::TypeObject::TypeObject(TypeNode src_tn, TypeNode unres_tn) : d_tn(src_tn), @@ -282,9 +285,10 @@ std::unique_ptr SygusGrammarNorm::inferTransf( Trace("sygus-gnorm") << " #cons = " << op_pos.size() << " / " << dt.getNumConstructors() << std::endl; // look for redundant constructors to drop - if (options::sygusMinGrammar() && dt.getNumConstructors() == op_pos.size()) + if (options().quantifiers.sygusMinGrammar + && dt.getNumConstructors() == op_pos.size()) { - SygusRedundantCons src; + SygusRedundantCons src(d_env); src.initialize(d_tds, tn); std::vector rindices; src.getRedundant(rindices); diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.h b/src/theory/quantifiers/sygus/sygus_grammar_norm.h index f1d8e01e0..cdaf97487 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.h +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.h @@ -24,6 +24,7 @@ #include "expr/node.h" #include "expr/sygus_datatype.h" #include "expr/type_node.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -123,10 +124,10 @@ class OpPosTrie * These lighweight transformations are always applied, independently of the * normalization option being enabled. */ -class SygusGrammarNorm +class SygusGrammarNorm : protected EnvObj { public: - SygusGrammarNorm(TermDbSygus* tds); + SygusGrammarNorm(Env& env, TermDbSygus* tds); ~SygusGrammarNorm() {} /** creates a normalized typenode from a given one. * diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp index fd84f0c0a..a8ff038de 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp @@ -148,7 +148,7 @@ void SygusRedundantCons::getGenericList(TermDbSygus* tds, if (index == dt[c].getNumArgs()) { Node gt = tds->mkGeneric(dt, c, pre); - gt = Rewriter::callExtendedRewrite(gt); + gt = extendedRewrite(gt); terms.push_back(gt); return; } diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.h b/src/theory/quantifiers/sygus/sygus_grammar_red.h index 2146e1f73..018bc8fe6 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.h +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.h @@ -22,6 +22,7 @@ #include #include "expr/node.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -36,10 +37,10 @@ class TermDbSygus; * where tn is a sygus tn. Then, use getRedundant and/or isRedundant to get the * indicies of the constructors of tn that are redundant. */ -class SygusRedundantCons +class SygusRedundantCons : protected EnvObj { public: - SygusRedundantCons() {} + SygusRedundantCons(Env& env) : EnvObj(env) {} ~SygusRedundantCons() {} /** register type tn * diff --git a/src/theory/quantifiers/sygus/sygus_invariance.cpp b/src/theory/quantifiers/sygus/sygus_invariance.cpp index 8048330e4..b35b23c90 100644 --- a/src/theory/quantifiers/sygus/sygus_invariance.cpp +++ b/src/theory/quantifiers/sygus/sygus_invariance.cpp @@ -106,7 +106,7 @@ bool EquivSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = Rewriter::callExtendedRewrite(nbv); + Node nbvr = d_rewriter->extendedRewrite(nbv); Trace("sygus-sb-mexp-debug") << " min-exp check : " << nbv << " -> " << nbvr << std::endl; bool exc_arg = false; @@ -176,7 +176,7 @@ bool DivByZeroSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = Rewriter::callExtendedRewrite(nbv); + Node nbvr = d_rewriter->extendedRewrite(nbv); if (tds->involvesDivByZero(nbvr)) { Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) @@ -207,7 +207,7 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, { TypeNode tn = nvn.getType(); Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = Rewriter::callExtendedRewrite(nbv); + Node nbvr = d_rewriter->extendedRewrite(nbv); // if for any of the examples, it is not contained, then we can exclude for (unsigned i = 0; i < d_neg_con_indices.size(); i++) { @@ -218,7 +218,7 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, Node cont = NodeManager::currentNM()->mkNode(kind::STRING_CONTAINS, out, nbvre); Trace("sygus-pbe-cterm-debug") << "Check: " << cont << std::endl; - Node contr = Rewriter::rewrite(cont); + Node contr = d_rewriter->extendedRewrite(cont); if (!contr.isConst()) { if (d_isUniversal) diff --git a/src/theory/quantifiers/sygus/sygus_invariance.h b/src/theory/quantifiers/sygus/sygus_invariance.h index afb59bf73..16abc541d 100644 --- a/src/theory/quantifiers/sygus/sygus_invariance.h +++ b/src/theory/quantifiers/sygus/sygus_invariance.h @@ -25,6 +25,9 @@ namespace cvc5 { namespace theory { + +class Rewriter; + namespace quantifiers { class TermDbSygus; @@ -44,6 +47,7 @@ class SynthConjecture; class SygusInvarianceTest { public: + SygusInvarianceTest(Rewriter* r) : d_rewriter(r) {} virtual ~SygusInvarianceTest() {} /** Is nvn invariant with respect to this test ? @@ -69,6 +73,8 @@ class SygusInvarianceTest /** set updated term */ void setUpdatedTerm(Node n) { d_update_nvn = n; } protected: + /** Pointer to the rewriter */ + Rewriter* d_rewriter; /** result of the node that satisfies this invariant */ Node d_update_nvn; /** check whether nvn[ x ] is invariant */ @@ -98,8 +104,10 @@ class SygusInvarianceTest class EvalSygusInvarianceTest : public SygusInvarianceTest { public: - EvalSygusInvarianceTest() - : d_kind(kind::UNDEFINED_KIND), d_is_conjunctive(false) + EvalSygusInvarianceTest(Rewriter* r) + : SygusInvarianceTest(r), + d_kind(kind::UNDEFINED_KIND), + d_is_conjunctive(false) { } @@ -168,7 +176,10 @@ class EvalSygusInvarianceTest : public SygusInvarianceTest class EquivSygusInvarianceTest : public SygusInvarianceTest { public: - EquivSygusInvarianceTest() : d_conj(nullptr) {} + EquivSygusInvarianceTest(Rewriter* r) + : SygusInvarianceTest(r), d_conj(nullptr) + { + } /** initialize this invariance test * tn is the sygus type for e @@ -209,7 +220,7 @@ class EquivSygusInvarianceTest : public SygusInvarianceTest class DivByZeroSygusInvarianceTest : public SygusInvarianceTest { public: - DivByZeroSygusInvarianceTest() {} + DivByZeroSygusInvarianceTest(Rewriter* r) : SygusInvarianceTest(r) {} protected: /** checks whether nvn involves division by zero. */ @@ -245,7 +256,10 @@ class DivByZeroSygusInvarianceTest : public SygusInvarianceTest class NegContainsSygusInvarianceTest : public SygusInvarianceTest { public: - NegContainsSygusInvarianceTest() : d_isUniversal(false) {} + NegContainsSygusInvarianceTest(Rewriter* r) + : SygusInvarianceTest(r), d_isUniversal(false) + { + } /** initialize this invariance test * e is the enumerator which we are reasoning about (associated with a synth diff --git a/src/theory/quantifiers/sygus/sygus_process_conj.cpp b/src/theory/quantifiers/sygus/sygus_process_conj.cpp index a1f197596..18845665c 100644 --- a/src/theory/quantifiers/sygus/sygus_process_conj.cpp +++ b/src/theory/quantifiers/sygus/sygus_process_conj.cpp @@ -69,8 +69,8 @@ bool SynthConjectureProcessFun::checkMatch( } Node cn_subs = cn.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); - cn_subs = Rewriter::rewrite(cn_subs); - n = Rewriter::rewrite(n); + cn_subs = rewrite(cn_subs); + n = rewrite(n); return cn_subs == n; } diff --git a/src/theory/quantifiers/sygus/sygus_qe_preproc.cpp b/src/theory/quantifiers/sygus/sygus_qe_preproc.cpp index 2400e0e56..b57c65f0f 100644 --- a/src/theory/quantifiers/sygus/sygus_qe_preproc.cpp +++ b/src/theory/quantifiers/sygus/sygus_qe_preproc.cpp @@ -39,7 +39,7 @@ Node SygusQePreproc::preprocess(Node q) SkolemManager* sm = nm->getSkolemManager(); Trace("cegqi-qep") << "Compute single invocation for " << q << "..." << std::endl; - quantifiers::SingleInvocationPartition sip; + quantifiers::SingleInvocationPartition sip(d_env); std::vector funcs0; funcs0.insert(funcs0.end(), q[0].begin(), q[0].end()); sip.init(funcs0, body); diff --git a/src/theory/quantifiers/sygus/sygus_random_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_random_enumerator.cpp index bf051a897..0711b44ae 100644 --- a/src/theory/quantifiers/sygus/sygus_random_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_random_enumerator.cpp @@ -62,7 +62,7 @@ bool SygusRandomEnumerator::increment() // Generate the next sygus term. n = incrementH(); bn = d_tds->sygusToBuiltin(n); - bn = Rewriter::callExtendedRewrite(bn); + bn = extendedRewrite(bn); // Ensure that the builtin counterpart is unique (up to rewriting). } while (d_cache.find(bn) != d_cache.cend()); d_cache.insert(bn); @@ -174,7 +174,7 @@ Node SygusRandomEnumerator::getMin(Node n) { TypeNode tn = n.getType(); Node bn = d_tds->sygusToBuiltin(n); - bn = Rewriter::callExtendedRewrite(bn); + bn = extendedRewrite(bn); // Did we calculate the size of `n` before? if (d_size.find(n) == d_size.cend()) { diff --git a/src/theory/quantifiers/sygus/sygus_random_enumerator.h b/src/theory/quantifiers/sygus/sygus_random_enumerator.h index b70fe9490..79fc0a090 100644 --- a/src/theory/quantifiers/sygus/sygus_random_enumerator.h +++ b/src/theory/quantifiers/sygus/sygus_random_enumerator.h @@ -50,7 +50,10 @@ class SygusRandomEnumerator : public EnumValGenerator * * @param tds pointer to term database sygus. */ - SygusRandomEnumerator(TermDbSygus* tds) : d_tds(tds){}; + SygusRandomEnumerator(Env& env, TermDbSygus* tds) + : EnumValGenerator(env), d_tds(tds) + { + } /** Initialize this class with enumerator `e`. */ void initialize(Node e) override; diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 7aa952600..2144b324c 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -974,7 +974,7 @@ bool SygusUnifIo::getExplanationForEnumeratorExclude( if (!cmp_indices.empty()) { // we check invariance with respect to a negative contains test - NegContainsSygusInvarianceTest ncset; + NegContainsSygusInvarianceTest ncset(d_env.getRewriter()); if (isConditional) { ncset.setUniversal(); diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 18f9ec7e2..4f386c008 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -123,22 +123,6 @@ Node TermUtil::getRemoveQuantifiers( Node n ) { return getRemoveQuantifiers2( n, visited ); } -//quantified simplify -Node TermUtil::getQuantSimplify( Node n ) { - std::unordered_set fvs; - expr::getFreeVariables(n, fvs); - if (fvs.empty()) - { - return Rewriter::rewrite( n ); - } - std::vector bvs; - bvs.insert(bvs.end(), fvs.begin(), fvs.end()); - NodeManager* nm = NodeManager::currentNM(); - Node q = nm->mkNode(FORALL, nm->mkNode(BOUND_VAR_LIST, bvs), n); - q = Rewriter::rewrite(q); - return getRemoveQuantifiers(q); -} - void TermUtil::computeInstConstContains(Node n, std::vector& ics) { computeVarContainsInternal(n, INST_CONSTANT, ics); @@ -377,22 +361,22 @@ Node TermUtil::mkTypeValueOffset(TypeNode tn, int32_t offset, int32_t& status) { + Assert(val.isConst() && val.getType() == tn); Node val_o; - Node offset_val = mkTypeValue(tn, offset); status = -1; - if (!offset_val.isNull()) + if (tn.isRealOrInt()) { - if (tn.isRealOrInt()) - { - val_o = Rewriter::rewrite( - NodeManager::currentNM()->mkNode(PLUS, val, offset_val)); - status = 0; - } - else if (tn.isBitVector()) - { - val_o = Rewriter::rewrite( - NodeManager::currentNM()->mkNode(BITVECTOR_ADD, val, offset_val)); - } + Rational vval = val.getConst(); + Rational oval(offset); + status = 0; + return NodeManager::currentNM()->mkConstRealOrInt(tn, vval + oval); + } + else if (tn.isBitVector()) + { + BitVector vval = val.getConst(); + uint32_t uv = static_cast(offset); + BitVector oval(tn.getConst(), uv); + return NodeManager::currentNM()->mkConst(vval + oval); } return val_o; } diff --git a/src/theory/quantifiers/term_util.h b/src/theory/quantifiers/term_util.h index fb664dab5..277ce03fc 100644 --- a/src/theory/quantifiers/term_util.h +++ b/src/theory/quantifiers/term_util.h @@ -78,8 +78,6 @@ private: public: //remove quantifiers static Node getRemoveQuantifiers( Node n ); - //quantified simplify (treat free variables in n as quantified and run rewriter) - static Node getQuantSimplify( Node n ); private: /** adds the set of nodes of kind k in n to vars */