From: Andrew Reynolds Date: Thu, 17 Oct 2019 21:40:18 +0000 (-0500) Subject: Move datatype utility functions to own file (#3397) X-Git-Tag: cvc5-1.0.0~3879 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=5396f014b66cbfd7cc16380c05c1539b1efe583c;p=cvc5.git Move datatype utility functions to own file (#3397) --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 14d4ef8ae..b6b4acffb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -407,6 +407,8 @@ libcvc4_add_sources( theory/datatypes/theory_datatypes.cpp theory/datatypes/theory_datatypes.h theory/datatypes/theory_datatypes_type_rules.h + theory/datatypes/theory_datatypes_utils.cpp + theory/datatypes/theory_datatypes_utils.h theory/datatypes/type_enumerator.cpp theory/datatypes/type_enumerator.h theory/decision_manager.cpp diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index 802dedcbd..572ddbac2 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -17,6 +17,8 @@ #include "theory/datatypes/datatypes_rewriter.h" #include "expr/node_algorithm.h" +#include "options/datatypes_options.h" +#include "theory/datatypes/theory_datatypes_utils.h" using namespace CVC4; using namespace CVC4::kind; @@ -55,7 +57,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) } } TNode constructor = in[0].getOperator(); - size_t constructorIndex = indexOf(constructor); + size_t constructorIndex = utils::indexOf(constructor); const Datatype& dt = Datatype::datatypeOf(constructor.toExpr()); const DatatypeConstructor& c = dt[constructorIndex]; unsigned weight = c.getWeight(); @@ -118,7 +120,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) { Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n"; const Datatype& dt = ev.getType().getDatatype(); - unsigned i = indexOf(ev.getOperator()); + unsigned i = utils::indexOf(ev.getOperator()); Node op = Node::fromExpr(dt[i].getSygusOp()); // if it is the "any constant" constructor, return its argument if (op.getAttribute(SygusAnyConstAttribute())) @@ -141,9 +143,9 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) cc.insert(cc.end(), args.begin(), args.end()); children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc)); } - Node ret = mkSygusTerm(dt, i, children); + Node ret = utils::mkSygusTerm(dt, i, children); // apply the appropriate substitution - ret = applySygusArgs(dt, op, ret, args); + ret = utils::applySygusArgs(dt, op, ret, args); Trace("dt-sygus-util") << "...got " << ret << "\n"; return RewriteResponse(REWRITE_AGAIN_FULL, ret); } @@ -215,7 +217,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) } if (!cons.isNull()) { - cases.push_back(mkTester(h, cindex, dt)); + cases.push_back(utils::mkTester(h, cindex, dt)); } else { @@ -246,7 +248,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); } std::vector rew; - if (checkClash(in[0], in[1], rew)) + if (utils::checkClash(in[0], in[1], rew)) { Trace("datatypes-rewrite") << "Rewrite clashing equality " << in << " to false" << std::endl; @@ -271,154 +273,6 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) return RewriteResponse(REWRITE_DONE, in); } -Node DatatypesRewriter::applySygusArgs(const Datatype& dt, - Node op, - Node n, - const std::vector& args) -{ - if (n.getKind() == BOUND_VARIABLE) - { - Assert(n.hasAttribute(SygusVarNumAttribute())); - int vn = n.getAttribute(SygusVarNumAttribute()); - Assert(Node::fromExpr(dt.getSygusVarList())[vn] == n); - return args[vn]; - } - // n is an application of operator op. - // We must compute the free variables in op to determine if there are - // any substitutions we need to make to n. - TNode val; - if (!op.hasAttribute(SygusVarFreeAttribute())) - { - std::unordered_set fvs; - if (expr::getFreeVariables(op, fvs)) - { - if (fvs.size() == 1) - { - for (const Node& v : fvs) - { - val = v; - } - } - else - { - val = op; - } - } - Trace("dt-sygus-fv") << "Free var in " << op << " : " << val << std::endl; - op.setAttribute(SygusVarFreeAttribute(), val); - } - else - { - val = op.getAttribute(SygusVarFreeAttribute()); - } - if (val.isNull()) - { - return n; - } - if (val.getKind() == BOUND_VARIABLE) - { - // single substitution case - int vn = val.getAttribute(SygusVarNumAttribute()); - TNode sub = args[vn]; - return n.substitute(val, sub); - } - // do the full substitution - std::vector vars; - Node bvl = Node::fromExpr(dt.getSygusVarList()); - for (unsigned i = 0, nvars = bvl.getNumChildren(); i < nvars; i++) - { - vars.push_back(bvl[i]); - } - return n.substitute(vars.begin(), vars.end(), args.begin(), args.end()); -} - -Kind DatatypesRewriter::getOperatorKindForSygusBuiltin(Node op) -{ - Assert(op.getKind() != BUILTIN); - if (op.getKind() == LAMBDA) - { - return APPLY_UF; - } - TypeNode tn = op.getType(); - if (tn.isConstructor()) - { - return APPLY_CONSTRUCTOR; - } - else if (tn.isSelector()) - { - return APPLY_SELECTOR; - } - else if (tn.isTester()) - { - return APPLY_TESTER; - } - else if (tn.isFunction()) - { - return APPLY_UF; - } - return UNDEFINED_KIND; -} - -Node DatatypesRewriter::mkSygusTerm(const Datatype& dt, - unsigned i, - const std::vector& children) -{ - Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i - << "] with children: " << children << std::endl; - Assert(i < dt.getNumConstructors()); - Assert(dt.isSygus()); - Assert(!dt[i].getSygusOp().isNull()); - std::vector schildren; - Node op = Node::fromExpr(dt[i].getSygusOp()); - Trace("dt-sygus-util") << "Operator is " << op << std::endl; - if (children.empty()) - { - // no children, return immediately - Trace("dt-sygus-util") << "...return direct op" << std::endl; - return op; - } - // if it is the any constant, we simply return the child - if (op.getAttribute(SygusAnyConstAttribute())) - { - Assert(children.size() == 1); - return children[0]; - } - if (op.getKind() != BUILTIN) - { - schildren.push_back(op); - } - schildren.insert(schildren.end(), children.begin(), children.end()); - Node ret; - if (op.getKind() == BUILTIN) - { - ret = NodeManager::currentNM()->mkNode(op, schildren); - Trace("dt-sygus-util") << "...return (builtin) " << ret << std::endl; - return ret; - } - Kind ok = NodeManager::operatorToKind(op); - Trace("dt-sygus-util") << "operator kind is " << ok << std::endl; - if (ok != UNDEFINED_KIND) - { - // If it is an APPLY_UF operator, we should have at least an operator and - // a child. - Assert(ok != APPLY_UF || schildren.size() != 1); - ret = NodeManager::currentNM()->mkNode(ok, schildren); - Trace("dt-sygus-util") << "...return (op) " << ret << std::endl; - return ret; - } - Kind tok = getOperatorKindForSygusBuiltin(op); - if (schildren.size() == 1 && tok == kind::UNDEFINED_KIND) - { - ret = schildren[0]; - } - else - { - ret = NodeManager::currentNM()->mkNode(tok, schildren); - } - Trace("dt-sygus-util") << "...return " << ret << std::endl; - return ret; -} - RewriteResponse DatatypesRewriter::preRewrite(TNode in) { Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl; @@ -443,7 +297,7 @@ RewriteResponse DatatypesRewriter::preRewrite(TNode in) Node op = in.getOperator(); // get the constructor object const DatatypeConstructor& dtc = - Datatype::datatypeOf(op.toExpr())[indexOf(op)]; + Datatype::datatypeOf(op.toExpr())[utils::indexOf(op)]; // create ascribed constructor type Node tc = NodeManager::currentNM()->mkConst( AscriptionType(dtc.getSpecializedConstructorType(t))); @@ -496,7 +350,7 @@ RewriteResponse DatatypesRewriter::rewriteSelector(TNode in) TypeNode argType = in[0].getType(); Expr selector = in.getOperator().toExpr(); TNode constructor = in[0].getOperator(); - size_t constructorIndex = indexOf(constructor); + size_t constructorIndex = utils::indexOf(constructor); const Datatype& dt = Datatype::datatypeOf(selector); const DatatypeConstructor& c = dt[constructorIndex]; Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : " @@ -600,7 +454,8 @@ RewriteResponse DatatypesRewriter::rewriteTester(TNode in) { if (in[0].getKind() == kind::APPLY_CONSTRUCTOR) { - bool result = indexOf(in.getOperator()) == indexOf(in[0].getOperator()); + bool result = + utils::indexOf(in.getOperator()) == utils::indexOf(in[0].getOperator()); Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " << "Rewrite trivial tester " << in << " " << result << std::endl; @@ -622,189 +477,6 @@ RewriteResponse DatatypesRewriter::rewriteTester(TNode in) return RewriteResponse(REWRITE_DONE, in); } -bool DatatypesRewriter::checkClash(Node n1, Node n2, std::vector& rew) -{ - Trace("datatypes-rewrite-debug") << "Check clash : " << n1 << " " << n2 - << std::endl; - if (n1.getKind() == kind::APPLY_CONSTRUCTOR - && n2.getKind() == kind::APPLY_CONSTRUCTOR) - { - if (n1.getOperator() != n2.getOperator()) - { - Trace("datatypes-rewrite-debug") << "Clash operators : " << n1 << " " - << n2 << " " << n1.getOperator() << " " - << n2.getOperator() << std::endl; - return true; - } - Assert(n1.getNumChildren() == n2.getNumChildren()); - for (unsigned i = 0, size = n1.getNumChildren(); i < size; i++) - { - if (checkClash(n1[i], n2[i], rew)) - { - return true; - } - } - } - else if (n1 != n2) - { - if (n1.isConst() && n2.isConst()) - { - Trace("datatypes-rewrite-debug") << "Clash constants : " << n1 << " " - << n2 << std::endl; - return true; - } - else - { - Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, n1, n2); - rew.push_back(eq); - } - } - return false; -} -/** get instantiate cons */ -Node DatatypesRewriter::getInstCons(Node n, const Datatype& dt, int index) -{ - Assert(index >= 0 && index < (int)dt.getNumConstructors()); - std::vector children; - NodeManager* nm = NodeManager::currentNM(); - children.push_back(Node::fromExpr(dt[index].getConstructor())); - Type t = n.getType().toType(); - for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++) - { - Node nc = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[index].getSelectorInternal(t, i)), - n); - children.push_back(nc); - } - Node n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children); - if (dt.isParametric()) - { - TypeNode tn = TypeNode::fromType(t); - // add type ascription for ambiguous constructor types - if (!n_ic.getType().isComparableTo(tn)) - { - Debug("datatypes-parametric") << "DtInstantiate: ambiguous type for " - << n_ic << ", ascribe to " << n.getType() - << std::endl; - Debug("datatypes-parametric") << "Constructor is " << dt[index] - << std::endl; - Type tspec = - dt[index].getSpecializedConstructorType(n.getType().toType()); - Debug("datatypes-parametric") << "Type specification is " << tspec - << std::endl; - children[0] = nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(tspec)), - children[0]); - n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children); - Assert(n_ic.getType() == tn); - } - } - Assert(isInstCons(n, n_ic, dt) == index); - // n_ic = Rewriter::rewrite( n_ic ); - return n_ic; -} - -int DatatypesRewriter::isInstCons(Node t, Node n, const Datatype& dt) -{ - if (n.getKind() == kind::APPLY_CONSTRUCTOR) - { - int index = indexOf(n.getOperator()); - const DatatypeConstructor& c = dt[index]; - Type nt = n.getType().toType(); - for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) - { - if (n[i].getKind() != kind::APPLY_SELECTOR_TOTAL - || n[i].getOperator() != Node::fromExpr(c.getSelectorInternal(nt, i)) - || n[i][0] != t) - { - return -1; - } - } - return index; - } - return -1; -} - -int DatatypesRewriter::isTester(Node n, Node& a) -{ - if (n.getKind() == kind::APPLY_TESTER) - { - a = n[0]; - return indexOf(n.getOperator()); - } - return -1; -} - -int DatatypesRewriter::isTester(Node n) -{ - if (n.getKind() == kind::APPLY_TESTER) - { - return indexOf(n.getOperator().toExpr()); - } - return -1; -} - -struct DtIndexAttributeId -{ -}; -typedef expr::Attribute DtIndexAttribute; - -unsigned DatatypesRewriter::indexOf(Node n) -{ - if (!n.hasAttribute(DtIndexAttribute())) - { - Assert(n.getType().isConstructor() || n.getType().isTester() - || n.getType().isSelector()); - unsigned index = Datatype::indexOfInternal(n.toExpr()); - n.setAttribute(DtIndexAttribute(), index); - return index; - } - return n.getAttribute(DtIndexAttribute()); -} - -Node DatatypesRewriter::mkTester(Node n, int i, const Datatype& dt) -{ - return NodeManager::currentNM()->mkNode( - kind::APPLY_TESTER, Node::fromExpr(dt[i].getTester()), n); -} - -Node DatatypesRewriter::mkSplit(Node n, const Datatype& dt) -{ - std::vector splits; - for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) - { - Node test = mkTester(n, i, dt); - splits.push_back(test); - } - NodeManager* nm = NodeManager::currentNM(); - return splits.size() == 1 ? splits[0] : nm->mkNode(kind::OR, splits); -} - -bool DatatypesRewriter::isNullaryApplyConstructor(Node n) -{ - Assert(n.getKind() == kind::APPLY_CONSTRUCTOR); - for (unsigned i = 0; i < n.getNumChildren(); i++) - { - if (n[i].getType().isDatatype()) - { - return false; - } - } - return true; -} - -bool DatatypesRewriter::isNullaryConstructor(const DatatypeConstructor& c) -{ - for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) - { - if (c[j].getType().getRangeType().isDatatype()) - { - return false; - } - } - return true; -} - Node DatatypesRewriter::normalizeCodatatypeConstant(Node n) { Trace("dt-nconst") << "Normalize " << n << std::endl; diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 1a1735402..d2fdd8f4d 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -20,59 +20,11 @@ #define CVC4__THEORY__DATATYPES__DATATYPES_REWRITER_H #include "expr/node_manager_attributes.h" -#include "options/datatypes_options.h" #include "theory/rewriter.h" #include "theory/type_enumerator.h" namespace CVC4 { namespace theory { - -/** sygus var num */ -struct SygusVarNumAttributeId -{ -}; -typedef expr::Attribute SygusVarNumAttribute; - -/** Attribute true for variables that represent any constant */ -struct SygusAnyConstAttributeId -{ -}; -typedef expr::Attribute SygusAnyConstAttribute; - -/** - * Attribute true for enumerators whose current model values were registered by - * the datatypes sygus solver, and were not excluded by sygus symmetry breaking. - * This is set by the datatypes sygus solver during LAST_CALL effort checks for - * each active sygus enumerator. - */ -struct SygusSymBreakOkAttributeId -{ -}; -typedef expr::Attribute - SygusSymBreakOkAttribute; - -/** sygus var free - * - * This attribute is used to mark whether sygus operators have free occurrences - * of variables from the formal argument list of the function-to-synthesize. - * - * We store three possible cases for sygus operators op: - * (1) op.getAttribute(SygusVarFreeAttribute())==Node::null() - * In this case, op has no free variables from the formal argument list of the - * function-to-synthesize. - * (2) op.getAttribute(SygusVarFreeAttribute())==v, where v is a bound variable. - * In this case, op has exactly one free variable, v. - * (3) op.getAttribute(SygusVarFreeAttribute())==op - * In this case, op has an arbitrary set (cardinality >1) of free variables from - * the formal argument list of the function to synthesize. - * - * This attribute is used to compute applySygusArgs below. - */ -struct SygusVarFreeAttributeId -{ -}; -typedef expr::Attribute SygusVarFreeAttribute; - namespace datatypes { class DatatypesRewriter { @@ -83,47 +35,7 @@ public: static inline void init() {} static inline void shutdown() {} - /** get instantiate cons - * - * This returns the term C( sel^{C,1}( n ), ..., sel^{C,m}( n ) ), - * where C is the index^{th} constructor of datatype dt. - */ - static Node getInstCons(Node n, const Datatype& dt, int index); - /** is instantiation cons - * - * If this method returns a value >=0, then that value, call it index, - * is such that n = C( sel^{C,1}( t ), ..., sel^{C,m}( t ) ), - * where C is the index^{th} constructor of dt. - */ - static int isInstCons(Node t, Node n, const Datatype& dt); - /** is tester - * - * This method returns a value >=0 if n is a tester predicate. The return - * value indicates the constructor index that the tester n is for. If this - * method returns a value >=0, then it updates a to the argument that the - * tester n applies to. - */ - static int isTester(Node n, Node& a); - /** is tester, same as above but does not update an argument */ - static int isTester(Node n); - /** - * Get the index of a constructor or tester in its datatype, or the - * index of a selector in its constructor. (Zero is always the - * first index.) - */ - static unsigned indexOf(Node n); - /** make tester is-C( n ), where C is the i^{th} constructor of dt */ - static Node mkTester(Node n, int i, const Datatype& dt); - /** make tester split - * - * Returns the formula (OR is-C1( n ) ... is-Ck( n ) ), where C1...Ck - * are the constructors of n's type (dt). - */ - static Node mkSplit(Node n, const Datatype& dt); - /** returns true iff n is a constructor term with no datatype children */ - static bool isNullaryApplyConstructor(Node n); - /** returns true iff c is a constructor with no datatype children */ - static bool isNullaryConstructor(const DatatypeConstructor& c); + /** normalize codatatype constant * * This returns the normal form of the codatatype constant n. This runs a @@ -142,72 +54,6 @@ public: * on all top-level codatatype subterms of n. */ static Node normalizeConstant(Node n); - /** check clash - * - * This method returns true if and only if n1 and n2 have a skeleton that has - * conflicting constructors at some term position. - * Examples of terms that clash are: - * C( x, y ) and D( z ) - * C( D( x ), y ) and C( E( x ), y ) - * Examples of terms that do not clash are: - * C( x, y ) and C( D( x ), y ) - * C( D( x ), y ) and C( x, E( z ) ) - * C( x, y ) and z - */ - static bool checkClash(Node n1, Node n2, std::vector& rew); - /** get operator kind for sygus builtin - * - * This returns the Kind corresponding to applications of the operator op - * when building the builtin version of sygus terms. This is used by the - * function mkSygusTerm. - */ - static Kind getOperatorKindForSygusBuiltin(Node op); - /** make sygus term - * - * This function returns a builtin term f( children[0], ..., children[n] ) - * where f is the builtin op that the i^th constructor of sygus datatype dt - * encodes. - */ - static Node mkSygusTerm(const Datatype& dt, - unsigned i, - const std::vector& children); - /** - * n is a builtin term that is an application of operator op. - * - * This returns an n' such that (eval n args) is n', where n' is a instance of - * n for the appropriate substitution. - * - * For example, given a function-to-synthesize with formal argument list (x,y), - * say we have grammar: - * A -> A+A | A+x | A+(x+y) | y - * These lead to constructors with sygus ops: - * C1 / (lambda w1 w2. w1+w2) - * C2 / (lambda w1. w1+x) - * C3 / (lambda w1. w1+(x+y)) - * C4 / y - * Examples of calling this function: - * applySygusArgs( dt, C1, (APPLY_UF (lambda w1 w2. w1+w2) t1 t2), { 3, 5 } ) - * ... returns (APPLY_UF (lambda w1 w2. w1+w2) t1 t2). - * applySygusArgs( dt, C2, (APPLY_UF (lambda w1. w1+x) t1), { 3, 5 } ) - * ... returns (APPLY_UF (lambda w1. w1+3) t1). - * applySygusArgs( dt, C3, (APPLY_UF (lambda w1. w1+(x+y)) t1), { 3, 5 } ) - * ... returns (APPLY_UF (lambda w1. w1+(3+5)) t1). - * applySygusArgs( dt, C4, y, { 3, 5 } ) - * ... returns 5. - * Notice the attribute SygusVarFreeAttribute is applied to C1, C2, C3, C4, - * to cache the results of whether the evaluation of this constructor needs - * a substitution over the formal argument list of the function-to-synthesize. - */ - static Node applySygusArgs(const Datatype& dt, - Node op, - Node n, - const std::vector& args); - /** - * Get the builtin sygus operator for constructor term n of sygus datatype - * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus - * constructor whose sygus op is the builtin operator +, this method returns +. - */ - static Node getSygusOpForCTerm(Node n); private: /** rewrite constructor term in */ diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index cf05a6029..b04686492 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -18,10 +18,11 @@ #include "expr/node_manager.h" #include "options/base_options.h" +#include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "printer/printer.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/datatypes/theory_datatypes.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/sygus_explain.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" @@ -143,15 +144,14 @@ Node SygusSymBreakNew::getTermOrderPredicate( Node n1, Node n2 ) { std::vector case_conj; for (unsigned k = 0; k < j; k++) { - case_conj.push_back(DatatypesRewriter::mkTester(n2, k, cdt).negate()); + case_conj.push_back(utils::mkTester(n2, k, cdt).negate()); } if (!case_conj.empty()) { Node corder = nm->mkNode( - kind::OR, - DatatypesRewriter::mkTester(n1, j, cdt).negate(), - case_conj.size() == 1 ? case_conj[0] - : nm->mkNode(kind::AND, case_conj)); + OR, + utils::mkTester(n1, j, cdt).negate(), + case_conj.size() == 1 ? case_conj[0] : nm->mkNode(AND, case_conj)); sz_eq_cases.push_back(corder); } } @@ -394,7 +394,7 @@ Node SygusSymBreakNew::getRelevancyCondition( Node n ) { int sindexi = dt[i].getSelectorIndexInternal(selExpr); if (sindexi != -1) { - disj.push_back(DatatypesRewriter::mkTester(n[0], i, dt).negate()); + disj.push_back(utils::mkTester(n[0], i, dt).negate()); } else { @@ -409,7 +409,7 @@ Node SygusSymBreakNew::getRelevancyCondition( Node n ) { }else{ int sindex = Datatype::cindexOf( selExpr ); Assert( sindex!=-1 ); - cond = DatatypesRewriter::mkTester(n[0], sindex, dt).negate(); + cond = utils::mkTester(n[0], sindex, dt).negate(); } Node c1 = getRelevancyCondition( n[0] ); if( cond.isNull() ){ @@ -596,8 +596,7 @@ Node SygusSymBreakNew::getSimpleSymBreakPred(Node e, if (options::sygusFair() == SYGUS_FAIR_DT_SIZE && !isAnyConstant) { Node szl = nm->mkNode(DT_SIZE, n); - Node szr = - nm->mkNode(DT_SIZE, DatatypesRewriter::getInstCons(n, dt, tindex)); + Node szr = nm->mkNode(DT_SIZE, utils::getInstCons(n, dt, tindex)); szr = Rewriter::rewrite(szr); sbp_conj.push_back(szl.eqNode(szr)); } @@ -703,7 +702,7 @@ Node SygusSymBreakNew::getSimpleSymBreakPred(Node e, const Datatype& cdt = static_cast(ctn.toType()).getDatatype(); Assert(i < static_cast(cdt.getNumConstructors())); - sbp_conj.push_back(DatatypesRewriter::mkTester(children[j], i, cdt)); + sbp_conj.push_back(utils::mkTester(children[j], i, cdt)); } } @@ -828,7 +827,7 @@ Node SygusSymBreakNew::getSimpleSymBreakPred(Node e, { Kind nck = cti.getConsNumKind(k); bool red = false; - Node tester = DatatypesRewriter::mkTester(nc, k, cdt); + Node tester = utils::mkTester(nc, k, cdt); // check if the argument is redundant if (static_cast(k) == anyc_cons_num) { @@ -915,10 +914,10 @@ Node SygusSymBreakNew::getSimpleSymBreakPred(Node e, Node::fromExpr(dt[tindex].getSelectorInternal(tn.toType(), 1)), children[0]); Assert(child11.getType() == children[1].getType()); - Node order_pred_trans = nm->mkNode( - OR, - DatatypesRewriter::mkTester(children[0], tindex, dt).negate(), - getTermOrderPredicate(child11, children[1])); + Node order_pred_trans = + nm->mkNode(OR, + utils::mkTester(children[0], tindex, dt).negate(), + getTermOrderPredicate(child11, children[1])); sbp_conj.push_back(order_pred_trans); } } @@ -933,8 +932,7 @@ Node SygusSymBreakNew::getSimpleSymBreakPred(Node e, << "Simple predicate for " << tn << " index " << tindex << " (" << nk << ") at depth " << depth << " : " << std::endl; Trace("sygus-sb-simple") << " " << sb_pred << std::endl; - sb_pred = nm->mkNode( - kind::OR, DatatypesRewriter::mkTester(n, tindex, dt).negate(), sb_pred); + sb_pred = nm->mkNode(OR, utils::mkTester(n, tindex, dt).negate(), sb_pred); } d_simple_sb_pred[e][tn][tindex][optHashVal][depth] = sb_pred; return sb_pred; @@ -987,7 +985,7 @@ Node SygusSymBreakNew::registerSearchValue(Node a, // we call the body of this function in a bottom-up fashion // this ensures that the "abstraction" of the model value is available if( nv.getNumChildren()>0 ){ - unsigned cindex = DatatypesRewriter::indexOf(nv.getOperator()); + unsigned cindex = utils::indexOf(nv.getOperator()); std::vector rcons_children; rcons_children.push_back(nv.getOperator()); bool childrenChanged = false; @@ -1677,8 +1675,8 @@ bool SygusSymBreakNew::checkValue(Node n, Assert(dt.isSygus()); // ensure that the expected size bound is met - int cindex = DatatypesRewriter::indexOf(vn.getOperator()); - Node tst = DatatypesRewriter::mkTester( n, cindex, dt ); + int cindex = utils::indexOf(vn.getOperator()); + Node tst = utils::mkTester(n, cindex, dt); bool hastst = d_td->getEqualityEngine()->hasTerm(tst); Node tstrep; if (hastst) @@ -1693,7 +1691,7 @@ bool SygusSymBreakNew::checkValue(Node n, if( !hastst ){ // This should not happen generally, it is caused by a sygus term not // being assigned a tester. - Node split = DatatypesRewriter::mkSplit(n, dt); + Node split = utils::mkSplit(n, dt); Trace("sygus-sb") << " SygusSymBreakNew::check: ...WARNING: considered " "missing split for " << n << "." << std::endl; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 8bac280b6..3d178681e 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -23,14 +23,13 @@ #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "options/smt_options.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "options/theory_options.h" #include "theory/datatypes/theory_datatypes_type_rules.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers_engine.h" #include "theory/theory_model.h" #include "theory/type_enumerator.h" #include "theory/valuation.h" -#include "options/theory_options.h" -#include "options/quantifiers_options.h" using namespace std; using namespace CVC4::kind; @@ -296,7 +295,7 @@ void TheoryDatatypes::check(Effort e) { if( dt.getNumConstructors()==1 ){ //this may not be necessary? //if only one constructor, then this term must be this constructor - Node t = DatatypesRewriter::mkTester( n, 0, dt ); + Node t = utils::mkTester(n, 0, dt); d_pending.push_back( t ); d_pending_exp[ t ] = d_true; Trace("datatypes-infer") << "DtInfer : 1-cons (full) : " << t << std::endl; @@ -304,7 +303,7 @@ void TheoryDatatypes::check(Effort e) { }else{ Assert( consIndex!=-1 || dt.isSygus() ); if( options::dtBinarySplit() && consIndex!=-1 ){ - Node test = DatatypesRewriter::mkTester( n, consIndex, dt ); + Node test = utils::mkTester(n, consIndex, dt); Trace("dt-split") << "*************Split for possible constructor " << dt[consIndex] << " for " << n << endl; test = Rewriter::rewrite( test ); NodeBuilder<> nb(kind::OR); @@ -314,7 +313,7 @@ void TheoryDatatypes::check(Effort e) { d_out->requirePhase( test, true ); }else{ Trace("dt-split") << "*************Split for constructors on " << n << endl; - Node lemma = DatatypesRewriter::mkSplit(n, dt); + Node lemma = utils::mkSplit(n, dt); Trace("dt-split-debug") << "Split lemma is : " << lemma << std::endl; d_out->lemma( lemma, false, false, true ); d_addedLemma = true; @@ -488,7 +487,7 @@ void TheoryDatatypes::assertFact( Node fact, Node exp ){ } //add to tester if applicable Node t_arg; - int tindex = DatatypesRewriter::isTester( atom, t_arg ); + int tindex = utils::isTester(atom, t_arg); if (tindex >= 0) { Trace("dt-tester") << "Assert tester : " << atom << " for " << t_arg << std::endl; @@ -568,7 +567,7 @@ Node TheoryDatatypes::expandDefinition(LogicRequest &logicRequest, Node n) { TypeNode ndt = n[0].getType(); if (options::dtSharedSelectors()) { - size_t selectorIndex = DatatypesRewriter::indexOf(selector); + size_t selectorIndex = utils::indexOf(selector); Trace("dt-expand") << "...selector index = " << selectorIndex << std::endl; Assert(selectorIndex < c.getNumArgs()); @@ -678,9 +677,12 @@ Node TheoryDatatypes::ppRewrite(TNode in) if( in.getKind()==EQUAL ){ Node nn; std::vector< Node > rew; - if( DatatypesRewriter::checkClash(in[0], in[1], rew) ){ + if (utils::checkClash(in[0], in[1], rew)) + { nn = NodeManager::currentNM()->mkConst(false); - }else{ + } + else + { nn = rew.size()==0 ? d_true : ( rew.size()==1 ? rew[0] : NodeManager::currentNM()->mkNode( kind::AND, rew ) ); } @@ -839,14 +841,16 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ Trace("datatypes-debug") << " constructors : " << cons1 << " " << cons2 << std::endl; Node unifEq = cons1.eqNode( cons2 ); std::vector< Node > rew; - if( DatatypesRewriter::checkClash( cons1, cons2, rew ) ){ + if (utils::checkClash(cons1, cons2, rew)) + { d_conflictNode = explain( unifEq ); Trace("dt-conflict") << "CONFLICT: Clash conflict : " << d_conflictNode << std::endl; d_out->conflict( d_conflictNode ); d_conflict = true; return; - }else{ - + } + else + { //do unification for( int i=0; i<(int)cons1.getNumChildren(); i++ ) { if( !areEqual( cons1[i], cons2[i] ) ){ @@ -964,13 +968,13 @@ Node TheoryDatatypes::getLabel( Node n ) { int TheoryDatatypes::getLabelIndex( EqcInfo* eqc, Node n ){ if( eqc && !eqc->d_constructor.get().isNull() ){ - return DatatypesRewriter::indexOf(eqc->d_constructor.get().getOperator()); + return utils::indexOf(eqc->d_constructor.get().getOperator()); }else{ Node lbl = getLabel( n ); if( lbl.isNull() ){ return -1; }else{ - int tindex = DatatypesRewriter::isTester( lbl ); + int tindex = utils::isTester(lbl); Assert( tindex!=-1 ); return tindex; } @@ -1118,7 +1122,7 @@ void TheoryDatatypes::addTester( { if( i!=ttindex && neg_testers.find( i )==neg_testers.end() ){ Assert( n.getKind()!=APPLY_CONSTRUCTOR ); - Node infer = DatatypesRewriter::mkTester( n, i, dt ).negate(); + Node infer = utils::mkTester(n, i, dt).negate(); Trace("datatypes-infer") << "DtInfer : neg label : " << infer << " by " << t << std::endl; d_infer.push_back( infer ); d_infer_exp.push_back( t ); @@ -1155,7 +1159,9 @@ void TheoryDatatypes::addTester( } } } - Node t_concl = testerIndex==-1 ? NodeManager::currentNM()->mkConst( false ) : DatatypesRewriter::mkTester( t_arg, testerIndex, dt ); + Node t_concl = testerIndex == -1 + ? NodeManager::currentNM()->mkConst(false) + : utils::mkTester(t_arg, testerIndex, dt); Node t_concl_exp = ( nb.getNumChildren() == 1 ) ? nb.getChild( 0 ) : nb; d_pending.push_back( t_concl ); d_pending_exp[ t_concl ] = t_concl_exp; @@ -1219,7 +1225,7 @@ void TheoryDatatypes::addConstructor( Node c, EqcInfo* eqc, Node n ){ //check labels NodeUIntMap::iterator lbl_i = d_labels.find(n); if( lbl_i != d_labels.end() ){ - size_t constructorIndex = DatatypesRewriter::indexOf(c.getOperator()); + size_t constructorIndex = utils::indexOf(c.getOperator()); size_t n_lbl = (*lbl_i).second; for (size_t i = 0; i < n_lbl; i++) { @@ -1306,7 +1312,7 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { } if( s.getKind()==kind::APPLY_SELECTOR_TOTAL ){ Expr selectorExpr = s.getOperator().toExpr(); - size_t constructorIndex = DatatypesRewriter::indexOf(c.getOperator()); + size_t constructorIndex = utils::indexOf(c.getOperator()); const Datatype& dt = Datatype::datatypeOf(selectorExpr); const DatatypeConstructor& dtc = dt[constructorIndex]; int selectorIndex = dtc.getSelectorIndexInternal( selectorExpr ); @@ -1560,7 +1566,7 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) //must try the infinite ones first bool cfinite = dt[ i ].isInterpretedFinite( tt ); if( pcons[i] && (r==1)==cfinite ){ - neqc = DatatypesRewriter::getInstCons( eqc, dt, i ); + neqc = utils::getInstCons(eqc, dt, i); //for( unsigned j=0; j TheoryDatatypes::entailmentCheck(TNode lit, const Entailme Node r = d_equalityEngine.getRepresentative( n ); EqcInfo * ei = getOrMakeEqcInfo( r, false ); int l_index = getLabelIndex( ei, r ); - int t_index = - static_cast(DatatypesRewriter::indexOf(atom.getOperator())); + int t_index = static_cast(utils::indexOf(atom.getOperator())); Trace("dt-entail") << " Tester indices are " << t_index << " and " << l_index << std::endl; if( l_index!=-1 && (l_index==t_index)==pol ){ std::vector< TNode > exp_c; diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp new file mode 100644 index 000000000..c3b145b15 --- /dev/null +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -0,0 +1,363 @@ +/********************* */ +/*! \file theory_datatypes_utils.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds, Morgan Deters, Paul Meng + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of rewriter for the theory of (co)inductive datatypes. + ** + ** Implementation of rewriter for the theory of (co)inductive datatypes. + **/ + +#include "theory/datatypes/theory_datatypes_utils.h" + +#include "expr/node_algorithm.h" + +using namespace CVC4; +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace datatypes { +namespace utils { + +Node applySygusArgs(const Datatype& dt, + Node op, + Node n, + const std::vector& args) +{ + if (n.getKind() == BOUND_VARIABLE) + { + Assert(n.hasAttribute(SygusVarNumAttribute())); + int vn = n.getAttribute(SygusVarNumAttribute()); + Assert(Node::fromExpr(dt.getSygusVarList())[vn] == n); + return args[vn]; + } + // n is an application of operator op. + // We must compute the free variables in op to determine if there are + // any substitutions we need to make to n. + TNode val; + if (!op.hasAttribute(SygusVarFreeAttribute())) + { + std::unordered_set fvs; + if (expr::getFreeVariables(op, fvs)) + { + if (fvs.size() == 1) + { + for (const Node& v : fvs) + { + val = v; + } + } + else + { + val = op; + } + } + Trace("dt-sygus-fv") << "Free var in " << op << " : " << val << std::endl; + op.setAttribute(SygusVarFreeAttribute(), val); + } + else + { + val = op.getAttribute(SygusVarFreeAttribute()); + } + if (val.isNull()) + { + return n; + } + if (val.getKind() == BOUND_VARIABLE) + { + // single substitution case + int vn = val.getAttribute(SygusVarNumAttribute()); + TNode sub = args[vn]; + return n.substitute(val, sub); + } + // do the full substitution + std::vector vars; + Node bvl = Node::fromExpr(dt.getSygusVarList()); + for (unsigned i = 0, nvars = bvl.getNumChildren(); i < nvars; i++) + { + vars.push_back(bvl[i]); + } + return n.substitute(vars.begin(), vars.end(), args.begin(), args.end()); +} + +Kind getOperatorKindForSygusBuiltin(Node op) +{ + Assert(op.getKind() != BUILTIN); + if (op.getKind() == LAMBDA) + { + return APPLY_UF; + } + TypeNode tn = op.getType(); + if (tn.isConstructor()) + { + return APPLY_CONSTRUCTOR; + } + else if (tn.isSelector()) + { + return APPLY_SELECTOR; + } + else if (tn.isTester()) + { + return APPLY_TESTER; + } + else if (tn.isFunction()) + { + return APPLY_UF; + } + return UNDEFINED_KIND; +} + +Node mkSygusTerm(const Datatype& dt, + unsigned i, + const std::vector& children) +{ + Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i + << "] with children: " << children << std::endl; + Assert(i < dt.getNumConstructors()); + Assert(dt.isSygus()); + Assert(!dt[i].getSygusOp().isNull()); + std::vector schildren; + Node op = Node::fromExpr(dt[i].getSygusOp()); + Trace("dt-sygus-util") << "Operator is " << op << std::endl; + if (children.empty()) + { + // no children, return immediately + Trace("dt-sygus-util") << "...return direct op" << std::endl; + return op; + } + // if it is the any constant, we simply return the child + if (op.getAttribute(SygusAnyConstAttribute())) + { + Assert(children.size() == 1); + return children[0]; + } + if (op.getKind() != BUILTIN) + { + schildren.push_back(op); + } + schildren.insert(schildren.end(), children.begin(), children.end()); + Node ret; + if (op.getKind() == BUILTIN) + { + ret = NodeManager::currentNM()->mkNode(op, schildren); + Trace("dt-sygus-util") << "...return (builtin) " << ret << std::endl; + return ret; + } + Kind ok = NodeManager::operatorToKind(op); + Trace("dt-sygus-util") << "operator kind is " << ok << std::endl; + if (ok != UNDEFINED_KIND) + { + // If it is an APPLY_UF operator, we should have at least an operator and + // a child. + Assert(ok != APPLY_UF || schildren.size() != 1); + ret = NodeManager::currentNM()->mkNode(ok, schildren); + Trace("dt-sygus-util") << "...return (op) " << ret << std::endl; + return ret; + } + Kind tok = getOperatorKindForSygusBuiltin(op); + if (schildren.size() == 1 && tok == UNDEFINED_KIND) + { + ret = schildren[0]; + } + else + { + ret = NodeManager::currentNM()->mkNode(tok, schildren); + } + Trace("dt-sygus-util") << "...return " << ret << std::endl; + return ret; +} + +/** get instantiate cons */ +Node getInstCons(Node n, const Datatype& dt, int index) +{ + Assert(index >= 0 && index < (int)dt.getNumConstructors()); + std::vector children; + NodeManager* nm = NodeManager::currentNM(); + children.push_back(Node::fromExpr(dt[index].getConstructor())); + Type t = n.getType().toType(); + for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++) + { + Node nc = nm->mkNode(APPLY_SELECTOR_TOTAL, + Node::fromExpr(dt[index].getSelectorInternal(t, i)), + n); + children.push_back(nc); + } + Node n_ic = nm->mkNode(APPLY_CONSTRUCTOR, children); + if (dt.isParametric()) + { + TypeNode tn = TypeNode::fromType(t); + // add type ascription for ambiguous constructor types + if (!n_ic.getType().isComparableTo(tn)) + { + Debug("datatypes-parametric") + << "DtInstantiate: ambiguous type for " << n_ic << ", ascribe to " + << n.getType() << std::endl; + Debug("datatypes-parametric") + << "Constructor is " << dt[index] << std::endl; + Type tspec = + dt[index].getSpecializedConstructorType(n.getType().toType()); + Debug("datatypes-parametric") + << "Type specification is " << tspec << std::endl; + children[0] = nm->mkNode(APPLY_TYPE_ASCRIPTION, + nm->mkConst(AscriptionType(tspec)), + children[0]); + n_ic = nm->mkNode(APPLY_CONSTRUCTOR, children); + Assert(n_ic.getType() == tn); + } + } + Assert(isInstCons(n, n_ic, dt) == index); + // n_ic = Rewriter::rewrite( n_ic ); + return n_ic; +} + +int isInstCons(Node t, Node n, const Datatype& dt) +{ + if (n.getKind() == APPLY_CONSTRUCTOR) + { + int index = indexOf(n.getOperator()); + const DatatypeConstructor& c = dt[index]; + Type nt = n.getType().toType(); + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) + { + if (n[i].getKind() != APPLY_SELECTOR_TOTAL + || n[i].getOperator() != Node::fromExpr(c.getSelectorInternal(nt, i)) + || n[i][0] != t) + { + return -1; + } + } + return index; + } + return -1; +} + +int isTester(Node n, Node& a) +{ + if (n.getKind() == APPLY_TESTER) + { + a = n[0]; + return indexOf(n.getOperator()); + } + return -1; +} + +int isTester(Node n) +{ + if (n.getKind() == APPLY_TESTER) + { + return indexOf(n.getOperator()); + } + return -1; +} + +struct DtIndexAttributeId +{ +}; +typedef expr::Attribute DtIndexAttribute; + +unsigned indexOf(Node n) +{ + if (!n.hasAttribute(DtIndexAttribute())) + { + Assert(n.getType().isConstructor() || n.getType().isTester() + || n.getType().isSelector()); + unsigned index = Datatype::indexOfInternal(n.toExpr()); + n.setAttribute(DtIndexAttribute(), index); + return index; + } + return n.getAttribute(DtIndexAttribute()); +} + +Node mkTester(Node n, int i, const Datatype& dt) +{ + return NodeManager::currentNM()->mkNode( + APPLY_TESTER, Node::fromExpr(dt[i].getTester()), n); +} + +Node mkSplit(Node n, const Datatype& dt) +{ + std::vector splits; + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + Node test = mkTester(n, i, dt); + splits.push_back(test); + } + NodeManager* nm = NodeManager::currentNM(); + return splits.size() == 1 ? splits[0] : nm->mkNode(OR, splits); +} + +bool isNullaryApplyConstructor(Node n) +{ + Assert(n.getKind() == APPLY_CONSTRUCTOR); + for (const Node& nc : n) + { + if (nc.getType().isDatatype()) + { + return false; + } + } + return true; +} + +bool isNullaryConstructor(const DatatypeConstructor& c) +{ + for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) + { + if (c[j].getType().getRangeType().isDatatype()) + { + return false; + } + } + return true; +} + +bool checkClash(Node n1, Node n2, std::vector& rew) +{ + Trace("datatypes-rewrite-debug") + << "Check clash : " << n1 << " " << n2 << std::endl; + if (n1.getKind() == APPLY_CONSTRUCTOR && n2.getKind() == APPLY_CONSTRUCTOR) + { + if (n1.getOperator() != n2.getOperator()) + { + Trace("datatypes-rewrite-debug") + << "Clash operators : " << n1 << " " << n2 << " " << n1.getOperator() + << " " << n2.getOperator() << std::endl; + return true; + } + Assert(n1.getNumChildren() == n2.getNumChildren()); + for (unsigned i = 0, size = n1.getNumChildren(); i < size; i++) + { + if (checkClash(n1[i], n2[i], rew)) + { + return true; + } + } + } + else if (n1 != n2) + { + if (n1.isConst() && n2.isConst()) + { + Trace("datatypes-rewrite-debug") + << "Clash constants : " << n1 << " " << n2 << std::endl; + return true; + } + else + { + Node eq = NodeManager::currentNM()->mkNode(EQUAL, n1, n2); + rew.push_back(eq); + } + } + return false; +} + +} // namespace utils +} // namespace datatypes +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h new file mode 100644 index 000000000..ba0643567 --- /dev/null +++ b/src/theory/datatypes/theory_datatypes_utils.h @@ -0,0 +1,199 @@ +/********************* */ +/*! \file theory_datatypes_utils.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Util functions for theory datatypes. + ** + ** Util functions for theory datatypes. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__STRINGS__THEORY_DATATYPES_UTILS_H +#define CVC4__THEORY__STRINGS__THEORY_DATATYPES_UTILS_H + +#include + +#include "expr/node.h" +#include "expr/node_manager_attributes.h" + +namespace CVC4 { +namespace theory { + +// ----------------------- sygus datatype attributes +/** sygus var num */ +struct SygusVarNumAttributeId +{ +}; +typedef expr::Attribute SygusVarNumAttribute; + +/** Attribute true for variables that represent any constant */ +struct SygusAnyConstAttributeId +{ +}; +typedef expr::Attribute SygusAnyConstAttribute; + +/** + * Attribute true for enumerators whose current model values were registered by + * the datatypes sygus solver, and were not excluded by sygus symmetry breaking. + * This is set by the datatypes sygus solver during LAST_CALL effort checks for + * each active sygus enumerator. + */ +struct SygusSymBreakOkAttributeId +{ +}; +typedef expr::Attribute + SygusSymBreakOkAttribute; + +/** sygus var free + * + * This attribute is used to mark whether sygus operators have free occurrences + * of variables from the formal argument list of the function-to-synthesize. + * + * We store three possible cases for sygus operators op: + * (1) op.getAttribute(SygusVarFreeAttribute())==Node::null() + * In this case, op has no free variables from the formal argument list of the + * function-to-synthesize. + * (2) op.getAttribute(SygusVarFreeAttribute())==v, where v is a bound variable. + * In this case, op has exactly one free variable, v. + * (3) op.getAttribute(SygusVarFreeAttribute())==op + * In this case, op has an arbitrary set (cardinality >1) of free variables from + * the formal argument list of the function to synthesize. + * + * This attribute is used to compute applySygusArgs below. + */ +struct SygusVarFreeAttributeId +{ +}; +typedef expr::Attribute SygusVarFreeAttribute; +// ----------------------- end sygus datatype attributes + +namespace datatypes { +namespace utils { + +/** get instantiate cons + * + * This returns the term C( sel^{C,1}( n ), ..., sel^{C,m}( n ) ), + * where C is the index^{th} constructor of datatype dt. + */ +Node getInstCons(Node n, const Datatype& dt, int index); +/** is instantiation cons + * + * If this method returns a value >=0, then that value, call it index, + * is such that n = C( sel^{C,1}( t ), ..., sel^{C,m}( t ) ), + * where C is the index^{th} constructor of dt. + */ +int isInstCons(Node t, Node n, const Datatype& dt); +/** is tester + * + * This method returns a value >=0 if n is a tester predicate. The return + * value indicates the constructor index that the tester n is for. If this + * method returns a value >=0, then it updates a to the argument that the + * tester n applies to. + */ +int isTester(Node n, Node& a); +/** is tester, same as above but does not update an argument */ +int isTester(Node n); +/** + * Get the index of a constructor or tester in its datatype, or the + * index of a selector in its constructor. (Zero is always the + * first index.) + */ +unsigned indexOf(Node n); +/** make tester is-C( n ), where C is the i^{th} constructor of dt */ +Node mkTester(Node n, int i, const Datatype& dt); +/** make tester split + * + * Returns the formula (OR is-C1( n ) ... is-Ck( n ) ), where C1...Ck + * are the constructors of n's type (dt). + */ +Node mkSplit(Node n, const Datatype& dt); +/** returns true iff n is a constructor term with no datatype children */ +bool isNullaryApplyConstructor(Node n); +/** returns true iff c is a constructor with no datatype children */ +bool isNullaryConstructor(const DatatypeConstructor& c); +/** check clash + * + * This method returns true if and only if n1 and n2 have a skeleton that has + * conflicting constructors at some term position. + * Examples of terms that clash are: + * C( x, y ) and D( z ) + * C( D( x ), y ) and C( E( x ), y ) + * Examples of terms that do not clash are: + * C( x, y ) and C( D( x ), y ) + * C( D( x ), y ) and C( x, E( z ) ) + * C( x, y ) and z + */ +bool checkClash(Node n1, Node n2, std::vector& rew); + +// ------------------------ sygus utils + +/** get operator kind for sygus builtin + * + * This returns the Kind corresponding to applications of the operator op + * when building the builtin version of sygus terms. This is used by the + * function mkSygusTerm. + */ +Kind getOperatorKindForSygusBuiltin(Node op); +/** make sygus term + * + * This function returns a builtin term f( children[0], ..., children[n] ) + * where f is the builtin op that the i^th constructor of sygus datatype dt + * encodes. + */ +Node mkSygusTerm(const Datatype& dt, + unsigned i, + const std::vector& children); +/** + * n is a builtin term that is an application of operator op. + * + * This returns an n' such that (eval n args) is n', where n' is a instance of + * n for the appropriate substitution. + * + * For example, given a function-to-synthesize with formal argument list (x,y), + * say we have grammar: + * A -> A+A | A+x | A+(x+y) | y + * These lead to constructors with sygus ops: + * C1 / (lambda w1 w2. w1+w2) + * C2 / (lambda w1. w1+x) + * C3 / (lambda w1. w1+(x+y)) + * C4 / y + * Examples of calling this function: + * applySygusArgs( dt, C1, (APPLY_UF (lambda w1 w2. w1+w2) t1 t2), { 3, 5 } ) + * ... returns (APPLY_UF (lambda w1 w2. w1+w2) t1 t2). + * applySygusArgs( dt, C2, (APPLY_UF (lambda w1. w1+x) t1), { 3, 5 } ) + * ... returns (APPLY_UF (lambda w1. w1+3) t1). + * applySygusArgs( dt, C3, (APPLY_UF (lambda w1. w1+(x+y)) t1), { 3, 5 } ) + * ... returns (APPLY_UF (lambda w1. w1+(3+5)) t1). + * applySygusArgs( dt, C4, y, { 3, 5 } ) + * ... returns 5. + * Notice the attribute SygusVarFreeAttribute is applied to C1, C2, C3, C4, + * to cache the results of whether the evaluation of this constructor needs + * a substitution over the formal argument list of the function-to-synthesize. + */ +Node applySygusArgs(const Datatype& dt, + Node op, + Node n, + const std::vector& args); +/** + * Get the builtin sygus operator for constructor term n of sygus datatype + * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus + * constructor whose sygus op is the builtin operator +, this method returns +. + */ +Node getSygusOpForCTerm(Node n); + +// ------------------------ end sygus utils + +} // namespace utils +} // namespace datatypes +} // namespace theory +} // namespace CVC4 + +#endif diff --git a/src/theory/datatypes/type_enumerator.cpp b/src/theory/datatypes/type_enumerator.cpp index 609106b46..023ade00d 100644 --- a/src/theory/datatypes/type_enumerator.cpp +++ b/src/theory/datatypes/type_enumerator.cpp @@ -14,8 +14,9 @@ ** Enumerators for datatypes. **/ - #include "theory/datatypes/type_enumerator.h" - #include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/type_enumerator.h" +#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" using namespace CVC4; using namespace theory; @@ -187,7 +188,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ Debug("dt-enum-debug") << "done : " << t << std::endl; Assert( t.getKind()==kind::APPLY_CONSTRUCTOR ); // start with the constructor for which a ground term is constructed - d_zeroCtor = datatypes::DatatypesRewriter::indexOf(t.getOperator()); + d_zeroCtor = datatypes::utils::indexOf(t.getOperator()); d_has_debruijn = 0; } Debug("dt-enum") << "zero ctor : " << d_zeroCtor << std::endl; diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index bd85ea496..472a82e29 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -16,7 +16,7 @@ #include "options/datatypes_options.h" #include "options/quantifiers_options.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" using namespace CVC4::kind; @@ -81,7 +81,7 @@ void SygusEnumerator::initialize(Node e) if (sbl.getKind() == NOT) { Node a; - int tst = datatypes::DatatypesRewriter::isTester(sbl[0], a); + int tst = datatypes::utils::isTester(sbl[0], a); if (tst >= 0) { if (a == e) diff --git a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp index 7324add50..5286ab6f7 100644 --- a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp +++ b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp @@ -15,7 +15,6 @@ #include "theory/quantifiers/sygus/sygus_eval_unfold.h" #include "options/quantifiers_options.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/sygus/term_database_sygus.h" using namespace std; diff --git a/src/theory/quantifiers/sygus/sygus_explain.cpp b/src/theory/quantifiers/sygus/sygus_explain.cpp index f55ce2097..b1baed9cb 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.cpp +++ b/src/theory/quantifiers/sygus/sygus_explain.cpp @@ -14,7 +14,7 @@ #include "theory/quantifiers/sygus/sygus_explain.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" using namespace CVC4::kind; @@ -139,8 +139,8 @@ void SygusExplain::getExplanationForEquality(Node n, } Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR); const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); - int i = datatypes::DatatypesRewriter::indexOf(vn.getOperator()); - Node tst = datatypes::DatatypesRewriter::mkTester(n, i, dt); + int i = datatypes::utils::indexOf(vn.getOperator()); + Node tst = datatypes::utils::mkTester(n, i, dt); exp.push_back(tst); for (unsigned j = 0; j < vn.getNumChildren(); j++) { @@ -223,9 +223,9 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb, } } const Datatype& dt = ((DatatypeType)ntn.toType()).getDatatype(); - int cindex = datatypes::DatatypesRewriter::indexOf(vn.getOperator()); + int cindex = datatypes::utils::indexOf(vn.getOperator()); Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors()); - Node tst = datatypes::DatatypesRewriter::mkTester(n, cindex, dt); + Node tst = datatypes::utils::mkTester(n, cindex, dt); exp.push_back(tst); // if the operator of vn is different than vnr, then disunification obligation // is met diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 43696bff0..0dc49fa96 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -20,7 +20,6 @@ #include "options/quantifiers_options.h" #include "printer/sygus_print_callback.h" #include "theory/bv/theory_bv_utils.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/sygus/sygus_grammar_norm.h" #include "theory/quantifiers/sygus/sygus_process_conj.h" #include "theory/quantifiers/sygus/synth_conjecture.h" diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp index ebb92b34b..5e8c9c411 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp @@ -20,7 +20,7 @@ #include "printer/sygus_print_callback.h" #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/cegqi/ceg_instantiator.h" #include "theory/quantifiers/sygus/sygus_grammar_red.h" #include "theory/quantifiers/sygus/term_database_sygus.h" diff --git a/src/theory/quantifiers/sygus/sygus_pbe.cpp b/src/theory/quantifiers/sygus/sygus_pbe.cpp index c76082b02..64bf0972c 100644 --- a/src/theory/quantifiers/sygus/sygus_pbe.cpp +++ b/src/theory/quantifiers/sygus/sygus_pbe.cpp @@ -16,7 +16,6 @@ #include "expr/datatype.h" #include "options/quantifiers_options.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" diff --git a/src/theory/quantifiers/sygus/sygus_process_conj.cpp b/src/theory/quantifiers/sygus/sygus_process_conj.cpp index 2b9592d4d..66e80523a 100644 --- a/src/theory/quantifiers/sygus/sygus_process_conj.cpp +++ b/src/theory/quantifiers/sygus/sygus_process_conj.cpp @@ -18,7 +18,6 @@ #include "expr/datatype.h" #include "options/quantifiers_options.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" diff --git a/src/theory/quantifiers/sygus/sygus_repair_const.cpp b/src/theory/quantifiers/sygus/sygus_repair_const.cpp index 39506b593..5511adb18 100644 --- a/src/theory/quantifiers/sygus/sygus_repair_const.cpp +++ b/src/theory/quantifiers/sygus/sygus_repair_const.cpp @@ -20,7 +20,7 @@ #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/cegqi/ceg_instantiator.h" #include "theory/quantifiers/sygus/sygus_grammar_norm.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -372,7 +372,7 @@ bool SygusRepairConst::isRepairable(Node n, bool useConstantsAsHoles) return false; } Node op = n.getOperator(); - unsigned cindex = datatypes::DatatypesRewriter::indexOf(op); + unsigned cindex = datatypes::utils::indexOf(op); Node sygusOp = Node::fromExpr(dt[cindex].getSygusOp()); if (sygusOp.getAttribute(SygusAnyConstAttribute())) { diff --git a/src/theory/quantifiers/sygus/sygus_unif.cpp b/src/theory/quantifiers/sygus/sygus_unif.cpp index 008947adb..fdc8120ff 100644 --- a/src/theory/quantifiers/sygus/sygus_unif.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif.cpp @@ -14,7 +14,6 @@ #include "theory/quantifiers/sygus/sygus_unif.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" #include "theory/quantifiers_engine.h" diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 207aa4c8e..ff58dbe38 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -15,7 +15,6 @@ #include "theory/quantifiers/sygus/sygus_unif_io.h" #include "options/quantifiers_options.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/evaluator.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" diff --git a/src/theory/quantifiers/sygus/sygus_unif_rl.cpp b/src/theory/quantifiers/sygus/sygus_unif_rl.cpp index 3514ccbeb..3f09a4346 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_rl.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_rl.cpp @@ -17,7 +17,6 @@ #include "options/base_options.h" #include "options/quantifiers_options.h" #include "printer/printer.h" -#include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "util/random.h" diff --git a/src/theory/quantifiers/sygus/sygus_unif_strat.cpp b/src/theory/quantifiers/sygus/sygus_unif_strat.cpp index a41d895b3..e74068ace 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_strat.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_strat.cpp @@ -14,7 +14,7 @@ #include "theory/quantifiers/sygus/sygus_unif_strat.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/sygus_unif.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" @@ -735,8 +735,7 @@ void SygusUnifStrategy::staticLearnRedundantOps( Assert(nc.first < dt.getNumConstructors()); if (!nc.second) { - Node tst = - datatypes::DatatypesRewriter::mkTester(em, nc.first, dt).negate(); + Node tst = datatypes::utils::mkTester(em, nc.first, dt).negate(); if (std::find(lemmas.begin(), lemmas.end(), tst) == lemmas.end()) { @@ -802,7 +801,7 @@ void SygusUnifStrategy::staticLearnRedundantOps( continue; } EnumTypeInfoStrat* etis = snode.d_strats[j]; - unsigned cindex = datatypes::DatatypesRewriter::indexOf(etis->d_cons); + unsigned cindex = datatypes::utils::indexOf(etis->d_cons); // constructors that correspond to strategies are not needed // the intuition is that the strategy itself is responsible for constructing // all terms that use the given constructor diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index e2a8540d4..78f2f6a7e 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -23,7 +23,7 @@ #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "smt/smt_statistics_registry.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/instantiate.h" #include "theory/quantifiers/quantifiers_attributes.h" diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index ed3eec145..ff9fede0b 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -16,10 +16,11 @@ #include "base/cvc4_check.h" #include "options/base_options.h" +#include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "printer/printer.h" #include "theory/arith/arith_msum.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" @@ -184,7 +185,7 @@ Node TermDbSygus::mkGeneric(const Datatype& dt, Assert( !a.isNull() ); children.push_back( a ); } - return datatypes::DatatypesRewriter::mkSygusTerm(dt, c, children); + return datatypes::utils::mkSygusTerm(dt, c, children); } Node TermDbSygus::mkGeneric(const Datatype& dt, int c, std::map& pre) @@ -286,7 +287,7 @@ Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn) } if (n.getKind() == APPLY_CONSTRUCTOR) { - unsigned i = datatypes::DatatypesRewriter::indexOf(n.getOperator()); + unsigned i = datatypes::utils::indexOf(n.getOperator()); Assert(n.getNumChildren() == dt[i].getNumArgs()); std::map pre; for (unsigned j = 0, size = n.getNumChildren(); j < size; j++) @@ -325,7 +326,7 @@ unsigned TermDbSygus::getSygusTermSize( Node n ){ sum += getSygusTermSize(n[i]); } const Datatype& dt = Datatype::datatypeOf(n.getOperator().toExpr()); - int cindex = datatypes::DatatypesRewriter::indexOf(n.getOperator()); + int cindex = datatypes::utils::indexOf(n.getOperator()); Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors()); unsigned weight = dt[cindex].getWeight(); return weight + sum; @@ -422,7 +423,7 @@ void TermDbSygus::registerEnumerator(Node e, // is necessary to generate a term of the form any_constant( x.0 ) for a // fresh variable x.0. Node fv = getFreeVar(stn, 0); - Node exc_val = datatypes::DatatypesRewriter::getInstCons(fv, dt, rindex); + Node exc_val = datatypes::utils::getInstCons(fv, dt, rindex); // should not include the constuctor in any subterm Node x = getFreeVar(stn, 0); Trace("sygus-db") << "Construct symmetry breaking lemma from " << x @@ -776,7 +777,7 @@ bool TermDbSygus::isSymbolicConsApp(Node n) const Assert(tn.isDatatype()); const Datatype& dt = static_cast(tn.toType()).getDatatype(); Assert(dt.isSygus()); - unsigned cindex = datatypes::DatatypesRewriter::indexOf(n.getOperator()); + unsigned cindex = datatypes::utils::indexOf(n.getOperator()); Node sygusOp = Node::fromExpr(dt[cindex].getSygusOp()); // it is symbolic if it represents "any constant" return sygusOp.getAttribute(SygusAnyConstAttribute()); @@ -948,7 +949,7 @@ Node TermDbSygus::unfold( Node en, std::map< Node, Node >& vtm, std::vector< Nod Type headType = en[0].getType().toType(); NodeManager* nm = NodeManager::currentNM(); const Datatype& dt = static_cast(headType).getDatatype(); - unsigned i = datatypes::DatatypesRewriter::indexOf(ev.getOperator()); + unsigned i = datatypes::utils::indexOf(ev.getOperator()); if (track_exp) { // explanation @@ -1007,7 +1008,7 @@ Node TermDbSygus::unfold( Node en, std::map< Node, Node >& vtm, std::vector< Nod } Node ret = mkGeneric(dt, i, pre); // apply the appropriate substitution to ret - ret = datatypes::DatatypesRewriter::applySygusArgs(dt, sop, ret, args); + ret = datatypes::utils::applySygusArgs(dt, sop, ret, args); // rewrite ret = Rewriter::rewrite(ret); return ret; diff --git a/src/theory/quantifiers/sygus/type_info.cpp b/src/theory/quantifiers/sygus/type_info.cpp index 070e2ad9a..818a53711 100644 --- a/src/theory/quantifiers/sygus/type_info.cpp +++ b/src/theory/quantifiers/sygus/type_info.cpp @@ -15,7 +15,7 @@ #include "theory/quantifiers/sygus/type_info.h" #include "base/cvc4_check.h" -#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" using namespace CVC4::kind;