From 19b18be61c365f8506785a41244b74d008fa5976 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 25 May 2022 15:16:40 -0500 Subject: [PATCH] Eliminate static access to dtSharedSelectors (#8804) Towards eliminating option scopes. --- src/expr/dtype_cons.cpp | 24 ++++----- src/expr/dtype_cons.h | 15 +++--- src/theory/datatypes/datatypes_rewriter.cpp | 19 +++---- src/theory/datatypes/datatypes_rewriter.h | 13 +++-- src/theory/datatypes/proof_checker.cpp | 2 +- src/theory/datatypes/proof_checker.h | 4 +- src/theory/datatypes/sygus_extension.cpp | 29 +++++++---- src/theory/datatypes/sygus_extension.h | 4 ++ src/theory/datatypes/theory_datatypes.cpp | 20 +++++--- .../datatypes/theory_datatypes_utils.cpp | 46 ++++++++--------- src/theory/datatypes/theory_datatypes_utils.h | 26 +++++++--- src/theory/datatypes/tuple_utils.cpp | 5 +- .../quantifiers/cegqi/ceg_dt_instantiator.cpp | 19 +++---- .../ematching/candidate_generator.cpp | 51 ++++++++++--------- .../ematching/candidate_generator.h | 29 ++++++++--- .../ematching/inst_match_generator.cpp | 15 +++--- .../quantifiers/fmf/bounded_integers.cpp | 6 +-- src/theory/quantifiers/skolemize.cpp | 20 +++----- .../quantifiers/sygus/sygus_eval_unfold.cpp | 13 +++-- .../quantifiers/sygus/sygus_explain.cpp | 18 ++++--- src/theory/quantifiers/sygus/sygus_explain.h | 5 +- .../quantifiers/sygus/term_database_sygus.cpp | 10 ++-- src/theory/quantifiers/sygus_inst.cpp | 2 +- 23 files changed, 225 insertions(+), 170 deletions(-) diff --git a/src/expr/dtype_cons.cpp b/src/expr/dtype_cons.cpp index cda9f8a63..f4bab0bde 100644 --- a/src/expr/dtype_cons.cpp +++ b/src/expr/dtype_cons.cpp @@ -240,21 +240,21 @@ TypeNode DTypeConstructor::getArgType(size_t index) const return (*this)[index].getType().getDatatypeSelectorRangeType(); } -Node DTypeConstructor::getSelectorInternal(TypeNode domainType, - size_t index) const +Node DTypeConstructor::getSelector(size_t index) const { Assert(isResolved()); Assert(index < getNumArgs()); - if (options::dtSharedSelectors()) - { - computeSharedSelectors(domainType); - Assert(d_sharedSelectors[domainType].size() == getNumArgs()); - return d_sharedSelectors[domainType][index]; - } - else - { - return d_args[index]->getSelector(); - } + return d_args[index]->getSelector(); +} + +Node DTypeConstructor::getSharedSelector(TypeNode domainType, + size_t index) const +{ + Assert(isResolved()); + Assert(index < getNumArgs()); + computeSharedSelectors(domainType); + Assert(d_sharedSelectors[domainType].size() == getNumArgs()); + return d_sharedSelectors[domainType][index]; } int DTypeConstructor::getSelectorIndexInternal(Node sel) const diff --git a/src/expr/dtype_cons.h b/src/expr/dtype_cons.h index 7c63a1ed8..b40f21316 100644 --- a/src/expr/dtype_cons.h +++ b/src/expr/dtype_cons.h @@ -188,13 +188,12 @@ class DTypeConstructor /** get selector internal * - * This gets the selector for the index^th argument - * of this constructor. The type dtt is the datatype - * type whose datatype is the owner of this constructor, - * where this type may be an instantiated parametric datatype. - * - * If shared selectors are enabled, - * this returns a shared (constructor-agnotic) selector, which + * This gets the (unshared) selector for the index^th argument + * of this constructor. + */ + Node getSelector(size_t index) const; + /** + * This returns a shared (constructor-agnotic) selector, which * in the terminology of "DTypes with Shared Selectors", is: * sel_{dtt}^{T,atos(T,C,index)} * where C is this constructor, and T is the type @@ -203,7 +202,7 @@ class DTypeConstructor * type T of constructor term t if one exists, or is * unconstrained otherwise. */ - Node getSelectorInternal(TypeNode dtt, size_t index) const; + Node getSharedSelector(TypeNode dtt, size_t index) const; /** get selector index internal * diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index 3837e1d21..928b6d1e4 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -37,8 +37,8 @@ namespace cvc5::internal { namespace theory { namespace datatypes { -DatatypesRewriter::DatatypesRewriter(Evaluator* sygusEval) - : d_sygusEval(sygusEval) +DatatypesRewriter::DatatypesRewriter(Evaluator* sygusEval, const Options& opts) + : d_sygusEval(sygusEval), d_opts(opts) { } @@ -780,12 +780,11 @@ Node DatatypesRewriter::replaceDebruijn(Node n, return n; } -Node DatatypesRewriter::expandApplySelector(Node n) +Node DatatypesRewriter::expandApplySelector(Node n, bool sharedSel) { Assert(n.getKind() == APPLY_SELECTOR); Node selector = n.getOperator(); - if (!options::dtSharedSelectors() - || !selector.hasAttribute(DTypeConsIndexAttr())) + if (!sharedSel || !selector.hasAttribute(DTypeConsIndexAttr())) { return n; } @@ -798,10 +797,7 @@ Node DatatypesRewriter::expandApplySelector(Node n) size_t selectorIndex = utils::indexOf(selector); Trace("dt-expand") << "...selector index = " << selectorIndex << std::endl; Assert(selectorIndex < c.getNumArgs()); - Node selector_use = c.getSelectorInternal(ndt, selectorIndex); - NodeManager* nm = NodeManager::currentNM(); - Node sel = nm->mkNode(kind::APPLY_SELECTOR, selector_use, n[0]); - return sel; + return utils::applySelector(c, selectorIndex, true, n[0]); } TrustNode DatatypesRewriter::expandDefinition(Node n) @@ -813,7 +809,7 @@ TrustNode DatatypesRewriter::expandDefinition(Node n) { case kind::APPLY_SELECTOR: { - ret = expandApplySelector(n); + ret = expandApplySelector(n, d_opts.datatypes.dtSharedSelectors); } break; case APPLY_UPDATER: @@ -837,6 +833,7 @@ TrustNode DatatypesRewriter::expandDefinition(Node n) Trace("dt-expand") << "expr is " << n << std::endl; Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl; Trace("dt-expand") << "t is " << tn << std::endl; + bool shareSel = d_opts.datatypes.dtSharedSelectors; for (size_t i = 0, size = dc.getNumArgs(); i < size; ++i) { if (i == updateIndex) @@ -845,7 +842,7 @@ TrustNode DatatypesRewriter::expandDefinition(Node n) } else { - b << nm->mkNode(APPLY_SELECTOR, dc.getSelectorInternal(tn, i), n[0]); + b << utils::applySelector(dc, i, shareSel, n[0]); } } ret = b; diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 83c6f8049..4c2137d79 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -22,6 +22,9 @@ #include "theory/theory_rewriter.h" namespace cvc5::internal { + +class Options; + namespace theory { namespace datatypes { @@ -38,7 +41,7 @@ namespace datatypes { class DatatypesRewriter : public TheoryRewriter { public: - DatatypesRewriter(Evaluator* sygusEval); + DatatypesRewriter(Evaluator* sygusEval, const Options& opts); RewriteResponse postRewrite(TNode in) override; RewriteResponse preRewrite(TNode in) override; @@ -65,13 +68,13 @@ class DatatypesRewriter : public TheoryRewriter * (APPLY_SELECTOR selC x) * its expanded form is * (APPLY_SELECTOR selC' x) - * where f is a skolem function with id SELECTOR_WRONG, and selC' is the - * internal selector function for selC (possibly a shared selector). + * where selC' is the internal selector function for selC (a shared selector + * if sharedSel is true). * Note that we do not introduce an uninterpreted function here, e.g. to * handle when the selector is misapplied. This is because it suffices to * reason about the original selector term e.g. via congruence. */ - static Node expandApplySelector(Node n); + static Node expandApplySelector(Node n, bool sharedSel); /** * Expand a match term into its definition. * For example @@ -200,6 +203,8 @@ class DatatypesRewriter : public TheoryRewriter Node sygusToBuiltinEval(Node n, const std::vector& args); /** Pointer to the evaluator, used as an optimization for the above method */ Evaluator* d_sygusEval; + /** Reference to the options */ + const Options& d_opts; }; } // namespace datatypes diff --git a/src/theory/datatypes/proof_checker.cpp b/src/theory/datatypes/proof_checker.cpp index 25a0bee2c..9ca52043a 100644 --- a/src/theory/datatypes/proof_checker.cpp +++ b/src/theory/datatypes/proof_checker.cpp @@ -74,7 +74,7 @@ Node DatatypesProofRuleChecker::checkInternal(PfRule id, return Node::null(); } Node tester = utils::mkTester(t, i, dt); - Node ticons = utils::getInstCons(t, dt, i); + Node ticons = utils::getInstCons(t, dt, i, d_sharedSel); return tester.eqNode(t.eqNode(ticons)); } else if (id == PfRule::DT_COLLAPSE) diff --git a/src/theory/datatypes/proof_checker.h b/src/theory/datatypes/proof_checker.h index 51e63d5a5..e9eeedcbe 100644 --- a/src/theory/datatypes/proof_checker.h +++ b/src/theory/datatypes/proof_checker.h @@ -30,7 +30,7 @@ namespace datatypes { class DatatypesProofRuleChecker : public ProofRuleChecker { public: - DatatypesProofRuleChecker() {} + DatatypesProofRuleChecker(bool sharedSel) : d_sharedSel(sharedSel) {} ~DatatypesProofRuleChecker() {} /** Register all rules owned by this rule checker into pc. */ @@ -41,6 +41,8 @@ class DatatypesProofRuleChecker : public ProofRuleChecker Node checkInternal(PfRule id, const std::vector& children, const std::vector& args) override; + /** Whether we are using shared selectors */ + bool d_sharedSel; }; } // namespace datatypes diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index de7e08d37..13e30765a 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -375,8 +375,8 @@ void SygusExtension::assertTesterInternal(int tindex, TNode n, Node exp) { Trace("sygus-sb-debug") << "Do lazy symmetry breaking...\n"; for( unsigned j=0; jmkNode(APPLY_SELECTOR, dt[tindex].getSelectorInternal(ntn, j), n); + Node sel = nm->mkNode( + APPLY_SELECTOR, getSelectorInternal(ntn, dt[tindex], j), n); Trace("sygus-sb-debug2") << " activate child sel : " << sel << std::endl; Assert(d_active_terms.find(sel) == d_active_terms.end()); IntMap::const_iterator itt = d_testers.find( sel ); @@ -602,7 +602,7 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, for (unsigned j = 0; j < dt_index_nargs; j++) { Node sel = - nm->mkNode(APPLY_SELECTOR, dt[tindex].getSelectorInternal(tn, j), n); + nm->mkNode(APPLY_SELECTOR, getSelectorInternal(tn, dt[tindex], j), n); Assert(sel.getType().isDatatype()); children.push_back(sel); } @@ -615,7 +615,10 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, && !isAnyConstant) { Node szl = nm->mkNode(DT_SIZE, n); - Node szr = nm->mkNode(DT_SIZE, utils::getInstCons(n, dt, tindex)); + Node szr = + nm->mkNode(DT_SIZE, + utils::getInstCons( + n, dt, tindex, options().datatypes.dtSharedSelectors)); szr = rewrite(szr); sbp_conj.push_back(szl.eqNode(szr)); } @@ -927,8 +930,9 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, && children[0].getType() == tn && children[1].getType() == tn) { // chainable - Node child11 = nm->mkNode( - APPLY_SELECTOR, dt[tindex].getSelectorInternal(tn, 1), children[0]); + Node child11 = nm->mkNode(APPLY_SELECTOR, + getSelectorInternal(tn, dt[tindex], 1), + children[0]); Assert(child11.getType() == children[1].getType()); Node order_pred_trans = nm->mkNode(OR, @@ -1015,7 +1019,7 @@ Node SygusExtension::registerSearchValue(Node a, for (unsigned i = 0, nchild = nv.getNumChildren(); i < nchild; i++) { Node sel = - nm->mkNode(APPLY_SELECTOR, dt[cindex].getSelectorInternal(tn, i), n); + nm->mkNode(APPLY_SELECTOR, getSelectorInternal(tn, dt[cindex], i), n); Node nvc = registerSearchValue(a, sel, nv[i], @@ -1734,7 +1738,7 @@ bool SygusExtension::checkValue(Node n, TNode vn, int ind) } for( unsigned i=0; imkNode(APPLY_SELECTOR, dt[cindex].getSelectorInternal(tn, i), n); + nm->mkNode(APPLY_SELECTOR, getSelectorInternal(tn, dt[cindex], i), n); if (!checkValue(sel, vn[i], ind + 1)) { return false; @@ -1756,7 +1760,7 @@ Node SygusExtension::getCurrentTemplate( Node n, std::map< TypeNode, int >& var_ children.push_back(dt[tindex].getConstructor()); for( unsigned i=0; imkNode( - APPLY_SELECTOR, dt[tindex].getSelectorInternal(tn, i), n); + APPLY_SELECTOR, getSelectorInternal(tn, dt[tindex], i), n); Node cc = getCurrentTemplate( sel, var_count ); children.push_back( cc ); } @@ -1843,3 +1847,10 @@ int SygusExtension::getGuardStatus( Node g ) { } } +Node SygusExtension::getSelectorInternal(TypeNode dtt, + const DTypeConstructor& dc, + size_t index) const +{ + return utils::getSelector( + dtt, dc, index, options().datatypes.dtSharedSelectors); +} diff --git a/src/theory/datatypes/sygus_extension.h b/src/theory/datatypes/sygus_extension.h index 9a2c12c17..ed296ae06 100644 --- a/src/theory/datatypes/sygus_extension.h +++ b/src/theory/datatypes/sygus_extension.h @@ -710,6 +710,10 @@ private: * false, and 0 if it is not asserted. */ int getGuardStatus( Node g ); + /** Calls util::getSelector based on the value of options::dtShareSel */ + Node getSelectorInternal(TypeNode dtt, + const DTypeConstructor& dc, + size_t index) const; }; } diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 39239ba9f..6377027f3 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -60,10 +60,11 @@ TheoryDatatypes::TheoryDatatypes(Env& env, d_functionTerms(context()), d_singleton_eq(userContext()), d_sygusExtension(nullptr), - d_rewriter(env.getEvaluator()), + d_rewriter(env.getEvaluator(), env.getOptions()), d_state(env, valuation), d_im(env, *this, d_state), d_notify(d_im, *this), + d_checker(env.getOptions().datatypes.dtSharedSelectors), d_cpacb(*this) { @@ -1005,8 +1006,10 @@ bool TheoryDatatypes::collectModelValues(TheoryModel* m, //unsigned orig_size = nodes.size(); std::map< TypeNode, int > typ_enum_map; std::vector< TypeEnumerator > typ_enum; - unsigned index = 0; - while( indexmkNode(APPLY_SELECTOR, s, n); +} + +Node getInstCons(Node n, const DType& dt, size_t index, bool shareSel) { Assert(index < dt.getNumConstructors()); std::vector children; NodeManager* nm = NodeManager::currentNM(); TypeNode tn = n.getType(); - for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++) + for (size_t i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++) { Node nc = - nm->mkNode(APPLY_SELECTOR, dt[index].getSelectorInternal(tn, i), n); + nm->mkNode(APPLY_SELECTOR, getSelector(tn, dt[index], i, shareSel), n); children.push_back(nc); } Node n_ic = mkApplyCons(tn, dt, index, children); Assert(n_ic.getType() == tn); - Assert(static_cast(isInstCons(n, n_ic, dt)) == index); return n_ic; } @@ -68,26 +82,6 @@ Node mkApplyCons(TypeNode tn, return nm->mkNode(APPLY_CONSTRUCTOR, cchildren); } -int isInstCons(Node t, Node n, const DType& dt) -{ - if (n.getKind() == APPLY_CONSTRUCTOR) - { - int index = indexOf(n.getOperator()); - const DTypeConstructor& c = dt[index]; - TypeNode tn = n.getType(); - for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) - { - if (n[i].getKind() != APPLY_SELECTOR - || n[i].getOperator() != c.getSelectorInternal(tn, i) || n[i][0] != t) - { - return -1; - } - } - return index; - } - return -1; -} - int isTester(Node n, Node& a) { if (n.getKind() == APPLY_TESTER) diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h index 414232e1a..9dfe45f46 100644 --- a/src/theory/datatypes/theory_datatypes_utils.h +++ b/src/theory/datatypes/theory_datatypes_utils.h @@ -29,12 +29,29 @@ namespace theory { namespace datatypes { namespace utils { +/** + * Get the index^th selector of datatype constructor dc whose type is dtt. If + * shareSel is true, this returns the shared selector of dc. + */ +Node getSelector(TypeNode dtt, + const DTypeConstructor& dc, + size_t index, + bool shareSel); +/** + * Apply the indext^th selector of datatype constructor dc to term n. If + * shareSel is true, we use the shared selector of dc. + */ +Node applySelector(const DTypeConstructor& dc, + size_t index, + bool shareSel, + const Node& n); + /** get instantiate cons * * This returns the term C( sel^{C,1}( n ), ..., sel^{C,m}( n ) ), * where C is the index^{th} constructor of datatype dt. */ -Node getInstCons(Node n, const DType& dt, size_t index); +Node getInstCons(Node n, const DType& dt, size_t index, bool shareSel); /** * Apply constructor, taking into account whether the datatype is parametric. * @@ -45,13 +62,6 @@ Node mkApplyCons(TypeNode tn, const DType& dt, size_t index, const std::vector& children); -/** is instantiation cons - * - * If this method returns a value >=0, then that value, call it index, - * is such that n = C( sel^{C,1}( t ), ..., sel^{C,m}( t ) ), - * where C is the index^{th} constructor of dt. - */ -int isInstCons(Node t, Node n, const DType& dt); /** is tester * * This method returns a value >=0 if n is a tester predicate. The return diff --git a/src/theory/datatypes/tuple_utils.cpp b/src/theory/datatypes/tuple_utils.cpp index 74024f508..ec63b1bf4 100644 --- a/src/theory/datatypes/tuple_utils.cpp +++ b/src/theory/datatypes/tuple_utils.cpp @@ -19,6 +19,7 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" +#include "theory/datatypes/theory_datatypes_utils.h" using namespace cvc5::internal::kind; @@ -69,8 +70,10 @@ Node TupleUtils::nthElementOfTuple(Node tuple, int n_th) } TypeNode tn = tuple.getType(); const DType& dt = tn.getDType(); + // note that shared selectors are irrelevant for datatypes with one + // constructor, hence we pass false here return NodeManager::currentNM()->mkNode( - APPLY_SELECTOR, dt[0].getSelectorInternal(tn, n_th), tuple); + APPLY_SELECTOR, utils::getSelector(tn, dt[0], n_th, false), tuple); } Node TupleUtils::getTupleProjection(const std::vector& indices, diff --git a/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp index 9b5d34526..95070121c 100644 --- a/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp @@ -18,6 +18,7 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" #include "expr/node_algorithm.h" +#include "options/datatypes_options.h" #include "theory/datatypes/theory_datatypes_utils.h" using namespace std; @@ -51,7 +52,6 @@ bool DtInstantiator::processEqualTerms(CegInstantiator* ci, Trace("cegqi-dt-debug") << "try based on constructors in equivalence class." << std::endl; // look in equivalence class for a constructor - NodeManager* nm = NodeManager::currentNM(); for (unsigned k = 0, size = eqc.size(); k < size; k++) { Node n = eqc[k]; @@ -64,14 +64,12 @@ bool DtInstantiator::processEqualTerms(CegInstantiator* ci, const DType& dt = d_type.getDType(); unsigned cindex = datatypes::utils::indexOf(n.getOperator()); // now must solve for selectors applied to pv - for (unsigned j = 0, nargs = dt[cindex].getNumArgs(); j < nargs; j++) + Node val = datatypes::utils::getInstCons( + pv, dt, cindex, options().datatypes.dtSharedSelectors); + for (const Node& c : val) { - Node c = nm->mkNode( - APPLY_SELECTOR, dt[cindex].getSelectorInternal(d_type, j), pv); ci->pushStackVariable(c); - children.push_back(c); } - Node val = nm->mkNode(kind::APPLY_CONSTRUCTOR, children); TermProperties pv_prop_dt; if (ci->constructInstantiationInc(pv, val, pv_prop_dt, sf)) { @@ -146,15 +144,14 @@ Node DtInstantiator::solve_dt(Node v, Node a, Node b, Node sa, Node sb) } else { - NodeManager* nm = NodeManager::currentNM(); unsigned cindex = DType::indexOf(a.getOperator()); TypeNode tn = a.getType(); const DType& dt = tn.getDType(); - for (unsigned i = 0, nchild = a.getNumChildren(); i < nchild; i++) + Node val = datatypes::utils::getInstCons( + sb, dt, cindex, options().datatypes.dtSharedSelectors); + for (size_t i = 0, nchild = val.getNumChildren(); i < nchild; i++) { - Node nn = nm->mkNode( - APPLY_SELECTOR, dt[cindex].getSelectorInternal(tn, i), sb); - Node s = solve_dt(v, a[i], Node::null(), sa[i], nn); + Node s = solve_dt(v, a[i], Node::null(), sa[i], val[i]); if (!s.isNull()) { return s; diff --git a/src/theory/quantifiers/ematching/candidate_generator.cpp b/src/theory/quantifiers/ematching/candidate_generator.cpp index 9e0a4597d..53a6a65f8 100644 --- a/src/theory/quantifiers/ematching/candidate_generator.cpp +++ b/src/theory/quantifiers/ematching/candidate_generator.cpp @@ -17,10 +17,12 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" +#include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "smt/solver_engine.h" #include "smt/solver_engine_scope.h" #include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/quantifiers_state.h" #include "theory/quantifiers/term_database.h" @@ -34,8 +36,10 @@ namespace theory { namespace quantifiers { namespace inst { -CandidateGenerator::CandidateGenerator(QuantifiersState& qs, TermRegistry& tr) - : d_qs(qs), d_treg(tr) +CandidateGenerator::CandidateGenerator(Env& env, + QuantifiersState& qs, + TermRegistry& tr) + : EnvObj(env), d_qs(qs), d_treg(tr) { } @@ -44,10 +48,11 @@ bool CandidateGenerator::isLegalCandidate( Node n ){ && !quantifiers::TermUtil::hasInstConstAttr(n); } -CandidateGeneratorQE::CandidateGeneratorQE(QuantifiersState& qs, +CandidateGeneratorQE::CandidateGeneratorQE(Env& env, + QuantifiersState& qs, TermRegistry& tr, Node pat) - : CandidateGenerator(qs, tr), + : CandidateGenerator(env, qs, tr), d_termIter(0), d_termIterList(nullptr), d_mode(cand_term_none) @@ -156,10 +161,11 @@ Node CandidateGeneratorQE::getNextCandidateInternal() return Node::null(); } -CandidateGeneratorQELitDeq::CandidateGeneratorQELitDeq(QuantifiersState& qs, +CandidateGeneratorQELitDeq::CandidateGeneratorQELitDeq(Env& env, + QuantifiersState& qs, TermRegistry& tr, Node mpat) - : CandidateGenerator(qs, tr), d_match_pattern(mpat) + : CandidateGenerator(env, qs, tr), d_match_pattern(mpat) { Assert(d_match_pattern.getKind() == EQUAL); d_match_pattern_type = d_match_pattern[0].getType(); @@ -189,10 +195,11 @@ Node CandidateGeneratorQELitDeq::getNextCandidate(){ return Node::null(); } -CandidateGeneratorQEAll::CandidateGeneratorQEAll(QuantifiersState& qs, +CandidateGeneratorQEAll::CandidateGeneratorQEAll(Env& env, + QuantifiersState& qs, TermRegistry& tr, Node mpat) - : CandidateGenerator(qs, tr), d_match_pattern(mpat) + : CandidateGenerator(env, qs, tr), d_match_pattern(mpat) { d_match_pattern_type = mpat.getType(); Assert(mpat.getKind() == INST_CONSTANT); @@ -215,7 +222,7 @@ Node CandidateGeneratorQEAll::getNextCandidate() { { TNode nh = tdb->getEligibleTermInEqc(n); if( !nh.isNull() ){ - if (options::instMaxLevel() != -1) + if (options().quantifiers.instMaxLevel != -1) { nh = d_treg.getModel()->getInternalRepresentative(nh, d_f, d_index); //don't consider this if already the instantiation is ineligible @@ -240,10 +247,11 @@ Node CandidateGeneratorQEAll::getNextCandidate() { return Node::null(); } -CandidateGeneratorConsExpand::CandidateGeneratorConsExpand(QuantifiersState& qs, +CandidateGeneratorConsExpand::CandidateGeneratorConsExpand(Env& env, + QuantifiersState& qs, TermRegistry& tr, Node mpat) - : CandidateGeneratorQE(qs, tr, mpat) + : CandidateGeneratorQE(env, qs, tr, mpat) { Assert(mpat.getKind() == APPLY_CONSTRUCTOR); d_mpat_type = mpat.getType(); @@ -256,7 +264,7 @@ void CandidateGeneratorConsExpand::reset(Node eqc) { // generates too many instantiations at top-level when eqc is null, thus // set mode to none unless option is set. - if (options::consExpandTriggers()) + if (options().quantifiers.consExpandTriggers) { d_termIterList = d_treg.getTermDatabase()->getGroundTermList(d_op); d_mode = cand_term_db; @@ -283,18 +291,11 @@ Node CandidateGeneratorConsExpand::getNextCandidate() return curr; } // expand it - NodeManager* nm = NodeManager::currentNM(); std::vector children; const DType& dt = d_mpat_type.getDType(); Assert(dt.getNumConstructors() == 1); - children.push_back(d_op); - for (unsigned i = 0, nargs = dt[0].getNumArgs(); i < nargs; i++) - { - Node sel = nm->mkNode( - APPLY_SELECTOR, dt[0].getSelectorInternal(d_mpat_type, i), curr); - children.push_back(sel); - } - return nm->mkNode(APPLY_CONSTRUCTOR, children); + return datatypes::utils::getInstCons( + curr, dt, 0, options().datatypes.dtSharedSelectors); } bool CandidateGeneratorConsExpand::isLegalOpCandidate(Node n) @@ -302,16 +303,18 @@ bool CandidateGeneratorConsExpand::isLegalOpCandidate(Node n) return isLegalCandidate(n); } -CandidateGeneratorSelector::CandidateGeneratorSelector(QuantifiersState& qs, +CandidateGeneratorSelector::CandidateGeneratorSelector(Env& env, + QuantifiersState& qs, TermRegistry& tr, Node mpat) - : CandidateGeneratorQE(qs, tr, mpat) + : CandidateGeneratorQE(env, qs, tr, mpat) { Trace("sel-trigger") << "Selector trigger: " << mpat << std::endl; Assert(mpat.getKind() == APPLY_SELECTOR); // Get the expanded form of the selector, meaning that we will match on // the shared selector if shared selectors are enabled. - Node mpatExp = datatypes::DatatypesRewriter::expandApplySelector(mpat); + Node mpatExp = datatypes::DatatypesRewriter::expandApplySelector( + mpat, options().datatypes.dtSharedSelectors); Trace("sel-trigger") << "Expands to: " << mpatExp << std::endl; Assert (mpatExp.getKind() == APPLY_SELECTOR); d_selOp = d_treg.getTermDatabase()->getMatchOperator(mpatExp); diff --git a/src/theory/quantifiers/ematching/candidate_generator.h b/src/theory/quantifiers/ematching/candidate_generator.h index 36d91e4cd..bf7ca06c1 100644 --- a/src/theory/quantifiers/ematching/candidate_generator.h +++ b/src/theory/quantifiers/ematching/candidate_generator.h @@ -18,6 +18,7 @@ #ifndef CVC5__THEORY__QUANTIFIERS__CANDIDATE_GENERATOR_H #define CVC5__THEORY__QUANTIFIERS__CANDIDATE_GENERATOR_H +#include "smt/env_obj.h" #include "theory/theory.h" #include "theory/uf/equality_engine.h" @@ -54,9 +55,10 @@ namespace inst { * }while( !cand.isNull() ); * */ -class CandidateGenerator { +class CandidateGenerator : protected EnvObj +{ public: - CandidateGenerator(QuantifiersState& qs, TermRegistry& tr); + CandidateGenerator(Env& env, QuantifiersState& qs, TermRegistry& tr); virtual ~CandidateGenerator(){} /** reset instantiation round * @@ -97,7 +99,10 @@ class CandidateGeneratorQE : public CandidateGenerator friend class CandidateGeneratorQEDisequal; public: - CandidateGeneratorQE(QuantifiersState& qs, TermRegistry& tr, Node pat); + CandidateGeneratorQE(Env& env, + QuantifiersState& qs, + TermRegistry& tr, + Node pat); /** reset */ void reset(Node eqc) override; /** get next candidate */ @@ -154,7 +159,10 @@ class CandidateGeneratorQELitDeq : public CandidateGenerator * mpat is an equality that we are matching to equalities in the equivalence * class of false */ - CandidateGeneratorQELitDeq(QuantifiersState& qs, TermRegistry& tr, Node mpat); + CandidateGeneratorQELitDeq(Env& env, + QuantifiersState& qs, + TermRegistry& tr, + Node mpat); /** reset */ void reset(Node eqc) override; /** get next candidate */ @@ -194,7 +202,10 @@ class CandidateGeneratorQEAll : public CandidateGenerator std::string identify() const override { return "CandidateGeneratorQEAll"; } public: - CandidateGeneratorQEAll(QuantifiersState& qs, TermRegistry& tr, Node mpat); + CandidateGeneratorQEAll(Env& env, + QuantifiersState& qs, + TermRegistry& tr, + Node mpat); /** reset */ void reset(Node eqc) override; /** get next candidate */ @@ -212,7 +223,8 @@ class CandidateGeneratorQEAll : public CandidateGenerator class CandidateGeneratorConsExpand : public CandidateGeneratorQE { public: - CandidateGeneratorConsExpand(QuantifiersState& qs, + CandidateGeneratorConsExpand(Env& env, + QuantifiersState& qs, TermRegistry& tr, Node mpat); /** reset */ @@ -239,7 +251,10 @@ class CandidateGeneratorConsExpand : public CandidateGeneratorQE class CandidateGeneratorSelector : public CandidateGeneratorQE { public: - CandidateGeneratorSelector(QuantifiersState& qs, TermRegistry& tr, Node mpat); + CandidateGeneratorSelector(Env& env, + QuantifiersState& qs, + TermRegistry& tr, + Node mpat); /** reset */ void reset(Node eqc) override; /** diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index 88db6ff31..bbd929adf 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -211,8 +211,8 @@ void InstMatchGenerator::initialize(Node q, { // candidates for apply selector are a union of correctly and incorrectly // applied selectors - d_cg = - new inst::CandidateGeneratorSelector(d_qstate, d_treg, d_match_pattern); + d_cg = new inst::CandidateGeneratorSelector( + d_env, d_qstate, d_treg, d_match_pattern); } else if (TriggerTermInfo::isAtomicTriggerKind(mpk)) { @@ -224,13 +224,13 @@ void InstMatchGenerator::initialize(Node q, if (dt.getNumConstructors() == 1) { d_cg = new inst::CandidateGeneratorConsExpand( - d_qstate, d_treg, d_match_pattern); + d_env, d_qstate, d_treg, d_match_pattern); } } if (d_cg == nullptr) { CandidateGeneratorQE* cg = - new CandidateGeneratorQE(d_qstate, d_treg, d_match_pattern); + new CandidateGeneratorQE(d_env, d_qstate, d_treg, d_match_pattern); // we will be scanning lists trying to find ground terms whose operator // is the same as d_match_operator's. d_cg = cg; @@ -255,9 +255,10 @@ void InstMatchGenerator::initialize(Node q, Trace("inst-match-gen") << "Purify dt trigger " << d_pattern << ", will match terms of op " << cOp << std::endl; - d_cg = new inst::CandidateGeneratorQE(d_qstate, d_treg, cOp); + d_cg = new inst::CandidateGeneratorQE(d_env, d_qstate, d_treg, cOp); }else{ - d_cg = new CandidateGeneratorQEAll(d_qstate, d_treg, d_match_pattern); + d_cg = + new CandidateGeneratorQEAll(d_env, d_qstate, d_treg, d_match_pattern); } } else if (mpk == EQUAL) @@ -267,7 +268,7 @@ void InstMatchGenerator::initialize(Node q, { // candidates will be all disequalities d_cg = new inst::CandidateGeneratorQELitDeq( - d_qstate, d_treg, d_match_pattern); + d_env, d_qstate, d_treg, d_match_pattern); } } else diff --git a/src/theory/quantifiers/fmf/bounded_integers.cpp b/src/theory/quantifiers/fmf/bounded_integers.cpp index dd3ddd54d..0c4aaefb7 100644 --- a/src/theory/quantifiers/fmf/bounded_integers.cpp +++ b/src/theory/quantifiers/fmf/bounded_integers.cpp @@ -20,6 +20,7 @@ #include "expr/dtype_cons.h" #include "expr/node_algorithm.h" #include "expr/skolem_manager.h" +#include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" #include "theory/datatypes/theory_datatypes_utils.h" @@ -780,16 +781,15 @@ Node BoundedIntegers::matchBoundVar( Node v, Node t, Node e ){ return Node::null(); } } - NodeManager* nm = NodeManager::currentNM(); const DType& dt = datatypes::utils::datatypeOf(t.getOperator()); unsigned index = datatypes::utils::indexOf(t.getOperator()); + bool sharedSel = options().datatypes.dtSharedSelectors; for( unsigned i=0; imkNode( - APPLY_SELECTOR, dt[index].getSelectorInternal(e.getType(), i), e); + Node se = datatypes::utils::applySelector(dt[index], i, sharedSel, e); u = matchBoundVar( v, t[i], se ); } if( !u.isNull() ){ diff --git a/src/theory/quantifiers/skolemize.cpp b/src/theory/quantifiers/skolemize.cpp index 398e9e6b6..8e5c63d5f 100644 --- a/src/theory/quantifiers/skolemize.cpp +++ b/src/theory/quantifiers/skolemize.cpp @@ -147,33 +147,29 @@ void Skolemize::getSelfSel(const DType& dt, NodeManager* nm = NodeManager::currentNM(); for (unsigned j = 0; j < dc.getNumArgs(); j++) { - std::vector ssc; if (dt.isParametric()) { Trace("sk-ind-debug") << "Compare " << tspec[j] << " " << ntn << std::endl; - if (tspec[j] == ntn) + if (tspec[j] != ntn) { - ssc.push_back(n); + continue; } } else { TypeNode tn = dc[j].getRangeType(); Trace("sk-ind-debug") << "Compare " << tn << " " << ntn << std::endl; - if (tn == ntn) + if (tn != ntn) { - ssc.push_back(n); + continue; } } - for (unsigned k = 0; k < ssc.size(); k++) + // do not use shared selectors + Node ss = nm->mkNode(APPLY_SELECTOR, dc.getSelector(j), n); + if (std::find(selfSel.begin(), selfSel.end(), ss) == selfSel.end()) { - Node ss = - nm->mkNode(APPLY_SELECTOR, dc.getSelectorInternal(n.getType(), j), n); - if (std::find(selfSel.begin(), selfSel.end(), ss) == selfSel.end()) - { - selfSel.push_back(ss); - } + selfSel.push_back(ss); } } } diff --git a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp index 47764e70f..23d93df65 100644 --- a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp +++ b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp @@ -17,6 +17,7 @@ #include "expr/dtype_cons.h" #include "expr/sygus_datatype.h" +#include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -287,8 +288,11 @@ Node SygusEvalUnfold::unfold(Node en, } else { + bool shareSel = options().datatypes.dtSharedSelectors; Node ret = nm->mkNode( - APPLY_SELECTOR, dt[i].getSelectorInternal(headType, 0), en[0]); + APPLY_SELECTOR, + datatypes::utils::getSelector(headType, dt[i], 0, shareSel), + en[0]); Trace("sygus-eval-unfold-debug") << "...return (from constructor) " << ret << std::endl; return ret; @@ -297,7 +301,8 @@ Node SygusEvalUnfold::unfold(Node en, Assert(!dt.isParametric()); std::map pre; - for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) + bool sharedSel = options().datatypes.dtSharedSelectors; + for (size_t j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { std::vector cc; Node s; @@ -310,8 +315,8 @@ Node SygusEvalUnfold::unfold(Node en, } else { - s = nm->mkNode( - APPLY_SELECTOR, dt[i].getSelectorInternal(headType, j), en[0]); + Node sel = datatypes::utils::getSelector(headType, dt[i], j, sharedSel); + s = nm->mkNode(APPLY_SELECTOR, sel, en[0]); } cc.push_back(s); if (track_exp) diff --git a/src/theory/quantifiers/sygus/sygus_explain.cpp b/src/theory/quantifiers/sygus/sygus_explain.cpp index 5283eeba9..27380ba91 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.cpp +++ b/src/theory/quantifiers/sygus/sygus_explain.cpp @@ -17,6 +17,7 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" +#include "options/datatypes_options.h" #include "smt/logic_exception.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/theory_datatypes_utils.h" @@ -24,7 +25,6 @@ #include "theory/quantifiers/sygus/term_database_sygus.h" using namespace cvc5::internal::kind; -using namespace std; namespace cvc5::internal { namespace theory { @@ -116,6 +116,10 @@ Node TermRecBuild::build(unsigned d) return NodeManager::currentNM()->mkNode(d_kind[d], children); } +SygusExplain::SygusExplain(Env& env, TermDbSygus* tdb) : EnvObj(env), d_tdb(tdb) +{ +} + void SygusExplain::getExplanationForEquality(Node n, Node vn, std::vector& exp) @@ -148,12 +152,12 @@ void SygusExplain::getExplanationForEquality(Node n, int i = datatypes::utils::indexOf(vn.getOperator()); Node tst = datatypes::utils::mkTester(n, i, dt); exp.push_back(tst); - for (unsigned j = 0; j < vn.getNumChildren(); j++) + bool shareSel = options().datatypes.dtSharedSelectors; + for (size_t j = 0, vnc = vn.getNumChildren(); j < vnc; j++) { if (cexc.find(j) == cexc.end()) { - Node sel = NodeManager::currentNM()->mkNode( - kind::APPLY_SELECTOR, dt[i].getSelectorInternal(tn, j), n); + Node sel = datatypes::utils::applySelector(dt[i], j, shareSel, n); getExplanationForEquality(sel, vn[j], exp); } } @@ -246,10 +250,10 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb, vnr_exp = NodeManager::currentNM()->mkConst(true); } } - for (unsigned i = 0; i < vn.getNumChildren(); i++) + bool shareSel = options().datatypes.dtSharedSelectors; + for (size_t i = 0, vnc = vn.getNumChildren(); i < vnc; i++) { - Node sel = NodeManager::currentNM()->mkNode( - kind::APPLY_SELECTOR, dt[cindex].getSelectorInternal(ntn, i), n); + Node sel = datatypes::utils::applySelector(dt[cindex], i, shareSel, n); Node vnr_c = vnr.isNull() ? vnr : (vn[i] == vnr[i] ? Node::null() : vnr[i]); if (cexc.find(i) == cexc.end()) { diff --git a/src/theory/quantifiers/sygus/sygus_explain.h b/src/theory/quantifiers/sygus/sygus_explain.h index e8f670519..188eee236 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.h +++ b/src/theory/quantifiers/sygus/sygus_explain.h @@ -21,6 +21,7 @@ #include #include "expr/node.h" +#include "smt/env_obj.h" namespace cvc5::internal { namespace theory { @@ -140,10 +141,10 @@ class TermRecBuild * [[exp]]_n = (plus w y) * where w is a fresh variable. */ -class SygusExplain +class SygusExplain : protected EnvObj { public: - SygusExplain(TermDbSygus* tdb) : d_tdb(tdb) {} + SygusExplain(Env& env, TermDbSygus* tdb); ~SygusExplain() {} /** get explanation for equality * diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 732570141..69c03dabd 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -54,7 +54,7 @@ std::ostream& operator<<(std::ostream& os, EnumeratorRole r) TermDbSygus::TermDbSygus(Env& env, QuantifiersState& qs, OracleChecker* oc) : EnvObj(env), d_qstate(qs), - d_syexp(new SygusExplain(this)), + d_syexp(new SygusExplain(env, this)), d_funDefEval(new FunDefEvaluator(env)), d_eval_unfold(new SygusEvalUnfold(env, this)), d_ochecker(oc) @@ -409,6 +409,7 @@ void TermDbSygus::registerEnumerator(Node e, SygusTypeInfo& eti = getTypeInfo(et); std::vector sf_types; eti.getSubfieldTypes(sf_types); + bool sharedSel = options().datatypes.dtSharedSelectors; // for each type of subfield type of this enumerator for (unsigned i = 0, ntypes = sf_types.size(); i < ntypes; i++) { @@ -440,7 +441,7 @@ void TermDbSygus::registerEnumerator(Node e, // is necessary to generate a term of the form any_constant( x.0 ) for a // fresh variable x.0. Node fv = getFreeVar(stn, 0); - Node exc_val = datatypes::utils::getInstCons(fv, dt, rindex); + Node exc_val = datatypes::utils::getInstCons(fv, dt, rindex, sharedSel); // should not include the constuctor in any subterm Node x = getFreeVar(stn, 0); Trace("sygus-db") << "Construct symmetry breaking lemma from " << x @@ -792,12 +793,13 @@ unsigned TermDbSygus::getSelectorWeight(TypeNode tn, Node sel) const DType& dt = tn.getDType(); Trace("sygus-db") << "Compute selector weights for " << dt.getName() << std::endl; + bool sharedSel = options().datatypes.dtSharedSelectors; for (unsigned i = 0, size = dt.getNumConstructors(); i < size; i++) { unsigned cw = dt[i].getWeight(); - for (unsigned j = 0, size2 = dt[i].getNumArgs(); j < size2; j++) + for (size_t j = 0, size2 = dt[i].getNumArgs(); j < size2; j++) { - Node csel = dt[i].getSelectorInternal(tn, j); + Node csel = datatypes::utils::getSelector(tn, dt[i], j, sharedSel); std::map::iterator its = itsw->second.find(csel); if (its == itsw->second.end() || cw < its->second) { diff --git a/src/theory/quantifiers/sygus_inst.cpp b/src/theory/quantifiers/sygus_inst.cpp index 35bd9003a..fa1a6b57e 100644 --- a/src/theory/quantifiers/sygus_inst.cpp +++ b/src/theory/quantifiers/sygus_inst.cpp @@ -268,7 +268,7 @@ void SygusInst::check(Theory::Effort e, QEffort quant_e) FirstOrderModel* model = d_treg.getModel(); Instantiate* inst = d_qim.getInstantiate(); TermDbSygus* db = d_treg.getTermDatabaseSygus(); - SygusExplain syexplain(db); + SygusExplain syexplain(d_env, db); NodeManager* nm = NodeManager::currentNM(); options::SygusInstMode mode = options().quantifiers.sygusInstMode; -- 2.30.2